From 64d1fcf4dd966526746acdad9702105e83ba2840 Mon Sep 17 00:00:00 2001 From: Cizz22 Date: Mon, 2 Mar 2026 14:40:49 +0700 Subject: [PATCH] feat: Implement comprehensive request validation by adding header and data parameter whitelisting, enhancing JSON body inspection, and standardizing validation error codes to 422. --- src/exceptions.py | 16 ++-- src/middleware.py | 239 ++++++++++++++++++++++++++++++---------------- 2 files changed, 165 insertions(+), 90 deletions(-) diff --git a/src/exceptions.py b/src/exceptions.py index 25b345e..4d99326 100644 --- a/src/exceptions.py +++ b/src/exceptions.py @@ -75,22 +75,22 @@ def handle_sqlalchemy_error(error: SQLAlchemyError): if isinstance(error, IntegrityError): if "unique constraint" in str(error).lower(): - return "This record already exists.", 409 + return "This record already exists.", 422 elif "foreign key constraint" in str(error).lower(): - return "Related record not found.", 400 + return "Related record not found.", 422 else: - return "Data integrity error.", 400 + return "Data integrity error.", 422 elif isinstance(error, DataError) or isinstance(original_error, AsyncPGDataError): - return "Invalid data provided.", 400 + return "Invalid data provided.", 422 elif isinstance(error, DBAPIError): if "unique constraint" in str(error).lower(): - return "This record already exists.", 409 + return "This record already exists.", 422 elif "foreign key constraint" in str(error).lower(): - return "Related record not found.", 400 + return "Related record not found.", 422 elif "null value in column" in str(error).lower(): - return "Required data missing.", 400 + return "Required data missing.", 422 elif "invalid input for query argument" in str(error).lower(): - return "Invalid data provided.", 400 + return "Invalid data provided.", 422 else: return "Database error.", 500 else: diff --git a/src/middleware.py b/src/middleware.py index 242d049..a308b64 100644 --- a/src/middleware.py +++ b/src/middleware.py @@ -14,15 +14,87 @@ ALLOWED_MULTI_PARAMS = { "exclude[]", } +ALLOWED_DATA_PARAMS = { + "AerosData", "AhmJobId", "CustomInput", "DurationUnit", "IsDefault", "Konkin_offset", + "MaintenanceOutages", "MasterData", "OffSet", "OverhaulDuration", "OverhaulInterval", + "PlannedOutages", "SchematicName", "SecretStr", "SimDuration", "SimNumRun", "SimSeed", + "SimulationName", "actual_shutdown", "aeros_node", "all_params", "aro_file", + "aro_file_path", "availability", "baseline_simulation_id", "calc_results", + "cm_dis_p1", "cm_dis_p2", "cm_dis_p3", "cm_dis_type", "cm_dis_unit_code", + "cm_waiting_time", "contribution", "contribution_factor", "created_at", + "criticality", "current_user", "custom_parameters", "data", "datetime", + "derating_hours", "descending", "design_flowrate", "duration", "duration_above_h", + "duration_above_hh", "duration_at_empty", "duration_at_full", "duration_below_l", + "duration_below_ll", "eaf", "eaf_konkin", "effective_loss", "efficiency", "efor", + "equipment", "equipment_name", "exclude", "failure_rates", "filter_spec", "finish", + "flowrate_unit", "id", "ideal_production", "ip_dis_p1", "ip_dis_p2", "ip_dis_p3", + "ip_dis_type", "ip_dis_unit_code", "items", "itemsPerPage", "items_per_page", + "level", "location_tag", "master_equipment", "max_flow_rate", "max_flowrate", + "message", "model_image", "mttr", "name", "node_id", "node_name", "node_type", + "num_cm", "num_events", "num_ip", "num_oh", "num_pm", "offset", "oh_dis_p1", + "oh_dis_p2", "oh_dis_p3", "oh_dis_type", "oh_dis_unit_code", "page", + "plan_duration", "planned_outage", "plot_results", "pm_dis_p1", "pm_dis_p2", + "pm_dis_p3", "pm_dis_type", "pm_dis_unit_code", "point_availabilities", + "point_flowrates", "production", "production_std", "project_name", "query_str", + "rel_dis_p1", "rel_dis_p2", "rel_dis_p3", "rel_dis_type", "rel_dis_unit_code", + "remark", "schematic_id", "schematic_name", "simulation_id", "simulation_name", + "sof", "sort_by", "sortBy[]", "descending[]", "exclude[]", "start", "started_at", + "status", "stg_input", "storage_capacity", "structure_name", "t_wait_for_crew", + "t_wait_for_spare", "target_simulation_id", "timestamp_outs", "total", + "totalPages", "total_cm_downtime", "total_downtime", "total_ip_downtime", + "total_oh_downtime", "total_pm_downtime", "total_uptime", "updated_at", "year", + "_", "t", "timestamp", "q", "filter", "currentUser" +} + +ALLOWED_HEADERS = { + "host", + "user-agent", + "accept", + "accept-language", + "accept-encoding", + "connection", + "upgrade-insecure-requests", + "if-modified-since", + "if-none-match", + "cache-control", + "authorization", + "content-type", + "content-length", + "origin", + "referer", + "sec-fetch-dest", + "sec-fetch-mode", + "sec-fetch-site", + "sec-fetch-user", + "sec-ch-ua", + "sec-ch-ua-mobile", + "sec-ch-ua-platform", + "pragma", + "dnt", + "priority", + "x-forwarded-for", + "x-forwarded-proto", + "x-forwarded-host", + "x-forwarded-port", + "x-real-ip", + "x-request-id", + "x-correlation-id", + "x-requested-with", + "x-csrf-token", + "x-xsrf-token", + "postman-token", + "x-internal-key", +} + MAX_QUERY_PARAMS = 50 MAX_QUERY_LENGTH = 2000 -MAX_JSON_BODY_SIZE = 1024 * 100 # 100 KB +MAX_JSON_BODY_SIZE = 1024 * 500 # 500 KB XSS_PATTERN = re.compile( r"( bool: def inspect_value(value: str, source: str): if XSS_PATTERN.search(value): - raise HTTPException( - status_code=400, - detail=f"Potential XSS payload detected in {source}", - ) + raise HTTPException(status_code=422, detail=f"Potential XSS payload detected in {source}") if SQLI_PATTERN.search(value): - raise HTTPException( - status_code=400, - detail=f"Potential SQL injection payload detected in {source}", - ) - + raise HTTPException(status_code=422, detail=f"Potential SQL injection payload detected in {source}") + if RCE_PATTERN.search(value): - raise HTTPException( - status_code=400, - detail=f"Potential RCE payload detected in {source}", - ) - + raise HTTPException(status_code=422, detail=f"Potential RCE payload detected in {source}") + if TRAVERSAL_PATTERN.search(value): - raise HTTPException( - status_code=400, - detail=f"Potential traversal payload detected in {source}", - ) + raise HTTPException(status_code=422, detail=f"Potential Path Traversal payload detected in {source}") if has_control_chars(value): - raise HTTPException( - status_code=400, - detail=f"Invalid control characters detected in {source}", - ) + raise HTTPException(status_code=422, detail=f"Invalid control characters detected in {source}") -def inspect_json(obj, path="body"): +def inspect_json(obj, path="body", check_whitelist=True): if isinstance(obj, dict): for key, value in obj.items(): if key in FORBIDDEN_JSON_KEYS: - raise HTTPException( - status_code=400, - detail=f"Forbidden JSON key detected: {path}.{key}", - ) - inspect_json(value, f"{path}.{key}") + raise HTTPException(status_code=422, detail=f"Forbidden JSON key detected: {path}.{key}") + + if check_whitelist and key not in ALLOWED_DATA_PARAMS: + raise HTTPException(status_code=422, detail=f"Unknown JSON key detected: {path}.{key}") + + # Recurse. If the key is a dynamic container, we stop whitelist checking for children. + should_check_subkeys = check_whitelist and (key not in DYNAMIC_KEYS) + inspect_json(value, f"{path}.{key}", check_whitelist=should_check_subkeys) elif isinstance(obj, list): for i, item in enumerate(obj): - inspect_json(item, f"{path}[{i}]") + inspect_json(item, f"{path}[{i}]", check_whitelist=check_whitelist) elif isinstance(obj, str): inspect_value(obj, path) @@ -113,22 +177,47 @@ def inspect_json(obj, path="body"): class RequestValidationMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): + # ------------------------- + # 0. Header validation + # ------------------------- + header_keys = [key.lower() for key, _ in request.headers.items()] + + # Check for duplicate headers + header_counter = Counter(header_keys) + duplicate_headers = [key for key, count in header_counter.items() if count > 1] + + ALLOW_DUPLICATE_HEADERS = {'accept', 'accept-encoding', 'accept-language', 'accept-charset', 'cookie'} + real_duplicates = [h for h in duplicate_headers if h not in ALLOW_DUPLICATE_HEADERS] + if real_duplicates: + raise HTTPException(status_code=422, detail=f"Duplicate headers are not allowed: {real_duplicates}") + + # Whitelist headers + unknown_headers = [key for key in header_keys if key not in ALLOWED_HEADERS] + if unknown_headers: + filtered_unknown = [h for h in unknown_headers if not h.startswith('sec-')] + if filtered_unknown: + raise HTTPException(status_code=422, detail=f"Unknown headers detected: {filtered_unknown}") + + # Inspect header values + for key, value in request.headers.items(): + if value: + inspect_value(value, f"header '{key}'") + # ------------------------- # 1. Query string limits # ------------------------- if len(request.url.query) > MAX_QUERY_LENGTH: - raise HTTPException( - status_code=414, - detail="Query string too long", - ) + raise HTTPException(status_code=422, detail="Query string too long") params = request.query_params.multi_items() if len(params) > MAX_QUERY_PARAMS: - raise HTTPException( - status_code=400, - detail="Too many query parameters", - ) + raise HTTPException(status_code=422, detail="Too many query parameters") + + # Check for unknown query parameters + unknown_params = [key for key, _ in params if key not in ALLOWED_DATA_PARAMS] + if unknown_params: + raise HTTPException(status_code=422, detail=f"Unknown query parameters detected: {unknown_params}") # ------------------------- # 2. Duplicate parameters @@ -140,38 +229,25 @@ class RequestValidationMiddleware(BaseHTTPMiddleware): ] if duplicates: - raise HTTPException( - status_code=400, - detail=f"Duplicate query parameters are not allowed: {duplicates}", - ) + raise HTTPException(status_code=422, detail=f"Duplicate query parameters are not allowed: {duplicates}") # ------------------------- - # 3. Query param inspection + # 3. Query param inspection & Pagination # ------------------------- pagination_size_keys = {"size", "itemsPerPage", "per_page", "limit", "items_per_page"} for key, value in params: if value: inspect_value(value, f"query param '{key}'") - - # Pagination constraint: multiples of 5, max 50 + if key in pagination_size_keys and value: try: size_val = int(value) if size_val > 50: - raise HTTPException( - status_code=400, - detail=f"Pagination size '{key}' cannot exceed 50", - ) + raise HTTPException(status_code=422, detail=f"Pagination size '{key}' cannot exceed 50") if size_val % 5 != 0: - raise HTTPException( - status_code=400, - detail=f"Pagination size '{key}' must be a multiple of 5", - ) + raise HTTPException(status_code=422, detail=f"Pagination size '{key}' must be a multiple of 5") except ValueError: - raise HTTPException( - status_code=400, - detail=f"Pagination size '{key}' must be an integer", - ) + raise HTTPException(status_code=422, detail=f"Pagination size '{key}' must be an integer") # ------------------------- # 4. Content-Type sanity @@ -179,44 +255,43 @@ class RequestValidationMiddleware(BaseHTTPMiddleware): content_type = request.headers.get("content-type", "") if content_type and not any( content_type.startswith(t) - for t in ( - "application/json", - "multipart/form-data", - "application/x-www-form-urlencoded", - ) + for t in ("application/json", "multipart/form-data", "application/x-www-form-urlencoded") ): - raise HTTPException( - status_code=415, - detail="Unsupported Content-Type", - ) + raise HTTPException(status_code=422, detail="Unsupported Content-Type") # ------------------------- - # 5. JSON body inspection + # 5. Single source check (Query vs JSON Body) # ------------------------- + has_query = len(params) > 0 + has_body = False + if content_type.startswith("application/json"): - body = await request.body() + content_length = request.headers.get("content-length") + if content_length and int(content_length) > 0: + has_body = True + + if has_query and has_body: + raise HTTPException(status_code=422, detail="Parameters must be from a single source (query string or JSON body), mixed sources are not allowed") - #if len(body) > MAX_JSON_BODY_SIZE: - # raise HTTPException( - # status_code=413, - # detail="JSON body too large", - # ) + # ------------------------- + # 6. JSON body inspection + # ------------------------- + if content_type.startswith("application/json"): + body = await request.body() + # if len(body) > MAX_JSON_BODY_SIZE: + # raise HTTPException(status_code=422, detail="JSON body too large") if body: try: payload = json.loads(body) except json.JSONDecodeError: - raise HTTPException( - status_code=400, - detail="Invalid JSON body", - ) + raise HTTPException(status_code=422, detail="Invalid JSON body") inspect_json(payload) # Re-inject body for downstream handlers async def receive(): return {"type": "http.request", "body": body} - - request._receive = receive # noqa: protected-access + request._receive = receive return await call_next(request)