You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

322 lines
13 KiB
Python

import json
import re
import logging
from collections import Counter
from fastapi import Request, HTTPException
from starlette.middleware.base import BaseHTTPMiddleware
# =========================
# Configuration
# =========================
ALLOWED_MULTI_PARAMS = {
"sortBy[]",
"descending[]",
"exclude[]",
}
ALLOWED_DATA_PARAMS = {
"AerosData", "AhmJobId", "CustomInput", "DurationUnit", "IsDefault", "Konkin_offset",
"MaintenanceOutages", "MasterData", "OffSet", "OverhaulDuration", "OverhaulInterval",
"PlannedOutages", "SchematicName", "SecretStr", "SimDuration", "SimNumRun", "SimSeed",
"SimulationName", "actual_shutdown", "aeros_node", "all_params", "aro_file",
"aro_file_path", "availability", "baseline_simulation_id", "calc_results",
"cm_dis_p1", "cm_dis_p2", "cm_dis_p3", "cm_dis_type", "cm_dis_unit_code",
"cm_waiting_time", "contribution", "contribution_factor", "created_at",
"criticality", "current_user", "custom_parameters", "data", "datetime",
"derating_hours", "descending", "design_flowrate", "duration", "duration_above_h",
"duration_above_hh", "duration_at_empty", "duration_at_full", "duration_below_l",
"duration_below_ll", "eaf", "eaf_konkin", "effective_loss", "efficiency", "efor",
"equipment", "equipment_name", "exclude", "failure_rates", "filter_spec", "finish",
"flowrate_unit", "id", "ideal_production", "ip_dis_p1", "ip_dis_p2", "ip_dis_p3",
"ip_dis_type", "ip_dis_unit_code", "items", "itemsPerPage", "items_per_page",
"level", "location_tag", "master_equipment", "max_flow_rate", "max_flowrate",
"message", "model_image", "mttr", "name", "node_id", "node_name", "node_type",
"num_cm", "num_events", "num_ip", "num_oh", "num_pm", "offset", "oh_dis_p1",
"oh_dis_p2", "oh_dis_p3", "oh_dis_type", "oh_dis_unit_code", "page",
"plan_duration", "planned_outage", "plot_results", "pm_dis_p1", "pm_dis_p2",
"pm_dis_p3", "pm_dis_type", "pm_dis_unit_code", "point_availabilities",
"point_flowrates", "production", "production_std", "project_name", "query_str",
"rel_dis_p1", "rel_dis_p2", "rel_dis_p3", "rel_dis_type", "rel_dis_unit_code",
"remark", "schematic_id", "schematic_name", "simulation_id", "simulation_name",
"sof", "sort_by", "sortBy[]", "descending[]", "exclude[]", "start", "started_at",
"status", "stg_input", "storage_capacity", "structure_name", "t_wait_for_crew",
"t_wait_for_spare", "target_simulation_id", "timestamp_outs", "total",
"totalPages", "total_cm_downtime", "total_downtime", "total_ip_downtime",
"total_oh_downtime", "total_pm_downtime", "total_uptime", "updated_at", "year",
"_", "t", "timestamp", "q", "filter", "currentUser"
}
ALLOWED_HEADERS = {
"host",
"user-agent",
"accept",
"accept-language",
"accept-encoding",
"connection",
"upgrade-insecure-requests",
"if-modified-since",
"if-none-match",
"cache-control",
"authorization",
"content-type",
"content-length",
"origin",
"referer",
"sec-fetch-dest",
"sec-fetch-mode",
"sec-fetch-site",
"sec-fetch-user",
"sec-ch-ua",
"sec-ch-ua-mobile",
"sec-ch-ua-platform",
"pragma",
"dnt",
"priority",
"x-forwarded-for",
"x-forwarded-proto",
"x-forwarded-host",
"x-forwarded-port",
"x-real-ip",
"x-request-id",
"x-correlation-id",
"x-requested-with",
"x-csrf-token",
"x-xsrf-token",
"postman-token",
"x-forwarded-path",
"x-forwarded-prefix",
"cookie",
}
MAX_QUERY_PARAMS = 50
MAX_QUERY_LENGTH = 2000
MAX_JSON_BODY_SIZE = 1024 * 500 # 500 KB
XSS_PATTERN = re.compile(
r"("
r"<(script|iframe|embed|object|svg|img|video|audio|base|link|meta|form|button|details|animate)\b|"
r"javascript\s*:|vbscript\s*:|data\s*:[^,]*base64[^,]*|data\s*:text/html|"
r"\bon[a-z]+\s*=|" # Catch-all for any 'on' event (onerror, onclick, etc.)
r"style\s*=.*expression\s*\(|" # Old IE specific
r"\b(eval|setTimeout|setInterval|Function)\s*\("
r")",
re.IGNORECASE,
)
SQLI_PATTERN = re.compile(
r"("
# 1. Keywords followed by whitespace and common SQL characters
r"\b(UNION|SELECT|INSERT|UPDATE|DELETE|DROP|ALTER|CREATE|TRUNCATE|EXEC(UTE)?|DECLARE)\b\s+[\w\*\(\']|"
# 2. Time-based attacks (more specific than just 'SLEEP')
r"\b(WAITFOR\b\s+DELAY|PG_SLEEP|SLEEP\s*\()|"
# 3. System tables/functions
r"\b(INFORMATION_SCHEMA|SYS\.|SYSOBJECTS|XP_CMDSHELL|LOAD_FILE|INTO\s+OUTFILE)\b|"
# 4. Logical Tautologies (OR 1=1) - Optimized for boundaries
r"\b(OR|AND)\b\s+['\"]?\d+['\"]?\s*=\s*['\"]?\d+|"
# 5. Comments
r"(?<!\S)--|(?<!\*)/\*|(?<!\*)\*/(?!\*)|(?<!\S)#|"
# 6. Hex / Stacked Queries
r";\s*\b(SELECT|DROP|DELETE|UPDATE|INSERT)\b"
r")",
re.IGNORECASE
)
RCE_PATTERN = re.compile(
r"("
r"\$\(.*\)|`.*`|" # Command substitution $(...) or `...`
r"[;&|]\s*(cat|ls|id|whoami|pwd|ifconfig|ip|netstat|nc|netcat|nmap|curl|wget|python|php|perl|ruby|bash|sh|cmd|powershell|pwsh|sc\s+|tasklist|taskkill|base64|sudo|crontab|ssh|ftp|tftp)|"
# Only flag naked commands if they are clearly standalone or system paths
r"\b(/etc/passwd|/etc/shadow|/etc/group|/etc/issue|/proc/self/|/windows/system32/|C:\\Windows\\)\b"
r")",
re.IGNORECASE,
)
TRAVERSAL_PATTERN = re.compile(
r"(\.\.[/\\]|%2e%2e%2f|%2e%2e/|\.\.%2f|%2e%2e%5c|%252e%252e%252f|\\00)",
re.IGNORECASE,
)
FORBIDDEN_JSON_KEYS = {"__proto__", "constructor", "prototype"}
DYNAMIC_KEYS = {
"data", "calc_results", "plot_results", "point_availabilities", "point_flowrates",
"failure_rates", "custom_parameters", "parameters", "results", "all_params"
}
log = logging.getLogger("security_logger")
def has_control_chars(value: str) -> bool:
return any(ord(c) < 32 and c not in ("\n", "\r", "\t") for c in value)
def inspect_value(value: str, source: str):
if not isinstance(value, str) or value == "*/*":
return
if XSS_PATTERN.search(value):
raise HTTPException(status_code=422, detail=f"Potential XSS payload detected in {source}")
if SQLI_PATTERN.search(value):
raise HTTPException(status_code=422, detail=f"Potential SQL injection payload detected in {source}")
if RCE_PATTERN.search(value):
raise HTTPException(status_code=422, detail=f"Potential RCE payload detected in {source}")
if TRAVERSAL_PATTERN.search(value):
raise HTTPException(status_code=422, detail=f"Potential Path Traversal payload detected in {source}")
if has_control_chars(value):
raise HTTPException(status_code=422, detail=f"Invalid control characters detected in {source}")
def inspect_json(obj, path="body", check_whitelist=True):
if isinstance(obj, dict):
for key, value in obj.items():
if key in FORBIDDEN_JSON_KEYS:
raise HTTPException(status_code=422, detail=f"Forbidden JSON key detected: {path}.{key}")
if check_whitelist and key not in ALLOWED_DATA_PARAMS:
raise HTTPException(status_code=422, detail=f"Unknown JSON key detected: {path}.{key}")
# Recurse. If the key is a dynamic container, we stop whitelist checking for children.
should_check_subkeys = check_whitelist and (key not in DYNAMIC_KEYS)
inspect_json(value, f"{path}.{key}", check_whitelist=should_check_subkeys)
elif isinstance(obj, list):
for i, item in enumerate(obj):
inspect_json(item, f"{path}[{i}]", check_whitelist=check_whitelist)
elif isinstance(obj, str):
inspect_value(obj, path)
# =========================
# Middleware
# =========================
class RequestValidationMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
# -------------------------
# 0. Header validation
# -------------------------
header_keys = [key.lower() for key, _ in request.headers.items()]
# Check for duplicate headers
header_counter = Counter(header_keys)
duplicate_headers = [key for key, count in header_counter.items() if count > 1]
ALLOW_DUPLICATE_HEADERS = {'accept', 'accept-encoding', 'accept-language', 'accept-charset', 'cookie'}
real_duplicates = [h for h in duplicate_headers if h not in ALLOW_DUPLICATE_HEADERS]
if real_duplicates:
raise HTTPException(status_code=422, detail=f"Duplicate headers are not allowed: {real_duplicates}")
# Whitelist headers
unknown_headers = [key for key in header_keys if key not in ALLOWED_HEADERS]
if unknown_headers:
filtered_unknown = [h for h in unknown_headers if not h.startswith('sec-')]
if filtered_unknown:
raise HTTPException(status_code=422, detail=f"Unknown headers detected: {filtered_unknown}")
# Inspect header values
for key, value in request.headers.items():
if value:
inspect_value(value, f"header '{key}'")
# -------------------------
# 1. Query string limits
# -------------------------
if len(request.url.query) > MAX_QUERY_LENGTH:
raise HTTPException(status_code=422, detail="Query string too long")
params = request.query_params.multi_items()
if len(params) > MAX_QUERY_PARAMS:
raise HTTPException(status_code=422, detail="Too many query parameters")
# Check for unknown query parameters
unknown_params = [key for key, _ in params if key not in ALLOWED_DATA_PARAMS]
if unknown_params:
raise HTTPException(status_code=422, detail=f"Unknown query parameters detected: {unknown_params}")
# -------------------------
# 2. Duplicate parameters
# -------------------------
counter = Counter(key for key, _ in params)
duplicates = [
key for key, count in counter.items()
if count > 1 and key not in ALLOWED_MULTI_PARAMS
]
if duplicates:
raise HTTPException(status_code=422, detail=f"Duplicate query parameters are not allowed: {duplicates}")
# -------------------------
# 3. Query param inspection & Pagination
# -------------------------
pagination_size_keys = {"size", "itemsPerPage", "per_page", "limit", "items_per_page"}
for key, value in params:
if value:
inspect_value(value, f"query param '{key}'")
if key in pagination_size_keys and value:
try:
size_val = int(value)
if size_val > 50:
raise HTTPException(status_code=422, detail=f"Pagination size '{key}' cannot exceed 50")
if size_val % 5 != 0:
raise HTTPException(status_code=422, detail=f"Pagination size '{key}' must be a multiple of 5")
except ValueError:
raise HTTPException(status_code=422, detail=f"Pagination size '{key}' must be an integer")
# -------------------------
# 4. Content-Type sanity
# -------------------------
content_type = request.headers.get("content-type", "")
if content_type and not any(
content_type.startswith(t)
for t in ("application/json", "multipart/form-data", "application/x-www-form-urlencoded")
):
raise HTTPException(status_code=422, detail="Unsupported Content-Type")
# -------------------------
# 5. Single source check (Query vs JSON Body)
# -------------------------
has_query = len(params) > 0
has_body = False
if content_type.startswith("application/json"):
content_length = request.headers.get("content-length")
if content_length and int(content_length) > 0:
has_body = True
if has_query and has_body:
raise HTTPException(status_code=422, detail="Parameters must be from a single source (query string or JSON body), mixed sources are not allowed")
# -------------------------
# 6. JSON body inspection
# -------------------------
if content_type.startswith("application/json"):
body = await request.body()
# if len(body) > MAX_JSON_BODY_SIZE:
# raise HTTPException(status_code=422, detail="JSON body too large")
if body:
try:
payload = json.loads(body)
except json.JSONDecodeError:
raise HTTPException(status_code=422, detail="Invalid JSON body")
inspect_json(payload)
# Re-inject body for downstream handlers
async def receive():
return {"type": "http.request", "body": body}
request._receive = receive
return await call_next(request)