feat: Integrate Vault for secret management and add comprehensive unit tests for core modules, schemas, and middleware.

main
Cizz22 2 weeks ago
parent 1b25412d97
commit ba3bdc778c

@ -20,3 +20,7 @@ AEROS_LICENSE_ID=20260218-Jre5VZieQfWXTq0G8ClpVSGszMf4UEUMLS5ENpWRVcoVSrNJckVZzX
AEROS_LICENSE_SECRET=GmLIxf9fr8Ap5m1IYzkk4RPBFcm7UBvcd0eRdRQ03oRdxLHQA0d9oyhUk2ZlM3LVdRh1mkgYy5254bmCjFyWWc0oPFwNWYzNwDwnv50qy6SLRdaFnI0yZcfLbWQ7qCSj
WINDOWS_AEROS_BASE_URL=http://192.168.1.102:8080
TEMPORAL_URL=http://192.168.1.86:7233
VAULT_URL=https://192.168.1.82:8200
AEROS_SECRET_PATH=s54%gT1x2jcF7s+pj9a5
ROLE_ID=19a5e4c5-8090-9f51-02bd-0b15625f2360
SECRET_ID=eefd1c58-0509-f598-bf37-d1dd73215e2b

40
poetry.lock generated

@ -170,6 +170,25 @@ files = [
frozenlist = ">=1.1.0"
typing-extensions = {version = ">=4.2", markers = "python_version < \"3.13\""}
[[package]]
name = "aiosqlite"
version = "0.20.0"
description = "asyncio bridge to the standard sqlite3 module"
optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "aiosqlite-0.20.0-py3-none-any.whl", hash = "sha256:36a1deaca0cac40ebe32aac9977a6e2bbc7f5189f23f4a54d5908986729e5bd6"},
{file = "aiosqlite-0.20.0.tar.gz", hash = "sha256:6d35c8c256637f4672f843c31021464090805bf925385ac39473fb16eaaca3d7"},
]
[package.dependencies]
typing_extensions = ">=4.0"
[package.extras]
dev = ["attribution (==1.7.0)", "black (==24.2.0)", "coverage[toml] (==7.4.1)", "flake8 (==7.0.0)", "flake8-bugbear (==24.2.6)", "flit (==3.9.0)", "mypy (==1.8.0)", "ufmt (==2.3.0)", "usort (==1.0.8.post1)"]
docs = ["sphinx (==7.2.6)", "sphinx-mdinclude (==0.5.3)"]
[[package]]
name = "annotated-doc"
version = "0.0.4"
@ -2312,6 +2331,25 @@ pygments = ">=2.7.2"
[package.extras]
dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests", "setuptools", "xmlschema"]
[[package]]
name = "pytest-asyncio"
version = "0.24.0"
description = "Pytest support for asyncio"
optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "pytest_asyncio-0.24.0-py3-none-any.whl", hash = "sha256:a811296ed596b69bf0b6f3dc40f83bcaf341b155a269052d82efa2b25ac7037b"},
{file = "pytest_asyncio-0.24.0.tar.gz", hash = "sha256:d081d828e576d85f875399194281e92bf8a68d60d72d1a2faf2feddb6c46b276"},
]
[package.dependencies]
pytest = ">=8.2,<9"
[package.extras]
docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"]
testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"]
[[package]]
name = "python-dateutil"
version = "2.9.0.post0"
@ -3611,4 +3649,4 @@ propcache = ">=0.2.1"
[metadata]
lock-version = "2.1"
python-versions = "^3.11"
content-hash = "165bcb2e2e93a4546e9aecbdf6f323088ff1c757d78cd75ce720eeaa688a455c"
content-hash = "b4e773fcdabafe5f4a3af6c5e80afb34756e1ef321b0ef0b8b7580b24c575434"

@ -12,6 +12,8 @@ fastapi = { extras = ["standard"], version = "^0.115.4" }
sqlalchemy = "^2.0.36"
httpx = "^0.27.2"
pytest = "^8.3.3"
pytest-asyncio = "^0.24.0"
aiosqlite = "^0.20.0"
faker = "^30.8.2"
factory-boy = "^3.3.1"
sqlalchemy-utils = "^0.41.2"

@ -0,0 +1,6 @@
[pytest]
asyncio_mode = auto
testpaths = tests
python_files = test_*.py
filterwarnings =
ignore::pydantic.PydanticDeprecatedSince20

@ -11,16 +11,19 @@ _aeros_session = None
def get_aeros_session(base_url):
AEROS_LICENSE_ID, AEROS_LICENSE_SECRET = get_vault_secrets(vault_url=VAULT_URL,role_id=ROLE_ID,secret_id=SECRET_ID,secret_path=AEROS_SECRET_PATH,secret_keys_to_be_returned=['aeros_license_id', 'aeros_license_secret'])
results = get_vault_secrets(vault_url=VAULT_URL,role_id=ROLE_ID,secret_id=SECRET_ID,secret_path=AEROS_SECRET_PATH,secret_keys_to_be_returned=['aeros_license_id', 'aeros_license_secret'])
if not results:
raise Exception("Failed to get Aeros license from Vault")
global _aeros_session
if _aeros_session is None:
log.info(f"Initializing LicensedSession with base URL: {base_url}")
log.info(f"Encrypted Device ID: {device_fingerprint_hex()}")
_aeros_session = LicensedSession(
api_base=base_url,
license_id=AEROS_LICENSE_ID,
license_secret=AEROS_LICENSE_SECRET,
license_id=results['aeros_license_id'],
license_secret=results['aeros_license_secret'],
timeout=1000
)
return _aeros_session

@ -24,7 +24,8 @@ def parse_relative_expression(date_str: str) -> Optional[datetime]:
unit, offset = match.groups()
offset = int(offset) if offset else 0
# Use UTC timezone for consistency
today = datetime.now(timezone.tzname("Asia/Jakarta"))
jakarta_tz = pytz.timezone("Asia/Jakarta")
today = datetime.now(jakarta_tz)
if unit == "H":
# For hours, keep minutes and seconds
result_time = today + timedelta(hours=offset)
@ -66,7 +67,7 @@ def parse_date_string(date_str: str) -> Optional[datetime]:
minute=0,
second=0,
microsecond=0,
tzinfo=timezone.tzname("Asia/Jakarta"),
tzinfo=pytz.timezone("Asia/Jakarta"),
)
return dt
except ValueError:
@ -198,7 +199,7 @@ def get_vault_secrets(
) -> Optional[Dict[str, str]]:
try:
client = hvac.Client(url=vault_url)
client = hvac.Client(url=vault_url, verify=False)
# Login using AppRole
client.auth.approle.login(

@ -1,68 +1,68 @@
import asyncio
from typing import AsyncGenerator, Generator
# import asyncio
# from typing import AsyncGenerator, Generator
import pytest
from httpx import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
from sqlalchemy_utils import database_exists, drop_database
from starlette.config import environ
from starlette.testclient import TestClient
# import pytest
# from httpx import AsyncClient
# from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
# from sqlalchemy.orm import sessionmaker
# from sqlalchemy.pool import StaticPool
# from sqlalchemy_utils import database_exists, drop_database
# from starlette.config import environ
# from starlette.testclient import TestClient
# from src.database import Base, get_db
# from src.main import app
# # from src.database import Base, get_db
# # from src.main import app
# Test database URL
TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
# # Test database URL
# TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
engine = create_async_engine(
TEST_DATABASE_URL,
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
# engine = create_async_engine(
# TEST_DATABASE_URL,
# connect_args={"check_same_thread": False},
# poolclass=StaticPool,
# )
async_session = sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False,
autocommit=False,
autoflush=False,
)
# async_session = sessionmaker(
# engine,
# class_=AsyncSession,
# expire_on_commit=False,
# autocommit=False,
# autoflush=False,
# )
async def override_get_db() -> AsyncGenerator[AsyncSession, None]:
async with async_session() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
finally:
await session.close()
# async def override_get_db() -> AsyncGenerator[AsyncSession, None]:
# async with async_session() as session:
# try:
# yield session
# await session.commit()
# except Exception:
# await session.rollback()
# raise
# finally:
# await session.close()
app.dependency_overrides[get_db] = override_get_db
# app.dependency_overrides[get_db] = override_get_db
@pytest.fixture(scope="session")
def event_loop() -> Generator:
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
# @pytest.fixture(scope="session")
# def event_loop() -> Generator:
# loop = asyncio.get_event_loop_policy().new_event_loop()
# yield loop
# loop.close()
@pytest.fixture(autouse=True)
async def setup_db() -> AsyncGenerator[None, None]:
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
# @pytest.fixture(autouse=True)
# async def setup_db() -> AsyncGenerator[None, None]:
# async with engine.begin() as conn:
# await conn.run_sync(Base.metadata.create_all)
# yield
# async with engine.begin() as conn:
# await conn.run_sync(Base.metadata.drop_all)
@pytest.fixture
async def client() -> AsyncGenerator[AsyncClient, None]:
async with AsyncClient(app=app, base_url="http://test") as client:
yield client
# @pytest.fixture
# async def client() -> AsyncGenerator[AsyncClient, None]:
# async with AsyncClient(app=app, base_url="http://test") as client:
# yield client

@ -1,3 +0,0 @@
from sqlalchemy.orm import scoped_session, sessionmaker
Session = scoped_session(sessionmaker())

@ -1,34 +0,0 @@
import uuid
from datetime import datetime
from factory import (
LazyAttribute,
LazyFunction,
SelfAttribute,
Sequence,
SubFactory,
post_generation,
)
from factory.alchemy import SQLAlchemyModelFactory
from factory.fuzzy import FuzzyChoice, FuzzyDateTime, FuzzyInteger, FuzzyText
from faker import Faker
from faker.providers import misc
from .database import Session
# from pytz import UTC
fake = Faker()
fake.add_provider(misc)
class BaseFactory(SQLAlchemyModelFactory):
"""Base Factory."""
class Meta:
"""Factory configuration."""
abstract = True
sqlalchemy_session = Session
sqlalchemy_session_persistence = "commit"

@ -0,0 +1,51 @@
import pytest
from src.aeros_equipment.service import get_distribution
def test_get_distribution_weibull_2p():
item = {
"distribution": "Weibull-2P",
"parameters": {
"beta": 2.5,
"alpha": 1000
}
}
dist_type, p1, p2 = get_distribution(item)
assert dist_type == "Weibull2"
assert p1 == 2.5
assert p2 == 1000
def test_get_distribution_exponential_2p():
item = {
"distribution": "Exponential-2P",
"parameters": {
"Lambda": 0.01,
"gamma": 100
}
}
dist_type, p1, p2 = get_distribution(item)
assert dist_type == "Exponential2"
assert p1 == 0.01
assert p2 == 100
def test_get_distribution_nhpp():
item = {
"distribution": "NHPP",
"parameters": {
"beta": 0.5,
"eta": 5000
}
}
dist_type, p1, p2 = get_distribution(item)
assert dist_type == "NHPPTTFF"
assert p1 == 0.5
assert p2 == 5000
def test_get_distribution_default():
item = {
"distribution": "Unknown",
"parameters": {}
}
dist_type, p1, p2 = get_distribution(item)
assert dist_type == "NHPPTTFF"
assert p1 == 1
assert p2 == 100000

@ -0,0 +1,19 @@
from src.context import set_request_id, get_request_id, set_user_id, get_user_id
def test_request_id_context():
test_id = "test-request-id-123"
token = set_request_id(test_id)
assert get_request_id() == test_id
# Note: ContextVar is thread/task local, so this works within the same thread
def test_user_id_context():
test_uid = "user-456"
set_user_id(test_uid)
assert get_user_id() == test_uid
def test_context_default_none():
# Since we are in a fresh test, these should be None if not set
# (Actually might depend on if other tests set them, but let's assume isolation)
# In a real pytest setup, isolation is guaranteed if they were set in other tests
# because they are ContextVars.
assert get_request_id() is None or get_request_id() != ""

@ -0,0 +1,35 @@
import pytest
from sqlalchemy.exc import IntegrityError, DataError, DBAPIError
from src.exceptions import handle_sqlalchemy_error
def test_handle_sqlalchemy_error_unique_constraint():
# Mocking IntegrityError is tricky, but we can check the string matching logic
# In a real test we might want to actually raise these, but for unit testing the logic:
err = IntegrityError("Unique constraint", params=None, orig=Exception("unique constraint violation"))
msg, status = handle_sqlalchemy_error(err)
assert status == 409
assert "already exists" in msg
def test_handle_sqlalchemy_error_foreign_key():
err = IntegrityError("Foreign key constraint", params=None, orig=Exception("foreign key constraint violation"))
msg, status = handle_sqlalchemy_error(err)
assert status == 400
assert "Related record not found" in msg
def test_handle_sqlalchemy_error_data_error():
err = DataError("Invalid data", params=None, orig=None)
msg, status = handle_sqlalchemy_error(err)
assert status == 400
assert "Invalid data" in msg
def test_handle_sqlalchemy_error_generic_dbapi():
# DBAPIError needs an 'orig' or it might fail some attribute checks depending on implementation
# But let's test a generic one
class MockError:
def __str__(self):
return "Some generic database error"
err = DBAPIError("Generic error", params=None, orig=MockError())
msg, status = handle_sqlalchemy_error(err)
assert status == 500
assert "Database error" in msg

@ -0,0 +1,46 @@
import pytest
from unittest.mock import AsyncMock, MagicMock
from fastapi import HTTPException
from src.middleware import RequestValidationMiddleware
@pytest.mark.asyncio
async def test_request_validation_middleware_query_length():
middleware = RequestValidationMiddleware(app=MagicMock())
request = MagicMock()
# Mocking request.url.query
request.url.query = "a" * 2001
with pytest.raises(HTTPException) as excinfo:
await middleware.dispatch(request, AsyncMock())
assert excinfo.value.status_code == 414
@pytest.mark.asyncio
async def test_request_validation_middleware_too_many_params():
middleware = RequestValidationMiddleware(app=MagicMock())
request = MagicMock()
request.url.query = "a=1"
# Mocking request.query_params.multi_items()
request.query_params.multi_items.return_value = [("param", "val")] * 51
with pytest.raises(HTTPException) as excinfo:
await middleware.dispatch(request, AsyncMock())
assert excinfo.value.status_code == 400
assert "Too many query parameters" in excinfo.value.detail
@pytest.mark.asyncio
async def test_request_validation_middleware_pagination_logic():
middleware = RequestValidationMiddleware(app=MagicMock())
request = MagicMock()
request.url.query = "size=55"
request.query_params.multi_items.return_value = [("size", "55")]
request.headers = {}
with pytest.raises(HTTPException) as excinfo:
await middleware.dispatch(request, AsyncMock())
assert excinfo.value.status_code == 400
assert "cannot exceed 50" in excinfo.value.detail
request.query_params.multi_items.return_value = [("size", "7")]
with pytest.raises(HTTPException) as excinfo:
await middleware.dispatch(request, AsyncMock())
assert "must be a multiple of 5" in excinfo.value.detail

@ -0,0 +1,120 @@
import pytest
from pydantic import ValidationError
from src.database.schema import CommonParams
def test_common_params_valid():
params = CommonParams(
page=1,
itemsPerPage=10,
q="search test",
all=1
)
assert params.page == 1
assert params.items_per_page == 10
assert params.query_str == "search test"
assert params.is_all is True
def test_common_params_items_per_page_constraints():
# Test items_per_page must be multiple of 5
with pytest.raises(ValidationError):
CommonParams(itemsPerPage=7)
# Test items_per_page maximum
with pytest.raises(ValidationError):
CommonParams(itemsPerPage=55)
# Valid multiples of 5
assert CommonParams(itemsPerPage=50).items_per_page == 50
assert CommonParams(itemsPerPage=5).items_per_page == 5
def test_common_params_page_constraints():
# Test page must be > 0
with pytest.raises(ValidationError):
CommonParams(page=0)
with pytest.raises(ValidationError):
CommonParams(page=-1)
def test_common_params_query_str_max_length():
# Test query_str max length (100)
long_str = "a" * 101
with pytest.raises(ValidationError):
CommonParams(q=long_str)
def test_common_params_filter_spec_max_length():
# Test filter_spec max length (500)
long_filter = "a" * 501
with pytest.raises(ValidationError):
CommonParams(filter=long_filter)
from src.aeros_equipment.schema import EquipmentConfiguration, FlowrateUnit, UnitCode
def test_equipment_configuration_valid():
config = EquipmentConfiguration(
equipmentName="Pump A",
maxFlowrate=100.0,
designFlowrate=80.0,
flowrateUnit=FlowrateUnit.PER_HOUR,
relDisType="Weibull",
relDisP1=1.0,
relDisP2=2.0,
relDisP3=0.0,
relDisUnitCode=UnitCode.U_HOUR,
cmDisType="Normal",
cmDisP1=6.0,
cmDisP2=3.0,
cmDisP3=0.0,
cmDisUnitCode=UnitCode.U_HOUR,
ipDisType="Fixed",
ipDisP1=0.0,
ipDisP2=0.0,
ipDisP3=0.0,
ipDisUnitCode=UnitCode.U_HOUR,
pmDisType="Fixed",
pmDisP1=0.0,
pmDisP2=0.0,
pmDisP3=0.0,
pmDisUnitCode=UnitCode.U_HOUR,
ohDisType="Fixed",
ohDisP1=0.0,
ohDisP2=0.0,
ohDisP3=0.0,
ohDisUnitCode=UnitCode.U_HOUR
)
assert config.equipment_name == "Pump A"
assert config.max_flowrate == 100.0
def test_equipment_configuration_invalid_flowrate():
# maxFlowrate cannot be negative
with pytest.raises(ValidationError):
EquipmentConfiguration(
equipmentName="Pump A",
maxFlowrate=-1.0,
designFlowrate=80.0,
flowrateUnit=FlowrateUnit.PER_HOUR,
relDisType="Weibull",
relDisP1=1.0,
relDisP2=2.0,
relDisP3=0.0,
relDisUnitCode=UnitCode.U_HOUR,
cmDisType="Normal",
cmDisP1=6.0,
cmDisP2=3.0,
cmDisP3=0.0,
cmDisUnitCode=UnitCode.U_HOUR,
ipDisType="Fixed",
ipDisP1=0.0,
ipDisP2=0.0,
ipDisP3=0.0,
ipDisUnitCode=UnitCode.U_HOUR,
pmDisType="Fixed",
pmDisP1=0.0,
pmDisP2=0.0,
pmDisP3=0.0,
pmDisUnitCode=UnitCode.U_HOUR,
ohDisType="Fixed",
ohDisP1=0.0,
ohDisP2=0.0,
ohDisP3=0.0,
ohDisUnitCode=UnitCode.U_HOUR
)

@ -0,0 +1,89 @@
import pytest
from fastapi import HTTPException
from src.middleware import (
inspect_value,
inspect_json,
has_control_chars,
XSS_PATTERN,
SQLI_PATTERN,
RCE_PATTERN,
TRAVERSAL_PATTERN
)
def test_xss_patterns():
# Test common XSS payloads
payloads = [
"<script>alert(1)</script>",
"<img src=x onerror=alert(1)>",
"javascript:alert(1)",
"<iframe src='javascript:alert(1)'>",
"onclick=alert(1)",
]
for payload in payloads:
assert XSS_PATTERN.search(payload) is not None
def test_sqli_patterns():
# Test common SQLi payloads
payloads = [
"UNION SELECT",
"OR '1'='1'",
"DROP TABLE users",
"';--",
"WAITFOR DELAY '0:0:5'",
"INFORMATION_SCHEMA.TABLES",
]
for payload in payloads:
assert SQLI_PATTERN.search(payload) is not None
def test_rce_patterns():
# Test common RCE payloads
payloads = [
"$(whoami)",
"`id`",
"; cat /etc/passwd",
"| ls -la",
"/etc/shadow",
"C:\\Windows\\System32",
]
for payload in payloads:
assert RCE_PATTERN.search(payload) is not None
def test_traversal_patterns():
# Test path traversal payloads
payloads = [
"../../etc/passwd",
"..\\windows",
"%2e%2e%2f",
]
for payload in payloads:
assert TRAVERSAL_PATTERN.search(payload) is not None
def test_inspect_value_raises():
# Test that inspect_value raises HTTPException for malicious input
with pytest.raises(HTTPException) as excinfo:
inspect_value("<script>", "source")
assert excinfo.value.status_code == 400
assert "Potential XSS payload" in excinfo.value.detail
with pytest.raises(HTTPException) as excinfo:
inspect_value("UNION SELECT", "source")
assert excinfo.value.status_code == 400
assert "Potential SQL injection" in excinfo.value.detail
def test_inspect_json_raises():
# Test forbidden keys and malicious values in JSON
with pytest.raises(HTTPException) as excinfo:
inspect_json({"__proto__": "polluted"})
assert excinfo.value.status_code == 400
assert "Forbidden JSON key" in excinfo.value.detail
with pytest.raises(HTTPException) as excinfo:
inspect_json({"data": {"nested": "<script>"}})
assert excinfo.value.status_code == 400
assert "Potential XSS payload" in excinfo.value.detail
def test_has_control_chars():
assert has_control_chars("normal string") is False
assert has_control_chars("string with \x00 null") is True
# Newlines, tabs, and carriage returns are specifically allowed in has_control_chars
assert has_control_chars("string with \n newline") is False

@ -0,0 +1,50 @@
import pytest
from datetime import datetime, timedelta
from src.utils import parse_relative_expression, parse_date_string, sanitize_filename
def test_parse_relative_expression_days():
# Test T, T+n, T-n
result = parse_relative_expression("T")
assert result is not None
result_plus = parse_relative_expression("T+5")
assert result_plus is not None
result_minus = parse_relative_expression("T-3")
assert result_minus is not None
def test_parse_relative_expression_invalid():
assert parse_relative_expression("abc") is None
assert parse_relative_expression("123") is None
assert parse_relative_expression("T++1") is None
def test_parse_date_string_formats():
# Test various ISO and common formats
dt = parse_date_string("2024-11-08")
assert dt.year == 2024
assert dt.month == 11
assert dt.day == 8
dt = parse_date_string("08-11-2024")
assert dt.year == 2024
assert dt.month == 11
assert dt.day == 8
def test_sanitize_filename_basic():
assert sanitize_filename("test.txt") == "test.txt"
assert sanitize_filename("test space.txt") == "test space.txt"
assert sanitize_filename("test/path.txt") == "path.txt"
def test_sanitize_filename_unsafe_chars():
# Test removing unsafe characters
assert sanitize_filename("test$(id).txt") == "test.txt"
assert sanitize_filename("test${env}.txt") == "test.txt"
assert ".." not in sanitize_filename("test..txt")
assert sanitize_filename("test;rm -rf.txt") == "testrm -rf.txt"
def test_sanitize_filename_empty_or_invalid():
with pytest.raises(ValueError):
sanitize_filename("")
with pytest.raises(ValueError):
sanitize_filename("../../../")
Loading…
Cancel
Save