From be3e1e6ae47e212c792316f685a7f637660eabbb Mon Sep 17 00:00:00 2001 From: Cizz22 Date: Tue, 3 Mar 2026 12:13:33 +0700 Subject: [PATCH] fix exception logging --- src/exceptions.py | 76 ++++++++++++++++------------------------------- src/logging.py | 18 +++-------- src/main.py | 30 +++++++------------ src/middleware.py | 17 ----------- 4 files changed, 40 insertions(+), 101 deletions(-) diff --git a/src/exceptions.py b/src/exceptions.py index f0bf6ed..600cadf 100644 --- a/src/exceptions.py +++ b/src/exceptions.py @@ -112,12 +112,12 @@ def handle_exception(request: Request, exc: Exception): if isinstance(exc, RateLimitExceeded): log.warning( - f"Rate limit exceeded | Error ID: {error_id}", + f"Rate limit exceeded: {str(exc.description) if hasattr(exc, 'description') else str(exc)} | Error ID: {error_id}", extra={ "error_id": error_id, "error_category": "rate_limit", - "request": request_info, "detail": str(exc.description) if hasattr(exc, "description") else str(exc), + "request": request_info, }, ) return JSONResponse( @@ -132,7 +132,7 @@ def handle_exception(request: Request, exc: Exception): if isinstance(exc, RequestValidationError): log.warning( - f"Validation error occurred | Error ID: {error_id}", + f"Validation error: {exc.errors()} | Error ID: {error_id}", extra={ "error_id": error_id, "error_category": "validation", @@ -155,28 +155,17 @@ def handle_exception(request: Request, exc: Exception): status_code = exc.status_code if hasattr(exc, "status_code") else 500 detail = exc.detail if hasattr(exc, "detail") else str(exc) - if 400 <= status_code < 500: - log.warning( - f"HTTP {status_code} occurred | Error ID: {error_id} | Detail: {detail}", - extra={ - "error_id": error_id, - "error_category": "http", - "status_code": status_code, - "detail": detail, - "request": request_info, - }, - ) - else: - log.error( - f"HTTP {status_code} occurred | Error ID: {error_id} | Detail: {detail}", - extra={ - "error_id": error_id, - "error_category": "http", - "status_code": status_code, - "detail": detail, - "request": request_info, - }, - ) + log_level = logging.WARNING if 400 <= status_code < 500 else logging.ERROR + log.log( + log_level, + f"HTTP {status_code}: {detail} | Error ID: {error_id}", + extra={ + "error_id": error_id, + "error_category": "http", + "status_code": status_code, + "request": request_info, + }, + ) return JSONResponse( status_code=status_code, @@ -191,28 +180,17 @@ def handle_exception(request: Request, exc: Exception): if isinstance(exc, SQLAlchemyError): error_message, status_code = handle_sqlalchemy_error(exc) # Log integrity errors as warning, others as error - if 400 <= status_code < 500: - log.warning( - f"Database integrity/validation error occurred | Error ID: {error_id}", - extra={ - "error_id": error_id, - "error_category": "database", - "error_message": error_message, - "request": request_info, - "exception": str(exc), - }, - ) - else: - log.error( - f"Database error occurred | Error ID: {error_id}", - extra={ - "error_id": error_id, - "error_category": "database", - "error_message": error_message, - "request": request_info, - "exception": str(exc), - }, - ) + log_level = logging.WARNING if 400 <= status_code < 500 else logging.ERROR + log.log( + log_level, + f"Database error: {error_message} | Error ID: {error_id}", + extra={ + "error_id": error_id, + "error_category": "database", + "violation": str(exc).split('\n')[0], + "request": request_info, + }, + ) return JSONResponse( status_code=status_code, @@ -226,14 +204,12 @@ def handle_exception(request: Request, exc: Exception): # Log unexpected errors log.error( - f"Unexpected error occurred | Error ID: {error_id}", + f"Unexpected error: {str(exc)} | Error ID: {error_id}", extra={ "error_id": error_id, "error_category": "unexpected", - "error_message": str(exc), "request": request_info, }, - exc_info=True, ) return JSONResponse( diff --git a/src/logging.py b/src/logging.py index 412a1c2..22a954a 100644 --- a/src/logging.py +++ b/src/logging.py @@ -54,33 +54,23 @@ class JSONFormatter(logging.Formatter): log_record = { "timestamp": datetime.datetime.fromtimestamp(record.created).strftime("%Y-%m-%d %H:%M:%S"), "level": record.levelname, - "name": record.name, "message": record.getMessage(), } # Add Context information if available - if user_id: - log_record["user_id"] = user_id - if username: - log_record["username"] = username - if role: - log_record["role"] = role if request_id: log_record["request_id"] = request_id + if role: + log_record["role"] = role # Add Error context if available if hasattr(record, "error_id"): log_record["error_id"] = record.error_id elif "error_id" in record.__dict__: log_record["error_id"] = record.error_id - - # Capture exception info if available - if record.exc_info: - log_record["exception"] = self.formatException(record.exc_info) - # Capture stack info if available - if record.stack_info: - log_record["stack_trace"] = self.formatStack(record.stack_info) + if user_id: + log_record["user_id"] = user_id # Add any extra attributes passed to the log call standard_attrs = { diff --git a/src/main.py b/src/main.py index dcb6524..0126a57 100644 --- a/src/main.py +++ b/src/main.py @@ -80,6 +80,10 @@ async def db_session_middleware(request: Request, call_next): response = await call_next(request) process_time = (time.time() - start_time) * 1000 + # Skip logging in middleware if it's an error (already logged in handle_exception) + if response.status_code >= 400: + return response + from src.context import get_username, get_role, get_user_id, set_user_id, set_username, set_role # Pull from context or fallback to request.state @@ -105,8 +109,13 @@ async def db_session_middleware(request: Request, call_next): if role: user_info_str += f" ({role})" + error_id = getattr(request.state, "error_id", None) + log_msg = f"HTTP {request.method} {request.url.path} completed in {round(process_time, 2)}ms{user_info_str}" + if error_id: + log_msg += f" | Error ID: {error_id}" + log.info( - f"HTTP {request.method} {request.url.path} completed in {round(process_time, 2)}ms{user_info_str}", + log_msg, extra={ "method": request.method, "path": request.url.path, @@ -114,28 +123,9 @@ async def db_session_middleware(request: Request, call_next): "duration_ms": round(process_time, 2), "user_id": user_id, "role": role, - }, - ) - - except Exception as e: - # Generate an error_id here if it hasn't been generated yet (e.g., if it failed before the handler) - error_id = getattr(request.state, "error_id", None) - if not error_id: - import uuid - error_id = str(uuid.uuid1()) - request.state.error_id = error_id - - log.error( - f"Request failed | Error ID: {error_id}", - extra={ - "method": request.method, - "path": request.url.path, - "error": str(e), "error_id": error_id, }, - exc_info=True, ) - raise e from None finally: await request.state.db.close() await request.state.aeros_db.close() diff --git a/src/middleware.py b/src/middleware.py index 5931688..5857074 100644 --- a/src/middleware.py +++ b/src/middleware.py @@ -161,23 +161,18 @@ def inspect_value(value: str, source: str): return if XSS_PATTERN.search(value): - log.warning(f"Security violation: Potential XSS payload detected in {source}") raise HTTPException(status_code=422, detail=f"Potential XSS payload detected in {source}") if SQLI_PATTERN.search(value): - log.warning(f"Security violation: Potential SQL injection payload detected in {source}") raise HTTPException(status_code=422, detail=f"Potential SQL injection payload detected in {source}") if RCE_PATTERN.search(value): - log.warning(f"Security violation: Potential RCE payload detected in {source}") raise HTTPException(status_code=422, detail=f"Potential RCE payload detected in {source}") if TRAVERSAL_PATTERN.search(value): - log.warning(f"Security violation: Potential Path Traversal payload detected in {source}") raise HTTPException(status_code=422, detail=f"Potential Path Traversal payload detected in {source}") if has_control_chars(value): - log.warning(f"Security violation: Invalid control characters detected in {source}") raise HTTPException(status_code=422, detail=f"Invalid control characters detected in {source}") @@ -185,11 +180,9 @@ def inspect_json(obj, path="body", check_whitelist=True): if isinstance(obj, dict): for key, value in obj.items(): if key in FORBIDDEN_JSON_KEYS: - log.warning(f"Security violation: Forbidden JSON key detected: {path}.{key}") raise HTTPException(status_code=422, detail=f"Forbidden JSON key detected: {path}.{key}") if check_whitelist and key not in ALLOWED_DATA_PARAMS: - log.warning(f"Security violation: Unknown JSON key detected: {path}.{key}") raise HTTPException(status_code=422, detail=f"Unknown JSON key detected: {path}.{key}") # Recurse. If the key is a dynamic container, we stop whitelist checking for children. @@ -220,7 +213,6 @@ class RequestValidationMiddleware(BaseHTTPMiddleware): ALLOW_DUPLICATE_HEADERS = {'accept', 'accept-encoding', 'accept-language', 'accept-charset', 'cookie'} real_duplicates = [h for h in duplicate_headers if h not in ALLOW_DUPLICATE_HEADERS] if real_duplicates: - log.warning(f"Security violation: Duplicate headers detected: {real_duplicates}") raise HTTPException(status_code=422, detail=f"Duplicate headers are not allowed: {real_duplicates}") # Whitelist headers @@ -228,7 +220,6 @@ class RequestValidationMiddleware(BaseHTTPMiddleware): if unknown_headers: filtered_unknown = [h for h in unknown_headers if not h.startswith('sec-')] if filtered_unknown: - log.warning(f"Security violation: Unknown headers detected: {filtered_unknown}") raise HTTPException(status_code=422, detail=f"Unknown headers detected: {filtered_unknown}") # Inspect header values @@ -240,19 +231,16 @@ class RequestValidationMiddleware(BaseHTTPMiddleware): # 1. Query string limits # ------------------------- if len(request.url.query) > MAX_QUERY_LENGTH: - log.warning(f"Security violation: Query string too long") raise HTTPException(status_code=422, detail="Query string too long") params = request.query_params.multi_items() if len(params) > MAX_QUERY_PARAMS: - log.warning(f"Security violation: Too many query parameters") raise HTTPException(status_code=422, detail="Too many query parameters") # Check for unknown query parameters unknown_params = [key for key, _ in params if key not in ALLOWED_DATA_PARAMS] if unknown_params: - log.warning(f"Security violation: Unknown query parameters detected: {unknown_params}") raise HTTPException(status_code=422, detail=f"Unknown query parameters detected: {unknown_params}") # ------------------------- @@ -265,7 +253,6 @@ class RequestValidationMiddleware(BaseHTTPMiddleware): ] if duplicates: - log.warning(f"Security violation: Duplicate query parameters detected: {duplicates}") raise HTTPException(status_code=422, detail=f"Duplicate query parameters are not allowed: {duplicates}") # ------------------------- @@ -280,13 +267,10 @@ class RequestValidationMiddleware(BaseHTTPMiddleware): try: size_val = int(value) if size_val > 50: - log.warning(f"Security violation: Pagination size too large ({size_val})") raise HTTPException(status_code=422, detail=f"Pagination size '{key}' cannot exceed 50") if size_val % 5 != 0: - log.warning(f"Security violation: Pagination size not multiple of 5 ({size_val})") raise HTTPException(status_code=422, detail=f"Pagination size '{key}' must be a multiple of 5") except ValueError: - log.warning(f"Security violation: Pagination size invalid value ({value})") raise HTTPException(status_code=422, detail=f"Pagination size '{key}' must be an integer") # ------------------------- @@ -311,7 +295,6 @@ class RequestValidationMiddleware(BaseHTTPMiddleware): has_body = True if has_query and has_body: - log.warning(f"Security violation: Mixed parameters (query + JSON body)") raise HTTPException(status_code=422, detail="Parameters must be from a single source (query string or JSON body), mixed sources are not allowed") # -------------------------