add aeros connection

main
Cizz22 3 months ago
parent 2d3e32a045
commit d3ec2ff18a

@ -13,7 +13,7 @@ COLLECTOR_HOSTNAME=192.168.1.86
COLLECTOR_PORT=5432
COLLECTOR_CREDENTIAL_USER=postgres
COLLECTOR_CREDENTIAL_PASSWORD=postgres
COLLECTOR_NAME=digital_twin
COLLECTOR_NAME=digital_aeros_fixed
WINDOWS_AEROS_BASE_URL=http://192.168.1.102:8800

@ -3,12 +3,12 @@ from uuid import UUID
import logging
import httpx
from fastapi import HTTPException, status
from sqlalchemy import Delete, Select, func, desc, and_
from sqlalchemy import Delete, Select, func, desc, and_, text
from sqlalchemy.orm import selectinload
from src.auth.service import CurrentUser
from src.config import AEROS_BASE_URL, DEFAULT_PROJECT_NAME
from src.database.core import DbSession
from src.database.core import CollectorDbSession, DbSession
from src.database.service import search_filter_sort_paginate
from .model import AerosEquipment, AerosEquipmentDetail, MasterEquipment, AerosEquipmentGroup, ReliabilityPredictNonRepairable
from .schema import EquipmentConfiguration
@ -313,9 +313,29 @@ def get_distribution(item):
return "Normal", mu, 1000
else:
return name, 0, 0
async def update_equipment_for_simulation(*, db_session: DbSession, project_name: str,overhaul_duration, overhaul_interval, offset ,schematic_name: str, custom_input: Optional[dict] = None):
async def update_oh_interval_offset(*, aeros_db_session, overhaul_offset, overhaul_interval, project_name):
query = text("""
UPDATE public."RegularNodes" rn
SET "OHInterval" = :new_oh_interval,
"OHOffset" = :new_oh_offset
FROM public."Schematics" s
JOIN public."Projects" p
ON s."ProjectId" = p."ProjectId"
WHERE rn."SchematicId" = s."SchematicId"
AND p."ProjectName" = :project_name
""")
aeros_db_session.execute(query, {
"new_oh_interval": overhaul_interval,
"new_oh_offset": overhaul_offset,
"project_name": project_name
})
aeros_db_session.commit()
async def update_equipment_for_simulation(*, db_session: DbSession,aeros_db_session:CollectorDbSession ,project_name: str,overhaul_duration, overhaul_interval, offset ,schematic_name: str, custom_input: Optional[dict] = None):
log.info("Updating equipment for simulation")
aeros_schematic = await get_aeros_schematic_by_name(db_session=db_session, schematic_name=schematic_name)
@ -351,6 +371,9 @@ async def update_equipment_for_simulation(*, db_session: DbSession, project_name
reqNodeInputs = []
results = defaultdict()
print("Updating Overhaul Offset & Overhaul Interval")
await update_oh_interval_offset(aeros_db_session=aeros_db_session, project_name=project_name,overhaul_interval=overhaul_interval, overhaul_offset=offset)
for eq in nodes_data:
try:

@ -10,7 +10,7 @@ from src.aeros_contribution.service import update_contribution_bulk_mappings
from src.aeros_equipment.model import AerosEquipment
from src.aeros_simulation.model import EafContribution
from src.auth.service import CurrentUser
from src.database.core import DbSession
from src.database.core import CollectorDbSession, DbSession
from src.database.service import CommonParameters
from src.models import StandardResponse
from src.aeros_equipment.service import update_equipment_for_simulation

@ -63,10 +63,22 @@ DATABASE_ENGINE_POOL_SIZE = config("DATABASE_ENGINE_POOL_SIZE", cast=int, defaul
DATABASE_ENGINE_MAX_OVERFLOW = config(
"DATABASE_ENGINE_MAX_OVERFLOW", cast=int, default=0
)
COLLECTOR_HOSTNAME = config("COLLECTOR_HOSTNAME")
COLLECTOR_PORT = config("COLLECTOR_PORT", default="5432")
COLLECTOR_CREDENTIAL_USER = config("COLLECTOR_CREDENTIAL_USER")
COLLECTOR_CREDENTIAL_PASSWORD = config("COLLECTOR_CREDENTIAL_PASSWORD")
QUOTED_COLLECTOR_CREDENTIAL_PASSWORD = parse.quote(str(COLLECTOR_CREDENTIAL_PASSWORD))
COLLECTOR_NAME = config("COLLECTOR_NAME")
# Deal with DB disconnects
# https://docs.sqlalchemy.org/en/20/core/pooling.html#pool-disconnects
DATABASE_ENGINE_POOL_PING = config("DATABASE_ENGINE_POOL_PING", default=False)
SQLALCHEMY_DATABASE_URI = f"postgresql+asyncpg://{_DATABASE_CREDENTIAL_USER}:{_QUOTED_DATABASE_PASSWORD}@{DATABASE_HOSTNAME}:{DATABASE_PORT}/{DATABASE_NAME}"
SQLALCHEMY_COLLECTOR_URI = f"postgresql+asyncpg://{COLLECTOR_CREDENTIAL_USER}:{QUOTED_COLLECTOR_CREDENTIAL_PASSWORD}@{COLLECTOR_HOSTNAME}:{COLLECTOR_PORT}/{COLLECTOR_NAME}"
TIMEZONE = "Asia/Jakarta"

@ -7,18 +7,21 @@ from typing import Annotated, Any, AsyncGenerator
from fastapi import Depends
from pydantic import BaseModel
from sqlalchemy import create_engine, inspect
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from sqlalchemy.ext.declarative import declarative_base, declared_attr
from sqlalchemy.orm import DeclarativeBase, Session, object_session, sessionmaker
from sqlalchemy.sql.expression import true
from sqlalchemy_utils import get_mapper
from starlette.requests import Request
from src.config import SQLALCHEMY_DATABASE_URI
from src.config import SQLALCHEMY_COLLECTOR_URI, SQLALCHEMY_DATABASE_URI
engine = create_async_engine(SQLALCHEMY_DATABASE_URI, echo=False, future=True)
collector_engine = create_async_engine(SQLALCHEMY_COLLECTOR_URI, echo=False, future=True)
async_session = sessionmaker(
async_session = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False,
@ -26,77 +29,51 @@ async_session = sessionmaker(
autoflush=False,
)
async_aeros_session = async_sessionmaker(
collector_engine,
class_=AsyncSession,
expire_on_commit=False,
autocommit=False,
autoflush=False,
)
async def get_aeros_db(request: Request):
return request.state.aeros_db
def get_db(request: Request):
return request.state.db
DbSession = Annotated[AsyncSession, Depends(get_db)]
CollectorDbSession = Annotated[AsyncSession, Depends(get_aeros_db)]
class CustomBase:
__repr_attrs__ = []
__repr_max_length__ = 15
class Base(DeclarativeBase):
@declared_attr.directive
def __tablename__(cls) -> str:
return resolve_table_name(cls.__name__)
@declared_attr
def __tablename__(self):
return resolve_table_name(self.__name__)
def dict(self):
"""Returns a dict representation of a model."""
if hasattr(self, '__table__'):
return {c.name: getattr(self, c.name) for c in self.__table__.columns}
return {}
class CollectorBase(DeclarativeBase):
@declared_attr.directive
def __tablename__(cls) -> str:
return resolve_table_name(cls.__name__)
def dict(self):
"""Returns a dict representation of a model."""
return {c.name: getattr(self, c.name) for c in self.__table__.columns}
@property
def _id_str(self):
ids = inspect(self).identity
if ids:
return "-".join([str(x) for x in ids]) if len(ids) > 1 else str(ids[0])
else:
return "None"
@property
def _repr_attrs_str(self):
max_length = self.__repr_max_length__
values = []
single = len(self.__repr_attrs__) == 1
for key in self.__repr_attrs__:
if not hasattr(self, key):
raise KeyError(
"{} has incorrect attribute '{}' in "
"__repr__attrs__".format(self.__class__, key)
)
value = getattr(self, key)
wrap_in_quote = isinstance(value, str)
value = str(value)
if len(value) > max_length:
value = value[:max_length] + "..."
if wrap_in_quote:
value = "'{}'".format(value)
values.append(value if single else "{}:{}".format(key, value))
return " ".join(values)
def __repr__(self):
# get id like '#123'
id_str = ("#" + self._id_str) if self._id_str else ""
# join class name, id and repr_attrs
return "<{} {}{}>".format(
self.__class__.__name__,
id_str,
" " + self._repr_attrs_str if self._repr_attrs_str else "",
)
Base = declarative_base(cls=CustomBase)
# make_searchable(Base.metadata)
@contextmanager
async def get_session():
"""Context manager to ensure the session is closed after use."""
if hasattr(self, '__table__'):
return {c.name: getattr(self, c.name) for c in self.__table__.columns}
return {}
@asynccontextmanager
async def get_main_session():
session = async_session()
try:
yield session
@ -107,6 +84,17 @@ async def get_session():
finally:
await session.close()
@asynccontextmanager
async def get_collector_session():
session = async_aeros_session()
try:
yield session
await session.commit()
except:
await session.rollback()
raise
finally:
await session.close()
def resolve_table_name(name):
"""Resolves table names to their mapped names."""
@ -114,27 +102,11 @@ def resolve_table_name(name):
return "_".join([x.lower() for x in names if x])
raise_attribute_error = object()
# def resolve_attr(obj, attr, default=None):
# """Attempts to access attr via dotted notation, returns none if attr does not exist."""
# try:
# return functools.reduce(getattr, attr.split("."), obj)
# except AttributeError:
# return default
# def get_model_name_by_tablename(table_fullname: str) -> str:
# """Returns the model name of a given table."""
# return get_class_by_tablename(table_fullname=table_fullname).__name__
def get_class_by_tablename(table_fullname: str) -> Any:
"""Return class reference mapped to table."""
def _find_class(name):
for c in Base._decl_class_registry.values():
for c in Base.registry._class_registry.values():
if hasattr(c, "__table__"):
if c.__table__.fullname.lower() == name.lower():
return c
@ -143,8 +115,3 @@ def get_class_by_tablename(table_fullname: str) -> Any:
mapped_class = _find_class(mapped_name)
return mapped_class
# def get_table_name_by_class_instance(class_instance: Base) -> str:
# """Returns the name of the table for a given class instance."""
# return class_instance._sa_instance_state.mapper.mapped_table.name

@ -24,7 +24,7 @@ from starlette.routing import compile_path
from starlette.staticfiles import StaticFiles
from src.api import api_router
from src.database.core import async_session, engine
from src.database.core import async_session, engine,async_aeros_session
from src.enums import ResponseStatus
from src.exceptions import handle_exception
from src.logging import setup_logging
@ -75,11 +75,16 @@ async def db_session_middleware(request: Request, call_next):
try:
session = async_scoped_session(async_session, scopefunc=get_request_id)
request.state.db = session()
collector_session = async_scoped_session(async_aeros_session, scopefunc=get_request_id)
request.state.aeros_db = collector_session()
response = await call_next(request)
except Exception as e:
raise e from None
finally:
await request.state.db.close()
await request.state.aeros_db.close()
_request_id_ctx_var.reset(ctx_token)
return response

@ -6,11 +6,13 @@ from temporalio import activity
async def update_equipment_for_simulation_activity(params: dict):
# ✅ Import inside the activity function
from src.aeros_equipment.service import update_equipment_for_simulation
from src.database.core import async_session
from src.database.core import async_session, async_aeros_session
async with async_session() as db_session:
async with async_session() as db_session, async_aeros_session() as aeros_db_session:
return await update_equipment_for_simulation(
db_session=db_session,
db_session=db_session,
aeros_db_session=aeros_db_session,
project_name=params["projectName"],
overhaul_duration=params["OverhaulDuration"],
overhaul_interval=params["OverhaulInterval"],

Loading…
Cancel
Save