From 01a24a39d2adae569599ef42a8a6fbc1665b9b20 Mon Sep 17 00:00:00 2001 From: Cizz22 Date: Wed, 13 Aug 2025 09:37:14 +0700 Subject: [PATCH] Update core.py --- src/database/core.py | 144 +++++++++++++++---------------------------- 1 file changed, 51 insertions(+), 93 deletions(-) diff --git a/src/database/core.py b/src/database/core.py index 28c0334..ef637dd 100644 --- a/src/database/core.py +++ b/src/database/core.py @@ -1,25 +1,20 @@ # src/database.py -import functools import re -from contextlib import contextmanager -from typing import Annotated, Any, AsyncGenerator +from contextlib import asynccontextmanager +from typing import Annotated, Any from fastapi import Depends -from pydantic import BaseModel -from sqlalchemy import create_engine, inspect -from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker from sqlalchemy.ext.declarative import declarative_base, declared_attr -from sqlalchemy.orm import (DeclarativeBase, Session, object_session, - sessionmaker) -from sqlalchemy.sql.expression import true -from sqlalchemy_utils import get_mapper +from sqlalchemy.orm import DeclarativeBase from starlette.requests import Request -from src.config import SQLALCHEMY_DATABASE_URI +from src.config import SQLALCHEMY_DATABASE_URI, SQLALCHEMY_COLLECTOR_URI engine = create_async_engine(SQLALCHEMY_DATABASE_URI, echo=False, future=True) +collector_engine = create_async_engine(SQLALCHEMY_COLLECTOR_URI, echo=False, future=True) -async_session = sessionmaker( +async_session = async_sessionmaker( engine, class_=AsyncSession, expire_on_commit=False, @@ -27,77 +22,50 @@ async_session = sessionmaker( autoflush=False, ) +async_collector_session = async_sessionmaker( + collector_engine, + class_=AsyncSession, + expire_on_commit=False, + autocommit=False, + autoflush=False, +) + +async def get_collector_db(request: Request): + return request.state.collector_db + def get_db(request: Request): return request.state.db DbSession = Annotated[AsyncSession, Depends(get_db)] +CollectorDbSession = Annotated[AsyncSession, Depends(get_collector_db)] -class CustomBase: - __repr_attrs__ = [] - __repr_max_length__ = 15 +class Base(DeclarativeBase): + @declared_attr.directive + def __tablename__(cls) -> str: + return resolve_table_name(cls.__name__) - @declared_attr - def __tablename__(self): - return resolve_table_name(self.__name__) + def dict(self): + """Returns a dict representation of a model.""" + if hasattr(self, '__table__'): + return {c.name: getattr(self, c.name) for c in self.__table__.columns} + return {} + +class CollectorBase(DeclarativeBase): + @declared_attr.directive + def __tablename__(cls) -> str: + return resolve_table_name(cls.__name__) def dict(self): """Returns a dict representation of a model.""" - return {c.name: getattr(self, c.name) for c in self.__table__.columns} - - @property - def _id_str(self): - ids = inspect(self).identity - if ids: - return "-".join([str(x) for x in ids]) if len(ids) > 1 else str(ids[0]) - else: - return "None" - - @property - def _repr_attrs_str(self): - max_length = self.__repr_max_length__ - - values = [] - single = len(self.__repr_attrs__) == 1 - for key in self.__repr_attrs__: - if not hasattr(self, key): - raise KeyError( - "{} has incorrect attribute '{}' in " - "__repr__attrs__".format(self.__class__, key) - ) - value = getattr(self, key) - wrap_in_quote = isinstance(value, str) - - value = str(value) - if len(value) > max_length: - value = value[:max_length] + "..." - - if wrap_in_quote: - value = "'{}'".format(value) - values.append(value if single else "{}:{}".format(key, value)) - - return " ".join(values) - - def __repr__(self): - # get id like '#123' - id_str = ("#" + self._id_str) if self._id_str else "" - # join class name, id and repr_attrs - return "<{} {}{}>".format( - self.__class__.__name__, - id_str, - " " + self._repr_attrs_str if self._repr_attrs_str else "", - ) - - -Base = declarative_base(cls=CustomBase) -# make_searchable(Base.metadata) - - -@contextmanager -async def get_session(): - """Context manager to ensure the session is closed after use.""" + if hasattr(self, '__table__'): + return {c.name: getattr(self, c.name) for c in self.__table__.columns} + return {} + +@asynccontextmanager +async def get_main_session(): session = async_session() try: yield session @@ -108,6 +76,17 @@ async def get_session(): finally: await session.close() +@asynccontextmanager +async def get_collector_session(): + session = async_collector_session() + try: + yield session + await session.commit() + except: + await session.rollback() + raise + finally: + await session.close() def resolve_table_name(name): """Resolves table names to their mapped names.""" @@ -115,27 +94,11 @@ def resolve_table_name(name): return "_".join([x.lower() for x in names if x]) -raise_attribute_error = object() - - -# def resolve_attr(obj, attr, default=None): -# """Attempts to access attr via dotted notation, returns none if attr does not exist.""" -# try: -# return functools.reduce(getattr, attr.split("."), obj) -# except AttributeError: -# return default - - -# def get_model_name_by_tablename(table_fullname: str) -> str: -# """Returns the model name of a given table.""" -# return get_class_by_tablename(table_fullname=table_fullname).__name__ - - def get_class_by_tablename(table_fullname: str) -> Any: """Return class reference mapped to table.""" def _find_class(name): - for c in Base._decl_class_registry.values(): + for c in Base.registry._class_registry.values(): if hasattr(c, "__table__"): if c.__table__.fullname.lower() == name.lower(): return c @@ -144,8 +107,3 @@ def get_class_by_tablename(table_fullname: str) -> Any: mapped_class = _find_class(mapped_name) return mapped_class - - -# def get_table_name_by_class_instance(class_instance: Base) -> str: -# """Returns the name of the table for a given class instance.""" -# return class_instance._sa_instance_state.mapper.mapped_table.name