import json import re from collections import Counter from fastapi import Request, HTTPException from starlette.middleware.base import BaseHTTPMiddleware # ========================= # Configuration # ========================= ALLOWED_MULTI_PARAMS = { "sortBy[]", "descending[]", "exclude[]", } MAX_QUERY_PARAMS = 50 MAX_QUERY_LENGTH = 2000 MAX_JSON_BODY_SIZE = 1024 * 100 # 100 KB # Very targeted patterns. Avoid catastrophic regex nonsense. XSS_PATTERN = re.compile( r"( bool: return any(ord(c) < 32 and c not in ("\n", "\r", "\t") for c in value) def inspect_value(value: str, source: str): if XSS_PATTERN.search(value): raise HTTPException( status_code=400, detail=f"Potential XSS payload detected in {source}", ) if SQLI_PATTERN.search(value): raise HTTPException( status_code=400, detail=f"Potential SQL injection payload detected in {source}", ) if has_control_chars(value): raise HTTPException( status_code=400, detail=f"Invalid control characters detected in {source}", ) def inspect_json(obj, path="body"): if isinstance(obj, dict): for key, value in obj.items(): if key in FORBIDDEN_JSON_KEYS: raise HTTPException( status_code=400, detail=f"Forbidden JSON key detected: {path}.{key}", ) inspect_json(value, f"{path}.{key}") elif isinstance(obj, list): for i, item in enumerate(obj): inspect_json(item, f"{path}[{i}]") elif isinstance(obj, str): inspect_value(obj, path) # ========================= # Middleware # ========================= class RequestValidationMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): # ------------------------- # 1. Query string limits # ------------------------- if len(request.url.query) > MAX_QUERY_LENGTH: raise HTTPException( status_code=414, detail="Query string too long", ) params = request.query_params.multi_items() if len(params) > MAX_QUERY_PARAMS: raise HTTPException( status_code=400, detail="Too many query parameters", ) # ------------------------- # 2. Duplicate parameters # ------------------------- counter = Counter(key for key, _ in params) duplicates = [ key for key, count in counter.items() if count > 1 and key not in ALLOWED_MULTI_PARAMS ] if duplicates: raise HTTPException( status_code=400, detail=f"Duplicate query parameters are not allowed: {duplicates}", ) # ------------------------- # 3. Query param inspection # ------------------------- for key, value in params: if value: inspect_value(value, f"query param '{key}'") # ------------------------- # 4. Content-Type sanity # ------------------------- content_type = request.headers.get("content-type", "") if content_type and not any( content_type.startswith(t) for t in ( "application/json", "multipart/form-data", "application/x-www-form-urlencoded", ) ): raise HTTPException( status_code=415, detail="Unsupported Content-Type", ) # ------------------------- # 5. JSON body inspection # ------------------------- if content_type.startswith("application/json"): body = await request.body() if len(body) > MAX_JSON_BODY_SIZE: raise HTTPException( status_code=413, detail="JSON body too large", ) if body: try: payload = json.loads(body) except json.JSONDecodeError: raise HTTPException( status_code=400, detail="Invalid JSON body", ) inspect_json(payload) # Re-inject body for downstream handlers async def receive(): return {"type": "http.request", "body": body} request._receive = receive # noqa: protected-access return await call_next(request)