feat: Implement filename sanitization for uploaded files and enforce secure default filenames for downloads.

main
Cizz22 1 month ago
parent 47cacc50d2
commit 3924954900

@ -18,7 +18,6 @@ import asyncio
import re
import requests
import json
from src.utils import save_to_pastebin
import pandas as pd
from src.aeros_simulation.service import get_aeros_schematic_by_name

@ -64,10 +64,9 @@ async def forward_download(db_session: DbSession):
file_bytes = response.content # full file in memory
# Extract headers
content_disposition = response.headers.get(
"content-disposition",
f'attachment; filename="{filename}"'
)
# Force a secure/default filename for the client
secure_filename = "rbd_project_export.aro"
content_disposition = f'attachment; filename="{secure_filename}"'
media_type = response.headers.get(
"content-type",
"application/octet-stream"

@ -12,6 +12,7 @@ from src.auth.service import CurrentUser
from src.config import WINDOWS_AEROS_BASE_URL, AEROS_BASE_URL, CLAMAV_HOST, CLAMAV_PORT
from src.database.core import DbSession
from src.database.service import search_filter_sort_paginate
from src.utils import sanitize_filename
import clamd
import io
@ -26,13 +27,23 @@ client = httpx.AsyncClient(timeout=300.0)
async def import_aro_project(*, db_session: DbSession, aeros_project_in: AerosProjectInput):
# windows_aeros_base_url = WINDOWS_AEROS_BASE_URL
file = aeros_project_in.aro_file
# Sanitize and validate filename
try:
clean_filename = sanitize_filename(file.filename)
except ValueError as e:
raise HTTPException(
status_code=400,
detail=f"Invalid filename: {str(e)}"
)
# Get filename
filename_without_ext = os.path.splitext(file.filename)[0]
filename_without_ext = os.path.splitext(clean_filename)[0]
# Get file extension
file_ext = os.path.splitext(file.filename)[1].lower()
file_ext = os.path.splitext(clean_filename)[1].lower()
# Validate file extension
if file_ext not in ALLOWED_EXTENSIONS:
@ -92,7 +103,7 @@ async def import_aro_project(*, db_session: DbSession, aeros_project_in: AerosPr
# Prepare file for upload
files = {
"file": (file.filename, content, file.content_type or "application/octet-stream")
"file": (clean_filename, content, file.content_type or "application/octet-stream")
}
print("fetch")

@ -1,3 +1,4 @@
import os
import re
from datetime import datetime, timedelta, timezone
from typing import Optional
@ -139,3 +140,42 @@ def save_to_pastebin(data, title="Result Log", expire_date="1H"):
return response.text # This will be the paste URL
else:
return f"Error: {response.status_code} - {response.text}"
def sanitize_filename(filename: str) -> str:
"""
Sanitize the filename to ensure it is safe.
- Remove path info.
- Remove unsafe characters.
- Limit length.
"""
if not filename:
raise ValueError("Filename cannot be empty")
# Get the basename (remove any path)
filename = os.path.basename(filename)
# Remove control characters and non-printable characters
filename = re.sub(r'[\x00-\x1f\x7f]', '', filename)
# Allow alphanumeric, underscore, hyphen, space, and dots
# Remove other potentially dangerous characters.
filename = re.sub(r'[^a-zA-Z0-9_\-\.\ ]', '_', filename)
# Remove consecutive dots to prevent directory traversal attempts like '..'
filename = re.sub(r'\.{2,}', '.', filename)
# Ensure filename is not practically empty after sanitization
if not filename.strip() or filename.strip().replace('.', '') == '':
raise ValueError("Filename invalid after sanitization")
# Limit length (e.g. 200 chars)
if len(filename) > 200:
base, ext = os.path.splitext(filename)
# Preserve extension if possible
if len(ext) < 20:
filename = base[:(200-len(ext))] + ext
else:
filename = filename[:200]
return filename.strip()

Loading…
Cancel
Save