From 1a1500fcbe2d72da00b2390284864f9424fa7986 Mon Sep 17 00:00:00 2001 From: Cizz22 Date: Mon, 9 Mar 2026 12:27:51 +0700 Subject: [PATCH] feat: Implement `with_for_update` in simulation and node retrieval to prevent TOCTOU race conditions. --- src/aeros_simulation/service.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/src/aeros_simulation/service.py b/src/aeros_simulation/service.py index fc382ce..4cd4b9a 100644 --- a/src/aeros_simulation/service.py +++ b/src/aeros_simulation/service.py @@ -66,6 +66,7 @@ async def get_simulation_by_id( db_session: DbSession, simulation_id: Optional[UUID] = None, is_completed: bool = False, + for_update: bool = False, ): """Get a simulation by id.""" query = select(AerosSimulation) @@ -78,6 +79,9 @@ async def get_simulation_by_id( else: query = query.order_by(AerosSimulation.created_at.desc()).limit(1) + if for_update: + query = query.with_for_update() + results = await db_session.execute(query) return results.scalar() @@ -93,7 +97,7 @@ async def get_default_simulation( return results.scalar() -async def get_simulation_node_by(*, db_session: DbSession, **kwargs): +async def get_simulation_node_by(*, db_session: DbSession, for_update: bool = False, **kwargs): """Get a simulation node by column.""" # Build WHERE conditions from kwargs conditions = [] @@ -105,13 +109,16 @@ async def get_simulation_node_by(*, db_session: DbSession, **kwargs): raise ValueError("No valid column conditions provided") query = select(AerosNode).where(*conditions) + if for_update: + query = query.with_for_update() + result = await db_session.execute(query) return result.scalar() async def get_or_save_node(*, db_session: DbSession, node_data: dict, type: str = "calc"): """Get a simulation node by column.""" node = await get_simulation_node_by( - db_session=db_session, node_name=node_data["nodeName"] + db_session=db_session, node_name=node_data["nodeName"], for_update=True ) @@ -410,9 +417,13 @@ async def execute_simulation(*, db_session: DbSession, simulation_id: Optional[U return result simulation = await get_simulation_by_id( - db_session=db_session, simulation_id=simulation_id + db_session=db_session, simulation_id=simulation_id, for_update=True ) + if simulation.status in ["processing", "completed"]: + # Prevent TOCTOU concurrent duplicate running + print(f"Simulation {simulation_id} is already {simulation.status}") + return True simulation.status = "processing" await db_session.commit()