add middleware

main
Cizz22 9 hours ago
parent 3c4c534930
commit c27cef35eb

@ -27,6 +27,7 @@ from src.database.core import async_session, engine, async_collector_session
from src.enums import ResponseStatus
from src.exceptions import handle_exception
from src.logging import configure_logging
from src.middleware import RequestValidationMiddleware
from src.rate_limiter import limiter
log = logging.getLogger(__name__)
@ -120,6 +121,10 @@ def security_headers_middleware(app: FastAPI):
security_headers_middleware(app)
app.add_middleware(RequestValidationMiddleware)
@app.middleware("http")
async def db_session_middleware(request: Request, call_next):
request_id = str(uuid1())

@ -0,0 +1,170 @@
import json
import re
from collections import Counter
from fastapi import Request, HTTPException
from starlette.middleware.base import BaseHTTPMiddleware
# =========================
# Configuration
# =========================
ALLOWED_MULTI_PARAMS = {
"sortBy[]",
"descending[]",
"exclude[]",
}
MAX_QUERY_PARAMS = 50
MAX_QUERY_LENGTH = 2000
MAX_JSON_BODY_SIZE = 1024 * 100 # 100 KB
# Very targeted patterns. Avoid catastrophic regex nonsense.
XSS_PATTERN = re.compile(
r"(<script|</script|javascript:|onerror\s*=|onload\s*=|<svg|<img)",
re.IGNORECASE,
)
SQLI_PATTERN = re.compile(
r"(\bUNION\b|\bSELECT\b|\bINSERT\b|\bDELETE\b|\bDROP\b|--|\bOR\b\s+1=1)",
re.IGNORECASE,
)
# JSON prototype pollution keys
FORBIDDEN_JSON_KEYS = {"__proto__", "constructor", "prototype"}
# =========================
# Helpers
# =========================
def has_control_chars(value: str) -> bool:
return any(ord(c) < 32 and c not in ("\n", "\r", "\t") for c in value)
def inspect_value(value: str, source: str):
if XSS_PATTERN.search(value):
raise HTTPException(
status_code=400,
detail=f"Potential XSS payload detected in {source}",
)
if SQLI_PATTERN.search(value):
raise HTTPException(
status_code=400,
detail=f"Potential SQL injection payload detected in {source}",
)
if has_control_chars(value):
raise HTTPException(
status_code=400,
detail=f"Invalid control characters detected in {source}",
)
def inspect_json(obj, path="body"):
if isinstance(obj, dict):
for key, value in obj.items():
if key in FORBIDDEN_JSON_KEYS:
raise HTTPException(
status_code=400,
detail=f"Forbidden JSON key detected: {path}.{key}",
)
inspect_json(value, f"{path}.{key}")
elif isinstance(obj, list):
for i, item in enumerate(obj):
inspect_json(item, f"{path}[{i}]")
elif isinstance(obj, str):
inspect_value(obj, path)
# =========================
# Middleware
# =========================
class RequestValidationMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
# -------------------------
# 1. Query string limits
# -------------------------
if len(request.url.query) > MAX_QUERY_LENGTH:
raise HTTPException(
status_code=414,
detail="Query string too long",
)
params = request.query_params.multi_items()
if len(params) > MAX_QUERY_PARAMS:
raise HTTPException(
status_code=400,
detail="Too many query parameters",
)
# -------------------------
# 2. Duplicate parameters
# -------------------------
counter = Counter(key for key, _ in params)
duplicates = [
key for key, count in counter.items()
if count > 1 and key not in ALLOWED_MULTI_PARAMS
]
if duplicates:
raise HTTPException(
status_code=400,
detail=f"Duplicate query parameters are not allowed: {duplicates}",
)
# -------------------------
# 3. Query param inspection
# -------------------------
for key, value in params:
if value:
inspect_value(value, f"query param '{key}'")
# -------------------------
# 4. Content-Type sanity
# -------------------------
content_type = request.headers.get("content-type", "")
if content_type and not any(
content_type.startswith(t)
for t in (
"application/json",
"multipart/form-data",
"application/x-www-form-urlencoded",
)
):
raise HTTPException(
status_code=415,
detail="Unsupported Content-Type",
)
# -------------------------
# 5. JSON body inspection
# -------------------------
if content_type.startswith("application/json"):
body = await request.body()
if len(body) > MAX_JSON_BODY_SIZE:
raise HTTPException(
status_code=413,
detail="JSON body too large",
)
if body:
try:
payload = json.loads(body)
except json.JSONDecodeError:
raise HTTPException(
status_code=400,
detail="Invalid JSON body",
)
inspect_json(payload)
# Re-inject body for downstream handlers
async def receive():
return {"type": "http.request", "body": body}
request._receive = receive # noqa: protected-access
return await call_next(request)
Loading…
Cancel
Save