Update core.py

feature/reliability_stat
Cizz22 5 months ago
parent 536c8fc889
commit 01a24a39d2

@ -1,25 +1,20 @@
# src/database.py
import functools
import re
from contextlib import contextmanager
from typing import Annotated, Any, AsyncGenerator
from contextlib import asynccontextmanager
from typing import Annotated, Any
from fastapi import Depends
from pydantic import BaseModel
from sqlalchemy import create_engine, inspect
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from sqlalchemy.ext.declarative import declarative_base, declared_attr
from sqlalchemy.orm import (DeclarativeBase, Session, object_session,
sessionmaker)
from sqlalchemy.sql.expression import true
from sqlalchemy_utils import get_mapper
from sqlalchemy.orm import DeclarativeBase
from starlette.requests import Request
from src.config import SQLALCHEMY_DATABASE_URI
from src.config import SQLALCHEMY_DATABASE_URI, SQLALCHEMY_COLLECTOR_URI
engine = create_async_engine(SQLALCHEMY_DATABASE_URI, echo=False, future=True)
collector_engine = create_async_engine(SQLALCHEMY_COLLECTOR_URI, echo=False, future=True)
async_session = sessionmaker(
async_session = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False,
@ -27,77 +22,50 @@ async_session = sessionmaker(
autoflush=False,
)
async_collector_session = async_sessionmaker(
collector_engine,
class_=AsyncSession,
expire_on_commit=False,
autocommit=False,
autoflush=False,
)
async def get_collector_db(request: Request):
return request.state.collector_db
def get_db(request: Request):
return request.state.db
DbSession = Annotated[AsyncSession, Depends(get_db)]
CollectorDbSession = Annotated[AsyncSession, Depends(get_collector_db)]
class CustomBase:
__repr_attrs__ = []
__repr_max_length__ = 15
class Base(DeclarativeBase):
@declared_attr.directive
def __tablename__(cls) -> str:
return resolve_table_name(cls.__name__)
@declared_attr
def __tablename__(self):
return resolve_table_name(self.__name__)
def dict(self):
"""Returns a dict representation of a model."""
if hasattr(self, '__table__'):
return {c.name: getattr(self, c.name) for c in self.__table__.columns}
return {}
class CollectorBase(DeclarativeBase):
@declared_attr.directive
def __tablename__(cls) -> str:
return resolve_table_name(cls.__name__)
def dict(self):
"""Returns a dict representation of a model."""
return {c.name: getattr(self, c.name) for c in self.__table__.columns}
@property
def _id_str(self):
ids = inspect(self).identity
if ids:
return "-".join([str(x) for x in ids]) if len(ids) > 1 else str(ids[0])
else:
return "None"
@property
def _repr_attrs_str(self):
max_length = self.__repr_max_length__
values = []
single = len(self.__repr_attrs__) == 1
for key in self.__repr_attrs__:
if not hasattr(self, key):
raise KeyError(
"{} has incorrect attribute '{}' in "
"__repr__attrs__".format(self.__class__, key)
)
value = getattr(self, key)
wrap_in_quote = isinstance(value, str)
value = str(value)
if len(value) > max_length:
value = value[:max_length] + "..."
if wrap_in_quote:
value = "'{}'".format(value)
values.append(value if single else "{}:{}".format(key, value))
return " ".join(values)
def __repr__(self):
# get id like '#123'
id_str = ("#" + self._id_str) if self._id_str else ""
# join class name, id and repr_attrs
return "<{} {}{}>".format(
self.__class__.__name__,
id_str,
" " + self._repr_attrs_str if self._repr_attrs_str else "",
)
Base = declarative_base(cls=CustomBase)
# make_searchable(Base.metadata)
@contextmanager
async def get_session():
"""Context manager to ensure the session is closed after use."""
if hasattr(self, '__table__'):
return {c.name: getattr(self, c.name) for c in self.__table__.columns}
return {}
@asynccontextmanager
async def get_main_session():
session = async_session()
try:
yield session
@ -108,6 +76,17 @@ async def get_session():
finally:
await session.close()
@asynccontextmanager
async def get_collector_session():
session = async_collector_session()
try:
yield session
await session.commit()
except:
await session.rollback()
raise
finally:
await session.close()
def resolve_table_name(name):
"""Resolves table names to their mapped names."""
@ -115,27 +94,11 @@ def resolve_table_name(name):
return "_".join([x.lower() for x in names if x])
raise_attribute_error = object()
# def resolve_attr(obj, attr, default=None):
# """Attempts to access attr via dotted notation, returns none if attr does not exist."""
# try:
# return functools.reduce(getattr, attr.split("."), obj)
# except AttributeError:
# return default
# def get_model_name_by_tablename(table_fullname: str) -> str:
# """Returns the model name of a given table."""
# return get_class_by_tablename(table_fullname=table_fullname).__name__
def get_class_by_tablename(table_fullname: str) -> Any:
"""Return class reference mapped to table."""
def _find_class(name):
for c in Base._decl_class_registry.values():
for c in Base.registry._class_registry.values():
if hasattr(c, "__table__"):
if c.__table__.fullname.lower() == name.lower():
return c
@ -144,8 +107,3 @@ def get_class_by_tablename(table_fullname: str) -> Any:
mapped_class = _find_class(mapped_name)
return mapped_class
# def get_table_name_by_class_instance(class_instance: Base) -> str:
# """Returns the name of the table for a given class instance."""
# return class_instance._sa_instance_state.mapper.mapped_table.name

Loading…
Cancel
Save