You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
be-optimumoh/temporal/connectors/database.py

105 lines
2.7 KiB
Python

# src/connectors/database.py
from dataclasses import dataclass
from typing import List, Dict, Any, Generator
import pandas as pd
from sqlalchemy import create_engine, text
from sqlalchemy.engine import Engine
import logging
import os
from starlette.config import Config
logger = logging.getLogger(__name__)
@dataclass
class DBConfig:
host: str
port: int
database: str
user: str
password: str
def get_connection_string(self) -> str:
return f"postgresql+psycopg2://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}"
class DatabaseConnector:
def __init__(self, source_config: DBConfig, target_config: DBConfig):
self.source_engine = create_engine(
source_config.get_connection_string())
self.target_engine = create_engine(
target_config.get_connection_string())
def fetch_batch(
self,
batch_size: int,
last_id: int = 0,
columns: List[str] = None
) -> pd.DataFrame:
"""
Fetch a batch of data from source database
"""
try:
query = """
SELECT {}
FROM dl_wo_staging
WHERE id > :last_id
AND worktype IN ('PM', 'CM', 'EM', 'PROACTIVE')
ORDER BY id
LIMIT :batch_size
""".format(', '.join(columns) if columns else '*')
# Execute query
params = {
"last_id": last_id,
"batch_size": batch_size
}
df = pd.read_sql(
text(query),
self.source_engine,
params=params
)
return df
except Exception as e:
logger.error(f"Error fetching batch: {str(e)}")
raise
def load_batch(
self,
df: pd.DataFrame,
target_table: str,
if_exists: str = 'append'
) -> bool:
"""
Load a batch of data to target database
"""
try:
df.to_sql(
target_table,
self.target_engine,
if_exists=if_exists,
index=False,
method='multi',
chunksize=1000
)
return True
except Exception as e:
logger.error(f"Error loading batch: {str(e)}")
raise
def get_total_records(self) -> int:
"""Get total number of records to migrate"""
try:
with self.source_engine.connect() as conn:
result = conn.execute(text("SELECT COUNT(*) FROM sensor_data"))
return result.scalar()
except Exception as e:
logger.error(f"Error getting total records: {str(e)}")
raise