Refactor AEROS API integration and dependency management, enhance security middleware, and refine validation rules for query parameters and schema fields.

main
Cizz22 2 weeks ago
parent 4392809e81
commit fadfafc241

@ -18,5 +18,5 @@ COLLECTOR_NAME=digital_aeros_fixed
AEROS_LICENSE_ID=20260218-Jre5VZieQfWXTq0G8ClpVSGszMf4UEUMLS5ENpWRVcoVSrNJckVZzXE AEROS_LICENSE_ID=20260218-Jre5VZieQfWXTq0G8ClpVSGszMf4UEUMLS5ENpWRVcoVSrNJckVZzXE
AEROS_LICENSE_SECRET=GmLIxf9fr8Ap5m1IYzkk4RPBFcm7UBvcd0eRdRQ03oRdxLHQA0d9oyhUk2ZlM3LVdRh1mkgYy5254bmCjFyWWc0oPFwNWYzNwDwnv50qy6SLRdaFnI0yZcfLbWQ7qCSj AEROS_LICENSE_SECRET=GmLIxf9fr8Ap5m1IYzkk4RPBFcm7UBvcd0eRdRQ03oRdxLHQA0d9oyhUk2ZlM3LVdRh1mkgYy5254bmCjFyWWc0oPFwNWYzNwDwnv50qy6SLRdaFnI0yZcfLbWQ7qCSj
WINDOWS_AEROS_BASE_URL=http://192.168.1.102:8800 WINDOWS_AEROS_BASE_URL=http://192.168.1.102:8080
TEMPORAL_URL=http://192.168.1.86:7233 TEMPORAL_URL=http://192.168.1.86:7233

16
poetry.lock generated

@ -1312,15 +1312,15 @@ i18n = ["Babel (>=2.7)"]
[[package]] [[package]]
name = "licaeros" name = "licaeros"
version = "0.1.2" version = "0.1.7"
description = "License App for Aeros" description = "License App for Aeros"
optional = false optional = false
python-versions = "*" python-versions = "*"
groups = ["main"] groups = ["main"]
files = [ files = [
{file = "licaeros-0.1.2-cp310-cp310-linux_x86_64.whl", hash = "sha256:4b9bfe2e7ba8ab9edb5db18dcb415476e7ab302e09d72b74b5bfd1ac8938b10c"}, {file = "licaeros-0.1.7-cp310-cp310-linux_x86_64.whl", hash = "sha256:77bec84f37e02a7aff84f6c45a97a5933a86d99cdfabfd74ede36fa64506bfde"},
{file = "licaeros-0.1.2-cp311-cp311-linux_x86_64.whl", hash = "sha256:4f3a2251aebe7351e61d6f80d6c7474387f9561fdcfff02103b78bb2168c9791"}, {file = "licaeros-0.1.7-cp311-cp311-linux_x86_64.whl", hash = "sha256:48e874645c5892e05c8f26bdea910dcdaa3e7b0e787be77920a4f4fb5504b2c1"},
{file = "licaeros-0.1.2-cp312-cp312-linux_x86_64.whl", hash = "sha256:933c24029aec984ccc39baf630fbee10e07c1e28192c499685bec0a11d31321d"}, {file = "licaeros-0.1.7-cp312-cp312-linux_x86_64.whl", hash = "sha256:6a0bf6c1b9094693058d927febdb6799c61aea7b5dd10265b014c9d314844135"},
] ]
[package.source] [package.source]
@ -2478,14 +2478,14 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
[[package]] [[package]]
name = "rich" name = "rich"
version = "14.3.2" version = "14.3.3"
description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal"
optional = false optional = false
python-versions = ">=3.8.0" python-versions = ">=3.8.0"
groups = ["main"] groups = ["main"]
files = [ files = [
{file = "rich-14.3.2-py3-none-any.whl", hash = "sha256:08e67c3e90884651da3239ea668222d19bea7b589149d8014a21c633420dbb69"}, {file = "rich-14.3.3-py3-none-any.whl", hash = "sha256:793431c1f8619afa7d3b52b2cdec859562b950ea0d4b6b505397612db8d5362d"},
{file = "rich-14.3.2.tar.gz", hash = "sha256:e712f11c1a562a11843306f5ed999475f09ac31ffb64281f73ab29ffdda8b3b8"}, {file = "rich-14.3.3.tar.gz", hash = "sha256:b8daa0b9e4eef54dd8cf7c86c03713f53241884e814f4e2f5fb342fe520f639b"},
] ]
[package.dependencies] [package.dependencies]
@ -3593,4 +3593,4 @@ propcache = ">=0.2.1"
[metadata] [metadata]
lock-version = "2.1" lock-version = "2.1"
python-versions = "^3.11" python-versions = "^3.11"
content-hash = "c97aecfef075bcbd7a40d9c98ae79c30d6253bc2c9f14ef187b1a098ace42088" content-hash = "46b6c8d43f09a99729b212166e31fd9190f8f659e178261a15bc35a694e2f81c"

@ -32,7 +32,8 @@ aiohttp = "^3.12.14"
ijson = "^3.4.0" ijson = "^3.4.0"
redis = "^7.1.0" redis = "^7.1.0"
clamd = "^1.0.2" clamd = "^1.0.2"
licaeros = "^0.1.2" licaeros = "^0.1.7"
[[tool.poetry.source]] [[tool.poetry.source]]

@ -54,10 +54,12 @@ async def get_all(*, common):
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
) )
response.raise_for_status() response.raise_for_status()
res = response.json() res = response.json()
# if not res.get("status"):
# raise HTTPException(
# status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=res.get("message")
# )
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(

@ -1,6 +1,6 @@
import os import os
from typing import Optional from typing import Optional
import json
import httpx import httpx
from fastapi import HTTPException, status from fastapi import HTTPException, status
from sqlalchemy import Delete, desc, Select, select, func from sqlalchemy import Delete, desc, Select, select, func
@ -10,7 +10,7 @@ from src.aeros_equipment.service import save_default_equipment
from src.aeros_simulation.service import save_default_simulation_node from src.aeros_simulation.service import save_default_simulation_node
from src.auth.service import CurrentUser from src.auth.service import CurrentUser
from src.config import WINDOWS_AEROS_BASE_URL, CLAMAV_HOST, CLAMAV_PORT from src.config import WINDOWS_AEROS_BASE_URL, CLAMAV_HOST, CLAMAV_PORT
from src.aeros_utils import aeros_post from src.aeros_utils import aeros_post, aeros_file_upload
from src.database.core import DbSession from src.database.core import DbSession
from src.database.service import search_filter_sort_paginate from src.database.service import search_filter_sort_paginate
from src.utils import sanitize_filename from src.utils import sanitize_filename
@ -119,8 +119,8 @@ async def import_aro_project(*, db_session: DbSession, aeros_project_in: AerosPr
# } # }
response = await aeros_file_upload( response = await aeros_file_upload(
"/api/upload", "/upload",
file, content,
"file", "file",
clean_filename clean_filename
) )
@ -139,7 +139,6 @@ async def import_aro_project(*, db_session: DbSession, aeros_project_in: AerosPr
aro_path = upload_result.get("full_path") aro_path = upload_result.get("full_path")
filename = upload_result.get("stored_filename").replace(".aro", "") filename = upload_result.get("stored_filename").replace(".aro", "")
if not aro_path: if not aro_path:
raise HTTPException( raise HTTPException(
status_code=500, status_code=500,
@ -176,15 +175,22 @@ async def import_aro_project(*, db_session: DbSession, aeros_project_in: AerosPr
await db_session.commit() await db_session.commit()
aro_json = json.dumps(aro_path)
# Update path to AEROS APP # Update path to AEROS APP
# Example BODy "C/dsad/dsad.aro" # Example BODy "C/dsad/dsad.aro"
try: try:
response = await aeros_post( response = await aeros_post(
"/api/Project/ImportAROFile", "/api/Project/ImportAROFile",
data=f'"{aro_path}"', json_data=aro_json,
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
) )
response.raise_for_status()
# response.raise_for_status()
response_json = response.json()
raise Exception(response_json)
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e) status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)

@ -309,7 +309,7 @@ async def get_simulation_result_plot_per_node(db_session: DbSession, simulation_
} }
@router.get("/result/ranking/{simulation_id}", response_model=StandardResponse[List[SimulationRankingParameters]]) @router.get("/result/ranking/{simulation_id}", response_model=StandardResponse[List[SimulationRankingParameters]])
async def get_simulation_result_ranking(db_session: DbSession, simulation_id, limit:int = Query(None)): async def get_simulation_result_ranking(db_session: DbSession, simulation_id, limit:int = Query(None, le=50)):
"""Get simulation result.""" """Get simulation result."""
if simulation_id == 'default': if simulation_id == 'default':
simulation = await get_default_simulation(db_session=db_session) simulation = await get_default_simulation(db_session=db_session)

@ -1,6 +1,6 @@
import anyio import anyio
from licaeros import LicensedSession, device_fingerprint_hex from licaeros import LicensedSession, device_fingerprint_hex
from src.config import AEROS_BASE_URL, AEROS_LICENSE_ID, AEROS_LICENSE_SECRET from src.config import AEROS_BASE_URL, AEROS_LICENSE_ID, AEROS_LICENSE_SECRET, WINDOWS_AEROS_BASE_URL
import logging import logging
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -8,33 +8,35 @@ log = logging.getLogger(__name__)
# Initialize a global session if possible, or create on demand # Initialize a global session if possible, or create on demand
_aeros_session = None _aeros_session = None
def get_aeros_session(): def get_aeros_session(base_url):
global _aeros_session global _aeros_session
if _aeros_session is None: if _aeros_session is None:
log.info(f"Initializing LicensedSession with base URL: {AEROS_BASE_URL}") log.info(f"Initializing LicensedSession with base URL: {base_url}")
log.info(f"Encrypted Device ID: {device_fingerprint_hex()}") log.info(f"Encrypted Device ID: {device_fingerprint_hex()}")
_aeros_session = LicensedSession( _aeros_session = LicensedSession(
api_base=AEROS_BASE_URL, api_base=base_url,
license_id=AEROS_LICENSE_ID, license_id=AEROS_LICENSE_ID,
license_secret=AEROS_LICENSE_SECRET, license_secret=AEROS_LICENSE_SECRET,
) )
return _aeros_session return _aeros_session
async def aeros_post(path: str, json: dict = None, **kwargs): async def aeros_post(path: str, json=None, data=None, **kwargs):
""" """
Asynchronous wrapper for LicensedSession.post Asynchronous wrapper for LicensedSession.post
""" """
session = get_aeros_session() session = get_aeros_session(WINDOWS_AEROS_BASE_URL)
url = f"/api/aeros{path}"
# LicensedSession might not be async-compatible, so we run it in a thread # LicensedSession might not be async-compatible, so we run it in a thread
response = await anyio.to_thread.run_sync( response = await anyio.to_thread.run_sync(
lambda: session.post(path, json) lambda: session.post(url, json_data=json, data=data, headers=kwargs.get("headers"))
) )
return response return response
async def aeros_file_upload(path, file, field_name, filename): async def aeros_file_upload(path, file, field_name, filename):
session = get_aeros_session() session = get_aeros_session(WINDOWS_AEROS_BASE_URL)
url = f"/api/aeros{path}"
response = await anyio.to_thread.run_sync( response = await anyio.to_thread.run_sync(
lambda: session.post_multipart(path, file, field_name, filename) lambda: session.post_multipart(url, file, field_name, filename)
) )
return response return response

@ -8,7 +8,7 @@ class CommonParams(DefultBase):
# This ensures no extra query params are allowed # This ensures no extra query params are allowed
current_user: Optional[str] = Field(None, alias="currentUser", max_length=50) current_user: Optional[str] = Field(None, alias="currentUser", max_length=50)
page: int = Field(1, gt=0, lt=2147483647) page: int = Field(1, gt=0, lt=2147483647)
items_per_page: int = Field(5, gt=-2, lt=2147483647, alias="itemsPerPage") items_per_page: int = Field(5, gt=0, le=50, multiple_of=5, alias="itemsPerPage")
query_str: Optional[str] = Field(None, alias="q", max_length=100) query_str: Optional[str] = Field(None, alias="q", max_length=100)
filter_spec: Optional[str] = Field(None, alias="filter", max_length=500) filter_spec: Optional[str] = Field(None, alias="filter", max_length=500)
sort_by: List[str] = Field(default_factory=list, alias="sortBy[]") sort_by: List[str] = Field(default_factory=list, alias="sortBy[]")

@ -19,8 +19,8 @@ log = logging.getLogger(__name__)
class ErrorDetail(BaseModel): class ErrorDetail(BaseModel):
field: Optional[str] = Field(None, max_length=100) field: Optional[str] = Field(None, max_length=100)
message: str = Field(..., max_length=255) message: str = Field(...)
code: Optional[str] = Field(None, max_length=50) code: Optional[str] = Field(None)
params: Optional[Dict[str, Any]] = None params: Optional[Dict[str, Any]] = None

@ -18,14 +18,33 @@ MAX_QUERY_PARAMS = 50
MAX_QUERY_LENGTH = 2000 MAX_QUERY_LENGTH = 2000
MAX_JSON_BODY_SIZE = 1024 * 100 # 100 KB MAX_JSON_BODY_SIZE = 1024 * 100 # 100 KB
# Very targeted patterns. Avoid catastrophic regex nonsense.
XSS_PATTERN = re.compile( XSS_PATTERN = re.compile(
r"(<script|</script|javascript:|onerror\s*=|onload\s*=|<svg|<img)", r"(<script|<iframe|<embed|<object|<svg|<img|<video|<audio|<base|<link|<meta|<form|<button|"
r"javascript:|vbscript:|data:text/html|onerror\s*=|onload\s*=|onmouseover\s*=|onfocus\s*=|"
r"onclick\s*=|onscroll\s*=|ondblclick\s*=|onkeydown\s*=|onkeypress\s*=|onkeyup\s*=|"
r"onloadstart\s*=|onpageshow\s*=|onresize\s*=|onunload\s*=|style\s*=\s*['\"].expression\s\(|"
r"eval\s*\(|setTimeout\s*\(|setInterval\s*\(|Function\s*\()",
re.IGNORECASE, re.IGNORECASE,
) )
SQLI_PATTERN = re.compile( SQLI_PATTERN = re.compile(
r"(\bUNION\b|\bSELECT\b|\bINSERT\b|\bDELETE\b|\bDROP\b|--|\bOR\b\s+1=1)", r"(\bUNION\b|\bSELECT\b|\bINSERT\b|\bUPDATE\b|\bDELETE\b|\bDROP\b|\bALTER\b|\bCREATE\b|\bTRUNCATE\b|"
r"\bEXEC\b|\bEXECUTE\b|\bDECLARE\b|\bWAITFOR\b|\bDELAY\b|\bGROUP\b\s+\bBY\b|\bHAVING\b|\bORDER\b\s+\bBY\b|"
r"\bINFORMATION_SCHEMA\b|\bSYS\b\.|\bSYSOBJECTS\b|\bPG_SLEEP\b|\bSLEEP\b\(|--|/\|\/|#|\bOR\b\s+['\"]?\d+['\"]?\s*=\s*['\"]?\d+|"
r"\bAND\b\s+['\"]?\d+['\"]?\s*=\s*['\"]?\d+|"
r"\bXP_CMDSHELL\b|\bLOAD_FILE\b|\bINTO\s+OUTFILE\b)",
re.IGNORECASE,
)
RCE_PATTERN = re.compile(
r"(\$\(|.*|[;&|]\s*(cat|ls|id|whoami|pwd|ifconfig|ip|netstat|nc|netcat|nmap|curl|wget|python|php|perl|ruby|bash|sh|cmd|powershell|pwsh|sc\s+|tasklist|taskkill|base64|sudo|crontab|ssh|ftp|tftp)|"
r"\b(cat|ls|id|whoami|pwd|ifconfig|ip|netstat|nc|netcat|nmap|curl|wget|python|php|perl|ruby|bash|sh|cmd|powershell|pwsh|base64|sudo|crontab)\b|"
r"/etc/passwd|/etc/shadow|/etc/group|/etc/issue|/proc/self/|/windows/system32/|C:\\Windows\\)",
re.IGNORECASE,
)
TRAVERSAL_PATTERN = re.compile(
r"(\.\./|\.\.\\|%2e%2e%2f|%2e%2e/|\.\.%2f|%2e%2e%5c)",
re.IGNORECASE, re.IGNORECASE,
) )
@ -53,6 +72,18 @@ def inspect_value(value: str, source: str):
detail=f"Potential SQL injection payload detected in {source}", detail=f"Potential SQL injection payload detected in {source}",
) )
if RCE_PATTERN.search(value):
raise HTTPException(
status_code=400,
detail=f"Potential RCE payload detected in {source}",
)
if TRAVERSAL_PATTERN.search(value):
raise HTTPException(
status_code=400,
detail=f"Potential traversal payload detected in {source}",
)
if has_control_chars(value): if has_control_chars(value):
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
@ -117,10 +148,31 @@ class RequestValidationMiddleware(BaseHTTPMiddleware):
# ------------------------- # -------------------------
# 3. Query param inspection # 3. Query param inspection
# ------------------------- # -------------------------
pagination_size_keys = {"size", "itemsPerPage", "per_page", "limit"}
for key, value in params: for key, value in params:
if value: if value:
inspect_value(value, f"query param '{key}'") inspect_value(value, f"query param '{key}'")
# Pagination constraint: multiples of 5, max 50
if key in pagination_size_keys and value:
try:
size_val = int(value)
if size_val > 50:
raise HTTPException(
status_code=400,
detail=f"Pagination size '{key}' cannot exceed 50",
)
if size_val % 5 != 0:
raise HTTPException(
status_code=400,
detail=f"Pagination size '{key}' must be a multiple of 5",
)
except ValueError:
raise HTTPException(
status_code=400,
detail=f"Pagination size '{key}' must be an integer",
)
# ------------------------- # -------------------------
# 4. Content-Type sanity # 4. Content-Type sanity
# ------------------------- # -------------------------

Loading…
Cancel
Save