From da2ce29c459e212748d7a250983ca3cb32068626 Mon Sep 17 00:00:00 2001 From: CIzz22 Date: Mon, 12 Jan 2026 06:34:54 +0000 Subject: [PATCH] add middleware --- src/middleware.py | 170 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 170 insertions(+) create mode 100644 src/middleware.py diff --git a/src/middleware.py b/src/middleware.py new file mode 100644 index 0000000..64e38b9 --- /dev/null +++ b/src/middleware.py @@ -0,0 +1,170 @@ +import json +import re +from collections import Counter +from fastapi import Request, HTTPException +from starlette.middleware.base import BaseHTTPMiddleware + +# ========================= +# Configuration +# ========================= + +ALLOWED_MULTI_PARAMS = { + "sortBy[]", + "descending[]", + "exclude[]", +} + +MAX_QUERY_PARAMS = 50 +MAX_QUERY_LENGTH = 2000 +MAX_JSON_BODY_SIZE = 1024 * 100 # 100 KB + +# Very targeted patterns. Avoid catastrophic regex nonsense. +XSS_PATTERN = re.compile( + r"( bool: + return any(ord(c) < 32 and c not in ("\n", "\r", "\t") for c in value) + + +def inspect_value(value: str, source: str): + if XSS_PATTERN.search(value): + raise HTTPException( + status_code=400, + detail=f"Potential XSS payload detected in {source}", + ) + + if SQLI_PATTERN.search(value): + raise HTTPException( + status_code=400, + detail=f"Potential SQL injection payload detected in {source}", + ) + + if has_control_chars(value): + raise HTTPException( + status_code=400, + detail=f"Invalid control characters detected in {source}", + ) + + +def inspect_json(obj, path="body"): + if isinstance(obj, dict): + for key, value in obj.items(): + if key in FORBIDDEN_JSON_KEYS: + raise HTTPException( + status_code=400, + detail=f"Forbidden JSON key detected: {path}.{key}", + ) + inspect_json(value, f"{path}.{key}") + elif isinstance(obj, list): + for i, item in enumerate(obj): + inspect_json(item, f"{path}[{i}]") + elif isinstance(obj, str): + inspect_value(obj, path) + + +# ========================= +# Middleware +# ========================= + +class RequestValidationMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + # ------------------------- + # 1. Query string limits + # ------------------------- + if len(request.url.query) > MAX_QUERY_LENGTH: + raise HTTPException( + status_code=414, + detail="Query string too long", + ) + + params = request.query_params.multi_items() + + if len(params) > MAX_QUERY_PARAMS: + raise HTTPException( + status_code=400, + detail="Too many query parameters", + ) + + # ------------------------- + # 2. Duplicate parameters + # ------------------------- + counter = Counter(key for key, _ in params) + duplicates = [ + key for key, count in counter.items() + if count > 1 and key not in ALLOWED_MULTI_PARAMS + ] + + if duplicates: + raise HTTPException( + status_code=400, + detail=f"Duplicate query parameters are not allowed: {duplicates}", + ) + + # ------------------------- + # 3. Query param inspection + # ------------------------- + for key, value in params: + if value: + inspect_value(value, f"query param '{key}'") + + # ------------------------- + # 4. Content-Type sanity + # ------------------------- + content_type = request.headers.get("content-type", "") + if content_type and not any( + content_type.startswith(t) + for t in ( + "application/json", + "multipart/form-data", + "application/x-www-form-urlencoded", + ) + ): + raise HTTPException( + status_code=415, + detail="Unsupported Content-Type", + ) + + # ------------------------- + # 5. JSON body inspection + # ------------------------- + if content_type.startswith("application/json"): + body = await request.body() + + if len(body) > MAX_JSON_BODY_SIZE: + raise HTTPException( + status_code=413, + detail="JSON body too large", + ) + + if body: + try: + payload = json.loads(body) + except json.JSONDecodeError: + raise HTTPException( + status_code=400, + detail="Invalid JSON body", + ) + + inspect_json(payload) + + # Re-inject body for downstream handlers + async def receive(): + return {"type": "http.request", "body": body} + + request._receive = receive # noqa: protected-access + + return await call_next(request)