You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

152 lines
4.3 KiB
Python

# src/database.py
import functools
import re
from contextlib import contextmanager
from typing import Annotated, Any, AsyncGenerator
from fastapi import Depends
from pydantic import BaseModel
from sqlalchemy import create_engine, inspect
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.ext.declarative import declarative_base, declared_attr
from sqlalchemy.orm import (DeclarativeBase, Session, object_session,
sessionmaker)
from sqlalchemy.sql.expression import true
from sqlalchemy_utils import get_mapper
from starlette.requests import Request
from src.config import SQLALCHEMY_DATABASE_URI
engine = create_async_engine(SQLALCHEMY_DATABASE_URI, echo=False, future=True)
async_session = sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False,
autocommit=False,
autoflush=False,
)
def get_db(request: Request):
return request.state.db
DbSession = Annotated[AsyncSession, Depends(get_db)]
class CustomBase:
__repr_attrs__ = []
__repr_max_length__ = 15
@declared_attr
def __tablename__(self):
return resolve_table_name(self.__name__)
def dict(self):
"""Returns a dict representation of a model."""
return {c.name: getattr(self, c.name) for c in self.__table__.columns}
@property
def _id_str(self):
ids = inspect(self).identity
if ids:
return "-".join([str(x) for x in ids]) if len(ids) > 1 else str(ids[0])
else:
return "None"
@property
def _repr_attrs_str(self):
max_length = self.__repr_max_length__
values = []
single = len(self.__repr_attrs__) == 1
for key in self.__repr_attrs__:
if not hasattr(self, key):
raise KeyError(
"{} has incorrect attribute '{}' in "
"__repr__attrs__".format(self.__class__, key)
)
value = getattr(self, key)
wrap_in_quote = isinstance(value, str)
value = str(value)
if len(value) > max_length:
value = value[:max_length] + "..."
if wrap_in_quote:
value = "'{}'".format(value)
values.append(value if single else "{}:{}".format(key, value))
return " ".join(values)
def __repr__(self):
# get id like '#123'
id_str = ("#" + self._id_str) if self._id_str else ""
# join class name, id and repr_attrs
return "<{} {}{}>".format(
self.__class__.__name__,
id_str,
" " + self._repr_attrs_str if self._repr_attrs_str else "",
)
Base = declarative_base(cls=CustomBase)
# make_searchable(Base.metadata)
@contextmanager
async def get_session():
"""Context manager to ensure the session is closed after use."""
session = async_session()
try:
yield session
await session.commit()
except:
await session.rollback()
raise
finally:
await session.close()
def resolve_table_name(name):
"""Resolves table names to their mapped names."""
names = re.split("(?=[A-Z])", name) # noqa
return "_".join([x.lower() for x in names if x])
raise_attribute_error = object()
# def resolve_attr(obj, attr, default=None):
# """Attempts to access attr via dotted notation, returns none if attr does not exist."""
# try:
# return functools.reduce(getattr, attr.split("."), obj)
# except AttributeError:
# return default
# def get_model_name_by_tablename(table_fullname: str) -> str:
# """Returns the model name of a given table."""
# return get_class_by_tablename(table_fullname=table_fullname).__name__
def get_class_by_tablename(table_fullname: str) -> Any:
"""Return class reference mapped to table."""
def _find_class(name):
for c in Base._decl_class_registry.values():
if hasattr(c, "__table__"):
if c.__table__.fullname.lower() == name.lower():
return c
mapped_name = resolve_table_name(table_fullname)
mapped_class = _find_class(mapped_name)
return mapped_class
# def get_table_name_by_class_instance(class_instance: Base) -> str:
# """Returns the name of the table for a given class instance."""
# return class_instance._sa_instance_state.mapper.mapped_table.name