import json import re import logging from collections import Counter from fastapi import Request, HTTPException from starlette.middleware.base import BaseHTTPMiddleware # ========================= # Configuration # ========================= ALLOWED_MULTI_PARAMS = { "sortBy[]", "descending[]", "exclude[]", } ALLOWED_DATA_PARAMS = { "AerosData", "AhmJobId", "CustomInput", "DurationUnit", "IsDefault", "Konkin_offset", "MaintenanceOutages", "MasterData", "OffSet", "OverhaulDuration", "OverhaulInterval", "PlannedOutages", "SchematicName", "SecretStr", "SimDuration", "SimNumRun", "SimSeed", "SimulationName", "actual_shutdown", "aeros_node", "all_params", "aro_file", "aro_file_path", "availability", "baseline_simulation_id", "calc_results", "cm_dis_p1", "cm_dis_p2", "cm_dis_p3", "cm_dis_type", "cm_dis_unit_code", "cm_waiting_time", "contribution", "contribution_factor", "created_at", "criticality", "current_user", "custom_parameters", "data", "datetime", "derating_hours", "descending", "design_flowrate", "duration", "duration_above_h", "duration_above_hh", "duration_at_empty", "duration_at_full", "duration_below_l", "duration_below_ll", "eaf", "eaf_konkin", "effective_loss", "efficiency", "efor", "equipment", "equipment_name", "exclude", "failure_rates", "filter_spec", "finish", "flowrate_unit", "id", "ideal_production", "ip_dis_p1", "ip_dis_p2", "ip_dis_p3", "ip_dis_type", "ip_dis_unit_code", "items", "itemsPerPage", "items_per_page", "level", "location_tag", "master_equipment", "max_flow_rate", "max_flowrate", "message", "model_image", "mttr", "name", "node_id", "node_name", "node_type", "num_cm", "num_events", "num_ip", "num_oh", "num_pm", "offset", "oh_dis_p1", "oh_dis_p2", "oh_dis_p3", "oh_dis_type", "oh_dis_unit_code", "page", "plan_duration", "planned_outage", "plot_results", "pm_dis_p1", "pm_dis_p2", "pm_dis_p3", "pm_dis_type", "pm_dis_unit_code", "point_availabilities", "point_flowrates", "production", "production_std", "project_name", "query_str", "rel_dis_p1", "rel_dis_p2", "rel_dis_p3", "rel_dis_type", "rel_dis_unit_code", "remark", "schematic_id", "schematic_name", "simulation_id", "simulation_name", "sof", "sort_by", "sortBy[]", "descending[]", "exclude[]", "start", "started_at", "status", "stg_input", "storage_capacity", "structure_name", "t_wait_for_crew", "t_wait_for_spare", "target_simulation_id", "timestamp_outs", "total", "totalPages", "total_cm_downtime", "total_downtime", "total_ip_downtime", "total_oh_downtime", "total_pm_downtime", "total_uptime", "updated_at", "year", "_", "t", "timestamp", "q", "filter", "currentUser" } ALLOWED_HEADERS = { "host", "user-agent", "accept", "accept-language", "accept-encoding", "connection", "upgrade-insecure-requests", "if-modified-since", "if-none-match", "cache-control", "authorization", "content-type", "content-length", "origin", "referer", "sec-fetch-dest", "sec-fetch-mode", "sec-fetch-site", "sec-fetch-user", "sec-ch-ua", "sec-ch-ua-mobile", "sec-ch-ua-platform", "pragma", "dnt", "priority", "x-forwarded-for", "x-forwarded-proto", "x-forwarded-host", "x-forwarded-port", "x-real-ip", "x-request-id", "x-correlation-id", "x-requested-with", "x-csrf-token", "x-xsrf-token", "postman-token", "x-forwarded-path", "x-forwarded-prefix", "cookie", } MAX_QUERY_PARAMS = 50 MAX_QUERY_LENGTH = 2000 MAX_JSON_BODY_SIZE = 1024 * 500 # 500 KB XSS_PATTERN = re.compile( r"(" r"<(script|iframe|embed|object|svg|img|video|audio|base|link|meta|form|button|details|animate)\b|" r"javascript\s*:|vbscript\s*:|data\s*:[^,]*base64[^,]*|data\s*:text/html|" r"\bon[a-z]+\s*=|" # Catch-all for any 'on' event (onerror, onclick, etc.) r"style\s*=.*expression\s*\(|" # Old IE specific r"\b(eval|setTimeout|setInterval|Function)\s*\(" r")", re.IGNORECASE, ) SQLI_PATTERN = re.compile( r"(" # 1. Keywords followed by whitespace and common SQL characters r"\b(UNION|SELECT|INSERT|UPDATE|DELETE|DROP|ALTER|CREATE|TRUNCATE|EXEC(UTE)?|DECLARE)\b\s+[\w\*\(\']|" # 2. Time-based attacks (more specific than just 'SLEEP') r"\b(WAITFOR\b\s+DELAY|PG_SLEEP|SLEEP\s*\()|" # 3. System tables/functions r"\b(INFORMATION_SCHEMA|SYS\.|SYSOBJECTS|XP_CMDSHELL|LOAD_FILE|INTO\s+OUTFILE)\b|" # 4. Logical Tautologies (OR 1=1) - Optimized for boundaries r"\b(OR|AND)\b\s+['\"]?\d+['\"]?\s*=\s*['\"]?\d+|" # 5. Comments r"(? bool: return any(ord(c) < 32 and c not in ("\n", "\r", "\t") for c in value) def inspect_value(value: str, source: str): if not isinstance(value, str) or value == "*/*": return if XSS_PATTERN.search(value): raise HTTPException(status_code=422, detail=f"Potential XSS payload detected in {source}") if SQLI_PATTERN.search(value): raise HTTPException(status_code=422, detail=f"Potential SQL injection payload detected in {source}") if RCE_PATTERN.search(value): raise HTTPException(status_code=422, detail=f"Potential RCE payload detected in {source}") if TRAVERSAL_PATTERN.search(value): raise HTTPException(status_code=422, detail=f"Potential Path Traversal payload detected in {source}") if has_control_chars(value): raise HTTPException(status_code=422, detail=f"Invalid control characters detected in {source}") def inspect_json(obj, path="body", check_whitelist=True): if isinstance(obj, dict): for key, value in obj.items(): if key in FORBIDDEN_JSON_KEYS: raise HTTPException(status_code=422, detail=f"Forbidden JSON key detected: {path}.{key}") if check_whitelist and key not in ALLOWED_DATA_PARAMS: raise HTTPException(status_code=422, detail=f"Unknown JSON key detected: {path}.{key}") # Recurse. If the key is a dynamic container, we stop whitelist checking for children. should_check_subkeys = check_whitelist and (key not in DYNAMIC_KEYS) inspect_json(value, f"{path}.{key}", check_whitelist=should_check_subkeys) elif isinstance(obj, list): for i, item in enumerate(obj): inspect_json(item, f"{path}[{i}]", check_whitelist=check_whitelist) elif isinstance(obj, str): inspect_value(obj, path) # ========================= # Middleware # ========================= class RequestValidationMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): # ------------------------- # 0. Header validation # ------------------------- header_keys = [key.lower() for key, _ in request.headers.items()] # Check for duplicate headers header_counter = Counter(header_keys) duplicate_headers = [key for key, count in header_counter.items() if count > 1] ALLOW_DUPLICATE_HEADERS = {'accept', 'accept-encoding', 'accept-language', 'accept-charset', 'cookie'} real_duplicates = [h for h in duplicate_headers if h not in ALLOW_DUPLICATE_HEADERS] if real_duplicates: raise HTTPException(status_code=422, detail=f"Duplicate headers are not allowed: {real_duplicates}") # Whitelist headers unknown_headers = [key for key in header_keys if key not in ALLOWED_HEADERS] if unknown_headers: filtered_unknown = [h for h in unknown_headers if not h.startswith('sec-')] if filtered_unknown: raise HTTPException(status_code=422, detail=f"Unknown headers detected: {filtered_unknown}") # Inspect header values for key, value in request.headers.items(): if value: inspect_value(value, f"header '{key}'") # ------------------------- # 1. Query string limits # ------------------------- if len(request.url.query) > MAX_QUERY_LENGTH: raise HTTPException(status_code=422, detail="Query string too long") params = request.query_params.multi_items() if len(params) > MAX_QUERY_PARAMS: raise HTTPException(status_code=422, detail="Too many query parameters") # Check for unknown query parameters unknown_params = [key for key, _ in params if key not in ALLOWED_DATA_PARAMS] if unknown_params: raise HTTPException(status_code=422, detail=f"Unknown query parameters detected: {unknown_params}") # ------------------------- # 2. Duplicate parameters # ------------------------- counter = Counter(key for key, _ in params) duplicates = [ key for key, count in counter.items() if count > 1 and key not in ALLOWED_MULTI_PARAMS ] if duplicates: raise HTTPException(status_code=422, detail=f"Duplicate query parameters are not allowed: {duplicates}") # ------------------------- # 3. Query param inspection & Pagination # ------------------------- pagination_size_keys = {"size", "itemsPerPage", "per_page", "limit", "items_per_page"} for key, value in params: if value: inspect_value(value, f"query param '{key}'") if key in pagination_size_keys and value: try: size_val = int(value) if size_val > 50: raise HTTPException(status_code=422, detail=f"Pagination size '{key}' cannot exceed 50") if size_val % 5 != 0: raise HTTPException(status_code=422, detail=f"Pagination size '{key}' must be a multiple of 5") except ValueError: raise HTTPException(status_code=422, detail=f"Pagination size '{key}' must be an integer") # ------------------------- # 4. Content-Type sanity # ------------------------- content_type = request.headers.get("content-type", "") if content_type and not any( content_type.startswith(t) for t in ("application/json", "multipart/form-data", "application/x-www-form-urlencoded") ): raise HTTPException(status_code=422, detail="Unsupported Content-Type") # ------------------------- # 5. Single source check (Query vs JSON Body) # ------------------------- has_query = len(params) > 0 has_body = False if content_type.startswith("application/json"): content_length = request.headers.get("content-length") if content_length and int(content_length) > 0: has_body = True if has_query and has_body: raise HTTPException(status_code=422, detail="Parameters must be from a single source (query string or JSON body), mixed sources are not allowed") # ------------------------- # 6. JSON body inspection # ------------------------- if content_type.startswith("application/json"): body = await request.body() # if len(body) > MAX_JSON_BODY_SIZE: # raise HTTPException(status_code=422, detail="JSON body too large") if body: try: payload = json.loads(body) except json.JSONDecodeError: raise HTTPException(status_code=422, detail="Invalid JSON body") inspect_json(payload) # Re-inject body for downstream handlers async def receive(): return {"type": "http.request", "body": body} request._receive = receive return await call_next(request)