diff --git a/src/aeros_equipment/service.py b/src/aeros_equipment/service.py index 2ee4a38..3c1c98a 100644 --- a/src/aeros_equipment/service.py +++ b/src/aeros_equipment/service.py @@ -18,7 +18,6 @@ import asyncio import re import requests import json -from src.utils import save_to_pastebin import pandas as pd from src.aeros_simulation.service import get_aeros_schematic_by_name diff --git a/src/aeros_project/router.py b/src/aeros_project/router.py index 450d657..27baaed 100644 --- a/src/aeros_project/router.py +++ b/src/aeros_project/router.py @@ -64,10 +64,9 @@ async def forward_download(db_session: DbSession): file_bytes = response.content # full file in memory # Extract headers - content_disposition = response.headers.get( - "content-disposition", - f'attachment; filename="{filename}"' - ) + # Force a secure/default filename for the client + secure_filename = "rbd_project_export.aro" + content_disposition = f'attachment; filename="{secure_filename}"' media_type = response.headers.get( "content-type", "application/octet-stream" diff --git a/src/aeros_project/service.py b/src/aeros_project/service.py index e421c10..201778b 100644 --- a/src/aeros_project/service.py +++ b/src/aeros_project/service.py @@ -12,6 +12,7 @@ from src.auth.service import CurrentUser from src.config import WINDOWS_AEROS_BASE_URL, AEROS_BASE_URL, CLAMAV_HOST, CLAMAV_PORT from src.database.core import DbSession from src.database.service import search_filter_sort_paginate +from src.utils import sanitize_filename import clamd import io @@ -26,13 +27,23 @@ client = httpx.AsyncClient(timeout=300.0) async def import_aro_project(*, db_session: DbSession, aeros_project_in: AerosProjectInput): # windows_aeros_base_url = WINDOWS_AEROS_BASE_URL + file = aeros_project_in.aro_file + + # Sanitize and validate filename + try: + clean_filename = sanitize_filename(file.filename) + except ValueError as e: + raise HTTPException( + status_code=400, + detail=f"Invalid filename: {str(e)}" + ) # Get filename - filename_without_ext = os.path.splitext(file.filename)[0] + filename_without_ext = os.path.splitext(clean_filename)[0] # Get file extension - file_ext = os.path.splitext(file.filename)[1].lower() + file_ext = os.path.splitext(clean_filename)[1].lower() # Validate file extension if file_ext not in ALLOWED_EXTENSIONS: @@ -92,7 +103,7 @@ async def import_aro_project(*, db_session: DbSession, aeros_project_in: AerosPr # Prepare file for upload files = { - "file": (file.filename, content, file.content_type or "application/octet-stream") + "file": (clean_filename, content, file.content_type or "application/octet-stream") } print("fetch") diff --git a/src/utils.py b/src/utils.py index acf7696..8458e97 100644 --- a/src/utils.py +++ b/src/utils.py @@ -1,3 +1,4 @@ +import os import re from datetime import datetime, timedelta, timezone from typing import Optional @@ -139,3 +140,42 @@ def save_to_pastebin(data, title="Result Log", expire_date="1H"): return response.text # This will be the paste URL else: return f"Error: {response.status_code} - {response.text}" + + +def sanitize_filename(filename: str) -> str: + """ + Sanitize the filename to ensure it is safe. + - Remove path info. + - Remove unsafe characters. + - Limit length. + """ + if not filename: + raise ValueError("Filename cannot be empty") + + # Get the basename (remove any path) + filename = os.path.basename(filename) + + # Remove control characters and non-printable characters + filename = re.sub(r'[\x00-\x1f\x7f]', '', filename) + + # Allow alphanumeric, underscore, hyphen, space, and dots + # Remove other potentially dangerous characters. + filename = re.sub(r'[^a-zA-Z0-9_\-\.\ ]', '_', filename) + + # Remove consecutive dots to prevent directory traversal attempts like '..' + filename = re.sub(r'\.{2,}', '.', filename) + + # Ensure filename is not practically empty after sanitization + if not filename.strip() or filename.strip().replace('.', '') == '': + raise ValueError("Filename invalid after sanitization") + + # Limit length (e.g. 200 chars) + if len(filename) > 200: + base, ext = os.path.splitext(filename) + # Preserve extension if possible + if len(ext) < 20: + filename = base[:(200-len(ext))] + ext + else: + filename = filename[:200] + + return filename.strip()