You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

95 lines
2.9 KiB
Python

import os
# Set dummy environment variables for testing
os.environ["DATABASE_HOSTNAME"] = "localhost"
os.environ["DATABASE_CREDENTIAL_USER"] = "test"
os.environ["DATABASE_CREDENTIAL_PASSWORD"] = "test"
os.environ["COLLECTOR_CREDENTIAL_USER"] = "test"
os.environ["COLLECTOR_CREDENTIAL_PASSWORD"] = "test"
os.environ["DEV_USERNAME"] = "test"
os.environ["DEV_PASSWORD"] = "test"
import asyncio
from typing import AsyncGenerator, Generator
import pytest
import pytest_asyncio
from httpx import AsyncClient, ASGITransport
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
from fastapi import Request
from src.main import app
from src.database.core import Base, get_db, get_collector_db
from src.auth.service import JWTBearer
from src.auth.model import UserBase
# Import all models to register them with Base
import src.acquisition_cost.model
import src.equipment.model
import src.equipment_master.model
import src.manpower_cost.model
import src.manpower_master.model
import src.masterdata.model
import src.masterdata_simulations.model
import src.plant_fs_transaction_data.model
import src.plant_masterdata.model
import src.plant_transaction_data.model
import src.plant_transaction_data_simulations.model
import src.simulations.model
import src.uploaded_file.model
import src.yeardata.model
# Test database URL
TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
engine = create_async_engine(
TEST_DATABASE_URL,
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
TestingSessionLocal = sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False,
autocommit=False,
autoflush=False,
)
@pytest.fixture(scope="session")
def event_loop():
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
yield loop
# loop.close() # Avoid closing if it might be shared
@pytest_asyncio.fixture(autouse=True)
async def setup_db():
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
async def override_get_db(request: Request = None):
async with TestingSessionLocal() as session:
yield session
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_collector_db] = override_get_db
@pytest.fixture(autouse=True)
def mock_auth(monkeypatch):
async def mock_call(self, request: Request):
user = UserBase(user_id="test-id", name="test-user", role="admin")
request.state.user = user
return user
monkeypatch.setattr(JWTBearer, "__call__", mock_call)
@pytest_asyncio.fixture
async def client() -> AsyncGenerator[AsyncClient, None]:
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
yield client