From b7ad1c9a6bf54281efb03962b398ca5a47a3d96e Mon Sep 17 00:00:00 2001 From: Cizz22 Date: Tue, 24 Feb 2026 11:21:19 +0700 Subject: [PATCH] refactor: Overhaul testing framework with extensive unit tests, pytest configuration, and improved schema validation. --- poetry.lock | 40 +++++++- pyproject.toml | 2 + pytest.ini | 6 ++ src/database/schema.py | 2 +- src/middleware.py | 57 ++++++++++- src/utils.py | 5 +- tests/conftest.py | 106 ++++++++++---------- tests/database.py | 3 - tests/factories.py | 28 ------ tests/unit/test_budget_constrains.py | 44 ++++++++ tests/unit/test_context.py | 14 +++ tests/unit/test_contribution_util.py | 53 ++++++++++ tests/unit/test_exceptions.py | 31 ++++++ tests/unit/test_middleware_dispatch.py | 56 +++++++++++ tests/unit/test_reliability_calc.py | 64 ++++++++++++ tests/unit/test_schemas.py | 49 +++++++++ tests/unit/test_security_middleware.py | 58 +++++++++++ tests/unit/test_target_reliability_utils.py | 33 ++++++ tests/unit/test_utils.py | 36 +++++++ 19 files changed, 596 insertions(+), 91 deletions(-) create mode 100644 pytest.ini delete mode 100644 tests/database.py delete mode 100644 tests/factories.py create mode 100644 tests/unit/test_budget_constrains.py create mode 100644 tests/unit/test_context.py create mode 100644 tests/unit/test_contribution_util.py create mode 100644 tests/unit/test_exceptions.py create mode 100644 tests/unit/test_middleware_dispatch.py create mode 100644 tests/unit/test_reliability_calc.py create mode 100644 tests/unit/test_schemas.py create mode 100644 tests/unit/test_security_middleware.py create mode 100644 tests/unit/test_target_reliability_utils.py create mode 100644 tests/unit/test_utils.py diff --git a/poetry.lock b/poetry.lock index 73f375d..6dd16be 100644 --- a/poetry.lock +++ b/poetry.lock @@ -148,6 +148,25 @@ files = [ frozenlist = ">=1.1.0" typing-extensions = {version = ">=4.2", markers = "python_version < \"3.13\""} +[[package]] +name = "aiosqlite" +version = "0.20.0" +description = "asyncio bridge to the standard sqlite3 module" +optional = false +python-versions = ">=3.8" +groups = ["main"] +files = [ + {file = "aiosqlite-0.20.0-py3-none-any.whl", hash = "sha256:36a1deaca0cac40ebe32aac9977a6e2bbc7f5189f23f4a54d5908986729e5bd6"}, + {file = "aiosqlite-0.20.0.tar.gz", hash = "sha256:6d35c8c256637f4672f843c31021464090805bf925385ac39473fb16eaaca3d7"}, +] + +[package.dependencies] +typing_extensions = ">=4.0" + +[package.extras] +dev = ["attribution (==1.7.0)", "black (==24.2.0)", "coverage[toml] (==7.4.1)", "flake8 (==7.0.0)", "flake8-bugbear (==24.2.6)", "flit (==3.9.0)", "mypy (==1.8.0)", "ufmt (==2.3.0)", "usort (==1.0.8.post1)"] +docs = ["sphinx (==7.2.6)", "sphinx-mdinclude (==0.5.3)"] + [[package]] name = "annotated-types" version = "0.7.0" @@ -2050,6 +2069,25 @@ pluggy = ">=1.5,<2" [package.extras] dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +[[package]] +name = "pytest-asyncio" +version = "0.24.0" +description = "Pytest support for asyncio" +optional = false +python-versions = ">=3.8" +groups = ["main"] +files = [ + {file = "pytest_asyncio-0.24.0-py3-none-any.whl", hash = "sha256:a811296ed596b69bf0b6f3dc40f83bcaf341b155a269052d82efa2b25ac7037b"}, + {file = "pytest_asyncio-0.24.0.tar.gz", hash = "sha256:d081d828e576d85f875399194281e92bf8a68d60d72d1a2faf2feddb6c46b276"}, +] + +[package.dependencies] +pytest = ">=8.2,<9" + +[package.extras] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] +testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] + [[package]] name = "python-dateutil" version = "2.9.0.post0" @@ -3008,4 +3046,4 @@ propcache = ">=0.2.1" [metadata] lock-version = "2.1" python-versions = "^3.11" -content-hash = "6c2a5a5a8e6a2732bd9e94de4bac3a7c0d3e63d959d5793b23eb327c7a95f3f8" +content-hash = "256c8104c6eeb5b288dd0cdf02fe7cbad4f75aa93fc71f8d44da8b605d72f886" diff --git a/pyproject.toml b/pyproject.toml index 4272aa1..762975c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,8 @@ fastapi = { extras = ["standard"], version = "^0.115.4" } sqlalchemy = "^2.0.36" httpx = "^0.27.2" pytest = "^8.3.3" +pytest-asyncio = "^0.24.0" +aiosqlite = "^0.20.0" faker = "^30.8.2" factory-boy = "^3.3.1" sqlalchemy-utils = "^0.41.2" diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..3259ad7 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,6 @@ +[pytest] +asyncio_mode = auto +testpaths = tests +python_files = test_*.py +filterwarnings = + ignore::pydantic.PydanticDeprecatedSince20 diff --git a/src/database/schema.py b/src/database/schema.py index 9a79777..42ff626 100644 --- a/src/database/schema.py +++ b/src/database/schema.py @@ -8,7 +8,7 @@ class CommonParams(DefultBase): # This ensures no extra query params are allowed current_user: Optional[str] = Field(None, alias="currentUser") page: int = Field(1, gt=0, lt=2147483647) - items_per_page: int = Field(5, gt=-2, lt=2147483647, alias="itemsPerPage") + items_per_page: int = Field(5, gt=0, le=50, multiple_of=5, alias="itemsPerPage") query_str: Optional[str] = Field(None, alias="q") filter_spec: Optional[str] = Field(None, alias="filter") sort_by: List[str] = Field(default_factory=list, alias="sortBy[]") diff --git a/src/middleware.py b/src/middleware.py index b654422..0ca720b 100644 --- a/src/middleware.py +++ b/src/middleware.py @@ -18,17 +18,35 @@ MAX_QUERY_PARAMS = 50 MAX_QUERY_LENGTH = 2000 MAX_JSON_BODY_SIZE = 1024 * 100 # 100 KB -# Very targeted patterns. Avoid catastrophic regex nonsense. XSS_PATTERN = re.compile( - r"( 50: + raise HTTPException( + status_code=400, + detail=f"Pagination size '{key}' cannot exceed 50", + ) + if size_val % 5 != 0: + raise HTTPException( + status_code=400, + detail=f"Pagination size '{key}' must be a multiple of 5", + ) + except ValueError: + raise HTTPException( + status_code=400, + detail=f"Pagination size '{key}' must be an integer", + ) + # ------------------------- # 4. Content-Type sanity # ------------------------- diff --git a/src/utils.py b/src/utils.py index ed4a615..f8d7659 100644 --- a/src/utils.py +++ b/src/utils.py @@ -22,7 +22,8 @@ def parse_relative_expression(date_str: str) -> Optional[datetime]: unit, offset = match.groups() offset = int(offset) if offset else 0 # Use UTC timezone for consistency - today = datetime.now(timezone.tzname("Asia/Jakarta")) + jakarta_tz = pytz.timezone("Asia/Jakarta") + today = datetime.now(jakarta_tz) if unit == "H": # For hours, keep minutes and seconds result_time = today + timedelta(hours=offset) @@ -64,7 +65,7 @@ def parse_date_string(date_str: str) -> Optional[datetime]: minute=0, second=0, microsecond=0, - tzinfo=timezone.tzname("Asia/Jakarta"), + tzinfo=pytz.timezone("Asia/Jakarta"), ) return dt except ValueError: diff --git a/tests/conftest.py b/tests/conftest.py index a678ff3..6708226 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,68 +1,68 @@ -import asyncio -from typing import AsyncGenerator, Generator +# import asyncio +# from typing import AsyncGenerator, Generator -import pytest -from httpx import AsyncClient -from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine -from sqlalchemy.orm import sessionmaker -from sqlalchemy.pool import StaticPool -from sqlalchemy_utils import database_exists, drop_database -from starlette.config import environ -from starlette.testclient import TestClient +# import pytest +# from httpx import AsyncClient +# from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine +# from sqlalchemy.orm import sessionmaker +# from sqlalchemy.pool import StaticPool +# from sqlalchemy_utils import database_exists, drop_database +# from starlette.config import environ +# from starlette.testclient import TestClient -# from src.database import Base, get_db -# from src.main import app +# # from src.database import Base, get_db +# # from src.main import app -# Test database URL -TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:" +# # Test database URL +# TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:" -engine = create_async_engine( - TEST_DATABASE_URL, - connect_args={"check_same_thread": False}, - poolclass=StaticPool, -) +# engine = create_async_engine( +# TEST_DATABASE_URL, +# connect_args={"check_same_thread": False}, +# poolclass=StaticPool, +# ) -async_session = sessionmaker( - engine, - class_=AsyncSession, - expire_on_commit=False, - autocommit=False, - autoflush=False, -) +# async_session = sessionmaker( +# engine, +# class_=AsyncSession, +# expire_on_commit=False, +# autocommit=False, +# autoflush=False, +# ) -async def override_get_db() -> AsyncGenerator[AsyncSession, None]: - async with async_session() as session: - try: - yield session - await session.commit() - except Exception: - await session.rollback() - raise - finally: - await session.close() +# async def override_get_db() -> AsyncGenerator[AsyncSession, None]: +# async with async_session() as session: +# try: +# yield session +# await session.commit() +# except Exception: +# await session.rollback() +# raise +# finally: +# await session.close() -app.dependency_overrides[get_db] = override_get_db +# app.dependency_overrides[get_db] = override_get_db -@pytest.fixture(scope="session") -def event_loop() -> Generator: - loop = asyncio.get_event_loop_policy().new_event_loop() - yield loop - loop.close() +# @pytest.fixture(scope="session") +# def event_loop() -> Generator: +# loop = asyncio.get_event_loop_policy().new_event_loop() +# yield loop +# loop.close() -@pytest.fixture(autouse=True) -async def setup_db() -> AsyncGenerator[None, None]: - async with engine.begin() as conn: - await conn.run_sync(Base.metadata.create_all) - yield - async with engine.begin() as conn: - await conn.run_sync(Base.metadata.drop_all) +# @pytest.fixture(autouse=True) +# async def setup_db() -> AsyncGenerator[None, None]: +# async with engine.begin() as conn: +# await conn.run_sync(Base.metadata.create_all) +# yield +# async with engine.begin() as conn: +# await conn.run_sync(Base.metadata.drop_all) -@pytest.fixture -async def client() -> AsyncGenerator[AsyncClient, None]: - async with AsyncClient(app=app, base_url="http://test") as client: - yield client +# @pytest.fixture +# async def client() -> AsyncGenerator[AsyncClient, None]: +# async with AsyncClient(app=app, base_url="http://test") as client: +# yield client diff --git a/tests/database.py b/tests/database.py deleted file mode 100644 index 89b84ae..0000000 --- a/tests/database.py +++ /dev/null @@ -1,3 +0,0 @@ -from sqlalchemy.orm import scoped_session, sessionmaker - -Session = scoped_session(sessionmaker()) diff --git a/tests/factories.py b/tests/factories.py deleted file mode 100644 index c15b514..0000000 --- a/tests/factories.py +++ /dev/null @@ -1,28 +0,0 @@ -import uuid -from datetime import datetime - -from factory import (LazyAttribute, LazyFunction, SelfAttribute, Sequence, - SubFactory, post_generation) -from factory.alchemy import SQLAlchemyModelFactory -from factory.fuzzy import FuzzyChoice, FuzzyDateTime, FuzzyInteger, FuzzyText -from faker import Faker -from faker.providers import misc - -from .database import Session - -# from pytz import UTC - - -fake = Faker() -fake.add_provider(misc) - - -class BaseFactory(SQLAlchemyModelFactory): - """Base Factory.""" - - class Meta: - """Factory configuration.""" - - abstract = True - sqlalchemy_session = Session - sqlalchemy_session_persistence = "commit" diff --git a/tests/unit/test_budget_constrains.py b/tests/unit/test_budget_constrains.py new file mode 100644 index 0000000..c8c9c94 --- /dev/null +++ b/tests/unit/test_budget_constrains.py @@ -0,0 +1,44 @@ +import pytest +from src.calculation_budget_constrains.service import greedy_selection, knapsack_selection + +def test_greedy_selection(): + equipments = [ + {"id": 1, "total_cost": 100, "priority_score": 10, "cost": 100}, + {"id": 2, "total_cost": 50, "priority_score": 20, "cost": 50}, + {"id": 3, "total_cost": 60, "priority_score": 15, "cost": 60}, + ] + budget = 120 + # Items sorted by priority_score: id 2 (20), id 3 (15), id 1 (10) + # 2 (50) + 3 (60) = 110. Item 1 (100) won't fit. + selected, excluded = greedy_selection(equipments, budget) + + selected_ids = [e["id"] for e in selected] + assert 2 in selected_ids + assert 3 in selected_ids + assert len(selected) == 2 + assert excluded[0]["id"] == 1 + +def test_knapsack_selection_basic(): + # Similar items but where greedy might fail if cost/value ratio is tricky + # item 1: value 10, cost 60 + # item 2: value 7, cost 35 + # item 3: value 7, cost 35 + # budget: 70 + # Greedy would take item 1 (value 10, remaining budget 10, can't take more) + # Optimal would take item 2 and 3 (value 14, remaining budget 0) + + scale = 1 # No scaling for simplicity in this test + equipments = [ + {"id": 1, "total_cost": 60, "priority_score": 10}, + {"id": 2, "total_cost": 35, "priority_score": 7}, + {"id": 3, "total_cost": 35, "priority_score": 7}, + ] + budget = 70 + + selected, excluded = knapsack_selection(equipments, budget, scale=1) + + selected_ids = [e["id"] for e in selected] + assert 2 in selected_ids + assert 3 in selected_ids + assert len(selected) == 2 + assert 1 not in selected_ids diff --git a/tests/unit/test_context.py b/tests/unit/test_context.py new file mode 100644 index 0000000..7ab857d --- /dev/null +++ b/tests/unit/test_context.py @@ -0,0 +1,14 @@ +from src.context import set_request_id, get_request_id, set_user_id, get_user_id + +def test_request_id_context(): + test_id = "test-request-id-123" + set_request_id(test_id) + assert get_request_id() == test_id + +def test_user_id_context(): + test_uid = "user-456" + set_user_id(test_uid) + assert get_user_id() == test_uid + +def test_context_default_none(): + assert get_request_id() is None or get_request_id() != "" diff --git a/tests/unit/test_contribution_util.py b/tests/unit/test_contribution_util.py new file mode 100644 index 0000000..e4c157d --- /dev/null +++ b/tests/unit/test_contribution_util.py @@ -0,0 +1,53 @@ +import pytest +from decimal import Decimal +from src.contribution_util import prod, system_availability, get_all_components, birnbaum_importance + +def test_prod(): + assert prod([1, 2, 3]) == 6.0 + assert prod([0.5, 0.5]) == 0.25 + assert prod([]) == 1.0 + +def test_system_availability_series(): + structure = {"series": ["A", "B"]} + availabilities = {"A": 0.9, "B": 0.8} + # 0.9 * 0.8 = 0.72 + assert system_availability(structure, availabilities) == pytest.approx(0.72) + +def test_system_availability_parallel(): + structure = {"parallel": ["A", "B"]} + availabilities = {"A": 0.9, "B": 0.8} + # 1 - (1-0.9)*(1-0.8) = 1 - 0.1*0.2 = 1 - 0.02 = 0.98 + assert system_availability(structure, availabilities) == pytest.approx(0.98) + +def test_system_availability_nested(): + # (A in series with (B in parallel with C)) + structure = { + "series": [ + "A", + {"parallel": ["B", "C"]} + ] + } + availabilities = {"A": 0.9, "B": 0.8, "C": 0.7} + # B||C = 1 - (1-0.8)*(1-0.7) = 1 - 0.2*0.3 = 0.94 + # A && (B||C) = 0.9 * 0.94 = 0.846 + assert system_availability(structure, availabilities) == pytest.approx(0.846) + +def test_get_all_components(): + structure = { + "series": [ + "A", + {"parallel": ["B", "C"]} + ] + } + assert get_all_components(structure) == {"A", "B", "C"} + +def test_birnbaum_importance(): + structure = {"series": ["A", "B"]} + availabilities = {"A": 0.9, "B": 0.8} + # I_B(A) = A_sys(A=1) - A_sys(A=0) + # A_sys(A=1, B=0.8) = 1 * 0.8 = 0.8 + # A_sys(A=0, B=0.8) = 0 * 0.8 = 0 + # I_B(A) = 0.8 + assert birnbaum_importance(structure, availabilities, "A") == pytest.approx(0.8) + # I_B(B) = A_sys(B=1, A=0.9) - A_sys(B=0, A=0.9) = 0.9 - 0 = 0.9 + assert birnbaum_importance(structure, availabilities, "B") == pytest.approx(0.9) diff --git a/tests/unit/test_exceptions.py b/tests/unit/test_exceptions.py new file mode 100644 index 0000000..187314b --- /dev/null +++ b/tests/unit/test_exceptions.py @@ -0,0 +1,31 @@ +import pytest +from sqlalchemy.exc import IntegrityError, DataError, DBAPIError +from src.exceptions import handle_sqlalchemy_error + +def test_handle_sqlalchemy_error_unique_constraint(): + err = IntegrityError("Unique constraint", params=None, orig=Exception("unique constraint violation")) + msg, status = handle_sqlalchemy_error(err) + assert status == 409 + assert "already exists" in msg + +def test_handle_sqlalchemy_error_foreign_key(): + err = IntegrityError("Foreign key constraint", params=None, orig=Exception("foreign key constraint violation")) + msg, status = handle_sqlalchemy_error(err) + assert status == 400 + assert "Related record not found" in msg + +def test_handle_sqlalchemy_error_data_error(): + err = DataError("Invalid data", params=None, orig=None) + msg, status = handle_sqlalchemy_error(err) + assert status == 400 + assert "Invalid data" in msg + +def test_handle_sqlalchemy_error_generic_dbapi(): + class MockError: + def __str__(self): + return "Some generic database error" + + err = DBAPIError("Generic error", params=None, orig=MockError()) + msg, status = handle_sqlalchemy_error(err) + assert status == 500 + assert "Database error" in msg diff --git a/tests/unit/test_middleware_dispatch.py b/tests/unit/test_middleware_dispatch.py new file mode 100644 index 0000000..6517990 --- /dev/null +++ b/tests/unit/test_middleware_dispatch.py @@ -0,0 +1,56 @@ +import pytest +from unittest.mock import AsyncMock, MagicMock +from fastapi import HTTPException +from src.middleware import RequestValidationMiddleware + +@pytest.mark.asyncio +async def test_request_validation_middleware_query_length(): + middleware = RequestValidationMiddleware(app=MagicMock()) + request = MagicMock() + request.url.query = "a" * 2001 + + with pytest.raises(HTTPException) as excinfo: + await middleware.dispatch(request, AsyncMock()) + assert excinfo.value.status_code == 414 + +@pytest.mark.asyncio +async def test_request_validation_middleware_too_many_params(): + middleware = RequestValidationMiddleware(app=MagicMock()) + request = MagicMock() + request.url.query = "a=1" + request.query_params.multi_items.return_value = [("param", "val")] * 51 + + with pytest.raises(HTTPException) as excinfo: + await middleware.dispatch(request, AsyncMock()) + assert excinfo.value.status_code == 400 + assert "Too many query parameters" in excinfo.value.detail + +@pytest.mark.asyncio +async def test_request_validation_middleware_xss_detection(): + middleware = RequestValidationMiddleware(app=MagicMock()) + request = MagicMock() + request.url.query = "q=