From 8d9ce1ae909bb1870a75447832625d142f60d635 Mon Sep 17 00:00:00 2001 From: Cizz22 Date: Mon, 23 Feb 2026 15:46:13 +0700 Subject: [PATCH] feat: Improve logging with user context and error IDs, enhance request tracing, and strengthen security middleware with RCE/path traversal detection and pagination validation. --- src/auth/service.py | 13 ++++- src/config.py | 2 +- src/context.py | 34 ++++++++++++ src/exceptions.py | 122 ++++++++++++++++++++++++-------------------- src/logging.py | 45 ++++++++++------ src/main.py | 70 ++++++++++++++++++++++--- src/middleware.py | 60 +++++++++++++++++++--- 7 files changed, 261 insertions(+), 85 deletions(-) diff --git a/src/auth/service.py b/src/auth/service.py index dc49e59..2d06f00 100644 --- a/src/auth/service.py +++ b/src/auth/service.py @@ -28,6 +28,17 @@ class JWTBearer(HTTPBearer): ) request.state.user = user_info + + from src.context import set_user_id, set_username, set_role + if hasattr(user_info, "user_id"): + set_user_id(str(user_info.user_id)) + if hasattr(user_info, "username"): + set_username(user_info.username) + elif hasattr(user_info, "name"): + set_username(user_info.name) + if hasattr(user_info, "role"): + set_role(user_info.role) + return user_info else: raise HTTPException(status_code=403, detail="Invalid authorization code.") @@ -46,7 +57,7 @@ class JWTBearer(HTTPBearer): return UserBase(**user_data["data"]) except Exception as e: - print(f"Token verification error: {str(e)}") + logging.error(f"Token verification error: {str(e)}") return None diff --git a/src/config.py b/src/config.py index 03ab027..b92f09f 100644 --- a/src/config.py +++ b/src/config.py @@ -51,7 +51,7 @@ def get_config(): config = get_config() -LOG_LEVEL = config("LOG_LEVEL", default=logging.WARNING) +LOG_LEVEL = config("LOG_LEVEL", default="INFO") ENV = config("ENV", default="local") PORT = config("PORT", cast=int, default=8000) HOST = config("HOST", default="localhost") diff --git a/src/context.py b/src/context.py index 4c968a2..47e0e62 100644 --- a/src/context.py +++ b/src/context.py @@ -2,8 +2,18 @@ from contextvars import ContextVar from typing import Optional, Final REQUEST_ID_CTX_KEY: Final[str] = "request_id" +USER_ID_CTX_KEY: Final[str] = "user_id" +USERNAME_CTX_KEY: Final[str] = "username" +ROLE_CTX_KEY: Final[str] = "role" + _request_id_ctx_var: ContextVar[Optional[str]] = ContextVar( REQUEST_ID_CTX_KEY, default=None) +_user_id_ctx_var: ContextVar[Optional[str]] = ContextVar( + USER_ID_CTX_KEY, default=None) +_username_ctx_var: ContextVar[Optional[str]] = ContextVar( + USERNAME_CTX_KEY, default=None) +_role_ctx_var: ContextVar[Optional[str]] = ContextVar( + ROLE_CTX_KEY, default=None) def get_request_id() -> Optional[str]: @@ -16,3 +26,27 @@ def set_request_id(request_id: str): def reset_request_id(token): _request_id_ctx_var.reset(token) + + +def get_user_id() -> Optional[str]: + return _user_id_ctx_var.get() + + +def set_user_id(user_id: str): + return _user_id_ctx_var.set(user_id) + + +def get_username() -> Optional[str]: + return _username_ctx_var.get() + + +def set_username(username: str): + return _username_ctx_var.set(username) + + +def get_role() -> Optional[str]: + return _role_ctx_var.get() + + +def set_role(role: str): + return _role_ctx_var.set(role) diff --git a/src/exceptions.py b/src/exceptions.py index bab7e8d..18377cb 100644 --- a/src/exceptions.py +++ b/src/exceptions.py @@ -96,58 +96,86 @@ def handle_exception(request: Request, exc: Exception): """ Global exception handler for Fastapi application. """ + import uuid + error_id = str(uuid.uuid1()) request_info = get_request_context(request) + + # Store error_id in request.state for middleware/logging + request.state.error_id = error_id if isinstance(exc, RateLimitExceeded): - return _rate_limit_exceeded_handler(request, exc) + logging.warning( + f"Rate limit exceeded | Error ID: {error_id}", + extra={ + "error_id": error_id, + "error_category": "rate_limit", + "request": request_info, + "detail": str(exc.description) if hasattr(exc, "description") else str(exc), + }, + ) + return JSONResponse( + status_code=429, + content={ + "data": None, + "message": "Rate limit exceeded", + "status": ResponseStatus.ERROR, + "error_id": error_id + } + ) if isinstance(exc, RequestValidationError): - logging.error( - f"Validation error | Error: {str(exc.errors())} | Request: {request_info}", - extra={"error_category": "validation"}, + logging.warning( + f"Validation error occurred | Error ID: {error_id}", + extra={ + "error_id": error_id, + "error_category": "validation", + "errors": exc.errors(), + "request": request_info, + }, ) return JSONResponse( status_code=422, content={ - "data": None, - "message": "Validation error", + "data": exc.errors(), + "message": "Validation Error", "status": ResponseStatus.ERROR, - "errors": [ - ErrorDetail( - field=".".join(map(str, err["loc"])), - message=err["msg"], - code=err["type"], - ).model_dump() - for err in exc.errors() - ] - } + "error_id": error_id + }, ) if isinstance(exc, HTTPException): logging.error( - f"HTTP exception | Code: {exc.status_code} | Error: {exc.detail} | Request: {request_info}", - extra={"error_category": "http"}, + f"HTTP exception occurred | Error ID: {error_id}", + extra={ + "error_id": error_id, + "error_category": "http", + "status_code": exc.status_code, + "detail": exc.detail if hasattr(exc, "detail") else str(exc), + "request": request_info, + }, ) return JSONResponse( status_code=exc.status_code, content={ "data": None, - "message": str(exc.detail), + "message": str(exc.detail) if hasattr(exc, "detail") else str(exc), "status": ResponseStatus.ERROR, - "errors": [ - ErrorDetail( - message=str(exc.detail) - ).model_dump() - ] - } + "error_id": error_id + }, ) if isinstance(exc, SQLAlchemyError): error_message, status_code = handle_sqlalchemy_error(exc) logging.error( - f"Database Error | Error: {str(error_message)} | Request: {request_info}", - extra={"error_category": "database"}, + f"Database error occurred | Error ID: {error_id}", + extra={ + "error_id": error_id, + "error_category": "database", + "error_message": error_message, + "request": request_info, + "exception": str(exc), + }, ) return JSONResponse( @@ -156,42 +184,28 @@ def handle_exception(request: Request, exc: Exception): "data": None, "message": error_message, "status": ResponseStatus.ERROR, - "errors": [ - ErrorDetail( - message=error_message - ).model_dump() - ] - } + "error_id": error_id + }, ) # Log unexpected errors - error_message = f"{exc.__class__.__name__}: {str(exc)}" - error_traceback = exc.__traceback__ - - # Get file and line info if available - if error_traceback: - tb = error_traceback - while tb.tb_next: - tb = tb.tb_next - file_name = tb.tb_frame.f_code.co_filename - line_num = tb.tb_lineno - error_message = f"{error_message}\nFile {file_name}, line {line_num}" - logging.error( - f"Unexpected Error | Error: {error_message} | Request: {request_info}", - extra={"error_category": "unexpected"}, + f"Unexpected error occurred | Error ID: {error_id}", + extra={ + "error_id": error_id, + "error_category": "unexpected", + "error_message": str(exc), + "request": request_info, + }, + exc_info=True, ) - + return JSONResponse( status_code=500, content={ "data": None, - "message": error_message, + "message": "An unexpected error occurred", "status": ResponseStatus.ERROR, - "errors": [ - ErrorDetail( - message=error_message - ).model_dump() - ] - } + "error_id": error_id + }, ) diff --git a/src/logging.py b/src/logging.py index 2207241..b9fc41c 100644 --- a/src/logging.py +++ b/src/logging.py @@ -35,29 +35,45 @@ class JSONFormatter(logging.Formatter): Custom formatter to output logs in JSON format. """ def format(self, record): - from src.context import get_request_id - + from src.context import get_request_id, get_user_id, get_username, get_role request_id = None + user_id = None + username = None + role = None + try: request_id = get_request_id() + user_id = get_user_id() + username = get_username() + role = get_role() except Exception: pass + # Standard fields from requirements log_record = { - "timestamp": datetime.datetime.fromtimestamp(record.created).astimezone().isoformat(), + "timestamp": datetime.datetime.fromtimestamp(record.created).strftime("%Y-%m-%d %H:%M:%S"), "level": record.levelname, + "name": record.name, "message": record.getMessage(), - "logger_name": record.name, - "location": f"{record.module}:{record.funcName}:{record.lineno}", - "module": record.module, - "funcName": record.funcName, - "lineno": record.lineno, - "pid": os.getpid(), - "request_id": request_id or "SYSTEM", # request id assigned per request or SYSTEM for system logs } - + # Add Context information if available + if user_id: + log_record["user_id"] = user_id + if username: + log_record["username"] = username + if role: + log_record["role"] = role + if request_id: + log_record["request_id"] = request_id + + # Add Error context if available + if hasattr(record, "error_id"): + log_record["error_id"] = record.error_id + elif "error_id" in record.__dict__: + log_record["error_id"] = record.error_id + # Capture exception info if available if record.exc_info: log_record["exception"] = self.formatException(record.exc_info) @@ -67,18 +83,17 @@ class JSONFormatter(logging.Formatter): log_record["stack_trace"] = self.formatStack(record.stack_info) # Add any extra attributes passed to the log call - # We skip standard and internal uvicorn/fastapi attributes to avoid duplication or mess standard_attrs = { "args", "asctime", "created", "exc_info", "exc_text", "filename", "funcName", "levelname", "levelno", "lineno", "module", "msecs", "message", "msg", "name", "pathname", "process", "processName", - "relativeCreated", "stack_info", "thread", "threadName", + "relativeCreated", "stack_info", "thread", "threadName", "error_id", "color_message", "request", "scope" } for key, value in record.__dict__.items(): - if key not in standard_attrs: + if key not in standard_attrs and not key.startswith("_"): log_record[key] = value - + log_json = json.dumps(log_record) # Apply color if the output is a terminal diff --git a/src/main.py b/src/main.py index 407f5b0..265deee 100644 --- a/src/main.py +++ b/src/main.py @@ -50,7 +50,7 @@ app.state.limiter = limiter app.add_exception_handler(Exception, handle_exception) app.add_exception_handler(HTTPException, handle_exception) app.add_exception_handler(RequestValidationError, handle_exception) -app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) +app.add_exception_handler(RateLimitExceeded, handle_exception) app.add_exception_handler(SQLAlchemyError, handle_exception) from src.context import set_request_id, reset_request_id, get_request_id @@ -68,18 +68,74 @@ async def db_session_middleware(request: Request, call_next): try: - log.info(f"Incoming request: {request.method} {request.url.path}") + start_time = time.time() session = async_scoped_session(async_session, scopefunc=get_request_id) request.state.db = session() collector_session = async_scoped_session(collector_async_session, scopefunc=get_request_id) request.state.collector_db = collector_session() + response = await call_next(request) - if response.status_code >= 400: - log.error(f"Request completed: {response.status_code}") - else: - log.info(f"Request completed: {response.status_code}") + process_time = (time.time() - start_time) * 1000 + + from src.context import get_username, get_role, get_user_id, set_user_id, set_username, set_role + + # Pull from context or fallback to request.state.user + username = get_username() + role = get_role() + user_id = get_user_id() + + user_obj = getattr(request.state, "user", None) + if user_obj: + # UserBase in this project + u_id = getattr(user_obj, "user_id", None) + u_name = getattr(user_obj, "name", None) or getattr(user_obj, "username", None) + u_role = getattr(user_obj, "role", None) + + if not user_id and u_id: + user_id = str(u_id) + set_user_id(user_id) + if not username and u_name: + username = u_name + set_username(username) + if not role and u_role: + role = u_role + set_role(role) + + user_info_str = "" + if username: + user_info_str = f" | User: {username}" + if role: + user_info_str += f" ({role})" + + log.info( + f"HTTP {request.method} {request.url.path} completed in {round(process_time, 2)}ms{user_info_str}", + extra={ + "method": request.method, + "path": request.url.path, + "status_code": response.status_code, + "duration_ms": round(process_time, 2), + "user_id": user_id, + "role": role, + }, + ) except Exception as e: - log.error(f"Request failed: {type(e).__name__} - {str(e)}") + # Generate an error_id here if it hasn't been generated yet + error_id = getattr(request.state, "error_id", None) + if not error_id: + import uuid + error_id = str(uuid.uuid1()) + request.state.error_id = error_id + + log.error( + f"Request failed | Error ID: {error_id}", + extra={ + "method": request.method, + "path": request.url.path, + "error": str(e), + "error_id": error_id, + }, + exc_info=True, + ) raise e from None finally: await request.state.db.close() diff --git a/src/middleware.py b/src/middleware.py index 3127020..5599a59 100644 --- a/src/middleware.py +++ b/src/middleware.py @@ -18,13 +18,35 @@ MAX_QUERY_PARAMS = 50 MAX_QUERY_LENGTH = 2000 MAX_JSON_BODY_SIZE = 1024 * 100 # 100 KB -# Very targeted patterns. Avoid catastrophic regex nonsense. -XSS_PATTERN_STR = r"( 50: + raise HTTPException(status_code=400, detail=f"Pagination size '{key}' cannot exceed 50") + if size_val % 5 != 0: + raise HTTPException(status_code=400, detail=f"Pagination size '{key}' must be a multiple of 5") + except ValueError: + raise HTTPException(status_code=400, detail=f"Pagination size '{key}' must be an integer") # ------------------------- # 4. Content-Type sanity