You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

168 lines
5.8 KiB
Python

import logging
import os
import sys
import time
from os import path
from uuid import uuid1
from typing import Optional, Final
from fastapi import FastAPI, HTTPException, status, Path
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from pydantic import ValidationError
from slowapi import _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded
from sqlalchemy import inspect
from sqlalchemy.orm import scoped_session
from sqlalchemy.ext.asyncio import async_scoped_session
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
from starlette.routing import compile_path
from starlette.middleware.gzip import GZipMiddleware
from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import Response, StreamingResponse, FileResponse
from starlette.staticfiles import StaticFiles
from src.enums import ResponseStatus
from src.logging import configure_logging
from src.rate_limiter import limiter
from src.api import api_router
from src.database.core import engine, async_session, async_aeros_session
from src.exceptions import handle_exception
from src.middleware import RequestValidationMiddleware
from src.context import set_request_id, reset_request_id, get_request_id
from sqlalchemy.exc import SQLAlchemyError
log = logging.getLogger(__name__)
from starlette.exceptions import HTTPException as StarletteHTTPException
# we configure the logging level and format
configure_logging()
# we create the ASGI for the app
app = FastAPI(openapi_url="", title="LCCA API",
description="Welcome to RBD's API documentation!",
version="0.1.0")
# we define the exception handlers
app.add_exception_handler(Exception, handle_exception)
app.add_exception_handler(HTTPException, handle_exception)
app.add_exception_handler(StarletteHTTPException, handle_exception)
app.add_exception_handler(RequestValidationError, handle_exception)
app.add_exception_handler(RateLimitExceeded, handle_exception)
app.add_exception_handler(SQLAlchemyError, handle_exception)
app.state.limiter = limiter
app.add_middleware(GZipMiddleware, minimum_size=2000)
app.add_middleware(RequestValidationMiddleware)
@app.middleware("http")
async def db_session_middleware(request: Request, call_next):
request_id = str(uuid1())
# we create a per-request id such that we can ensure that our session is scoped for a particular request.
# see: https://github.com/tiangolo/fastapi/issues/726
ctx_token = set_request_id(request_id)
try:
session = async_scoped_session(async_session, scopefunc=get_request_id)
request.state.db = session()
collector_session = async_scoped_session(async_aeros_session, scopefunc=get_request_id)
request.state.aeros_db = collector_session()
start_time = time.time()
response = await call_next(request)
process_time = (time.time() - start_time) * 1000
from src.context import get_username, get_role, get_user_id, set_user_id, set_username, set_role
# Pull from context or fallback to request.state
username = get_username()
role = get_role()
user_id = get_user_id()
user_obj = getattr(request.state, "user", None)
if user_obj:
if not user_id and hasattr(user_obj, "user_id"):
user_id = str(user_obj.user_id)
set_user_id(user_id)
if not username and hasattr(user_obj, "name"):
username = user_obj.name
set_username(username)
if not role and hasattr(user_obj, "role"):
role = user_obj.role
set_role(role)
user_info_str = ""
if username:
user_info_str = f" | User: {username}"
if role:
user_info_str += f" ({role})"
log.info(
f"HTTP {request.method} {request.url.path} completed in {round(process_time, 2)}ms{user_info_str}",
extra={
"method": request.method,
"path": request.url.path,
"status_code": response.status_code,
"duration_ms": round(process_time, 2),
"user_id": user_id,
"role": role,
},
)
except Exception as e:
# Generate an error_id here if it hasn't been generated yet (e.g., if it failed before the handler)
error_id = getattr(request.state, "error_id", None)
if not error_id:
import uuid
error_id = str(uuid.uuid1())
request.state.error_id = error_id
log.error(
f"Request failed | Error ID: {error_id}",
extra={
"method": request.method,
"path": request.url.path,
"error": str(e),
"error_id": error_id,
},
exc_info=True,
)
raise e from None
finally:
await request.state.db.close()
await request.state.aeros_db.close()
reset_request_id(ctx_token)
return response
@app.middleware("http")
async def add_security_headers(request: Request, call_next):
response = await call_next(request)
response.headers["Strict-Transport-Security"] = "max-age=31536000 ; includeSubDomains"
return response
app.mount("/model", StaticFiles(directory="model"), name="model")
@app.get("/images/{image_path:path}")
async def get_image(image_path: str = Path(...)):
# Extract filename from the full path
filename = os.path.basename(image_path)
full_image_path = f"model/RBD Model/Image/{filename}"
if os.path.exists(full_image_path):
return FileResponse(full_image_path)
else:
raise HTTPException(status_code=404, detail="Image not found")
app.include_router(api_router)