You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

243 lines
6.9 KiB
Python

import asyncio
import os
from subprocess import PIPE
from typing import Dict, List, Optional
from uuid import UUID
from sqlalchemy import Delete, Select, func
from sqlalchemy.inspection import inspect as sa_inspect
from sqlalchemy.orm import selectinload
from src.database.core import DbSession
from src.database.service import search_filter_sort_paginate
from src.masterdata.model import MasterData
from src.masterdata_simulations.model import MasterDataSimulation
from src.plant_transaction_data.model import PlantTransactionData
from src.plant_transaction_data_simulations.model import PlantTransactionDataSimulations
from src.auth.service import CurrentUser
from .model import Simulation
from .schema import (
MasterDataOverride,
SimulationCreate,
SimulationRunPayload,
SimulationUpdate,
)
MODULES_PLANT_PATH = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../modules/plant")
)
SIMULATION_SCRIPT_PATH = os.path.join(MODULES_PLANT_PATH, "run_plant_simulation.py")
MASTERDATA_COPY_COLUMNS = [
column.key for column in sa_inspect(MasterData).mapper.column_attrs if column.key != "id"
]
PLANT_COPY_COLUMNS = [
column.key for column in sa_inspect(PlantTransactionData).mapper.column_attrs if column.key != "id"
]
async def get(
*, db_session: DbSession, simulation_id: str, owner: str
) -> Optional[Simulation]:
query = (
Select(Simulation)
.options(
selectinload(Simulation.masterdata_entries),
selectinload(Simulation.plant_transactions),
)
.where(
Simulation.id == simulation_id,
Simulation.created_by == owner,
)
)
result = await db_session.execute(query)
return result.scalars().one_or_none()
async def get_all(
*,
db_session: DbSession,
items_per_page: Optional[int],
search: Optional[str],
common,
owner: str,
):
query = (
Select(Simulation)
.where(Simulation.created_by == owner)
.order_by(Simulation.created_at.desc())
)
if search:
query = query.filter(Simulation.label.ilike(f"%{search}%"))
common["items_per_page"] = items_per_page
return await search_filter_sort_paginate(model=query, **common)
async def create(*, db_session: DbSession, simulation_in: SimulationCreate) -> Simulation:
data = simulation_in.model_dump()
if data.get("version") is None:
data["version"] = await _get_next_version(db_session)
if not data.get("label"):
data["label"] = f"Simulation {data['version']}"
simulation = Simulation(**data)
db_session.add(simulation)
await db_session.commit()
return simulation
async def update(
*, db_session: DbSession, simulation: Simulation, simulation_in: SimulationUpdate
) -> Simulation:
update_data = simulation_in.model_dump(exclude_defaults=True)
for field, value in update_data.items():
setattr(simulation, field, value)
await db_session.commit()
return simulation
async def delete(*, db_session: DbSession, simulation_id: str) -> None:
query = Delete(Simulation).where(Simulation.id == simulation_id)
await db_session.execute(query)
await db_session.commit()
async def run_simulation(
*,
db_session: DbSession,
payload: SimulationRunPayload,
current_user: CurrentUser,
) -> Simulation:
next_version = await _get_next_version(db_session)
label = payload.label or f"Simulation {next_version}"
simulation = Simulation(
label=label,
version=next_version,
created_by=current_user.name,
updated_by=current_user.name,
)
db_session.add(simulation)
await db_session.commit()
await _copy_masterdata_to_simulation(
db_session=db_session,
simulation_id=simulation.id,
overrides=payload.overrides,
actor=current_user.name,
)
await _copy_plant_transactions_to_simulation(
db_session=db_session,
simulation_id=simulation.id,
actor=current_user.name,
)
await db_session.commit()
try:
await _run_plant_calculation_for_simulation(simulation.id)
except RuntimeError as exc:
raise RuntimeError(str(exc)) from exc
await db_session.refresh(simulation)
return simulation
async def _get_next_version(db_session: DbSession) -> int:
query = Select(func.max(Simulation.version))
result = await db_session.execute(query)
max_version = result.scalar()
return (max_version or 0) + 1
async def _copy_masterdata_to_simulation(
*,
db_session: DbSession,
simulation_id: UUID,
overrides: List[MasterDataOverride],
actor: Optional[str],
):
override_map: Dict[str, MasterDataOverride] = {item.name: item for item in overrides or []}
result = await db_session.execute(Select(MasterData))
records = result.scalars().all()
if not records:
raise RuntimeError("Master data is empty; cannot run simulation.")
entries: List[MasterDataSimulation] = []
for record in records:
payload = {column: getattr(record, column) for column in MASTERDATA_COPY_COLUMNS}
payload["simulation_id"] = simulation_id
if actor:
payload["created_by"] = actor
payload["updated_by"] = actor
override = override_map.get(record.name)
if override:
if getattr(override, "value_num", None) is not None:
payload["value_num"] = override.value_num
if getattr(override, "value_str", None) is not None:
payload["value_str"] = override.value_str
entries.append(MasterDataSimulation(**payload))
db_session.add_all(entries)
async def _copy_plant_transactions_to_simulation(
*,
db_session: DbSession,
simulation_id: UUID,
actor: Optional[str],
):
result = await db_session.execute(Select(PlantTransactionData))
rows = result.scalars().all()
if not rows:
raise RuntimeError("Plant transaction data is empty; cannot run simulation.")
entries: List[PlantTransactionDataSimulations] = []
for row in rows:
payload = {column: getattr(row, column) for column in PLANT_COPY_COLUMNS}
payload["simulation_id"] = simulation_id
if actor:
payload["created_by"] = actor
payload["updated_by"] = actor
entries.append(PlantTransactionDataSimulations(**payload))
db_session.add_all(entries)
async def _run_plant_calculation_for_simulation(simulation_id: UUID) -> None:
env = os.environ.copy()
env["PLANT_SIMULATION_ID"] = str(simulation_id)
process = await asyncio.create_subprocess_exec(
"python",
SIMULATION_SCRIPT_PATH,
stdout=PIPE,
stderr=PIPE,
cwd=MODULES_PLANT_PATH,
env=env,
)
stdout, stderr = await process.communicate()
if process.returncode != 0:
error_output = stderr.decode().strip() or stdout.decode().strip()
raise RuntimeError(
f"Plant calculation failed for simulation {simulation_id}: {error_output}"
)