From abb7e8d27b19225a4a7a18c799295767c387e988 Mon Sep 17 00:00:00 2001 From: Cizz22 Date: Tue, 10 Feb 2026 12:22:28 +0700 Subject: [PATCH] feat: Implement JSON logging with request context and centralize logging configuration. --- .../simulation_save_service.py | 6 +- src/aeros_simulation/utils.py | 5 - src/context.py | 18 +++ src/logging.py | 133 ++++++++++++++++-- src/main.py | 89 ++++-------- 5 files changed, 172 insertions(+), 79 deletions(-) create mode 100644 src/context.py diff --git a/src/aeros_simulation/simulation_save_service.py b/src/aeros_simulation/simulation_save_service.py index 40fec74..28a68bd 100644 --- a/src/aeros_simulation/simulation_save_service.py +++ b/src/aeros_simulation/simulation_save_service.py @@ -1,6 +1,8 @@ from collections import defaultdict import json import logging + +log = logging.getLogger(__name__) import os import tempfile from datetime import datetime @@ -17,10 +19,6 @@ from src.aeros_simulation.service import get_all_aeros_node, get_or_save_node, g from src.aeros_simulation.utils import calculate_eaf, calculate_eaf_konkin from src.config import AEROS_BASE_URL from src.database.core import DbSession -from src.logging import setup_logging - -log = logging.getLogger(__name__) -setup_logging(logger=log) async def execute_simulation( diff --git a/src/aeros_simulation/utils.py b/src/aeros_simulation/utils.py index efb9d5d..cc4ec55 100644 --- a/src/aeros_simulation/utils.py +++ b/src/aeros_simulation/utils.py @@ -2,12 +2,7 @@ from datetime import datetime import json import logging -from src.logging import setup_logging - - log = logging.getLogger(__name__) -setup_logging(log) - def date_to_utc(date_val): return datetime.combine( diff --git a/src/context.py b/src/context.py new file mode 100644 index 0000000..4c968a2 --- /dev/null +++ b/src/context.py @@ -0,0 +1,18 @@ +from contextvars import ContextVar +from typing import Optional, Final + +REQUEST_ID_CTX_KEY: Final[str] = "request_id" +_request_id_ctx_var: ContextVar[Optional[str]] = ContextVar( + REQUEST_ID_CTX_KEY, default=None) + + +def get_request_id() -> Optional[str]: + return _request_id_ctx_var.get() + + +def set_request_id(request_id: str): + return _request_id_ctx_var.set(request_id) + + +def reset_request_id(token): + _request_id_ctx_var.reset(token) diff --git a/src/logging.py b/src/logging.py index a19c97e..8ecf884 100644 --- a/src/logging.py +++ b/src/logging.py @@ -1,16 +1,129 @@ import logging +import json +import datetime +import os import sys -from fastapi import FastAPI +from typing import Optional + from src.config import LOG_LEVEL +from src.enums import RBDEnum + + +LOG_FORMAT_DEBUG = "%(levelname)s:%(message)s:%(pathname)s:%(funcName)s:%(lineno)d" + +# ANSI Color Codes +RESET = "\033[0m" +COLORS = { + "DEBUG": "\033[36m", # Cyan + "INFO": "\033[32m", # Green + "WARNING": "\033[33m", # Yellow + "WARN": "\033[33m", # Yellow + "ERROR": "\033[31m", # Red + "CRITICAL": "\033[1;31m", # Bold Red +} + + +class LogLevels(RBDEnum): + info = "INFO" + warn = "WARN" + error = "ERROR" + debug = "DEBUG" + + +class JSONFormatter(logging.Formatter): + """ + Custom formatter to output logs in JSON format. + """ + def format(self, record): + from src.context import get_request_id + + + request_id = None + try: + request_id = get_request_id() + except Exception: + pass + + log_record = { + "timestamp": datetime.datetime.fromtimestamp(record.created).astimezone().isoformat(), + "level": record.levelname, + "message": record.getMessage(), + "logger_name": record.name, + "location": f"{record.module}:{record.funcName}:{record.lineno}", + "module": record.module, + "funcName": record.funcName, + "lineno": record.lineno, + "pid": os.getpid(), + "request_id": request_id, + } + + + # Capture exception info if available + if record.exc_info: + log_record["exception"] = self.formatException(record.exc_info) + + # Capture stack info if available + if record.stack_info: + log_record["stack_trace"] = self.formatStack(record.stack_info) + + # Add any extra attributes passed to the log call + # We skip standard attributes to avoid duplication + standard_attrs = { + "args", "asctime", "created", "exc_info", "exc_text", "filename", + "funcName", "levelname", "levelno", "lineno", "module", "msecs", + "message", "msg", "name", "pathname", "process", "processName", + "relativeCreated", "stack_info", "thread", "threadName" + } + for key, value in record.__dict__.items(): + if key not in standard_attrs: + log_record[key] = value + + log_json = json.dumps(log_record) -def setup_logging(logger): - # Your logging configuration here - logger.setLevel(logging.DEBUG) - # Create formatter - formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + # Apply color if the output is a terminal + if sys.stdout.isatty(): + level_color = COLORS.get(record.levelname, "") + return f"{level_color}{log_json}{RESET}" + + return log_json + + +def configure_logging(): + log_level = str(LOG_LEVEL).upper() # cast to string + log_levels = list(LogLevels) + + if log_level not in log_levels: + log_level = LogLevels.error + + # Get the root logger + root_logger = logging.getLogger() + root_logger.setLevel(log_level) + + # Clear existing handlers to avoid duplicate logs + if root_logger.hasHandlers(): + root_logger.handlers.clear() + + # Create a stream handler that outputs to stdout + handler = logging.StreamHandler(sys.stdout) + + # Use JSONFormatter for all environments, or could be conditional + # For now, let's assume the user wants JSON everywhere as requested + formatter = JSONFormatter() + + # If debug mode is specifically requested and we want the old format for debug: + # if log_level == LogLevels.debug: + # formatter = logging.Formatter(LOG_FORMAT_DEBUG) - # Create console handler - stream_handler = logging.StreamHandler(sys.stdout) - stream_handler.setFormatter(formatter) - logger.addHandler(stream_handler) + handler.setFormatter(formatter) + root_logger.addHandler(handler) + + # Reconfigure uvicorn loggers to use our JSON formatter + for logger_name in ["uvicorn", "uvicorn.access", "uvicorn.error", "fastapi"]: + logger = logging.getLogger(logger_name) + logger.handlers = [] + logger.propagate = True + + # sometimes the slack client can be too verbose + logging.getLogger("slack_sdk.web.base_client").setLevel(logging.CRITICAL) + \ No newline at end of file diff --git a/src/main.py b/src/main.py index 6f276c6..d64404b 100644 --- a/src/main.py +++ b/src/main.py @@ -2,78 +2,63 @@ import logging import os import sys import time -from contextvars import ContextVar from os import path -from typing import Final, Optional from uuid import uuid1 +from typing import Optional, Final -from fastapi import FastAPI, HTTPException, Path, status -from fastapi.middleware.cors import CORSMiddleware +from fastapi import FastAPI, HTTPException, status, Path from fastapi.responses import JSONResponse from pydantic import ValidationError + from slowapi import _rate_limit_exceeded_handler from slowapi.errors import RateLimitExceeded from sqlalchemy import inspect -from sqlalchemy.ext.asyncio import async_scoped_session from sqlalchemy.orm import scoped_session +from sqlalchemy.ext.asyncio import async_scoped_session from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint -from starlette.middleware.gzip import GZipMiddleware from starlette.requests import Request -from starlette.responses import FileResponse, Response, StreamingResponse from starlette.routing import compile_path +from starlette.middleware.gzip import GZipMiddleware +from fastapi.middleware.cors import CORSMiddleware + +from starlette.responses import Response, StreamingResponse, FileResponse from starlette.staticfiles import StaticFiles -from src.api import api_router -from src.database.core import async_session, engine,async_aeros_session from src.enums import ResponseStatus +from src.logging import configure_logging +from src.rate_limiter import limiter +from src.api import api_router +from src.database.core import engine, async_session, async_aeros_session from src.exceptions import handle_exception -from src.logging import setup_logging from src.middleware import RequestValidationMiddleware -from src.rate_limiter import limiter +from src.context import set_request_id, reset_request_id, get_request_id log = logging.getLogger(__name__) # we configure the logging level and format - +configure_logging() # we define the exception handlers exception_handlers = {Exception: handle_exception} # we create the ASGI for the app -app = FastAPI( - exception_handlers=exception_handlers, - openapi_url="", - title="LCCA API", - description="Welcome to RBD's API documentation!", - version="0.1.0", -) +app = FastAPI(exception_handlers=exception_handlers, openapi_url="", title="LCCA API", + description="Welcome to RBD's API documentation!", + version="0.1.0") app.state.limiter = limiter app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) app.add_middleware(GZipMiddleware, minimum_size=2000) -# credentials: "include", -setup_logging(logger=log) -log.info('API is starting up') - - -REQUEST_ID_CTX_KEY: Final[str] = "request_id" -_request_id_ctx_var: ContextVar[Optional[str]] = ContextVar( - REQUEST_ID_CTX_KEY, default=None -) - - -def get_request_id() -> Optional[str]: - return _request_id_ctx_var.get() - app.add_middleware(RequestValidationMiddleware) + @app.middleware("http") async def db_session_middleware(request: Request, call_next): request_id = str(uuid1()) # we create a per-request id such that we can ensure that our session is scoped for a particular request. # see: https://github.com/tiangolo/fastapi/issues/726 - ctx_token = _request_id_ctx_var.set(request_id) + ctx_token = set_request_id(request_id) try: session = async_scoped_session(async_session, scopefunc=get_request_id) @@ -82,50 +67,34 @@ async def db_session_middleware(request: Request, call_next): collector_session = async_scoped_session(async_aeros_session, scopefunc=get_request_id) request.state.aeros_db = collector_session() + start_time = time.time() response = await call_next(request) + process_time = (time.time() - start_time) * 1000 + + log.info( + f"Request: {request.method} {request.url.path} Status: {response.status_code} Duration: {process_time:.2f}ms" + ) + except Exception as e: + log.error(f"Request failed: {request.method} {request.url.path} Error: {str(e)}") raise e from None finally: await request.state.db.close() await request.state.aeros_db.close() - _request_id_ctx_var.reset(ctx_token) + reset_request_id(ctx_token) return response @app.middleware("http") async def add_security_headers(request: Request, call_next): response = await call_next(request) - response.headers["Strict-Transport-Security"] = ( - "max-age=31536000 ; includeSubDomains" - ) + response.headers["Strict-Transport-Security"] = "max-age=31536000 ; includeSubDomains" return response app.mount("/model", StaticFiles(directory="model"), name="model") -# class MetricsMiddleware(BaseHTTPMiddleware): -# async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: -# method = request.method -# endpoint = request.url.path -# tags = {"method": method, "endpoint": endpoint} - -# try: -# start = time.perf_counter() -# response = await call_next(request) -# elapsed_time = time.perf_counter() - start -# tags.update({"status_code": response.status_code}) -# metric_provider.counter("server.call.counter", tags=tags) -# metric_provider.timer("server.call.elapsed", value=elapsed_time, tags=tags) -# log.debug(f"server.call.elapsed.{endpoint}: {elapsed_time}") -# except Exception as e: -# metric_provider.counter("server.call.exception.counter", tags=tags) -# raise e from None -# return response - - -# app.add_middleware(ExceptionMiddleware) - @app.get("/images/{image_path:path}") async def get_image(image_path: str = Path(...)): # Extract filename from the full path