You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

137 lines
4.5 KiB
Python

import logging
import os
import sys
import time
from contextvars import ContextVar
from os import path
from typing import Final, Optional
from uuid import uuid1
from fastapi import FastAPI, HTTPException, Path, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import ValidationError
from slowapi import _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded
from sqlalchemy import inspect
from sqlalchemy.ext.asyncio import async_scoped_session
from sqlalchemy.orm import scoped_session
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.middleware.gzip import GZipMiddleware
from starlette.requests import Request
from starlette.responses import FileResponse, Response, StreamingResponse
from starlette.routing import compile_path
from starlette.staticfiles import StaticFiles
from src.api import api_router
from src.database.core import async_session, engine,async_aeros_session
from src.enums import ResponseStatus
from src.exceptions import handle_exception
from src.logging import setup_logging
from src.rate_limiter import limiter
log = logging.getLogger(__name__)
# we configure the logging level and format
# we define the exception handlers
exception_handlers = {Exception: handle_exception}
# we create the ASGI for the app
app = FastAPI(
exception_handlers=exception_handlers,
openapi_url="",
title="LCCA API",
description="Welcome to RBD's API documentation!",
version="0.1.0",
)
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
app.add_middleware(GZipMiddleware, minimum_size=2000)
# credentials: "include",
setup_logging(logger=log)
log.info('API is starting up')
REQUEST_ID_CTX_KEY: Final[str] = "request_id"
_request_id_ctx_var: ContextVar[Optional[str]] = ContextVar(
REQUEST_ID_CTX_KEY, default=None
)
def get_request_id() -> Optional[str]:
return _request_id_ctx_var.get()
@app.middleware("http")
async def db_session_middleware(request: Request, call_next):
request_id = str(uuid1())
# we create a per-request id such that we can ensure that our session is scoped for a particular request.
# see: https://github.com/tiangolo/fastapi/issues/726
ctx_token = _request_id_ctx_var.set(request_id)
try:
session = async_scoped_session(async_session, scopefunc=get_request_id)
request.state.db = session()
collector_session = async_scoped_session(async_aeros_session, scopefunc=get_request_id)
request.state.aeros_db = collector_session()
response = await call_next(request)
except Exception as e:
raise e from None
finally:
await request.state.db.close()
await request.state.aeros_db.close()
_request_id_ctx_var.reset(ctx_token)
return response
@app.middleware("http")
async def add_security_headers(request: Request, call_next):
response = await call_next(request)
response.headers["Strict-Transport-Security"] = (
"max-age=31536000 ; includeSubDomains"
)
return response
app.mount("/model", StaticFiles(directory="model"), name="model")
# class MetricsMiddleware(BaseHTTPMiddleware):
# async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
# method = request.method
# endpoint = request.url.path
# tags = {"method": method, "endpoint": endpoint}
# try:
# start = time.perf_counter()
# response = await call_next(request)
# elapsed_time = time.perf_counter() - start
# tags.update({"status_code": response.status_code})
# metric_provider.counter("server.call.counter", tags=tags)
# metric_provider.timer("server.call.elapsed", value=elapsed_time, tags=tags)
# log.debug(f"server.call.elapsed.{endpoint}: {elapsed_time}")
# except Exception as e:
# metric_provider.counter("server.call.exception.counter", tags=tags)
# raise e from None
# return response
# app.add_middleware(ExceptionMiddleware)
@app.get("/images/{image_path:path}")
async def get_image(image_path: str = Path(...)):
# Extract filename from the full path
filename = os.path.basename(image_path)
full_image_path = f"model/RBD Model/Image/{filename}"
if os.path.exists(full_image_path):
return FileResponse(full_image_path)
else:
raise HTTPException(status_code=404, detail="Image not found")
app.include_router(api_router)