From 0de7bc9bbe0cf198f12c0f6ffcd96b22d75a1e7b Mon Sep 17 00:00:00 2001 From: Cizz22 Date: Mon, 12 Jan 2026 11:21:52 +0700 Subject: [PATCH] add validation --- src/aeros_simulation/router.py | 17 ++-- src/aeros_simulation/schema.py | 33 ++++--- src/aeros_simulation/service.py | 6 +- src/dashboard_model/router.py | 8 +- src/dashboard_model/schema.py | 8 +- src/database/schema.py | 22 +++++ src/database/service.py | 50 ++++++---- src/main.py | 4 + src/middleware.py | 170 ++++++++++++++++++++++++++++++++ src/models.py | 2 + 10 files changed, 274 insertions(+), 46 deletions(-) create mode 100644 src/database/schema.py create mode 100644 src/middleware.py diff --git a/src/aeros_simulation/router.py b/src/aeros_simulation/router.py index f9abcda..660fcf2 100644 --- a/src/aeros_simulation/router.py +++ b/src/aeros_simulation/router.py @@ -1,9 +1,9 @@ from collections import defaultdict from datetime import datetime -from typing import List, Optional +from typing import Annotated, List, Optional from uuid import UUID from sqlalchemy.orm import selectinload -from fastapi import APIRouter, BackgroundTasks, HTTPException, background, status, Query +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, background, status, Query from sqlalchemy import select, text from temporalio.client import Client from src.aeros_contribution.service import update_contribution_bulk_mappings @@ -13,7 +13,7 @@ from src.aeros_simulation.utils import date_to_utc, hours_between, year_window_u from src.auth.service import CurrentUser from src.config import TEMPORAL_URL from src.database.core import CollectorDbSession, DbSession -from src.database.service import CommonParameters +from src.database.service import CommonParameters, get_params_factory from src.models import StandardResponse from src.aeros_equipment.service import update_equipment_for_simulation from src.aeros_project.service import get_project @@ -21,12 +21,14 @@ from temporal.workflow import SimulationWorkflow from .schema import ( AhmMetricInput, SimulationCalcResult, + SimulationCalcResultQuery, SimulationInput, SimulationPagination, SimulationPlot, SimulationPlotResult, SimulationCalc, SimulationData, + SimulationQueryModel, SimulationRankingParameters, YearlySimulationInput ) @@ -53,10 +55,9 @@ active_simulations = {} @router.get("", response_model=StandardResponse[SimulationPagination]) -async def get_all_simulation(db_session: DbSession, current_user:CurrentUser,common: CommonParameters, status: Optional[str] = Query(None)): +async def get_all_simulation(db_session: DbSession, current_user:CurrentUser, common:Annotated[dict, Depends(get_params_factory(SimulationQueryModel))]): """Get all simulation.""" - - results = await get_all(common, status, current_user) + results = await get_all(common, current_user) return { "data": results, @@ -223,8 +224,10 @@ async def run_yearly_simulation( "/result/calc/{simulation_id}", response_model=StandardResponse[List[SimulationCalc]], ) -async def get_simulation_result(db_session: DbSession, simulation_id, schematic_name: Optional[str] = Query(None), node_type = Query(None, alias="nodetype")): +async def get_simulation_result(db_session: DbSession, simulation_id, params:Annotated[SimulationCalcResultQuery, Query()]): """Get simulation result.""" + schematic_name = params.schematic_name + node_type = params.node_type if simulation_id == 'default': simulation = await get_default_simulation(db_session=db_session) simulation_id = simulation.id diff --git a/src/aeros_simulation/schema.py b/src/aeros_simulation/schema.py index 22b11be..e433bfd 100644 --- a/src/aeros_simulation/schema.py +++ b/src/aeros_simulation/schema.py @@ -4,11 +4,13 @@ from uuid import UUID from pydantic import Field -from src.models import BaseModel, Pagination +from src.database.schema import CommonParams +from src.database.service import CommonParameters +from src.models import DefultBase, Pagination from src.aeros_equipment.schema import MasterEquipment, EquipmentWithCustomParameters # Pydantic models for request/response validation -class SimulationInput(BaseModel): +class SimulationInput(DefultBase): SchematicName: str = "- TJB - Unit 3 -" SimSeed: int = 1 SimDuration: int = 3 @@ -25,7 +27,7 @@ class SimulationInput(BaseModel): OverhaulDuration: Optional[int] = Field(1200) AhmJobId: Optional[str] = Field(None) -class SimulationNode(BaseModel): +class SimulationNode(DefultBase): id: UUID node_type: Optional[str] node_id: Optional[int] @@ -36,7 +38,7 @@ class SimulationNode(BaseModel): model_image: Optional[list] = Field(None) equipment:Optional[MasterEquipment] = None -class SimulationCalc(BaseModel): +class SimulationCalc(DefultBase): id: UUID total_downtime: float total_uptime: float @@ -75,7 +77,7 @@ class SimulationCalc(BaseModel): contribution_factor: Optional[float] sof: Optional[float] -class SimulationPlot(BaseModel): +class SimulationPlot(DefultBase): id: UUID max_flow_rate: float storage_capacity: float @@ -90,16 +92,16 @@ class SimulationNodeWithResult(SimulationNode): calc_results: List[SimulationCalc] -class SimulationCalcResult(BaseModel): +class SimulationCalcResult(DefultBase): id: UUID calc_results: List[SimulationCalc] -class SimulationPlotResult(BaseModel): +class SimulationPlotResult(DefultBase): id: UUID plot_results: List[SimulationPlot] -class SimulationData(BaseModel): +class SimulationData(DefultBase): id: UUID simulation_name: str status: str @@ -117,9 +119,18 @@ class SimulationPagination(Pagination): items: List[SimulationData] = [] -class AhmMetricInput(BaseModel): +class AhmMetricInput(DefultBase): target_simulation_id: str baseline_simulation_id: Optional[str] = Field(None) -class YearlySimulationInput(BaseModel): - year: int \ No newline at end of file +class YearlySimulationInput(DefultBase): + year: int + + +class SimulationQueryModel(CommonParams): + status: Optional[str] = Field() + + +class SimulationCalcResultQuery(DefultBase): + schematic_name: Optional[str] = None + node_type: Optional[str] = Field(None, alias="nodetype") \ No newline at end of file diff --git a/src/aeros_simulation/service.py b/src/aeros_simulation/service.py index d37cf7f..c6765b3 100644 --- a/src/aeros_simulation/service.py +++ b/src/aeros_simulation/service.py @@ -37,11 +37,9 @@ active_simulations = {} # Get Data Service -async def get_all(common: CommonParameters, status, current_user): +async def get_all(common, current_user): query = select(AerosSimulation).order_by(desc(AerosSimulation.created_at)) - if status: - query = query.where(AerosSimulation.status == "completed") - + query = query.where(AerosSimulation.status == "completed") if current_user.role.lower() != "admin": query = query.where(AerosSimulation.created_by == current_user.user_id) diff --git a/src/dashboard_model/router.py b/src/dashboard_model/router.py index 967e049..83796b9 100644 --- a/src/dashboard_model/router.py +++ b/src/dashboard_model/router.py @@ -1,9 +1,10 @@ -from typing import List, Optional +from typing import Annotated, List, Optional from uuid import UUID -from fastapi import APIRouter, HTTPException, Query, status +from fastapi import APIRouter, Depends, HTTPException, Query, status from src.auth.service import CurrentUser +from src.dashboard_model.schema import DashboardModelQuery from src.database.core import DbSession from src.database.service import CommonParameters from src.models import StandardResponse @@ -16,8 +17,9 @@ router = APIRouter() @router.get("", response_model=StandardResponse[dict]) async def get_dashboard_model_data( db_session: DbSession, - simulation_id: Optional[UUID] = Query(None), + query:Annotated[DashboardModelQuery, Query()] ): + simulation_id = query.simulation_id result = await get_model_data(db_session=db_session, simulation_id=simulation_id) return StandardResponse( diff --git a/src/dashboard_model/schema.py b/src/dashboard_model/schema.py index b68f0a2..0183e06 100644 --- a/src/dashboard_model/schema.py +++ b/src/dashboard_model/schema.py @@ -4,7 +4,9 @@ # from pydantic import Field -# from src.models import DefultBase, Pagination +from typing import Optional +from uuid import UUID +from src.models import DefultBase # from src.overhaul_scope.schema import ScopeRead # from src.scope_equipment_job.schema import ScopeEquipmentJobRead # from src.job.schema import ActivityMasterRead @@ -41,3 +43,7 @@ # class OverhaulSchedulePagination(Pagination): # items: List[OverhaulScheduleRead] = [] + + +class DashboardModelQuery(DefultBase): + simulation_id : Optional[UUID] = None \ No newline at end of file diff --git a/src/database/schema.py b/src/database/schema.py new file mode 100644 index 0000000..67cb3c7 --- /dev/null +++ b/src/database/schema.py @@ -0,0 +1,22 @@ +from typing import Optional, List + +from pydantic import Field +from src.models import DefultBase + + +class CommonParams(DefultBase): + # This ensures no extra query params are allowed + current_user: Optional[str] = Field(None, alias="currentUser") + page: int = Field(1, gt=0, lt=2147483647) + items_per_page: int = Field(5, gt=-2, lt=2147483647) + query_str: Optional[str] = Field(None, alias="q") + filter_spec: Optional[str] = Field(None, alias="filter") + sort_by: List[str] = Field(default_factory=list, alias="sortBy[]") + descending: List[bool] = Field(default_factory=list, alias="descending[]") + exclude: List[str] = Field(default_factory=list, alias="exclude[]") + all_params: int = Field(0, alias="all") + + # Property to mirror your original return dict's bool conversion + @property + def is_all(self) -> bool: + return bool(self.all_params) \ No newline at end of file diff --git a/src/database/service.py b/src/database/service.py index 34850a8..20d3291 100644 --- a/src/database/service.py +++ b/src/database/service.py @@ -1,12 +1,14 @@ import logging -from typing import Annotated, List +from typing import Annotated, List, Type, TypeVar -from fastapi import Depends, Query +from fastapi import Depends, Query, Request from pydantic.types import Json, constr from sqlalchemy import Select, desc, func, or_ from sqlalchemy.exc import ProgrammingError from sqlalchemy_filters import apply_pagination +from src.database.schema import CommonParams + from .core import DbSession log = logging.getLogger(__name__) @@ -17,27 +19,19 @@ QueryStr = constr(pattern=r"^[ -~]+$", min_length=1) def common_parameters( db_session: DbSession, # type: ignore - current_user: QueryStr = Query(None, alias="currentUser"), # type: ignore - page: int = Query(1, gt=0, lt=2147483647), - items_per_page: int = Query(5, alias="itemsPerPage", gt=-2, lt=2147483647), - query_str: QueryStr = Query(None, alias="q"), # type: ignore - filter_spec: QueryStr = Query(None, alias="filter"), # type: ignore - sort_by: List[str] = Query([], alias="sortBy[]"), - descending: List[bool] = Query([], alias="descending[]"), - exclude: List[str] = Query([], alias="exclude[]"), - all: int = Query(0), + params: Annotated[CommonParams, Query()] # role: QueryStr = Depends(get_current_role), -): +): return { "db_session": db_session, - "page": page, - "items_per_page": items_per_page, - "query_str": query_str, - "filter_spec": filter_spec, - "sort_by": sort_by, - "descending": descending, - "current_user": current_user, - "all": bool(all), + "page": params.page, + "items_per_page": params.items_per_page, + "query_str": params.query_str, + "filter_spec": params.filter_spec, + "sort_by": params.sort_by, + "descending": params.descending, + "current_user": params.current_user, + "all": params.is_all, # "role": role, } @@ -47,6 +41,21 @@ CommonParameters = Annotated[ Depends(common_parameters), ] +T = TypeVar("T", bound=CommonParams) + +def get_params_factory(model_type: Type[T]): + async def wrapper( + db_session: DbSession, + params: Annotated[model_type, Query()] # type: ignore + ): + res = params.model_dump() + return { + "db_session": db_session, + "all": params.is_all, + **res + } + return wrapper + def search(*, query_str: str, query: Query, model, sort=False): """Perform a search based on the query.""" @@ -89,6 +98,7 @@ async def search_filter_sort_paginate( current_user: str = None, exclude: List[str] = None, all: bool = False, + **extra_params, ): """Common functionality for searching, filtering, sorting, and pagination.""" # try: diff --git a/src/main.py b/src/main.py index b49bc28..6f276c6 100644 --- a/src/main.py +++ b/src/main.py @@ -28,6 +28,7 @@ from src.database.core import async_session, engine,async_aeros_session from src.enums import ResponseStatus from src.exceptions import handle_exception from src.logging import setup_logging +from src.middleware import RequestValidationMiddleware from src.rate_limiter import limiter log = logging.getLogger(__name__) @@ -64,6 +65,8 @@ def get_request_id() -> Optional[str]: return _request_id_ctx_var.get() +app.add_middleware(RequestValidationMiddleware) + @app.middleware("http") async def db_session_middleware(request: Request, call_next): request_id = str(uuid1()) @@ -98,6 +101,7 @@ async def add_security_headers(request: Request, call_next): ) return response + app.mount("/model", StaticFiles(directory="model"), name="model") # class MetricsMiddleware(BaseHTTPMiddleware): diff --git a/src/middleware.py b/src/middleware.py new file mode 100644 index 0000000..64e38b9 --- /dev/null +++ b/src/middleware.py @@ -0,0 +1,170 @@ +import json +import re +from collections import Counter +from fastapi import Request, HTTPException +from starlette.middleware.base import BaseHTTPMiddleware + +# ========================= +# Configuration +# ========================= + +ALLOWED_MULTI_PARAMS = { + "sortBy[]", + "descending[]", + "exclude[]", +} + +MAX_QUERY_PARAMS = 50 +MAX_QUERY_LENGTH = 2000 +MAX_JSON_BODY_SIZE = 1024 * 100 # 100 KB + +# Very targeted patterns. Avoid catastrophic regex nonsense. +XSS_PATTERN = re.compile( + r"( bool: + return any(ord(c) < 32 and c not in ("\n", "\r", "\t") for c in value) + + +def inspect_value(value: str, source: str): + if XSS_PATTERN.search(value): + raise HTTPException( + status_code=400, + detail=f"Potential XSS payload detected in {source}", + ) + + if SQLI_PATTERN.search(value): + raise HTTPException( + status_code=400, + detail=f"Potential SQL injection payload detected in {source}", + ) + + if has_control_chars(value): + raise HTTPException( + status_code=400, + detail=f"Invalid control characters detected in {source}", + ) + + +def inspect_json(obj, path="body"): + if isinstance(obj, dict): + for key, value in obj.items(): + if key in FORBIDDEN_JSON_KEYS: + raise HTTPException( + status_code=400, + detail=f"Forbidden JSON key detected: {path}.{key}", + ) + inspect_json(value, f"{path}.{key}") + elif isinstance(obj, list): + for i, item in enumerate(obj): + inspect_json(item, f"{path}[{i}]") + elif isinstance(obj, str): + inspect_value(obj, path) + + +# ========================= +# Middleware +# ========================= + +class RequestValidationMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + # ------------------------- + # 1. Query string limits + # ------------------------- + if len(request.url.query) > MAX_QUERY_LENGTH: + raise HTTPException( + status_code=414, + detail="Query string too long", + ) + + params = request.query_params.multi_items() + + if len(params) > MAX_QUERY_PARAMS: + raise HTTPException( + status_code=400, + detail="Too many query parameters", + ) + + # ------------------------- + # 2. Duplicate parameters + # ------------------------- + counter = Counter(key for key, _ in params) + duplicates = [ + key for key, count in counter.items() + if count > 1 and key not in ALLOWED_MULTI_PARAMS + ] + + if duplicates: + raise HTTPException( + status_code=400, + detail=f"Duplicate query parameters are not allowed: {duplicates}", + ) + + # ------------------------- + # 3. Query param inspection + # ------------------------- + for key, value in params: + if value: + inspect_value(value, f"query param '{key}'") + + # ------------------------- + # 4. Content-Type sanity + # ------------------------- + content_type = request.headers.get("content-type", "") + if content_type and not any( + content_type.startswith(t) + for t in ( + "application/json", + "multipart/form-data", + "application/x-www-form-urlencoded", + ) + ): + raise HTTPException( + status_code=415, + detail="Unsupported Content-Type", + ) + + # ------------------------- + # 5. JSON body inspection + # ------------------------- + if content_type.startswith("application/json"): + body = await request.body() + + if len(body) > MAX_JSON_BODY_SIZE: + raise HTTPException( + status_code=413, + detail="JSON body too large", + ) + + if body: + try: + payload = json.loads(body) + except json.JSONDecodeError: + raise HTTPException( + status_code=400, + detail="Invalid JSON body", + ) + + inspect_json(payload) + + # Re-inject body for downstream handlers + async def receive(): + return {"type": "http.request", "body": body} + + request._receive = receive # noqa: protected-access + + return await call_next(request) diff --git a/src/models.py b/src/models.py index c938ef5..c7fe804 100644 --- a/src/models.py +++ b/src/models.py @@ -79,6 +79,8 @@ class DefultBase(BaseModel): # forbid extra/unexpected fields in input (prevents silent injection/mass assignment) extra = 'forbid' + + populate_by_name = True # secure JSON serialization: custom formatting for sensitive types json_encoders = {