feat: Strengthen security middleware by updating XSS, SQLi, RCE, and path traversal detection patterns, introducing security violation logging, and refining header and query parameter validation.

main
Cizz22 1 week ago
parent ad53324311
commit d9a6687ba7

@ -161,22 +161,38 @@ def handle_exception(request: Request, exc: Exception):
) )
if isinstance(exc, (HTTPException, StarletteHTTPException)): if isinstance(exc, (HTTPException, StarletteHTTPException)):
# Log as warning for 4xx, error for 5xx
status_code = exc.status_code if hasattr(exc, "status_code") else 500
detail = exc.detail if hasattr(exc, "detail") else str(exc)
if 400 <= status_code < 500:
log.warning(
f"HTTP {status_code} occurred | Error ID: {error_id} | Detail: {detail}",
extra={
"error_id": error_id,
"error_category": "http",
"status_code": status_code,
"detail": detail,
"request": request_info,
},
)
else:
log.error( log.error(
f"HTTP exception occurred | Error ID: {error_id}", f"HTTP {status_code} occurred | Error ID: {error_id} | Detail: {detail}",
extra={ extra={
"error_id": error_id, "error_id": error_id,
"error_category": "http", "error_category": "http",
"status_code": exc.status_code, "status_code": status_code,
"detail": exc.detail if hasattr(exc, "detail") else str(exc), "detail": detail,
"request": request_info, "request": request_info,
}, },
) )
return JSONResponse( return JSONResponse(
status_code=exc.status_code, status_code=status_code,
content={ content={
"data": None, "data": None,
"message": str(exc.detail) if hasattr(exc, "detail") else str(exc), "message": str(detail),
"status": ResponseStatus.ERROR, "status": ResponseStatus.ERROR,
"error_id": error_id "error_id": error_id
}, },
@ -184,6 +200,19 @@ def handle_exception(request: Request, exc: Exception):
if isinstance(exc, SQLAlchemyError): if isinstance(exc, SQLAlchemyError):
error_message, status_code = handle_sqlalchemy_error(exc) error_message, status_code = handle_sqlalchemy_error(exc)
# Log integrity errors as warning, others as error
if 400 <= status_code < 500:
log.warning(
f"Database integrity/validation error occurred | Error ID: {error_id}",
extra={
"error_id": error_id,
"error_category": "database",
"error_message": error_message,
"request": request_info,
"exception": str(exc),
},
)
else:
log.error( log.error(
f"Database error occurred | Error ID: {error_id}", f"Database error occurred | Error ID: {error_id}",
extra={ extra={

@ -1,5 +1,6 @@
import json import json
import re import re
import logging
from collections import Counter from collections import Counter
from fastapi import Request, HTTPException from fastapi import Request, HTTPException
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
@ -99,7 +100,9 @@ ALLOWED_HEADERS = {
"x-csrf-token", "x-csrf-token",
"x-xsrf-token", "x-xsrf-token",
"postman-token", "postman-token",
"x-internal-key", "x-forwarded-path",
"x-forwarded-prefix",
"cookie",
} }
MAX_QUERY_PARAMS = 50 MAX_QUERY_PARAMS = 50
@ -107,32 +110,56 @@ MAX_QUERY_LENGTH = 2000
MAX_JSON_BODY_SIZE = 1024 * 500 # 500 KB MAX_JSON_BODY_SIZE = 1024 * 500 # 500 KB
XSS_PATTERN = re.compile( XSS_PATTERN = re.compile(
r"(<script|<iframe|<embed|<object|<svg|<img|<video|<audio|<base|<link|<meta|<form|<button|" r"("
r"javascript:|vbscript:|data:text/html|onerror\s*=|onload\s*=|onmouseover\s*=|onfocus\s*=|" r"<(script|iframe|embed|object|svg|img|video|audio|base|link|meta|form|button|details|animate)\b|"
r"onclick\s*=|onscroll\s*=|ondblclick\s*=|onkeydown\s*=|onkeypress\s*=|onkeyup\s*=|" r"javascript\s*:|vbscript\s*:|data\s*:[^,]*base64[^,]*|data\s*:text/html|"
r"onloadstart\s*=|onpageshow\s*=|onresize\s*=|onunload\s*=|style\s*=\s*['\"].*expression\s*\(|" r"\bon[a-z]+\s*=|" # Catch-all for any 'on' event (onerror, onclick, etc.)
r"eval\s*\(|setTimeout\s*\(|setInterval\s*\(|Function\s*\()", r"style\s*=.*expression\s*\(|" # Old IE specific
r"\b(eval|setTimeout|setInterval|Function)\s*\("
r")",
re.IGNORECASE, re.IGNORECASE,
) )
SQLI_PATTERN = re.compile( SQLI_PATTERN = re.compile(
r"(\bUNION\b|\bSELECT\b|\bINSERT\b|\bUPDATE\b|\bDELETE\b|\bDROP\b|\bALTER\b|\bCREATE\b|\bTRUNCATE\b|" r"("
r"\bEXEC\b|\bEXECUTE\b|\bDECLARE\b|\bWAITFOR\b|\bDELAY\b|\bGROUP\b\s+\bBY\b|\bHAVING\b|\bORDER\b\s+\bBY\b|" # 1. Keywords followed by whitespace and common SQL characters
r"\bINFORMATION_SCHEMA\b|\bSYS\b\.|\bSYSOBJECTS\b|\bPG_SLEEP\b|\bSLEEP\b\(|--|/\*|\*/|#|\bOR\b\s+['\"]?\d+['\"]?\s*=\s*['\"]?\d+|" r"\b(UNION|SELECT|INSERT|UPDATE|DELETE|DROP|ALTER|CREATE|TRUNCATE|EXEC(UTE)?|DECLARE)\b\s+[\w\*\(\']|"
r"\bAND\b\s+['\"]?\d+['\"]?\s*=\s*['\"]?\d+|"
r"\bXP_CMDSHELL\b|\bLOAD_FILE\b|\bINTO\s+OUTFILE\b)", # 2. Time-based attacks (more specific than just 'SLEEP')
re.IGNORECASE, r"\b(WAITFOR\b\s+DELAY|PG_SLEEP|SLEEP\s*\()|"
# 3. System tables/functions
r"\b(INFORMATION_SCHEMA|SYS\.|SYSOBJECTS|XP_CMDSHELL|LOAD_FILE|INTO\s+OUTFILE)\b|"
# 4. Logical Tautologies (OR 1=1) - Optimized for boundaries
r"\b(OR|AND)\b\s+['\"]?\d+['\"]?\s*=\s*['\"]?\d+|"
# 5. Comments
# Match '--' if at start or preceded by whitespace
r"(?<!\S)--|"
# Match block comments, ensuring they aren't part of mime patterns like */*
r"(?<!\*)/\*|(?<!\*)\*/(?!\*)|"
# Match '#' if at start or preceded by whitespace
r"(?<!\S)#|"
# 6. Hex / Stacked Queries
r";\s*\b(SELECT|DROP|DELETE|UPDATE|INSERT)\b"
r")",
re.IGNORECASE
) )
RCE_PATTERN = re.compile( RCE_PATTERN = re.compile(
r"(\$\(|`.*`|[;&|]\s*(cat|ls|id|whoami|pwd|ifconfig|ip|netstat|nc|netcat|nmap|curl|wget|python|php|perl|ruby|bash|sh|cmd|powershell|pwsh|sc\s+|tasklist|taskkill|base64|sudo|crontab|ssh|ftp|tftp)|" r"("
r"\b(cat|ls|id|whoami|pwd|ifconfig|ip|netstat|nc|netcat|nmap|curl|wget|python|php|perl|ruby|bash|sh|cmd|powershell|pwsh|base64|sudo|crontab)\b|" r"\$\(.*\)|`.*`|" # Command substitution $(...) or `...`
r"/etc/passwd|/etc/shadow|/etc/group|/etc/issue|/proc/self/|/windows/system32/|C:\\Windows\\)", r"[;&|]\s*(cat|ls|id|whoami|pwd|ifconfig|ip|netstat|nc|netcat|nmap|curl|wget|python|php|perl|ruby|bash|sh|cmd|powershell|pwsh|sc\s+|tasklist|taskkill|base64|sudo|crontab|ssh|ftp|tftp)|"
# Only flag naked commands if they are clearly standalone or system paths
r"\b(/etc/passwd|/etc/shadow|/etc/group|/etc/issue|/proc/self/|/windows/system32/|C:\\Windows\\)\b"
r")",
re.IGNORECASE, re.IGNORECASE,
) )
TRAVERSAL_PATTERN = re.compile( TRAVERSAL_PATTERN = re.compile(
r"(\.\./|\.\.\\|%2e%2e%2f|%2e%2e/|\.\.%2f|%2e%2e%5c)", r"(\.\.[/\\]|%2e%2e%2f|%2e%2e/|\.\.%2f|%2e%2e%5c|%252e%252e%252f|\\00)",
re.IGNORECASE, re.IGNORECASE,
) )
@ -153,39 +180,70 @@ DYNAMIC_KEYS = {
"program_data" "program_data"
} }
# ========================= log = logging.getLogger("security_logger")
# Helpers
# =========================
def has_control_chars(value: str) -> bool: def has_control_chars(value: str) -> bool:
return any(ord(c) < 32 and c not in ("\n", "\r", "\t") for c in value) return any(ord(c) < 32 and c not in ("\n", "\r", "\t") for c in value)
def inspect_value(value: str, source: str): def inspect_value(value: str, source: str):
if not isinstance(value, str) or value == "*/*":
return
if XSS_PATTERN.search(value): if XSS_PATTERN.search(value):
raise HTTPException(status_code=422, detail=f"Potential XSS payload detected in {source}") log.warning(f"Security violation: Potential XSS payload detected in {source}")
raise HTTPException(
status_code=422,
detail=f"Potential XSS payload detected in {source}",
)
if SQLI_PATTERN.search(value): if SQLI_PATTERN.search(value):
raise HTTPException(status_code=422, detail=f"Potential SQL injection payload detected in {source}") log.warning(f"Security violation: Potential SQL injection payload detected in {source}")
raise HTTPException(
status_code=422,
detail=f"Potential SQL injection payload detected in {source}",
)
if RCE_PATTERN.search(value): if RCE_PATTERN.search(value):
raise HTTPException(status_code=422, detail=f"Potential RCE payload detected in {source}") log.warning(f"Security violation: Potential RCE payload detected in {source}")
raise HTTPException(
status_code=422,
detail=f"Potential RCE payload detected in {source}",
)
if TRAVERSAL_PATTERN.search(value): if TRAVERSAL_PATTERN.search(value):
raise HTTPException(status_code=422, detail=f"Potential Path Traversal payload detected in {source}") log.warning(f"Security violation: Potential Path Traversal payload detected in {source}")
raise HTTPException(
status_code=422,
detail=f"Potential Path Traversal payload detected in {source}",
)
if has_control_chars(value): if has_control_chars(value):
raise HTTPException(status_code=422, detail=f"Invalid control characters detected in {source}") log.warning(f"Security violation: Invalid control characters detected in {source}")
raise HTTPException(
status_code=422,
detail=f"Invalid control characters detected in {source}",
)
def inspect_json(obj, path="body", check_whitelist=True): def inspect_json(obj, path="body", check_whitelist=True):
if isinstance(obj, dict): if isinstance(obj, dict):
for key, value in obj.items(): for key, value in obj.items():
if key in FORBIDDEN_JSON_KEYS: if key in FORBIDDEN_JSON_KEYS:
raise HTTPException(status_code=422, detail=f"Forbidden JSON key detected: {path}.{key}") log.warning(f"Security violation: Forbidden JSON key detected: {path}.{key}")
raise HTTPException(
status_code=422,
detail=f"Forbidden JSON key detected: {path}.{key}",
)
if check_whitelist and key not in ALLOWED_DATA_PARAMS: if check_whitelist and key not in ALLOWED_DATA_PARAMS:
raise HTTPException(status_code=422, detail=f"Unknown JSON key detected: {path}.{key}") log.warning(f"Security violation: Unknown JSON key detected: {path}.{key}")
raise HTTPException(
status_code=422,
detail=f"Unknown JSON key detected: {path}.{key}",
)
# Recurse. If the key is a dynamic container, we stop whitelist checking for children. # Recurse. If the key is a dynamic container, we stop whitelist checking for children.
should_check_subkeys = check_whitelist and (key not in DYNAMIC_KEYS) should_check_subkeys = check_whitelist and (key not in DYNAMIC_KEYS)
@ -215,14 +273,22 @@ class RequestValidationMiddleware(BaseHTTPMiddleware):
ALLOW_DUPLICATE_HEADERS = {'accept', 'accept-encoding', 'accept-language', 'accept-charset', 'cookie'} ALLOW_DUPLICATE_HEADERS = {'accept', 'accept-encoding', 'accept-language', 'accept-charset', 'cookie'}
real_duplicates = [h for h in duplicate_headers if h not in ALLOW_DUPLICATE_HEADERS] real_duplicates = [h for h in duplicate_headers if h not in ALLOW_DUPLICATE_HEADERS]
if real_duplicates: if real_duplicates:
raise HTTPException(status_code=422, detail=f"Duplicate headers are not allowed: {real_duplicates}") log.warning(f"Security violation: Duplicate headers detected: {real_duplicates}")
raise HTTPException(
status_code=422,
detail=f"Duplicate headers are not allowed: {real_duplicates}",
)
# Whitelist headers # Whitelist headers
unknown_headers = [key for key in header_keys if key not in ALLOWED_HEADERS] unknown_headers = [key for key in header_keys if key not in ALLOWED_HEADERS]
if unknown_headers: if unknown_headers:
filtered_unknown = [h for h in unknown_headers if not h.startswith('sec-')] filtered_unknown = [h for h in unknown_headers if not h.startswith('sec-')]
if filtered_unknown: if filtered_unknown:
raise HTTPException(status_code=422, detail=f"Unknown headers detected: {filtered_unknown}") log.warning(f"Security violation: Unknown headers detected: {filtered_unknown}")
raise HTTPException(
status_code=422,
detail=f"Unknown headers detected: {filtered_unknown}",
)
# Inspect header values # Inspect header values
for key, value in request.headers.items(): for key, value in request.headers.items():
@ -233,17 +299,29 @@ class RequestValidationMiddleware(BaseHTTPMiddleware):
# 1. Query string limits # 1. Query string limits
# ------------------------- # -------------------------
if len(request.url.query) > MAX_QUERY_LENGTH: if len(request.url.query) > MAX_QUERY_LENGTH:
raise HTTPException(status_code=422, detail="Query string too long") log.warning(f"Security violation: Query string too long")
raise HTTPException(
status_code=422,
detail="Query string too long",
)
params = request.query_params.multi_items() params = request.query_params.multi_items()
if len(params) > MAX_QUERY_PARAMS: if len(params) > MAX_QUERY_PARAMS:
raise HTTPException(status_code=422, detail="Too many query parameters") log.warning(f"Security violation: Too many query parameters")
raise HTTPException(
status_code=422,
detail="Too many query parameters",
)
# Check for unknown query parameters # Check for unknown query parameters
unknown_params = [key for key, _ in params if key not in ALLOWED_DATA_PARAMS] unknown_params = [key for key, _ in params if key not in ALLOWED_DATA_PARAMS]
if unknown_params: if unknown_params:
raise HTTPException(status_code=422, detail=f"Unknown query parameters detected: {unknown_params}") log.warning(f"Security violation: Unknown query parameters detected: {unknown_params}")
raise HTTPException(
status_code=422,
detail=f"Unknown query parameters detected: {unknown_params}",
)
# ------------------------- # -------------------------
# 2. Duplicate parameters # 2. Duplicate parameters
@ -255,7 +333,11 @@ class RequestValidationMiddleware(BaseHTTPMiddleware):
] ]
if duplicates: if duplicates:
raise HTTPException(status_code=422, detail=f"Duplicate query parameters are not allowed: {duplicates}") log.warning(f"Security violation: Duplicate query parameters detected: {duplicates}")
raise HTTPException(
status_code=422,
detail=f"Duplicate query parameters are not allowed: {duplicates}",
)
# ------------------------- # -------------------------
# 3. Query param inspection & Pagination # 3. Query param inspection & Pagination
@ -269,11 +351,23 @@ class RequestValidationMiddleware(BaseHTTPMiddleware):
try: try:
size_val = int(value) size_val = int(value)
if size_val > 50: if size_val > 50:
raise HTTPException(status_code=422, detail=f"Pagination size '{key}' cannot exceed 50") log.warning(f"Security violation: Pagination size too large ({size_val})")
raise HTTPException(
status_code=422,
detail=f"Pagination size '{key}' cannot exceed 50",
)
if size_val % 5 != 0: if size_val % 5 != 0:
raise HTTPException(status_code=422, detail=f"Pagination size '{key}' must be a multiple of 5") log.warning(f"Security violation: Pagination size not multiple of 5 ({size_val})")
raise HTTPException(
status_code=422,
detail=f"Pagination size '{key}' must be a multiple of 5",
)
except ValueError: except ValueError:
raise HTTPException(status_code=422, detail=f"Pagination size '{key}' must be an integer") log.warning(f"Security violation: Pagination size invalid value ({value})")
raise HTTPException(
status_code=422,
detail=f"Pagination size '{key}' must be an integer",
)
# ------------------------- # -------------------------
# 4. Content-Type sanity # 4. Content-Type sanity
@ -283,6 +377,7 @@ class RequestValidationMiddleware(BaseHTTPMiddleware):
content_type.startswith(t) content_type.startswith(t)
for t in ("application/json", "multipart/form-data", "application/x-www-form-urlencoded") for t in ("application/json", "multipart/form-data", "application/x-www-form-urlencoded")
): ):
log.warning(f"Security violation: Unsupported Content-Type: {content_type}")
raise HTTPException(status_code=422, detail="Unsupported Content-Type") raise HTTPException(status_code=422, detail="Unsupported Content-Type")
# ------------------------- # -------------------------
@ -299,7 +394,11 @@ class RequestValidationMiddleware(BaseHTTPMiddleware):
has_body = True has_body = True
if has_query and has_body: if has_query and has_body:
raise HTTPException(status_code=422, detail="Parameters must be from a single source (query string or JSON body), mixed sources are not allowed") log.warning(f"Security violation: Mixed parameters (query + JSON body)")
raise HTTPException(
status_code=422,
detail="Parameters must be from a single source (query string or JSON body), mixed sources are not allowed",
)
# ------------------------- # -------------------------
# 6. JSON body inspection # 6. JSON body inspection

Loading…
Cancel
Save