diff --git a/src/aeros_project/router.py b/src/aeros_project/router.py index fdc4d3e..c18f121 100644 --- a/src/aeros_project/router.py +++ b/src/aeros_project/router.py @@ -16,11 +16,18 @@ router = APIRouter() @router.post("", response_model=StandardResponse[None]) async def import_aro( db_session: DbSession, + current_user: CurrentUser, schematic_name: str = Form(...), aro_file: UploadFile = File(..., description="ARO file"), project_name: str = "trialapi" ): + if current_user.role.lower() != "admin": + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Admin access required to modify Aeros files." + ) + # Create the input object manually aeros_project_input = AerosProjectInput(schematic_name=schematic_name, aro_file=aro_file) diff --git a/src/aeros_simulation/model.py b/src/aeros_simulation/model.py index 51cf5fb..bf23a8b 100644 --- a/src/aeros_simulation/model.py +++ b/src/aeros_simulation/model.py @@ -19,6 +19,7 @@ class AerosSimulation(Base, DefaultMixin): duration = Column(Integer, nullable=True) offset = Column(Integer, nullable=True) is_default = Column(Boolean, default=False) + created_by = Column(UUID(as_uuid=True), nullable=True) calc_results = relationship( "AerosSimulationCalcResult", back_populates="aeros_simulation", lazy="raise" diff --git a/src/aeros_simulation/router.py b/src/aeros_simulation/router.py index 0822190..1449a6e 100644 --- a/src/aeros_simulation/router.py +++ b/src/aeros_simulation/router.py @@ -50,10 +50,10 @@ active_simulations = {} @router.get("", response_model=StandardResponse[SimulationPagination]) -async def get_all_simulation(db_session: DbSession, common: CommonParameters, status: Optional[str] = Query(None)): +async def get_all_simulation(db_session: DbSession, current_user:CurrentUser,common: CommonParameters, status: Optional[str] = Query(None)): """Get all simulation.""" - results = await get_all(common, status) + results = await get_all(common, status, current_user) return { "data": results, @@ -76,14 +76,14 @@ async def get_simulation(db_session: DbSession, simulation_id): async def run_simulations( db_session: DbSession, simulation_in: SimulationInput, - background_tasks: BackgroundTasks + current_user:CurrentUser ): """RUN Simulation""" temporal_client = await Client.connect(TEMPORAL_URL) simulation = await create_simulation( - db_session=db_session, simulation_in=simulation_in + db_session=db_session, simulation_in=simulation_in, current_user=current_user ) simulation_id = simulation.id diff --git a/src/aeros_simulation/service.py b/src/aeros_simulation/service.py index c0f666e..43624df 100644 --- a/src/aeros_simulation/service.py +++ b/src/aeros_simulation/service.py @@ -37,10 +37,13 @@ active_simulations = {} # Get Data Service -async def get_all(common: CommonParameters, status): +async def get_all(common: CommonParameters, status, current_user): query = select(AerosSimulation).order_by(desc(AerosSimulation.created_at)) if status: query = query.where(AerosSimulation.status == "completed") + + if current_user.role.lower() != "admin": + query = query.where(AerosSimulation.created_by == current_user.user_id) results = await search_filter_sort_paginate(model=query, **common) @@ -225,7 +228,7 @@ async def get_result_ranking(*, db_session: DbSession, simulation_id: UUID, limi ) ) - query = query.order_by(AerosSimulationCalcResult.availability.desc()) + query = query.order_by(AerosSimulationCalcResult.availability.asc()) if limit: query = query.limit(limit) @@ -987,9 +990,10 @@ def convert_id_to_none_if_negative(value): return None if value < 0 else value -async def create_simulation(*, db_session: DbSession, simulation_in: SimulationInput): +async def create_simulation(*, db_session: DbSession, simulation_in: SimulationInput, current_user): """Create a new simulation.""" input = simulation_in.model_dump(exclude={"SimulationName"}) + user_id = current_user.get("user_id") # Check if is default if simulation_in.IsDefault: @@ -1005,7 +1009,8 @@ async def create_simulation(*, db_session: DbSession, simulation_in: SimulationI "schematic_name": "- TJB - Unit 3 -", "is_default": simulation_in.IsDefault, "duration": simulation_in.SimDuration, - "offset": simulation_in.OffSet + "offset": simulation_in.OffSet, + "created_by": user_id } simulation = AerosSimulation(**active_simulations)