You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

178 lines
6.2 KiB
Python

import time
import logging
from os import path
from uuid import uuid1
from typing import Optional, Final
from fastapi import FastAPI, HTTPException, status
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from pydantic import ValidationError
from slowapi import _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded
from sqlalchemy import inspect
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import scoped_session
from sqlalchemy.ext.asyncio import async_scoped_session
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
from starlette.routing import compile_path
from starlette.middleware.gzip import GZipMiddleware
from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import Response, StreamingResponse, FileResponse
from starlette.staticfiles import StaticFiles
import logging
from src.enums import ResponseStatus
from src.logging import configure_logging
from src.rate_limiter import limiter
from src.api import api_router
from src.database.core import engine, async_session, collector_async_session
from src.exceptions import handle_exception
from src.middleware import RequestValidationMiddleware
log = logging.getLogger(__name__)
# we configure the logging level and format
configure_logging()
# we create the ASGI for the app
app = FastAPI(openapi_url="", title="LCCA API",
description="Welcome to LCCA's API documentation!",
version="0.1.0")
app.state.limiter = limiter
# we define the exception handlers
app.add_exception_handler(Exception, handle_exception)
app.add_exception_handler(HTTPException, handle_exception)
app.add_exception_handler(RequestValidationError, handle_exception)
app.add_exception_handler(RateLimitExceeded, handle_exception)
app.add_exception_handler(SQLAlchemyError, handle_exception)
from src.context import set_request_id, reset_request_id, get_request_id
app.add_middleware(RequestValidationMiddleware)
@app.middleware("http")
async def db_session_middleware(request: Request, call_next):
request_id = str(uuid1())
# we create a per-request id such that we can ensure that our session is scoped for a particular request.
# see: https://github.com/tiangolo/fastapi/issues/726
ctx_token = set_request_id(request_id)
try:
start_time = time.time()
session = async_scoped_session(async_session, scopefunc=get_request_id)
request.state.db = session()
collector_session = async_scoped_session(collector_async_session, scopefunc=get_request_id)
request.state.collector_db = collector_session()
response = await call_next(request)
process_time = (time.time() - start_time) * 1000
from src.context import get_username, get_role, get_user_id, set_user_id, set_username, set_role
# Pull from context or fallback to request.state.user
username = get_username()
role = get_role()
user_id = get_user_id()
user_obj = getattr(request.state, "user", None)
if user_obj:
# UserBase in this project
u_id = getattr(user_obj, "user_id", None)
u_name = getattr(user_obj, "name", None) or getattr(user_obj, "username", None)
u_role = getattr(user_obj, "role", None)
if not user_id and u_id:
user_id = str(u_id)
set_user_id(user_id)
if not username and u_name:
username = u_name
set_username(username)
if not role and u_role:
role = u_role
set_role(role)
user_info_str = ""
if username:
user_info_str = f" | User: {username}"
if role:
user_info_str += f" ({role})"
log.info(
f"HTTP {request.method} {request.url.path} completed in {round(process_time, 2)}ms{user_info_str}",
extra={
"method": request.method,
"path": request.url.path,
"status_code": response.status_code,
"duration_ms": round(process_time, 2),
"user_id": user_id,
"role": role,
},
)
except Exception as e:
# Generate an error_id here if it hasn't been generated yet
error_id = getattr(request.state, "error_id", None)
if not error_id:
import uuid
error_id = str(uuid.uuid1())
request.state.error_id = error_id
log.error(
f"Request failed | Error ID: {error_id}",
extra={
"method": request.method,
"path": request.url.path,
"error": str(e),
"error_id": error_id,
},
exc_info=True,
)
raise e from None
finally:
await request.state.db.close()
await request.state.collector_db.close()
reset_request_id(ctx_token)
return response
@app.middleware("http")
async def add_security_headers(request: Request, call_next):
response = await call_next(request)
response.headers["Strict-Transport-Security"] = "max-age=31536000 ; includeSubDomains"
return response
# class MetricsMiddleware(BaseHTTPMiddleware):
# async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
# method = request.method
# endpoint = request.url.path
# tags = {"method": method, "endpoint": endpoint}
# try:
# start = time.perf_counter()
# response = await call_next(request)
# elapsed_time = time.perf_counter() - start
# tags.update({"status_code": response.status_code})
# metric_provider.counter("server.call.counter", tags=tags)
# metric_provider.timer("server.call.elapsed", value=elapsed_time, tags=tags)
# log.debug(f"server.call.elapsed.{endpoint}: {elapsed_time}")
# except Exception as e:
# metric_provider.counter("server.call.exception.counter", tags=tags)
# raise e from None
# return response
# app.add_middleware(ExceptionMiddleware)
app.include_router(api_router)