You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

123 lines
4.5 KiB
Python

import time
import logging
from os import path
from uuid import uuid1
from typing import Optional, Final
from fastapi import FastAPI, HTTPException, status
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from pydantic import ValidationError
from slowapi import _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded
from sqlalchemy import inspect
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import scoped_session
from sqlalchemy.ext.asyncio import async_scoped_session
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
from starlette.routing import compile_path
from starlette.middleware.gzip import GZipMiddleware
from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import Response, StreamingResponse, FileResponse
from starlette.staticfiles import StaticFiles
import logging
from src.enums import ResponseStatus
from src.logging import configure_logging
from src.rate_limiter import limiter
from src.api import api_router
from src.database.core import engine, async_session, collector_async_session
from src.exceptions import handle_exception
from src.middleware import RequestValidationMiddleware
log = logging.getLogger(__name__)
# we configure the logging level and format
configure_logging()
# we create the ASGI for the app
app = FastAPI(openapi_url="", title="LCCA API",
description="Welcome to LCCA's API documentation!",
version="0.1.0")
app.state.limiter = limiter
# we define the exception handlers
app.add_exception_handler(Exception, handle_exception)
app.add_exception_handler(HTTPException, handle_exception)
app.add_exception_handler(StarletteHTTPException, handle_exception)
app.add_exception_handler(RequestValidationError, handle_exception)
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
app.add_exception_handler(SQLAlchemyError, handle_exception)
from src.context import set_request_id, reset_request_id, get_request_id
app.add_middleware(RequestValidationMiddleware)
@app.middleware("http")
async def db_session_middleware(request: Request, call_next):
request_id = str(uuid1())
# we create a per-request id such that we can ensure that our session is scoped for a particular request.
# see: https://github.com/tiangolo/fastapi/issues/726
ctx_token = set_request_id(request_id)
try:
log.info(f"Incoming request: {request.method} {request.url.path}")
session = async_scoped_session(async_session, scopefunc=get_request_id)
request.state.db = session()
collector_session = async_scoped_session(collector_async_session, scopefunc=get_request_id)
request.state.collector_db = collector_session()
response = await call_next(request)
if response.status_code >= 400:
log.error(f"Request completed: {response.status_code}")
else:
log.info(f"Request completed: {response.status_code}")
except Exception as e:
log.error(f"Request failed: {type(e).__name__} - {str(e)}")
raise e from None
finally:
await request.state.db.close()
await request.state.collector_db.close()
reset_request_id(ctx_token)
return response
@app.middleware("http")
async def add_security_headers(request: Request, call_next):
response = await call_next(request)
response.headers["Strict-Transport-Security"] = "max-age=31536000 ; includeSubDomains"
return response
# class MetricsMiddleware(BaseHTTPMiddleware):
# async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
# method = request.method
# endpoint = request.url.path
# tags = {"method": method, "endpoint": endpoint}
# try:
# start = time.perf_counter()
# response = await call_next(request)
# elapsed_time = time.perf_counter() - start
# tags.update({"status_code": response.status_code})
# metric_provider.counter("server.call.counter", tags=tags)
# metric_provider.timer("server.call.elapsed", value=elapsed_time, tags=tags)
# log.debug(f"server.call.elapsed.{endpoint}: {elapsed_time}")
# except Exception as e:
# metric_provider.counter("server.call.exception.counter", tags=tags)
# raise e from None
# return response
# app.add_middleware(ExceptionMiddleware)
app.include_router(api_router)