You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

115 lines
3.6 KiB
Python

import os
# Set dummy environment variables for testing
os.environ["DATABASE_HOSTNAME"] = "localhost"
os.environ["DATABASE_CREDENTIAL_USER"] = "test"
os.environ["DATABASE_CREDENTIAL_PASSWORD"] = "test"
os.environ["COLLECTOR_CREDENTIAL_USER"] = "test"
os.environ["COLLECTOR_CREDENTIAL_PASSWORD"] = "test"
os.environ["DEV_USERNAME"] = "test"
os.environ["DEV_PASSWORD"] = "test"
import asyncio
from typing import AsyncGenerator, Generator
import pytest
import pytest_asyncio
from httpx import AsyncClient, ASGITransport
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
from fastapi import Request
from src.main import app
from src.database.core import Base, get_db, get_collector_db
from src.auth.service import JWTBearer
from src.auth.model import UserBase
# Import all models to register them with Base
import src.acquisition_cost.model
import src.equipment.model
import src.equipment_master.model
import src.manpower_cost.model
import src.manpower_master.model
import src.masterdata.model
import src.masterdata_simulations.model
import src.plant_fs_transaction_data.model
import src.plant_masterdata.model
import src.plant_transaction_data.model
import src.plant_transaction_data_simulations.model
import src.simulations.model
import src.uploaded_file.model
import src.yeardata.model
# Test database URL
TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
engine = create_async_engine(
TEST_DATABASE_URL,
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
TestingSessionLocal = sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False,
autocommit=False,
autoflush=False,
)
def pytest_sessionfinish(session, exitstatus):
"""
Called after whole test run finished, right before returning the exit status to the system.
Used here to dispose of all SQLAlchemy engines to prevent hanging.
"""
from src.database.core import engine as db_engine, collector_engine
async def dispose_all():
# Dispose of both test engine and production engines
await engine.dispose()
await db_engine.dispose()
await collector_engine.dispose()
try:
loop = asyncio.get_event_loop()
if loop.is_running():
# If the loop is already running, we create a task
loop.create_task(dispose_all())
else:
loop.run_until_complete(dispose_all())
except Exception:
# Fallback for environment where no loop is available or loop is closed
try:
asyncio.run(dispose_all())
except Exception:
pass
# Removed custom event_loop fixture
@pytest_asyncio.fixture(autouse=True)
async def setup_db():
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
async def override_get_db(request: Request = None):
async with TestingSessionLocal() as session:
yield session
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_collector_db] = override_get_db
@pytest.fixture(autouse=True)
def mock_auth(monkeypatch):
async def mock_call(self, request: Request):
user = UserBase(user_id="test-id", name="test-user", role="admin")
request.state.user = user
return user
monkeypatch.setattr(JWTBearer, "__call__", mock_call)
@pytest_asyncio.fixture
async def client() -> AsyncGenerator[AsyncClient, None]:
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
yield client