import time import logging from os import path from uuid import uuid1 from typing import Optional, Final from fastapi import FastAPI, HTTPException, status from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse from pydantic import ValidationError from slowapi import _rate_limit_exceeded_handler from slowapi.errors import RateLimitExceeded from sqlalchemy import inspect from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import scoped_session from sqlalchemy.ext.asyncio import async_scoped_session from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.requests import Request from starlette.routing import compile_path from starlette.middleware.gzip import GZipMiddleware from fastapi.middleware.cors import CORSMiddleware from starlette.responses import Response, StreamingResponse, FileResponse from starlette.staticfiles import StaticFiles import logging from src.enums import ResponseStatus from src.logging import configure_logging from src.rate_limiter import limiter from src.api import api_router from src.database.core import engine, async_session, collector_async_session from src.exceptions import handle_exception from src.middleware import RequestValidationMiddleware log = logging.getLogger(__name__) # we configure the logging level and format configure_logging() # we create the ASGI for the app app = FastAPI(openapi_url="", title="LCCA API", description="Welcome to LCCA's API documentation!", version="0.1.0") app.state.limiter = limiter # we define the exception handlers app.add_exception_handler(Exception, handle_exception) app.add_exception_handler(HTTPException, handle_exception) app.add_exception_handler(RequestValidationError, handle_exception) app.add_exception_handler(RateLimitExceeded, handle_exception) app.add_exception_handler(SQLAlchemyError, handle_exception) from src.context import set_request_id, reset_request_id, get_request_id app.add_middleware(RequestValidationMiddleware) @app.middleware("http") async def db_session_middleware(request: Request, call_next): request_id = str(uuid1()) # we create a per-request id such that we can ensure that our session is scoped for a particular request. # see: https://github.com/tiangolo/fastapi/issues/726 ctx_token = set_request_id(request_id) try: start_time = time.time() session = async_scoped_session(async_session, scopefunc=get_request_id) request.state.db = session() collector_session = async_scoped_session(collector_async_session, scopefunc=get_request_id) request.state.collector_db = collector_session() response = await call_next(request) process_time = (time.time() - start_time) * 1000 from src.context import get_username, get_role, get_user_id, set_user_id, set_username, set_role # Pull from context or fallback to request.state.user username = get_username() role = get_role() user_id = get_user_id() user_obj = getattr(request.state, "user", None) if user_obj: # UserBase in this project u_id = getattr(user_obj, "user_id", None) u_name = getattr(user_obj, "name", None) or getattr(user_obj, "username", None) u_role = getattr(user_obj, "role", None) if not user_id and u_id: user_id = str(u_id) set_user_id(user_id) if not username and u_name: username = u_name set_username(username) if not role and u_role: role = u_role set_role(role) user_info_str = "" if username: user_info_str = f" | User: {username}" if role: user_info_str += f" ({role})" log.info( f"HTTP {request.method} {request.url.path} completed in {round(process_time, 2)}ms{user_info_str}", extra={ "method": request.method, "path": request.url.path, "status_code": response.status_code, "duration_ms": round(process_time, 2), "user_id": user_id, "role": role, }, ) except Exception as e: # Generate an error_id here if it hasn't been generated yet error_id = getattr(request.state, "error_id", None) if not error_id: import uuid error_id = str(uuid.uuid1()) request.state.error_id = error_id log.error( f"Request failed | Error ID: {error_id}", extra={ "method": request.method, "path": request.url.path, "error": str(e), "error_id": error_id, }, exc_info=True, ) raise e from None finally: await request.state.db.close() await request.state.collector_db.close() reset_request_id(ctx_token) return response @app.middleware("http") async def add_security_headers(request: Request, call_next): response = await call_next(request) response.headers["Strict-Transport-Security"] = "max-age=31536000 ; includeSubDomains" return response # class MetricsMiddleware(BaseHTTPMiddleware): # async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: # method = request.method # endpoint = request.url.path # tags = {"method": method, "endpoint": endpoint} # try: # start = time.perf_counter() # response = await call_next(request) # elapsed_time = time.perf_counter() - start # tags.update({"status_code": response.status_code}) # metric_provider.counter("server.call.counter", tags=tags) # metric_provider.timer("server.call.elapsed", value=elapsed_time, tags=tags) # log.debug(f"server.call.elapsed.{endpoint}: {elapsed_time}") # except Exception as e: # metric_provider.counter("server.call.exception.counter", tags=tags) # raise e from None # return response # app.add_middleware(ExceptionMiddleware) app.include_router(api_router)