You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
168 lines
5.8 KiB
Python
168 lines
5.8 KiB
Python
import logging
|
|
import os
|
|
import sys
|
|
import time
|
|
from os import path
|
|
from uuid import uuid1
|
|
from typing import Optional, Final
|
|
|
|
from fastapi import FastAPI, HTTPException, status, Path
|
|
from fastapi.exceptions import RequestValidationError
|
|
from fastapi.responses import JSONResponse
|
|
from pydantic import ValidationError
|
|
|
|
from slowapi import _rate_limit_exceeded_handler
|
|
from slowapi.errors import RateLimitExceeded
|
|
from sqlalchemy import inspect
|
|
from sqlalchemy.orm import scoped_session
|
|
from sqlalchemy.ext.asyncio import async_scoped_session
|
|
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
|
from starlette.requests import Request
|
|
from starlette.routing import compile_path
|
|
from starlette.middleware.gzip import GZipMiddleware
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
from starlette.responses import Response, StreamingResponse, FileResponse
|
|
from starlette.staticfiles import StaticFiles
|
|
|
|
from src.enums import ResponseStatus
|
|
from src.logging import configure_logging
|
|
from src.rate_limiter import limiter
|
|
from src.api import api_router
|
|
from src.database.core import engine, async_session, async_aeros_session
|
|
from src.exceptions import handle_exception
|
|
from src.middleware import RequestValidationMiddleware
|
|
from src.context import set_request_id, reset_request_id, get_request_id
|
|
from sqlalchemy.exc import SQLAlchemyError
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
from starlette.exceptions import HTTPException as StarletteHTTPException
|
|
|
|
# we configure the logging level and format
|
|
configure_logging()
|
|
|
|
# we create the ASGI for the app
|
|
app = FastAPI(openapi_url="", title="LCCA API",
|
|
description="Welcome to RBD's API documentation!",
|
|
version="0.1.0")
|
|
|
|
# we define the exception handlers
|
|
app.add_exception_handler(Exception, handle_exception)
|
|
app.add_exception_handler(HTTPException, handle_exception)
|
|
app.add_exception_handler(StarletteHTTPException, handle_exception)
|
|
app.add_exception_handler(RequestValidationError, handle_exception)
|
|
app.add_exception_handler(RateLimitExceeded, handle_exception)
|
|
app.add_exception_handler(SQLAlchemyError, handle_exception)
|
|
|
|
app.state.limiter = limiter
|
|
app.add_middleware(GZipMiddleware, minimum_size=2000)
|
|
|
|
app.add_middleware(RequestValidationMiddleware)
|
|
|
|
|
|
@app.middleware("http")
|
|
async def db_session_middleware(request: Request, call_next):
|
|
request_id = str(uuid1())
|
|
|
|
# we create a per-request id such that we can ensure that our session is scoped for a particular request.
|
|
# see: https://github.com/tiangolo/fastapi/issues/726
|
|
ctx_token = set_request_id(request_id)
|
|
|
|
try:
|
|
session = async_scoped_session(async_session, scopefunc=get_request_id)
|
|
request.state.db = session()
|
|
|
|
collector_session = async_scoped_session(async_aeros_session, scopefunc=get_request_id)
|
|
request.state.aeros_db = collector_session()
|
|
|
|
start_time = time.time()
|
|
response = await call_next(request)
|
|
process_time = (time.time() - start_time) * 1000
|
|
|
|
from src.context import get_username, get_role, get_user_id, set_user_id, set_username, set_role
|
|
|
|
# Pull from context or fallback to request.state
|
|
username = get_username()
|
|
role = get_role()
|
|
user_id = get_user_id()
|
|
|
|
user_obj = getattr(request.state, "user", None)
|
|
if user_obj:
|
|
if not user_id and hasattr(user_obj, "user_id"):
|
|
user_id = str(user_obj.user_id)
|
|
set_user_id(user_id)
|
|
if not username and hasattr(user_obj, "name"):
|
|
username = user_obj.name
|
|
set_username(username)
|
|
if not role and hasattr(user_obj, "role"):
|
|
role = user_obj.role
|
|
set_role(role)
|
|
|
|
user_info_str = ""
|
|
if username:
|
|
user_info_str = f" | User: {username}"
|
|
if role:
|
|
user_info_str += f" ({role})"
|
|
|
|
log.info(
|
|
f"HTTP {request.method} {request.url.path} completed in {round(process_time, 2)}ms{user_info_str}",
|
|
extra={
|
|
"method": request.method,
|
|
"path": request.url.path,
|
|
"status_code": response.status_code,
|
|
"duration_ms": round(process_time, 2),
|
|
"user_id": user_id,
|
|
"role": role,
|
|
},
|
|
)
|
|
|
|
except Exception as e:
|
|
# Generate an error_id here if it hasn't been generated yet (e.g., if it failed before the handler)
|
|
error_id = getattr(request.state, "error_id", None)
|
|
if not error_id:
|
|
import uuid
|
|
error_id = str(uuid.uuid1())
|
|
request.state.error_id = error_id
|
|
|
|
log.error(
|
|
f"Request failed | Error ID: {error_id}",
|
|
extra={
|
|
"method": request.method,
|
|
"path": request.url.path,
|
|
"error": str(e),
|
|
"error_id": error_id,
|
|
},
|
|
exc_info=True,
|
|
)
|
|
raise e from None
|
|
finally:
|
|
await request.state.db.close()
|
|
await request.state.aeros_db.close()
|
|
|
|
reset_request_id(ctx_token)
|
|
return response
|
|
|
|
|
|
@app.middleware("http")
|
|
async def add_security_headers(request: Request, call_next):
|
|
response = await call_next(request)
|
|
response.headers["Strict-Transport-Security"] = "max-age=31536000 ; includeSubDomains"
|
|
return response
|
|
|
|
|
|
app.mount("/model", StaticFiles(directory="model"), name="model")
|
|
|
|
@app.get("/images/{image_path:path}")
|
|
async def get_image(image_path: str = Path(...)):
|
|
# Extract filename from the full path
|
|
filename = os.path.basename(image_path)
|
|
full_image_path = f"model/RBD Model/Image/{filename}"
|
|
|
|
if os.path.exists(full_image_path):
|
|
return FileResponse(full_image_path)
|
|
else:
|
|
raise HTTPException(status_code=404, detail="Image not found")
|
|
|
|
app.include_router(api_router)
|