refactor: Overhaul testing framework with extensive unit tests, pytest configuration, and improved schema validation.
parent
e60a26b6a6
commit
baa6aeb7e7
@ -0,0 +1,6 @@
|
||||
[pytest]
|
||||
asyncio_mode = auto
|
||||
testpaths = tests
|
||||
python_files = test_*.py
|
||||
filterwarnings =
|
||||
ignore::pydantic.PydanticDeprecatedSince20
|
||||
@ -1,68 +1,68 @@
|
||||
import asyncio
|
||||
from typing import AsyncGenerator, Generator
|
||||
# import asyncio
|
||||
# from typing import AsyncGenerator, Generator
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from sqlalchemy_utils import database_exists, drop_database
|
||||
from starlette.config import environ
|
||||
from starlette.testclient import TestClient
|
||||
# import pytest
|
||||
# from httpx import AsyncClient
|
||||
# from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
# from sqlalchemy.orm import sessionmaker
|
||||
# from sqlalchemy.pool import StaticPool
|
||||
# from sqlalchemy_utils import database_exists, drop_database
|
||||
# from starlette.config import environ
|
||||
# from starlette.testclient import TestClient
|
||||
|
||||
# from src.database import Base, get_db
|
||||
# from src.main import app
|
||||
# # from src.database import Base, get_db
|
||||
# # from src.main import app
|
||||
|
||||
# Test database URL
|
||||
TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
|
||||
# # Test database URL
|
||||
# TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
|
||||
|
||||
engine = create_async_engine(
|
||||
TEST_DATABASE_URL,
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
# engine = create_async_engine(
|
||||
# TEST_DATABASE_URL,
|
||||
# connect_args={"check_same_thread": False},
|
||||
# poolclass=StaticPool,
|
||||
# )
|
||||
|
||||
async_session = sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
# async_session = sessionmaker(
|
||||
# engine,
|
||||
# class_=AsyncSession,
|
||||
# expire_on_commit=False,
|
||||
# autocommit=False,
|
||||
# autoflush=False,
|
||||
# )
|
||||
|
||||
|
||||
async def override_get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
async with async_session() as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
# async def override_get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
# async with async_session() as session:
|
||||
# try:
|
||||
# yield session
|
||||
# await session.commit()
|
||||
# except Exception:
|
||||
# await session.rollback()
|
||||
# raise
|
||||
# finally:
|
||||
# await session.close()
|
||||
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
# app.dependency_overrides[get_db] = override_get_db
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop() -> Generator:
|
||||
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
# @pytest.fixture(scope="session")
|
||||
# def event_loop() -> Generator:
|
||||
# loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||
# yield loop
|
||||
# loop.close()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def setup_db() -> AsyncGenerator[None, None]:
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
# @pytest.fixture(autouse=True)
|
||||
# async def setup_db() -> AsyncGenerator[None, None]:
|
||||
# async with engine.begin() as conn:
|
||||
# await conn.run_sync(Base.metadata.create_all)
|
||||
# yield
|
||||
# async with engine.begin() as conn:
|
||||
# await conn.run_sync(Base.metadata.drop_all)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client() -> AsyncGenerator[AsyncClient, None]:
|
||||
async with AsyncClient(app=app, base_url="http://test") as client:
|
||||
yield client
|
||||
# @pytest.fixture
|
||||
# async def client() -> AsyncGenerator[AsyncClient, None]:
|
||||
# async with AsyncClient(app=app, base_url="http://test") as client:
|
||||
# yield client
|
||||
|
||||
@ -1,3 +0,0 @@
|
||||
from sqlalchemy.orm import scoped_session, sessionmaker
|
||||
|
||||
Session = scoped_session(sessionmaker())
|
||||
@ -1,28 +0,0 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from factory import (LazyAttribute, LazyFunction, SelfAttribute, Sequence,
|
||||
SubFactory, post_generation)
|
||||
from factory.alchemy import SQLAlchemyModelFactory
|
||||
from factory.fuzzy import FuzzyChoice, FuzzyDateTime, FuzzyInteger, FuzzyText
|
||||
from faker import Faker
|
||||
from faker.providers import misc
|
||||
|
||||
from .database import Session
|
||||
|
||||
# from pytz import UTC
|
||||
|
||||
|
||||
fake = Faker()
|
||||
fake.add_provider(misc)
|
||||
|
||||
|
||||
class BaseFactory(SQLAlchemyModelFactory):
|
||||
"""Base Factory."""
|
||||
|
||||
class Meta:
|
||||
"""Factory configuration."""
|
||||
|
||||
abstract = True
|
||||
sqlalchemy_session = Session
|
||||
sqlalchemy_session_persistence = "commit"
|
||||
@ -0,0 +1,44 @@
|
||||
import pytest
|
||||
from src.calculation_budget_constrains.service import greedy_selection, knapsack_selection
|
||||
|
||||
def test_greedy_selection():
|
||||
equipments = [
|
||||
{"id": 1, "total_cost": 100, "priority_score": 10, "cost": 100},
|
||||
{"id": 2, "total_cost": 50, "priority_score": 20, "cost": 50},
|
||||
{"id": 3, "total_cost": 60, "priority_score": 15, "cost": 60},
|
||||
]
|
||||
budget = 120
|
||||
# Items sorted by priority_score: id 2 (20), id 3 (15), id 1 (10)
|
||||
# 2 (50) + 3 (60) = 110. Item 1 (100) won't fit.
|
||||
selected, excluded = greedy_selection(equipments, budget)
|
||||
|
||||
selected_ids = [e["id"] for e in selected]
|
||||
assert 2 in selected_ids
|
||||
assert 3 in selected_ids
|
||||
assert len(selected) == 2
|
||||
assert excluded[0]["id"] == 1
|
||||
|
||||
def test_knapsack_selection_basic():
|
||||
# Similar items but where greedy might fail if cost/value ratio is tricky
|
||||
# item 1: value 10, cost 60
|
||||
# item 2: value 7, cost 35
|
||||
# item 3: value 7, cost 35
|
||||
# budget: 70
|
||||
# Greedy would take item 1 (value 10, remaining budget 10, can't take more)
|
||||
# Optimal would take item 2 and 3 (value 14, remaining budget 0)
|
||||
|
||||
scale = 1 # No scaling for simplicity in this test
|
||||
equipments = [
|
||||
{"id": 1, "total_cost": 60, "priority_score": 10},
|
||||
{"id": 2, "total_cost": 35, "priority_score": 7},
|
||||
{"id": 3, "total_cost": 35, "priority_score": 7},
|
||||
]
|
||||
budget = 70
|
||||
|
||||
selected, excluded = knapsack_selection(equipments, budget, scale=1)
|
||||
|
||||
selected_ids = [e["id"] for e in selected]
|
||||
assert 2 in selected_ids
|
||||
assert 3 in selected_ids
|
||||
assert len(selected) == 2
|
||||
assert 1 not in selected_ids
|
||||
@ -0,0 +1,14 @@
|
||||
from src.context import set_request_id, get_request_id, set_user_id, get_user_id
|
||||
|
||||
def test_request_id_context():
|
||||
test_id = "test-request-id-123"
|
||||
set_request_id(test_id)
|
||||
assert get_request_id() == test_id
|
||||
|
||||
def test_user_id_context():
|
||||
test_uid = "user-456"
|
||||
set_user_id(test_uid)
|
||||
assert get_user_id() == test_uid
|
||||
|
||||
def test_context_default_none():
|
||||
assert get_request_id() is None or get_request_id() != ""
|
||||
@ -0,0 +1,53 @@
|
||||
import pytest
|
||||
from decimal import Decimal
|
||||
from src.contribution_util import prod, system_availability, get_all_components, birnbaum_importance
|
||||
|
||||
def test_prod():
|
||||
assert prod([1, 2, 3]) == 6.0
|
||||
assert prod([0.5, 0.5]) == 0.25
|
||||
assert prod([]) == 1.0
|
||||
|
||||
def test_system_availability_series():
|
||||
structure = {"series": ["A", "B"]}
|
||||
availabilities = {"A": 0.9, "B": 0.8}
|
||||
# 0.9 * 0.8 = 0.72
|
||||
assert system_availability(structure, availabilities) == pytest.approx(0.72)
|
||||
|
||||
def test_system_availability_parallel():
|
||||
structure = {"parallel": ["A", "B"]}
|
||||
availabilities = {"A": 0.9, "B": 0.8}
|
||||
# 1 - (1-0.9)*(1-0.8) = 1 - 0.1*0.2 = 1 - 0.02 = 0.98
|
||||
assert system_availability(structure, availabilities) == pytest.approx(0.98)
|
||||
|
||||
def test_system_availability_nested():
|
||||
# (A in series with (B in parallel with C))
|
||||
structure = {
|
||||
"series": [
|
||||
"A",
|
||||
{"parallel": ["B", "C"]}
|
||||
]
|
||||
}
|
||||
availabilities = {"A": 0.9, "B": 0.8, "C": 0.7}
|
||||
# B||C = 1 - (1-0.8)*(1-0.7) = 1 - 0.2*0.3 = 0.94
|
||||
# A && (B||C) = 0.9 * 0.94 = 0.846
|
||||
assert system_availability(structure, availabilities) == pytest.approx(0.846)
|
||||
|
||||
def test_get_all_components():
|
||||
structure = {
|
||||
"series": [
|
||||
"A",
|
||||
{"parallel": ["B", "C"]}
|
||||
]
|
||||
}
|
||||
assert get_all_components(structure) == {"A", "B", "C"}
|
||||
|
||||
def test_birnbaum_importance():
|
||||
structure = {"series": ["A", "B"]}
|
||||
availabilities = {"A": 0.9, "B": 0.8}
|
||||
# I_B(A) = A_sys(A=1) - A_sys(A=0)
|
||||
# A_sys(A=1, B=0.8) = 1 * 0.8 = 0.8
|
||||
# A_sys(A=0, B=0.8) = 0 * 0.8 = 0
|
||||
# I_B(A) = 0.8
|
||||
assert birnbaum_importance(structure, availabilities, "A") == pytest.approx(0.8)
|
||||
# I_B(B) = A_sys(B=1, A=0.9) - A_sys(B=0, A=0.9) = 0.9 - 0 = 0.9
|
||||
assert birnbaum_importance(structure, availabilities, "B") == pytest.approx(0.9)
|
||||
@ -0,0 +1,31 @@
|
||||
import pytest
|
||||
from sqlalchemy.exc import IntegrityError, DataError, DBAPIError
|
||||
from src.exceptions import handle_sqlalchemy_error
|
||||
|
||||
def test_handle_sqlalchemy_error_unique_constraint():
|
||||
err = IntegrityError("Unique constraint", params=None, orig=Exception("unique constraint violation"))
|
||||
msg, status = handle_sqlalchemy_error(err)
|
||||
assert status == 409
|
||||
assert "already exists" in msg
|
||||
|
||||
def test_handle_sqlalchemy_error_foreign_key():
|
||||
err = IntegrityError("Foreign key constraint", params=None, orig=Exception("foreign key constraint violation"))
|
||||
msg, status = handle_sqlalchemy_error(err)
|
||||
assert status == 400
|
||||
assert "Related record not found" in msg
|
||||
|
||||
def test_handle_sqlalchemy_error_data_error():
|
||||
err = DataError("Invalid data", params=None, orig=None)
|
||||
msg, status = handle_sqlalchemy_error(err)
|
||||
assert status == 400
|
||||
assert "Invalid data" in msg
|
||||
|
||||
def test_handle_sqlalchemy_error_generic_dbapi():
|
||||
class MockError:
|
||||
def __str__(self):
|
||||
return "Some generic database error"
|
||||
|
||||
err = DBAPIError("Generic error", params=None, orig=MockError())
|
||||
msg, status = handle_sqlalchemy_error(err)
|
||||
assert status == 500
|
||||
assert "Database error" in msg
|
||||
@ -0,0 +1,56 @@
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from fastapi import HTTPException
|
||||
from src.middleware import RequestValidationMiddleware
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_validation_middleware_query_length():
|
||||
middleware = RequestValidationMiddleware(app=MagicMock())
|
||||
request = MagicMock()
|
||||
request.url.query = "a" * 2001
|
||||
|
||||
with pytest.raises(HTTPException) as excinfo:
|
||||
await middleware.dispatch(request, AsyncMock())
|
||||
assert excinfo.value.status_code == 414
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_validation_middleware_too_many_params():
|
||||
middleware = RequestValidationMiddleware(app=MagicMock())
|
||||
request = MagicMock()
|
||||
request.url.query = "a=1"
|
||||
request.query_params.multi_items.return_value = [("param", "val")] * 51
|
||||
|
||||
with pytest.raises(HTTPException) as excinfo:
|
||||
await middleware.dispatch(request, AsyncMock())
|
||||
assert excinfo.value.status_code == 400
|
||||
assert "Too many query parameters" in excinfo.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_validation_middleware_xss_detection():
|
||||
middleware = RequestValidationMiddleware(app=MagicMock())
|
||||
request = MagicMock()
|
||||
request.url.query = "q=<script>"
|
||||
request.query_params.multi_items.return_value = [("q", "<script>")]
|
||||
|
||||
with pytest.raises(HTTPException) as excinfo:
|
||||
await middleware.dispatch(request, AsyncMock())
|
||||
assert excinfo.value.status_code == 400
|
||||
assert "Potential XSS payload" in excinfo.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_validation_middleware_pagination_logic():
|
||||
middleware = RequestValidationMiddleware(app=MagicMock())
|
||||
request = MagicMock()
|
||||
request.url.query = "size=55"
|
||||
request.query_params.multi_items.return_value = [("size", "55")]
|
||||
request.headers = {}
|
||||
|
||||
with pytest.raises(HTTPException) as excinfo:
|
||||
await middleware.dispatch(request, AsyncMock())
|
||||
assert excinfo.value.status_code == 400
|
||||
assert "cannot exceed 50" in excinfo.value.detail
|
||||
|
||||
request.query_params.multi_items.return_value = [("size", "7")]
|
||||
with pytest.raises(HTTPException) as excinfo:
|
||||
await middleware.dispatch(request, AsyncMock())
|
||||
assert "must be a multiple of 5" in excinfo.value.detail
|
||||
@ -0,0 +1,64 @@
|
||||
import pytest
|
||||
import math
|
||||
from src.calculation_target_reliability.service import calculate_asset_eaf_contributions
|
||||
|
||||
def test_calculate_asset_eaf_contributions_basic():
|
||||
# Mock plant result
|
||||
plant_result = {
|
||||
"total_uptime": 7000,
|
||||
"total_downtime": 1000,
|
||||
"eaf": 85.0
|
||||
}
|
||||
# total_hours = 8000
|
||||
|
||||
# Mock equipment results
|
||||
eq_results = [
|
||||
{
|
||||
"aeros_node": {"node_name": "Asset1"},
|
||||
"num_events": 5,
|
||||
"contribution_factor": 0.5,
|
||||
"contribution": 0.1, # Birnbaum
|
||||
"availability": 0.9,
|
||||
"total_downtime": 100
|
||||
},
|
||||
{
|
||||
"aeros_node": {"node_name": "Asset2"},
|
||||
"num_events": 2,
|
||||
"contribution_factor": 0.3,
|
||||
"contribution": 0.05,
|
||||
"availability": 0.95,
|
||||
"total_downtime": 50
|
||||
}
|
||||
]
|
||||
|
||||
standard_scope = ["Asset1", "Asset2"]
|
||||
eaf_gap = 2.0 # 2% gap
|
||||
scheduled_outage = 500
|
||||
|
||||
results = calculate_asset_eaf_contributions(
|
||||
plant_result, eq_results, standard_scope, eaf_gap, scheduled_outage
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
# Check sorting (highest birnbaum first)
|
||||
assert results[0].node["node_name"] == "Asset1"
|
||||
assert results[0].birbaum > results[1].birbaum
|
||||
|
||||
# Check that required_improvement is positive
|
||||
assert results[0].required_improvement > 0
|
||||
assert results[0].improvement_impact > 0
|
||||
|
||||
def test_calculate_asset_eaf_contributions_skipping():
|
||||
plant_result = {"total_uptime": 1000, "total_downtime": 0, "eaf": 100}
|
||||
eq_results = [{
|
||||
"aeros_node": {"node_name": "Asset1"},
|
||||
"num_events": 0,
|
||||
"contribution_factor": 0.5,
|
||||
"contribution": 0.1,
|
||||
"availability": 1.0,
|
||||
"total_downtime": 0
|
||||
}]
|
||||
results = calculate_asset_eaf_contributions(
|
||||
plant_result, eq_results, ["Asset1"], 1.0, 0
|
||||
)
|
||||
assert len(results) == 0
|
||||
@ -0,0 +1,49 @@
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
from src.database.schema import CommonParams
|
||||
from src.overhaul.schema import OverhaulCriticalParts
|
||||
|
||||
def test_common_params_valid():
|
||||
params = CommonParams(
|
||||
page=1,
|
||||
itemsPerPage=10,
|
||||
q="search test",
|
||||
all=1
|
||||
)
|
||||
assert params.page == 1
|
||||
assert params.items_per_page == 10
|
||||
assert params.query_str == "search test"
|
||||
assert params.is_all is True
|
||||
|
||||
def test_common_params_page_constraints():
|
||||
# Test page must be > 0
|
||||
with pytest.raises(ValidationError):
|
||||
CommonParams(page=0)
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
CommonParams(page=-1)
|
||||
|
||||
def test_common_params_items_per_page_constraints():
|
||||
# Test items_per_page must be multiple of 5
|
||||
with pytest.raises(ValidationError):
|
||||
CommonParams(itemsPerPage=7)
|
||||
|
||||
# Test items_per_page maximum
|
||||
with pytest.raises(ValidationError):
|
||||
CommonParams(itemsPerPage=55)
|
||||
|
||||
# Valid multiples of 5
|
||||
assert CommonParams(itemsPerPage=50).items_per_page == 50
|
||||
assert CommonParams(itemsPerPage=5).items_per_page == 5
|
||||
|
||||
def test_overhaul_critical_parts_valid():
|
||||
parts = OverhaulCriticalParts(criticalParts=["Part A", "Part B"])
|
||||
assert parts.criticalParts == ["Part A", "Part B"]
|
||||
|
||||
def test_overhaul_critical_parts_invalid():
|
||||
# criticalParts is required and must be a list
|
||||
with pytest.raises(ValidationError):
|
||||
OverhaulCriticalParts()
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
OverhaulCriticalParts(criticalParts="Not a list")
|
||||
@ -0,0 +1,58 @@
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from src.middleware import (
|
||||
inspect_value,
|
||||
inspect_json,
|
||||
has_control_chars,
|
||||
XSS_PATTERN,
|
||||
SQLI_PATTERN
|
||||
)
|
||||
|
||||
def test_xss_patterns():
|
||||
# Test common XSS payloads in be-optimumoh
|
||||
payloads = [
|
||||
"<script>",
|
||||
"javascript:",
|
||||
"onerror=",
|
||||
"onload=",
|
||||
"<svg",
|
||||
"<img"
|
||||
]
|
||||
for payload in payloads:
|
||||
assert XSS_PATTERN.search(payload) is not None
|
||||
|
||||
def test_sqli_patterns():
|
||||
# Test common SQLi payloads in be-optimumoh
|
||||
payloads = [
|
||||
"UNION",
|
||||
"SELECT",
|
||||
"INSERT",
|
||||
"DELETE",
|
||||
"DROP",
|
||||
"--",
|
||||
"OR 1=1"
|
||||
]
|
||||
for payload in payloads:
|
||||
assert SQLI_PATTERN.search(payload) is not None
|
||||
|
||||
def test_inspect_value_raises():
|
||||
with pytest.raises(HTTPException) as excinfo:
|
||||
inspect_value("<script>", "source")
|
||||
assert excinfo.value.status_code == 400
|
||||
assert "Potential XSS payload" in excinfo.value.detail
|
||||
|
||||
with pytest.raises(HTTPException) as excinfo:
|
||||
inspect_value("UNION SELECT", "source")
|
||||
assert excinfo.value.status_code == 400
|
||||
assert "Potential SQL injection" in excinfo.value.detail
|
||||
|
||||
def test_inspect_json_raises():
|
||||
with pytest.raises(HTTPException) as excinfo:
|
||||
inspect_json({"__proto__": "polluted"})
|
||||
assert excinfo.value.status_code == 400
|
||||
assert "Forbidden JSON key" in excinfo.value.detail
|
||||
|
||||
def test_has_control_chars():
|
||||
assert has_control_chars("normal string") is False
|
||||
assert has_control_chars("string with \x00 null") is True
|
||||
assert has_control_chars("string with \n newline") is False
|
||||
@ -0,0 +1,33 @@
|
||||
import pytest
|
||||
from datetime import datetime, timedelta
|
||||
from src.calculation_target_reliability.utils import generate_down_periods
|
||||
|
||||
def test_generate_down_periods_count():
|
||||
start = datetime(2025, 1, 1)
|
||||
end = datetime(2025, 1, 31)
|
||||
# Test fixed number of periods
|
||||
periods = generate_down_periods(start, end, num_periods=5)
|
||||
# It attempts to generate 5, but might be fewer due to overlaps
|
||||
assert len(periods) <= 5
|
||||
|
||||
# Check they are within range
|
||||
for p_start, p_end in periods:
|
||||
assert p_start >= start
|
||||
assert p_end <= end
|
||||
assert p_start < p_end
|
||||
|
||||
def test_generate_down_periods_no_overlap():
|
||||
start = datetime(2025, 1, 1)
|
||||
end = datetime(2025, 1, 31)
|
||||
periods = generate_down_periods(start, end, num_periods=10)
|
||||
|
||||
# Sort and check gaps
|
||||
for i in range(len(periods) - 1):
|
||||
assert periods[i][1] <= periods[i+1][0]
|
||||
|
||||
def test_generate_down_periods_too_small_range():
|
||||
start = datetime(2025, 1, 1)
|
||||
end = datetime(2025, 1, 2)
|
||||
# Requesting 5 days duration in 1 day range
|
||||
periods = generate_down_periods(start, end, num_periods=1, min_duration=5)
|
||||
assert len(periods) == 0
|
||||
@ -0,0 +1,36 @@
|
||||
import pytest
|
||||
from datetime import datetime, timedelta
|
||||
from src.utils import parse_relative_expression, parse_date_string
|
||||
|
||||
def test_parse_relative_expression_days():
|
||||
# Test T, T+n, T-n
|
||||
result = parse_relative_expression("T")
|
||||
assert result is not None
|
||||
assert isinstance(result, datetime)
|
||||
|
||||
result_plus = parse_relative_expression("T+5")
|
||||
assert result_plus is not None
|
||||
|
||||
result_minus = parse_relative_expression("T-3")
|
||||
assert result_minus is not None
|
||||
|
||||
def test_parse_relative_expression_invalid():
|
||||
assert parse_relative_expression("abc") is None
|
||||
assert parse_relative_expression("123") is None
|
||||
assert parse_relative_expression("T++1") is None
|
||||
|
||||
def test_parse_date_string_formats():
|
||||
# Test various ISO and common formats
|
||||
dt = parse_date_string("2024-11-08")
|
||||
assert dt.year == 2024
|
||||
assert dt.month == 11
|
||||
assert dt.day == 8
|
||||
|
||||
dt = parse_date_string("08-11-2024")
|
||||
assert dt.year == 2024
|
||||
assert dt.month == 11
|
||||
assert dt.day == 8
|
||||
|
||||
def test_parse_date_string_invalid():
|
||||
with pytest.raises(ValueError):
|
||||
parse_date_string("invalid-date")
|
||||
Loading…
Reference in New Issue