You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
223 lines
7.5 KiB
Python
223 lines
7.5 KiB
Python
import json
|
|
import re
|
|
from collections import Counter
|
|
from fastapi import Request, HTTPException
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
|
|
# =========================
|
|
# Configuration
|
|
# =========================
|
|
|
|
ALLOWED_MULTI_PARAMS = {
|
|
"sortBy[]",
|
|
"descending[]",
|
|
"exclude[]",
|
|
}
|
|
|
|
MAX_QUERY_PARAMS = 50
|
|
MAX_QUERY_LENGTH = 2000
|
|
MAX_JSON_BODY_SIZE = 1024 * 100 # 100 KB
|
|
|
|
XSS_PATTERN = re.compile(
|
|
r"(<script|<iframe|<embed|<object|<svg|<img|<video|<audio|<base|<link|<meta|<form|<button|"
|
|
r"javascript:|vbscript:|data:text/html|onerror\s*=|onload\s*=|onmouseover\s*=|onfocus\s*=|"
|
|
r"onclick\s*=|onscroll\s*=|ondblclick\s*=|onkeydown\s*=|onkeypress\s*=|onkeyup\s*=|"
|
|
r"onloadstart\s*=|onpageshow\s*=|onresize\s*=|onunload\s*=|style\s*=\s*['\"].*expression\s*\(|"
|
|
r"eval\s*\(|setTimeout\s*\(|setInterval\s*\(|Function\s*\()",
|
|
re.IGNORECASE,
|
|
)
|
|
|
|
SQLI_PATTERN = re.compile(
|
|
r"(\bUNION\b|\bSELECT\b|\bINSERT\b|\bUPDATE\b|\bDELETE\b|\bDROP\b|\bALTER\b|\bCREATE\b|\bTRUNCATE\b|"
|
|
r"\bEXEC\b|\bEXECUTE\b|\bDECLARE\b|\bWAITFOR\b|\bDELAY\b|\bGROUP\b\s+\bBY\b|\bHAVING\b|\bORDER\b\s+\bBY\b|"
|
|
r"\bINFORMATION_SCHEMA\b|\bSYS\b\.|\bSYSOBJECTS\b|\bPG_SLEEP\b|\bSLEEP\b\(|--|/\*|\*/|#|\bOR\b\s+['\"]?\d+['\"]?\s*=\s*['\"]?\d+|"
|
|
r"\bAND\b\s+['\"]?\d+['\"]?\s*=\s*['\"]?\d+|"
|
|
r"\bXP_CMDSHELL\b|\bLOAD_FILE\b|\bINTO\s+OUTFILE\b)",
|
|
re.IGNORECASE,
|
|
)
|
|
|
|
RCE_PATTERN = re.compile(
|
|
r"(\$\(|`.*`|[;&|]\s*(cat|ls|id|whoami|pwd|ifconfig|ip|netstat|nc|netcat|nmap|curl|wget|python|php|perl|ruby|bash|sh|cmd|powershell|pwsh|sc\s+|tasklist|taskkill|base64|sudo|crontab|ssh|ftp|tftp)|"
|
|
r"\b(cat|ls|id|whoami|pwd|ifconfig|ip|netstat|nc|netcat|nmap|curl|wget|python|php|perl|ruby|bash|sh|cmd|powershell|pwsh|base64|sudo|crontab)\b|"
|
|
r"/etc/passwd|/etc/shadow|/etc/group|/etc/issue|/proc/self/|/windows/system32/|C:\\Windows\\)",
|
|
re.IGNORECASE,
|
|
)
|
|
|
|
TRAVERSAL_PATTERN = re.compile(
|
|
r"(\.\./|\.\.\\|%2e%2e%2f|%2e%2e/|\.\.%2f|%2e%2e%5c)",
|
|
re.IGNORECASE,
|
|
)
|
|
|
|
# JSON prototype pollution keys
|
|
FORBIDDEN_JSON_KEYS = {"__proto__", "constructor", "prototype"}
|
|
|
|
# =========================
|
|
# Helpers
|
|
# =========================
|
|
|
|
def has_control_chars(value: str) -> bool:
|
|
return any(ord(c) < 32 and c not in ("\n", "\r", "\t") for c in value)
|
|
|
|
|
|
def inspect_value(value: str, source: str):
|
|
if XSS_PATTERN.search(value):
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Potential XSS payload detected in {source}",
|
|
)
|
|
|
|
if SQLI_PATTERN.search(value):
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Potential SQL injection payload detected in {source}",
|
|
)
|
|
|
|
if RCE_PATTERN.search(value):
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Potential RCE payload detected in {source}",
|
|
)
|
|
|
|
if TRAVERSAL_PATTERN.search(value):
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Potential Path Traversal payload detected in {source}",
|
|
)
|
|
|
|
if has_control_chars(value):
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Invalid control characters detected in {source}",
|
|
)
|
|
|
|
|
|
def inspect_json(obj, path="body"):
|
|
if isinstance(obj, dict):
|
|
for key, value in obj.items():
|
|
if key in FORBIDDEN_JSON_KEYS:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Forbidden JSON key detected: {path}.{key}",
|
|
)
|
|
inspect_json(value, f"{path}.{key}")
|
|
elif isinstance(obj, list):
|
|
for i, item in enumerate(obj):
|
|
inspect_json(item, f"{path}[{i}]")
|
|
elif isinstance(obj, str):
|
|
inspect_value(obj, path)
|
|
|
|
|
|
# =========================
|
|
# Middleware
|
|
# =========================
|
|
|
|
class RequestValidationMiddleware(BaseHTTPMiddleware):
|
|
async def dispatch(self, request: Request, call_next):
|
|
# -------------------------
|
|
# 1. Query string limits
|
|
# -------------------------
|
|
if len(request.url.query) > MAX_QUERY_LENGTH:
|
|
raise HTTPException(
|
|
status_code=414,
|
|
detail="Query string too long",
|
|
)
|
|
|
|
params = request.query_params.multi_items()
|
|
|
|
if len(params) > MAX_QUERY_PARAMS:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Too many query parameters",
|
|
)
|
|
|
|
# -------------------------
|
|
# 2. Duplicate parameters
|
|
# -------------------------
|
|
counter = Counter(key for key, _ in params)
|
|
duplicates = [
|
|
key for key, count in counter.items()
|
|
if count > 1 and key not in ALLOWED_MULTI_PARAMS
|
|
]
|
|
|
|
if duplicates:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Duplicate query parameters are not allowed: {duplicates}",
|
|
)
|
|
|
|
# -------------------------
|
|
# 3. Query param inspection
|
|
# -------------------------
|
|
pagination_size_keys = {"size", "itemsPerPage", "per_page", "limit", "items_per_page"}
|
|
for key, value in params:
|
|
if value:
|
|
inspect_value(value, f"query param '{key}'")
|
|
|
|
# Pagination constraint: multiples of 5, max 50
|
|
if key in pagination_size_keys and value:
|
|
try:
|
|
size_val = int(value)
|
|
if size_val > 50:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Pagination size '{key}' cannot exceed 50",
|
|
)
|
|
if size_val % 5 != 0:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Pagination size '{key}' must be a multiple of 5",
|
|
)
|
|
except ValueError:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Pagination size '{key}' must be an integer",
|
|
)
|
|
|
|
# -------------------------
|
|
# 4. Content-Type sanity
|
|
# -------------------------
|
|
content_type = request.headers.get("content-type", "")
|
|
if content_type and not any(
|
|
content_type.startswith(t)
|
|
for t in (
|
|
"application/json",
|
|
"multipart/form-data",
|
|
"application/x-www-form-urlencoded",
|
|
)
|
|
):
|
|
raise HTTPException(
|
|
status_code=415,
|
|
detail="Unsupported Content-Type",
|
|
)
|
|
|
|
# -------------------------
|
|
# 5. JSON body inspection
|
|
# -------------------------
|
|
if content_type.startswith("application/json"):
|
|
body = await request.body()
|
|
|
|
#if len(body) > MAX_JSON_BODY_SIZE:
|
|
# raise HTTPException(
|
|
# status_code=413,
|
|
# detail="JSON body too large",
|
|
# )
|
|
|
|
if body:
|
|
try:
|
|
payload = json.loads(body)
|
|
except json.JSONDecodeError:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Invalid JSON body",
|
|
)
|
|
|
|
inspect_json(payload)
|
|
|
|
# Re-inject body for downstream handlers
|
|
async def receive():
|
|
return {"type": "http.request", "body": body}
|
|
|
|
request._receive = receive # noqa: protected-access
|
|
|
|
return await call_next(request)
|