feat: Implement comprehensive request validation by adding header and data parameter whitelisting, enhancing JSON body inspection, and standardizing validation error codes to 422.

main
Cizz22 1 week ago
parent 2797d4c989
commit 64d1fcf4dd

@ -75,22 +75,22 @@ def handle_sqlalchemy_error(error: SQLAlchemyError):
if isinstance(error, IntegrityError):
if "unique constraint" in str(error).lower():
return "This record already exists.", 409
return "This record already exists.", 422
elif "foreign key constraint" in str(error).lower():
return "Related record not found.", 400
return "Related record not found.", 422
else:
return "Data integrity error.", 400
return "Data integrity error.", 422
elif isinstance(error, DataError) or isinstance(original_error, AsyncPGDataError):
return "Invalid data provided.", 400
return "Invalid data provided.", 422
elif isinstance(error, DBAPIError):
if "unique constraint" in str(error).lower():
return "This record already exists.", 409
return "This record already exists.", 422
elif "foreign key constraint" in str(error).lower():
return "Related record not found.", 400
return "Related record not found.", 422
elif "null value in column" in str(error).lower():
return "Required data missing.", 400
return "Required data missing.", 422
elif "invalid input for query argument" in str(error).lower():
return "Invalid data provided.", 400
return "Invalid data provided.", 422
else:
return "Database error.", 500
else:

@ -14,15 +14,87 @@ ALLOWED_MULTI_PARAMS = {
"exclude[]",
}
ALLOWED_DATA_PARAMS = {
"AerosData", "AhmJobId", "CustomInput", "DurationUnit", "IsDefault", "Konkin_offset",
"MaintenanceOutages", "MasterData", "OffSet", "OverhaulDuration", "OverhaulInterval",
"PlannedOutages", "SchematicName", "SecretStr", "SimDuration", "SimNumRun", "SimSeed",
"SimulationName", "actual_shutdown", "aeros_node", "all_params", "aro_file",
"aro_file_path", "availability", "baseline_simulation_id", "calc_results",
"cm_dis_p1", "cm_dis_p2", "cm_dis_p3", "cm_dis_type", "cm_dis_unit_code",
"cm_waiting_time", "contribution", "contribution_factor", "created_at",
"criticality", "current_user", "custom_parameters", "data", "datetime",
"derating_hours", "descending", "design_flowrate", "duration", "duration_above_h",
"duration_above_hh", "duration_at_empty", "duration_at_full", "duration_below_l",
"duration_below_ll", "eaf", "eaf_konkin", "effective_loss", "efficiency", "efor",
"equipment", "equipment_name", "exclude", "failure_rates", "filter_spec", "finish",
"flowrate_unit", "id", "ideal_production", "ip_dis_p1", "ip_dis_p2", "ip_dis_p3",
"ip_dis_type", "ip_dis_unit_code", "items", "itemsPerPage", "items_per_page",
"level", "location_tag", "master_equipment", "max_flow_rate", "max_flowrate",
"message", "model_image", "mttr", "name", "node_id", "node_name", "node_type",
"num_cm", "num_events", "num_ip", "num_oh", "num_pm", "offset", "oh_dis_p1",
"oh_dis_p2", "oh_dis_p3", "oh_dis_type", "oh_dis_unit_code", "page",
"plan_duration", "planned_outage", "plot_results", "pm_dis_p1", "pm_dis_p2",
"pm_dis_p3", "pm_dis_type", "pm_dis_unit_code", "point_availabilities",
"point_flowrates", "production", "production_std", "project_name", "query_str",
"rel_dis_p1", "rel_dis_p2", "rel_dis_p3", "rel_dis_type", "rel_dis_unit_code",
"remark", "schematic_id", "schematic_name", "simulation_id", "simulation_name",
"sof", "sort_by", "sortBy[]", "descending[]", "exclude[]", "start", "started_at",
"status", "stg_input", "storage_capacity", "structure_name", "t_wait_for_crew",
"t_wait_for_spare", "target_simulation_id", "timestamp_outs", "total",
"totalPages", "total_cm_downtime", "total_downtime", "total_ip_downtime",
"total_oh_downtime", "total_pm_downtime", "total_uptime", "updated_at", "year",
"_", "t", "timestamp", "q", "filter", "currentUser"
}
ALLOWED_HEADERS = {
"host",
"user-agent",
"accept",
"accept-language",
"accept-encoding",
"connection",
"upgrade-insecure-requests",
"if-modified-since",
"if-none-match",
"cache-control",
"authorization",
"content-type",
"content-length",
"origin",
"referer",
"sec-fetch-dest",
"sec-fetch-mode",
"sec-fetch-site",
"sec-fetch-user",
"sec-ch-ua",
"sec-ch-ua-mobile",
"sec-ch-ua-platform",
"pragma",
"dnt",
"priority",
"x-forwarded-for",
"x-forwarded-proto",
"x-forwarded-host",
"x-forwarded-port",
"x-real-ip",
"x-request-id",
"x-correlation-id",
"x-requested-with",
"x-csrf-token",
"x-xsrf-token",
"postman-token",
"x-internal-key",
}
MAX_QUERY_PARAMS = 50
MAX_QUERY_LENGTH = 2000
MAX_JSON_BODY_SIZE = 1024 * 100 # 100 KB
MAX_JSON_BODY_SIZE = 1024 * 500 # 500 KB
XSS_PATTERN = re.compile(
r"(<script|<iframe|<embed|<object|<svg|<img|<video|<audio|<base|<link|<meta|<form|<button|"
r"javascript:|vbscript:|data:text/html|onerror\s*=|onload\s*=|onmouseover\s*=|onfocus\s*=|"
r"onclick\s*=|onscroll\s*=|ondblclick\s*=|onkeydown\s*=|onkeypress\s*=|onkeyup\s*=|"
r"onloadstart\s*=|onpageshow\s*=|onresize\s*=|onunload\s*=|style\s*=\s*['\"].expression\s\(|"
r"onloadstart\s*=|onpageshow\s*=|onresize\s*=|onunload\s*=|style\s*=\s*['\"].*expression\s*\(|"
r"eval\s*\(|setTimeout\s*\(|setInterval\s*\(|Function\s*\()",
re.IGNORECASE,
)
@ -30,7 +102,7 @@ XSS_PATTERN = re.compile(
SQLI_PATTERN = re.compile(
r"(\bUNION\b|\bSELECT\b|\bINSERT\b|\bUPDATE\b|\bDELETE\b|\bDROP\b|\bALTER\b|\bCREATE\b|\bTRUNCATE\b|"
r"\bEXEC\b|\bEXECUTE\b|\bDECLARE\b|\bWAITFOR\b|\bDELAY\b|\bGROUP\b\s+\bBY\b|\bHAVING\b|\bORDER\b\s+\bBY\b|"
r"\bINFORMATION_SCHEMA\b|\bSYS\b\.|\bSYSOBJECTS\b|\bPG_SLEEP\b|\bSLEEP\b\(|--|/\|\/|#|\bOR\b\s+['\"]?\d+['\"]?\s*=\s*['\"]?\d+|"
r"\bINFORMATION_SCHEMA\b|\bSYS\b\.|\bSYSOBJECTS\b|\bPG_SLEEP\b|\bSLEEP\b\(|--|/\*|\*/|#|\bOR\b\s+['\"]?\d+['\"]?\s*=\s*['\"]?\d+|"
r"\bAND\b\s+['\"]?\d+['\"]?\s*=\s*['\"]?\d+|"
r"\bXP_CMDSHELL\b|\bLOAD_FILE\b|\bINTO\s+OUTFILE\b)",
re.IGNORECASE,
@ -48,9 +120,13 @@ TRAVERSAL_PATTERN = re.compile(
re.IGNORECASE,
)
# JSON prototype pollution keys
FORBIDDEN_JSON_KEYS = {"__proto__", "constructor", "prototype"}
DYNAMIC_KEYS = {
"data", "calc_results", "plot_results", "point_availabilities", "point_flowrates",
"failure_rates", "custom_parameters", "parameters", "results", "all_params"
}
# =========================
# Helpers
# =========================
@ -61,48 +137,36 @@ def has_control_chars(value: str) -> bool:
def inspect_value(value: str, source: str):
if XSS_PATTERN.search(value):
raise HTTPException(
status_code=400,
detail=f"Potential XSS payload detected in {source}",
)
raise HTTPException(status_code=422, detail=f"Potential XSS payload detected in {source}")
if SQLI_PATTERN.search(value):
raise HTTPException(
status_code=400,
detail=f"Potential SQL injection payload detected in {source}",
)
raise HTTPException(status_code=422, detail=f"Potential SQL injection payload detected in {source}")
if RCE_PATTERN.search(value):
raise HTTPException(
status_code=400,
detail=f"Potential RCE payload detected in {source}",
)
raise HTTPException(status_code=422, detail=f"Potential RCE payload detected in {source}")
if TRAVERSAL_PATTERN.search(value):
raise HTTPException(
status_code=400,
detail=f"Potential traversal payload detected in {source}",
)
raise HTTPException(status_code=422, detail=f"Potential Path Traversal payload detected in {source}")
if has_control_chars(value):
raise HTTPException(
status_code=400,
detail=f"Invalid control characters detected in {source}",
)
raise HTTPException(status_code=422, detail=f"Invalid control characters detected in {source}")
def inspect_json(obj, path="body"):
def inspect_json(obj, path="body", check_whitelist=True):
if isinstance(obj, dict):
for key, value in obj.items():
if key in FORBIDDEN_JSON_KEYS:
raise HTTPException(
status_code=400,
detail=f"Forbidden JSON key detected: {path}.{key}",
)
inspect_json(value, f"{path}.{key}")
raise HTTPException(status_code=422, detail=f"Forbidden JSON key detected: {path}.{key}")
if check_whitelist and key not in ALLOWED_DATA_PARAMS:
raise HTTPException(status_code=422, detail=f"Unknown JSON key detected: {path}.{key}")
# Recurse. If the key is a dynamic container, we stop whitelist checking for children.
should_check_subkeys = check_whitelist and (key not in DYNAMIC_KEYS)
inspect_json(value, f"{path}.{key}", check_whitelist=should_check_subkeys)
elif isinstance(obj, list):
for i, item in enumerate(obj):
inspect_json(item, f"{path}[{i}]")
inspect_json(item, f"{path}[{i}]", check_whitelist=check_whitelist)
elif isinstance(obj, str):
inspect_value(obj, path)
@ -113,22 +177,47 @@ def inspect_json(obj, path="body"):
class RequestValidationMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
# -------------------------
# 0. Header validation
# -------------------------
header_keys = [key.lower() for key, _ in request.headers.items()]
# Check for duplicate headers
header_counter = Counter(header_keys)
duplicate_headers = [key for key, count in header_counter.items() if count > 1]
ALLOW_DUPLICATE_HEADERS = {'accept', 'accept-encoding', 'accept-language', 'accept-charset', 'cookie'}
real_duplicates = [h for h in duplicate_headers if h not in ALLOW_DUPLICATE_HEADERS]
if real_duplicates:
raise HTTPException(status_code=422, detail=f"Duplicate headers are not allowed: {real_duplicates}")
# Whitelist headers
unknown_headers = [key for key in header_keys if key not in ALLOWED_HEADERS]
if unknown_headers:
filtered_unknown = [h for h in unknown_headers if not h.startswith('sec-')]
if filtered_unknown:
raise HTTPException(status_code=422, detail=f"Unknown headers detected: {filtered_unknown}")
# Inspect header values
for key, value in request.headers.items():
if value:
inspect_value(value, f"header '{key}'")
# -------------------------
# 1. Query string limits
# -------------------------
if len(request.url.query) > MAX_QUERY_LENGTH:
raise HTTPException(
status_code=414,
detail="Query string too long",
)
raise HTTPException(status_code=422, detail="Query string too long")
params = request.query_params.multi_items()
if len(params) > MAX_QUERY_PARAMS:
raise HTTPException(
status_code=400,
detail="Too many query parameters",
)
raise HTTPException(status_code=422, detail="Too many query parameters")
# Check for unknown query parameters
unknown_params = [key for key, _ in params if key not in ALLOWED_DATA_PARAMS]
if unknown_params:
raise HTTPException(status_code=422, detail=f"Unknown query parameters detected: {unknown_params}")
# -------------------------
# 2. Duplicate parameters
@ -140,38 +229,25 @@ class RequestValidationMiddleware(BaseHTTPMiddleware):
]
if duplicates:
raise HTTPException(
status_code=400,
detail=f"Duplicate query parameters are not allowed: {duplicates}",
)
raise HTTPException(status_code=422, detail=f"Duplicate query parameters are not allowed: {duplicates}")
# -------------------------
# 3. Query param inspection
# 3. Query param inspection & Pagination
# -------------------------
pagination_size_keys = {"size", "itemsPerPage", "per_page", "limit", "items_per_page"}
for key, value in params:
if value:
inspect_value(value, f"query param '{key}'")
# Pagination constraint: multiples of 5, max 50
if key in pagination_size_keys and value:
try:
size_val = int(value)
if size_val > 50:
raise HTTPException(
status_code=400,
detail=f"Pagination size '{key}' cannot exceed 50",
)
raise HTTPException(status_code=422, detail=f"Pagination size '{key}' cannot exceed 50")
if size_val % 5 != 0:
raise HTTPException(
status_code=400,
detail=f"Pagination size '{key}' must be a multiple of 5",
)
raise HTTPException(status_code=422, detail=f"Pagination size '{key}' must be a multiple of 5")
except ValueError:
raise HTTPException(
status_code=400,
detail=f"Pagination size '{key}' must be an integer",
)
raise HTTPException(status_code=422, detail=f"Pagination size '{key}' must be an integer")
# -------------------------
# 4. Content-Type sanity
@ -179,44 +255,43 @@ class RequestValidationMiddleware(BaseHTTPMiddleware):
content_type = request.headers.get("content-type", "")
if content_type and not any(
content_type.startswith(t)
for t in (
"application/json",
"multipart/form-data",
"application/x-www-form-urlencoded",
)
for t in ("application/json", "multipart/form-data", "application/x-www-form-urlencoded")
):
raise HTTPException(
status_code=415,
detail="Unsupported Content-Type",
)
raise HTTPException(status_code=422, detail="Unsupported Content-Type")
# -------------------------
# 5. JSON body inspection
# 5. Single source check (Query vs JSON Body)
# -------------------------
has_query = len(params) > 0
has_body = False
if content_type.startswith("application/json"):
body = await request.body()
content_length = request.headers.get("content-length")
if content_length and int(content_length) > 0:
has_body = True
if has_query and has_body:
raise HTTPException(status_code=422, detail="Parameters must be from a single source (query string or JSON body), mixed sources are not allowed")
#if len(body) > MAX_JSON_BODY_SIZE:
# raise HTTPException(
# status_code=413,
# detail="JSON body too large",
# )
# -------------------------
# 6. JSON body inspection
# -------------------------
if content_type.startswith("application/json"):
body = await request.body()
# if len(body) > MAX_JSON_BODY_SIZE:
# raise HTTPException(status_code=422, detail="JSON body too large")
if body:
try:
payload = json.loads(body)
except json.JSONDecodeError:
raise HTTPException(
status_code=400,
detail="Invalid JSON body",
)
raise HTTPException(status_code=422, detail="Invalid JSON body")
inspect_json(payload)
# Re-inject body for downstream handlers
async def receive():
return {"type": "http.request", "body": body}
request._receive = receive # noqa: protected-access
request._receive = receive
return await call_next(request)

Loading…
Cancel
Save