You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
243 lines
8.7 KiB
Python
243 lines
8.7 KiB
Python
import logging
|
|
import time
|
|
from src.context import set_request_id, reset_request_id, get_request_id
|
|
from contextvars import ContextVar
|
|
from os import path
|
|
from typing import Final, Optional
|
|
from uuid import uuid1
|
|
|
|
from fastapi import FastAPI, HTTPException, status
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import JSONResponse
|
|
from pydantic import ValidationError
|
|
from slowapi import _rate_limit_exceeded_handler
|
|
from slowapi.errors import RateLimitExceeded
|
|
from sqlalchemy import inspect
|
|
from sqlalchemy.ext.asyncio import async_scoped_session
|
|
from sqlalchemy.orm import scoped_session
|
|
from starlette.middleware.base import (BaseHTTPMiddleware,
|
|
RequestResponseEndpoint)
|
|
from starlette.middleware.gzip import GZipMiddleware
|
|
from starlette.requests import Request
|
|
from starlette.responses import FileResponse, Response, StreamingResponse
|
|
from starlette.routing import compile_path
|
|
from starlette.staticfiles import StaticFiles
|
|
|
|
from src.api import api_router
|
|
from src.database.core import async_session, engine, async_collector_session
|
|
from src.enums import ResponseStatus
|
|
from src.exceptions import handle_exception
|
|
from src.logging import configure_logging
|
|
from src.middleware import RequestValidationMiddleware
|
|
from src.rate_limiter import limiter
|
|
from fastapi.exceptions import RequestValidationError
|
|
from starlette.exceptions import HTTPException as StarletteHTTPException
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
# we configure the logging level and format
|
|
configure_logging()
|
|
|
|
# we create the ASGI for the app
|
|
app = FastAPI(
|
|
openapi_url="",
|
|
title="LCCA API",
|
|
description="Welcome to LCCA's API documentation!",
|
|
version="0.1.0",
|
|
)
|
|
|
|
# we define the exception handlers
|
|
app.add_exception_handler(Exception, handle_exception)
|
|
app.add_exception_handler(HTTPException, handle_exception)
|
|
app.add_exception_handler(StarletteHTTPException, handle_exception)
|
|
app.add_exception_handler(RequestValidationError, handle_exception)
|
|
app.add_exception_handler(RateLimitExceeded, handle_exception)
|
|
|
|
app.state.limiter = limiter
|
|
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
|
# credentials: "include",
|
|
|
|
|
|
|
|
|
|
|
|
def security_headers_middleware(app: FastAPI):
|
|
is_production = False
|
|
|
|
# CSP rules
|
|
csp_policy = {
|
|
"default-src": "'self'",
|
|
"script-src": "'self' 'unsafe-inline' https://cdnjs.cloudflare.com https://cdn.jsdelivr.net",
|
|
"style-src": "'self' 'unsafe-inline' https://fonts.googleapis.com https://cdn.jsdelivr.net",
|
|
"img-src": "'self' data: https: blob:",
|
|
"font-src": "'self' https://fonts.gstatic.com data:",
|
|
"connect-src": "'self' https://api.your-domain.com wss://ws.your-domain.com",
|
|
"frame-src": "'none'",
|
|
"object-src": "'none'",
|
|
"base-uri": "'self'",
|
|
"form-action": "'self'",
|
|
}
|
|
|
|
# Feature / Permissions Policy
|
|
feature_policy = {
|
|
"geolocation": "'none'",
|
|
"midi": "'none'",
|
|
"notifications": "'none'",
|
|
"push": "'none'",
|
|
"sync-xhr": "'none'",
|
|
"microphone": "'none'",
|
|
"camera": "'none'",
|
|
"magnetometer": "'none'",
|
|
"gyroscope": "'none'",
|
|
"speaker": "'none'",
|
|
"vibrate": "'none'",
|
|
"fullscreen": "'self'",
|
|
"payment": "'none'",
|
|
}
|
|
|
|
csp_header_value = "; ".join(f"{k} {v}" for k, v in csp_policy.items())
|
|
feature_header_value = "; ".join(f"{k}={v}" for k, v in feature_policy.items())
|
|
|
|
# Middleware definition
|
|
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
|
async def dispatch(self, request: Request, call_next):
|
|
response: Response = await call_next(request)
|
|
|
|
if is_production:
|
|
response.headers["Strict-Transport-Security"] = "max-age=15724800; includeSubDomains; preload"
|
|
response.headers["X-Frame-Options"] = "DENY"
|
|
response.headers["X-Content-Type-Options"] = "nosniff"
|
|
response.headers["X-XSS-Protection"] = "1; mode=block"
|
|
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
|
response.headers["Content-Security-Policy"] = csp_header_value
|
|
response.headers["Permissions-Policy"] = feature_header_value
|
|
else:
|
|
# Relaxed settings for development
|
|
response.headers["Content-Security-Policy"] = "default-src 'self' 'unsafe-inline' 'unsafe-eval' *"
|
|
# You can skip some headers here for local testing
|
|
|
|
return response
|
|
|
|
app.add_middleware(SecurityHeadersMiddleware)
|
|
|
|
security_headers_middleware(app)
|
|
|
|
|
|
app.add_middleware(RequestValidationMiddleware)
|
|
|
|
|
|
|
|
|
|
@app.middleware("http")
|
|
async def db_session_middleware(request: Request, call_next):
|
|
request_id = str(uuid1())
|
|
|
|
# we create a per-request id such that we can ensure that our session is scoped for a particular request.
|
|
# see: https://github.com/tiangolo/fastapi/issues/726
|
|
ctx_token = set_request_id(request_id)
|
|
|
|
try:
|
|
session = async_scoped_session(async_session, scopefunc=get_request_id)
|
|
request.state.db = session()
|
|
|
|
collector_session = async_scoped_session(async_collector_session, scopefunc=get_request_id)
|
|
request.state.collector_db = collector_session()
|
|
|
|
start_time = time.time()
|
|
response = await call_next(request)
|
|
process_time = (time.time() - start_time) * 1000
|
|
|
|
# Skip logging in middleware if it's an error (already logged in handle_exception)
|
|
if response.status_code >= 400:
|
|
return response
|
|
|
|
from src.context import get_username, get_role, get_user_id, set_user_id, set_username, set_role
|
|
|
|
# Pull from context or fallback to request.state.user
|
|
username = get_username()
|
|
role = get_role()
|
|
user_id = get_user_id()
|
|
|
|
user_obj = getattr(request.state, "user", None)
|
|
if user_obj:
|
|
# message is UserBase dict/obj in this project
|
|
if isinstance(user_obj, dict):
|
|
u_id = user_obj.get("user_id")
|
|
u_name = user_obj.get("name") or user_obj.get("username")
|
|
u_role = user_obj.get("role")
|
|
else:
|
|
u_id = getattr(user_obj, "user_id", None)
|
|
u_name = getattr(user_obj, "name", None) or getattr(user_obj, "username", None)
|
|
u_role = getattr(user_obj, "role", None)
|
|
|
|
if not user_id and u_id:
|
|
user_id = str(u_id)
|
|
set_user_id(user_id)
|
|
if not username and u_name:
|
|
username = u_name
|
|
set_username(username)
|
|
if not role and u_role:
|
|
role = u_role
|
|
set_role(role)
|
|
|
|
user_info_str = ""
|
|
if user_id:
|
|
user_info_str = f" | User ID: {user_id}"
|
|
|
|
error_id = getattr(request.state, "error_id", None)
|
|
log_msg = f"HTTP {request.method} {request.url.path} completed in {round(process_time, 2)}ms{user_info_str}"
|
|
if error_id:
|
|
log_msg += f" | Error ID: {error_id}"
|
|
|
|
log.info(
|
|
log_msg,
|
|
extra={
|
|
"method": request.method,
|
|
"path": request.url.path,
|
|
"status_code": response.status_code,
|
|
"duration_ms": round(process_time, 2),
|
|
"user_id": user_id,
|
|
"error_id": error_id,
|
|
},
|
|
)
|
|
finally:
|
|
await request.state.db.close()
|
|
await request.state.collector_db.close()
|
|
|
|
reset_request_id(ctx_token)
|
|
return response
|
|
|
|
|
|
@app.middleware("http")
|
|
async def add_security_headers(request: Request, call_next):
|
|
response = await call_next(request)
|
|
response.headers["Strict-Transport-Security"] = (
|
|
"max-age=31536000 ; includeSubDomains"
|
|
)
|
|
return response
|
|
|
|
|
|
# class MetricsMiddleware(BaseHTTPMiddleware):
|
|
# async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
|
|
# method = request.method
|
|
# endpoint = request.url.path
|
|
# tags = {"method": method, "endpoint": endpoint}
|
|
|
|
# try:
|
|
# start = time.perf_counter()
|
|
# response = await call_next(request)
|
|
# elapsed_time = time.perf_counter() - start
|
|
# tags.update({"status_code": response.status_code})
|
|
# metric_provider.counter("server.call.counter", tags=tags)
|
|
# metric_provider.timer("server.call.elapsed", value=elapsed_time, tags=tags)
|
|
# log.debug(f"server.call.elapsed.{endpoint}: {elapsed_time}")
|
|
# except Exception as e:
|
|
# metric_provider.counter("server.call.exception.counter", tags=tags)
|
|
# raise e from None
|
|
# return response
|
|
|
|
|
|
# app.add_middleware(ExceptionMiddleware)
|
|
|
|
app.include_router(api_router)
|