You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
243 lines
6.9 KiB
Python
243 lines
6.9 KiB
Python
import asyncio
|
|
import os
|
|
from subprocess import PIPE
|
|
from typing import Dict, List, Optional
|
|
from uuid import UUID
|
|
|
|
from sqlalchemy import Delete, Select, func
|
|
from sqlalchemy.inspection import inspect as sa_inspect
|
|
from sqlalchemy.orm import selectinload
|
|
|
|
from src.database.core import DbSession
|
|
from src.database.service import search_filter_sort_paginate
|
|
from src.masterdata.model import MasterData
|
|
from src.masterdata_simulations.model import MasterDataSimulation
|
|
from src.plant_transaction_data.model import PlantTransactionData
|
|
from src.plant_transaction_data_simulations.model import PlantTransactionDataSimulations
|
|
from src.auth.service import CurrentUser
|
|
|
|
from .model import Simulation
|
|
from .schema import (
|
|
MasterDataOverride,
|
|
SimulationCreate,
|
|
SimulationRunPayload,
|
|
SimulationUpdate,
|
|
)
|
|
|
|
|
|
MODULES_PLANT_PATH = os.path.abspath(
|
|
os.path.join(os.path.dirname(__file__), "../modules/plant")
|
|
)
|
|
SIMULATION_SCRIPT_PATH = os.path.join(MODULES_PLANT_PATH, "run_plant_simulation.py")
|
|
|
|
MASTERDATA_COPY_COLUMNS = [
|
|
column.key for column in sa_inspect(MasterData).mapper.column_attrs if column.key != "id"
|
|
]
|
|
|
|
PLANT_COPY_COLUMNS = [
|
|
column.key for column in sa_inspect(PlantTransactionData).mapper.column_attrs if column.key != "id"
|
|
]
|
|
|
|
|
|
async def get(
|
|
*, db_session: DbSession, simulation_id: str, owner: str
|
|
) -> Optional[Simulation]:
|
|
query = (
|
|
Select(Simulation)
|
|
.options(
|
|
selectinload(Simulation.masterdata_entries),
|
|
selectinload(Simulation.plant_transactions),
|
|
)
|
|
.where(
|
|
Simulation.id == simulation_id,
|
|
Simulation.created_by == owner,
|
|
)
|
|
)
|
|
result = await db_session.execute(query)
|
|
return result.scalars().one_or_none()
|
|
|
|
|
|
async def get_all(
|
|
*,
|
|
db_session: DbSession,
|
|
items_per_page: Optional[int],
|
|
search: Optional[str],
|
|
common,
|
|
owner: str,
|
|
):
|
|
query = (
|
|
Select(Simulation)
|
|
.where(Simulation.created_by == owner)
|
|
.order_by(Simulation.created_at.desc())
|
|
)
|
|
|
|
if search:
|
|
query = query.filter(Simulation.label.ilike(f"%{search}%"))
|
|
|
|
common["items_per_page"] = items_per_page
|
|
return await search_filter_sort_paginate(model=query, **common)
|
|
|
|
|
|
async def create(*, db_session: DbSession, simulation_in: SimulationCreate) -> Simulation:
|
|
data = simulation_in.model_dump()
|
|
|
|
if data.get("version") is None:
|
|
data["version"] = await _get_next_version(db_session)
|
|
|
|
if not data.get("label"):
|
|
data["label"] = f"Simulation {data['version']}"
|
|
|
|
simulation = Simulation(**data)
|
|
db_session.add(simulation)
|
|
await db_session.commit()
|
|
return simulation
|
|
|
|
|
|
async def update(
|
|
*, db_session: DbSession, simulation: Simulation, simulation_in: SimulationUpdate
|
|
) -> Simulation:
|
|
update_data = simulation_in.model_dump(exclude_defaults=True)
|
|
|
|
for field, value in update_data.items():
|
|
setattr(simulation, field, value)
|
|
|
|
await db_session.commit()
|
|
return simulation
|
|
|
|
|
|
async def delete(*, db_session: DbSession, simulation_id: str) -> None:
|
|
query = Delete(Simulation).where(Simulation.id == simulation_id)
|
|
await db_session.execute(query)
|
|
await db_session.commit()
|
|
|
|
|
|
async def run_simulation(
|
|
*,
|
|
db_session: DbSession,
|
|
payload: SimulationRunPayload,
|
|
current_user: CurrentUser,
|
|
) -> Simulation:
|
|
next_version = await _get_next_version(db_session)
|
|
label = payload.label or f"Simulation {next_version}"
|
|
|
|
simulation = Simulation(
|
|
label=label,
|
|
version=next_version,
|
|
created_by=current_user.name,
|
|
updated_by=current_user.name,
|
|
)
|
|
db_session.add(simulation)
|
|
await db_session.commit()
|
|
|
|
await _copy_masterdata_to_simulation(
|
|
db_session=db_session,
|
|
simulation_id=simulation.id,
|
|
overrides=payload.overrides,
|
|
actor=current_user.name,
|
|
)
|
|
|
|
await _copy_plant_transactions_to_simulation(
|
|
db_session=db_session,
|
|
simulation_id=simulation.id,
|
|
actor=current_user.name,
|
|
)
|
|
|
|
await db_session.commit()
|
|
|
|
try:
|
|
await _run_plant_calculation_for_simulation(simulation.id)
|
|
except RuntimeError as exc:
|
|
raise RuntimeError(str(exc)) from exc
|
|
|
|
await db_session.refresh(simulation)
|
|
return simulation
|
|
|
|
|
|
async def _get_next_version(db_session: DbSession) -> int:
|
|
query = Select(func.max(Simulation.version))
|
|
result = await db_session.execute(query)
|
|
max_version = result.scalar()
|
|
return (max_version or 0) + 1
|
|
|
|
|
|
async def _copy_masterdata_to_simulation(
|
|
*,
|
|
db_session: DbSession,
|
|
simulation_id: UUID,
|
|
overrides: List[MasterDataOverride],
|
|
actor: Optional[str],
|
|
):
|
|
override_map: Dict[str, MasterDataOverride] = {item.name: item for item in overrides or []}
|
|
|
|
result = await db_session.execute(Select(MasterData))
|
|
records = result.scalars().all()
|
|
|
|
if not records:
|
|
raise RuntimeError("Master data is empty; cannot run simulation.")
|
|
|
|
entries: List[MasterDataSimulation] = []
|
|
|
|
for record in records:
|
|
payload = {column: getattr(record, column) for column in MASTERDATA_COPY_COLUMNS}
|
|
payload["simulation_id"] = simulation_id
|
|
if actor:
|
|
payload["created_by"] = actor
|
|
payload["updated_by"] = actor
|
|
|
|
override = override_map.get(record.name)
|
|
if override:
|
|
if getattr(override, "value_num", None) is not None:
|
|
payload["value_num"] = override.value_num
|
|
if getattr(override, "value_str", None) is not None:
|
|
payload["value_str"] = override.value_str
|
|
|
|
entries.append(MasterDataSimulation(**payload))
|
|
|
|
db_session.add_all(entries)
|
|
|
|
|
|
async def _copy_plant_transactions_to_simulation(
|
|
*,
|
|
db_session: DbSession,
|
|
simulation_id: UUID,
|
|
actor: Optional[str],
|
|
):
|
|
result = await db_session.execute(Select(PlantTransactionData))
|
|
rows = result.scalars().all()
|
|
|
|
if not rows:
|
|
raise RuntimeError("Plant transaction data is empty; cannot run simulation.")
|
|
|
|
entries: List[PlantTransactionDataSimulations] = []
|
|
for row in rows:
|
|
payload = {column: getattr(row, column) for column in PLANT_COPY_COLUMNS}
|
|
payload["simulation_id"] = simulation_id
|
|
if actor:
|
|
payload["created_by"] = actor
|
|
payload["updated_by"] = actor
|
|
|
|
entries.append(PlantTransactionDataSimulations(**payload))
|
|
|
|
db_session.add_all(entries)
|
|
|
|
|
|
async def _run_plant_calculation_for_simulation(simulation_id: UUID) -> None:
|
|
env = os.environ.copy()
|
|
env["PLANT_SIMULATION_ID"] = str(simulation_id)
|
|
|
|
process = await asyncio.create_subprocess_exec(
|
|
"python",
|
|
SIMULATION_SCRIPT_PATH,
|
|
stdout=PIPE,
|
|
stderr=PIPE,
|
|
cwd=MODULES_PLANT_PATH,
|
|
env=env,
|
|
)
|
|
stdout, stderr = await process.communicate()
|
|
|
|
if process.returncode != 0:
|
|
error_output = stderr.decode().strip() or stdout.decode().strip()
|
|
raise RuntimeError(
|
|
f"Plant calculation failed for simulation {simulation_id}: {error_output}"
|
|
)
|