add validation
parent
3f940b9a4f
commit
0de7bc9bbe
@ -0,0 +1,22 @@
|
||||
from typing import Optional, List
|
||||
|
||||
from pydantic import Field
|
||||
from src.models import DefultBase
|
||||
|
||||
|
||||
class CommonParams(DefultBase):
|
||||
# This ensures no extra query params are allowed
|
||||
current_user: Optional[str] = Field(None, alias="currentUser")
|
||||
page: int = Field(1, gt=0, lt=2147483647)
|
||||
items_per_page: int = Field(5, gt=-2, lt=2147483647)
|
||||
query_str: Optional[str] = Field(None, alias="q")
|
||||
filter_spec: Optional[str] = Field(None, alias="filter")
|
||||
sort_by: List[str] = Field(default_factory=list, alias="sortBy[]")
|
||||
descending: List[bool] = Field(default_factory=list, alias="descending[]")
|
||||
exclude: List[str] = Field(default_factory=list, alias="exclude[]")
|
||||
all_params: int = Field(0, alias="all")
|
||||
|
||||
# Property to mirror your original return dict's bool conversion
|
||||
@property
|
||||
def is_all(self) -> bool:
|
||||
return bool(self.all_params)
|
||||
@ -0,0 +1,170 @@
|
||||
import json
|
||||
import re
|
||||
from collections import Counter
|
||||
from fastapi import Request, HTTPException
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
# =========================
|
||||
# Configuration
|
||||
# =========================
|
||||
|
||||
ALLOWED_MULTI_PARAMS = {
|
||||
"sortBy[]",
|
||||
"descending[]",
|
||||
"exclude[]",
|
||||
}
|
||||
|
||||
MAX_QUERY_PARAMS = 50
|
||||
MAX_QUERY_LENGTH = 2000
|
||||
MAX_JSON_BODY_SIZE = 1024 * 100 # 100 KB
|
||||
|
||||
# Very targeted patterns. Avoid catastrophic regex nonsense.
|
||||
XSS_PATTERN = re.compile(
|
||||
r"(<script|</script|javascript:|onerror\s*=|onload\s*=|<svg|<img)",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
SQLI_PATTERN = re.compile(
|
||||
r"(\bUNION\b|\bSELECT\b|\bINSERT\b|\bDELETE\b|\bDROP\b|--|\bOR\b\s+1=1)",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# JSON prototype pollution keys
|
||||
FORBIDDEN_JSON_KEYS = {"__proto__", "constructor", "prototype"}
|
||||
|
||||
# =========================
|
||||
# Helpers
|
||||
# =========================
|
||||
|
||||
def has_control_chars(value: str) -> bool:
|
||||
return any(ord(c) < 32 and c not in ("\n", "\r", "\t") for c in value)
|
||||
|
||||
|
||||
def inspect_value(value: str, source: str):
|
||||
if XSS_PATTERN.search(value):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Potential XSS payload detected in {source}",
|
||||
)
|
||||
|
||||
if SQLI_PATTERN.search(value):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Potential SQL injection payload detected in {source}",
|
||||
)
|
||||
|
||||
if has_control_chars(value):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid control characters detected in {source}",
|
||||
)
|
||||
|
||||
|
||||
def inspect_json(obj, path="body"):
|
||||
if isinstance(obj, dict):
|
||||
for key, value in obj.items():
|
||||
if key in FORBIDDEN_JSON_KEYS:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Forbidden JSON key detected: {path}.{key}",
|
||||
)
|
||||
inspect_json(value, f"{path}.{key}")
|
||||
elif isinstance(obj, list):
|
||||
for i, item in enumerate(obj):
|
||||
inspect_json(item, f"{path}[{i}]")
|
||||
elif isinstance(obj, str):
|
||||
inspect_value(obj, path)
|
||||
|
||||
|
||||
# =========================
|
||||
# Middleware
|
||||
# =========================
|
||||
|
||||
class RequestValidationMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
# -------------------------
|
||||
# 1. Query string limits
|
||||
# -------------------------
|
||||
if len(request.url.query) > MAX_QUERY_LENGTH:
|
||||
raise HTTPException(
|
||||
status_code=414,
|
||||
detail="Query string too long",
|
||||
)
|
||||
|
||||
params = request.query_params.multi_items()
|
||||
|
||||
if len(params) > MAX_QUERY_PARAMS:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Too many query parameters",
|
||||
)
|
||||
|
||||
# -------------------------
|
||||
# 2. Duplicate parameters
|
||||
# -------------------------
|
||||
counter = Counter(key for key, _ in params)
|
||||
duplicates = [
|
||||
key for key, count in counter.items()
|
||||
if count > 1 and key not in ALLOWED_MULTI_PARAMS
|
||||
]
|
||||
|
||||
if duplicates:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Duplicate query parameters are not allowed: {duplicates}",
|
||||
)
|
||||
|
||||
# -------------------------
|
||||
# 3. Query param inspection
|
||||
# -------------------------
|
||||
for key, value in params:
|
||||
if value:
|
||||
inspect_value(value, f"query param '{key}'")
|
||||
|
||||
# -------------------------
|
||||
# 4. Content-Type sanity
|
||||
# -------------------------
|
||||
content_type = request.headers.get("content-type", "")
|
||||
if content_type and not any(
|
||||
content_type.startswith(t)
|
||||
for t in (
|
||||
"application/json",
|
||||
"multipart/form-data",
|
||||
"application/x-www-form-urlencoded",
|
||||
)
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=415,
|
||||
detail="Unsupported Content-Type",
|
||||
)
|
||||
|
||||
# -------------------------
|
||||
# 5. JSON body inspection
|
||||
# -------------------------
|
||||
if content_type.startswith("application/json"):
|
||||
body = await request.body()
|
||||
|
||||
if len(body) > MAX_JSON_BODY_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail="JSON body too large",
|
||||
)
|
||||
|
||||
if body:
|
||||
try:
|
||||
payload = json.loads(body)
|
||||
except json.JSONDecodeError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid JSON body",
|
||||
)
|
||||
|
||||
inspect_json(payload)
|
||||
|
||||
# Re-inject body for downstream handlers
|
||||
async def receive():
|
||||
return {"type": "http.request", "body": body}
|
||||
|
||||
request._receive = receive # noqa: protected-access
|
||||
|
||||
return await call_next(request)
|
||||
Loading…
Reference in New Issue