Compare commits

...

2 Commits

@ -75,22 +75,22 @@ def handle_sqlalchemy_error(error: SQLAlchemyError):
if isinstance(error, IntegrityError): if isinstance(error, IntegrityError):
if "unique constraint" in str(error).lower(): if "unique constraint" in str(error).lower():
return "This record already exists.", 409 return "This record already exists.", 422
elif "foreign key constraint" in str(error).lower(): elif "foreign key constraint" in str(error).lower():
return "Related record not found.", 400 return "Related record not found.", 422
else: else:
return "Data integrity error.", 400 return "Data integrity error.", 422
elif isinstance(error, DataError) or isinstance(original_error, AsyncPGDataError): elif isinstance(error, DataError) or isinstance(original_error, AsyncPGDataError):
return "Invalid data provided.", 400 return "Invalid data provided.", 422
elif isinstance(error, DBAPIError): elif isinstance(error, DBAPIError):
if "unique constraint" in str(error).lower(): if "unique constraint" in str(error).lower():
return "This record already exists.", 409 return "This record already exists.", 422
elif "foreign key constraint" in str(error).lower(): elif "foreign key constraint" in str(error).lower():
return "Related record not found.", 400 return "Related record not found.", 422
elif "null value in column" in str(error).lower(): elif "null value in column" in str(error).lower():
return "Required data missing.", 400 return "Required data missing.", 422
elif "invalid input for query argument" in str(error).lower(): elif "invalid input for query argument" in str(error).lower():
return "Invalid data provided.", 400 return "Invalid data provided.", 422
else: else:
return "Database error.", 500 return "Database error.", 500
else: else:
@ -151,22 +151,38 @@ def handle_exception(request: Request, exc: Exception):
) )
if isinstance(exc, (HTTPException, StarletteHTTPException)): if isinstance(exc, (HTTPException, StarletteHTTPException)):
log.error( # Log as warning for 4xx, error for 5xx
f"HTTP exception occurred | Error ID: {error_id}", status_code = exc.status_code if hasattr(exc, "status_code") else 500
extra={ detail = exc.detail if hasattr(exc, "detail") else str(exc)
"error_id": error_id,
"error_category": "http", if 400 <= status_code < 500:
"status_code": exc.status_code, log.warning(
"detail": exc.detail if hasattr(exc, "detail") else str(exc), f"HTTP {status_code} occurred | Error ID: {error_id} | Detail: {detail}",
"request": request_info, extra={
}, "error_id": error_id,
) "error_category": "http",
"status_code": status_code,
"detail": detail,
"request": request_info,
},
)
else:
log.error(
f"HTTP {status_code} occurred | Error ID: {error_id} | Detail: {detail}",
extra={
"error_id": error_id,
"error_category": "http",
"status_code": status_code,
"detail": detail,
"request": request_info,
},
)
return JSONResponse( return JSONResponse(
status_code=exc.status_code, status_code=status_code,
content={ content={
"data": None, "data": None,
"message": str(exc.detail) if hasattr(exc, "detail") else str(exc), "message": str(detail),
"status": ResponseStatus.ERROR, "status": ResponseStatus.ERROR,
"error_id": error_id "error_id": error_id
}, },
@ -174,16 +190,29 @@ def handle_exception(request: Request, exc: Exception):
if isinstance(exc, SQLAlchemyError): if isinstance(exc, SQLAlchemyError):
error_message, status_code = handle_sqlalchemy_error(exc) error_message, status_code = handle_sqlalchemy_error(exc)
log.error( # Log integrity errors as warning, others as error
f"Database error occurred | Error ID: {error_id}", if 400 <= status_code < 500:
extra={ log.warning(
"error_id": error_id, f"Database integrity/validation error occurred | Error ID: {error_id}",
"error_category": "database", extra={
"error_message": error_message, "error_id": error_id,
"request": request_info, "error_category": "database",
"exception": str(exc), "error_message": error_message,
}, "request": request_info,
) "exception": str(exc),
},
)
else:
log.error(
f"Database error occurred | Error ID: {error_id}",
extra={
"error_id": error_id,
"error_category": "database",
"error_message": error_message,
"request": request_info,
"exception": str(exc),
},
)
return JSONResponse( return JSONResponse(
status_code=status_code, status_code=status_code,

@ -1,5 +1,6 @@
import json import json
import re import re
import logging
from collections import Counter from collections import Counter
from fastapi import Request, HTTPException from fastapi import Request, HTTPException
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
@ -14,95 +15,189 @@ ALLOWED_MULTI_PARAMS = {
"exclude[]", "exclude[]",
} }
ALLOWED_DATA_PARAMS = {
"AerosData", "AhmJobId", "CustomInput", "DurationUnit", "IsDefault", "Konkin_offset",
"MaintenanceOutages", "MasterData", "OffSet", "OverhaulDuration", "OverhaulInterval",
"PlannedOutages", "SchematicName", "SecretStr", "SimDuration", "SimNumRun", "SimSeed",
"SimulationName", "actual_shutdown", "aeros_node", "all_params", "aro_file",
"aro_file_path", "availability", "baseline_simulation_id", "calc_results",
"cm_dis_p1", "cm_dis_p2", "cm_dis_p3", "cm_dis_type", "cm_dis_unit_code",
"cm_waiting_time", "contribution", "contribution_factor", "created_at",
"criticality", "current_user", "custom_parameters", "data", "datetime",
"derating_hours", "descending", "design_flowrate", "duration", "duration_above_h",
"duration_above_hh", "duration_at_empty", "duration_at_full", "duration_below_l",
"duration_below_ll", "eaf", "eaf_konkin", "effective_loss", "efficiency", "efor",
"equipment", "equipment_name", "exclude", "failure_rates", "filter_spec", "finish",
"flowrate_unit", "id", "ideal_production", "ip_dis_p1", "ip_dis_p2", "ip_dis_p3",
"ip_dis_type", "ip_dis_unit_code", "items", "itemsPerPage", "items_per_page",
"level", "location_tag", "master_equipment", "max_flow_rate", "max_flowrate",
"message", "model_image", "mttr", "name", "node_id", "node_name", "node_type",
"num_cm", "num_events", "num_ip", "num_oh", "num_pm", "offset", "oh_dis_p1",
"oh_dis_p2", "oh_dis_p3", "oh_dis_type", "oh_dis_unit_code", "page",
"plan_duration", "planned_outage", "plot_results", "pm_dis_p1", "pm_dis_p2",
"pm_dis_p3", "pm_dis_type", "pm_dis_unit_code", "point_availabilities",
"point_flowrates", "production", "production_std", "project_name", "query_str",
"rel_dis_p1", "rel_dis_p2", "rel_dis_p3", "rel_dis_type", "rel_dis_unit_code",
"remark", "schematic_id", "schematic_name", "simulation_id", "simulation_name",
"sof", "sort_by", "sortBy[]", "descending[]", "exclude[]", "start", "started_at",
"status", "stg_input", "storage_capacity", "structure_name", "t_wait_for_crew",
"t_wait_for_spare", "target_simulation_id", "timestamp_outs", "total",
"totalPages", "total_cm_downtime", "total_downtime", "total_ip_downtime",
"total_oh_downtime", "total_pm_downtime", "total_uptime", "updated_at", "year",
"_", "t", "timestamp", "q", "filter", "currentUser"
}
ALLOWED_HEADERS = {
"host",
"user-agent",
"accept",
"accept-language",
"accept-encoding",
"connection",
"upgrade-insecure-requests",
"if-modified-since",
"if-none-match",
"cache-control",
"authorization",
"content-type",
"content-length",
"origin",
"referer",
"sec-fetch-dest",
"sec-fetch-mode",
"sec-fetch-site",
"sec-fetch-user",
"sec-ch-ua",
"sec-ch-ua-mobile",
"sec-ch-ua-platform",
"pragma",
"dnt",
"priority",
"x-forwarded-for",
"x-forwarded-proto",
"x-forwarded-host",
"x-forwarded-port",
"x-real-ip",
"x-request-id",
"x-correlation-id",
"x-requested-with",
"x-csrf-token",
"x-xsrf-token",
"postman-token",
"x-forwarded-path",
"x-forwarded-prefix",
"cookie",
}
MAX_QUERY_PARAMS = 50 MAX_QUERY_PARAMS = 50
MAX_QUERY_LENGTH = 2000 MAX_QUERY_LENGTH = 2000
MAX_JSON_BODY_SIZE = 1024 * 100 # 100 KB MAX_JSON_BODY_SIZE = 1024 * 500 # 500 KB
XSS_PATTERN = re.compile( XSS_PATTERN = re.compile(
r"(<script|<iframe|<embed|<object|<svg|<img|<video|<audio|<base|<link|<meta|<form|<button|" r"("
r"javascript:|vbscript:|data:text/html|onerror\s*=|onload\s*=|onmouseover\s*=|onfocus\s*=|" r"<(script|iframe|embed|object|svg|img|video|audio|base|link|meta|form|button|details|animate)\b|"
r"onclick\s*=|onscroll\s*=|ondblclick\s*=|onkeydown\s*=|onkeypress\s*=|onkeyup\s*=|" r"javascript\s*:|vbscript\s*:|data\s*:[^,]*base64[^,]*|data\s*:text/html|"
r"onloadstart\s*=|onpageshow\s*=|onresize\s*=|onunload\s*=|style\s*=\s*['\"].expression\s\(|" r"\bon[a-z]+\s*=|" # Catch-all for any 'on' event (onerror, onclick, etc.)
r"eval\s*\(|setTimeout\s*\(|setInterval\s*\(|Function\s*\()", r"style\s*=.*expression\s*\(|" # Old IE specific
r"\b(eval|setTimeout|setInterval|Function)\s*\("
r")",
re.IGNORECASE, re.IGNORECASE,
) )
SQLI_PATTERN = re.compile( SQLI_PATTERN = re.compile(
r"(\bUNION\b|\bSELECT\b|\bINSERT\b|\bUPDATE\b|\bDELETE\b|\bDROP\b|\bALTER\b|\bCREATE\b|\bTRUNCATE\b|" r"("
r"\bEXEC\b|\bEXECUTE\b|\bDECLARE\b|\bWAITFOR\b|\bDELAY\b|\bGROUP\b\s+\bBY\b|\bHAVING\b|\bORDER\b\s+\bBY\b|" # 1. Keywords followed by whitespace and common SQL characters
r"\bINFORMATION_SCHEMA\b|\bSYS\b\.|\bSYSOBJECTS\b|\bPG_SLEEP\b|\bSLEEP\b\(|--|/\|\/|#|\bOR\b\s+['\"]?\d+['\"]?\s*=\s*['\"]?\d+|" r"\b(UNION|SELECT|INSERT|UPDATE|DELETE|DROP|ALTER|CREATE|TRUNCATE|EXEC(UTE)?|DECLARE)\b\s+[\w\*\(\']|"
r"\bAND\b\s+['\"]?\d+['\"]?\s*=\s*['\"]?\d+|"
r"\bXP_CMDSHELL\b|\bLOAD_FILE\b|\bINTO\s+OUTFILE\b)", # 2. Time-based attacks (more specific than just 'SLEEP')
re.IGNORECASE, r"\b(WAITFOR\b\s+DELAY|PG_SLEEP|SLEEP\s*\()|"
# 3. System tables/functions
r"\b(INFORMATION_SCHEMA|SYS\.|SYSOBJECTS|XP_CMDSHELL|LOAD_FILE|INTO\s+OUTFILE)\b|"
# 4. Logical Tautologies (OR 1=1) - Optimized for boundaries
r"\b(OR|AND)\b\s+['\"]?\d+['\"]?\s*=\s*['\"]?\d+|"
# 5. Comments
r"(?<!\S)--|(?<!\*)/\*|(?<!\*)\*/(?!\*)|(?<!\S)#|"
# 6. Hex / Stacked Queries
r";\s*\b(SELECT|DROP|DELETE|UPDATE|INSERT)\b"
r")",
re.IGNORECASE
) )
RCE_PATTERN = re.compile( RCE_PATTERN = re.compile(
r"(\$\(|`.*`|[;&|]\s*(cat|ls|id|whoami|pwd|ifconfig|ip|netstat|nc|netcat|nmap|curl|wget|python|php|perl|ruby|bash|sh|cmd|powershell|pwsh|sc\s+|tasklist|taskkill|base64|sudo|crontab|ssh|ftp|tftp)|" r"("
r"\b(cat|ls|id|whoami|pwd|ifconfig|ip|netstat|nc|netcat|nmap|curl|wget|python|php|perl|ruby|bash|sh|cmd|powershell|pwsh|base64|sudo|crontab)\b|" r"\$\(.*\)|`.*`|" # Command substitution $(...) or `...`
r"/etc/passwd|/etc/shadow|/etc/group|/etc/issue|/proc/self/|/windows/system32/|C:\\Windows\\)", r"[;&|]\s*(cat|ls|id|whoami|pwd|ifconfig|ip|netstat|nc|netcat|nmap|curl|wget|python|php|perl|ruby|bash|sh|cmd|powershell|pwsh|sc\s+|tasklist|taskkill|base64|sudo|crontab|ssh|ftp|tftp)|"
# Only flag naked commands if they are clearly standalone or system paths
r"\b(/etc/passwd|/etc/shadow|/etc/group|/etc/issue|/proc/self/|/windows/system32/|C:\\Windows\\)\b"
r")",
re.IGNORECASE, re.IGNORECASE,
) )
TRAVERSAL_PATTERN = re.compile( TRAVERSAL_PATTERN = re.compile(
r"(\.\./|\.\.\\|%2e%2e%2f|%2e%2e/|\.\.%2f|%2e%2e%5c)", r"(\.\.[/\\]|%2e%2e%2f|%2e%2e/|\.\.%2f|%2e%2e%5c|%252e%252e%252f|\\00)",
re.IGNORECASE, re.IGNORECASE,
) )
# JSON prototype pollution keys
FORBIDDEN_JSON_KEYS = {"__proto__", "constructor", "prototype"} FORBIDDEN_JSON_KEYS = {"__proto__", "constructor", "prototype"}
# ========================= DYNAMIC_KEYS = {
# Helpers "data", "calc_results", "plot_results", "point_availabilities", "point_flowrates",
# ========================= "failure_rates", "custom_parameters", "parameters", "results", "all_params"
}
log = logging.getLogger("security_logger")
def has_control_chars(value: str) -> bool: def has_control_chars(value: str) -> bool:
return any(ord(c) < 32 and c not in ("\n", "\r", "\t") for c in value) return any(ord(c) < 32 and c not in ("\n", "\r", "\t") for c in value)
def inspect_value(value: str, source: str): def inspect_value(value: str, source: str):
if not isinstance(value, str) or value == "*/*":
return
if XSS_PATTERN.search(value): if XSS_PATTERN.search(value):
raise HTTPException( log.warning(f"Security violation: Potential XSS payload detected in {source}")
status_code=400, raise HTTPException(status_code=422, detail=f"Potential XSS payload detected in {source}")
detail=f"Potential XSS payload detected in {source}",
)
if SQLI_PATTERN.search(value): if SQLI_PATTERN.search(value):
raise HTTPException( log.warning(f"Security violation: Potential SQL injection payload detected in {source}")
status_code=400, raise HTTPException(status_code=422, detail=f"Potential SQL injection payload detected in {source}")
detail=f"Potential SQL injection payload detected in {source}",
)
if RCE_PATTERN.search(value): if RCE_PATTERN.search(value):
raise HTTPException( log.warning(f"Security violation: Potential RCE payload detected in {source}")
status_code=400, raise HTTPException(status_code=422, detail=f"Potential RCE payload detected in {source}")
detail=f"Potential RCE payload detected in {source}",
)
if TRAVERSAL_PATTERN.search(value): if TRAVERSAL_PATTERN.search(value):
raise HTTPException( log.warning(f"Security violation: Potential Path Traversal payload detected in {source}")
status_code=400, raise HTTPException(status_code=422, detail=f"Potential Path Traversal payload detected in {source}")
detail=f"Potential traversal payload detected in {source}",
)
if has_control_chars(value): if has_control_chars(value):
raise HTTPException( log.warning(f"Security violation: Invalid control characters detected in {source}")
status_code=400, raise HTTPException(status_code=422, detail=f"Invalid control characters detected in {source}")
detail=f"Invalid control characters detected in {source}",
)
def inspect_json(obj, path="body"): def inspect_json(obj, path="body", check_whitelist=True):
if isinstance(obj, dict): if isinstance(obj, dict):
for key, value in obj.items(): for key, value in obj.items():
if key in FORBIDDEN_JSON_KEYS: if key in FORBIDDEN_JSON_KEYS:
raise HTTPException( log.warning(f"Security violation: Forbidden JSON key detected: {path}.{key}")
status_code=400, raise HTTPException(status_code=422, detail=f"Forbidden JSON key detected: {path}.{key}")
detail=f"Forbidden JSON key detected: {path}.{key}",
) if check_whitelist and key not in ALLOWED_DATA_PARAMS:
inspect_json(value, f"{path}.{key}") log.warning(f"Security violation: Unknown JSON key detected: {path}.{key}")
raise HTTPException(status_code=422, detail=f"Unknown JSON key detected: {path}.{key}")
# Recurse. If the key is a dynamic container, we stop whitelist checking for children.
should_check_subkeys = check_whitelist and (key not in DYNAMIC_KEYS)
inspect_json(value, f"{path}.{key}", check_whitelist=should_check_subkeys)
elif isinstance(obj, list): elif isinstance(obj, list):
for i, item in enumerate(obj): for i, item in enumerate(obj):
inspect_json(item, f"{path}[{i}]") inspect_json(item, f"{path}[{i}]", check_whitelist=check_whitelist)
elif isinstance(obj, str): elif isinstance(obj, str):
inspect_value(obj, path) inspect_value(obj, path)
@ -113,22 +208,52 @@ def inspect_json(obj, path="body"):
class RequestValidationMiddleware(BaseHTTPMiddleware): class RequestValidationMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next): async def dispatch(self, request: Request, call_next):
# -------------------------
# 0. Header validation
# -------------------------
header_keys = [key.lower() for key, _ in request.headers.items()]
# Check for duplicate headers
header_counter = Counter(header_keys)
duplicate_headers = [key for key, count in header_counter.items() if count > 1]
ALLOW_DUPLICATE_HEADERS = {'accept', 'accept-encoding', 'accept-language', 'accept-charset', 'cookie'}
real_duplicates = [h for h in duplicate_headers if h not in ALLOW_DUPLICATE_HEADERS]
if real_duplicates:
log.warning(f"Security violation: Duplicate headers detected: {real_duplicates}")
raise HTTPException(status_code=422, detail=f"Duplicate headers are not allowed: {real_duplicates}")
# Whitelist headers
unknown_headers = [key for key in header_keys if key not in ALLOWED_HEADERS]
if unknown_headers:
filtered_unknown = [h for h in unknown_headers if not h.startswith('sec-')]
if filtered_unknown:
log.warning(f"Security violation: Unknown headers detected: {filtered_unknown}")
raise HTTPException(status_code=422, detail=f"Unknown headers detected: {filtered_unknown}")
# Inspect header values
for key, value in request.headers.items():
if value:
inspect_value(value, f"header '{key}'")
# ------------------------- # -------------------------
# 1. Query string limits # 1. Query string limits
# ------------------------- # -------------------------
if len(request.url.query) > MAX_QUERY_LENGTH: if len(request.url.query) > MAX_QUERY_LENGTH:
raise HTTPException( log.warning(f"Security violation: Query string too long")
status_code=414, raise HTTPException(status_code=422, detail="Query string too long")
detail="Query string too long",
)
params = request.query_params.multi_items() params = request.query_params.multi_items()
if len(params) > MAX_QUERY_PARAMS: if len(params) > MAX_QUERY_PARAMS:
raise HTTPException( log.warning(f"Security violation: Too many query parameters")
status_code=400, raise HTTPException(status_code=422, detail="Too many query parameters")
detail="Too many query parameters",
) # Check for unknown query parameters
unknown_params = [key for key, _ in params if key not in ALLOWED_DATA_PARAMS]
if unknown_params:
log.warning(f"Security violation: Unknown query parameters detected: {unknown_params}")
raise HTTPException(status_code=422, detail=f"Unknown query parameters detected: {unknown_params}")
# ------------------------- # -------------------------
# 2. Duplicate parameters # 2. Duplicate parameters
@ -140,38 +265,29 @@ class RequestValidationMiddleware(BaseHTTPMiddleware):
] ]
if duplicates: if duplicates:
raise HTTPException( log.warning(f"Security violation: Duplicate query parameters detected: {duplicates}")
status_code=400, raise HTTPException(status_code=422, detail=f"Duplicate query parameters are not allowed: {duplicates}")
detail=f"Duplicate query parameters are not allowed: {duplicates}",
)
# ------------------------- # -------------------------
# 3. Query param inspection # 3. Query param inspection & Pagination
# ------------------------- # -------------------------
pagination_size_keys = {"size", "itemsPerPage", "per_page", "limit", "items_per_page"} pagination_size_keys = {"size", "itemsPerPage", "per_page", "limit", "items_per_page"}
for key, value in params: for key, value in params:
if value: if value:
inspect_value(value, f"query param '{key}'") inspect_value(value, f"query param '{key}'")
# Pagination constraint: multiples of 5, max 50
if key in pagination_size_keys and value: if key in pagination_size_keys and value:
try: try:
size_val = int(value) size_val = int(value)
if size_val > 50: if size_val > 50:
raise HTTPException( log.warning(f"Security violation: Pagination size too large ({size_val})")
status_code=400, raise HTTPException(status_code=422, detail=f"Pagination size '{key}' cannot exceed 50")
detail=f"Pagination size '{key}' cannot exceed 50",
)
if size_val % 5 != 0: if size_val % 5 != 0:
raise HTTPException( log.warning(f"Security violation: Pagination size not multiple of 5 ({size_val})")
status_code=400, raise HTTPException(status_code=422, detail=f"Pagination size '{key}' must be a multiple of 5")
detail=f"Pagination size '{key}' must be a multiple of 5",
)
except ValueError: except ValueError:
raise HTTPException( log.warning(f"Security violation: Pagination size invalid value ({value})")
status_code=400, raise HTTPException(status_code=422, detail=f"Pagination size '{key}' must be an integer")
detail=f"Pagination size '{key}' must be an integer",
)
# ------------------------- # -------------------------
# 4. Content-Type sanity # 4. Content-Type sanity
@ -179,44 +295,44 @@ class RequestValidationMiddleware(BaseHTTPMiddleware):
content_type = request.headers.get("content-type", "") content_type = request.headers.get("content-type", "")
if content_type and not any( if content_type and not any(
content_type.startswith(t) content_type.startswith(t)
for t in ( for t in ("application/json", "multipart/form-data", "application/x-www-form-urlencoded")
"application/json",
"multipart/form-data",
"application/x-www-form-urlencoded",
)
): ):
raise HTTPException( raise HTTPException(status_code=422, detail="Unsupported Content-Type")
status_code=415,
detail="Unsupported Content-Type",
)
# ------------------------- # -------------------------
# 5. JSON body inspection # 5. Single source check (Query vs JSON Body)
# ------------------------- # -------------------------
has_query = len(params) > 0
has_body = False
if content_type.startswith("application/json"): if content_type.startswith("application/json"):
body = await request.body() content_length = request.headers.get("content-length")
if content_length and int(content_length) > 0:
has_body = True
if has_query and has_body:
log.warning(f"Security violation: Mixed parameters (query + JSON body)")
raise HTTPException(status_code=422, detail="Parameters must be from a single source (query string or JSON body), mixed sources are not allowed")
#if len(body) > MAX_JSON_BODY_SIZE: # -------------------------
# raise HTTPException( # 6. JSON body inspection
# status_code=413, # -------------------------
# detail="JSON body too large", if content_type.startswith("application/json"):
# ) body = await request.body()
# if len(body) > MAX_JSON_BODY_SIZE:
# raise HTTPException(status_code=422, detail="JSON body too large")
if body: if body:
try: try:
payload = json.loads(body) payload = json.loads(body)
except json.JSONDecodeError: except json.JSONDecodeError:
raise HTTPException( raise HTTPException(status_code=422, detail="Invalid JSON body")
status_code=400,
detail="Invalid JSON body",
)
inspect_json(payload) inspect_json(payload)
# Re-inject body for downstream handlers # Re-inject body for downstream handlers
async def receive(): async def receive():
return {"type": "http.request", "body": body} return {"type": "http.request", "body": body}
request._receive = receive
request._receive = receive # noqa: protected-access
return await call_next(request) return await call_next(request)

Loading…
Cancel
Save