diff --git a/src/exceptions.py b/src/exceptions.py index 4d99326..f0bf6ed 100644 --- a/src/exceptions.py +++ b/src/exceptions.py @@ -151,22 +151,38 @@ def handle_exception(request: Request, exc: Exception): ) if isinstance(exc, (HTTPException, StarletteHTTPException)): - log.error( - f"HTTP exception occurred | Error ID: {error_id}", - extra={ - "error_id": error_id, - "error_category": "http", - "status_code": exc.status_code, - "detail": exc.detail if hasattr(exc, "detail") else str(exc), - "request": request_info, - }, - ) + # Log as warning for 4xx, error for 5xx + status_code = exc.status_code if hasattr(exc, "status_code") else 500 + detail = exc.detail if hasattr(exc, "detail") else str(exc) + + if 400 <= status_code < 500: + log.warning( + f"HTTP {status_code} occurred | Error ID: {error_id} | Detail: {detail}", + extra={ + "error_id": error_id, + "error_category": "http", + "status_code": status_code, + "detail": detail, + "request": request_info, + }, + ) + else: + log.error( + f"HTTP {status_code} occurred | Error ID: {error_id} | Detail: {detail}", + extra={ + "error_id": error_id, + "error_category": "http", + "status_code": status_code, + "detail": detail, + "request": request_info, + }, + ) return JSONResponse( - status_code=exc.status_code, + status_code=status_code, content={ "data": None, - "message": str(exc.detail) if hasattr(exc, "detail") else str(exc), + "message": str(detail), "status": ResponseStatus.ERROR, "error_id": error_id }, @@ -174,16 +190,29 @@ def handle_exception(request: Request, exc: Exception): if isinstance(exc, SQLAlchemyError): error_message, status_code = handle_sqlalchemy_error(exc) - log.error( - f"Database error occurred | Error ID: {error_id}", - extra={ - "error_id": error_id, - "error_category": "database", - "error_message": error_message, - "request": request_info, - "exception": str(exc), - }, - ) + # Log integrity errors as warning, others as error + if 400 <= status_code < 500: + log.warning( + f"Database integrity/validation error occurred | Error ID: {error_id}", + extra={ + "error_id": error_id, + "error_category": "database", + "error_message": error_message, + "request": request_info, + "exception": str(exc), + }, + ) + else: + log.error( + f"Database error occurred | Error ID: {error_id}", + extra={ + "error_id": error_id, + "error_category": "database", + "error_message": error_message, + "request": request_info, + "exception": str(exc), + }, + ) return JSONResponse( status_code=status_code, diff --git a/src/middleware.py b/src/middleware.py index a308b64..5931688 100644 --- a/src/middleware.py +++ b/src/middleware.py @@ -1,5 +1,6 @@ import json import re +import logging from collections import Counter from fastapi import Request, HTTPException from starlette.middleware.base import BaseHTTPMiddleware @@ -83,7 +84,9 @@ ALLOWED_HEADERS = { "x-csrf-token", "x-xsrf-token", "postman-token", - "x-internal-key", + "x-forwarded-path", + "x-forwarded-prefix", + "cookie", } MAX_QUERY_PARAMS = 50 @@ -91,32 +94,51 @@ MAX_QUERY_LENGTH = 2000 MAX_JSON_BODY_SIZE = 1024 * 500 # 500 KB XSS_PATTERN = re.compile( - r"( bool: return any(ord(c) < 32 and c not in ("\n", "\r", "\t") for c in value) def inspect_value(value: str, source: str): + if not isinstance(value, str) or value == "*/*": + return + if XSS_PATTERN.search(value): + log.warning(f"Security violation: Potential XSS payload detected in {source}") raise HTTPException(status_code=422, detail=f"Potential XSS payload detected in {source}") if SQLI_PATTERN.search(value): + log.warning(f"Security violation: Potential SQL injection payload detected in {source}") raise HTTPException(status_code=422, detail=f"Potential SQL injection payload detected in {source}") if RCE_PATTERN.search(value): + log.warning(f"Security violation: Potential RCE payload detected in {source}") raise HTTPException(status_code=422, detail=f"Potential RCE payload detected in {source}") if TRAVERSAL_PATTERN.search(value): + log.warning(f"Security violation: Potential Path Traversal payload detected in {source}") raise HTTPException(status_code=422, detail=f"Potential Path Traversal payload detected in {source}") if has_control_chars(value): + log.warning(f"Security violation: Invalid control characters detected in {source}") raise HTTPException(status_code=422, detail=f"Invalid control characters detected in {source}") @@ -156,9 +185,11 @@ def inspect_json(obj, path="body", check_whitelist=True): if isinstance(obj, dict): for key, value in obj.items(): if key in FORBIDDEN_JSON_KEYS: + log.warning(f"Security violation: Forbidden JSON key detected: {path}.{key}") raise HTTPException(status_code=422, detail=f"Forbidden JSON key detected: {path}.{key}") if check_whitelist and key not in ALLOWED_DATA_PARAMS: + log.warning(f"Security violation: Unknown JSON key detected: {path}.{key}") raise HTTPException(status_code=422, detail=f"Unknown JSON key detected: {path}.{key}") # Recurse. If the key is a dynamic container, we stop whitelist checking for children. @@ -189,6 +220,7 @@ class RequestValidationMiddleware(BaseHTTPMiddleware): ALLOW_DUPLICATE_HEADERS = {'accept', 'accept-encoding', 'accept-language', 'accept-charset', 'cookie'} real_duplicates = [h for h in duplicate_headers if h not in ALLOW_DUPLICATE_HEADERS] if real_duplicates: + log.warning(f"Security violation: Duplicate headers detected: {real_duplicates}") raise HTTPException(status_code=422, detail=f"Duplicate headers are not allowed: {real_duplicates}") # Whitelist headers @@ -196,6 +228,7 @@ class RequestValidationMiddleware(BaseHTTPMiddleware): if unknown_headers: filtered_unknown = [h for h in unknown_headers if not h.startswith('sec-')] if filtered_unknown: + log.warning(f"Security violation: Unknown headers detected: {filtered_unknown}") raise HTTPException(status_code=422, detail=f"Unknown headers detected: {filtered_unknown}") # Inspect header values @@ -207,16 +240,19 @@ class RequestValidationMiddleware(BaseHTTPMiddleware): # 1. Query string limits # ------------------------- if len(request.url.query) > MAX_QUERY_LENGTH: + log.warning(f"Security violation: Query string too long") raise HTTPException(status_code=422, detail="Query string too long") params = request.query_params.multi_items() if len(params) > MAX_QUERY_PARAMS: + log.warning(f"Security violation: Too many query parameters") raise HTTPException(status_code=422, detail="Too many query parameters") # Check for unknown query parameters unknown_params = [key for key, _ in params if key not in ALLOWED_DATA_PARAMS] if unknown_params: + log.warning(f"Security violation: Unknown query parameters detected: {unknown_params}") raise HTTPException(status_code=422, detail=f"Unknown query parameters detected: {unknown_params}") # ------------------------- @@ -229,6 +265,7 @@ class RequestValidationMiddleware(BaseHTTPMiddleware): ] if duplicates: + log.warning(f"Security violation: Duplicate query parameters detected: {duplicates}") raise HTTPException(status_code=422, detail=f"Duplicate query parameters are not allowed: {duplicates}") # ------------------------- @@ -243,10 +280,13 @@ class RequestValidationMiddleware(BaseHTTPMiddleware): try: size_val = int(value) if size_val > 50: + log.warning(f"Security violation: Pagination size too large ({size_val})") raise HTTPException(status_code=422, detail=f"Pagination size '{key}' cannot exceed 50") if size_val % 5 != 0: + log.warning(f"Security violation: Pagination size not multiple of 5 ({size_val})") raise HTTPException(status_code=422, detail=f"Pagination size '{key}' must be a multiple of 5") except ValueError: + log.warning(f"Security violation: Pagination size invalid value ({value})") raise HTTPException(status_code=422, detail=f"Pagination size '{key}' must be an integer") # ------------------------- @@ -271,6 +311,7 @@ class RequestValidationMiddleware(BaseHTTPMiddleware): has_body = True if has_query and has_body: + log.warning(f"Security violation: Mixed parameters (query + JSON body)") raise HTTPException(status_code=422, detail="Parameters must be from a single source (query string or JSON body), mixed sources are not allowed") # -------------------------