refactor get all simulations per user

main
MrWaradana 4 weeks ago
parent 87d334ed73
commit b2b191dd78

@ -23,6 +23,7 @@ router = APIRouter()
async def get_simulations(
db_session: DbSession,
common: CommonParameters,
current_user: CurrentUser,
items_per_page: Optional[int] = Query(5),
search: Optional[str] = Query(None),
):
@ -31,13 +32,22 @@ async def get_simulations(
items_per_page=items_per_page,
search=search,
common=common,
owner=current_user.name,
)
return StandardResponse(data=simulations, message="Data retrieved successfully")
@router.get("/{simulation_id}", response_model=StandardResponse[SimulationRead])
async def get_simulation(db_session: DbSession, simulation_id: str):
simulation = await get(db_session=db_session, simulation_id=simulation_id)
async def get_simulation(
db_session: DbSession,
simulation_id: str,
current_user: CurrentUser,
):
simulation = await get(
db_session=db_session,
simulation_id=simulation_id,
owner=current_user.name,
)
if not simulation:
raise HTTPException(
@ -88,7 +98,11 @@ async def update_simulation(
simulation_in: SimulationUpdate,
current_user: CurrentUser,
):
simulation = await get(db_session=db_session, simulation_id=simulation_id)
simulation = await get(
db_session=db_session,
simulation_id=simulation_id,
owner=current_user.name,
)
if not simulation:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
@ -105,8 +119,16 @@ async def update_simulation(
@router.delete("/{simulation_id}", response_model=StandardResponse[SimulationBase])
async def delete_simulation(db_session: DbSession, simulation_id: str):
simulation = await get(db_session=db_session, simulation_id=simulation_id)
async def delete_simulation(
db_session: DbSession,
simulation_id: str,
current_user: CurrentUser,
):
simulation = await get(
db_session=db_session,
simulation_id=simulation_id,
owner=current_user.name,
)
if not simulation:
raise HTTPException(

@ -39,14 +39,19 @@ PLANT_COPY_COLUMNS = [
]
async def get(*, db_session: DbSession, simulation_id: str) -> Optional[Simulation]:
async def get(
*, db_session: DbSession, simulation_id: str, owner: str
) -> Optional[Simulation]:
query = (
Select(Simulation)
.options(
selectinload(Simulation.masterdata_entries),
selectinload(Simulation.plant_transactions),
)
.where(Simulation.id == simulation_id)
.where(
Simulation.id == simulation_id,
Simulation.created_by == owner,
)
)
result = await db_session.execute(query)
return result.scalars().one_or_none()
@ -58,8 +63,13 @@ async def get_all(
items_per_page: Optional[int],
search: Optional[str],
common,
owner: str,
):
query = Select(Simulation).order_by(Simulation.created_at.desc())
query = (
Select(Simulation)
.where(Simulation.created_by == owner)
.order_by(Simulation.created_at.desc())
)
if search:
query = query.filter(Simulation.label.ilike(f"%{search}%"))

Loading…
Cancel
Save