refactor: Overhaul testing framework with extensive unit tests, pytest configuration, and improved schema validation.

main
Cizz22 2 weeks ago
parent e60a26b6a6
commit baa6aeb7e7

40
poetry.lock generated

@ -148,6 +148,25 @@ files = [
frozenlist = ">=1.1.0" frozenlist = ">=1.1.0"
typing-extensions = {version = ">=4.2", markers = "python_version < \"3.13\""} typing-extensions = {version = ">=4.2", markers = "python_version < \"3.13\""}
[[package]]
name = "aiosqlite"
version = "0.20.0"
description = "asyncio bridge to the standard sqlite3 module"
optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "aiosqlite-0.20.0-py3-none-any.whl", hash = "sha256:36a1deaca0cac40ebe32aac9977a6e2bbc7f5189f23f4a54d5908986729e5bd6"},
{file = "aiosqlite-0.20.0.tar.gz", hash = "sha256:6d35c8c256637f4672f843c31021464090805bf925385ac39473fb16eaaca3d7"},
]
[package.dependencies]
typing_extensions = ">=4.0"
[package.extras]
dev = ["attribution (==1.7.0)", "black (==24.2.0)", "coverage[toml] (==7.4.1)", "flake8 (==7.0.0)", "flake8-bugbear (==24.2.6)", "flit (==3.9.0)", "mypy (==1.8.0)", "ufmt (==2.3.0)", "usort (==1.0.8.post1)"]
docs = ["sphinx (==7.2.6)", "sphinx-mdinclude (==0.5.3)"]
[[package]] [[package]]
name = "annotated-types" name = "annotated-types"
version = "0.7.0" version = "0.7.0"
@ -2050,6 +2069,25 @@ pluggy = ">=1.5,<2"
[package.extras] [package.extras]
dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
[[package]]
name = "pytest-asyncio"
version = "0.24.0"
description = "Pytest support for asyncio"
optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "pytest_asyncio-0.24.0-py3-none-any.whl", hash = "sha256:a811296ed596b69bf0b6f3dc40f83bcaf341b155a269052d82efa2b25ac7037b"},
{file = "pytest_asyncio-0.24.0.tar.gz", hash = "sha256:d081d828e576d85f875399194281e92bf8a68d60d72d1a2faf2feddb6c46b276"},
]
[package.dependencies]
pytest = ">=8.2,<9"
[package.extras]
docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"]
testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"]
[[package]] [[package]]
name = "python-dateutil" name = "python-dateutil"
version = "2.9.0.post0" version = "2.9.0.post0"
@ -3008,4 +3046,4 @@ propcache = ">=0.2.1"
[metadata] [metadata]
lock-version = "2.1" lock-version = "2.1"
python-versions = "^3.11" python-versions = "^3.11"
content-hash = "6c2a5a5a8e6a2732bd9e94de4bac3a7c0d3e63d959d5793b23eb327c7a95f3f8" content-hash = "256c8104c6eeb5b288dd0cdf02fe7cbad4f75aa93fc71f8d44da8b605d72f886"

@ -12,6 +12,8 @@ fastapi = { extras = ["standard"], version = "^0.115.4" }
sqlalchemy = "^2.0.36" sqlalchemy = "^2.0.36"
httpx = "^0.27.2" httpx = "^0.27.2"
pytest = "^8.3.3" pytest = "^8.3.3"
pytest-asyncio = "^0.24.0"
aiosqlite = "^0.20.0"
faker = "^30.8.2" faker = "^30.8.2"
factory-boy = "^3.3.1" factory-boy = "^3.3.1"
sqlalchemy-utils = "^0.41.2" sqlalchemy-utils = "^0.41.2"

@ -0,0 +1,6 @@
[pytest]
asyncio_mode = auto
testpaths = tests
python_files = test_*.py
filterwarnings =
ignore::pydantic.PydanticDeprecatedSince20

@ -8,7 +8,7 @@ class CommonParams(DefultBase):
# This ensures no extra query params are allowed # This ensures no extra query params are allowed
current_user: Optional[str] = Field(None, alias="currentUser") current_user: Optional[str] = Field(None, alias="currentUser")
page: int = Field(1, gt=0, lt=2147483647) page: int = Field(1, gt=0, lt=2147483647)
items_per_page: int = Field(5, gt=-2, lt=2147483647, alias="itemsPerPage") items_per_page: int = Field(5, gt=0, le=50, multiple_of=5, alias="itemsPerPage")
query_str: Optional[str] = Field(None, alias="q") query_str: Optional[str] = Field(None, alias="q")
filter_spec: Optional[str] = Field(None, alias="filter") filter_spec: Optional[str] = Field(None, alias="filter")
sort_by: List[str] = Field(default_factory=list, alias="sortBy[]") sort_by: List[str] = Field(default_factory=list, alias="sortBy[]")

@ -18,17 +18,35 @@ MAX_QUERY_PARAMS = 50
MAX_QUERY_LENGTH = 2000 MAX_QUERY_LENGTH = 2000
MAX_JSON_BODY_SIZE = 1024 * 100 # 100 KB MAX_JSON_BODY_SIZE = 1024 * 100 # 100 KB
# Very targeted patterns. Avoid catastrophic regex nonsense.
XSS_PATTERN = re.compile( XSS_PATTERN = re.compile(
r"(<script|</script|javascript:|onerror\s*=|onload\s*=|<svg|<img)", r"(<script|<iframe|<embed|<object|<svg|<img|<video|<audio|<base|<link|<meta|<form|<button|"
r"javascript:|vbscript:|data:text/html|onerror\s*=|onload\s*=|onmouseover\s*=|onfocus\s*=|"
r"onclick\s*=|onscroll\s*=|ondblclick\s*=|onkeydown\s*=|onkeypress\s*=|onkeyup\s*=|"
r"onloadstart\s*=|onpageshow\s*=|onresize\s*=|onunload\s*=|style\s*=\s*['\"].expression\s\(|"
r"eval\s*\(|setTimeout\s*\(|setInterval\s*\(|Function\s*\()",
re.IGNORECASE, re.IGNORECASE,
) )
SQLI_PATTERN = re.compile( SQLI_PATTERN = re.compile(
r"(\bUNION\b|\bSELECT\b|\bINSERT\b|\bDELETE\b|\bDROP\b|--|\bOR\b\s+1=1)", r"(\bUNION\b|\bSELECT\b|\bINSERT\b|\bUPDATE\b|\bDELETE\b|\bDROP\b|\bALTER\b|\bCREATE\b|\bTRUNCATE\b|"
r"\bEXEC\b|\bEXECUTE\b|\bDECLARE\b|\bWAITFOR\b|\bDELAY\b|\bGROUP\b\s+\bBY\b|\bHAVING\b|\bORDER\b\s+\bBY\b|"
r"\bINFORMATION_SCHEMA\b|\bSYS\b\.|\bSYSOBJECTS\b|\bPG_SLEEP\b|\bSLEEP\b\(|--|/\|\/|#|\bOR\b\s+['\"]?\d+['\"]?\s*=\s*['\"]?\d+|"
r"\bAND\b\s+['\"]?\d+['\"]?\s*=\s*['\"]?\d+|"
r"\bXP_CMDSHELL\b|\bLOAD_FILE\b|\bINTO\s+OUTFILE\b)",
re.IGNORECASE, re.IGNORECASE,
) )
RCE_PATTERN = re.compile(
r"(\$\(|`.*`|[;&|]\s*(cat|ls|id|whoami|pwd|ifconfig|ip|netstat|nc|netcat|nmap|curl|wget|python|php|perl|ruby|bash|sh|cmd|powershell|pwsh|sc\s+|tasklist|taskkill|base64|sudo|crontab|ssh|ftp|tftp)|"
r"\b(cat|ls|id|whoami|pwd|ifconfig|ip|netstat|nc|netcat|nmap|curl|wget|python|php|perl|ruby|bash|sh|cmd|powershell|pwsh|base64|sudo|crontab)\b|"
r"/etc/passwd|/etc/shadow|/etc/group|/etc/issue|/proc/self/|/windows/system32/|C:\\Windows\\)",
re.IGNORECASE,
)
TRAVERSAL_PATTERN = re.compile(
r"(\.\./|\.\.\\|%2e%2e%2f|%2e%2e/|\.\.%2f|%2e%2e%5c)",
re.IGNORECASE,
)
# JSON prototype pollution keys # JSON prototype pollution keys
FORBIDDEN_JSON_KEYS = {"__proto__", "constructor", "prototype"} FORBIDDEN_JSON_KEYS = {"__proto__", "constructor", "prototype"}
@ -53,6 +71,18 @@ def inspect_value(value: str, source: str):
detail=f"Potential SQL injection payload detected in {source}", detail=f"Potential SQL injection payload detected in {source}",
) )
if RCE_PATTERN.search(value):
raise HTTPException(
status_code=400,
detail=f"Potential Remote Code Execution payload detected in {source}",
)
if TRAVERSAL_PATTERN.search(value):
raise HTTPException(
status_code=400,
detail=f"Path traversal detected in {source}",
)
if has_control_chars(value): if has_control_chars(value):
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
@ -117,10 +147,31 @@ class RequestValidationMiddleware(BaseHTTPMiddleware):
# ------------------------- # -------------------------
# 3. Query param inspection # 3. Query param inspection
# ------------------------- # -------------------------
pagination_size_keys = {"size", "itemsPerPage", "per_page", "limit"}
for key, value in params: for key, value in params:
if value: if value:
inspect_value(value, f"query param '{key}'") inspect_value(value, f"query param '{key}'")
# Pagination constraint: multiples of 5, max 50
if key in pagination_size_keys and value:
try:
size_val = int(value)
if size_val > 50:
raise HTTPException(
status_code=400,
detail=f"Pagination size '{key}' cannot exceed 50",
)
if size_val % 5 != 0:
raise HTTPException(
status_code=400,
detail=f"Pagination size '{key}' must be a multiple of 5",
)
except ValueError:
raise HTTPException(
status_code=400,
detail=f"Pagination size '{key}' must be an integer",
)
# ------------------------- # -------------------------
# 4. Content-Type sanity # 4. Content-Type sanity
# ------------------------- # -------------------------

@ -22,7 +22,8 @@ def parse_relative_expression(date_str: str) -> Optional[datetime]:
unit, offset = match.groups() unit, offset = match.groups()
offset = int(offset) if offset else 0 offset = int(offset) if offset else 0
# Use UTC timezone for consistency # Use UTC timezone for consistency
today = datetime.now(timezone.tzname("Asia/Jakarta")) jakarta_tz = pytz.timezone("Asia/Jakarta")
today = datetime.now(jakarta_tz)
if unit == "H": if unit == "H":
# For hours, keep minutes and seconds # For hours, keep minutes and seconds
result_time = today + timedelta(hours=offset) result_time = today + timedelta(hours=offset)
@ -64,7 +65,7 @@ def parse_date_string(date_str: str) -> Optional[datetime]:
minute=0, minute=0,
second=0, second=0,
microsecond=0, microsecond=0,
tzinfo=timezone.tzname("Asia/Jakarta"), tzinfo=pytz.timezone("Asia/Jakarta"),
) )
return dt return dt
except ValueError: except ValueError:

@ -1,68 +1,68 @@
import asyncio # import asyncio
from typing import AsyncGenerator, Generator # from typing import AsyncGenerator, Generator
import pytest # import pytest
from httpx import AsyncClient # from httpx import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine # from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker # from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool # from sqlalchemy.pool import StaticPool
from sqlalchemy_utils import database_exists, drop_database # from sqlalchemy_utils import database_exists, drop_database
from starlette.config import environ # from starlette.config import environ
from starlette.testclient import TestClient # from starlette.testclient import TestClient
# from src.database import Base, get_db # # from src.database import Base, get_db
# from src.main import app # # from src.main import app
# Test database URL # # Test database URL
TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:" # TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
engine = create_async_engine( # engine = create_async_engine(
TEST_DATABASE_URL, # TEST_DATABASE_URL,
connect_args={"check_same_thread": False}, # connect_args={"check_same_thread": False},
poolclass=StaticPool, # poolclass=StaticPool,
) # )
async_session = sessionmaker( # async_session = sessionmaker(
engine, # engine,
class_=AsyncSession, # class_=AsyncSession,
expire_on_commit=False, # expire_on_commit=False,
autocommit=False, # autocommit=False,
autoflush=False, # autoflush=False,
) # )
async def override_get_db() -> AsyncGenerator[AsyncSession, None]: # async def override_get_db() -> AsyncGenerator[AsyncSession, None]:
async with async_session() as session: # async with async_session() as session:
try: # try:
yield session # yield session
await session.commit() # await session.commit()
except Exception: # except Exception:
await session.rollback() # await session.rollback()
raise # raise
finally: # finally:
await session.close() # await session.close()
app.dependency_overrides[get_db] = override_get_db # app.dependency_overrides[get_db] = override_get_db
@pytest.fixture(scope="session") # @pytest.fixture(scope="session")
def event_loop() -> Generator: # def event_loop() -> Generator:
loop = asyncio.get_event_loop_policy().new_event_loop() # loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop # yield loop
loop.close() # loop.close()
@pytest.fixture(autouse=True) # @pytest.fixture(autouse=True)
async def setup_db() -> AsyncGenerator[None, None]: # async def setup_db() -> AsyncGenerator[None, None]:
async with engine.begin() as conn: # async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all) # await conn.run_sync(Base.metadata.create_all)
yield # yield
async with engine.begin() as conn: # async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all) # await conn.run_sync(Base.metadata.drop_all)
@pytest.fixture # @pytest.fixture
async def client() -> AsyncGenerator[AsyncClient, None]: # async def client() -> AsyncGenerator[AsyncClient, None]:
async with AsyncClient(app=app, base_url="http://test") as client: # async with AsyncClient(app=app, base_url="http://test") as client:
yield client # yield client

@ -1,3 +0,0 @@
from sqlalchemy.orm import scoped_session, sessionmaker
Session = scoped_session(sessionmaker())

@ -1,28 +0,0 @@
import uuid
from datetime import datetime
from factory import (LazyAttribute, LazyFunction, SelfAttribute, Sequence,
SubFactory, post_generation)
from factory.alchemy import SQLAlchemyModelFactory
from factory.fuzzy import FuzzyChoice, FuzzyDateTime, FuzzyInteger, FuzzyText
from faker import Faker
from faker.providers import misc
from .database import Session
# from pytz import UTC
fake = Faker()
fake.add_provider(misc)
class BaseFactory(SQLAlchemyModelFactory):
"""Base Factory."""
class Meta:
"""Factory configuration."""
abstract = True
sqlalchemy_session = Session
sqlalchemy_session_persistence = "commit"

@ -0,0 +1,44 @@
import pytest
from src.calculation_budget_constrains.service import greedy_selection, knapsack_selection
def test_greedy_selection():
equipments = [
{"id": 1, "total_cost": 100, "priority_score": 10, "cost": 100},
{"id": 2, "total_cost": 50, "priority_score": 20, "cost": 50},
{"id": 3, "total_cost": 60, "priority_score": 15, "cost": 60},
]
budget = 120
# Items sorted by priority_score: id 2 (20), id 3 (15), id 1 (10)
# 2 (50) + 3 (60) = 110. Item 1 (100) won't fit.
selected, excluded = greedy_selection(equipments, budget)
selected_ids = [e["id"] for e in selected]
assert 2 in selected_ids
assert 3 in selected_ids
assert len(selected) == 2
assert excluded[0]["id"] == 1
def test_knapsack_selection_basic():
# Similar items but where greedy might fail if cost/value ratio is tricky
# item 1: value 10, cost 60
# item 2: value 7, cost 35
# item 3: value 7, cost 35
# budget: 70
# Greedy would take item 1 (value 10, remaining budget 10, can't take more)
# Optimal would take item 2 and 3 (value 14, remaining budget 0)
scale = 1 # No scaling for simplicity in this test
equipments = [
{"id": 1, "total_cost": 60, "priority_score": 10},
{"id": 2, "total_cost": 35, "priority_score": 7},
{"id": 3, "total_cost": 35, "priority_score": 7},
]
budget = 70
selected, excluded = knapsack_selection(equipments, budget, scale=1)
selected_ids = [e["id"] for e in selected]
assert 2 in selected_ids
assert 3 in selected_ids
assert len(selected) == 2
assert 1 not in selected_ids

@ -0,0 +1,14 @@
from src.context import set_request_id, get_request_id, set_user_id, get_user_id
def test_request_id_context():
test_id = "test-request-id-123"
set_request_id(test_id)
assert get_request_id() == test_id
def test_user_id_context():
test_uid = "user-456"
set_user_id(test_uid)
assert get_user_id() == test_uid
def test_context_default_none():
assert get_request_id() is None or get_request_id() != ""

@ -0,0 +1,53 @@
import pytest
from decimal import Decimal
from src.contribution_util import prod, system_availability, get_all_components, birnbaum_importance
def test_prod():
assert prod([1, 2, 3]) == 6.0
assert prod([0.5, 0.5]) == 0.25
assert prod([]) == 1.0
def test_system_availability_series():
structure = {"series": ["A", "B"]}
availabilities = {"A": 0.9, "B": 0.8}
# 0.9 * 0.8 = 0.72
assert system_availability(structure, availabilities) == pytest.approx(0.72)
def test_system_availability_parallel():
structure = {"parallel": ["A", "B"]}
availabilities = {"A": 0.9, "B": 0.8}
# 1 - (1-0.9)*(1-0.8) = 1 - 0.1*0.2 = 1 - 0.02 = 0.98
assert system_availability(structure, availabilities) == pytest.approx(0.98)
def test_system_availability_nested():
# (A in series with (B in parallel with C))
structure = {
"series": [
"A",
{"parallel": ["B", "C"]}
]
}
availabilities = {"A": 0.9, "B": 0.8, "C": 0.7}
# B||C = 1 - (1-0.8)*(1-0.7) = 1 - 0.2*0.3 = 0.94
# A && (B||C) = 0.9 * 0.94 = 0.846
assert system_availability(structure, availabilities) == pytest.approx(0.846)
def test_get_all_components():
structure = {
"series": [
"A",
{"parallel": ["B", "C"]}
]
}
assert get_all_components(structure) == {"A", "B", "C"}
def test_birnbaum_importance():
structure = {"series": ["A", "B"]}
availabilities = {"A": 0.9, "B": 0.8}
# I_B(A) = A_sys(A=1) - A_sys(A=0)
# A_sys(A=1, B=0.8) = 1 * 0.8 = 0.8
# A_sys(A=0, B=0.8) = 0 * 0.8 = 0
# I_B(A) = 0.8
assert birnbaum_importance(structure, availabilities, "A") == pytest.approx(0.8)
# I_B(B) = A_sys(B=1, A=0.9) - A_sys(B=0, A=0.9) = 0.9 - 0 = 0.9
assert birnbaum_importance(structure, availabilities, "B") == pytest.approx(0.9)

@ -0,0 +1,31 @@
import pytest
from sqlalchemy.exc import IntegrityError, DataError, DBAPIError
from src.exceptions import handle_sqlalchemy_error
def test_handle_sqlalchemy_error_unique_constraint():
err = IntegrityError("Unique constraint", params=None, orig=Exception("unique constraint violation"))
msg, status = handle_sqlalchemy_error(err)
assert status == 409
assert "already exists" in msg
def test_handle_sqlalchemy_error_foreign_key():
err = IntegrityError("Foreign key constraint", params=None, orig=Exception("foreign key constraint violation"))
msg, status = handle_sqlalchemy_error(err)
assert status == 400
assert "Related record not found" in msg
def test_handle_sqlalchemy_error_data_error():
err = DataError("Invalid data", params=None, orig=None)
msg, status = handle_sqlalchemy_error(err)
assert status == 400
assert "Invalid data" in msg
def test_handle_sqlalchemy_error_generic_dbapi():
class MockError:
def __str__(self):
return "Some generic database error"
err = DBAPIError("Generic error", params=None, orig=MockError())
msg, status = handle_sqlalchemy_error(err)
assert status == 500
assert "Database error" in msg

@ -0,0 +1,56 @@
import pytest
from unittest.mock import AsyncMock, MagicMock
from fastapi import HTTPException
from src.middleware import RequestValidationMiddleware
@pytest.mark.asyncio
async def test_request_validation_middleware_query_length():
middleware = RequestValidationMiddleware(app=MagicMock())
request = MagicMock()
request.url.query = "a" * 2001
with pytest.raises(HTTPException) as excinfo:
await middleware.dispatch(request, AsyncMock())
assert excinfo.value.status_code == 414
@pytest.mark.asyncio
async def test_request_validation_middleware_too_many_params():
middleware = RequestValidationMiddleware(app=MagicMock())
request = MagicMock()
request.url.query = "a=1"
request.query_params.multi_items.return_value = [("param", "val")] * 51
with pytest.raises(HTTPException) as excinfo:
await middleware.dispatch(request, AsyncMock())
assert excinfo.value.status_code == 400
assert "Too many query parameters" in excinfo.value.detail
@pytest.mark.asyncio
async def test_request_validation_middleware_xss_detection():
middleware = RequestValidationMiddleware(app=MagicMock())
request = MagicMock()
request.url.query = "q=<script>"
request.query_params.multi_items.return_value = [("q", "<script>")]
with pytest.raises(HTTPException) as excinfo:
await middleware.dispatch(request, AsyncMock())
assert excinfo.value.status_code == 400
assert "Potential XSS payload" in excinfo.value.detail
@pytest.mark.asyncio
async def test_request_validation_middleware_pagination_logic():
middleware = RequestValidationMiddleware(app=MagicMock())
request = MagicMock()
request.url.query = "size=55"
request.query_params.multi_items.return_value = [("size", "55")]
request.headers = {}
with pytest.raises(HTTPException) as excinfo:
await middleware.dispatch(request, AsyncMock())
assert excinfo.value.status_code == 400
assert "cannot exceed 50" in excinfo.value.detail
request.query_params.multi_items.return_value = [("size", "7")]
with pytest.raises(HTTPException) as excinfo:
await middleware.dispatch(request, AsyncMock())
assert "must be a multiple of 5" in excinfo.value.detail

@ -0,0 +1,64 @@
import pytest
import math
from src.calculation_target_reliability.service import calculate_asset_eaf_contributions
def test_calculate_asset_eaf_contributions_basic():
# Mock plant result
plant_result = {
"total_uptime": 7000,
"total_downtime": 1000,
"eaf": 85.0
}
# total_hours = 8000
# Mock equipment results
eq_results = [
{
"aeros_node": {"node_name": "Asset1"},
"num_events": 5,
"contribution_factor": 0.5,
"contribution": 0.1, # Birnbaum
"availability": 0.9,
"total_downtime": 100
},
{
"aeros_node": {"node_name": "Asset2"},
"num_events": 2,
"contribution_factor": 0.3,
"contribution": 0.05,
"availability": 0.95,
"total_downtime": 50
}
]
standard_scope = ["Asset1", "Asset2"]
eaf_gap = 2.0 # 2% gap
scheduled_outage = 500
results = calculate_asset_eaf_contributions(
plant_result, eq_results, standard_scope, eaf_gap, scheduled_outage
)
assert len(results) == 2
# Check sorting (highest birnbaum first)
assert results[0].node["node_name"] == "Asset1"
assert results[0].birbaum > results[1].birbaum
# Check that required_improvement is positive
assert results[0].required_improvement > 0
assert results[0].improvement_impact > 0
def test_calculate_asset_eaf_contributions_skipping():
plant_result = {"total_uptime": 1000, "total_downtime": 0, "eaf": 100}
eq_results = [{
"aeros_node": {"node_name": "Asset1"},
"num_events": 0,
"contribution_factor": 0.5,
"contribution": 0.1,
"availability": 1.0,
"total_downtime": 0
}]
results = calculate_asset_eaf_contributions(
plant_result, eq_results, ["Asset1"], 1.0, 0
)
assert len(results) == 0

@ -0,0 +1,49 @@
import pytest
from pydantic import ValidationError
from src.database.schema import CommonParams
from src.overhaul.schema import OverhaulCriticalParts
def test_common_params_valid():
params = CommonParams(
page=1,
itemsPerPage=10,
q="search test",
all=1
)
assert params.page == 1
assert params.items_per_page == 10
assert params.query_str == "search test"
assert params.is_all is True
def test_common_params_page_constraints():
# Test page must be > 0
with pytest.raises(ValidationError):
CommonParams(page=0)
with pytest.raises(ValidationError):
CommonParams(page=-1)
def test_common_params_items_per_page_constraints():
# Test items_per_page must be multiple of 5
with pytest.raises(ValidationError):
CommonParams(itemsPerPage=7)
# Test items_per_page maximum
with pytest.raises(ValidationError):
CommonParams(itemsPerPage=55)
# Valid multiples of 5
assert CommonParams(itemsPerPage=50).items_per_page == 50
assert CommonParams(itemsPerPage=5).items_per_page == 5
def test_overhaul_critical_parts_valid():
parts = OverhaulCriticalParts(criticalParts=["Part A", "Part B"])
assert parts.criticalParts == ["Part A", "Part B"]
def test_overhaul_critical_parts_invalid():
# criticalParts is required and must be a list
with pytest.raises(ValidationError):
OverhaulCriticalParts()
with pytest.raises(ValidationError):
OverhaulCriticalParts(criticalParts="Not a list")

@ -0,0 +1,58 @@
import pytest
from fastapi import HTTPException
from src.middleware import (
inspect_value,
inspect_json,
has_control_chars,
XSS_PATTERN,
SQLI_PATTERN
)
def test_xss_patterns():
# Test common XSS payloads in be-optimumoh
payloads = [
"<script>",
"javascript:",
"onerror=",
"onload=",
"<svg",
"<img"
]
for payload in payloads:
assert XSS_PATTERN.search(payload) is not None
def test_sqli_patterns():
# Test common SQLi payloads in be-optimumoh
payloads = [
"UNION",
"SELECT",
"INSERT",
"DELETE",
"DROP",
"--",
"OR 1=1"
]
for payload in payloads:
assert SQLI_PATTERN.search(payload) is not None
def test_inspect_value_raises():
with pytest.raises(HTTPException) as excinfo:
inspect_value("<script>", "source")
assert excinfo.value.status_code == 400
assert "Potential XSS payload" in excinfo.value.detail
with pytest.raises(HTTPException) as excinfo:
inspect_value("UNION SELECT", "source")
assert excinfo.value.status_code == 400
assert "Potential SQL injection" in excinfo.value.detail
def test_inspect_json_raises():
with pytest.raises(HTTPException) as excinfo:
inspect_json({"__proto__": "polluted"})
assert excinfo.value.status_code == 400
assert "Forbidden JSON key" in excinfo.value.detail
def test_has_control_chars():
assert has_control_chars("normal string") is False
assert has_control_chars("string with \x00 null") is True
assert has_control_chars("string with \n newline") is False

@ -0,0 +1,33 @@
import pytest
from datetime import datetime, timedelta
from src.calculation_target_reliability.utils import generate_down_periods
def test_generate_down_periods_count():
start = datetime(2025, 1, 1)
end = datetime(2025, 1, 31)
# Test fixed number of periods
periods = generate_down_periods(start, end, num_periods=5)
# It attempts to generate 5, but might be fewer due to overlaps
assert len(periods) <= 5
# Check they are within range
for p_start, p_end in periods:
assert p_start >= start
assert p_end <= end
assert p_start < p_end
def test_generate_down_periods_no_overlap():
start = datetime(2025, 1, 1)
end = datetime(2025, 1, 31)
periods = generate_down_periods(start, end, num_periods=10)
# Sort and check gaps
for i in range(len(periods) - 1):
assert periods[i][1] <= periods[i+1][0]
def test_generate_down_periods_too_small_range():
start = datetime(2025, 1, 1)
end = datetime(2025, 1, 2)
# Requesting 5 days duration in 1 day range
periods = generate_down_periods(start, end, num_periods=1, min_duration=5)
assert len(periods) == 0

@ -0,0 +1,36 @@
import pytest
from datetime import datetime, timedelta
from src.utils import parse_relative_expression, parse_date_string
def test_parse_relative_expression_days():
# Test T, T+n, T-n
result = parse_relative_expression("T")
assert result is not None
assert isinstance(result, datetime)
result_plus = parse_relative_expression("T+5")
assert result_plus is not None
result_minus = parse_relative_expression("T-3")
assert result_minus is not None
def test_parse_relative_expression_invalid():
assert parse_relative_expression("abc") is None
assert parse_relative_expression("123") is None
assert parse_relative_expression("T++1") is None
def test_parse_date_string_formats():
# Test various ISO and common formats
dt = parse_date_string("2024-11-08")
assert dt.year == 2024
assert dt.month == 11
assert dt.day == 8
dt = parse_date_string("08-11-2024")
assert dt.year == 2024
assert dt.month == 11
assert dt.day == 8
def test_parse_date_string_invalid():
with pytest.raises(ValueError):
parse_date_string("invalid-date")
Loading…
Cancel
Save