feat: Implement JSON logging with request context and centralize logging configuration.

main
Cizz22 4 weeks ago
parent d694dafa8f
commit abb7e8d27b

@ -1,6 +1,8 @@
from collections import defaultdict from collections import defaultdict
import json import json
import logging import logging
log = logging.getLogger(__name__)
import os import os
import tempfile import tempfile
from datetime import datetime from datetime import datetime
@ -17,10 +19,6 @@ from src.aeros_simulation.service import get_all_aeros_node, get_or_save_node, g
from src.aeros_simulation.utils import calculate_eaf, calculate_eaf_konkin from src.aeros_simulation.utils import calculate_eaf, calculate_eaf_konkin
from src.config import AEROS_BASE_URL from src.config import AEROS_BASE_URL
from src.database.core import DbSession from src.database.core import DbSession
from src.logging import setup_logging
log = logging.getLogger(__name__)
setup_logging(logger=log)
async def execute_simulation( async def execute_simulation(

@ -2,12 +2,7 @@ from datetime import datetime
import json import json
import logging import logging
from src.logging import setup_logging
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
setup_logging(log)
def date_to_utc(date_val): def date_to_utc(date_val):
return datetime.combine( return datetime.combine(

@ -0,0 +1,18 @@
from contextvars import ContextVar
from typing import Optional, Final
REQUEST_ID_CTX_KEY: Final[str] = "request_id"
_request_id_ctx_var: ContextVar[Optional[str]] = ContextVar(
REQUEST_ID_CTX_KEY, default=None)
def get_request_id() -> Optional[str]:
return _request_id_ctx_var.get()
def set_request_id(request_id: str):
return _request_id_ctx_var.set(request_id)
def reset_request_id(token):
_request_id_ctx_var.reset(token)

@ -1,16 +1,129 @@
import logging import logging
import json
import datetime
import os
import sys import sys
from fastapi import FastAPI from typing import Optional
from src.config import LOG_LEVEL from src.config import LOG_LEVEL
from src.enums import RBDEnum
LOG_FORMAT_DEBUG = "%(levelname)s:%(message)s:%(pathname)s:%(funcName)s:%(lineno)d"
# ANSI Color Codes
RESET = "\033[0m"
COLORS = {
"DEBUG": "\033[36m", # Cyan
"INFO": "\033[32m", # Green
"WARNING": "\033[33m", # Yellow
"WARN": "\033[33m", # Yellow
"ERROR": "\033[31m", # Red
"CRITICAL": "\033[1;31m", # Bold Red
}
class LogLevels(RBDEnum):
info = "INFO"
warn = "WARN"
error = "ERROR"
debug = "DEBUG"
class JSONFormatter(logging.Formatter):
"""
Custom formatter to output logs in JSON format.
"""
def format(self, record):
from src.context import get_request_id
request_id = None
try:
request_id = get_request_id()
except Exception:
pass
log_record = {
"timestamp": datetime.datetime.fromtimestamp(record.created).astimezone().isoformat(),
"level": record.levelname,
"message": record.getMessage(),
"logger_name": record.name,
"location": f"{record.module}:{record.funcName}:{record.lineno}",
"module": record.module,
"funcName": record.funcName,
"lineno": record.lineno,
"pid": os.getpid(),
"request_id": request_id,
}
# Capture exception info if available
if record.exc_info:
log_record["exception"] = self.formatException(record.exc_info)
# Capture stack info if available
if record.stack_info:
log_record["stack_trace"] = self.formatStack(record.stack_info)
# Add any extra attributes passed to the log call
# We skip standard attributes to avoid duplication
standard_attrs = {
"args", "asctime", "created", "exc_info", "exc_text", "filename",
"funcName", "levelname", "levelno", "lineno", "module", "msecs",
"message", "msg", "name", "pathname", "process", "processName",
"relativeCreated", "stack_info", "thread", "threadName"
}
for key, value in record.__dict__.items():
if key not in standard_attrs:
log_record[key] = value
log_json = json.dumps(log_record)
# Apply color if the output is a terminal
if sys.stdout.isatty():
level_color = COLORS.get(record.levelname, "")
return f"{level_color}{log_json}{RESET}"
return log_json
def configure_logging():
log_level = str(LOG_LEVEL).upper() # cast to string
log_levels = list(LogLevels)
if log_level not in log_levels:
log_level = LogLevels.error
# Get the root logger
root_logger = logging.getLogger()
root_logger.setLevel(log_level)
# Clear existing handlers to avoid duplicate logs
if root_logger.hasHandlers():
root_logger.handlers.clear()
# Create a stream handler that outputs to stdout
handler = logging.StreamHandler(sys.stdout)
# Use JSONFormatter for all environments, or could be conditional
# For now, let's assume the user wants JSON everywhere as requested
formatter = JSONFormatter()
# If debug mode is specifically requested and we want the old format for debug:
# if log_level == LogLevels.debug:
# formatter = logging.Formatter(LOG_FORMAT_DEBUG)
handler.setFormatter(formatter)
root_logger.addHandler(handler)
# Reconfigure uvicorn loggers to use our JSON formatter
for logger_name in ["uvicorn", "uvicorn.access", "uvicorn.error", "fastapi"]:
logger = logging.getLogger(logger_name)
logger.handlers = []
logger.propagate = True
def setup_logging(logger): # sometimes the slack client can be too verbose
# Your logging configuration here logging.getLogger("slack_sdk.web.base_client").setLevel(logging.CRITICAL)
logger.setLevel(logging.DEBUG)
# Create formatter
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
# Create console handler
stream_handler = logging.StreamHandler(sys.stdout)
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)

@ -2,78 +2,63 @@ import logging
import os import os
import sys import sys
import time import time
from contextvars import ContextVar
from os import path from os import path
from typing import Final, Optional
from uuid import uuid1 from uuid import uuid1
from typing import Optional, Final
from fastapi import FastAPI, HTTPException, Path, status from fastapi import FastAPI, HTTPException, status, Path
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from pydantic import ValidationError from pydantic import ValidationError
from slowapi import _rate_limit_exceeded_handler from slowapi import _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded from slowapi.errors import RateLimitExceeded
from sqlalchemy import inspect from sqlalchemy import inspect
from sqlalchemy.ext.asyncio import async_scoped_session
from sqlalchemy.orm import scoped_session from sqlalchemy.orm import scoped_session
from sqlalchemy.ext.asyncio import async_scoped_session
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.middleware.gzip import GZipMiddleware
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import FileResponse, Response, StreamingResponse
from starlette.routing import compile_path from starlette.routing import compile_path
from starlette.middleware.gzip import GZipMiddleware
from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import Response, StreamingResponse, FileResponse
from starlette.staticfiles import StaticFiles from starlette.staticfiles import StaticFiles
from src.api import api_router
from src.database.core import async_session, engine,async_aeros_session
from src.enums import ResponseStatus from src.enums import ResponseStatus
from src.logging import configure_logging
from src.rate_limiter import limiter
from src.api import api_router
from src.database.core import engine, async_session, async_aeros_session
from src.exceptions import handle_exception from src.exceptions import handle_exception
from src.logging import setup_logging
from src.middleware import RequestValidationMiddleware from src.middleware import RequestValidationMiddleware
from src.rate_limiter import limiter from src.context import set_request_id, reset_request_id, get_request_id
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
# we configure the logging level and format # we configure the logging level and format
configure_logging()
# we define the exception handlers # we define the exception handlers
exception_handlers = {Exception: handle_exception} exception_handlers = {Exception: handle_exception}
# we create the ASGI for the app # we create the ASGI for the app
app = FastAPI( app = FastAPI(exception_handlers=exception_handlers, openapi_url="", title="LCCA API",
exception_handlers=exception_handlers,
openapi_url="",
title="LCCA API",
description="Welcome to RBD's API documentation!", description="Welcome to RBD's API documentation!",
version="0.1.0", version="0.1.0")
)
app.state.limiter = limiter app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
app.add_middleware(GZipMiddleware, minimum_size=2000) app.add_middleware(GZipMiddleware, minimum_size=2000)
# credentials: "include",
setup_logging(logger=log)
log.info('API is starting up')
REQUEST_ID_CTX_KEY: Final[str] = "request_id"
_request_id_ctx_var: ContextVar[Optional[str]] = ContextVar(
REQUEST_ID_CTX_KEY, default=None
)
def get_request_id() -> Optional[str]:
return _request_id_ctx_var.get()
app.add_middleware(RequestValidationMiddleware) app.add_middleware(RequestValidationMiddleware)
@app.middleware("http") @app.middleware("http")
async def db_session_middleware(request: Request, call_next): async def db_session_middleware(request: Request, call_next):
request_id = str(uuid1()) request_id = str(uuid1())
# we create a per-request id such that we can ensure that our session is scoped for a particular request. # we create a per-request id such that we can ensure that our session is scoped for a particular request.
# see: https://github.com/tiangolo/fastapi/issues/726 # see: https://github.com/tiangolo/fastapi/issues/726
ctx_token = _request_id_ctx_var.set(request_id) ctx_token = set_request_id(request_id)
try: try:
session = async_scoped_session(async_session, scopefunc=get_request_id) session = async_scoped_session(async_session, scopefunc=get_request_id)
@ -82,50 +67,34 @@ async def db_session_middleware(request: Request, call_next):
collector_session = async_scoped_session(async_aeros_session, scopefunc=get_request_id) collector_session = async_scoped_session(async_aeros_session, scopefunc=get_request_id)
request.state.aeros_db = collector_session() request.state.aeros_db = collector_session()
start_time = time.time()
response = await call_next(request) response = await call_next(request)
process_time = (time.time() - start_time) * 1000
log.info(
f"Request: {request.method} {request.url.path} Status: {response.status_code} Duration: {process_time:.2f}ms"
)
except Exception as e: except Exception as e:
log.error(f"Request failed: {request.method} {request.url.path} Error: {str(e)}")
raise e from None raise e from None
finally: finally:
await request.state.db.close() await request.state.db.close()
await request.state.aeros_db.close() await request.state.aeros_db.close()
_request_id_ctx_var.reset(ctx_token) reset_request_id(ctx_token)
return response return response
@app.middleware("http") @app.middleware("http")
async def add_security_headers(request: Request, call_next): async def add_security_headers(request: Request, call_next):
response = await call_next(request) response = await call_next(request)
response.headers["Strict-Transport-Security"] = ( response.headers["Strict-Transport-Security"] = "max-age=31536000 ; includeSubDomains"
"max-age=31536000 ; includeSubDomains"
)
return response return response
app.mount("/model", StaticFiles(directory="model"), name="model") app.mount("/model", StaticFiles(directory="model"), name="model")
# class MetricsMiddleware(BaseHTTPMiddleware):
# async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
# method = request.method
# endpoint = request.url.path
# tags = {"method": method, "endpoint": endpoint}
# try:
# start = time.perf_counter()
# response = await call_next(request)
# elapsed_time = time.perf_counter() - start
# tags.update({"status_code": response.status_code})
# metric_provider.counter("server.call.counter", tags=tags)
# metric_provider.timer("server.call.elapsed", value=elapsed_time, tags=tags)
# log.debug(f"server.call.elapsed.{endpoint}: {elapsed_time}")
# except Exception as e:
# metric_provider.counter("server.call.exception.counter", tags=tags)
# raise e from None
# return response
# app.add_middleware(ExceptionMiddleware)
@app.get("/images/{image_path:path}") @app.get("/images/{image_path:path}")
async def get_image(image_path: str = Path(...)): async def get_image(image_path: str = Path(...)):
# Extract filename from the full path # Extract filename from the full path

Loading…
Cancel
Save