From 0cec12e0c5670a4b0e25d4117b0ddd831293f098 Mon Sep 17 00:00:00 2001 From: Cizz22 Date: Tue, 19 Aug 2025 17:28:53 +0700 Subject: [PATCH] add regular node filter --- src/aeros_simulation/router.py | 4 ++-- src/aeros_simulation/service.py | 15 +++++++++++++-- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/src/aeros_simulation/router.py b/src/aeros_simulation/router.py index 2147c7f..c491c4e 100644 --- a/src/aeros_simulation/router.py +++ b/src/aeros_simulation/router.py @@ -111,10 +111,10 @@ async def run_simulations( "/result/calc/{simulation_id}", response_model=StandardResponse[List[SimulationCalc]], ) -async def get_simulation_result(db_session: DbSession, simulation_id, schematic_name: Optional[str] = Query(None)): +async def get_simulation_result(db_session: DbSession, simulation_id, schematic_name: Optional[str] = Query(None), node_type = Query(None, alias="nodetype")): """Get simulation result.""" simulation_result = await get_simulation_with_calc_result( - db_session=db_session, simulation_id=simulation_id, schematic_name=schematic_name + db_session=db_session, simulation_id=simulation_id, schematic_name=schematic_name, node_type=node_type ) return { diff --git a/src/aeros_simulation/service.py b/src/aeros_simulation/service.py index f8e92fa..5999b19 100644 --- a/src/aeros_simulation/service.py +++ b/src/aeros_simulation/service.py @@ -483,7 +483,7 @@ async def create_simulation(*, db_session: DbSession, simulation_in: SimulationI async def get_simulation_with_calc_result( - *, db_session: DbSession, simulation_id: UUID, aeros_node_id: Optional[UUID] = None, schematic_name: Optional[str] = None + *, db_session: DbSession, simulation_id: UUID, aeros_node_id: Optional[UUID] = None, schematic_name: Optional[str] = None, node_type: Optional[str] = None ): """Get a simulation by id.""" query = (select(AerosSimulationCalcResult).filter( @@ -499,6 +499,11 @@ async def get_simulation_with_calc_result( AerosNode, AerosNode.id == AerosSimulationCalcResult.aeros_node_id ).filter(AerosNode.structure_name.contains(schematic_name)) + if node_type: + query = query.join( + AerosNode, AerosNode.id == AerosSimulationCalcResult.aeros_node_id + ).filter(AerosNode.node_type == node_type) + query = query.options( selectinload(AerosSimulationCalcResult.aeros_node).options( selectinload(AerosNode.equipment) @@ -545,7 +550,7 @@ async def get_result_ranking(*, db_session: DbSession, simulation_id: UUID): async def get_simulation_with_plot_result( - *, db_session: DbSession, simulation_id: UUID + *, db_session: DbSession, simulation_id: UUID, node_type: Optional[str] = None ): """Get a simulation by id.""" query = ( @@ -557,6 +562,12 @@ async def get_simulation_with_plot_result( ) ) ) + + if node_type: + query = query.join( + AerosNode, AerosNode.id == AerosSimulation.plot_results.aeros_node_id + ).filter(AerosNode.node_type == node_type) + simulation = await db_session.execute(query) return simulation.scalar()