@ -14,15 +14,87 @@ ALLOWED_MULTI_PARAMS = {
" exclude[] " ,
}
ALLOWED_DATA_PARAMS = {
" AerosData " , " AhmJobId " , " CustomInput " , " DurationUnit " , " IsDefault " , " Konkin_offset " ,
" MaintenanceOutages " , " MasterData " , " OffSet " , " OverhaulDuration " , " OverhaulInterval " ,
" PlannedOutages " , " SchematicName " , " SecretStr " , " SimDuration " , " SimNumRun " , " SimSeed " ,
" SimulationName " , " actual_shutdown " , " aeros_node " , " all_params " , " aro_file " ,
" aro_file_path " , " availability " , " baseline_simulation_id " , " calc_results " ,
" cm_dis_p1 " , " cm_dis_p2 " , " cm_dis_p3 " , " cm_dis_type " , " cm_dis_unit_code " ,
" cm_waiting_time " , " contribution " , " contribution_factor " , " created_at " ,
" criticality " , " current_user " , " custom_parameters " , " data " , " datetime " ,
" derating_hours " , " descending " , " design_flowrate " , " duration " , " duration_above_h " ,
" duration_above_hh " , " duration_at_empty " , " duration_at_full " , " duration_below_l " ,
" duration_below_ll " , " eaf " , " eaf_konkin " , " effective_loss " , " efficiency " , " efor " ,
" equipment " , " equipment_name " , " exclude " , " failure_rates " , " filter_spec " , " finish " ,
" flowrate_unit " , " id " , " ideal_production " , " ip_dis_p1 " , " ip_dis_p2 " , " ip_dis_p3 " ,
" ip_dis_type " , " ip_dis_unit_code " , " items " , " itemsPerPage " , " items_per_page " ,
" level " , " location_tag " , " master_equipment " , " max_flow_rate " , " max_flowrate " ,
" message " , " model_image " , " mttr " , " name " , " node_id " , " node_name " , " node_type " ,
" num_cm " , " num_events " , " num_ip " , " num_oh " , " num_pm " , " offset " , " oh_dis_p1 " ,
" oh_dis_p2 " , " oh_dis_p3 " , " oh_dis_type " , " oh_dis_unit_code " , " page " ,
" plan_duration " , " planned_outage " , " plot_results " , " pm_dis_p1 " , " pm_dis_p2 " ,
" pm_dis_p3 " , " pm_dis_type " , " pm_dis_unit_code " , " point_availabilities " ,
" point_flowrates " , " production " , " production_std " , " project_name " , " query_str " ,
" rel_dis_p1 " , " rel_dis_p2 " , " rel_dis_p3 " , " rel_dis_type " , " rel_dis_unit_code " ,
" remark " , " schematic_id " , " schematic_name " , " simulation_id " , " simulation_name " ,
" sof " , " sort_by " , " sortBy[] " , " descending[] " , " exclude[] " , " start " , " started_at " ,
" status " , " stg_input " , " storage_capacity " , " structure_name " , " t_wait_for_crew " ,
" t_wait_for_spare " , " target_simulation_id " , " timestamp_outs " , " total " ,
" totalPages " , " total_cm_downtime " , " total_downtime " , " total_ip_downtime " ,
" total_oh_downtime " , " total_pm_downtime " , " total_uptime " , " updated_at " , " year " ,
" _ " , " t " , " timestamp " , " q " , " filter " , " currentUser "
}
ALLOWED_HEADERS = {
" host " ,
" user-agent " ,
" accept " ,
" accept-language " ,
" accept-encoding " ,
" connection " ,
" upgrade-insecure-requests " ,
" if-modified-since " ,
" if-none-match " ,
" cache-control " ,
" authorization " ,
" content-type " ,
" content-length " ,
" origin " ,
" referer " ,
" sec-fetch-dest " ,
" sec-fetch-mode " ,
" sec-fetch-site " ,
" sec-fetch-user " ,
" sec-ch-ua " ,
" sec-ch-ua-mobile " ,
" sec-ch-ua-platform " ,
" pragma " ,
" dnt " ,
" priority " ,
" x-forwarded-for " ,
" x-forwarded-proto " ,
" x-forwarded-host " ,
" x-forwarded-port " ,
" x-real-ip " ,
" x-request-id " ,
" x-correlation-id " ,
" x-requested-with " ,
" x-csrf-token " ,
" x-xsrf-token " ,
" postman-token " ,
" x-internal-key " ,
}
MAX_QUERY_PARAMS = 50
MAX_QUERY_LENGTH = 2000
MAX_JSON_BODY_SIZE = 1024 * 100 # 100 KB
MAX_JSON_BODY_SIZE = 1024 * 500 # 5 00 KB
XSS_PATTERN = re . compile (
r " (<script|<iframe|<embed|<object|<svg|<img|<video|<audio|<base|<link|<meta|<form|<button| "
r " javascript:|vbscript:|data:text/html|onerror \ s*=|onload \ s*=|onmouseover \ s*=|onfocus \ s*=| "
r " onclick \ s*=|onscroll \ s*=|ondblclick \ s*=|onkeydown \ s*=|onkeypress \ s*=|onkeyup \ s*=| "
r " onloadstart \ s*=|onpageshow \ s*=|onresize \ s*=|onunload \ s*=|style \ s*= \ s*[ ' \" ].expression \ s \ (| "
r " onloadstart \ s*=|onpageshow \ s*=|onresize \ s*=|onunload \ s*=|style \ s*= \ s*[ ' \" ]. * expression\ s * \ (| "
r " eval \ s* \ (|setTimeout \ s* \ (|setInterval \ s* \ (|Function \ s* \ () " ,
re . IGNORECASE ,
)
@ -30,7 +102,7 @@ XSS_PATTERN = re.compile(
SQLI_PATTERN = re . compile (
r " ( \ bUNION \ b| \ bSELECT \ b| \ bINSERT \ b| \ bUPDATE \ b| \ bDELETE \ b| \ bDROP \ b| \ bALTER \ b| \ bCREATE \ b| \ bTRUNCATE \ b| "
r " \ bEXEC \ b| \ bEXECUTE \ b| \ bDECLARE \ b| \ bWAITFOR \ b| \ bDELAY \ b| \ bGROUP \ b \ s+ \ bBY \ b| \ bHAVING \ b| \ bORDER \ b \ s+ \ bBY \ b| "
r " \ bINFORMATION_SCHEMA \ b| \ bSYS \ b \ .| \ bSYSOBJECTS \ b| \ bPG_SLEEP \ b| \ bSLEEP \ b \ (|--|/ \ |\ /|#|\ bOR \ b \ s+[ ' \" ]? \ d+[ ' \" ]? \ s*= \ s*[ ' \" ]? \ d+| "
r " \ bINFORMATION_SCHEMA \ b| \ bSYS \ b \ .| \ bSYSOBJECTS \ b| \ bPG_SLEEP \ b| \ bSLEEP \ b \ (|--|/ \ * |\ * /|#|\ bOR \ b \ s+[ ' \" ]? \ d+[ ' \" ]? \ s*= \ s*[ ' \" ]? \ d+| "
r " \ bAND \ b \ s+[ ' \" ]? \ d+[ ' \" ]? \ s*= \ s*[ ' \" ]? \ d+| "
r " \ bXP_CMDSHELL \ b| \ bLOAD_FILE \ b| \ bINTO \ s+OUTFILE \ b) " ,
re . IGNORECASE ,
@ -48,9 +120,13 @@ TRAVERSAL_PATTERN = re.compile(
re . IGNORECASE ,
)
# JSON prototype pollution keys
FORBIDDEN_JSON_KEYS = { " __proto__ " , " constructor " , " prototype " }
DYNAMIC_KEYS = {
" data " , " calc_results " , " plot_results " , " point_availabilities " , " point_flowrates " ,
" failure_rates " , " custom_parameters " , " parameters " , " results " , " all_params "
}
# =========================
# Helpers
# =========================
@ -61,48 +137,36 @@ def has_control_chars(value: str) -> bool:
def inspect_value ( value : str , source : str ) :
if XSS_PATTERN . search ( value ) :
raise HTTPException (
status_code = 400 ,
detail = f " Potential XSS payload detected in { source } " ,
)
raise HTTPException ( status_code = 422 , detail = f " Potential XSS payload detected in { source } " )
if SQLI_PATTERN . search ( value ) :
raise HTTPException (
status_code = 400 ,
detail = f " Potential SQL injection payload detected in { source } " ,
)
raise HTTPException ( status_code = 422 , detail = f " Potential SQL injection payload detected in { source } " )
if RCE_PATTERN . search ( value ) :
raise HTTPException (
status_code = 400 ,
detail = f " Potential RCE payload detected in { source } " ,
)
raise HTTPException ( status_code = 422 , detail = f " Potential RCE payload detected in { source } " )
if TRAVERSAL_PATTERN . search ( value ) :
raise HTTPException (
status_code = 400 ,
detail = f " Potential traversal payload detected in { source } " ,
)
raise HTTPException ( status_code = 422 , detail = f " Potential Path Traversal payload detected in { source } " )
if has_control_chars ( value ) :
raise HTTPException (
status_code = 400 ,
detail = f " Invalid control characters detected in { source } " ,
)
raise HTTPException ( status_code = 422 , detail = f " Invalid control characters detected in { source } " )
def inspect_json ( obj , path = " body " ):
def inspect_json ( obj , path = " body " , check_whitelist = True ) :
if isinstance ( obj , dict ) :
for key , value in obj . items ( ) :
if key in FORBIDDEN_JSON_KEYS :
raise HTTPException (
status_code = 400 ,
detail = f " Forbidden JSON key detected: { path } . { key } " ,
)
inspect_json ( value , f " { path } . { key } " )
raise HTTPException ( status_code = 422 , detail = f " Forbidden JSON key detected: { path } . { key } " )
if check_whitelist and key not in ALLOWED_DATA_PARAMS :
raise HTTPException ( status_code = 422 , detail = f " Unknown JSON key detected: { path } . { key } " )
# Recurse. If the key is a dynamic container, we stop whitelist checking for children.
should_check_subkeys = check_whitelist and ( key not in DYNAMIC_KEYS )
inspect_json ( value , f " { path } . { key } " , check_whitelist = should_check_subkeys )
elif isinstance ( obj , list ) :
for i , item in enumerate ( obj ) :
inspect_json ( item , f " { path } [ { i } ] " )
inspect_json ( item , f " { path } [ { i } ] " , check_whitelist = check_whitelist )
elif isinstance ( obj , str ) :
inspect_value ( obj , path )
@ -113,22 +177,47 @@ def inspect_json(obj, path="body"):
class RequestValidationMiddleware ( BaseHTTPMiddleware ) :
async def dispatch ( self , request : Request , call_next ) :
# -------------------------
# 0. Header validation
# -------------------------
header_keys = [ key . lower ( ) for key , _ in request . headers . items ( ) ]
# Check for duplicate headers
header_counter = Counter ( header_keys )
duplicate_headers = [ key for key , count in header_counter . items ( ) if count > 1 ]
ALLOW_DUPLICATE_HEADERS = { ' accept ' , ' accept-encoding ' , ' accept-language ' , ' accept-charset ' , ' cookie ' }
real_duplicates = [ h for h in duplicate_headers if h not in ALLOW_DUPLICATE_HEADERS ]
if real_duplicates :
raise HTTPException ( status_code = 422 , detail = f " Duplicate headers are not allowed: { real_duplicates } " )
# Whitelist headers
unknown_headers = [ key for key in header_keys if key not in ALLOWED_HEADERS ]
if unknown_headers :
filtered_unknown = [ h for h in unknown_headers if not h . startswith ( ' sec- ' ) ]
if filtered_unknown :
raise HTTPException ( status_code = 422 , detail = f " Unknown headers detected: { filtered_unknown } " )
# Inspect header values
for key , value in request . headers . items ( ) :
if value :
inspect_value ( value , f " header ' { key } ' " )
# -------------------------
# 1. Query string limits
# -------------------------
if len ( request . url . query ) > MAX_QUERY_LENGTH :
raise HTTPException (
status_code = 414 ,
detail = " Query string too long " ,
)
raise HTTPException ( status_code = 422 , detail = " Query string too long " )
params = request . query_params . multi_items ( )
if len ( params ) > MAX_QUERY_PARAMS :
raise HTTPException (
status_code = 400 ,
detail = " Too many query parameters " ,
)
raise HTTPException ( status_code = 422 , detail = " Too many query parameters " )
# Check for unknown query parameters
unknown_params = [ key for key , _ in params if key not in ALLOWED_DATA_PARAMS ]
if unknown_params :
raise HTTPException ( status_code = 422 , detail = f " Unknown query parameters detected: { unknown_params } " )
# -------------------------
# 2. Duplicate parameters
@ -140,38 +229,25 @@ class RequestValidationMiddleware(BaseHTTPMiddleware):
]
if duplicates :
raise HTTPException (
status_code = 400 ,
detail = f " Duplicate query parameters are not allowed: { duplicates } " ,
)
raise HTTPException ( status_code = 422 , detail = f " Duplicate query parameters are not allowed: { duplicates } " )
# -------------------------
# 3. Query param inspection
# 3. Query param inspection & Pagination
# -------------------------
pagination_size_keys = { " size " , " itemsPerPage " , " per_page " , " limit " , " items_per_page " }
for key , value in params :
if value :
inspect_value ( value , f " query param ' { key } ' " )
# Pagination constraint: multiples of 5, max 50
if key in pagination_size_keys and value :
try :
size_val = int ( value )
if size_val > 50 :
raise HTTPException (
status_code = 400 ,
detail = f " Pagination size ' { key } ' cannot exceed 50 " ,
)
raise HTTPException ( status_code = 422 , detail = f " Pagination size ' { key } ' cannot exceed 50 " )
if size_val % 5 != 0 :
raise HTTPException (
status_code = 400 ,
detail = f " Pagination size ' { key } ' must be a multiple of 5 " ,
)
raise HTTPException ( status_code = 422 , detail = f " Pagination size ' { key } ' must be a multiple of 5 " )
except ValueError :
raise HTTPException (
status_code = 400 ,
detail = f " Pagination size ' { key } ' must be an integer " ,
)
raise HTTPException ( status_code = 422 , detail = f " Pagination size ' { key } ' must be an integer " )
# -------------------------
# 4. Content-Type sanity
@ -179,44 +255,43 @@ class RequestValidationMiddleware(BaseHTTPMiddleware):
content_type = request . headers . get ( " content-type " , " " )
if content_type and not any (
content_type . startswith ( t )
for t in (
" application/json " ,
" multipart/form-data " ,
" application/x-www-form-urlencoded " ,
)
for t in ( " application/json " , " multipart/form-data " , " application/x-www-form-urlencoded " )
) :
raise HTTPException (
status_code = 415 ,
detail = " Unsupported Content-Type " ,
)
raise HTTPException ( status_code = 422 , detail = " Unsupported Content-Type " )
# -------------------------
# 5. JSON body inspection
# 5. Single source check (Query vs JSON Body)
# -------------------------
has_query = len ( params ) > 0
has_body = False
if content_type . startswith ( " application/json " ) :
body = await request . body ( )
content_length = request . headers . get ( " content-length " )
if content_length and int ( content_length ) > 0 :
has_body = True
if has_query and has_body :
raise HTTPException ( status_code = 422 , detail = " Parameters must be from a single source (query string or JSON body), mixed sources are not allowed " )
# -------------------------
# 6. JSON body inspection
# -------------------------
if content_type . startswith ( " application/json " ) :
body = await request . body ( )
# if len(body) > MAX_JSON_BODY_SIZE:
# raise HTTPException(
# status_code=413,
# detail="JSON body too large",
# )
# raise HTTPException(status_code=422, detail="JSON body too large")
if body :
try :
payload = json . loads ( body )
except json . JSONDecodeError :
raise HTTPException (
status_code = 400 ,
detail = " Invalid JSON body " ,
)
raise HTTPException ( status_code = 422 , detail = " Invalid JSON body " )
inspect_json ( payload )
# Re-inject body for downstream handlers
async def receive ( ) :
return { " type " : " http.request " , " body " : body }
request . _receive = receive # noqa: protected-access
request . _receive = receive
return await call_next ( request )