You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
137 lines
4.5 KiB
Python
137 lines
4.5 KiB
Python
import logging
|
|
import os
|
|
import sys
|
|
import time
|
|
from contextvars import ContextVar
|
|
from os import path
|
|
from typing import Final, Optional
|
|
from uuid import uuid1
|
|
|
|
from fastapi import FastAPI, HTTPException, Path, status
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import JSONResponse
|
|
from pydantic import ValidationError
|
|
from slowapi import _rate_limit_exceeded_handler
|
|
from slowapi.errors import RateLimitExceeded
|
|
from sqlalchemy import inspect
|
|
from sqlalchemy.ext.asyncio import async_scoped_session
|
|
from sqlalchemy.orm import scoped_session
|
|
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
|
from starlette.middleware.gzip import GZipMiddleware
|
|
from starlette.requests import Request
|
|
from starlette.responses import FileResponse, Response, StreamingResponse
|
|
from starlette.routing import compile_path
|
|
from starlette.staticfiles import StaticFiles
|
|
|
|
from src.api import api_router
|
|
from src.database.core import async_session, engine,async_aeros_session
|
|
from src.enums import ResponseStatus
|
|
from src.exceptions import handle_exception
|
|
from src.logging import setup_logging
|
|
from src.rate_limiter import limiter
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
# we configure the logging level and format
|
|
|
|
|
|
# we define the exception handlers
|
|
exception_handlers = {Exception: handle_exception}
|
|
|
|
# we create the ASGI for the app
|
|
app = FastAPI(
|
|
exception_handlers=exception_handlers,
|
|
openapi_url="",
|
|
title="LCCA API",
|
|
description="Welcome to RBD's API documentation!",
|
|
version="0.1.0",
|
|
)
|
|
app.state.limiter = limiter
|
|
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
|
app.add_middleware(GZipMiddleware, minimum_size=2000)
|
|
# credentials: "include",
|
|
setup_logging(logger=log)
|
|
log.info('API is starting up')
|
|
|
|
|
|
REQUEST_ID_CTX_KEY: Final[str] = "request_id"
|
|
_request_id_ctx_var: ContextVar[Optional[str]] = ContextVar(
|
|
REQUEST_ID_CTX_KEY, default=None
|
|
)
|
|
|
|
|
|
def get_request_id() -> Optional[str]:
|
|
return _request_id_ctx_var.get()
|
|
|
|
|
|
@app.middleware("http")
|
|
async def db_session_middleware(request: Request, call_next):
|
|
request_id = str(uuid1())
|
|
|
|
# we create a per-request id such that we can ensure that our session is scoped for a particular request.
|
|
# see: https://github.com/tiangolo/fastapi/issues/726
|
|
ctx_token = _request_id_ctx_var.set(request_id)
|
|
|
|
try:
|
|
session = async_scoped_session(async_session, scopefunc=get_request_id)
|
|
request.state.db = session()
|
|
|
|
collector_session = async_scoped_session(async_aeros_session, scopefunc=get_request_id)
|
|
request.state.aeros_db = collector_session()
|
|
|
|
response = await call_next(request)
|
|
except Exception as e:
|
|
raise e from None
|
|
finally:
|
|
await request.state.db.close()
|
|
await request.state.aeros_db.close()
|
|
|
|
_request_id_ctx_var.reset(ctx_token)
|
|
return response
|
|
|
|
|
|
@app.middleware("http")
|
|
async def add_security_headers(request: Request, call_next):
|
|
response = await call_next(request)
|
|
response.headers["Strict-Transport-Security"] = (
|
|
"max-age=31536000 ; includeSubDomains"
|
|
)
|
|
return response
|
|
|
|
app.mount("/model", StaticFiles(directory="model"), name="model")
|
|
|
|
# class MetricsMiddleware(BaseHTTPMiddleware):
|
|
# async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
|
|
# method = request.method
|
|
# endpoint = request.url.path
|
|
# tags = {"method": method, "endpoint": endpoint}
|
|
|
|
# try:
|
|
# start = time.perf_counter()
|
|
# response = await call_next(request)
|
|
# elapsed_time = time.perf_counter() - start
|
|
# tags.update({"status_code": response.status_code})
|
|
# metric_provider.counter("server.call.counter", tags=tags)
|
|
# metric_provider.timer("server.call.elapsed", value=elapsed_time, tags=tags)
|
|
# log.debug(f"server.call.elapsed.{endpoint}: {elapsed_time}")
|
|
# except Exception as e:
|
|
# metric_provider.counter("server.call.exception.counter", tags=tags)
|
|
# raise e from None
|
|
# return response
|
|
|
|
|
|
# app.add_middleware(ExceptionMiddleware)
|
|
|
|
@app.get("/images/{image_path:path}")
|
|
async def get_image(image_path: str = Path(...)):
|
|
# Extract filename from the full path
|
|
filename = os.path.basename(image_path)
|
|
full_image_path = f"model/RBD Model/Image/{filename}"
|
|
|
|
if os.path.exists(full_image_path):
|
|
return FileResponse(full_image_path)
|
|
else:
|
|
raise HTTPException(status_code=404, detail="Image not found")
|
|
|
|
app.include_router(api_router)
|