Update core.py

feature/reliability_stat
Cizz22 5 months ago
parent 536c8fc889
commit 01a24a39d2

@ -1,25 +1,20 @@
# src/database.py # src/database.py
import functools
import re import re
from contextlib import contextmanager from contextlib import asynccontextmanager
from typing import Annotated, Any, AsyncGenerator from typing import Annotated, Any
from fastapi import Depends from fastapi import Depends
from pydantic import BaseModel from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from sqlalchemy import create_engine, inspect
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.ext.declarative import declarative_base, declared_attr from sqlalchemy.ext.declarative import declarative_base, declared_attr
from sqlalchemy.orm import (DeclarativeBase, Session, object_session, from sqlalchemy.orm import DeclarativeBase
sessionmaker)
from sqlalchemy.sql.expression import true
from sqlalchemy_utils import get_mapper
from starlette.requests import Request from starlette.requests import Request
from src.config import SQLALCHEMY_DATABASE_URI from src.config import SQLALCHEMY_DATABASE_URI, SQLALCHEMY_COLLECTOR_URI
engine = create_async_engine(SQLALCHEMY_DATABASE_URI, echo=False, future=True) engine = create_async_engine(SQLALCHEMY_DATABASE_URI, echo=False, future=True)
collector_engine = create_async_engine(SQLALCHEMY_COLLECTOR_URI, echo=False, future=True)
async_session = sessionmaker( async_session = async_sessionmaker(
engine, engine,
class_=AsyncSession, class_=AsyncSession,
expire_on_commit=False, expire_on_commit=False,
@ -27,77 +22,50 @@ async_session = sessionmaker(
autoflush=False, autoflush=False,
) )
async_collector_session = async_sessionmaker(
collector_engine,
class_=AsyncSession,
expire_on_commit=False,
autocommit=False,
autoflush=False,
)
async def get_collector_db(request: Request):
return request.state.collector_db
def get_db(request: Request): def get_db(request: Request):
return request.state.db return request.state.db
DbSession = Annotated[AsyncSession, Depends(get_db)] DbSession = Annotated[AsyncSession, Depends(get_db)]
CollectorDbSession = Annotated[AsyncSession, Depends(get_collector_db)]
class CustomBase: class Base(DeclarativeBase):
__repr_attrs__ = [] @declared_attr.directive
__repr_max_length__ = 15 def __tablename__(cls) -> str:
return resolve_table_name(cls.__name__)
@declared_attr
def __tablename__(self):
return resolve_table_name(self.__name__)
def dict(self): def dict(self):
"""Returns a dict representation of a model.""" """Returns a dict representation of a model."""
if hasattr(self, '__table__'):
return {c.name: getattr(self, c.name) for c in self.__table__.columns} return {c.name: getattr(self, c.name) for c in self.__table__.columns}
return {}
@property class CollectorBase(DeclarativeBase):
def _id_str(self): @declared_attr.directive
ids = inspect(self).identity def __tablename__(cls) -> str:
if ids: return resolve_table_name(cls.__name__)
return "-".join([str(x) for x in ids]) if len(ids) > 1 else str(ids[0])
else:
return "None"
@property
def _repr_attrs_str(self):
max_length = self.__repr_max_length__
values = []
single = len(self.__repr_attrs__) == 1
for key in self.__repr_attrs__:
if not hasattr(self, key):
raise KeyError(
"{} has incorrect attribute '{}' in "
"__repr__attrs__".format(self.__class__, key)
)
value = getattr(self, key)
wrap_in_quote = isinstance(value, str)
value = str(value)
if len(value) > max_length:
value = value[:max_length] + "..."
if wrap_in_quote:
value = "'{}'".format(value)
values.append(value if single else "{}:{}".format(key, value))
return " ".join(values)
def __repr__(self):
# get id like '#123'
id_str = ("#" + self._id_str) if self._id_str else ""
# join class name, id and repr_attrs
return "<{} {}{}>".format(
self.__class__.__name__,
id_str,
" " + self._repr_attrs_str if self._repr_attrs_str else "",
)
Base = declarative_base(cls=CustomBase)
# make_searchable(Base.metadata)
def dict(self):
"""Returns a dict representation of a model."""
if hasattr(self, '__table__'):
return {c.name: getattr(self, c.name) for c in self.__table__.columns}
return {}
@contextmanager @asynccontextmanager
async def get_session(): async def get_main_session():
"""Context manager to ensure the session is closed after use."""
session = async_session() session = async_session()
try: try:
yield session yield session
@ -108,6 +76,17 @@ async def get_session():
finally: finally:
await session.close() await session.close()
@asynccontextmanager
async def get_collector_session():
session = async_collector_session()
try:
yield session
await session.commit()
except:
await session.rollback()
raise
finally:
await session.close()
def resolve_table_name(name): def resolve_table_name(name):
"""Resolves table names to their mapped names.""" """Resolves table names to their mapped names."""
@ -115,27 +94,11 @@ def resolve_table_name(name):
return "_".join([x.lower() for x in names if x]) return "_".join([x.lower() for x in names if x])
raise_attribute_error = object()
# def resolve_attr(obj, attr, default=None):
# """Attempts to access attr via dotted notation, returns none if attr does not exist."""
# try:
# return functools.reduce(getattr, attr.split("."), obj)
# except AttributeError:
# return default
# def get_model_name_by_tablename(table_fullname: str) -> str:
# """Returns the model name of a given table."""
# return get_class_by_tablename(table_fullname=table_fullname).__name__
def get_class_by_tablename(table_fullname: str) -> Any: def get_class_by_tablename(table_fullname: str) -> Any:
"""Return class reference mapped to table.""" """Return class reference mapped to table."""
def _find_class(name): def _find_class(name):
for c in Base._decl_class_registry.values(): for c in Base.registry._class_registry.values():
if hasattr(c, "__table__"): if hasattr(c, "__table__"):
if c.__table__.fullname.lower() == name.lower(): if c.__table__.fullname.lower() == name.lower():
return c return c
@ -144,8 +107,3 @@ def get_class_by_tablename(table_fullname: str) -> Any:
mapped_class = _find_class(mapped_name) mapped_class = _find_class(mapped_name)
return mapped_class return mapped_class
# def get_table_name_by_class_instance(class_instance: Base) -> str:
# """Returns the name of the table for a given class instance."""
# return class_instance._sa_instance_state.mapper.mapped_table.name

Loading…
Cancel
Save