feat: Implement `with_for_update` in simulation and node retrieval to prevent TOCTOU race conditions.

main
Cizz22 6 hours ago
parent 8b2388f5fc
commit 1a1500fcbe

@ -66,6 +66,7 @@ async def get_simulation_by_id(
db_session: DbSession, db_session: DbSession,
simulation_id: Optional[UUID] = None, simulation_id: Optional[UUID] = None,
is_completed: bool = False, is_completed: bool = False,
for_update: bool = False,
): ):
"""Get a simulation by id.""" """Get a simulation by id."""
query = select(AerosSimulation) query = select(AerosSimulation)
@ -78,6 +79,9 @@ async def get_simulation_by_id(
else: else:
query = query.order_by(AerosSimulation.created_at.desc()).limit(1) query = query.order_by(AerosSimulation.created_at.desc()).limit(1)
if for_update:
query = query.with_for_update()
results = await db_session.execute(query) results = await db_session.execute(query)
return results.scalar() return results.scalar()
@ -93,7 +97,7 @@ async def get_default_simulation(
return results.scalar() return results.scalar()
async def get_simulation_node_by(*, db_session: DbSession, **kwargs): async def get_simulation_node_by(*, db_session: DbSession, for_update: bool = False, **kwargs):
"""Get a simulation node by column.""" """Get a simulation node by column."""
# Build WHERE conditions from kwargs # Build WHERE conditions from kwargs
conditions = [] conditions = []
@ -105,13 +109,16 @@ async def get_simulation_node_by(*, db_session: DbSession, **kwargs):
raise ValueError("No valid column conditions provided") raise ValueError("No valid column conditions provided")
query = select(AerosNode).where(*conditions) query = select(AerosNode).where(*conditions)
if for_update:
query = query.with_for_update()
result = await db_session.execute(query) result = await db_session.execute(query)
return result.scalar() return result.scalar()
async def get_or_save_node(*, db_session: DbSession, node_data: dict, type: str = "calc"): async def get_or_save_node(*, db_session: DbSession, node_data: dict, type: str = "calc"):
"""Get a simulation node by column.""" """Get a simulation node by column."""
node = await get_simulation_node_by( node = await get_simulation_node_by(
db_session=db_session, node_name=node_data["nodeName"] db_session=db_session, node_name=node_data["nodeName"], for_update=True
) )
@ -410,9 +417,13 @@ async def execute_simulation(*, db_session: DbSession, simulation_id: Optional[U
return result return result
simulation = await get_simulation_by_id( simulation = await get_simulation_by_id(
db_session=db_session, simulation_id=simulation_id db_session=db_session, simulation_id=simulation_id, for_update=True
) )
if simulation.status in ["processing", "completed"]:
# Prevent TOCTOU concurrent duplicate running
print(f"Simulation {simulation_id} is already {simulation.status}")
return True
simulation.status = "processing" simulation.status = "processing"
await db_session.commit() await db_session.commit()

Loading…
Cancel
Save