# src/database.py import functools import re from contextlib import contextmanager from typing import Annotated, Any, AsyncGenerator from fastapi import Depends from pydantic import BaseModel from sqlalchemy import create_engine, inspect from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.ext.declarative import declarative_base, declared_attr from sqlalchemy.orm import (DeclarativeBase, Session, object_session, sessionmaker) from sqlalchemy.sql.expression import true from sqlalchemy_utils import get_mapper from starlette.requests import Request from src.config import SQLALCHEMY_DATABASE_URI engine = create_async_engine(SQLALCHEMY_DATABASE_URI, echo=False, future=True) async_session = sessionmaker( engine, class_=AsyncSession, expire_on_commit=False, autocommit=False, autoflush=False, ) def get_db(request: Request): return request.state.db DbSession = Annotated[AsyncSession, Depends(get_db)] class CustomBase: __repr_attrs__ = [] __repr_max_length__ = 15 @declared_attr def __tablename__(self): return resolve_table_name(self.__name__) def dict(self): """Returns a dict representation of a model.""" return {c.name: getattr(self, c.name) for c in self.__table__.columns} @property def _id_str(self): ids = inspect(self).identity if ids: return "-".join([str(x) for x in ids]) if len(ids) > 1 else str(ids[0]) else: return "None" @property def _repr_attrs_str(self): max_length = self.__repr_max_length__ values = [] single = len(self.__repr_attrs__) == 1 for key in self.__repr_attrs__: if not hasattr(self, key): raise KeyError( "{} has incorrect attribute '{}' in " "__repr__attrs__".format(self.__class__, key) ) value = getattr(self, key) wrap_in_quote = isinstance(value, str) value = str(value) if len(value) > max_length: value = value[:max_length] + "..." if wrap_in_quote: value = "'{}'".format(value) values.append(value if single else "{}:{}".format(key, value)) return " ".join(values) def __repr__(self): # get id like '#123' id_str = ("#" + self._id_str) if self._id_str else "" # join class name, id and repr_attrs return "<{} {}{}>".format( self.__class__.__name__, id_str, " " + self._repr_attrs_str if self._repr_attrs_str else "", ) Base = declarative_base(cls=CustomBase) # make_searchable(Base.metadata) @contextmanager async def get_session(): """Context manager to ensure the session is closed after use.""" session = async_session() try: yield session await session.commit() except: await session.rollback() raise finally: await session.close() def resolve_table_name(name): """Resolves table names to their mapped names.""" names = re.split("(?=[A-Z])", name) # noqa return "_".join([x.lower() for x in names if x]) raise_attribute_error = object() # def resolve_attr(obj, attr, default=None): # """Attempts to access attr via dotted notation, returns none if attr does not exist.""" # try: # return functools.reduce(getattr, attr.split("."), obj) # except AttributeError: # return default # def get_model_name_by_tablename(table_fullname: str) -> str: # """Returns the model name of a given table.""" # return get_class_by_tablename(table_fullname=table_fullname).__name__ def get_class_by_tablename(table_fullname: str) -> Any: """Return class reference mapped to table.""" def _find_class(name): for c in Base._decl_class_registry.values(): if hasattr(c, "__table__"): if c.__table__.fullname.lower() == name.lower(): return c mapped_name = resolve_table_name(table_fullname) mapped_class = _find_class(mapped_name) return mapped_class # def get_table_name_by_class_instance(class_instance: Base) -> str: # """Returns the name of the table for a given class instance.""" # return class_instance._sa_instance_state.mapper.mapped_table.name