diff --git a/src/exceptions.py b/src/exceptions.py index 852e8ea..23b5003 100644 --- a/src/exceptions.py +++ b/src/exceptions.py @@ -78,22 +78,22 @@ def handle_sqlalchemy_error(error: SQLAlchemyError): if isinstance(error, IntegrityError): if "unique constraint" in str(error).lower(): - return "This record already exists.", 409 + return "This record already exists.", 422 elif "foreign key constraint" in str(error).lower(): - return "Related record not found.", 400 + return "Related record not found.", 422 else: - return "Data integrity error.", 400 + return "Data integrity error.", 422 elif isinstance(error, DataError) or isinstance(original_error, AsyncPGDataError): - return "Invalid data provided.", 400 + return "Invalid data provided.", 422 elif isinstance(error, DBAPIError): if "unique constraint" in str(error).lower(): - return "This record already exists.", 409 + return "This record already exists.", 422 elif "foreign key constraint" in str(error).lower(): - return "Related record not found.", 400 + return "Related record not found.", 422 elif "null value in column" in str(error).lower(): - return "Required data missing.", 400 + return "Required data missing.", 422 elif "invalid input for query argument" in str(error).lower(): - return "Invalid data provided.", 400 + return "Invalid data provided.", 422 else: return "Database error.", 500 else: diff --git a/src/middleware.py b/src/middleware.py index 0ca720b..312696e 100644 --- a/src/middleware.py +++ b/src/middleware.py @@ -12,17 +12,105 @@ ALLOWED_MULTI_PARAMS = { "sortBy[]", "descending[]", "exclude[]", + "assetnums", + "plant_ids", + "job_ids", +} + +ALLOWED_DATA_PARAMS = { + "actual_shutdown", "all_params", "analysis_metadata", "asset_contributions", + "assetnum", "assetnums", "assigned_date", "availability", "availableScopes", + "avg_cost", "birbaum", "calculation_type", "capacity_weight", "code", + "contribution", "corrective_cost", "corrective_costs", "cost", "costPerFailure", + "cost_savings_vs_planned", "cost_threshold", "cost_trend", "created_at", + "crew_number", "criticalParts", "critical_procurement_items", "current_eaf", + "current_plant_eaf", "current_stock", "current_user", "cut_hours", + "daily_failures", "data", "datetime", "day", "days", "descending", + "description", "down_time", "duration", "duration_oh", "eaf_gap", + "eaf_improvement_text", "eaf_input", "efficiency", "end_date", + "equipment_name", "equipment_results", "equipment_with_sparepart_constraints", + "exclude", "excluded_equipment", "expected_delivery_date", "filter_spec", + "finish", "fleet_statistics", "id", "improvement_impact", "included_equipment", + "included_in_optimization", "intervalDays", "is_included", "itemnum", + "items", "itemsPerPage", "items_per_page", "job", "job_ids", + "last_overhaul_date", "lead_time", "location", "location_tag", "location_tags", + "maintenance_type", "master_equipment", "material_cost", "max_interval", + "max_interval_months", "message", "month", "months_from_planned", "name", + "next_planned_overhaul", "node", "num_failures", "num_of_failures", + "ohSessionId", "oh_scope", "oh_session_id", "oh_type", "oh_types", + "optimal_analysis", "optimal_breakdown", "optimal_month", "optimal_total_cost", + "optimization_success", "optimum_analysis", "optimum_day", "optimum_oh", + "optimum_oh_day", "optimum_oh_month", "order_date", "overhaulCost", + "overhaul_activity", "overhaul_cost", "overhaul_costs", + "overhaul_reference_type", "overhaul_scope", "overhaul_scope_id", "overview", + "page", "parent", "parent_id", "plan_duration", "planned_month", + "planned_outage", "plant_level_benefit", "po_pr_id", "po_vendor_delivery_date", + "possible_plant_eaf", "priority_score", "procurement_cost", "procurement_costs", + "procurement_details", "projected_eaf_improvement", "quantity", + "quantity_required", "query_str", "recommendedScope", "recommended_reduced_outage", + "reference", "reference_id", "remark", "removal_date", "required_improvement", + "results", "schedules", "scope", "scope_calculation_id", "scope_equipment_job", + "scope_name", "scope_overhaul", "service_cost", "session", "simulation", + "simulation_id", "sort_by", "sortBy[]", "descending[]", "exclude[]", + "sparepart_id", "sparepart_impact", "sparepart_name", "sparepart_summary", + "spreadsheet_link", "start", "start_date", "status", "subsystem", "system", + "systemComponents", "target_plant_eaf", "tasks", "timing_recommendation", + "total", "totalPages", "total_cost", "total_equipment", "total_equipment_analyzed", + "total_procurement_items", "type", "unit_cost", "warning_message", "with_results", + "workscope", "workscope_group", "year", "_", "t", "timestamp", + "q", "filter", "currentUser", "risk_cost", "all", "with_results", + "eaf_threshold", "simulation_id", "scope_calculation_id", "calculation_id" +} + +ALLOWED_HEADERS = { + "host", + "user-agent", + "accept", + "accept-language", + "accept-encoding", + "connection", + "upgrade-insecure-requests", + "if-modified-since", + "if-none-match", + "cache-control", + "authorization", + "content-type", + "content-length", + "origin", + "referer", + "sec-fetch-dest", + "sec-fetch-mode", + "sec-fetch-site", + "sec-fetch-user", + "sec-ch-ua", + "sec-ch-ua-mobile", + "sec-ch-ua-platform", + "pragma", + "dnt", + "priority", + "x-forwarded-for", + "x-forwarded-proto", + "x-forwarded-host", + "x-forwarded-port", + "x-real-ip", + "x-request-id", + "x-correlation-id", + "x-requested-with", + "x-csrf-token", + "x-xsrf-token", + "postman-token", + "x-internal-key", } MAX_QUERY_PARAMS = 50 MAX_QUERY_LENGTH = 2000 -MAX_JSON_BODY_SIZE = 1024 * 100 # 100 KB +MAX_JSON_BODY_SIZE = 1024 * 500 # 500 KB XSS_PATTERN = re.compile( r"( bool: def inspect_value(value: str, source: str): if XSS_PATTERN.search(value): - raise HTTPException( - status_code=400, - detail=f"Potential XSS payload detected in {source}", - ) + raise HTTPException(status_code=422, detail=f"Potential XSS payload detected in {source}") if SQLI_PATTERN.search(value): - raise HTTPException( - status_code=400, - detail=f"Potential SQL injection payload detected in {source}", - ) - + raise HTTPException(status_code=422, detail=f"Potential SQL injection payload detected in {source}") + if RCE_PATTERN.search(value): - raise HTTPException( - status_code=400, - detail=f"Potential Remote Code Execution payload detected in {source}", - ) + raise HTTPException(status_code=422, detail=f"Potential RCE payload detected in {source}") if TRAVERSAL_PATTERN.search(value): - raise HTTPException( - status_code=400, - detail=f"Path traversal detected in {source}", - ) + raise HTTPException(status_code=422, detail=f"Potential Path Traversal payload detected in {source}") if has_control_chars(value): - raise HTTPException( - status_code=400, - detail=f"Invalid control characters detected in {source}", - ) + raise HTTPException(status_code=422, detail=f"Invalid control characters detected in {source}") -def inspect_json(obj, path="body"): +def inspect_json(obj, path="body", check_whitelist=True): if isinstance(obj, dict): for key, value in obj.items(): if key in FORBIDDEN_JSON_KEYS: - raise HTTPException( - status_code=400, - detail=f"Forbidden JSON key detected: {path}.{key}", - ) - inspect_json(value, f"{path}.{key}") + raise HTTPException(status_code=422, detail=f"Forbidden JSON key detected: {path}.{key}") + + if check_whitelist and key not in ALLOWED_DATA_PARAMS: + raise HTTPException(status_code=422, detail=f"Unknown JSON key detected: {path}.{key}") + + # Recurse. If the key is a dynamic container, we stop whitelist checking for children. + should_check_subkeys = check_whitelist and (key not in DYNAMIC_KEYS) + inspect_json(value, f"{path}.{key}", check_whitelist=should_check_subkeys) elif isinstance(obj, list): for i, item in enumerate(obj): - inspect_json(item, f"{path}[{i}]") + inspect_json(item, f"{path}[{i}]", check_whitelist=check_whitelist) elif isinstance(obj, str): inspect_value(obj, path) @@ -112,22 +203,47 @@ def inspect_json(obj, path="body"): class RequestValidationMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): + # ------------------------- + # 0. Header validation + # ------------------------- + header_keys = [key.lower() for key, _ in request.headers.items()] + + # Check for duplicate headers + header_counter = Counter(header_keys) + duplicate_headers = [key for key, count in header_counter.items() if count > 1] + + ALLOW_DUPLICATE_HEADERS = {'accept', 'accept-encoding', 'accept-language', 'accept-charset', 'cookie'} + real_duplicates = [h for h in duplicate_headers if h not in ALLOW_DUPLICATE_HEADERS] + if real_duplicates: + raise HTTPException(status_code=422, detail=f"Duplicate headers are not allowed: {real_duplicates}") + + # Whitelist headers + unknown_headers = [key for key in header_keys if key not in ALLOWED_HEADERS] + if unknown_headers: + filtered_unknown = [h for h in unknown_headers if not h.startswith('sec-')] + if filtered_unknown: + raise HTTPException(status_code=422, detail=f"Unknown headers detected: {filtered_unknown}") + + # Inspect header values + for key, value in request.headers.items(): + if value: + inspect_value(value, f"header '{key}'") + # ------------------------- # 1. Query string limits # ------------------------- if len(request.url.query) > MAX_QUERY_LENGTH: - raise HTTPException( - status_code=414, - detail="Query string too long", - ) + raise HTTPException(status_code=422, detail="Query string too long") params = request.query_params.multi_items() if len(params) > MAX_QUERY_PARAMS: - raise HTTPException( - status_code=400, - detail="Too many query parameters", - ) + raise HTTPException(status_code=422, detail="Too many query parameters") + + # Check for unknown query parameters + unknown_params = [key for key, _ in params if key not in ALLOWED_DATA_PARAMS] + if unknown_params: + raise HTTPException(status_code=422, detail=f"Unknown query parameters detected: {unknown_params}") # ------------------------- # 2. Duplicate parameters @@ -139,38 +255,25 @@ class RequestValidationMiddleware(BaseHTTPMiddleware): ] if duplicates: - raise HTTPException( - status_code=400, - detail=f"Duplicate query parameters are not allowed: {duplicates}", - ) + raise HTTPException(status_code=422, detail=f"Duplicate query parameters are not allowed: {duplicates}") # ------------------------- - # 3. Query param inspection + # 3. Query param inspection & Pagination # ------------------------- - pagination_size_keys = {"size", "itemsPerPage", "per_page", "limit"} + pagination_size_keys = {"size", "itemsPerPage", "per_page", "limit", "items_per_page"} for key, value in params: if value: inspect_value(value, f"query param '{key}'") - - # Pagination constraint: multiples of 5, max 50 + if key in pagination_size_keys and value: try: size_val = int(value) if size_val > 50: - raise HTTPException( - status_code=400, - detail=f"Pagination size '{key}' cannot exceed 50", - ) + raise HTTPException(status_code=422, detail=f"Pagination size '{key}' cannot exceed 50") if size_val % 5 != 0: - raise HTTPException( - status_code=400, - detail=f"Pagination size '{key}' must be a multiple of 5", - ) + raise HTTPException(status_code=422, detail=f"Pagination size '{key}' must be a multiple of 5") except ValueError: - raise HTTPException( - status_code=400, - detail=f"Pagination size '{key}' must be an integer", - ) + raise HTTPException(status_code=422, detail=f"Pagination size '{key}' must be an integer") # ------------------------- # 4. Content-Type sanity @@ -178,44 +281,45 @@ class RequestValidationMiddleware(BaseHTTPMiddleware): content_type = request.headers.get("content-type", "") if content_type and not any( content_type.startswith(t) - for t in ( - "application/json", - "multipart/form-data", - "application/x-www-form-urlencoded", - ) + for t in ("application/json", "multipart/form-data", "application/x-www-form-urlencoded") ): - raise HTTPException( - status_code=415, - detail="Unsupported Content-Type", - ) + raise HTTPException(status_code=422, detail="Unsupported Content-Type") # ------------------------- - # 5. JSON body inspection + # 5. Single source check (Query vs JSON Body) # ------------------------- + has_query = len(params) > 0 + has_body = False + if content_type.startswith("application/json"): - body = await request.body() + # We can't easily check body existence without consuming it, + # so we check if Content-Length > 0 + content_length = request.headers.get("content-length") + if content_length and int(content_length) > 0: + has_body = True + + if has_query and has_body: + raise HTTPException(status_code=422, detail="Parameters must be from a single source (query string or JSON body), mixed sources are not allowed") + # ------------------------- + # 6. JSON body inspection + # ------------------------- + if content_type.startswith("application/json"): + body = await request.body() # if len(body) > MAX_JSON_BODY_SIZE: - # raise HTTPException( - # status_code=413, - # detail="JSON body too large", - # ) + # raise HTTPException(status_code=422, detail="JSON body too large") if body: try: payload = json.loads(body) except json.JSONDecodeError: - raise HTTPException( - status_code=400, - detail="Invalid JSON body", - ) + raise HTTPException(status_code=422, detail="Invalid JSON body") inspect_json(payload) # Re-inject body for downstream handlers async def receive(): return {"type": "http.request", "body": body} - - request._receive = receive # noqa: protected-access + request._receive = receive return await call_next(request)