import logging import time from contextvars import ContextVar from os import path from typing import Final, Optional from uuid import uuid1 from fastapi import FastAPI, HTTPException, status from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from pydantic import ValidationError from slowapi import _rate_limit_exceeded_handler from slowapi.errors import RateLimitExceeded from sqlalchemy import inspect from sqlalchemy.ext.asyncio import async_scoped_session from sqlalchemy.orm import scoped_session from starlette.middleware.base import (BaseHTTPMiddleware, RequestResponseEndpoint) from starlette.middleware.gzip import GZipMiddleware from starlette.requests import Request from starlette.responses import FileResponse, Response, StreamingResponse from starlette.routing import compile_path from starlette.staticfiles import StaticFiles from src.api import api_router from src.database.core import async_session, engine, async_collector_session from src.enums import ResponseStatus from src.exceptions import handle_exception from src.logging import configure_logging from src.rate_limiter import limiter log = logging.getLogger(__name__) # we configure the logging level and format configure_logging() # we define the exception handlers exception_handlers = {Exception: handle_exception} # we create the ASGI for the app app = FastAPI( exception_handlers=exception_handlers, openapi_url="", title="LCCA API", description="Welcome to LCCA's API documentation!", version="0.1.0", ) app.state.limiter = limiter app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) app.add_middleware(GZipMiddleware, minimum_size=2000) # credentials: "include", REQUEST_ID_CTX_KEY: Final[str] = "request_id" _request_id_ctx_var: ContextVar[Optional[str]] = ContextVar( REQUEST_ID_CTX_KEY, default=None ) def get_request_id() -> Optional[str]: return _request_id_ctx_var.get() def security_headers_middleware(app: FastAPI): is_production = False # CSP rules csp_policy = { "default-src": "'self'", "script-src": "'self' 'unsafe-inline' https://cdnjs.cloudflare.com https://cdn.jsdelivr.net", "style-src": "'self' 'unsafe-inline' https://fonts.googleapis.com https://cdn.jsdelivr.net", "img-src": "'self' data: https: blob:", "font-src": "'self' https://fonts.gstatic.com data:", "connect-src": "'self' https://api.your-domain.com wss://ws.your-domain.com", "frame-src": "'none'", "object-src": "'none'", "base-uri": "'self'", "form-action": "'self'", } # Feature / Permissions Policy feature_policy = { "geolocation": "'none'", "midi": "'none'", "notifications": "'none'", "push": "'none'", "sync-xhr": "'none'", "microphone": "'none'", "camera": "'none'", "magnetometer": "'none'", "gyroscope": "'none'", "speaker": "'none'", "vibrate": "'none'", "fullscreen": "'self'", "payment": "'none'", } csp_header_value = "; ".join(f"{k} {v}" for k, v in csp_policy.items()) feature_header_value = "; ".join(f"{k}={v}" for k, v in feature_policy.items()) # Middleware definition class SecurityHeadersMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): response: Response = await call_next(request) if is_production: response.headers["Strict-Transport-Security"] = "max-age=15724800; includeSubDomains; preload" response.headers["X-Frame-Options"] = "DENY" response.headers["X-Content-Type-Options"] = "nosniff" response.headers["X-XSS-Protection"] = "1; mode=block" response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin" response.headers["Content-Security-Policy"] = csp_header_value response.headers["Permissions-Policy"] = feature_header_value else: # Relaxed settings for development response.headers["Content-Security-Policy"] = "default-src 'self' 'unsafe-inline' 'unsafe-eval' *" # You can skip some headers here for local testing return response app.add_middleware(SecurityHeadersMiddleware) security_headers_middleware(app) @app.middleware("http") async def db_session_middleware(request: Request, call_next): request_id = str(uuid1()) # we create a per-request id such that we can ensure that our session is scoped for a particular request. # see: https://github.com/tiangolo/fastapi/issues/726 ctx_token = _request_id_ctx_var.set(request_id) try: session = async_scoped_session(async_session, scopefunc=get_request_id) request.state.db = session() collector_session = async_scoped_session(async_collector_session, scopefunc=get_request_id) request.state.collector_db = collector_session() response = await call_next(request) except Exception as e: raise e from None finally: await request.state.db.close() await request.state.collector_db.close() _request_id_ctx_var.reset(ctx_token) return response @app.middleware("http") async def add_security_headers(request: Request, call_next): response = await call_next(request) response.headers["Strict-Transport-Security"] = ( "max-age=31536000 ; includeSubDomains" ) return response # class MetricsMiddleware(BaseHTTPMiddleware): # async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: # method = request.method # endpoint = request.url.path # tags = {"method": method, "endpoint": endpoint} # try: # start = time.perf_counter() # response = await call_next(request) # elapsed_time = time.perf_counter() - start # tags.update({"status_code": response.status_code}) # metric_provider.counter("server.call.counter", tags=tags) # metric_provider.timer("server.call.elapsed", value=elapsed_time, tags=tags) # log.debug(f"server.call.elapsed.{endpoint}: {elapsed_time}") # except Exception as e: # metric_provider.counter("server.call.exception.counter", tags=tags) # raise e from None # return response # app.add_middleware(ExceptionMiddleware) app.include_router(api_router)