You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

69 lines
1.7 KiB
Python

import asyncio
from typing import AsyncGenerator, Generator
import pytest
from httpx import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
from sqlalchemy_utils import database_exists, drop_database
from starlette.config import environ
from starlette.testclient import TestClient
# from src.database import Base, get_db
# from src.main import app
# Test database URL
TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
engine = create_async_engine(
TEST_DATABASE_URL,
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
async_session = sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False,
autocommit=False,
autoflush=False,
)
async def override_get_db() -> AsyncGenerator[AsyncSession, None]:
async with async_session() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
finally:
await session.close()
app.dependency_overrides[get_db] = override_get_db
@pytest.fixture(scope="session")
def event_loop() -> Generator:
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
@pytest.fixture(autouse=True)
async def setup_db() -> AsyncGenerator[None, None]:
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
@pytest.fixture
async def client() -> AsyncGenerator[AsyncClient, None]:
async with AsyncClient(app=app, base_url="http://test") as client:
yield client