|
|
|
|
@ -1,25 +1,20 @@
|
|
|
|
|
# src/database.py
|
|
|
|
|
import functools
|
|
|
|
|
import re
|
|
|
|
|
from contextlib import contextmanager
|
|
|
|
|
from typing import Annotated, Any, AsyncGenerator
|
|
|
|
|
from contextlib import asynccontextmanager
|
|
|
|
|
from typing import Annotated, Any
|
|
|
|
|
|
|
|
|
|
from fastapi import Depends
|
|
|
|
|
from pydantic import BaseModel
|
|
|
|
|
from sqlalchemy import create_engine, inspect
|
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
|
|
|
|
from sqlalchemy.ext.declarative import declarative_base, declared_attr
|
|
|
|
|
from sqlalchemy.orm import (DeclarativeBase, Session, object_session,
|
|
|
|
|
sessionmaker)
|
|
|
|
|
from sqlalchemy.sql.expression import true
|
|
|
|
|
from sqlalchemy_utils import get_mapper
|
|
|
|
|
from sqlalchemy.orm import DeclarativeBase
|
|
|
|
|
from starlette.requests import Request
|
|
|
|
|
|
|
|
|
|
from src.config import SQLALCHEMY_DATABASE_URI
|
|
|
|
|
from src.config import SQLALCHEMY_DATABASE_URI, SQLALCHEMY_COLLECTOR_URI
|
|
|
|
|
|
|
|
|
|
engine = create_async_engine(SQLALCHEMY_DATABASE_URI, echo=False, future=True)
|
|
|
|
|
collector_engine = create_async_engine(SQLALCHEMY_COLLECTOR_URI, echo=False, future=True)
|
|
|
|
|
|
|
|
|
|
async_session = sessionmaker(
|
|
|
|
|
async_session = async_sessionmaker(
|
|
|
|
|
engine,
|
|
|
|
|
class_=AsyncSession,
|
|
|
|
|
expire_on_commit=False,
|
|
|
|
|
@ -27,77 +22,50 @@ async_session = sessionmaker(
|
|
|
|
|
autoflush=False,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
async_collector_session = async_sessionmaker(
|
|
|
|
|
collector_engine,
|
|
|
|
|
class_=AsyncSession,
|
|
|
|
|
expire_on_commit=False,
|
|
|
|
|
autocommit=False,
|
|
|
|
|
autoflush=False,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
async def get_collector_db(request: Request):
|
|
|
|
|
return request.state.collector_db
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_db(request: Request):
|
|
|
|
|
return request.state.db
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DbSession = Annotated[AsyncSession, Depends(get_db)]
|
|
|
|
|
CollectorDbSession = Annotated[AsyncSession, Depends(get_collector_db)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CustomBase:
|
|
|
|
|
__repr_attrs__ = []
|
|
|
|
|
__repr_max_length__ = 15
|
|
|
|
|
class Base(DeclarativeBase):
|
|
|
|
|
@declared_attr.directive
|
|
|
|
|
def __tablename__(cls) -> str:
|
|
|
|
|
return resolve_table_name(cls.__name__)
|
|
|
|
|
|
|
|
|
|
@declared_attr
|
|
|
|
|
def __tablename__(self):
|
|
|
|
|
return resolve_table_name(self.__name__)
|
|
|
|
|
def dict(self):
|
|
|
|
|
"""Returns a dict representation of a model."""
|
|
|
|
|
if hasattr(self, '__table__'):
|
|
|
|
|
return {c.name: getattr(self, c.name) for c in self.__table__.columns}
|
|
|
|
|
return {}
|
|
|
|
|
|
|
|
|
|
class CollectorBase(DeclarativeBase):
|
|
|
|
|
@declared_attr.directive
|
|
|
|
|
def __tablename__(cls) -> str:
|
|
|
|
|
return resolve_table_name(cls.__name__)
|
|
|
|
|
|
|
|
|
|
def dict(self):
|
|
|
|
|
"""Returns a dict representation of a model."""
|
|
|
|
|
return {c.name: getattr(self, c.name) for c in self.__table__.columns}
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def _id_str(self):
|
|
|
|
|
ids = inspect(self).identity
|
|
|
|
|
if ids:
|
|
|
|
|
return "-".join([str(x) for x in ids]) if len(ids) > 1 else str(ids[0])
|
|
|
|
|
else:
|
|
|
|
|
return "None"
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def _repr_attrs_str(self):
|
|
|
|
|
max_length = self.__repr_max_length__
|
|
|
|
|
|
|
|
|
|
values = []
|
|
|
|
|
single = len(self.__repr_attrs__) == 1
|
|
|
|
|
for key in self.__repr_attrs__:
|
|
|
|
|
if not hasattr(self, key):
|
|
|
|
|
raise KeyError(
|
|
|
|
|
"{} has incorrect attribute '{}' in "
|
|
|
|
|
"__repr__attrs__".format(self.__class__, key)
|
|
|
|
|
)
|
|
|
|
|
value = getattr(self, key)
|
|
|
|
|
wrap_in_quote = isinstance(value, str)
|
|
|
|
|
|
|
|
|
|
value = str(value)
|
|
|
|
|
if len(value) > max_length:
|
|
|
|
|
value = value[:max_length] + "..."
|
|
|
|
|
|
|
|
|
|
if wrap_in_quote:
|
|
|
|
|
value = "'{}'".format(value)
|
|
|
|
|
values.append(value if single else "{}:{}".format(key, value))
|
|
|
|
|
|
|
|
|
|
return " ".join(values)
|
|
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
|
# get id like '#123'
|
|
|
|
|
id_str = ("#" + self._id_str) if self._id_str else ""
|
|
|
|
|
# join class name, id and repr_attrs
|
|
|
|
|
return "<{} {}{}>".format(
|
|
|
|
|
self.__class__.__name__,
|
|
|
|
|
id_str,
|
|
|
|
|
" " + self._repr_attrs_str if self._repr_attrs_str else "",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Base = declarative_base(cls=CustomBase)
|
|
|
|
|
# make_searchable(Base.metadata)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
|
async def get_session():
|
|
|
|
|
"""Context manager to ensure the session is closed after use."""
|
|
|
|
|
if hasattr(self, '__table__'):
|
|
|
|
|
return {c.name: getattr(self, c.name) for c in self.__table__.columns}
|
|
|
|
|
return {}
|
|
|
|
|
|
|
|
|
|
@asynccontextmanager
|
|
|
|
|
async def get_main_session():
|
|
|
|
|
session = async_session()
|
|
|
|
|
try:
|
|
|
|
|
yield session
|
|
|
|
|
@ -108,6 +76,17 @@ async def get_session():
|
|
|
|
|
finally:
|
|
|
|
|
await session.close()
|
|
|
|
|
|
|
|
|
|
@asynccontextmanager
|
|
|
|
|
async def get_collector_session():
|
|
|
|
|
session = async_collector_session()
|
|
|
|
|
try:
|
|
|
|
|
yield session
|
|
|
|
|
await session.commit()
|
|
|
|
|
except:
|
|
|
|
|
await session.rollback()
|
|
|
|
|
raise
|
|
|
|
|
finally:
|
|
|
|
|
await session.close()
|
|
|
|
|
|
|
|
|
|
def resolve_table_name(name):
|
|
|
|
|
"""Resolves table names to their mapped names."""
|
|
|
|
|
@ -115,27 +94,11 @@ def resolve_table_name(name):
|
|
|
|
|
return "_".join([x.lower() for x in names if x])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
raise_attribute_error = object()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# def resolve_attr(obj, attr, default=None):
|
|
|
|
|
# """Attempts to access attr via dotted notation, returns none if attr does not exist."""
|
|
|
|
|
# try:
|
|
|
|
|
# return functools.reduce(getattr, attr.split("."), obj)
|
|
|
|
|
# except AttributeError:
|
|
|
|
|
# return default
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# def get_model_name_by_tablename(table_fullname: str) -> str:
|
|
|
|
|
# """Returns the model name of a given table."""
|
|
|
|
|
# return get_class_by_tablename(table_fullname=table_fullname).__name__
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_class_by_tablename(table_fullname: str) -> Any:
|
|
|
|
|
"""Return class reference mapped to table."""
|
|
|
|
|
|
|
|
|
|
def _find_class(name):
|
|
|
|
|
for c in Base._decl_class_registry.values():
|
|
|
|
|
for c in Base.registry._class_registry.values():
|
|
|
|
|
if hasattr(c, "__table__"):
|
|
|
|
|
if c.__table__.fullname.lower() == name.lower():
|
|
|
|
|
return c
|
|
|
|
|
@ -144,8 +107,3 @@ def get_class_by_tablename(table_fullname: str) -> Any:
|
|
|
|
|
mapped_class = _find_class(mapped_name)
|
|
|
|
|
|
|
|
|
|
return mapped_class
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# def get_table_name_by_class_instance(class_instance: Base) -> str:
|
|
|
|
|
# """Returns the name of the table for a given class instance."""
|
|
|
|
|
# return class_instance._sa_instance_state.mapper.mapped_table.name
|
|
|
|
|
|