diff --git a/src/auth/service.py b/src/auth/service.py index 034e87b..b293521 100644 --- a/src/auth/service.py +++ b/src/auth/service.py @@ -11,7 +11,6 @@ import src.config as config from .model import UserBase from .util import extract_template -from tomlkit.items import Bool class JWTBearer(HTTPBearer): def __init__(self, auto_error: bool = True): diff --git a/src/main.py b/src/main.py index cb31470..c7611b2 100644 --- a/src/main.py +++ b/src/main.py @@ -28,7 +28,7 @@ from src.enums import ResponseStatus from src.exceptions import handle_exception from src.logging import configure_logging from src.rate_limiter import limiter - +import config log = logging.getLogger(__name__) # we configure the logging level and format @@ -60,6 +60,66 @@ _request_id_ctx_var: ContextVar[Optional[str]] = ContextVar( def get_request_id() -> Optional[str]: return _request_id_ctx_var.get() +def security_headers_middleware(app: FastAPI): + is_production = config.ENV == "PROD" + + # CSP rules + csp_policy = { + "default-src": "'self'", + "script-src": "'self' 'unsafe-inline' https://cdnjs.cloudflare.com https://cdn.jsdelivr.net", + "style-src": "'self' 'unsafe-inline' https://fonts.googleapis.com https://cdn.jsdelivr.net", + "img-src": "'self' data: https: blob:", + "font-src": "'self' https://fonts.gstatic.com data:", + "connect-src": "'self' https://api.your-domain.com wss://ws.your-domain.com", + "frame-src": "'none'", + "object-src": "'none'", + "base-uri": "'self'", + "form-action": "'self'", + } + + # Feature / Permissions Policy + feature_policy = { + "geolocation": "'none'", + "midi": "'none'", + "notifications": "'none'", + "push": "'none'", + "sync-xhr": "'none'", + "microphone": "'none'", + "camera": "'none'", + "magnetometer": "'none'", + "gyroscope": "'none'", + "speaker": "'none'", + "vibrate": "'none'", + "fullscreen": "'self'", + "payment": "'none'", + } + + csp_header_value = "; ".join(f"{k} {v}" for k, v in csp_policy.items()) + feature_header_value = "; ".join(f"{k}={v}" for k, v in feature_policy.items()) + + # Middleware definition + class SecurityHeadersMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + response: Response = await call_next(request) + + if is_production: + response.headers["Strict-Transport-Security"] = "max-age=15724800; includeSubDomains; preload" + response.headers["X-Frame-Options"] = "DENY" + response.headers["X-Content-Type-Options"] = "nosniff" + response.headers["X-XSS-Protection"] = "1; mode=block" + response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin" + response.headers["Content-Security-Policy"] = csp_header_value + response.headers["Permissions-Policy"] = feature_header_value + else: + # Relaxed settings for development + response.headers["Content-Security-Policy"] = "default-src 'self' 'unsafe-inline' 'unsafe-eval' *" + # You can skip some headers here for local testing + + return response + + app.add_middleware(SecurityHeadersMiddleware) + +security_headers_middleware(app) @app.middleware("http") async def db_session_middleware(request: Request, call_next):