Merge pull request 'main' (#10) from CIzz22/rbd-app:main into main

Reviewed-on: DigitalTwin/rbd-app#10
main
CIzz22 2 weeks ago
commit fff11eb527

@ -1,5 +1,5 @@
ENV=development ENV=development
LOG_LEVEL=ERROR LOG_LEVEL=INFO
PORT=3020 PORT=3020
HOST=0.0.0.0 HOST=0.0.0.0
@ -18,5 +18,5 @@ COLLECTOR_NAME=digital_aeros_fixed
AEROS_LICENSE_ID=20260218-Jre5VZieQfWXTq0G8ClpVSGszMf4UEUMLS5ENpWRVcoVSrNJckVZzXE AEROS_LICENSE_ID=20260218-Jre5VZieQfWXTq0G8ClpVSGszMf4UEUMLS5ENpWRVcoVSrNJckVZzXE
AEROS_LICENSE_SECRET=GmLIxf9fr8Ap5m1IYzkk4RPBFcm7UBvcd0eRdRQ03oRdxLHQA0d9oyhUk2ZlM3LVdRh1mkgYy5254bmCjFyWWc0oPFwNWYzNwDwnv50qy6SLRdaFnI0yZcfLbWQ7qCSj AEROS_LICENSE_SECRET=GmLIxf9fr8Ap5m1IYzkk4RPBFcm7UBvcd0eRdRQ03oRdxLHQA0d9oyhUk2ZlM3LVdRh1mkgYy5254bmCjFyWWc0oPFwNWYzNwDwnv50qy6SLRdaFnI0yZcfLbWQ7qCSj
WINDOWS_AEROS_BASE_URL=http://192.168.1.102:8800 WINDOWS_AEROS_BASE_URL=http://192.168.1.102:8080
TEMPORAL_URL=http://192.168.1.86:7233 TEMPORAL_URL=http://192.168.1.86:7233

16
poetry.lock generated

@ -1312,15 +1312,15 @@ i18n = ["Babel (>=2.7)"]
[[package]] [[package]]
name = "licaeros" name = "licaeros"
version = "0.1.2" version = "0.1.7"
description = "License App for Aeros" description = "License App for Aeros"
optional = false optional = false
python-versions = "*" python-versions = "*"
groups = ["main"] groups = ["main"]
files = [ files = [
{file = "licaeros-0.1.2-cp310-cp310-linux_x86_64.whl", hash = "sha256:4b9bfe2e7ba8ab9edb5db18dcb415476e7ab302e09d72b74b5bfd1ac8938b10c"}, {file = "licaeros-0.1.7-cp310-cp310-linux_x86_64.whl", hash = "sha256:77bec84f37e02a7aff84f6c45a97a5933a86d99cdfabfd74ede36fa64506bfde"},
{file = "licaeros-0.1.2-cp311-cp311-linux_x86_64.whl", hash = "sha256:4f3a2251aebe7351e61d6f80d6c7474387f9561fdcfff02103b78bb2168c9791"}, {file = "licaeros-0.1.7-cp311-cp311-linux_x86_64.whl", hash = "sha256:48e874645c5892e05c8f26bdea910dcdaa3e7b0e787be77920a4f4fb5504b2c1"},
{file = "licaeros-0.1.2-cp312-cp312-linux_x86_64.whl", hash = "sha256:933c24029aec984ccc39baf630fbee10e07c1e28192c499685bec0a11d31321d"}, {file = "licaeros-0.1.7-cp312-cp312-linux_x86_64.whl", hash = "sha256:6a0bf6c1b9094693058d927febdb6799c61aea7b5dd10265b014c9d314844135"},
] ]
[package.source] [package.source]
@ -2478,14 +2478,14 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
[[package]] [[package]]
name = "rich" name = "rich"
version = "14.3.2" version = "14.3.3"
description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal"
optional = false optional = false
python-versions = ">=3.8.0" python-versions = ">=3.8.0"
groups = ["main"] groups = ["main"]
files = [ files = [
{file = "rich-14.3.2-py3-none-any.whl", hash = "sha256:08e67c3e90884651da3239ea668222d19bea7b589149d8014a21c633420dbb69"}, {file = "rich-14.3.3-py3-none-any.whl", hash = "sha256:793431c1f8619afa7d3b52b2cdec859562b950ea0d4b6b505397612db8d5362d"},
{file = "rich-14.3.2.tar.gz", hash = "sha256:e712f11c1a562a11843306f5ed999475f09ac31ffb64281f73ab29ffdda8b3b8"}, {file = "rich-14.3.3.tar.gz", hash = "sha256:b8daa0b9e4eef54dd8cf7c86c03713f53241884e814f4e2f5fb342fe520f639b"},
] ]
[package.dependencies] [package.dependencies]
@ -3593,4 +3593,4 @@ propcache = ">=0.2.1"
[metadata] [metadata]
lock-version = "2.1" lock-version = "2.1"
python-versions = "^3.11" python-versions = "^3.11"
content-hash = "c97aecfef075bcbd7a40d9c98ae79c30d6253bc2c9f14ef187b1a098ace42088" content-hash = "46b6c8d43f09a99729b212166e31fd9190f8f659e178261a15bc35a694e2f81c"

@ -32,7 +32,8 @@ aiohttp = "^3.12.14"
ijson = "^3.4.0" ijson = "^3.4.0"
redis = "^7.1.0" redis = "^7.1.0"
clamd = "^1.0.2" clamd = "^1.0.2"
licaeros = "^0.1.2" licaeros = "^0.1.7"
[[tool.poetry.source]] [[tool.poetry.source]]

@ -294,12 +294,12 @@ def calculate_contribution_accurate(availabilities: Dict[str, float], structure_
key=lambda x: x[1]['birnbaum_importance'], key=lambda x: x[1]['birnbaum_importance'],
reverse=True) reverse=True)
print("\n=== COMPONENT IMPORTANCE ANALYSIS ===") # print("\n=== COMPONENT IMPORTANCE ANALYSIS ===")
print(f"System Availability: {system_info['system_availability']:.6f} ({system_info['system_availability']*100:.4f}%)") # print(f"System Availability: {system_info['system_availability']:.6f} ({system_info['system_availability']*100:.4f}%)")
print(f"System Unavailability: {system_info['system_unavailability']:.6f}") # print(f"System Unavailability: {system_info['system_unavailability']:.6f}")
print("\nComponent Rankings (by Birnbaum Importance):") # print("\nComponent Rankings (by Birnbaum Importance):")
print(f"{'Component':<20} {'Availability':<12} {'Birnbaum':<12} {'Criticality':<12} {'F-V':<12} {'Contribution%':<12}") # print(f"{'Component':<20} {'Availability':<12} {'Birnbaum':<12} {'Criticality':<12} {'F-V':<12} {'Contribution%':<12}")
print("-" * 92) # print("-" * 92)
for component, measures in sorted_components: for component, measures in sorted_components:
print(f"{component:<20} {measures['component_availability']:<12.6f} " print(f"{component:<20} {measures['component_availability']:<12.6f} "

@ -54,10 +54,12 @@ async def get_all(*, common):
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
) )
response.raise_for_status() response.raise_for_status()
res = response.json() res = response.json()
# if not res.get("status"):
# raise HTTPException(
# status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=res.get("message")
# )
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(

@ -1,6 +1,6 @@
import os import os
from typing import Optional from typing import Optional
import json
import httpx import httpx
from fastapi import HTTPException, status from fastapi import HTTPException, status
from sqlalchemy import Delete, desc, Select, select, func from sqlalchemy import Delete, desc, Select, select, func
@ -10,7 +10,7 @@ from src.aeros_equipment.service import save_default_equipment
from src.aeros_simulation.service import save_default_simulation_node from src.aeros_simulation.service import save_default_simulation_node
from src.auth.service import CurrentUser from src.auth.service import CurrentUser
from src.config import WINDOWS_AEROS_BASE_URL, CLAMAV_HOST, CLAMAV_PORT from src.config import WINDOWS_AEROS_BASE_URL, CLAMAV_HOST, CLAMAV_PORT
from src.aeros_utils import aeros_post from src.aeros_utils import aeros_post, aeros_file_upload
from src.database.core import DbSession from src.database.core import DbSession
from src.database.service import search_filter_sort_paginate from src.database.service import search_filter_sort_paginate
from src.utils import sanitize_filename from src.utils import sanitize_filename
@ -119,8 +119,8 @@ async def import_aro_project(*, db_session: DbSession, aeros_project_in: AerosPr
# } # }
response = await aeros_file_upload( response = await aeros_file_upload(
"/api/upload", "/upload",
file, content,
"file", "file",
clean_filename clean_filename
) )
@ -139,7 +139,6 @@ async def import_aro_project(*, db_session: DbSession, aeros_project_in: AerosPr
aro_path = upload_result.get("full_path") aro_path = upload_result.get("full_path")
filename = upload_result.get("stored_filename").replace(".aro", "") filename = upload_result.get("stored_filename").replace(".aro", "")
if not aro_path: if not aro_path:
raise HTTPException( raise HTTPException(
status_code=500, status_code=500,
@ -176,15 +175,21 @@ async def import_aro_project(*, db_session: DbSession, aeros_project_in: AerosPr
await db_session.commit() await db_session.commit()
# aro_json = json.dumps(aro_path)
# Update path to AEROS APP # Update path to AEROS APP
# Example BODy "C/dsad/dsad.aro" # Example BODy "C/dsad/dsad.aro"
try: try:
response = await aeros_post( response = await aeros_post(
"/api/Project/ImportAROFile", "/api/Project/ImportAROFile",
data=f'"{aro_path}"', json=aro_path,
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
) )
response.raise_for_status() response.raise_for_status()
response_json = response.json()
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e) status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)

@ -309,7 +309,7 @@ async def get_simulation_result_plot_per_node(db_session: DbSession, simulation_
} }
@router.get("/result/ranking/{simulation_id}", response_model=StandardResponse[List[SimulationRankingParameters]]) @router.get("/result/ranking/{simulation_id}", response_model=StandardResponse[List[SimulationRankingParameters]])
async def get_simulation_result_ranking(db_session: DbSession, simulation_id, limit:int = Query(None)): async def get_simulation_result_ranking(db_session: DbSession, simulation_id, limit:int = Query(None, le=50)):
"""Get simulation result.""" """Get simulation result."""
if simulation_id == 'default': if simulation_id == 'default':
simulation = await get_default_simulation(db_session=db_session) simulation = await get_default_simulation(db_session=db_session)

@ -1,6 +1,6 @@
import anyio import anyio
from licaeros import LicensedSession, device_fingerprint_hex from licaeros import LicensedSession, device_fingerprint_hex
from src.config import AEROS_BASE_URL, AEROS_LICENSE_ID, AEROS_LICENSE_SECRET from src.config import AEROS_BASE_URL, AEROS_LICENSE_ID, AEROS_LICENSE_SECRET, WINDOWS_AEROS_BASE_URL
import logging import logging
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -8,33 +8,36 @@ log = logging.getLogger(__name__)
# Initialize a global session if possible, or create on demand # Initialize a global session if possible, or create on demand
_aeros_session = None _aeros_session = None
def get_aeros_session(): def get_aeros_session(base_url):
global _aeros_session global _aeros_session
if _aeros_session is None: if _aeros_session is None:
log.info(f"Initializing LicensedSession with base URL: {AEROS_BASE_URL}") log.info(f"Initializing LicensedSession with base URL: {base_url}")
log.info(f"Encrypted Device ID: {device_fingerprint_hex()}") log.info(f"Encrypted Device ID: {device_fingerprint_hex()}")
_aeros_session = LicensedSession( _aeros_session = LicensedSession(
api_base=AEROS_BASE_URL, api_base=base_url,
license_id=AEROS_LICENSE_ID, license_id=AEROS_LICENSE_ID,
license_secret=AEROS_LICENSE_SECRET, license_secret=AEROS_LICENSE_SECRET,
timeout=1000
) )
return _aeros_session return _aeros_session
async def aeros_post(path: str, json: dict = None, **kwargs): async def aeros_post(path: str, json=None, data=None, **kwargs):
""" """
Asynchronous wrapper for LicensedSession.post Asynchronous wrapper for LicensedSession.post
""" """
session = get_aeros_session() session = get_aeros_session(WINDOWS_AEROS_BASE_URL)
url = f"/api/aeros{path}"
# LicensedSession might not be async-compatible, so we run it in a thread # LicensedSession might not be async-compatible, so we run it in a thread
response = await anyio.to_thread.run_sync( response = await anyio.to_thread.run_sync(
lambda: session.post(path, json) lambda: session.post(url, json_data=json, data=data, headers=kwargs.get("headers"))
) )
return response return response
async def aeros_file_upload(path, file, field_name, filename): async def aeros_file_upload(path, file, field_name, filename):
session = get_aeros_session() session = get_aeros_session(WINDOWS_AEROS_BASE_URL)
url = f"/api/aeros{path}"
response = await anyio.to_thread.run_sync( response = await anyio.to_thread.run_sync(
lambda: session.post_multipart(path, file, field_name, filename) lambda: session.post_multipart(url, file, field_name, filename)
) )
return response return response

@ -31,6 +31,17 @@ class JWTBearer(HTTPBearer):
) )
request.state.user = user_info request.state.user = user_info
from src.context import set_user_id, set_username, set_role
if hasattr(user_info, "user_id"):
set_user_id(str(user_info.user_id))
if hasattr(user_info, "username"):
set_username(user_info.username)
elif hasattr(user_info, "name"):
set_username(user_info.name)
if hasattr(user_info, "role"):
set_role(user_info.role)
return user_info return user_info
else: else:
raise HTTPException(status_code=403, detail="Invalid authorization code.") raise HTTPException(status_code=403, detail="Invalid authorization code.")

@ -45,7 +45,7 @@ def get_config():
config = get_config() config = get_config()
LOG_LEVEL = config("LOG_LEVEL", default=logging.INFO) LOG_LEVEL = config("LOG_LEVEL", default="INFO")
ENV = config("ENV", default="local") ENV = config("ENV", default="local")
PORT = config("PORT", cast=int, default=8000) PORT = config("PORT", cast=int, default=8000)
HOST = config("HOST", default="localhost") HOST = config("HOST", default="localhost")

@ -2,8 +2,14 @@ from contextvars import ContextVar
from typing import Optional, Final from typing import Optional, Final
REQUEST_ID_CTX_KEY: Final[str] = "request_id" REQUEST_ID_CTX_KEY: Final[str] = "request_id"
_request_id_ctx_var: ContextVar[Optional[str]] = ContextVar( USER_ID_CTX_KEY: Final[str] = "user_id"
REQUEST_ID_CTX_KEY, default=None) USERNAME_CTX_KEY: Final[str] = "username"
ROLE_CTX_KEY: Final[str] = "role"
_request_id_ctx_var: ContextVar[Optional[str]] = ContextVar(REQUEST_ID_CTX_KEY, default=None)
_user_id_ctx_var: ContextVar[Optional[str]] = ContextVar(USER_ID_CTX_KEY, default=None)
_username_ctx_var: ContextVar[Optional[str]] = ContextVar(USERNAME_CTX_KEY, default=None)
_role_ctx_var: ContextVar[Optional[str]] = ContextVar(ROLE_CTX_KEY, default=None)
def get_request_id() -> Optional[str]: def get_request_id() -> Optional[str]:
@ -16,3 +22,21 @@ def set_request_id(request_id: str):
def reset_request_id(token): def reset_request_id(token):
_request_id_ctx_var.reset(token) _request_id_ctx_var.reset(token)
def get_user_id() -> Optional[str]:
return _user_id_ctx_var.get()
def set_user_id(user_id: str):
return _user_id_ctx_var.set(user_id)
def get_username() -> Optional[str]:
return _username_ctx_var.get()
def set_username(username: str):
return _username_ctx_var.set(username)
def get_role() -> Optional[str]:
return _role_ctx_var.get()
def set_role(role: str):
return _role_ctx_var.set(role)

@ -8,7 +8,7 @@ class CommonParams(DefultBase):
# This ensures no extra query params are allowed # This ensures no extra query params are allowed
current_user: Optional[str] = Field(None, alias="currentUser", max_length=50) current_user: Optional[str] = Field(None, alias="currentUser", max_length=50)
page: int = Field(1, gt=0, lt=2147483647) page: int = Field(1, gt=0, lt=2147483647)
items_per_page: int = Field(5, gt=-2, lt=2147483647, alias="itemsPerPage") items_per_page: int = Field(5, gt=0, le=50, multiple_of=5, alias="itemsPerPage")
query_str: Optional[str] = Field(None, alias="q", max_length=100) query_str: Optional[str] = Field(None, alias="q", max_length=100)
filter_spec: Optional[str] = Field(None, alias="filter", max_length=500) filter_spec: Optional[str] = Field(None, alias="filter", max_length=500)
sort_by: List[str] = Field(default_factory=list, alias="sortBy[]") sort_by: List[str] = Field(default_factory=list, alias="sortBy[]")

@ -19,8 +19,8 @@ log = logging.getLogger(__name__)
class ErrorDetail(BaseModel): class ErrorDetail(BaseModel):
field: Optional[str] = Field(None, max_length=100) field: Optional[str] = Field(None, max_length=100)
message: str = Field(..., max_length=255) message: str = Field(...)
code: Optional[str] = Field(None, max_length=50) code: Optional[str] = Field(None)
params: Optional[Dict[str, Any]] = None params: Optional[Dict[str, Any]] = None
@ -103,15 +103,38 @@ def handle_exception(request: Request, exc: Exception):
""" """
Global exception handler for Fastapi application. Global exception handler for Fastapi application.
""" """
import uuid
error_id = str(uuid.uuid1())
request_info = get_request_context(request) request_info = get_request_context(request)
# In FastAPI, we don't have a global 'g', but we can pass info in request.state
request.state.error_id = error_id
if isinstance(exc, RateLimitExceeded): if isinstance(exc, RateLimitExceeded):
return _rate_limit_exceeded_handler(request, exc) log.warning(
f"Rate limit exceeded | Error ID: {error_id}",
extra={
"error_id": error_id,
"error_category": "rate_limit",
"request": request_info,
"detail": str(exc.description) if hasattr(exc, "description") else str(exc),
},
)
return JSONResponse(
status_code=429,
content={
"data": None,
"message": "Rate limit exceeded",
"status": ResponseStatus.ERROR,
"error_id": error_id
}
)
if isinstance(exc, RequestValidationError): if isinstance(exc, RequestValidationError):
log.error( log.warning(
"Validation error occurred", f"Validation error occurred | Error ID: {error_id}",
extra={ extra={
"error_id": error_id,
"error_category": "validation", "error_category": "validation",
"errors": exc.errors(), "errors": exc.errors(),
"request": request_info, "request": request_info,
@ -120,16 +143,18 @@ def handle_exception(request: Request, exc: Exception):
return JSONResponse( return JSONResponse(
status_code=422, status_code=422,
content={ content={
"data": None, "data": exc.errors(),
"message": "Validation Error", "message": "Validation Error",
"status": ResponseStatus.ERROR, "status": ResponseStatus.ERROR,
"errors": exc.errors(), "error_id": error_id
}, },
) )
if isinstance(exc, (HTTPException, StarletteHTTPException)): if isinstance(exc, (HTTPException, StarletteHTTPException)):
log.error( log.error(
"HTTP exception occurred", f"HTTP exception occurred | Error ID: {error_id}",
extra={ extra={
"error_id": error_id,
"error_category": "http", "error_category": "http",
"status_code": exc.status_code, "status_code": exc.status_code,
"detail": exc.detail if hasattr(exc, "detail") else str(exc), "detail": exc.detail if hasattr(exc, "detail") else str(exc),
@ -143,19 +168,16 @@ def handle_exception(request: Request, exc: Exception):
"data": None, "data": None,
"message": str(exc.detail) if hasattr(exc, "detail") else str(exc), "message": str(exc.detail) if hasattr(exc, "detail") else str(exc),
"status": ResponseStatus.ERROR, "status": ResponseStatus.ERROR,
"errors": [ "error_id": error_id
ErrorDetail(
message=str(exc.detail) if hasattr(exc, "detail") else str(exc)
).model_dump()
],
}, },
) )
if isinstance(exc, SQLAlchemyError): if isinstance(exc, SQLAlchemyError):
error_message, status_code = handle_sqlalchemy_error(exc) error_message, status_code = handle_sqlalchemy_error(exc)
log.error( log.error(
"Database error occurred", f"Database error occurred | Error ID: {error_id}",
extra={ extra={
"error_id": error_id,
"error_category": "database", "error_category": "database",
"error_message": error_message, "error_message": error_message,
"request": request_info, "request": request_info,
@ -169,14 +191,15 @@ def handle_exception(request: Request, exc: Exception):
"data": None, "data": None,
"message": error_message, "message": error_message,
"status": ResponseStatus.ERROR, "status": ResponseStatus.ERROR,
"errors": [ErrorDetail(message=error_message).model_dump()], "error_id": error_id
}, },
) )
# Log unexpected errors # Log unexpected errors
log.error( log.error(
"Unexpected error occurred", f"Unexpected error occurred | Error ID: {error_id}",
extra={ extra={
"error_id": error_id,
"error_category": "unexpected", "error_category": "unexpected",
"error_message": str(exc), "error_message": str(exc),
"request": request_info, "request": request_info,
@ -188,10 +211,8 @@ def handle_exception(request: Request, exc: Exception):
status_code=500, status_code=500,
content={ content={
"data": None, "data": None,
"message": str(exc), "message": "An unexpected error occurred",
"status": ResponseStatus.ERROR, "status": ResponseStatus.ERROR,
"errors": [ "error_id": error_id
ErrorDetail(message="An unexpected error occurred.").model_dump()
],
}, },
) )

@ -35,28 +35,44 @@ class JSONFormatter(logging.Formatter):
Custom formatter to output logs in JSON format. Custom formatter to output logs in JSON format.
""" """
def format(self, record): def format(self, record):
from src.context import get_request_id from src.context import get_request_id, get_user_id, get_username, get_role
request_id = None request_id = None
user_id = None
username = None
role = None
try: try:
request_id = get_request_id() request_id = get_request_id()
user_id = get_user_id()
username = get_username()
role = get_role()
except Exception: except Exception:
pass pass
# Standard fields from requirements
log_record = { log_record = {
"timestamp": datetime.datetime.fromtimestamp(record.created).astimezone().isoformat(), "timestamp": datetime.datetime.fromtimestamp(record.created).strftime("%Y-%m-%d %H:%M:%S"),
"level": record.levelname, "level": record.levelname,
"name": record.name,
"message": record.getMessage(), "message": record.getMessage(),
"logger_name": record.name,
"location": f"{record.module}:{record.funcName}:{record.lineno}",
"module": record.module,
"funcName": record.funcName,
"lineno": record.lineno,
"pid": os.getpid(),
"request_id": request_id,
} }
# Add Context information if available
if user_id:
log_record["user_id"] = user_id
if username:
log_record["username"] = username
if role:
log_record["role"] = role
if request_id:
log_record["request_id"] = request_id
# Add Error context if available
if hasattr(record, "error_id"):
log_record["error_id"] = record.error_id
elif "error_id" in record.__dict__:
log_record["error_id"] = record.error_id
# Capture exception info if available # Capture exception info if available
if record.exc_info: if record.exc_info:
@ -67,15 +83,14 @@ class JSONFormatter(logging.Formatter):
log_record["stack_trace"] = self.formatStack(record.stack_info) log_record["stack_trace"] = self.formatStack(record.stack_info)
# Add any extra attributes passed to the log call # Add any extra attributes passed to the log call
# We skip standard attributes to avoid duplication
standard_attrs = { standard_attrs = {
"args", "asctime", "created", "exc_info", "exc_text", "filename", "args", "asctime", "created", "exc_info", "exc_text", "filename",
"funcName", "levelname", "levelno", "lineno", "module", "msecs", "funcName", "levelname", "levelno", "lineno", "module", "msecs",
"message", "msg", "name", "pathname", "process", "processName", "message", "msg", "name", "pathname", "process", "processName",
"relativeCreated", "stack_info", "thread", "threadName" "relativeCreated", "stack_info", "thread", "threadName", "error_id"
} }
for key, value in record.__dict__.items(): for key, value in record.__dict__.items():
if key not in standard_attrs: if key not in standard_attrs and not key.startswith("_"):
log_record[key] = value log_record[key] = value
log_json = json.dumps(log_record) log_json = json.dumps(log_record)
@ -87,7 +102,6 @@ class JSONFormatter(logging.Formatter):
return log_json return log_json
def configure_logging(): def configure_logging():
log_level = str(LOG_LEVEL).upper() # cast to string log_level = str(LOG_LEVEL).upper() # cast to string
log_levels = list(LogLevels) log_levels = list(LogLevels)

@ -52,7 +52,7 @@ app.add_exception_handler(Exception, handle_exception)
app.add_exception_handler(HTTPException, handle_exception) app.add_exception_handler(HTTPException, handle_exception)
app.add_exception_handler(StarletteHTTPException, handle_exception) app.add_exception_handler(StarletteHTTPException, handle_exception)
app.add_exception_handler(RequestValidationError, handle_exception) app.add_exception_handler(RequestValidationError, handle_exception)
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) app.add_exception_handler(RateLimitExceeded, handle_exception)
app.add_exception_handler(SQLAlchemyError, handle_exception) app.add_exception_handler(SQLAlchemyError, handle_exception)
app.state.limiter = limiter app.state.limiter = limiter
@ -80,23 +80,58 @@ async def db_session_middleware(request: Request, call_next):
response = await call_next(request) response = await call_next(request)
process_time = (time.time() - start_time) * 1000 process_time = (time.time() - start_time) * 1000
from src.context import get_username, get_role, get_user_id, set_user_id, set_username, set_role
# Pull from context or fallback to request.state
username = get_username()
role = get_role()
user_id = get_user_id()
user_obj = getattr(request.state, "user", None)
if user_obj:
if not user_id and hasattr(user_obj, "user_id"):
user_id = str(user_obj.user_id)
set_user_id(user_id)
if not username and hasattr(user_obj, "name"):
username = user_obj.name
set_username(username)
if not role and hasattr(user_obj, "role"):
role = user_obj.role
set_role(role)
user_info_str = ""
if username:
user_info_str = f" | User: {username}"
if role:
user_info_str += f" ({role})"
log.info( log.info(
"Request finished", f"HTTP {request.method} {request.url.path} completed in {round(process_time, 2)}ms{user_info_str}",
extra={ extra={
"method": request.method, "method": request.method,
"path": request.url.path, "path": request.url.path,
"status_code": response.status_code, "status_code": response.status_code,
"duration_ms": round(process_time, 2), "duration_ms": round(process_time, 2),
"user_id": user_id,
"role": role,
}, },
) )
except Exception as e: except Exception as e:
# Generate an error_id here if it hasn't been generated yet (e.g., if it failed before the handler)
error_id = getattr(request.state, "error_id", None)
if not error_id:
import uuid
error_id = str(uuid.uuid1())
request.state.error_id = error_id
log.error( log.error(
"Request failed", f"Request failed | Error ID: {error_id}",
extra={ extra={
"method": request.method, "method": request.method,
"path": request.url.path, "path": request.url.path,
"error": str(e), "error": str(e),
"error_id": error_id,
}, },
exc_info=True, exc_info=True,
) )

@ -18,14 +18,33 @@ MAX_QUERY_PARAMS = 50
MAX_QUERY_LENGTH = 2000 MAX_QUERY_LENGTH = 2000
MAX_JSON_BODY_SIZE = 1024 * 100 # 100 KB MAX_JSON_BODY_SIZE = 1024 * 100 # 100 KB
# Very targeted patterns. Avoid catastrophic regex nonsense.
XSS_PATTERN = re.compile( XSS_PATTERN = re.compile(
r"(<script|</script|javascript:|onerror\s*=|onload\s*=|<svg|<img)", r"(<script|<iframe|<embed|<object|<svg|<img|<video|<audio|<base|<link|<meta|<form|<button|"
r"javascript:|vbscript:|data:text/html|onerror\s*=|onload\s*=|onmouseover\s*=|onfocus\s*=|"
r"onclick\s*=|onscroll\s*=|ondblclick\s*=|onkeydown\s*=|onkeypress\s*=|onkeyup\s*=|"
r"onloadstart\s*=|onpageshow\s*=|onresize\s*=|onunload\s*=|style\s*=\s*['\"].expression\s\(|"
r"eval\s*\(|setTimeout\s*\(|setInterval\s*\(|Function\s*\()",
re.IGNORECASE, re.IGNORECASE,
) )
SQLI_PATTERN = re.compile( SQLI_PATTERN = re.compile(
r"(\bUNION\b|\bSELECT\b|\bINSERT\b|\bDELETE\b|\bDROP\b|--|\bOR\b\s+1=1)", r"(\bUNION\b|\bSELECT\b|\bINSERT\b|\bUPDATE\b|\bDELETE\b|\bDROP\b|\bALTER\b|\bCREATE\b|\bTRUNCATE\b|"
r"\bEXEC\b|\bEXECUTE\b|\bDECLARE\b|\bWAITFOR\b|\bDELAY\b|\bGROUP\b\s+\bBY\b|\bHAVING\b|\bORDER\b\s+\bBY\b|"
r"\bINFORMATION_SCHEMA\b|\bSYS\b\.|\bSYSOBJECTS\b|\bPG_SLEEP\b|\bSLEEP\b\(|--|/\|\/|#|\bOR\b\s+['\"]?\d+['\"]?\s*=\s*['\"]?\d+|"
r"\bAND\b\s+['\"]?\d+['\"]?\s*=\s*['\"]?\d+|"
r"\bXP_CMDSHELL\b|\bLOAD_FILE\b|\bINTO\s+OUTFILE\b)",
re.IGNORECASE,
)
RCE_PATTERN = re.compile(
r"(\$\(|`.*`|[;&|]\s*(cat|ls|id|whoami|pwd|ifconfig|ip|netstat|nc|netcat|nmap|curl|wget|python|php|perl|ruby|bash|sh|cmd|powershell|pwsh|sc\s+|tasklist|taskkill|base64|sudo|crontab|ssh|ftp|tftp)|"
r"\b(cat|ls|id|whoami|pwd|ifconfig|ip|netstat|nc|netcat|nmap|curl|wget|python|php|perl|ruby|bash|sh|cmd|powershell|pwsh|base64|sudo|crontab)\b|"
r"/etc/passwd|/etc/shadow|/etc/group|/etc/issue|/proc/self/|/windows/system32/|C:\\Windows\\)",
re.IGNORECASE,
)
TRAVERSAL_PATTERN = re.compile(
r"(\.\./|\.\.\\|%2e%2e%2f|%2e%2e/|\.\.%2f|%2e%2e%5c)",
re.IGNORECASE, re.IGNORECASE,
) )
@ -53,6 +72,18 @@ def inspect_value(value: str, source: str):
detail=f"Potential SQL injection payload detected in {source}", detail=f"Potential SQL injection payload detected in {source}",
) )
if RCE_PATTERN.search(value):
raise HTTPException(
status_code=400,
detail=f"Potential RCE payload detected in {source}",
)
if TRAVERSAL_PATTERN.search(value):
raise HTTPException(
status_code=400,
detail=f"Potential traversal payload detected in {source}",
)
if has_control_chars(value): if has_control_chars(value):
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
@ -117,10 +148,31 @@ class RequestValidationMiddleware(BaseHTTPMiddleware):
# ------------------------- # -------------------------
# 3. Query param inspection # 3. Query param inspection
# ------------------------- # -------------------------
pagination_size_keys = {"size", "itemsPerPage", "per_page", "limit"}
for key, value in params: for key, value in params:
if value: if value:
inspect_value(value, f"query param '{key}'") inspect_value(value, f"query param '{key}'")
# Pagination constraint: multiples of 5, max 50
if key in pagination_size_keys and value:
try:
size_val = int(value)
if size_val > 50:
raise HTTPException(
status_code=400,
detail=f"Pagination size '{key}' cannot exceed 50",
)
if size_val % 5 != 0:
raise HTTPException(
status_code=400,
detail=f"Pagination size '{key}' must be a multiple of 5",
)
except ValueError:
raise HTTPException(
status_code=400,
detail=f"Pagination size '{key}' must be an integer",
)
# ------------------------- # -------------------------
# 4. Content-Type sanity # 4. Content-Type sanity
# ------------------------- # -------------------------

Loading…
Cancel
Save