|
|
|
|
@ -15,7 +15,7 @@ from typing import AsyncGenerator
|
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
|
|
|
|
from sqlalchemy.orm import DeclarativeBase, sessionmaker
|
|
|
|
|
|
|
|
|
|
from src.config import SQLALCHEMY_DATABASE_URI
|
|
|
|
|
from src.config import SQLALCHEMY_DATABASE_URI, COLLECTOR_URI
|
|
|
|
|
|
|
|
|
|
engine = create_async_engine(
|
|
|
|
|
SQLALCHEMY_DATABASE_URI,
|
|
|
|
|
@ -23,6 +23,12 @@ engine = create_async_engine(
|
|
|
|
|
future=True
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
collector_engine = create_async_engine(
|
|
|
|
|
COLLECTOR_URI,
|
|
|
|
|
echo=False,
|
|
|
|
|
future=True
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
async_session = sessionmaker(
|
|
|
|
|
engine,
|
|
|
|
|
class_=AsyncSession,
|
|
|
|
|
@ -31,13 +37,22 @@ async_session = sessionmaker(
|
|
|
|
|
autoflush=False,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
collector_async_session = sessionmaker(
|
|
|
|
|
collector_engine,
|
|
|
|
|
class_=AsyncSession,
|
|
|
|
|
expire_on_commit=False,
|
|
|
|
|
autocommit=False,
|
|
|
|
|
autoflush=False,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def get_db(request: Request):
|
|
|
|
|
return request.state.db
|
|
|
|
|
|
|
|
|
|
def get_collector_db(request: Request):
|
|
|
|
|
return request.state.collector_db
|
|
|
|
|
|
|
|
|
|
DbSession = Annotated[AsyncSession, Depends(get_db)]
|
|
|
|
|
|
|
|
|
|
CollectorDbSession = Annotated[AsyncSession, Depends(get_collector_db)]
|
|
|
|
|
|
|
|
|
|
class CustomBase:
|
|
|
|
|
__repr_attrs__ = []
|
|
|
|
|
@ -113,6 +128,18 @@ async def get_session():
|
|
|
|
|
finally:
|
|
|
|
|
await session.close()
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
|
async def get_collector_session():
|
|
|
|
|
"""Context manager to ensure the collector session is closed after use."""
|
|
|
|
|
session = collector_async_session()
|
|
|
|
|
try:
|
|
|
|
|
yield session
|
|
|
|
|
await session.commit()
|
|
|
|
|
except:
|
|
|
|
|
await session.rollback()
|
|
|
|
|
raise
|
|
|
|
|
finally:
|
|
|
|
|
await session.close()
|
|
|
|
|
|
|
|
|
|
def resolve_table_name(name):
|
|
|
|
|
"""Resolves table names to their mapped names."""
|
|
|
|
|
@ -122,20 +149,6 @@ def resolve_table_name(name):
|
|
|
|
|
|
|
|
|
|
raise_attribute_error = object()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# def resolve_attr(obj, attr, default=None):
|
|
|
|
|
# """Attempts to access attr via dotted notation, returns none if attr does not exist."""
|
|
|
|
|
# try:
|
|
|
|
|
# return functools.reduce(getattr, attr.split("."), obj)
|
|
|
|
|
# except AttributeError:
|
|
|
|
|
# return default
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# def get_model_name_by_tablename(table_fullname: str) -> str:
|
|
|
|
|
# """Returns the model name of a given table."""
|
|
|
|
|
# return get_class_by_tablename(table_fullname=table_fullname).__name__
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_class_by_tablename(table_fullname: str) -> Any:
|
|
|
|
|
"""Return class reference mapped to table."""
|
|
|
|
|
|
|
|
|
|
@ -149,8 +162,3 @@ def get_class_by_tablename(table_fullname: str) -> Any:
|
|
|
|
|
mapped_class = _find_class(mapped_name)
|
|
|
|
|
|
|
|
|
|
return mapped_class
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# def get_table_name_by_class_instance(class_instance: Base) -> str:
|
|
|
|
|
# """Returns the name of the table for a given class instance."""
|
|
|
|
|
# return class_instance._sa_instance_state.mapper.mapped_table.name
|
|
|
|
|
|