add validation

main
Cizz22 10 hours ago
parent 3f940b9a4f
commit 0de7bc9bbe

@ -1,9 +1,9 @@
from collections import defaultdict
from datetime import datetime
from typing import List, Optional
from typing import Annotated, List, Optional
from uuid import UUID
from sqlalchemy.orm import selectinload
from fastapi import APIRouter, BackgroundTasks, HTTPException, background, status, Query
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, background, status, Query
from sqlalchemy import select, text
from temporalio.client import Client
from src.aeros_contribution.service import update_contribution_bulk_mappings
@ -13,7 +13,7 @@ from src.aeros_simulation.utils import date_to_utc, hours_between, year_window_u
from src.auth.service import CurrentUser
from src.config import TEMPORAL_URL
from src.database.core import CollectorDbSession, DbSession
from src.database.service import CommonParameters
from src.database.service import CommonParameters, get_params_factory
from src.models import StandardResponse
from src.aeros_equipment.service import update_equipment_for_simulation
from src.aeros_project.service import get_project
@ -21,12 +21,14 @@ from temporal.workflow import SimulationWorkflow
from .schema import (
AhmMetricInput,
SimulationCalcResult,
SimulationCalcResultQuery,
SimulationInput,
SimulationPagination,
SimulationPlot,
SimulationPlotResult,
SimulationCalc,
SimulationData,
SimulationQueryModel,
SimulationRankingParameters,
YearlySimulationInput
)
@ -53,10 +55,9 @@ active_simulations = {}
@router.get("", response_model=StandardResponse[SimulationPagination])
async def get_all_simulation(db_session: DbSession, current_user:CurrentUser,common: CommonParameters, status: Optional[str] = Query(None)):
async def get_all_simulation(db_session: DbSession, current_user:CurrentUser, common:Annotated[dict, Depends(get_params_factory(SimulationQueryModel))]):
"""Get all simulation."""
results = await get_all(common, status, current_user)
results = await get_all(common, current_user)
return {
"data": results,
@ -223,8 +224,10 @@ async def run_yearly_simulation(
"/result/calc/{simulation_id}",
response_model=StandardResponse[List[SimulationCalc]],
)
async def get_simulation_result(db_session: DbSession, simulation_id, schematic_name: Optional[str] = Query(None), node_type = Query(None, alias="nodetype")):
async def get_simulation_result(db_session: DbSession, simulation_id, params:Annotated[SimulationCalcResultQuery, Query()]):
"""Get simulation result."""
schematic_name = params.schematic_name
node_type = params.node_type
if simulation_id == 'default':
simulation = await get_default_simulation(db_session=db_session)
simulation_id = simulation.id

@ -4,11 +4,13 @@ from uuid import UUID
from pydantic import Field
from src.models import BaseModel, Pagination
from src.database.schema import CommonParams
from src.database.service import CommonParameters
from src.models import DefultBase, Pagination
from src.aeros_equipment.schema import MasterEquipment, EquipmentWithCustomParameters
# Pydantic models for request/response validation
class SimulationInput(BaseModel):
class SimulationInput(DefultBase):
SchematicName: str = "- TJB - Unit 3 -"
SimSeed: int = 1
SimDuration: int = 3
@ -25,7 +27,7 @@ class SimulationInput(BaseModel):
OverhaulDuration: Optional[int] = Field(1200)
AhmJobId: Optional[str] = Field(None)
class SimulationNode(BaseModel):
class SimulationNode(DefultBase):
id: UUID
node_type: Optional[str]
node_id: Optional[int]
@ -36,7 +38,7 @@ class SimulationNode(BaseModel):
model_image: Optional[list] = Field(None)
equipment:Optional[MasterEquipment] = None
class SimulationCalc(BaseModel):
class SimulationCalc(DefultBase):
id: UUID
total_downtime: float
total_uptime: float
@ -75,7 +77,7 @@ class SimulationCalc(BaseModel):
contribution_factor: Optional[float]
sof: Optional[float]
class SimulationPlot(BaseModel):
class SimulationPlot(DefultBase):
id: UUID
max_flow_rate: float
storage_capacity: float
@ -90,16 +92,16 @@ class SimulationNodeWithResult(SimulationNode):
calc_results: List[SimulationCalc]
class SimulationCalcResult(BaseModel):
class SimulationCalcResult(DefultBase):
id: UUID
calc_results: List[SimulationCalc]
class SimulationPlotResult(BaseModel):
class SimulationPlotResult(DefultBase):
id: UUID
plot_results: List[SimulationPlot]
class SimulationData(BaseModel):
class SimulationData(DefultBase):
id: UUID
simulation_name: str
status: str
@ -117,9 +119,18 @@ class SimulationPagination(Pagination):
items: List[SimulationData] = []
class AhmMetricInput(BaseModel):
class AhmMetricInput(DefultBase):
target_simulation_id: str
baseline_simulation_id: Optional[str] = Field(None)
class YearlySimulationInput(BaseModel):
year: int
class YearlySimulationInput(DefultBase):
year: int
class SimulationQueryModel(CommonParams):
status: Optional[str] = Field()
class SimulationCalcResultQuery(DefultBase):
schematic_name: Optional[str] = None
node_type: Optional[str] = Field(None, alias="nodetype")

@ -37,11 +37,9 @@ active_simulations = {}
# Get Data Service
async def get_all(common: CommonParameters, status, current_user):
async def get_all(common, current_user):
query = select(AerosSimulation).order_by(desc(AerosSimulation.created_at))
if status:
query = query.where(AerosSimulation.status == "completed")
query = query.where(AerosSimulation.status == "completed")
if current_user.role.lower() != "admin":
query = query.where(AerosSimulation.created_by == current_user.user_id)

@ -1,9 +1,10 @@
from typing import List, Optional
from typing import Annotated, List, Optional
from uuid import UUID
from fastapi import APIRouter, HTTPException, Query, status
from fastapi import APIRouter, Depends, HTTPException, Query, status
from src.auth.service import CurrentUser
from src.dashboard_model.schema import DashboardModelQuery
from src.database.core import DbSession
from src.database.service import CommonParameters
from src.models import StandardResponse
@ -16,8 +17,9 @@ router = APIRouter()
@router.get("", response_model=StandardResponse[dict])
async def get_dashboard_model_data(
db_session: DbSession,
simulation_id: Optional[UUID] = Query(None),
query:Annotated[DashboardModelQuery, Query()]
):
simulation_id = query.simulation_id
result = await get_model_data(db_session=db_session, simulation_id=simulation_id)
return StandardResponse(

@ -4,7 +4,9 @@
# from pydantic import Field
# from src.models import DefultBase, Pagination
from typing import Optional
from uuid import UUID
from src.models import DefultBase
# from src.overhaul_scope.schema import ScopeRead
# from src.scope_equipment_job.schema import ScopeEquipmentJobRead
# from src.job.schema import ActivityMasterRead
@ -41,3 +43,7 @@
# class OverhaulSchedulePagination(Pagination):
# items: List[OverhaulScheduleRead] = []
class DashboardModelQuery(DefultBase):
simulation_id : Optional[UUID] = None

@ -0,0 +1,22 @@
from typing import Optional, List
from pydantic import Field
from src.models import DefultBase
class CommonParams(DefultBase):
# This ensures no extra query params are allowed
current_user: Optional[str] = Field(None, alias="currentUser")
page: int = Field(1, gt=0, lt=2147483647)
items_per_page: int = Field(5, gt=-2, lt=2147483647)
query_str: Optional[str] = Field(None, alias="q")
filter_spec: Optional[str] = Field(None, alias="filter")
sort_by: List[str] = Field(default_factory=list, alias="sortBy[]")
descending: List[bool] = Field(default_factory=list, alias="descending[]")
exclude: List[str] = Field(default_factory=list, alias="exclude[]")
all_params: int = Field(0, alias="all")
# Property to mirror your original return dict's bool conversion
@property
def is_all(self) -> bool:
return bool(self.all_params)

@ -1,12 +1,14 @@
import logging
from typing import Annotated, List
from typing import Annotated, List, Type, TypeVar
from fastapi import Depends, Query
from fastapi import Depends, Query, Request
from pydantic.types import Json, constr
from sqlalchemy import Select, desc, func, or_
from sqlalchemy.exc import ProgrammingError
from sqlalchemy_filters import apply_pagination
from src.database.schema import CommonParams
from .core import DbSession
log = logging.getLogger(__name__)
@ -17,27 +19,19 @@ QueryStr = constr(pattern=r"^[ -~]+$", min_length=1)
def common_parameters(
db_session: DbSession, # type: ignore
current_user: QueryStr = Query(None, alias="currentUser"), # type: ignore
page: int = Query(1, gt=0, lt=2147483647),
items_per_page: int = Query(5, alias="itemsPerPage", gt=-2, lt=2147483647),
query_str: QueryStr = Query(None, alias="q"), # type: ignore
filter_spec: QueryStr = Query(None, alias="filter"), # type: ignore
sort_by: List[str] = Query([], alias="sortBy[]"),
descending: List[bool] = Query([], alias="descending[]"),
exclude: List[str] = Query([], alias="exclude[]"),
all: int = Query(0),
params: Annotated[CommonParams, Query()]
# role: QueryStr = Depends(get_current_role),
):
):
return {
"db_session": db_session,
"page": page,
"items_per_page": items_per_page,
"query_str": query_str,
"filter_spec": filter_spec,
"sort_by": sort_by,
"descending": descending,
"current_user": current_user,
"all": bool(all),
"page": params.page,
"items_per_page": params.items_per_page,
"query_str": params.query_str,
"filter_spec": params.filter_spec,
"sort_by": params.sort_by,
"descending": params.descending,
"current_user": params.current_user,
"all": params.is_all,
# "role": role,
}
@ -47,6 +41,21 @@ CommonParameters = Annotated[
Depends(common_parameters),
]
T = TypeVar("T", bound=CommonParams)
def get_params_factory(model_type: Type[T]):
async def wrapper(
db_session: DbSession,
params: Annotated[model_type, Query()] # type: ignore
):
res = params.model_dump()
return {
"db_session": db_session,
"all": params.is_all,
**res
}
return wrapper
def search(*, query_str: str, query: Query, model, sort=False):
"""Perform a search based on the query."""
@ -89,6 +98,7 @@ async def search_filter_sort_paginate(
current_user: str = None,
exclude: List[str] = None,
all: bool = False,
**extra_params,
):
"""Common functionality for searching, filtering, sorting, and pagination."""
# try:

@ -28,6 +28,7 @@ from src.database.core import async_session, engine,async_aeros_session
from src.enums import ResponseStatus
from src.exceptions import handle_exception
from src.logging import setup_logging
from src.middleware import RequestValidationMiddleware
from src.rate_limiter import limiter
log = logging.getLogger(__name__)
@ -64,6 +65,8 @@ def get_request_id() -> Optional[str]:
return _request_id_ctx_var.get()
app.add_middleware(RequestValidationMiddleware)
@app.middleware("http")
async def db_session_middleware(request: Request, call_next):
request_id = str(uuid1())
@ -98,6 +101,7 @@ async def add_security_headers(request: Request, call_next):
)
return response
app.mount("/model", StaticFiles(directory="model"), name="model")
# class MetricsMiddleware(BaseHTTPMiddleware):

@ -0,0 +1,170 @@
import json
import re
from collections import Counter
from fastapi import Request, HTTPException
from starlette.middleware.base import BaseHTTPMiddleware
# =========================
# Configuration
# =========================
ALLOWED_MULTI_PARAMS = {
"sortBy[]",
"descending[]",
"exclude[]",
}
MAX_QUERY_PARAMS = 50
MAX_QUERY_LENGTH = 2000
MAX_JSON_BODY_SIZE = 1024 * 100 # 100 KB
# Very targeted patterns. Avoid catastrophic regex nonsense.
XSS_PATTERN = re.compile(
r"(<script|</script|javascript:|onerror\s*=|onload\s*=|<svg|<img)",
re.IGNORECASE,
)
SQLI_PATTERN = re.compile(
r"(\bUNION\b|\bSELECT\b|\bINSERT\b|\bDELETE\b|\bDROP\b|--|\bOR\b\s+1=1)",
re.IGNORECASE,
)
# JSON prototype pollution keys
FORBIDDEN_JSON_KEYS = {"__proto__", "constructor", "prototype"}
# =========================
# Helpers
# =========================
def has_control_chars(value: str) -> bool:
return any(ord(c) < 32 and c not in ("\n", "\r", "\t") for c in value)
def inspect_value(value: str, source: str):
if XSS_PATTERN.search(value):
raise HTTPException(
status_code=400,
detail=f"Potential XSS payload detected in {source}",
)
if SQLI_PATTERN.search(value):
raise HTTPException(
status_code=400,
detail=f"Potential SQL injection payload detected in {source}",
)
if has_control_chars(value):
raise HTTPException(
status_code=400,
detail=f"Invalid control characters detected in {source}",
)
def inspect_json(obj, path="body"):
if isinstance(obj, dict):
for key, value in obj.items():
if key in FORBIDDEN_JSON_KEYS:
raise HTTPException(
status_code=400,
detail=f"Forbidden JSON key detected: {path}.{key}",
)
inspect_json(value, f"{path}.{key}")
elif isinstance(obj, list):
for i, item in enumerate(obj):
inspect_json(item, f"{path}[{i}]")
elif isinstance(obj, str):
inspect_value(obj, path)
# =========================
# Middleware
# =========================
class RequestValidationMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
# -------------------------
# 1. Query string limits
# -------------------------
if len(request.url.query) > MAX_QUERY_LENGTH:
raise HTTPException(
status_code=414,
detail="Query string too long",
)
params = request.query_params.multi_items()
if len(params) > MAX_QUERY_PARAMS:
raise HTTPException(
status_code=400,
detail="Too many query parameters",
)
# -------------------------
# 2. Duplicate parameters
# -------------------------
counter = Counter(key for key, _ in params)
duplicates = [
key for key, count in counter.items()
if count > 1 and key not in ALLOWED_MULTI_PARAMS
]
if duplicates:
raise HTTPException(
status_code=400,
detail=f"Duplicate query parameters are not allowed: {duplicates}",
)
# -------------------------
# 3. Query param inspection
# -------------------------
for key, value in params:
if value:
inspect_value(value, f"query param '{key}'")
# -------------------------
# 4. Content-Type sanity
# -------------------------
content_type = request.headers.get("content-type", "")
if content_type and not any(
content_type.startswith(t)
for t in (
"application/json",
"multipart/form-data",
"application/x-www-form-urlencoded",
)
):
raise HTTPException(
status_code=415,
detail="Unsupported Content-Type",
)
# -------------------------
# 5. JSON body inspection
# -------------------------
if content_type.startswith("application/json"):
body = await request.body()
if len(body) > MAX_JSON_BODY_SIZE:
raise HTTPException(
status_code=413,
detail="JSON body too large",
)
if body:
try:
payload = json.loads(body)
except json.JSONDecodeError:
raise HTTPException(
status_code=400,
detail="Invalid JSON body",
)
inspect_json(payload)
# Re-inject body for downstream handlers
async def receive():
return {"type": "http.request", "body": body}
request._receive = receive # noqa: protected-access
return await call_next(request)

@ -79,6 +79,8 @@ class DefultBase(BaseModel):
# forbid extra/unexpected fields in input (prevents silent injection/mass assignment)
extra = 'forbid'
populate_by_name = True
# secure JSON serialization: custom formatting for sensitive types
json_encoders = {

Loading…
Cancel
Save