diff --git a/src/simulations/router.py b/src/simulations/router.py index 05eac0c..267fa0a 100644 --- a/src/simulations/router.py +++ b/src/simulations/router.py @@ -23,6 +23,7 @@ router = APIRouter() async def get_simulations( db_session: DbSession, common: CommonParameters, + current_user: CurrentUser, items_per_page: Optional[int] = Query(5), search: Optional[str] = Query(None), ): @@ -31,13 +32,22 @@ async def get_simulations( items_per_page=items_per_page, search=search, common=common, + owner=current_user.name, ) return StandardResponse(data=simulations, message="Data retrieved successfully") @router.get("/{simulation_id}", response_model=StandardResponse[SimulationRead]) -async def get_simulation(db_session: DbSession, simulation_id: str): - simulation = await get(db_session=db_session, simulation_id=simulation_id) +async def get_simulation( + db_session: DbSession, + simulation_id: str, + current_user: CurrentUser, +): + simulation = await get( + db_session=db_session, + simulation_id=simulation_id, + owner=current_user.name, + ) if not simulation: raise HTTPException( @@ -88,7 +98,11 @@ async def update_simulation( simulation_in: SimulationUpdate, current_user: CurrentUser, ): - simulation = await get(db_session=db_session, simulation_id=simulation_id) + simulation = await get( + db_session=db_session, + simulation_id=simulation_id, + owner=current_user.name, + ) if not simulation: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -105,8 +119,16 @@ async def update_simulation( @router.delete("/{simulation_id}", response_model=StandardResponse[SimulationBase]) -async def delete_simulation(db_session: DbSession, simulation_id: str): - simulation = await get(db_session=db_session, simulation_id=simulation_id) +async def delete_simulation( + db_session: DbSession, + simulation_id: str, + current_user: CurrentUser, +): + simulation = await get( + db_session=db_session, + simulation_id=simulation_id, + owner=current_user.name, + ) if not simulation: raise HTTPException( diff --git a/src/simulations/service.py b/src/simulations/service.py index 0dc6865..30977fd 100644 --- a/src/simulations/service.py +++ b/src/simulations/service.py @@ -39,14 +39,19 @@ PLANT_COPY_COLUMNS = [ ] -async def get(*, db_session: DbSession, simulation_id: str) -> Optional[Simulation]: +async def get( + *, db_session: DbSession, simulation_id: str, owner: str +) -> Optional[Simulation]: query = ( Select(Simulation) .options( selectinload(Simulation.masterdata_entries), selectinload(Simulation.plant_transactions), ) - .where(Simulation.id == simulation_id) + .where( + Simulation.id == simulation_id, + Simulation.created_by == owner, + ) ) result = await db_session.execute(query) return result.scalars().one_or_none() @@ -58,8 +63,13 @@ async def get_all( items_per_page: Optional[int], search: Optional[str], common, + owner: str, ): - query = Select(Simulation).order_by(Simulation.created_at.desc()) + query = ( + Select(Simulation) + .where(Simulation.created_by == owner) + .order_by(Simulation.created_at.desc()) + ) if search: query = query.filter(Simulation.label.ilike(f"%{search}%"))