Merge pull request 'main' (#10) from CIzz22/rbd-app:main into main

Reviewed-on: DigitalTwin/rbd-app#10
main
CIzz22 2 weeks ago
commit fff11eb527

@ -1,5 +1,5 @@
ENV=development
LOG_LEVEL=ERROR
LOG_LEVEL=INFO
PORT=3020
HOST=0.0.0.0
@ -18,5 +18,5 @@ COLLECTOR_NAME=digital_aeros_fixed
AEROS_LICENSE_ID=20260218-Jre5VZieQfWXTq0G8ClpVSGszMf4UEUMLS5ENpWRVcoVSrNJckVZzXE
AEROS_LICENSE_SECRET=GmLIxf9fr8Ap5m1IYzkk4RPBFcm7UBvcd0eRdRQ03oRdxLHQA0d9oyhUk2ZlM3LVdRh1mkgYy5254bmCjFyWWc0oPFwNWYzNwDwnv50qy6SLRdaFnI0yZcfLbWQ7qCSj
WINDOWS_AEROS_BASE_URL=http://192.168.1.102:8800
WINDOWS_AEROS_BASE_URL=http://192.168.1.102:8080
TEMPORAL_URL=http://192.168.1.86:7233

16
poetry.lock generated

@ -1312,15 +1312,15 @@ i18n = ["Babel (>=2.7)"]
[[package]]
name = "licaeros"
version = "0.1.2"
version = "0.1.7"
description = "License App for Aeros"
optional = false
python-versions = "*"
groups = ["main"]
files = [
{file = "licaeros-0.1.2-cp310-cp310-linux_x86_64.whl", hash = "sha256:4b9bfe2e7ba8ab9edb5db18dcb415476e7ab302e09d72b74b5bfd1ac8938b10c"},
{file = "licaeros-0.1.2-cp311-cp311-linux_x86_64.whl", hash = "sha256:4f3a2251aebe7351e61d6f80d6c7474387f9561fdcfff02103b78bb2168c9791"},
{file = "licaeros-0.1.2-cp312-cp312-linux_x86_64.whl", hash = "sha256:933c24029aec984ccc39baf630fbee10e07c1e28192c499685bec0a11d31321d"},
{file = "licaeros-0.1.7-cp310-cp310-linux_x86_64.whl", hash = "sha256:77bec84f37e02a7aff84f6c45a97a5933a86d99cdfabfd74ede36fa64506bfde"},
{file = "licaeros-0.1.7-cp311-cp311-linux_x86_64.whl", hash = "sha256:48e874645c5892e05c8f26bdea910dcdaa3e7b0e787be77920a4f4fb5504b2c1"},
{file = "licaeros-0.1.7-cp312-cp312-linux_x86_64.whl", hash = "sha256:6a0bf6c1b9094693058d927febdb6799c61aea7b5dd10265b014c9d314844135"},
]
[package.source]
@ -2478,14 +2478,14 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
[[package]]
name = "rich"
version = "14.3.2"
version = "14.3.3"
description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal"
optional = false
python-versions = ">=3.8.0"
groups = ["main"]
files = [
{file = "rich-14.3.2-py3-none-any.whl", hash = "sha256:08e67c3e90884651da3239ea668222d19bea7b589149d8014a21c633420dbb69"},
{file = "rich-14.3.2.tar.gz", hash = "sha256:e712f11c1a562a11843306f5ed999475f09ac31ffb64281f73ab29ffdda8b3b8"},
{file = "rich-14.3.3-py3-none-any.whl", hash = "sha256:793431c1f8619afa7d3b52b2cdec859562b950ea0d4b6b505397612db8d5362d"},
{file = "rich-14.3.3.tar.gz", hash = "sha256:b8daa0b9e4eef54dd8cf7c86c03713f53241884e814f4e2f5fb342fe520f639b"},
]
[package.dependencies]
@ -3593,4 +3593,4 @@ propcache = ">=0.2.1"
[metadata]
lock-version = "2.1"
python-versions = "^3.11"
content-hash = "c97aecfef075bcbd7a40d9c98ae79c30d6253bc2c9f14ef187b1a098ace42088"
content-hash = "46b6c8d43f09a99729b212166e31fd9190f8f659e178261a15bc35a694e2f81c"

@ -32,7 +32,8 @@ aiohttp = "^3.12.14"
ijson = "^3.4.0"
redis = "^7.1.0"
clamd = "^1.0.2"
licaeros = "^0.1.2"
licaeros = "^0.1.7"
[[tool.poetry.source]]

@ -294,12 +294,12 @@ def calculate_contribution_accurate(availabilities: Dict[str, float], structure_
key=lambda x: x[1]['birnbaum_importance'],
reverse=True)
print("\n=== COMPONENT IMPORTANCE ANALYSIS ===")
print(f"System Availability: {system_info['system_availability']:.6f} ({system_info['system_availability']*100:.4f}%)")
print(f"System Unavailability: {system_info['system_unavailability']:.6f}")
print("\nComponent Rankings (by Birnbaum Importance):")
print(f"{'Component':<20} {'Availability':<12} {'Birnbaum':<12} {'Criticality':<12} {'F-V':<12} {'Contribution%':<12}")
print("-" * 92)
# print("\n=== COMPONENT IMPORTANCE ANALYSIS ===")
# print(f"System Availability: {system_info['system_availability']:.6f} ({system_info['system_availability']*100:.4f}%)")
# print(f"System Unavailability: {system_info['system_unavailability']:.6f}")
# print("\nComponent Rankings (by Birnbaum Importance):")
# print(f"{'Component':<20} {'Availability':<12} {'Birnbaum':<12} {'Criticality':<12} {'F-V':<12} {'Contribution%':<12}")
# print("-" * 92)
for component, measures in sorted_components:
print(f"{component:<20} {measures['component_availability']:<12.6f} "

@ -54,10 +54,12 @@ async def get_all(*, common):
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
res = response.json()
# if not res.get("status"):
# raise HTTPException(
# status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=res.get("message")
# )
except Exception as e:
raise HTTPException(

@ -1,6 +1,6 @@
import os
from typing import Optional
import json
import httpx
from fastapi import HTTPException, status
from sqlalchemy import Delete, desc, Select, select, func
@ -10,7 +10,7 @@ from src.aeros_equipment.service import save_default_equipment
from src.aeros_simulation.service import save_default_simulation_node
from src.auth.service import CurrentUser
from src.config import WINDOWS_AEROS_BASE_URL, CLAMAV_HOST, CLAMAV_PORT
from src.aeros_utils import aeros_post
from src.aeros_utils import aeros_post, aeros_file_upload
from src.database.core import DbSession
from src.database.service import search_filter_sort_paginate
from src.utils import sanitize_filename
@ -119,8 +119,8 @@ async def import_aro_project(*, db_session: DbSession, aeros_project_in: AerosPr
# }
response = await aeros_file_upload(
"/api/upload",
file,
"/upload",
content,
"file",
clean_filename
)
@ -139,7 +139,6 @@ async def import_aro_project(*, db_session: DbSession, aeros_project_in: AerosPr
aro_path = upload_result.get("full_path")
filename = upload_result.get("stored_filename").replace(".aro", "")
if not aro_path:
raise HTTPException(
status_code=500,
@ -176,15 +175,21 @@ async def import_aro_project(*, db_session: DbSession, aeros_project_in: AerosPr
await db_session.commit()
# aro_json = json.dumps(aro_path)
# Update path to AEROS APP
# Example BODy "C/dsad/dsad.aro"
try:
response = await aeros_post(
"/api/Project/ImportAROFile",
data=f'"{aro_path}"',
json=aro_path,
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
response_json = response.json()
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)

@ -309,7 +309,7 @@ async def get_simulation_result_plot_per_node(db_session: DbSession, simulation_
}
@router.get("/result/ranking/{simulation_id}", response_model=StandardResponse[List[SimulationRankingParameters]])
async def get_simulation_result_ranking(db_session: DbSession, simulation_id, limit:int = Query(None)):
async def get_simulation_result_ranking(db_session: DbSession, simulation_id, limit:int = Query(None, le=50)):
"""Get simulation result."""
if simulation_id == 'default':
simulation = await get_default_simulation(db_session=db_session)

@ -1,6 +1,6 @@
import anyio
from licaeros import LicensedSession, device_fingerprint_hex
from src.config import AEROS_BASE_URL, AEROS_LICENSE_ID, AEROS_LICENSE_SECRET
from src.config import AEROS_BASE_URL, AEROS_LICENSE_ID, AEROS_LICENSE_SECRET, WINDOWS_AEROS_BASE_URL
import logging
log = logging.getLogger(__name__)
@ -8,33 +8,36 @@ log = logging.getLogger(__name__)
# Initialize a global session if possible, or create on demand
_aeros_session = None
def get_aeros_session():
def get_aeros_session(base_url):
global _aeros_session
if _aeros_session is None:
log.info(f"Initializing LicensedSession with base URL: {AEROS_BASE_URL}")
log.info(f"Initializing LicensedSession with base URL: {base_url}")
log.info(f"Encrypted Device ID: {device_fingerprint_hex()}")
_aeros_session = LicensedSession(
api_base=AEROS_BASE_URL,
api_base=base_url,
license_id=AEROS_LICENSE_ID,
license_secret=AEROS_LICENSE_SECRET,
timeout=1000
)
return _aeros_session
async def aeros_post(path: str, json: dict = None, **kwargs):
async def aeros_post(path: str, json=None, data=None, **kwargs):
"""
Asynchronous wrapper for LicensedSession.post
"""
session = get_aeros_session()
session = get_aeros_session(WINDOWS_AEROS_BASE_URL)
url = f"/api/aeros{path}"
# LicensedSession might not be async-compatible, so we run it in a thread
response = await anyio.to_thread.run_sync(
lambda: session.post(path, json)
lambda: session.post(url, json_data=json, data=data, headers=kwargs.get("headers"))
)
return response
async def aeros_file_upload(path, file, field_name, filename):
session = get_aeros_session()
session = get_aeros_session(WINDOWS_AEROS_BASE_URL)
url = f"/api/aeros{path}"
response = await anyio.to_thread.run_sync(
lambda: session.post_multipart(path, file, field_name, filename)
lambda: session.post_multipart(url, file, field_name, filename)
)
return response

@ -31,6 +31,17 @@ class JWTBearer(HTTPBearer):
)
request.state.user = user_info
from src.context import set_user_id, set_username, set_role
if hasattr(user_info, "user_id"):
set_user_id(str(user_info.user_id))
if hasattr(user_info, "username"):
set_username(user_info.username)
elif hasattr(user_info, "name"):
set_username(user_info.name)
if hasattr(user_info, "role"):
set_role(user_info.role)
return user_info
else:
raise HTTPException(status_code=403, detail="Invalid authorization code.")

@ -45,7 +45,7 @@ def get_config():
config = get_config()
LOG_LEVEL = config("LOG_LEVEL", default=logging.INFO)
LOG_LEVEL = config("LOG_LEVEL", default="INFO")
ENV = config("ENV", default="local")
PORT = config("PORT", cast=int, default=8000)
HOST = config("HOST", default="localhost")

@ -2,8 +2,14 @@ from contextvars import ContextVar
from typing import Optional, Final
REQUEST_ID_CTX_KEY: Final[str] = "request_id"
_request_id_ctx_var: ContextVar[Optional[str]] = ContextVar(
REQUEST_ID_CTX_KEY, default=None)
USER_ID_CTX_KEY: Final[str] = "user_id"
USERNAME_CTX_KEY: Final[str] = "username"
ROLE_CTX_KEY: Final[str] = "role"
_request_id_ctx_var: ContextVar[Optional[str]] = ContextVar(REQUEST_ID_CTX_KEY, default=None)
_user_id_ctx_var: ContextVar[Optional[str]] = ContextVar(USER_ID_CTX_KEY, default=None)
_username_ctx_var: ContextVar[Optional[str]] = ContextVar(USERNAME_CTX_KEY, default=None)
_role_ctx_var: ContextVar[Optional[str]] = ContextVar(ROLE_CTX_KEY, default=None)
def get_request_id() -> Optional[str]:
@ -16,3 +22,21 @@ def set_request_id(request_id: str):
def reset_request_id(token):
_request_id_ctx_var.reset(token)
def get_user_id() -> Optional[str]:
return _user_id_ctx_var.get()
def set_user_id(user_id: str):
return _user_id_ctx_var.set(user_id)
def get_username() -> Optional[str]:
return _username_ctx_var.get()
def set_username(username: str):
return _username_ctx_var.set(username)
def get_role() -> Optional[str]:
return _role_ctx_var.get()
def set_role(role: str):
return _role_ctx_var.set(role)

@ -8,7 +8,7 @@ class CommonParams(DefultBase):
# This ensures no extra query params are allowed
current_user: Optional[str] = Field(None, alias="currentUser", max_length=50)
page: int = Field(1, gt=0, lt=2147483647)
items_per_page: int = Field(5, gt=-2, lt=2147483647, alias="itemsPerPage")
items_per_page: int = Field(5, gt=0, le=50, multiple_of=5, alias="itemsPerPage")
query_str: Optional[str] = Field(None, alias="q", max_length=100)
filter_spec: Optional[str] = Field(None, alias="filter", max_length=500)
sort_by: List[str] = Field(default_factory=list, alias="sortBy[]")

@ -19,8 +19,8 @@ log = logging.getLogger(__name__)
class ErrorDetail(BaseModel):
field: Optional[str] = Field(None, max_length=100)
message: str = Field(..., max_length=255)
code: Optional[str] = Field(None, max_length=50)
message: str = Field(...)
code: Optional[str] = Field(None)
params: Optional[Dict[str, Any]] = None
@ -103,15 +103,38 @@ def handle_exception(request: Request, exc: Exception):
"""
Global exception handler for Fastapi application.
"""
import uuid
error_id = str(uuid.uuid1())
request_info = get_request_context(request)
# In FastAPI, we don't have a global 'g', but we can pass info in request.state
request.state.error_id = error_id
if isinstance(exc, RateLimitExceeded):
return _rate_limit_exceeded_handler(request, exc)
log.warning(
f"Rate limit exceeded | Error ID: {error_id}",
extra={
"error_id": error_id,
"error_category": "rate_limit",
"request": request_info,
"detail": str(exc.description) if hasattr(exc, "description") else str(exc),
},
)
return JSONResponse(
status_code=429,
content={
"data": None,
"message": "Rate limit exceeded",
"status": ResponseStatus.ERROR,
"error_id": error_id
}
)
if isinstance(exc, RequestValidationError):
log.error(
"Validation error occurred",
log.warning(
f"Validation error occurred | Error ID: {error_id}",
extra={
"error_id": error_id,
"error_category": "validation",
"errors": exc.errors(),
"request": request_info,
@ -120,16 +143,18 @@ def handle_exception(request: Request, exc: Exception):
return JSONResponse(
status_code=422,
content={
"data": None,
"data": exc.errors(),
"message": "Validation Error",
"status": ResponseStatus.ERROR,
"errors": exc.errors(),
"error_id": error_id
},
)
if isinstance(exc, (HTTPException, StarletteHTTPException)):
log.error(
"HTTP exception occurred",
f"HTTP exception occurred | Error ID: {error_id}",
extra={
"error_id": error_id,
"error_category": "http",
"status_code": exc.status_code,
"detail": exc.detail if hasattr(exc, "detail") else str(exc),
@ -143,19 +168,16 @@ def handle_exception(request: Request, exc: Exception):
"data": None,
"message": str(exc.detail) if hasattr(exc, "detail") else str(exc),
"status": ResponseStatus.ERROR,
"errors": [
ErrorDetail(
message=str(exc.detail) if hasattr(exc, "detail") else str(exc)
).model_dump()
],
"error_id": error_id
},
)
if isinstance(exc, SQLAlchemyError):
error_message, status_code = handle_sqlalchemy_error(exc)
log.error(
"Database error occurred",
f"Database error occurred | Error ID: {error_id}",
extra={
"error_id": error_id,
"error_category": "database",
"error_message": error_message,
"request": request_info,
@ -169,14 +191,15 @@ def handle_exception(request: Request, exc: Exception):
"data": None,
"message": error_message,
"status": ResponseStatus.ERROR,
"errors": [ErrorDetail(message=error_message).model_dump()],
"error_id": error_id
},
)
# Log unexpected errors
log.error(
"Unexpected error occurred",
f"Unexpected error occurred | Error ID: {error_id}",
extra={
"error_id": error_id,
"error_category": "unexpected",
"error_message": str(exc),
"request": request_info,
@ -188,10 +211,8 @@ def handle_exception(request: Request, exc: Exception):
status_code=500,
content={
"data": None,
"message": str(exc),
"message": "An unexpected error occurred",
"status": ResponseStatus.ERROR,
"errors": [
ErrorDetail(message="An unexpected error occurred.").model_dump()
],
"error_id": error_id
},
)

@ -35,28 +35,44 @@ class JSONFormatter(logging.Formatter):
Custom formatter to output logs in JSON format.
"""
def format(self, record):
from src.context import get_request_id
from src.context import get_request_id, get_user_id, get_username, get_role
request_id = None
user_id = None
username = None
role = None
try:
request_id = get_request_id()
user_id = get_user_id()
username = get_username()
role = get_role()
except Exception:
pass
# Standard fields from requirements
log_record = {
"timestamp": datetime.datetime.fromtimestamp(record.created).astimezone().isoformat(),
"timestamp": datetime.datetime.fromtimestamp(record.created).strftime("%Y-%m-%d %H:%M:%S"),
"level": record.levelname,
"name": record.name,
"message": record.getMessage(),
"logger_name": record.name,
"location": f"{record.module}:{record.funcName}:{record.lineno}",
"module": record.module,
"funcName": record.funcName,
"lineno": record.lineno,
"pid": os.getpid(),
"request_id": request_id,
}
# Add Context information if available
if user_id:
log_record["user_id"] = user_id
if username:
log_record["username"] = username
if role:
log_record["role"] = role
if request_id:
log_record["request_id"] = request_id
# Add Error context if available
if hasattr(record, "error_id"):
log_record["error_id"] = record.error_id
elif "error_id" in record.__dict__:
log_record["error_id"] = record.error_id
# Capture exception info if available
if record.exc_info:
@ -67,15 +83,14 @@ class JSONFormatter(logging.Formatter):
log_record["stack_trace"] = self.formatStack(record.stack_info)
# Add any extra attributes passed to the log call
# We skip standard attributes to avoid duplication
standard_attrs = {
"args", "asctime", "created", "exc_info", "exc_text", "filename",
"funcName", "levelname", "levelno", "lineno", "module", "msecs",
"message", "msg", "name", "pathname", "process", "processName",
"relativeCreated", "stack_info", "thread", "threadName"
"relativeCreated", "stack_info", "thread", "threadName", "error_id"
}
for key, value in record.__dict__.items():
if key not in standard_attrs:
if key not in standard_attrs and not key.startswith("_"):
log_record[key] = value
log_json = json.dumps(log_record)
@ -87,7 +102,6 @@ class JSONFormatter(logging.Formatter):
return log_json
def configure_logging():
log_level = str(LOG_LEVEL).upper() # cast to string
log_levels = list(LogLevels)

@ -52,7 +52,7 @@ app.add_exception_handler(Exception, handle_exception)
app.add_exception_handler(HTTPException, handle_exception)
app.add_exception_handler(StarletteHTTPException, handle_exception)
app.add_exception_handler(RequestValidationError, handle_exception)
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
app.add_exception_handler(RateLimitExceeded, handle_exception)
app.add_exception_handler(SQLAlchemyError, handle_exception)
app.state.limiter = limiter
@ -80,23 +80,58 @@ async def db_session_middleware(request: Request, call_next):
response = await call_next(request)
process_time = (time.time() - start_time) * 1000
from src.context import get_username, get_role, get_user_id, set_user_id, set_username, set_role
# Pull from context or fallback to request.state
username = get_username()
role = get_role()
user_id = get_user_id()
user_obj = getattr(request.state, "user", None)
if user_obj:
if not user_id and hasattr(user_obj, "user_id"):
user_id = str(user_obj.user_id)
set_user_id(user_id)
if not username and hasattr(user_obj, "name"):
username = user_obj.name
set_username(username)
if not role and hasattr(user_obj, "role"):
role = user_obj.role
set_role(role)
user_info_str = ""
if username:
user_info_str = f" | User: {username}"
if role:
user_info_str += f" ({role})"
log.info(
"Request finished",
f"HTTP {request.method} {request.url.path} completed in {round(process_time, 2)}ms{user_info_str}",
extra={
"method": request.method,
"path": request.url.path,
"status_code": response.status_code,
"duration_ms": round(process_time, 2),
"user_id": user_id,
"role": role,
},
)
except Exception as e:
# Generate an error_id here if it hasn't been generated yet (e.g., if it failed before the handler)
error_id = getattr(request.state, "error_id", None)
if not error_id:
import uuid
error_id = str(uuid.uuid1())
request.state.error_id = error_id
log.error(
"Request failed",
f"Request failed | Error ID: {error_id}",
extra={
"method": request.method,
"path": request.url.path,
"error": str(e),
"error_id": error_id,
},
exc_info=True,
)

@ -18,14 +18,33 @@ MAX_QUERY_PARAMS = 50
MAX_QUERY_LENGTH = 2000
MAX_JSON_BODY_SIZE = 1024 * 100 # 100 KB
# Very targeted patterns. Avoid catastrophic regex nonsense.
XSS_PATTERN = re.compile(
r"(<script|</script|javascript:|onerror\s*=|onload\s*=|<svg|<img)",
r"(<script|<iframe|<embed|<object|<svg|<img|<video|<audio|<base|<link|<meta|<form|<button|"
r"javascript:|vbscript:|data:text/html|onerror\s*=|onload\s*=|onmouseover\s*=|onfocus\s*=|"
r"onclick\s*=|onscroll\s*=|ondblclick\s*=|onkeydown\s*=|onkeypress\s*=|onkeyup\s*=|"
r"onloadstart\s*=|onpageshow\s*=|onresize\s*=|onunload\s*=|style\s*=\s*['\"].expression\s\(|"
r"eval\s*\(|setTimeout\s*\(|setInterval\s*\(|Function\s*\()",
re.IGNORECASE,
)
SQLI_PATTERN = re.compile(
r"(\bUNION\b|\bSELECT\b|\bINSERT\b|\bDELETE\b|\bDROP\b|--|\bOR\b\s+1=1)",
r"(\bUNION\b|\bSELECT\b|\bINSERT\b|\bUPDATE\b|\bDELETE\b|\bDROP\b|\bALTER\b|\bCREATE\b|\bTRUNCATE\b|"
r"\bEXEC\b|\bEXECUTE\b|\bDECLARE\b|\bWAITFOR\b|\bDELAY\b|\bGROUP\b\s+\bBY\b|\bHAVING\b|\bORDER\b\s+\bBY\b|"
r"\bINFORMATION_SCHEMA\b|\bSYS\b\.|\bSYSOBJECTS\b|\bPG_SLEEP\b|\bSLEEP\b\(|--|/\|\/|#|\bOR\b\s+['\"]?\d+['\"]?\s*=\s*['\"]?\d+|"
r"\bAND\b\s+['\"]?\d+['\"]?\s*=\s*['\"]?\d+|"
r"\bXP_CMDSHELL\b|\bLOAD_FILE\b|\bINTO\s+OUTFILE\b)",
re.IGNORECASE,
)
RCE_PATTERN = re.compile(
r"(\$\(|`.*`|[;&|]\s*(cat|ls|id|whoami|pwd|ifconfig|ip|netstat|nc|netcat|nmap|curl|wget|python|php|perl|ruby|bash|sh|cmd|powershell|pwsh|sc\s+|tasklist|taskkill|base64|sudo|crontab|ssh|ftp|tftp)|"
r"\b(cat|ls|id|whoami|pwd|ifconfig|ip|netstat|nc|netcat|nmap|curl|wget|python|php|perl|ruby|bash|sh|cmd|powershell|pwsh|base64|sudo|crontab)\b|"
r"/etc/passwd|/etc/shadow|/etc/group|/etc/issue|/proc/self/|/windows/system32/|C:\\Windows\\)",
re.IGNORECASE,
)
TRAVERSAL_PATTERN = re.compile(
r"(\.\./|\.\.\\|%2e%2e%2f|%2e%2e/|\.\.%2f|%2e%2e%5c)",
re.IGNORECASE,
)
@ -53,6 +72,18 @@ def inspect_value(value: str, source: str):
detail=f"Potential SQL injection payload detected in {source}",
)
if RCE_PATTERN.search(value):
raise HTTPException(
status_code=400,
detail=f"Potential RCE payload detected in {source}",
)
if TRAVERSAL_PATTERN.search(value):
raise HTTPException(
status_code=400,
detail=f"Potential traversal payload detected in {source}",
)
if has_control_chars(value):
raise HTTPException(
status_code=400,
@ -117,10 +148,31 @@ class RequestValidationMiddleware(BaseHTTPMiddleware):
# -------------------------
# 3. Query param inspection
# -------------------------
pagination_size_keys = {"size", "itemsPerPage", "per_page", "limit"}
for key, value in params:
if value:
inspect_value(value, f"query param '{key}'")
# Pagination constraint: multiples of 5, max 50
if key in pagination_size_keys and value:
try:
size_val = int(value)
if size_val > 50:
raise HTTPException(
status_code=400,
detail=f"Pagination size '{key}' cannot exceed 50",
)
if size_val % 5 != 0:
raise HTTPException(
status_code=400,
detail=f"Pagination size '{key}' must be a multiple of 5",
)
except ValueError:
raise HTTPException(
status_code=400,
detail=f"Pagination size '{key}' must be an integer",
)
# -------------------------
# 4. Content-Type sanity
# -------------------------

Loading…
Cancel
Save