from datetime import datetime from typing import Optional from uuid import UUID import httpx from fastapi import HTTPException, status from sqlalchemy import delete, select from sqlalchemy.orm import selectinload from src.config import AEROS_BASE_URL from src.database.core import DbSession from src.database.service import CommonParameters, search_filter_sort_paginate from .model import ( AerosNode, AerosSimulation, AerosSimulationCalcResult, AerosSimulationPlotResult, ) from .schema import SimulationInput client = httpx.AsyncClient(timeout=300.0) active_simulations = {} async def get_all(common: CommonParameters): query = select(AerosSimulation) results = await search_filter_sort_paginate(model=query, **common) return results async def get_simulation_by_id( *, db_session: DbSession, simulation_id: Optional[UUID] = None, is_completed: bool = False, ): """Get a simulation by id.""" query = select(AerosSimulation) if is_completed: query = query.where(AerosSimulation.status == "completed") if simulation_id: query = query.where(AerosSimulation.id == simulation_id) else: query = query.order_by(AerosSimulation.id.asc()).limit(1) results = await db_session.execute(query) return results.scalar() async def get_simulation_node_by(*, db_session: DbSession, **kwargs): """Get a simulation node by column.""" # Build WHERE conditions from kwargs conditions = [] for key, value in kwargs.items(): if hasattr(AerosNode, key): conditions.append(getattr(AerosNode, key) == value) if not conditions: raise ValueError("No valid column conditions provided") query = select(AerosNode).where(*conditions) result = await db_session.execute(query) return result.scalar() async def get_or_save_node(*, db_session: DbSession, node_data: dict): """Get a simulation node by column.""" node = await get_simulation_node_by( db_session=db_session, node_id=node_data["nodeId"] ) if not node: node = AerosNode( node_type=node_data["nodeType"], original_node_id=node_data["originalNodeId"], node_id=node_data["nodeId"], node_name=node_data["nodeName"], structure_name=node_data["structureName"], schematic_name=node_data["schematicName"], schematic_id=node_data["schematicId"], original_schematic_id=node_data["originalSchematicId"], ref_schematic_id=node_data["refSchematicId"], orignal_ref_schematic_id=node_data["orignalRefSchematicId"], ) db_session.add(node) await db_session.commit() return node async def execute_simulation( *, db_session: DbSession, simulation_id: Optional[UUID] = None, sim_data: dict, is_saved: bool = False, ): """Execute the actual simulation call""" try: response = await client.post( f"{AEROS_BASE_URL}/api/Simulation/RunSimulation", json=sim_data, headers={"Content-Type": "application/json"}, ) response.raise_for_status() result = response.json() if is_saved: simulation = await get_simulation_by_id( db_session=db_session, simulation_id=simulation_id ) simulation.status = "proccessing" simulation.result = result await db_session.commit() await save_simulation_result( db_session=db_session, simulation_id=simulation_id, result=result ) return result except Exception as e: simulation = await get_simulation_by_id( db_session=db_session, simulation_id=simulation_id ) simulation.status = "failed" simulation.error = str(e) await db_session.commit() raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e) ) async def save_simulation_result( *, db_session: DbSession, simulation_id: UUID, result: dict ): """Save the simulation result.""" calc_result = result["nodeResultOuts"] plot_result = result["plotNodeOuts"] """Save the simulation result""" avaiable_nodes = { node["nodeId"]: await get_simulation_node_by( db_session=db_session, node_id=node["nodeId"] ) for node in calc_result } calc_objects = [] plot_objects = [] try: for result in calc_result: node = avaiable_nodes.get(result["nodeId"], None) if not node: if result["nodeName"] == "Main": node = await get_simulation_node_by( db_session=db_session, node_name=result["nodeName"] ) else: continue calc_result = AerosSimulationCalcResult( aeros_simulation_id=simulation_id, aeros_node_id=node.id, total_downtime=result["totalDowntime"], total_uptime=result["totalUpTime"], num_events=result["numEvents"], production=result["production"], production_std=result["productionStd"], ideal_production=result["idealProduction"], availability=result["availability"], efficiency=result["efficiency"], effective_loss=result["effectiveLoss"], num_cm=result["numCM"], cm_waiting_time=result["cmWaitingTime"], total_cm_downtime=result["totalCMDowntime"], num_pm=result["numPM"], total_pm_downtime=result["totalPMDowntime"], num_ip=result["numIP"], total_ip_downtime=result["totalIPDowntime"], num_oh=result["numOH"], total_oh_downtime=result["totalOHDowntime"], t_wait_for_crew=result["tWaitForCrew"], t_wait_for_spare=result["tWaitForSpare"], duration_at_full=result["durationAtFull"], duration_above_hh=result["durationAboveHH"], duration_above_h=result["durationAboveH"], duration_below_l=result["durationBelowL"], duration_below_ll=result["durationBelowLL"], duration_at_empty=result["durationAtEmpty"], stg_input=result["stgInput"], stg_output=result["stgOutput"], average_level=result["averageLevel"], potential_production=result["potentialProduction"], eaf=result["production"] / result["idealProduction"], ) calc_objects.append(calc_result) for result in plot_result: node = avaiable_nodes.get(result["nodeId"], None) if not node: if result["nodeName"] == "Main": node = await get_simulation_node_by( db_session=db_session, node_name=result["nodeName"] ) else: continue plot_result = AerosSimulationPlotResult( aeros_simulation_id=simulation_id, aeros_node_id=node.id, max_flow_rate=result["maxFlowrate"], storage_capacity=result["storageCapacity"], point_availabilities=result["pointAvailabilities"], point_flowrates=result["pointFlowrates"], timestamp_outs=result["timeStampOuts"], ) plot_objects.append(plot_result) except Exception as e: simulation = await get_simulation_by_id( db_session=db_session, simulation_id=simulation_id ) simulation.status = "failed" simulation.result = str(e) await db_session.commit() raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e) ) db_session.add_all(calc_objects) db_session.add_all(plot_objects) simulation = await get_simulation_by_id( db_session=db_session, simulation_id=simulation_id ) simulation.status = "completed" simulation.completed_at = datetime.now() await db_session.commit() return async def save_default_simulation_node( *, db_session: DbSession, project_name: str = "trialapi" ): sim_data = { "projectName": project_name, "SchematicName": "Boiler", "SimSeed": 1, "SimDuration": 3, "DurationUnit": "UYear", "SimNumRun": 1, } results = await execute_simulation(db_session=db_session, sim_data=sim_data) nodes = [] # delete old data await db_session.execute(delete(AerosNode)) plotResult = results["plotNodeOuts"] nodeResult = results["nodeResultOuts"] for result in nodeResult: aeros_node = AerosNode( node_name=result["nodeName"], node_type=result["nodeType"], node_id=convert_id_to_none_if_negative(result["nodeId"]), original_node_id=convert_id_to_none_if_negative(result["originalNodeId"]), structure_name=result["structureName"], schematic_name=result["schematicName"], schematic_id=convert_id_to_none_if_negative(result["schematicId"]), original_schematic_id=convert_id_to_none_if_negative( result["originalSchematicId"] ), ref_schematic_id=convert_id_to_none_if_negative(result["refSchematicId"]), orignal_ref_schematic_id=convert_id_to_none_if_negative( result["orinalRefSchematic"] ), ) nodes.append(aeros_node) for result in plotResult: nodeId = convert_id_to_none_if_negative(result["nodeId"]) nodeName = result["nodeName"] # Check in node name and nodeid already exists in nodes if not any( node.node_name == nodeName and node.node_id == nodeId for node in nodes ): aeros_node = AerosNode( node_name=nodeName, node_type=result["nodeType"], node_id=nodeId, original_node_id=convert_id_to_none_if_negative(result["originalNodeId"]), schematic_name=result["schematicName"], schematic_id=convert_id_to_none_if_negative(result["parentSchematicId"]), original_schematic_id=convert_id_to_none_if_negative( result["originalParentSchematicId"] ), ref_schematic_id=convert_id_to_none_if_negative(result["targetSchematicId"]), orignal_ref_schematic_id=convert_id_to_none_if_negative( result["originalTargetSchematicId"] ), ) nodes.append(aeros_node) db_session.add_all(nodes) await db_session.commit() def convert_id_to_none_if_negative(value): """Convert ID to None if it's below 0, otherwise return the value.""" return None if value < 0 else value async def create_simulation(*, db_session: DbSession, simulation_in: SimulationInput): """Create a new simulation.""" input = simulation_in.model_dump(exclude={"SimulationName"}) active_simulations = { "status": "running", "started_at": datetime.now(), "input": input, "simulation_name": simulation_in.SimulationName, "schematic_name": input["SchematicName"], } simulation = AerosSimulation(**active_simulations) db_session.add(simulation) await db_session.commit() return simulation async def get_simulation_with_calc_result( *, db_session: DbSession, simulation_id: UUID, aeros_node_id: Optional[UUID] = None ): """Get a simulation by id.""" query = ( select(AerosSimulation) .where(AerosSimulation.id == simulation_id) .options( selectinload(AerosSimulation.calc_results).options( selectinload(AerosSimulationCalcResult.aeros_node) ) ) ) simulation = await db_session.execute(query) return simulation.scalar() async def get_simulation_with_plot_result( *, db_session: DbSession, simulation_id: UUID ): """Get a simulation by id.""" query = ( select(AerosSimulation) .where(AerosSimulation.id == simulation_id) .options( selectinload(AerosSimulation.plot_results).options( selectinload(AerosSimulationPlotResult.aeros_node) ) ) ) simulation = await db_session.execute(query) return simulation.scalar() async def get_calc_result_by( *, db_session: DbSession, simulation_id: UUID, aeros_node_id: Optional[UUID] = None ): """Get a simulation node by column.""" # Build WHERE conditions from kwargs query = select(AerosSimulationCalcResult).where( AerosSimulationCalcResult.aeros_simulation_id == simulation_id ) if aeros_node_id: query = query.where(AerosSimulationCalcResult.aeros_node_id == aeros_node_id) result = await db_session.execute(query) return result.scalar() async def get_custom_parameters(*, db_session: DbSession, simulation_id: UUID): """Get a simulation node by column.""" # Build WHERE conditions from kwargs query = select(AerosSimulationCalcResult).where( AerosSimulationCalcResult.aeros_simulation_id == simulation_id ) query = query.join( AerosNode, AerosNode.id == AerosSimulationCalcResult.aeros_node_id ) query = query.where(AerosNode.node_type == "RegularNode") query = ( query.order_by(AerosSimulationCalcResult.eaf.desc()) .limit(20) .options(selectinload(AerosSimulationCalcResult.aeros_node)) ) result = await db_session.execute(query) return result.scalars().all()