import asyncio import os from subprocess import PIPE from typing import Dict, List, Optional from uuid import UUID from sqlalchemy import Delete, Select, func from sqlalchemy.inspection import inspect as sa_inspect from sqlalchemy.orm import selectinload from src.database.core import DbSession from src.database.service import search_filter_sort_paginate from src.masterdata.model import MasterData from src.masterdata_simulations.model import MasterDataSimulation from src.plant_transaction_data.model import PlantTransactionData from src.plant_transaction_data_simulations.model import PlantTransactionDataSimulations from src.auth.service import CurrentUser from .model import Simulation from .schema import ( MasterDataOverride, SimulationCreate, SimulationRunPayload, SimulationUpdate, ) MODULES_PLANT_PATH = os.path.abspath( os.path.join(os.path.dirname(__file__), "../modules/plant") ) SIMULATION_SCRIPT_PATH = os.path.join(MODULES_PLANT_PATH, "run_plant_simulation.py") MASTERDATA_COPY_COLUMNS = [ column.key for column in sa_inspect(MasterData).mapper.column_attrs if column.key != "id" ] PLANT_COPY_COLUMNS = [ column.key for column in sa_inspect(PlantTransactionData).mapper.column_attrs if column.key != "id" ] async def get( *, db_session: DbSession, simulation_id: str, owner: str ) -> Optional[Simulation]: query = ( Select(Simulation) .options( selectinload(Simulation.masterdata_entries), selectinload(Simulation.plant_transactions), ) .where( Simulation.id == simulation_id, Simulation.created_by == owner, ) ) result = await db_session.execute(query) return result.scalars().one_or_none() async def get_all( *, db_session: DbSession, items_per_page: Optional[int], search: Optional[str], common, owner: str, ): query = ( Select(Simulation) .where(Simulation.created_by == owner) .order_by(Simulation.created_at.desc()) ) if search: query = query.filter(Simulation.label.ilike(f"%{search}%")) common["items_per_page"] = items_per_page return await search_filter_sort_paginate(model=query, **common) async def create(*, db_session: DbSession, simulation_in: SimulationCreate) -> Simulation: data = simulation_in.model_dump() if data.get("version") is None: data["version"] = await _get_next_version(db_session) if not data.get("label"): data["label"] = f"Simulation {data['version']}" simulation = Simulation(**data) db_session.add(simulation) await db_session.commit() return simulation async def update( *, db_session: DbSession, simulation: Simulation, simulation_in: SimulationUpdate ) -> Simulation: update_data = simulation_in.model_dump(exclude_defaults=True) for field, value in update_data.items(): setattr(simulation, field, value) await db_session.commit() return simulation async def delete(*, db_session: DbSession, simulation_id: str) -> None: query = Delete(Simulation).where(Simulation.id == simulation_id) await db_session.execute(query) await db_session.commit() async def run_simulation( *, db_session: DbSession, payload: SimulationRunPayload, current_user: CurrentUser, ) -> Simulation: next_version = await _get_next_version(db_session) label = payload.label or f"Simulation {next_version}" simulation = Simulation( label=label, version=next_version, created_by=current_user.name, updated_by=current_user.name, ) db_session.add(simulation) await db_session.commit() await _copy_masterdata_to_simulation( db_session=db_session, simulation_id=simulation.id, overrides=payload.overrides, actor=current_user.name, ) await _copy_plant_transactions_to_simulation( db_session=db_session, simulation_id=simulation.id, actor=current_user.name, ) await db_session.commit() try: await _run_plant_calculation_for_simulation(simulation.id) except RuntimeError as exc: raise RuntimeError(str(exc)) from exc await db_session.refresh(simulation) return simulation async def _get_next_version(db_session: DbSession) -> int: query = Select(func.max(Simulation.version)) result = await db_session.execute(query) max_version = result.scalar() return (max_version or 0) + 1 async def _copy_masterdata_to_simulation( *, db_session: DbSession, simulation_id: UUID, overrides: List[MasterDataOverride], actor: Optional[str], ): override_map: Dict[str, MasterDataOverride] = {item.name: item for item in overrides or []} result = await db_session.execute(Select(MasterData)) records = result.scalars().all() if not records: raise RuntimeError("Master data is empty; cannot run simulation.") entries: List[MasterDataSimulation] = [] for record in records: payload = {column: getattr(record, column) for column in MASTERDATA_COPY_COLUMNS} payload["simulation_id"] = simulation_id if actor: payload["created_by"] = actor payload["updated_by"] = actor override = override_map.get(record.name) if override: if getattr(override, "value_num", None) is not None: payload["value_num"] = override.value_num if getattr(override, "value_str", None) is not None: payload["value_str"] = override.value_str entries.append(MasterDataSimulation(**payload)) db_session.add_all(entries) async def _copy_plant_transactions_to_simulation( *, db_session: DbSession, simulation_id: UUID, actor: Optional[str], ): result = await db_session.execute(Select(PlantTransactionData)) rows = result.scalars().all() if not rows: raise RuntimeError("Plant transaction data is empty; cannot run simulation.") entries: List[PlantTransactionDataSimulations] = [] for row in rows: payload = {column: getattr(row, column) for column in PLANT_COPY_COLUMNS} payload["simulation_id"] = simulation_id if actor: payload["created_by"] = actor payload["updated_by"] = actor entries.append(PlantTransactionDataSimulations(**payload)) db_session.add_all(entries) async def _run_plant_calculation_for_simulation(simulation_id: UUID) -> None: env = os.environ.copy() env["PLANT_SIMULATION_ID"] = str(simulation_id) process = await asyncio.create_subprocess_exec( "python", SIMULATION_SCRIPT_PATH, stdout=PIPE, stderr=PIPE, cwd=MODULES_PLANT_PATH, env=env, ) stdout, stderr = await process.communicate() if process.returncode != 0: error_output = stderr.decode().strip() or stdout.decode().strip() raise RuntimeError( f"Plant calculation failed for simulation {simulation_id}: {error_output}" )