You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
105 lines
2.7 KiB
Python
105 lines
2.7 KiB
Python
# src/connectors/database.py
|
|
from dataclasses import dataclass
|
|
from typing import List, Dict, Any, Generator
|
|
import pandas as pd
|
|
from sqlalchemy import create_engine, text
|
|
from sqlalchemy.engine import Engine
|
|
import logging
|
|
import os
|
|
from starlette.config import Config
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
@dataclass
|
|
class DBConfig:
|
|
host: str
|
|
port: int
|
|
database: str
|
|
user: str
|
|
password: str
|
|
|
|
def get_connection_string(self) -> str:
|
|
return f"postgresql+psycopg2://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}"
|
|
|
|
|
|
class DatabaseConnector:
|
|
def __init__(self, source_config: DBConfig, target_config: DBConfig):
|
|
self.source_engine = create_engine(
|
|
source_config.get_connection_string())
|
|
self.target_engine = create_engine(
|
|
target_config.get_connection_string())
|
|
|
|
def fetch_batch(
|
|
self,
|
|
batch_size: int,
|
|
last_id: int = 0,
|
|
columns: List[str] = None
|
|
) -> pd.DataFrame:
|
|
"""
|
|
Fetch a batch of data from source database
|
|
"""
|
|
try:
|
|
query = """
|
|
SELECT {}
|
|
FROM dl_wo_staging
|
|
WHERE id > :last_id
|
|
AND worktype IN ('PM', 'CM', 'EM', 'PROACTIVE')
|
|
ORDER BY id
|
|
LIMIT :batch_size
|
|
""".format(', '.join(columns) if columns else '*')
|
|
|
|
# Execute query
|
|
params = {
|
|
"last_id": last_id,
|
|
"batch_size": batch_size
|
|
}
|
|
|
|
df = pd.read_sql(
|
|
text(query),
|
|
self.source_engine,
|
|
params=params
|
|
)
|
|
|
|
return df
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error fetching batch: {str(e)}")
|
|
raise
|
|
|
|
def load_batch(
|
|
self,
|
|
df: pd.DataFrame,
|
|
target_table: str,
|
|
if_exists: str = 'append'
|
|
) -> bool:
|
|
"""
|
|
Load a batch of data to target database
|
|
"""
|
|
try:
|
|
df.to_sql(
|
|
target_table,
|
|
self.target_engine,
|
|
if_exists=if_exists,
|
|
index=False,
|
|
method='multi',
|
|
chunksize=1000
|
|
)
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error loading batch: {str(e)}")
|
|
raise
|
|
|
|
def get_total_records(self) -> int:
|
|
"""Get total number of records to migrate"""
|
|
try:
|
|
with self.source_engine.connect() as conn:
|
|
result = conn.execute(text("SELECT COUNT(*) FROM sensor_data"))
|
|
return result.scalar()
|
|
except Exception as e:
|
|
logger.error(f"Error getting total records: {str(e)}")
|
|
raise
|
|
|
|
|