Compare commits

...

2 Commits

@ -28,6 +28,17 @@ class JWTBearer(HTTPBearer):
) )
request.state.user = user_info request.state.user = user_info
from src.context import set_user_id, set_username, set_role
if hasattr(user_info, "user_id"):
set_user_id(str(user_info.user_id))
if hasattr(user_info, "username"):
set_username(user_info.username)
elif hasattr(user_info, "name"):
set_username(user_info.name)
if hasattr(user_info, "role"):
set_role(user_info.role)
return user_info return user_info
else: else:
raise HTTPException(status_code=403, detail="Invalid authorization code.") raise HTTPException(status_code=403, detail="Invalid authorization code.")
@ -46,7 +57,7 @@ class JWTBearer(HTTPBearer):
return UserBase(**user_data["data"]) return UserBase(**user_data["data"])
except Exception as e: except Exception as e:
print(f"Token verification error: {str(e)}") logging.error(f"Token verification error: {str(e)}")
return None return None

@ -51,7 +51,7 @@ def get_config():
config = get_config() config = get_config()
LOG_LEVEL = config("LOG_LEVEL", default=logging.WARNING) LOG_LEVEL = config("LOG_LEVEL", default="INFO")
ENV = config("ENV", default="local") ENV = config("ENV", default="local")
PORT = config("PORT", cast=int, default=8000) PORT = config("PORT", cast=int, default=8000)
HOST = config("HOST", default="localhost") HOST = config("HOST", default="localhost")

@ -2,8 +2,18 @@ from contextvars import ContextVar
from typing import Optional, Final from typing import Optional, Final
REQUEST_ID_CTX_KEY: Final[str] = "request_id" REQUEST_ID_CTX_KEY: Final[str] = "request_id"
USER_ID_CTX_KEY: Final[str] = "user_id"
USERNAME_CTX_KEY: Final[str] = "username"
ROLE_CTX_KEY: Final[str] = "role"
_request_id_ctx_var: ContextVar[Optional[str]] = ContextVar( _request_id_ctx_var: ContextVar[Optional[str]] = ContextVar(
REQUEST_ID_CTX_KEY, default=None) REQUEST_ID_CTX_KEY, default=None)
_user_id_ctx_var: ContextVar[Optional[str]] = ContextVar(
USER_ID_CTX_KEY, default=None)
_username_ctx_var: ContextVar[Optional[str]] = ContextVar(
USERNAME_CTX_KEY, default=None)
_role_ctx_var: ContextVar[Optional[str]] = ContextVar(
ROLE_CTX_KEY, default=None)
def get_request_id() -> Optional[str]: def get_request_id() -> Optional[str]:
@ -16,3 +26,27 @@ def set_request_id(request_id: str):
def reset_request_id(token): def reset_request_id(token):
_request_id_ctx_var.reset(token) _request_id_ctx_var.reset(token)
def get_user_id() -> Optional[str]:
return _user_id_ctx_var.get()
def set_user_id(user_id: str):
return _user_id_ctx_var.set(user_id)
def get_username() -> Optional[str]:
return _username_ctx_var.get()
def set_username(username: str):
return _username_ctx_var.set(username)
def get_role() -> Optional[str]:
return _role_ctx_var.get()
def set_role(role: str):
return _role_ctx_var.set(role)

@ -96,58 +96,86 @@ def handle_exception(request: Request, exc: Exception):
""" """
Global exception handler for Fastapi application. Global exception handler for Fastapi application.
""" """
import uuid
error_id = str(uuid.uuid1())
request_info = get_request_context(request) request_info = get_request_context(request)
# Store error_id in request.state for middleware/logging
request.state.error_id = error_id
if isinstance(exc, RateLimitExceeded): if isinstance(exc, RateLimitExceeded):
return _rate_limit_exceeded_handler(request, exc) logging.warning(
f"Rate limit exceeded | Error ID: {error_id}",
extra={
"error_id": error_id,
"error_category": "rate_limit",
"request": request_info,
"detail": str(exc.description) if hasattr(exc, "description") else str(exc),
},
)
return JSONResponse(
status_code=429,
content={
"data": None,
"message": "Rate limit exceeded",
"status": ResponseStatus.ERROR,
"error_id": error_id
}
)
if isinstance(exc, RequestValidationError): if isinstance(exc, RequestValidationError):
logging.error( logging.warning(
f"Validation error | Error: {str(exc.errors())} | Request: {request_info}", f"Validation error occurred | Error ID: {error_id}",
extra={"error_category": "validation"}, extra={
"error_id": error_id,
"error_category": "validation",
"errors": exc.errors(),
"request": request_info,
},
) )
return JSONResponse( return JSONResponse(
status_code=422, status_code=422,
content={ content={
"data": None, "data": exc.errors(),
"message": "Validation error", "message": "Validation Error",
"status": ResponseStatus.ERROR, "status": ResponseStatus.ERROR,
"errors": [ "error_id": error_id
ErrorDetail( },
field=".".join(map(str, err["loc"])),
message=err["msg"],
code=err["type"],
).model_dump()
for err in exc.errors()
]
}
) )
if isinstance(exc, HTTPException): if isinstance(exc, HTTPException):
logging.error( logging.error(
f"HTTP exception | Code: {exc.status_code} | Error: {exc.detail} | Request: {request_info}", f"HTTP exception occurred | Error ID: {error_id}",
extra={"error_category": "http"}, extra={
"error_id": error_id,
"error_category": "http",
"status_code": exc.status_code,
"detail": exc.detail if hasattr(exc, "detail") else str(exc),
"request": request_info,
},
) )
return JSONResponse( return JSONResponse(
status_code=exc.status_code, status_code=exc.status_code,
content={ content={
"data": None, "data": None,
"message": str(exc.detail), "message": str(exc.detail) if hasattr(exc, "detail") else str(exc),
"status": ResponseStatus.ERROR, "status": ResponseStatus.ERROR,
"errors": [ "error_id": error_id
ErrorDetail( },
message=str(exc.detail)
).model_dump()
]
}
) )
if isinstance(exc, SQLAlchemyError): if isinstance(exc, SQLAlchemyError):
error_message, status_code = handle_sqlalchemy_error(exc) error_message, status_code = handle_sqlalchemy_error(exc)
logging.error( logging.error(
f"Database Error | Error: {str(error_message)} | Request: {request_info}", f"Database error occurred | Error ID: {error_id}",
extra={"error_category": "database"}, extra={
"error_id": error_id,
"error_category": "database",
"error_message": error_message,
"request": request_info,
"exception": str(exc),
},
) )
return JSONResponse( return JSONResponse(
@ -156,42 +184,28 @@ def handle_exception(request: Request, exc: Exception):
"data": None, "data": None,
"message": error_message, "message": error_message,
"status": ResponseStatus.ERROR, "status": ResponseStatus.ERROR,
"errors": [ "error_id": error_id
ErrorDetail( },
message=error_message
).model_dump()
]
}
) )
# Log unexpected errors # Log unexpected errors
error_message = f"{exc.__class__.__name__}: {str(exc)}"
error_traceback = exc.__traceback__
# Get file and line info if available
if error_traceback:
tb = error_traceback
while tb.tb_next:
tb = tb.tb_next
file_name = tb.tb_frame.f_code.co_filename
line_num = tb.tb_lineno
error_message = f"{error_message}\nFile {file_name}, line {line_num}"
logging.error( logging.error(
f"Unexpected Error | Error: {error_message} | Request: {request_info}", f"Unexpected error occurred | Error ID: {error_id}",
extra={"error_category": "unexpected"}, extra={
"error_id": error_id,
"error_category": "unexpected",
"error_message": str(exc),
"request": request_info,
},
exc_info=True,
) )
return JSONResponse( return JSONResponse(
status_code=500, status_code=500,
content={ content={
"data": None, "data": None,
"message": error_message, "message": "An unexpected error occurred",
"status": ResponseStatus.ERROR, "status": ResponseStatus.ERROR,
"errors": [ "error_id": error_id
ErrorDetail( },
message=error_message
).model_dump()
]
}
) )

@ -35,29 +35,45 @@ class JSONFormatter(logging.Formatter):
Custom formatter to output logs in JSON format. Custom formatter to output logs in JSON format.
""" """
def format(self, record): def format(self, record):
from src.context import get_request_id from src.context import get_request_id, get_user_id, get_username, get_role
request_id = None request_id = None
user_id = None
username = None
role = None
try: try:
request_id = get_request_id() request_id = get_request_id()
user_id = get_user_id()
username = get_username()
role = get_role()
except Exception: except Exception:
pass pass
# Standard fields from requirements
log_record = { log_record = {
"timestamp": datetime.datetime.fromtimestamp(record.created).astimezone().isoformat(), "timestamp": datetime.datetime.fromtimestamp(record.created).strftime("%Y-%m-%d %H:%M:%S"),
"level": record.levelname, "level": record.levelname,
"name": record.name,
"message": record.getMessage(), "message": record.getMessage(),
"logger_name": record.name,
"location": f"{record.module}:{record.funcName}:{record.lineno}",
"module": record.module,
"funcName": record.funcName,
"lineno": record.lineno,
"pid": os.getpid(),
"request_id": request_id or "SYSTEM", # request id assigned per request or SYSTEM for system logs
} }
# Add Context information if available
if user_id:
log_record["user_id"] = user_id
if username:
log_record["username"] = username
if role:
log_record["role"] = role
if request_id:
log_record["request_id"] = request_id
# Add Error context if available
if hasattr(record, "error_id"):
log_record["error_id"] = record.error_id
elif "error_id" in record.__dict__:
log_record["error_id"] = record.error_id
# Capture exception info if available # Capture exception info if available
if record.exc_info: if record.exc_info:
log_record["exception"] = self.formatException(record.exc_info) log_record["exception"] = self.formatException(record.exc_info)
@ -67,18 +83,17 @@ class JSONFormatter(logging.Formatter):
log_record["stack_trace"] = self.formatStack(record.stack_info) log_record["stack_trace"] = self.formatStack(record.stack_info)
# Add any extra attributes passed to the log call # Add any extra attributes passed to the log call
# We skip standard and internal uvicorn/fastapi attributes to avoid duplication or mess
standard_attrs = { standard_attrs = {
"args", "asctime", "created", "exc_info", "exc_text", "filename", "args", "asctime", "created", "exc_info", "exc_text", "filename",
"funcName", "levelname", "levelno", "lineno", "module", "msecs", "funcName", "levelname", "levelno", "lineno", "module", "msecs",
"message", "msg", "name", "pathname", "process", "processName", "message", "msg", "name", "pathname", "process", "processName",
"relativeCreated", "stack_info", "thread", "threadName", "relativeCreated", "stack_info", "thread", "threadName", "error_id",
"color_message", "request", "scope" "color_message", "request", "scope"
} }
for key, value in record.__dict__.items(): for key, value in record.__dict__.items():
if key not in standard_attrs: if key not in standard_attrs and not key.startswith("_"):
log_record[key] = value log_record[key] = value
log_json = json.dumps(log_record) log_json = json.dumps(log_record)
# Apply color if the output is a terminal # Apply color if the output is a terminal

@ -50,7 +50,7 @@ app.state.limiter = limiter
app.add_exception_handler(Exception, handle_exception) app.add_exception_handler(Exception, handle_exception)
app.add_exception_handler(HTTPException, handle_exception) app.add_exception_handler(HTTPException, handle_exception)
app.add_exception_handler(RequestValidationError, handle_exception) app.add_exception_handler(RequestValidationError, handle_exception)
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) app.add_exception_handler(RateLimitExceeded, handle_exception)
app.add_exception_handler(SQLAlchemyError, handle_exception) app.add_exception_handler(SQLAlchemyError, handle_exception)
from src.context import set_request_id, reset_request_id, get_request_id from src.context import set_request_id, reset_request_id, get_request_id
@ -68,18 +68,74 @@ async def db_session_middleware(request: Request, call_next):
try: try:
log.info(f"Incoming request: {request.method} {request.url.path}") start_time = time.time()
session = async_scoped_session(async_session, scopefunc=get_request_id) session = async_scoped_session(async_session, scopefunc=get_request_id)
request.state.db = session() request.state.db = session()
collector_session = async_scoped_session(collector_async_session, scopefunc=get_request_id) collector_session = async_scoped_session(collector_async_session, scopefunc=get_request_id)
request.state.collector_db = collector_session() request.state.collector_db = collector_session()
response = await call_next(request) response = await call_next(request)
if response.status_code >= 400: process_time = (time.time() - start_time) * 1000
log.error(f"Request completed: {response.status_code}")
else: from src.context import get_username, get_role, get_user_id, set_user_id, set_username, set_role
log.info(f"Request completed: {response.status_code}")
# Pull from context or fallback to request.state.user
username = get_username()
role = get_role()
user_id = get_user_id()
user_obj = getattr(request.state, "user", None)
if user_obj:
# UserBase in this project
u_id = getattr(user_obj, "user_id", None)
u_name = getattr(user_obj, "name", None) or getattr(user_obj, "username", None)
u_role = getattr(user_obj, "role", None)
if not user_id and u_id:
user_id = str(u_id)
set_user_id(user_id)
if not username and u_name:
username = u_name
set_username(username)
if not role and u_role:
role = u_role
set_role(role)
user_info_str = ""
if username:
user_info_str = f" | User: {username}"
if role:
user_info_str += f" ({role})"
log.info(
f"HTTP {request.method} {request.url.path} completed in {round(process_time, 2)}ms{user_info_str}",
extra={
"method": request.method,
"path": request.url.path,
"status_code": response.status_code,
"duration_ms": round(process_time, 2),
"user_id": user_id,
"role": role,
},
)
except Exception as e: except Exception as e:
log.error(f"Request failed: {type(e).__name__} - {str(e)}") # Generate an error_id here if it hasn't been generated yet
error_id = getattr(request.state, "error_id", None)
if not error_id:
import uuid
error_id = str(uuid.uuid1())
request.state.error_id = error_id
log.error(
f"Request failed | Error ID: {error_id}",
extra={
"method": request.method,
"path": request.url.path,
"error": str(e),
"error_id": error_id,
},
exc_info=True,
)
raise e from None raise e from None
finally: finally:
await request.state.db.close() await request.state.db.close()

@ -18,13 +18,35 @@ MAX_QUERY_PARAMS = 50
MAX_QUERY_LENGTH = 2000 MAX_QUERY_LENGTH = 2000
MAX_JSON_BODY_SIZE = 1024 * 100 # 100 KB MAX_JSON_BODY_SIZE = 1024 * 100 # 100 KB
# Very targeted patterns. Avoid catastrophic regex nonsense. XSS_PATTERN = re.compile(
XSS_PATTERN_STR = r"(<script|</script|javascript:|onerror\s*=|onload\s*=|<svg|<img)" r"(<script|<iframe|<embed|<object|<svg|<img|<video|<audio|<base|<link|<meta|<form|<button|"
XSS_PATTERN = re.compile(XSS_PATTERN_STR, re.IGNORECASE) r"javascript:|vbscript:|data:text/html|onerror\s*=|onload\s*=|onmouseover\s*=|onfocus\s*=|"
r"onclick\s*=|onscroll\s*=|ondblclick\s*=|onkeydown\s*=|onkeypress\s*=|onkeyup\s*=|"
SQLI_PATTERN_STR = r"(\bUNION\b|\bSELECT\b|\bINSERT\b|\bDELETE\b|\bDROP\b|--|\bOR\b\s+1=1)" r"onloadstart\s*=|onpageshow\s*=|onresize\s*=|onunload\s*=|style\s*=\s*['\"].*expression\s*\(|"
SQLI_PATTERN = re.compile(SQLI_PATTERN_STR, re.IGNORECASE) r"eval\s*\(|setTimeout\s*\(|setInterval\s*\(|Function\s*\()",
re.IGNORECASE,
)
SQLI_PATTERN = re.compile(
r"(\bUNION\b|\bSELECT\b|\bINSERT\b|\bUPDATE\b|\bDELETE\b|\bDROP\b|\bALTER\b|\bCREATE\b|\bTRUNCATE\b|"
r"\bEXEC\b|\bEXECUTE\b|\bDECLARE\b|\bWAITFOR\b|\bDELAY\b|\bGROUP\b\s+\bBY\b|\bHAVING\b|\bORDER\b\s+\bBY\b|"
r"\bINFORMATION_SCHEMA\b|\bSYS\b\.|\bSYSOBJECTS\b|\bPG_SLEEP\b|\bSLEEP\b\(|--|/\*|\*/|#|\bOR\b\s+['\"]?\d+['\"]?\s*=\s*['\"]?\d+|"
r"\bAND\b\s+['\"]?\d+['\"]?\s*=\s*['\"]?\d+|"
r"\bXP_CMDSHELL\b|\bLOAD_FILE\b|\bINTO\s+OUTFILE\b)",
re.IGNORECASE,
)
RCE_PATTERN = re.compile(
r"(\$\(|`.*`|[;&|]\s*(cat|ls|id|whoami|pwd|ifconfig|ip|netstat|nc|netcat|nmap|curl|wget|python|php|perl|ruby|bash|sh|cmd|powershell|pwsh|sc\s+|tasklist|taskkill|base64|sudo|crontab|ssh|ftp|tftp)|"
r"\b(cat|ls|id|whoami|pwd|ifconfig|ip|netstat|nc|netcat|nmap|curl|wget|python|php|perl|ruby|bash|sh|cmd|powershell|pwsh|base64|sudo|crontab)\b|"
r"/etc/passwd|/etc/shadow|/etc/group|/etc/issue|/proc/self/|/windows/system32/|C:\\Windows\\)",
re.IGNORECASE,
)
TRAVERSAL_PATTERN = re.compile(
r"(\.\./|\.\.\\|%2e%2e%2f|%2e%2e/|\.\.%2f|%2e%2e%5c)",
re.IGNORECASE,
)
# JSON prototype pollution keys # JSON prototype pollution keys
FORBIDDEN_JSON_KEYS = {"__proto__", "constructor", "prototype"} FORBIDDEN_JSON_KEYS = {"__proto__", "constructor", "prototype"}
@ -50,6 +72,18 @@ def inspect_value(value: str, source: str):
detail=f"Potential SQL injection payload detected in {source}", detail=f"Potential SQL injection payload detected in {source}",
) )
if RCE_PATTERN.search(value):
raise HTTPException(
status_code=400,
detail=f"Potential RCE payload detected in {source}",
)
if TRAVERSAL_PATTERN.search(value):
raise HTTPException(
status_code=400,
detail=f"Potential Path Traversal payload detected in {source}",
)
if has_control_chars(value): if has_control_chars(value):
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
@ -114,9 +148,21 @@ class RequestValidationMiddleware(BaseHTTPMiddleware):
# ------------------------- # -------------------------
# 3. Query param inspection # 3. Query param inspection
# ------------------------- # -------------------------
pagination_size_keys = {"size", "itemsPerPage", "per_page", "limit", "items_per_page"}
for key, value in params: for key, value in params:
if value: if value:
inspect_value(value, f"query param '{key}'") inspect_value(value, f"query param '{key}'")
# Pagination constraint: multiples of 5, max 50
if key in pagination_size_keys and value:
try:
size_val = int(value)
if size_val > 50:
raise HTTPException(status_code=400, detail=f"Pagination size '{key}' cannot exceed 50")
if size_val % 5 != 0:
raise HTTPException(status_code=400, detail=f"Pagination size '{key}' must be a multiple of 5")
except ValueError:
raise HTTPException(status_code=400, detail=f"Pagination size '{key}' must be an integer")
# ------------------------- # -------------------------
# 4. Content-Type sanity # 4. Content-Type sanity

Loading…
Cancel
Save