diff --git a/src/aeros_simulation/router.py b/src/aeros_simulation/router.py index 4915832..4988e70 100644 --- a/src/aeros_simulation/router.py +++ b/src/aeros_simulation/router.py @@ -18,6 +18,7 @@ from src.aeros_equipment.service import update_equipment_for_simulation from src.aeros_project.service import get_project from temporal.workflow import SimulationWorkflow from .schema import ( + AhmMetricInput, SimulationCalcResult, SimulationInput, SimulationPagination, @@ -324,14 +325,15 @@ async def get_custom_parameters_controller(db_session: DbSession, simulation_id: } @router.post("/ahm_metrics", response_model=StandardResponse[dict]) -async def get_ahm_metrics_controller(db_session: DbSession, simulation_id:UUID): +async def get_ahm_metrics_controller(db_session: DbSession, metrics_in:AhmMetricInput): simulation_result = await get_plant_calc_result( - db_session=db_session, simulation_id=simulation_id + db_session=db_session, simulation_id=metrics_in.target_simulation_id ) - default_simulation = await get_default_simulation(db_session=db_session) + default_simulation_id = metrics_in.baseline_simulation_id if metrics_in.baseline_simulation_id else await get_default_simulation(db_session=db_session).id + default_simulation_result = await get_plant_calc_result( - db_session=db_session, simulation_id=default_simulation.id + db_session=db_session, simulation_id=default_simulation_id ) result = { diff --git a/src/aeros_simulation/schema.py b/src/aeros_simulation/schema.py index 4487dbc..74c5e40 100644 --- a/src/aeros_simulation/schema.py +++ b/src/aeros_simulation/schema.py @@ -115,3 +115,8 @@ class SimulationRankingParameters(EquipmentWithCustomParameters): class SimulationPagination(Pagination): items: List[SimulationData] = [] + + +class AhmMetricInput(BaseModel): + target_simulation_id: str + baseline_simulation_id: Optional[str] \ No newline at end of file