feat: Introduce request context management and refactor logging to JSON format with request IDs.

main
Cizz22 4 weeks ago
parent 17c0bce713
commit 24ecc2f5f4

@ -12,7 +12,6 @@ from sqlalchemy import and_, case, func, select, update
from sqlalchemy.orm import joinedload, selectinload
from src.database.core import DbSession
from src.logging import setup_logging
from src.overhaul_activity.service import get_all as get_all_by_session_id
from src.overhaul_scope.service import get as get_scope, get_prev_oh
from src.sparepart.service import get_spareparts_paginated, load_sparepart_data_from_db
@ -46,7 +45,6 @@ import json
client = httpx.AsyncClient(timeout=300.0)
log = logging.getLogger(__name__)
setup_logging(logger=log)
class OptimumCostModelWithSpareparts:
def __init__(self, token: str, last_oh_date: date, next_oh_date: date,

@ -0,0 +1,18 @@
from contextvars import ContextVar
from typing import Optional, Final
REQUEST_ID_CTX_KEY: Final[str] = "request_id"
_request_id_ctx_var: ContextVar[Optional[str]] = ContextVar(
REQUEST_ID_CTX_KEY, default=None)
def get_request_id() -> Optional[str]:
return _request_id_ctx_var.get()
def set_request_id(request_id: str):
return _request_id_ctx_var.set(request_id)
def reset_request_id(token):
_request_id_ctx_var.reset(token)

@ -1,11 +1,27 @@
import logging
import json
import datetime
import os
import sys
from typing import Optional
from src.config import LOG_LEVEL
from src.enums import OptimumOHEnum
LOG_FORMAT_DEBUG = "%(levelname)s:%(message)s:%(pathname)s:%(funcName)s:%(lineno)d"
# ANSI Color Codes
RESET = "\033[0m"
COLORS = {
"DEBUG": "\033[36m", # Cyan
"INFO": "\033[32m", # Green
"WARNING": "\033[33m", # Yellow
"WARN": "\033[33m", # Yellow
"ERROR": "\033[31m", # Red
"CRITICAL": "\033[1;31m", # Bold Red
}
class LogLevels(OptimumOHEnum):
info = "INFO"
@ -14,32 +30,98 @@ class LogLevels(OptimumOHEnum):
debug = "DEBUG"
class JSONFormatter(logging.Formatter):
"""
Custom formatter to output logs in JSON format.
"""
def format(self, record):
from src.context import get_request_id
request_id = None
try:
request_id = get_request_id()
except Exception:
pass
log_record = {
"timestamp": datetime.datetime.fromtimestamp(record.created).astimezone().isoformat(),
"level": record.levelname,
"message": record.getMessage(),
"logger_name": record.name,
"location": f"{record.module}:{record.funcName}:{record.lineno}",
"module": record.module,
"funcName": record.funcName,
"lineno": record.lineno,
"pid": os.getpid(),
"request_id": request_id,
}
# Capture exception info if available
if record.exc_info:
log_record["exception"] = self.formatException(record.exc_info)
# Capture stack info if available
if record.stack_info:
log_record["stack_trace"] = self.formatStack(record.stack_info)
# Add any extra attributes passed to the log call
# We skip standard attributes to avoid duplication
standard_attrs = {
"args", "asctime", "created", "exc_info", "exc_text", "filename",
"funcName", "levelname", "levelno", "lineno", "module", "msecs",
"message", "msg", "name", "pathname", "process", "processName",
"relativeCreated", "stack_info", "thread", "threadName"
}
for key, value in record.__dict__.items():
if key not in standard_attrs:
log_record[key] = value
log_json = json.dumps(log_record)
# Apply color if the output is a terminal
if sys.stdout.isatty():
level_color = COLORS.get(record.levelname, "")
return f"{level_color}{log_json}{RESET}"
return log_json
def configure_logging():
log_level = str(LOG_LEVEL).upper() # cast to string
log_levels = list(LogLevels)
if log_level not in log_levels:
# we use error as the default log level
logging.basicConfig(level=LogLevels.error)
return
log_level = LogLevels.error
if log_level == LogLevels.debug:
logging.basicConfig(level=log_level, format=LOG_FORMAT_DEBUG)
return
# Get the root logger
root_logger = logging.getLogger()
root_logger.setLevel(log_level)
logging.basicConfig(level=log_level)
# Clear existing handlers to avoid duplicate logs
if root_logger.hasHandlers():
root_logger.handlers.clear()
# sometimes the slack client can be too verbose
logging.getLogger("slack_sdk.web.base_client").setLevel(logging.CRITICAL)
# Create a stream handler that outputs to stdout
handler = logging.StreamHandler(sys.stdout)
# Use JSONFormatter for all environments, or could be conditional
# For now, let's assume the user wants JSON everywhere as requested
formatter = JSONFormatter()
# If debug mode is specifically requested and we want the old format for debug:
# if log_level == LogLevels.debug:
# formatter = logging.Formatter(LOG_FORMAT_DEBUG)
handler.setFormatter(formatter)
root_logger.addHandler(handler)
# Reconfigure uvicorn loggers to use our JSON formatter
for logger_name in ["uvicorn", "uvicorn.access", "uvicorn.error", "fastapi"]:
logger = logging.getLogger(logger_name)
logger.handlers = []
logger.propagate = True
def setup_logging(logger):
# Your logging configuration here
logger.setLevel(logging.DEBUG)
# Create formatter
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
# Create console handler
stream_handler = logging.StreamHandler(sys.stdout)
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
# sometimes the slack client can be too verbose
logging.getLogger("slack_sdk.web.base_client").setLevel(logging.CRITICAL)

@ -1,5 +1,6 @@
import logging
import time
from src.context import set_request_id, reset_request_id, get_request_id
from contextvars import ContextVar
from os import path
from typing import Final, Optional
@ -51,14 +52,8 @@ app.add_middleware(GZipMiddleware, minimum_size=1000)
# credentials: "include",
REQUEST_ID_CTX_KEY: Final[str] = "request_id"
_request_id_ctx_var: ContextVar[Optional[str]] = ContextVar(
REQUEST_ID_CTX_KEY, default=None
)
def get_request_id() -> Optional[str]:
return _request_id_ctx_var.get()
def security_headers_middleware(app: FastAPI):
is_production = False
@ -125,13 +120,15 @@ security_headers_middleware(app)
app.add_middleware(RequestValidationMiddleware)
@app.middleware("http")
async def db_session_middleware(request: Request, call_next):
request_id = str(uuid1())
# we create a per-request id such that we can ensure that our session is scoped for a particular request.
# see: https://github.com/tiangolo/fastapi/issues/726
ctx_token = _request_id_ctx_var.set(request_id)
ctx_token = set_request_id(request_id)
try:
session = async_scoped_session(async_session, scopefunc=get_request_id)
@ -140,14 +137,21 @@ async def db_session_middleware(request: Request, call_next):
collector_session = async_scoped_session(async_collector_session, scopefunc=get_request_id)
request.state.collector_db = collector_session()
start_time = time.time()
response = await call_next(request)
process_time = (time.time() - start_time) * 1000
log.info(
f"Request: {request.method} {request.url.path} Status: {response.status_code} Duration: {process_time:.2f}ms"
)
except Exception as e:
log.error(f"Request failed: {request.method} {request.url.path} Error: {str(e)}")
raise e from None
finally:
await request.state.db.close()
await request.state.collector_db.close()
_request_id_ctx_var.reset(ctx_token)
reset_request_id(ctx_token)
return response

@ -15,7 +15,6 @@ from sqlalchemy.orm import joinedload, selectinload
from src.auth.service import CurrentUser
from src.database.core import DbSession
from src.database.service import CommonParameters, search_filter_sort_paginate
from src.logging import setup_logging
from src.overhaul_activity.service import get_standard_scope_by_session_id
from src.overhaul_scope.service import get as get_scope, get_overview_overhaul
from src.overhaul_scope.service import get_prev_oh
@ -24,7 +23,6 @@ from src.sparepart.schema import ProcurementRecord, ProcurementStatus, Sparepart
log = logging.getLogger(__name__)
setup_logging(logger=log)
from sqlalchemy import text
import math

Loading…
Cancel
Save