refactor get all simulations per user

main
MrWaradana 4 weeks ago
parent 87d334ed73
commit b2b191dd78

@ -23,6 +23,7 @@ router = APIRouter()
async def get_simulations( async def get_simulations(
db_session: DbSession, db_session: DbSession,
common: CommonParameters, common: CommonParameters,
current_user: CurrentUser,
items_per_page: Optional[int] = Query(5), items_per_page: Optional[int] = Query(5),
search: Optional[str] = Query(None), search: Optional[str] = Query(None),
): ):
@ -31,13 +32,22 @@ async def get_simulations(
items_per_page=items_per_page, items_per_page=items_per_page,
search=search, search=search,
common=common, common=common,
owner=current_user.name,
) )
return StandardResponse(data=simulations, message="Data retrieved successfully") return StandardResponse(data=simulations, message="Data retrieved successfully")
@router.get("/{simulation_id}", response_model=StandardResponse[SimulationRead]) @router.get("/{simulation_id}", response_model=StandardResponse[SimulationRead])
async def get_simulation(db_session: DbSession, simulation_id: str): async def get_simulation(
simulation = await get(db_session=db_session, simulation_id=simulation_id) db_session: DbSession,
simulation_id: str,
current_user: CurrentUser,
):
simulation = await get(
db_session=db_session,
simulation_id=simulation_id,
owner=current_user.name,
)
if not simulation: if not simulation:
raise HTTPException( raise HTTPException(
@ -88,7 +98,11 @@ async def update_simulation(
simulation_in: SimulationUpdate, simulation_in: SimulationUpdate,
current_user: CurrentUser, current_user: CurrentUser,
): ):
simulation = await get(db_session=db_session, simulation_id=simulation_id) simulation = await get(
db_session=db_session,
simulation_id=simulation_id,
owner=current_user.name,
)
if not simulation: if not simulation:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
@ -105,8 +119,16 @@ async def update_simulation(
@router.delete("/{simulation_id}", response_model=StandardResponse[SimulationBase]) @router.delete("/{simulation_id}", response_model=StandardResponse[SimulationBase])
async def delete_simulation(db_session: DbSession, simulation_id: str): async def delete_simulation(
simulation = await get(db_session=db_session, simulation_id=simulation_id) db_session: DbSession,
simulation_id: str,
current_user: CurrentUser,
):
simulation = await get(
db_session=db_session,
simulation_id=simulation_id,
owner=current_user.name,
)
if not simulation: if not simulation:
raise HTTPException( raise HTTPException(

@ -39,14 +39,19 @@ PLANT_COPY_COLUMNS = [
] ]
async def get(*, db_session: DbSession, simulation_id: str) -> Optional[Simulation]: async def get(
*, db_session: DbSession, simulation_id: str, owner: str
) -> Optional[Simulation]:
query = ( query = (
Select(Simulation) Select(Simulation)
.options( .options(
selectinload(Simulation.masterdata_entries), selectinload(Simulation.masterdata_entries),
selectinload(Simulation.plant_transactions), selectinload(Simulation.plant_transactions),
) )
.where(Simulation.id == simulation_id) .where(
Simulation.id == simulation_id,
Simulation.created_by == owner,
)
) )
result = await db_session.execute(query) result = await db_session.execute(query)
return result.scalars().one_or_none() return result.scalars().one_or_none()
@ -58,8 +63,13 @@ async def get_all(
items_per_page: Optional[int], items_per_page: Optional[int],
search: Optional[str], search: Optional[str],
common, common,
owner: str,
): ):
query = Select(Simulation).order_by(Simulation.created_at.desc()) query = (
Select(Simulation)
.where(Simulation.created_by == owner)
.order_by(Simulation.created_at.desc())
)
if search: if search:
query = query.filter(Simulation.label.ilike(f"%{search}%")) query = query.filter(Simulation.label.ilike(f"%{search}%"))

Loading…
Cancel
Save