initialCommit
commit
a8161e7edb
@ -0,0 +1,16 @@
|
|||||||
|
ENV=development
|
||||||
|
LOG_LEVEL=ERROR
|
||||||
|
PORT=3010
|
||||||
|
HOST=0.0.0.0
|
||||||
|
|
||||||
|
DATABASE_HOSTNAME=192.168.1.82
|
||||||
|
DATABASE_PORT=5432
|
||||||
|
DATABASE_CREDENTIAL_USER=postgres
|
||||||
|
DATABASE_CREDENTIAL_PASSWORD=postgres
|
||||||
|
DATABASE_NAME=digital_twin
|
||||||
|
|
||||||
|
COLLECTOR_HOSTNAME=192.168.1.86
|
||||||
|
COLLECTOR_PORT=5432
|
||||||
|
COLLECTOR_CREDENTIAL_USER=postgres
|
||||||
|
COLLECTOR_CREDENTIAL_PASSWORD=postgres
|
||||||
|
COLLECTOR_NAME=digital_twin
|
||||||
@ -0,0 +1,2 @@
|
|||||||
|
env/
|
||||||
|
__pycache__/
|
||||||
@ -0,0 +1,44 @@
|
|||||||
|
# Quick Start:
|
||||||
|
#
|
||||||
|
# pip install pre-commit
|
||||||
|
# pre-commit install && pre-commit install -t pre-push
|
||||||
|
# pre-commit run --all-files
|
||||||
|
#
|
||||||
|
# To Skip Checks:
|
||||||
|
#
|
||||||
|
# git commit --no-verify
|
||||||
|
fail_fast: false
|
||||||
|
|
||||||
|
default_language_version:
|
||||||
|
python: python3.11.2
|
||||||
|
|
||||||
|
repos:
|
||||||
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
|
# ruff version.
|
||||||
|
rev: v0.7.0
|
||||||
|
hooks:
|
||||||
|
# Run the linter.
|
||||||
|
#
|
||||||
|
# When running with --fix, Ruff's lint hook should be placed before Ruff's formatter hook,
|
||||||
|
# and before Black, isort, and other formatting tools, as Ruff's fix behavior can output code changes that require reformatting.
|
||||||
|
- id: ruff
|
||||||
|
args: [--fix]
|
||||||
|
# Run the formatter.
|
||||||
|
- id: ruff-format
|
||||||
|
|
||||||
|
# Typos
|
||||||
|
- repo: https://github.com/crate-ci/typos
|
||||||
|
rev: v1.26.1
|
||||||
|
hooks:
|
||||||
|
- id: typos
|
||||||
|
exclude: ^(data/dispatch-sample-data.dump|src/dispatch/static/dispatch/src/|src/dispatch/database/revisions/)
|
||||||
|
|
||||||
|
# Pytest
|
||||||
|
- repo: local
|
||||||
|
hooks:
|
||||||
|
- id: tests
|
||||||
|
name: run tests
|
||||||
|
entry: pytest -v tests/
|
||||||
|
language: system
|
||||||
|
types: [python]
|
||||||
|
stages: [push]
|
||||||
@ -0,0 +1,50 @@
|
|||||||
|
# Use the official Python 3.11 image from the Docker Hub
|
||||||
|
FROM python:3.11-slim as builder
|
||||||
|
|
||||||
|
# Install Poetry
|
||||||
|
RUN pip install poetry
|
||||||
|
|
||||||
|
# Set environment variables for Poetry
|
||||||
|
ENV POETRY_NO_INTERACTION=1 \
|
||||||
|
POETRY_VIRTUALENVS_IN_PROJECT=1 \
|
||||||
|
POETRY_VIRTUALENVS_CREATE=1 \
|
||||||
|
POETRY_CACHE_DIR=/tmp/poetry_cache
|
||||||
|
|
||||||
|
# Set the working directory
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Copy the Poetry configuration files
|
||||||
|
COPY pyproject.toml poetry.lock ./
|
||||||
|
|
||||||
|
# Install dependencies
|
||||||
|
RUN poetry install --no-dev --no-root
|
||||||
|
|
||||||
|
# Use a new slim image for the runtime
|
||||||
|
FROM python:3.11-slim as runtime
|
||||||
|
|
||||||
|
# Install necessary tools for running the app, including `make`
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
make \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Set environment variables for Poetry
|
||||||
|
ENV POETRY_VIRTUALENVS_IN_PROJECT=1 \
|
||||||
|
PATH="/app/.venv/bin:$PATH"
|
||||||
|
|
||||||
|
# Copy Poetry installation from builder
|
||||||
|
COPY --from=builder /app/.venv /app/.venv
|
||||||
|
|
||||||
|
# Copy application files
|
||||||
|
COPY . /app/
|
||||||
|
|
||||||
|
# Delete Tests for production
|
||||||
|
RUN rm -rf /app/tests/
|
||||||
|
|
||||||
|
# Expose port for the application
|
||||||
|
EXPOSE 3005
|
||||||
|
|
||||||
|
# Set the working directory
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Run `make run` as the entry point
|
||||||
|
CMD ["make", "run"]
|
||||||
@ -0,0 +1,14 @@
|
|||||||
|
# Define variables
|
||||||
|
POETRY = poetry
|
||||||
|
PYTHON = $(POETRY) run python
|
||||||
|
APP = src/server.py
|
||||||
|
|
||||||
|
# Targets and their rules
|
||||||
|
|
||||||
|
# Install dependencies
|
||||||
|
install:
|
||||||
|
$(POETRY) install
|
||||||
|
|
||||||
|
# Run the application
|
||||||
|
run:
|
||||||
|
python run.py
|
||||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,32 @@
|
|||||||
|
[tool.poetry]
|
||||||
|
name = "optimumohservice"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = ""
|
||||||
|
authors = ["Cizz22 <cisatraa@gmail.com>"]
|
||||||
|
license = "MIT"
|
||||||
|
readme = "README.md"
|
||||||
|
|
||||||
|
[tool.poetry.dependencies]
|
||||||
|
python = "^3.11"
|
||||||
|
fastapi = {extras = ["standard"], version = "^0.115.4"}
|
||||||
|
sqlalchemy = "^2.0.36"
|
||||||
|
httpx = "^0.27.2"
|
||||||
|
pytest = "^8.3.3"
|
||||||
|
faker = "^30.8.2"
|
||||||
|
factory-boy = "^3.3.1"
|
||||||
|
sqlalchemy-utils = "^0.41.2"
|
||||||
|
slowapi = "^0.1.9"
|
||||||
|
uvicorn = "^0.32.0"
|
||||||
|
pytz = "^2024.2"
|
||||||
|
sqlalchemy-filters = "^0.13.0"
|
||||||
|
asyncpg = "^0.30.0"
|
||||||
|
requests = "^2.32.3"
|
||||||
|
pydantic = "^2.10.2"
|
||||||
|
temporalio = "^1.8.0"
|
||||||
|
pandas = "^2.2.3"
|
||||||
|
psycopg2-binary = "^2.9.10"
|
||||||
|
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["poetry-core"]
|
||||||
|
build-backend = "poetry.core.masonry.api"
|
||||||
@ -0,0 +1,10 @@
|
|||||||
|
import uvicorn
|
||||||
|
from src.config import PORT, HOST
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
uvicorn.run(
|
||||||
|
"src.main:app",
|
||||||
|
host=HOST,
|
||||||
|
port=PORT,
|
||||||
|
reload=True
|
||||||
|
)
|
||||||
@ -0,0 +1,71 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
from src.auth.service import JWTBearer
|
||||||
|
|
||||||
|
|
||||||
|
from src.scope.router import router as scope_router
|
||||||
|
from src.scope_equipment.router import router as scope_equipment_router
|
||||||
|
from src.overhaul.router import router as overhaul_router
|
||||||
|
from src.calculation_time_constrains.router import router as calculation_time_constrains_router
|
||||||
|
|
||||||
|
|
||||||
|
class ErrorMessage(BaseModel):
|
||||||
|
msg: str
|
||||||
|
|
||||||
|
|
||||||
|
class ErrorResponse(BaseModel):
|
||||||
|
detail: Optional[List[ErrorMessage]]
|
||||||
|
|
||||||
|
|
||||||
|
api_router = APIRouter(
|
||||||
|
default_response_class=JSONResponse,
|
||||||
|
responses={
|
||||||
|
400: {"model": ErrorResponse},
|
||||||
|
401: {"model": ErrorResponse},
|
||||||
|
403: {"model": ErrorResponse},
|
||||||
|
404: {"model": ErrorResponse},
|
||||||
|
500: {"model": ErrorResponse},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@api_router.get("/healthcheck", include_in_schema=False)
|
||||||
|
def healthcheck():
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
|
authenticated_api_router = APIRouter(dependencies=[Depends(JWTBearer())],
|
||||||
|
)
|
||||||
|
# overhaul data
|
||||||
|
authenticated_api_router.include_router(
|
||||||
|
overhaul_router, prefix="/overhauls", tags=["overhaul"])
|
||||||
|
|
||||||
|
|
||||||
|
# Scope data
|
||||||
|
authenticated_api_router.include_router(
|
||||||
|
scope_router, prefix="/scopes", tags=["scope"])
|
||||||
|
|
||||||
|
authenticated_api_router.include_router(
|
||||||
|
scope_equipment_router, prefix="/scope-equipments", tags=["scope_equipment"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# calculation
|
||||||
|
calculation_router = APIRouter(prefix="/calculation", tags=["calculations"])
|
||||||
|
|
||||||
|
# Time constrains
|
||||||
|
calculation_router.include_router(
|
||||||
|
calculation_time_constrains_router, prefix="/time-constraint", tags=["time_constraint"])
|
||||||
|
|
||||||
|
# Target reliability
|
||||||
|
|
||||||
|
# Budget Constrain
|
||||||
|
|
||||||
|
authenticated_api_router.include_router(
|
||||||
|
calculation_router
|
||||||
|
)
|
||||||
|
|
||||||
|
api_router.include_router(authenticated_api_router)
|
||||||
@ -0,0 +1,9 @@
|
|||||||
|
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class UserBase(BaseModel):
|
||||||
|
name: str
|
||||||
|
role: str
|
||||||
|
user_id: str
|
||||||
@ -0,0 +1,55 @@
|
|||||||
|
# app/auth/auth_bearer.py
|
||||||
|
|
||||||
|
from typing import Annotated, Optional
|
||||||
|
from fastapi import Depends, Request, HTTPException
|
||||||
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||||
|
import requests
|
||||||
|
import src.config as config
|
||||||
|
from .model import UserBase
|
||||||
|
|
||||||
|
|
||||||
|
class JWTBearer(HTTPBearer):
|
||||||
|
def __init__(self, auto_error: bool = True):
|
||||||
|
super(JWTBearer, self).__init__(auto_error=auto_error)
|
||||||
|
|
||||||
|
async def __call__(self, request: Request):
|
||||||
|
credentials: HTTPAuthorizationCredentials = await super(JWTBearer, self).__call__(request)
|
||||||
|
if credentials:
|
||||||
|
if not credentials.scheme == "Bearer":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403, detail="Invalid authentication scheme.")
|
||||||
|
user_info = self.verify_jwt(credentials.credentials)
|
||||||
|
if not user_info:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403, detail="Invalid token or expired token.")
|
||||||
|
|
||||||
|
request.state.user = user_info
|
||||||
|
return user_info
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403, detail="Invalid authorization code.")
|
||||||
|
|
||||||
|
def verify_jwt(self, jwtoken: str) -> Optional[UserBase]:
|
||||||
|
try:
|
||||||
|
response = requests.get(
|
||||||
|
f"{config.AUTH_SERVICE_API}/verify-token",
|
||||||
|
headers={"Authorization": f"Bearer {jwtoken}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
if not response.ok:
|
||||||
|
return None
|
||||||
|
|
||||||
|
user_data = response.json()
|
||||||
|
return UserBase(**user_data['data'])
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Token verification error: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# Create dependency to get current user from request state
|
||||||
|
async def get_current_user(request: Request) -> UserBase:
|
||||||
|
return request.state.user
|
||||||
|
|
||||||
|
|
||||||
|
CurrentUser = Annotated[UserBase, Depends(get_current_user)]
|
||||||
@ -0,0 +1,96 @@
|
|||||||
|
|
||||||
|
from typing import Any, Dict
|
||||||
|
from fastapi import APIRouter, HTTPException, status
|
||||||
|
from .schema import CalculationTimeConstrainsParametersRead, CalculationTimeConstrainsRead, CalculationTimeConstrainsCreate
|
||||||
|
from src.database.core import DbSession
|
||||||
|
from src.models import StandardResponse
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/parameters", response_model=StandardResponse[CalculationTimeConstrainsParametersRead])
|
||||||
|
async def get_calculation_parameters():
|
||||||
|
"""Get all calculation parameter pagination."""
|
||||||
|
|
||||||
|
# {
|
||||||
|
# "costPerFailure": 733.614,
|
||||||
|
# "availableScopes": ["A", "B"],
|
||||||
|
# "recommendedScope": "B",
|
||||||
|
# "historicalData": {
|
||||||
|
# "averageOverhaulCost": 10000000,
|
||||||
|
# "lastCalculation": {
|
||||||
|
# "id": "calc_122",
|
||||||
|
# "date": "2024-10-15",
|
||||||
|
# "scope": "B"
|
||||||
|
# }
|
||||||
|
# }
|
||||||
|
# }
|
||||||
|
|
||||||
|
return StandardResponse(
|
||||||
|
data=CalculationTimeConstrainsParametersRead(
|
||||||
|
costPerFailure=733.614,
|
||||||
|
availableScopes=["A", "B"],
|
||||||
|
recommendedScope="B",
|
||||||
|
historicalData={
|
||||||
|
"averageOverhaulCost": 10000000,
|
||||||
|
"lastCalculation": {
|
||||||
|
"id": "calc_122",
|
||||||
|
"date": "2024-10-15",
|
||||||
|
"scope": "B",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
message="Data retrieved successfully",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("", response_model=StandardResponse[CalculationTimeConstrainsRead])
|
||||||
|
async def create_calculation_time_constrains(db_session: DbSession, calculation_time_constrains_in: CalculationTimeConstrainsCreate):
|
||||||
|
"""Calculate Here"""
|
||||||
|
calculation_result = {
|
||||||
|
"id": "calc_123",
|
||||||
|
"result": {
|
||||||
|
"summary": {
|
||||||
|
"scope": "B",
|
||||||
|
"numberOfFailures": 59,
|
||||||
|
"optimumOHTime": 90,
|
||||||
|
"optimumTotalCost": 500000000
|
||||||
|
},
|
||||||
|
"chartData": {},
|
||||||
|
"comparisons": {
|
||||||
|
"vsLastCalculation": {
|
||||||
|
"costDifference": -50000000,
|
||||||
|
"timeChange": "+15 days"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"simulationLimits": {
|
||||||
|
"minInterval": 30,
|
||||||
|
"maxInterval": 180
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return StandardResponse(data=CalculationTimeConstrainsRead(
|
||||||
|
id=calculation_result["id"],
|
||||||
|
result=calculation_result["result"],
|
||||||
|
), message="Data created successfully")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/simulation", response_model=StandardResponse[Dict[str, Any]])
|
||||||
|
async def get_simulation_result():
|
||||||
|
|
||||||
|
results = {
|
||||||
|
"simulation": {
|
||||||
|
"intervalDays": 45,
|
||||||
|
"numberOfFailures": 75,
|
||||||
|
"totalCost": 550000000,
|
||||||
|
},
|
||||||
|
"comparison": {
|
||||||
|
"vsOptimal": {
|
||||||
|
"failureDifference": 16,
|
||||||
|
"costDifference": 50000000
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return StandardResponse(data=results, message="Data retrieved successfully")
|
||||||
@ -0,0 +1,78 @@
|
|||||||
|
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from src.models import DefultBase
|
||||||
|
|
||||||
|
# {
|
||||||
|
# "costPerFailure": 733.614,
|
||||||
|
# "availableScopes": ["A", "B"],
|
||||||
|
# "recommendedScope": "B",
|
||||||
|
# "historicalData": {
|
||||||
|
# "averageOverhaulCost": 10000000,
|
||||||
|
# "lastCalculation": {
|
||||||
|
# "id": "calc_122",
|
||||||
|
# "date": "2024-10-15",
|
||||||
|
# "scope": "B"
|
||||||
|
# }
|
||||||
|
# }
|
||||||
|
# }
|
||||||
|
|
||||||
|
|
||||||
|
class CalculationTimeConstrainsBase(DefultBase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class CalculationTimeConstrainsParametersRead(CalculationTimeConstrainsBase):
|
||||||
|
costPerFailure: float = Field(..., description="Cost per failure")
|
||||||
|
availableScopes: List[str] = Field(..., description="Available scopes")
|
||||||
|
recommendedScope: str = Field(..., description="Recommended scope")
|
||||||
|
historicalData: Dict[str, Any] = Field(..., description="Historical data")
|
||||||
|
|
||||||
|
|
||||||
|
# {
|
||||||
|
# "overhaulCost": 10000000,
|
||||||
|
# "scopeOH": "B",
|
||||||
|
# "costPerFailure": 733.614,
|
||||||
|
# "metadata": {
|
||||||
|
# "unit": "PLTU1",
|
||||||
|
# "calculatedBy": "user123",
|
||||||
|
# "timestamp": "2024-11-28T10:00:00Z"
|
||||||
|
# }
|
||||||
|
# }
|
||||||
|
|
||||||
|
class CalculationTimeConstrainsCreate(CalculationTimeConstrainsBase):
|
||||||
|
overhaulCost: float = Field(..., description="Overhaul cost")
|
||||||
|
scopeOH: str = Field(..., description="Scope OH")
|
||||||
|
costPerFailure: float = Field(..., description="Cost per failure")
|
||||||
|
metadata: Dict[str, Any] = Field(..., description="Metadata")
|
||||||
|
|
||||||
|
# {
|
||||||
|
# "calculationId": "calc_123",
|
||||||
|
# "result": {
|
||||||
|
# "summary": {
|
||||||
|
# "scope": "B",
|
||||||
|
# "numberOfFailures": 59,
|
||||||
|
# "optimumOHTime": 90,
|
||||||
|
# "optimumTotalCost": 500000000
|
||||||
|
# },
|
||||||
|
# "chartData": {/* ... */},
|
||||||
|
# "comparisons": {
|
||||||
|
# "vsLastCalculation": {
|
||||||
|
# "costDifference": -50000000,
|
||||||
|
# "timeChange": "+15 days"
|
||||||
|
# }
|
||||||
|
# }
|
||||||
|
# },
|
||||||
|
# "simulationLimits": {
|
||||||
|
# "minInterval": 30,
|
||||||
|
# "maxInterval": 180
|
||||||
|
# }
|
||||||
|
# }
|
||||||
|
|
||||||
|
class CalculationTimeConstrainsRead(CalculationTimeConstrainsBase):
|
||||||
|
id: Union[UUID, str]
|
||||||
|
result: Dict[str, Any]
|
||||||
@ -0,0 +1,74 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import base64
|
||||||
|
from urllib import parse
|
||||||
|
from typing import List
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from starlette.config import Config
|
||||||
|
from starlette.datastructures import CommaSeparatedStrings
|
||||||
|
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseConfigurationModel(BaseModel):
|
||||||
|
"""Base configuration model used by all config options."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def get_env_tags(tag_list: List[str]) -> dict:
|
||||||
|
"""Create dictionary of available env tags."""
|
||||||
|
tags = {}
|
||||||
|
for t in tag_list:
|
||||||
|
tag_key, env_key = t.split(":")
|
||||||
|
|
||||||
|
env_value = os.environ.get(env_key)
|
||||||
|
|
||||||
|
if env_value:
|
||||||
|
tags.update({tag_key: env_value})
|
||||||
|
|
||||||
|
return tags
|
||||||
|
|
||||||
|
def get_config():
|
||||||
|
try:
|
||||||
|
# Try to load from .env file first
|
||||||
|
config = Config(".env")
|
||||||
|
except FileNotFoundError:
|
||||||
|
# If .env doesn't exist, use environment variables
|
||||||
|
config = Config(environ=os.environ)
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
config = get_config()
|
||||||
|
|
||||||
|
|
||||||
|
LOG_LEVEL = config("LOG_LEVEL", default=logging.WARNING)
|
||||||
|
ENV = config("ENV", default="local")
|
||||||
|
PORT = config("PORT", cast=int, default=8000)
|
||||||
|
HOST = config("HOST", default="localhost")
|
||||||
|
|
||||||
|
|
||||||
|
# database
|
||||||
|
DATABASE_HOSTNAME = config("DATABASE_HOSTNAME")
|
||||||
|
_DATABASE_CREDENTIAL_USER = config("DATABASE_CREDENTIAL_USER")
|
||||||
|
_DATABASE_CREDENTIAL_PASSWORD = config("DATABASE_CREDENTIAL_PASSWORD")
|
||||||
|
_QUOTED_DATABASE_PASSWORD = parse.quote(str(_DATABASE_CREDENTIAL_PASSWORD))
|
||||||
|
DATABASE_NAME = config("DATABASE_NAME", default="digital_twin")
|
||||||
|
DATABASE_PORT = config("DATABASE_PORT", default="5432")
|
||||||
|
|
||||||
|
DATABASE_ENGINE_POOL_SIZE = config(
|
||||||
|
"DATABASE_ENGINE_POOL_SIZE", cast=int, default=20)
|
||||||
|
DATABASE_ENGINE_MAX_OVERFLOW = config(
|
||||||
|
"DATABASE_ENGINE_MAX_OVERFLOW", cast=int, default=0)
|
||||||
|
# Deal with DB disconnects
|
||||||
|
# https://docs.sqlalchemy.org/en/20/core/pooling.html#pool-disconnects
|
||||||
|
DATABASE_ENGINE_POOL_PING = config("DATABASE_ENGINE_POOL_PING", default=False)
|
||||||
|
SQLALCHEMY_DATABASE_URI = f"postgresql+asyncpg://{_DATABASE_CREDENTIAL_USER}:{_QUOTED_DATABASE_PASSWORD}@{DATABASE_HOSTNAME}:{DATABASE_PORT}/{DATABASE_NAME}"
|
||||||
|
|
||||||
|
TIMEZONE = "Asia/Jakarta"
|
||||||
|
|
||||||
|
|
||||||
|
AUTH_SERVICE_API = config(
|
||||||
|
"AUTH_SERVICE_API", default="http://192.168.1.82:8000/auth")
|
||||||
@ -0,0 +1,156 @@
|
|||||||
|
# src/database.py
|
||||||
|
from starlette.requests import Request
|
||||||
|
from sqlalchemy_utils import get_mapper
|
||||||
|
from sqlalchemy.sql.expression import true
|
||||||
|
from sqlalchemy.orm import object_session, sessionmaker, Session
|
||||||
|
from sqlalchemy.ext.declarative import declarative_base, declared_attr
|
||||||
|
from sqlalchemy import create_engine, inspect
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from fastapi import Depends
|
||||||
|
from typing import Annotated, Any
|
||||||
|
from contextlib import contextmanager
|
||||||
|
import re
|
||||||
|
import functools
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||||
|
from sqlalchemy.orm import DeclarativeBase, sessionmaker
|
||||||
|
|
||||||
|
from src.config import SQLALCHEMY_DATABASE_URI
|
||||||
|
|
||||||
|
engine = create_async_engine(
|
||||||
|
SQLALCHEMY_DATABASE_URI,
|
||||||
|
echo=False,
|
||||||
|
future=True
|
||||||
|
)
|
||||||
|
|
||||||
|
async_session = sessionmaker(
|
||||||
|
engine,
|
||||||
|
class_=AsyncSession,
|
||||||
|
expire_on_commit=False,
|
||||||
|
autocommit=False,
|
||||||
|
autoflush=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_db(request: Request):
|
||||||
|
return request.state.db
|
||||||
|
|
||||||
|
|
||||||
|
DbSession = Annotated[AsyncSession, Depends(get_db)]
|
||||||
|
|
||||||
|
|
||||||
|
class CustomBase:
|
||||||
|
__repr_attrs__ = []
|
||||||
|
__repr_max_length__ = 15
|
||||||
|
|
||||||
|
@declared_attr
|
||||||
|
def __tablename__(self):
|
||||||
|
return resolve_table_name(self.__name__)
|
||||||
|
|
||||||
|
def dict(self):
|
||||||
|
"""Returns a dict representation of a model."""
|
||||||
|
return {c.name: getattr(self, c.name) for c in self.__table__.columns}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _id_str(self):
|
||||||
|
ids = inspect(self).identity
|
||||||
|
if ids:
|
||||||
|
return "-".join([str(x) for x in ids]) if len(ids) > 1 else str(ids[0])
|
||||||
|
else:
|
||||||
|
return "None"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _repr_attrs_str(self):
|
||||||
|
max_length = self.__repr_max_length__
|
||||||
|
|
||||||
|
values = []
|
||||||
|
single = len(self.__repr_attrs__) == 1
|
||||||
|
for key in self.__repr_attrs__:
|
||||||
|
if not hasattr(self, key):
|
||||||
|
raise KeyError(
|
||||||
|
"{} has incorrect attribute '{}' in " "__repr__attrs__".format(
|
||||||
|
self.__class__, key
|
||||||
|
)
|
||||||
|
)
|
||||||
|
value = getattr(self, key)
|
||||||
|
wrap_in_quote = isinstance(value, str)
|
||||||
|
|
||||||
|
value = str(value)
|
||||||
|
if len(value) > max_length:
|
||||||
|
value = value[:max_length] + "..."
|
||||||
|
|
||||||
|
if wrap_in_quote:
|
||||||
|
value = "'{}'".format(value)
|
||||||
|
values.append(value if single else "{}:{}".format(key, value))
|
||||||
|
|
||||||
|
return " ".join(values)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
# get id like '#123'
|
||||||
|
id_str = ("#" + self._id_str) if self._id_str else ""
|
||||||
|
# join class name, id and repr_attrs
|
||||||
|
return "<{} {}{}>".format(
|
||||||
|
self.__class__.__name__,
|
||||||
|
id_str,
|
||||||
|
" " + self._repr_attrs_str if self._repr_attrs_str else "",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Base = declarative_base(cls=CustomBase)
|
||||||
|
# make_searchable(Base.metadata)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
async def get_session():
|
||||||
|
"""Context manager to ensure the session is closed after use."""
|
||||||
|
session = async_session()
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
await session.commit()
|
||||||
|
except:
|
||||||
|
await session.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
await session.close()
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_table_name(name):
|
||||||
|
"""Resolves table names to their mapped names."""
|
||||||
|
names = re.split("(?=[A-Z])", name) # noqa
|
||||||
|
return "_".join([x.lower() for x in names if x])
|
||||||
|
|
||||||
|
|
||||||
|
raise_attribute_error = object()
|
||||||
|
|
||||||
|
|
||||||
|
# def resolve_attr(obj, attr, default=None):
|
||||||
|
# """Attempts to access attr via dotted notation, returns none if attr does not exist."""
|
||||||
|
# try:
|
||||||
|
# return functools.reduce(getattr, attr.split("."), obj)
|
||||||
|
# except AttributeError:
|
||||||
|
# return default
|
||||||
|
|
||||||
|
|
||||||
|
# def get_model_name_by_tablename(table_fullname: str) -> str:
|
||||||
|
# """Returns the model name of a given table."""
|
||||||
|
# return get_class_by_tablename(table_fullname=table_fullname).__name__
|
||||||
|
|
||||||
|
|
||||||
|
def get_class_by_tablename(table_fullname: str) -> Any:
|
||||||
|
"""Return class reference mapped to table."""
|
||||||
|
|
||||||
|
def _find_class(name):
|
||||||
|
for c in Base._decl_class_registry.values():
|
||||||
|
if hasattr(c, "__table__"):
|
||||||
|
if c.__table__.fullname.lower() == name.lower():
|
||||||
|
return c
|
||||||
|
|
||||||
|
mapped_name = resolve_table_name(table_fullname)
|
||||||
|
mapped_class = _find_class(mapped_name)
|
||||||
|
|
||||||
|
return mapped_class
|
||||||
|
|
||||||
|
|
||||||
|
# def get_table_name_by_class_instance(class_instance: Base) -> str:
|
||||||
|
# """Returns the name of the table for a given class instance."""
|
||||||
|
# return class_instance._sa_instance_state.mapper.mapped_table.name
|
||||||
@ -0,0 +1,133 @@
|
|||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Annotated, List
|
||||||
|
|
||||||
|
from sqlalchemy import desc, func, or_, Select
|
||||||
|
from sqlalchemy_filters import apply_pagination
|
||||||
|
from sqlalchemy.exc import ProgrammingError
|
||||||
|
from .core import DbSession
|
||||||
|
|
||||||
|
|
||||||
|
from fastapi import Query, Depends
|
||||||
|
from pydantic.types import Json, constr
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# allows only printable characters
|
||||||
|
QueryStr = constr(pattern=r"^[ -~]+$", min_length=1)
|
||||||
|
|
||||||
|
|
||||||
|
def common_parameters(
|
||||||
|
db_session: DbSession, # type: ignore
|
||||||
|
current_user: QueryStr = Query(None, alias="currentUser"), # type: ignore
|
||||||
|
page: int = Query(1, gt=0, lt=2147483647),
|
||||||
|
items_per_page: int = Query(5, alias="itemsPerPage", gt=-2, lt=2147483647),
|
||||||
|
query_str: QueryStr = Query(None, alias="q"), # type: ignore
|
||||||
|
filter_spec: QueryStr = Query(None, alias="filter"), # type: ignore
|
||||||
|
sort_by: List[str] = Query([], alias="sortBy[]"),
|
||||||
|
descending: List[bool] = Query([], alias="descending[]"),
|
||||||
|
exclude: List[str] = Query([], alias="exclude[]"),
|
||||||
|
# role: QueryStr = Depends(get_current_role),
|
||||||
|
):
|
||||||
|
return {
|
||||||
|
"db_session": db_session,
|
||||||
|
"page": page,
|
||||||
|
"items_per_page": items_per_page,
|
||||||
|
"query_str": query_str,
|
||||||
|
"filter_spec": filter_spec,
|
||||||
|
"sort_by": sort_by,
|
||||||
|
"descending": descending,
|
||||||
|
"current_user": current_user,
|
||||||
|
# "role": role,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
CommonParameters = Annotated[
|
||||||
|
dict[str, int | str | DbSession | QueryStr |
|
||||||
|
Json | List[str] | List[bool]],
|
||||||
|
Depends(common_parameters),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def search(*, query_str: str, query: Query, model, sort=False):
|
||||||
|
"""Perform a search based on the query."""
|
||||||
|
search_model = model
|
||||||
|
|
||||||
|
if not query_str.strip():
|
||||||
|
return query
|
||||||
|
|
||||||
|
search = []
|
||||||
|
if hasattr(search_model, "search_vector"):
|
||||||
|
vector = search_model.search_vector
|
||||||
|
search.append(vector.op("@@")(func.tsq_parse(query_str)))
|
||||||
|
|
||||||
|
if hasattr(search_model, "name"):
|
||||||
|
search.append(
|
||||||
|
search_model.name.ilike(f"%{query_str}%"),
|
||||||
|
)
|
||||||
|
search.append(search_model.name == query_str)
|
||||||
|
|
||||||
|
if not search:
|
||||||
|
raise Exception(f"Search not supported for model: {model}")
|
||||||
|
|
||||||
|
query = query.filter(or_(*search))
|
||||||
|
|
||||||
|
if sort:
|
||||||
|
query = query.order_by(
|
||||||
|
desc(func.ts_rank_cd(vector, func.tsq_parse(query_str))))
|
||||||
|
|
||||||
|
return query.params(term=query_str)
|
||||||
|
|
||||||
|
|
||||||
|
async def search_filter_sort_paginate(
|
||||||
|
db_session: DbSession,
|
||||||
|
model,
|
||||||
|
query_str: str = None,
|
||||||
|
filter_spec: str | dict | None = None,
|
||||||
|
page: int = 1,
|
||||||
|
items_per_page: int = 5,
|
||||||
|
sort_by: List[str] = None,
|
||||||
|
descending: List[bool] = None,
|
||||||
|
current_user: str = None,
|
||||||
|
exclude: List[str] = None,
|
||||||
|
):
|
||||||
|
"""Common functionality for searching, filtering, sorting, and pagination."""
|
||||||
|
# try:
|
||||||
|
query = Select(model)
|
||||||
|
|
||||||
|
if query_str:
|
||||||
|
sort = False if sort_by else True
|
||||||
|
query = search(query_str=query_str, query=query,
|
||||||
|
model=model, sort=sort)
|
||||||
|
|
||||||
|
# Get total count
|
||||||
|
count_query = Select(func.count()).select_from(query.subquery())
|
||||||
|
total = await db_session.scalar(count_query)
|
||||||
|
|
||||||
|
query = (
|
||||||
|
query
|
||||||
|
.offset((page - 1) * items_per_page)
|
||||||
|
.limit(items_per_page)
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await db_session.execute(query)
|
||||||
|
items = result.scalars().all()
|
||||||
|
|
||||||
|
# try:
|
||||||
|
# query, pagination = apply_pagination(
|
||||||
|
# query=query, page_number=page, page_size=items_per_page)
|
||||||
|
# except ProgrammingError as e:
|
||||||
|
# log.debug(e)
|
||||||
|
# return {
|
||||||
|
# "items": [],
|
||||||
|
# "itemsPerPage": items_per_page,
|
||||||
|
# "page": page,
|
||||||
|
# "total": 0,
|
||||||
|
# }
|
||||||
|
|
||||||
|
return {
|
||||||
|
"items": items,
|
||||||
|
"itemsPerPage": items_per_page,
|
||||||
|
"page": page,
|
||||||
|
"total": total,
|
||||||
|
}
|
||||||
@ -0,0 +1,24 @@
|
|||||||
|
from enum import StrEnum
|
||||||
|
|
||||||
|
|
||||||
|
class OptimumOHEnum(StrEnum):
|
||||||
|
"""
|
||||||
|
A custom Enum class that extends StrEnum.
|
||||||
|
|
||||||
|
This class inherits all functionality from StrEnum, including
|
||||||
|
string representation and automatic value conversion to strings.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
class Visibility(DispatchEnum):
|
||||||
|
OPEN = "Open"
|
||||||
|
RESTRICTED = "Restricted"
|
||||||
|
|
||||||
|
assert str(Visibility.OPEN) == "Open"
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass # No additional implementation needed
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseStatus(OptimumOHEnum):
|
||||||
|
SUCCESS = "success"
|
||||||
|
ERROR = "error"
|
||||||
@ -0,0 +1,164 @@
|
|||||||
|
# Define base error model
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
from fastapi import FastAPI, HTTPException, Request
|
||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from src.enums import ResponseStatus
|
||||||
|
from slowapi import _rate_limit_exceeded_handler
|
||||||
|
from slowapi.errors import RateLimitExceeded
|
||||||
|
|
||||||
|
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, DataError, DBAPIError
|
||||||
|
from asyncpg.exceptions import DataError as AsyncPGDataError, PostgresError
|
||||||
|
|
||||||
|
|
||||||
|
class ErrorDetail(BaseModel):
|
||||||
|
field: Optional[str] = None
|
||||||
|
message: str
|
||||||
|
code: Optional[str] = None
|
||||||
|
params: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ErrorResponse(BaseModel):
|
||||||
|
data: Optional[Any] = None
|
||||||
|
message: str
|
||||||
|
status: ResponseStatus = ResponseStatus.ERROR
|
||||||
|
errors: Optional[List[ErrorDetail]] = None
|
||||||
|
|
||||||
|
# Custom exception handler setup
|
||||||
|
|
||||||
|
|
||||||
|
def get_request_context(request: Request):
|
||||||
|
"""
|
||||||
|
Get detailed request context for logging.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_client_ip():
|
||||||
|
"""
|
||||||
|
Get the real client IP address from Kong Gateway headers.
|
||||||
|
Kong sets X-Real-IP and X-Forwarded-For headers by default.
|
||||||
|
"""
|
||||||
|
# Kong specific headers
|
||||||
|
if "X-Real-IP" in request.headers:
|
||||||
|
return request.headers["X-Real-IP"]
|
||||||
|
|
||||||
|
# Fallback to X-Forwarded-For
|
||||||
|
if "X-Forwarded-For" in request.headers:
|
||||||
|
# Get the first IP (original client)
|
||||||
|
return request.headers["X-Forwarded-For"].split(",")[0].strip()
|
||||||
|
|
||||||
|
# Last resort
|
||||||
|
return request.client.host
|
||||||
|
|
||||||
|
return {
|
||||||
|
"endpoint": request.url.path,
|
||||||
|
"url": request.url,
|
||||||
|
"method": request.method,
|
||||||
|
"remote_addr": get_client_ip(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def handle_sqlalchemy_error(error: SQLAlchemyError):
|
||||||
|
"""
|
||||||
|
Handle SQLAlchemy errors and return user-friendly error messages.
|
||||||
|
"""
|
||||||
|
original_error = getattr(error, 'orig', None)
|
||||||
|
print(original_error)
|
||||||
|
|
||||||
|
if isinstance(error, IntegrityError):
|
||||||
|
if "unique constraint" in str(error).lower():
|
||||||
|
return "This record already exists.", 409
|
||||||
|
elif "foreign key constraint" in str(error).lower():
|
||||||
|
return "Related record not found.", 400
|
||||||
|
else:
|
||||||
|
return "Data integrity error.", 400
|
||||||
|
elif isinstance(error, DataError) or isinstance(original_error, AsyncPGDataError):
|
||||||
|
return "Invalid data provided.", 400
|
||||||
|
elif isinstance(error, DBAPIError):
|
||||||
|
if "unique constraint" in str(error).lower():
|
||||||
|
return "This record already exists.", 409
|
||||||
|
elif "foreign key constraint" in str(error).lower():
|
||||||
|
return "Related record not found.", 400
|
||||||
|
elif "null value in column" in str(error).lower():
|
||||||
|
return "Required data missing.", 400
|
||||||
|
elif "invalid input for query argument" in str(error).lower():
|
||||||
|
return "Invalid data provided.", 400
|
||||||
|
else:
|
||||||
|
return "Database error.", 500
|
||||||
|
else:
|
||||||
|
# Log the full error for debugging purposes
|
||||||
|
logging.error(f"Unexpected database error: {str(error)}")
|
||||||
|
return "An unexpected database error occurred.", 500
|
||||||
|
|
||||||
|
|
||||||
|
def handle_exception(request: Request, exc: Exception):
|
||||||
|
"""
|
||||||
|
Global exception handler for Fastapi application.
|
||||||
|
"""
|
||||||
|
request_info = get_request_context(request)
|
||||||
|
|
||||||
|
if isinstance(exc, RateLimitExceeded):
|
||||||
|
_rate_limit_exceeded_handler(request, exc)
|
||||||
|
if isinstance(exc, HTTPException):
|
||||||
|
logging.error(
|
||||||
|
f"HTTP exception | Code: {exc.status_code} | Error: {exc.detail} | Request: {request_info}",
|
||||||
|
extra={"error_category": "http"},
|
||||||
|
)
|
||||||
|
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=exc.status_code,
|
||||||
|
content={
|
||||||
|
"data": None,
|
||||||
|
"message": str(exc.detail),
|
||||||
|
"status": ResponseStatus.ERROR,
|
||||||
|
"errors": [
|
||||||
|
ErrorDetail(
|
||||||
|
message=str(exc.detail)
|
||||||
|
).model_dump()
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(exc, SQLAlchemyError):
|
||||||
|
error_message, status_code = handle_sqlalchemy_error(exc)
|
||||||
|
logging.error(
|
||||||
|
f"Database Error | Error: {str(error_message)} | Request: {request_info}",
|
||||||
|
extra={"error_category": "database"},
|
||||||
|
)
|
||||||
|
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status_code,
|
||||||
|
content={
|
||||||
|
"data": None,
|
||||||
|
"message": error_message,
|
||||||
|
"status": ResponseStatus.ERROR,
|
||||||
|
"errors": [
|
||||||
|
ErrorDetail(
|
||||||
|
message=error_message
|
||||||
|
).model_dump()
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Log unexpected errors
|
||||||
|
logging.error(
|
||||||
|
f"Unexpected Error | Error: {str(exc)} | Request: {request_info}",
|
||||||
|
extra={"error_category": "unexpected"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=500,
|
||||||
|
content={
|
||||||
|
"data": None,
|
||||||
|
"message": exc.__class__.__name__,
|
||||||
|
"status": ResponseStatus.ERROR,
|
||||||
|
"errors": [
|
||||||
|
ErrorDetail(
|
||||||
|
message="An unexpected error occurred."
|
||||||
|
).model_dump()
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
@ -0,0 +1,33 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
from src.config import LOG_LEVEL
|
||||||
|
from src.enums import OptimumOHEnum
|
||||||
|
|
||||||
|
|
||||||
|
LOG_FORMAT_DEBUG = "%(levelname)s:%(message)s:%(pathname)s:%(funcName)s:%(lineno)d"
|
||||||
|
|
||||||
|
|
||||||
|
class LogLevels(OptimumOHEnum):
|
||||||
|
info = "INFO"
|
||||||
|
warn = "WARN"
|
||||||
|
error = "ERROR"
|
||||||
|
debug = "DEBUG"
|
||||||
|
|
||||||
|
|
||||||
|
def configure_logging():
|
||||||
|
log_level = str(LOG_LEVEL).upper() # cast to string
|
||||||
|
log_levels = list(LogLevels)
|
||||||
|
|
||||||
|
if log_level not in log_levels:
|
||||||
|
# we use error as the default log level
|
||||||
|
logging.basicConfig(level=LogLevels.error)
|
||||||
|
return
|
||||||
|
|
||||||
|
if log_level == LogLevels.debug:
|
||||||
|
logging.basicConfig(level=log_level, format=LOG_FORMAT_DEBUG)
|
||||||
|
return
|
||||||
|
|
||||||
|
logging.basicConfig(level=log_level)
|
||||||
|
|
||||||
|
# sometimes the slack client can be too verbose
|
||||||
|
logging.getLogger("slack_sdk.web.base_client").setLevel(logging.CRITICAL)
|
||||||
@ -0,0 +1,112 @@
|
|||||||
|
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
from os import path
|
||||||
|
from uuid import uuid1
|
||||||
|
from typing import Optional, Final
|
||||||
|
from contextvars import ContextVar
|
||||||
|
|
||||||
|
from fastapi import FastAPI, HTTPException, status
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from slowapi import _rate_limit_exceeded_handler
|
||||||
|
from slowapi.errors import RateLimitExceeded
|
||||||
|
from sqlalchemy import inspect
|
||||||
|
from sqlalchemy.orm import scoped_session
|
||||||
|
from sqlalchemy.ext.asyncio import async_scoped_session
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
||||||
|
from starlette.requests import Request
|
||||||
|
from starlette.routing import compile_path
|
||||||
|
from starlette.middleware.gzip import GZipMiddleware
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
from starlette.responses import Response, StreamingResponse, FileResponse
|
||||||
|
from starlette.staticfiles import StaticFiles
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from src.enums import ResponseStatus
|
||||||
|
from src.logging import configure_logging
|
||||||
|
from src.rate_limiter import limiter
|
||||||
|
from src.api import api_router
|
||||||
|
from src.database.core import engine, async_session
|
||||||
|
from src.exceptions import handle_exception
|
||||||
|
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# we configure the logging level and format
|
||||||
|
configure_logging()
|
||||||
|
|
||||||
|
# we define the exception handlers
|
||||||
|
exception_handlers = {Exception: handle_exception}
|
||||||
|
|
||||||
|
# we create the ASGI for the app
|
||||||
|
app = FastAPI(exception_handlers=exception_handlers, openapi_url="", title="LCCA API",
|
||||||
|
description="Welcome to LCCA's API documentation!",
|
||||||
|
version="0.1.0")
|
||||||
|
app.state.limiter = limiter
|
||||||
|
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
||||||
|
app.add_middleware(GZipMiddleware, minimum_size=2000)
|
||||||
|
# credentials: "include",
|
||||||
|
|
||||||
|
|
||||||
|
REQUEST_ID_CTX_KEY: Final[str] = "request_id"
|
||||||
|
_request_id_ctx_var: ContextVar[Optional[str]] = ContextVar(
|
||||||
|
REQUEST_ID_CTX_KEY, default=None)
|
||||||
|
|
||||||
|
|
||||||
|
def get_request_id() -> Optional[str]:
|
||||||
|
return _request_id_ctx_var.get()
|
||||||
|
|
||||||
|
|
||||||
|
@app.middleware("http")
|
||||||
|
async def db_session_middleware(request: Request, call_next):
|
||||||
|
request_id = str(uuid1())
|
||||||
|
|
||||||
|
# we create a per-request id such that we can ensure that our session is scoped for a particular request.
|
||||||
|
# see: https://github.com/tiangolo/fastapi/issues/726
|
||||||
|
ctx_token = _request_id_ctx_var.set(request_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
session = async_scoped_session(async_session, scopefunc=get_request_id)
|
||||||
|
request.state.db = session()
|
||||||
|
response = await call_next(request)
|
||||||
|
except Exception as e:
|
||||||
|
raise e from None
|
||||||
|
finally:
|
||||||
|
await request.state.db.close()
|
||||||
|
|
||||||
|
_request_id_ctx_var.reset(ctx_token)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@app.middleware("http")
|
||||||
|
async def add_security_headers(request: Request, call_next):
|
||||||
|
response = await call_next(request)
|
||||||
|
response.headers["Strict-Transport-Security"] = "max-age=31536000 ; includeSubDomains"
|
||||||
|
return response
|
||||||
|
|
||||||
|
# class MetricsMiddleware(BaseHTTPMiddleware):
|
||||||
|
# async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
|
||||||
|
# method = request.method
|
||||||
|
# endpoint = request.url.path
|
||||||
|
# tags = {"method": method, "endpoint": endpoint}
|
||||||
|
|
||||||
|
# try:
|
||||||
|
# start = time.perf_counter()
|
||||||
|
# response = await call_next(request)
|
||||||
|
# elapsed_time = time.perf_counter() - start
|
||||||
|
# tags.update({"status_code": response.status_code})
|
||||||
|
# metric_provider.counter("server.call.counter", tags=tags)
|
||||||
|
# metric_provider.timer("server.call.elapsed", value=elapsed_time, tags=tags)
|
||||||
|
# log.debug(f"server.call.elapsed.{endpoint}: {elapsed_time}")
|
||||||
|
# except Exception as e:
|
||||||
|
# metric_provider.counter("server.call.exception.counter", tags=tags)
|
||||||
|
# raise e from None
|
||||||
|
# return response
|
||||||
|
|
||||||
|
|
||||||
|
# app.add_middleware(ExceptionMiddleware)
|
||||||
|
|
||||||
|
app.include_router(api_router)
|
||||||
@ -0,0 +1,46 @@
|
|||||||
|
# import logging
|
||||||
|
|
||||||
|
# from dispatch.plugins.base import plugins
|
||||||
|
|
||||||
|
# from .config import METRIC_PROVIDERS
|
||||||
|
|
||||||
|
# log = logging.getLogger(__file__)
|
||||||
|
|
||||||
|
|
||||||
|
# class Metrics(object):
|
||||||
|
# _providers = []
|
||||||
|
|
||||||
|
# def __init__(self):
|
||||||
|
# if not METRIC_PROVIDERS:
|
||||||
|
# log.info(
|
||||||
|
# "No metric providers defined via METRIC_PROVIDERS env var. Metrics will not be sent."
|
||||||
|
# )
|
||||||
|
# else:
|
||||||
|
# self._providers = METRIC_PROVIDERS
|
||||||
|
|
||||||
|
# def gauge(self, name, value, tags=None):
|
||||||
|
# for provider in self._providers:
|
||||||
|
# log.debug(
|
||||||
|
# f"Sending gauge metric {name} to provider {provider}. Value: {value} Tags: {tags}"
|
||||||
|
# )
|
||||||
|
# p = plugins.get(provider)
|
||||||
|
# p.gauge(name, value, tags=tags)
|
||||||
|
|
||||||
|
# def counter(self, name, value=None, tags=None):
|
||||||
|
# for provider in self._providers:
|
||||||
|
# log.debug(
|
||||||
|
# f"Sending counter metric {name} to provider {provider}. Value: {value} Tags: {tags}"
|
||||||
|
# )
|
||||||
|
# p = plugins.get(provider)
|
||||||
|
# p.counter(name, value=value, tags=tags)
|
||||||
|
|
||||||
|
# def timer(self, name, value, tags=None):
|
||||||
|
# for provider in self._providers:
|
||||||
|
# log.debug(
|
||||||
|
# f"Sending timer metric {name} to provider {provider}. Value: {value} Tags: {tags}"
|
||||||
|
# )
|
||||||
|
# p = plugins.get(provider)
|
||||||
|
# p.timer(name, value, tags=tags)
|
||||||
|
|
||||||
|
|
||||||
|
# provider = Metrics()
|
||||||
@ -0,0 +1,85 @@
|
|||||||
|
# src/common/models.py
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Generic, Optional, TypeVar
|
||||||
|
import uuid
|
||||||
|
from pydantic import BaseModel, Field, SecretStr
|
||||||
|
from sqlalchemy import Column, DateTime, String, func, event
|
||||||
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
from src.config import TIMEZONE
|
||||||
|
import pytz
|
||||||
|
|
||||||
|
from src.auth.service import CurrentUser
|
||||||
|
from src.enums import ResponseStatus
|
||||||
|
# SQLAlchemy Mixins
|
||||||
|
|
||||||
|
|
||||||
|
class TimeStampMixin(object):
|
||||||
|
"""Timestamping mixin"""
|
||||||
|
|
||||||
|
created_at = Column(
|
||||||
|
DateTime(timezone=True), default=datetime.now(pytz.timezone(TIMEZONE)))
|
||||||
|
created_at._creation_order = 9998
|
||||||
|
updated_at = Column(
|
||||||
|
DateTime(timezone=True), default=datetime.now(pytz.timezone(TIMEZONE)))
|
||||||
|
updated_at._creation_order = 9998
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _updated_at(mapper, connection, target):
|
||||||
|
target.updated_at = datetime.now(pytz.timezone(TIMEZONE))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __declare_last__(cls):
|
||||||
|
event.listen(cls, "before_update", cls._updated_at)
|
||||||
|
|
||||||
|
|
||||||
|
class UUIDMixin:
|
||||||
|
"""UUID mixin"""
|
||||||
|
id = Column(UUID(as_uuid=True), primary_key=True,
|
||||||
|
default=uuid.uuid4, unique=True, nullable=False)
|
||||||
|
|
||||||
|
|
||||||
|
class IdentityMixin:
|
||||||
|
"""Identity mixin"""
|
||||||
|
created_by = Column(String(100), nullable=True)
|
||||||
|
updated_by = Column(String(100), nullable=True)
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultMixin(TimeStampMixin, UUIDMixin):
|
||||||
|
"""Default mixin"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# Pydantic Models
|
||||||
|
class DefultBase(BaseModel):
|
||||||
|
class Config:
|
||||||
|
from_attributes = True
|
||||||
|
validate_assignment = True
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
str_strip_whitespace = True
|
||||||
|
|
||||||
|
json_encoders = {
|
||||||
|
# custom output conversion for datetime
|
||||||
|
datetime: lambda v: v.strftime("%Y-%m-%dT%H:%M:%S.%fZ") if v else None,
|
||||||
|
SecretStr: lambda v: v.get_secret_value() if v else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Pagination(DefultBase):
|
||||||
|
itemsPerPage: int
|
||||||
|
page: int
|
||||||
|
total: int
|
||||||
|
|
||||||
|
|
||||||
|
class PrimaryKeyModel(BaseModel):
|
||||||
|
id: uuid.UUID
|
||||||
|
|
||||||
|
|
||||||
|
# Define data type variable for generic response
|
||||||
|
T = TypeVar('T')
|
||||||
|
|
||||||
|
|
||||||
|
class StandardResponse(BaseModel, Generic[T]):
|
||||||
|
data: Optional[T] = None
|
||||||
|
message: str = "Success"
|
||||||
|
status: ResponseStatus = ResponseStatus.SUCCESS
|
||||||
@ -0,0 +1,59 @@
|
|||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, status
|
||||||
|
|
||||||
|
from src.overhaul.service import get_overhaul_critical_parts, get_overhaul_overview, get_overhaul_schedules, get_overhaul_system_components
|
||||||
|
|
||||||
|
from .schema import OverhaulRead, OverhaulSchedules, OverhaulCriticalParts, OverhaulSystemComponents
|
||||||
|
|
||||||
|
from src.models import StandardResponse
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("", response_model=StandardResponse[OverhaulRead])
|
||||||
|
async def get_overhaul():
|
||||||
|
"""Get all scope pagination."""
|
||||||
|
overview = get_overhaul_overview()
|
||||||
|
schedules = get_overhaul_schedules()
|
||||||
|
criticalParts = get_overhaul_critical_parts()
|
||||||
|
systemComponents = get_overhaul_system_components()
|
||||||
|
|
||||||
|
return StandardResponse(
|
||||||
|
data=OverhaulRead(
|
||||||
|
overview=overview,
|
||||||
|
schedules=schedules,
|
||||||
|
criticalParts=criticalParts,
|
||||||
|
systemComponents=systemComponents,
|
||||||
|
),
|
||||||
|
message="Data retrieved successfully",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/schedules", response_model=StandardResponse[OverhaulSchedules])
|
||||||
|
async def get_schedules():
|
||||||
|
"""Get all overhaul schedules."""
|
||||||
|
schedules = get_overhaul_schedules()
|
||||||
|
return StandardResponse(
|
||||||
|
data=OverhaulSchedules(schedules=schedules),
|
||||||
|
message="Data retrieved successfully",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/critical-parts", response_model=StandardResponse[OverhaulCriticalParts])
|
||||||
|
async def get_critical_parts():
|
||||||
|
"""Get all overhaul critical parts."""
|
||||||
|
criticalParts = get_overhaul_critical_parts()
|
||||||
|
return StandardResponse(
|
||||||
|
data=OverhaulCriticalParts(criticalParts=criticalParts),
|
||||||
|
message="Data retrieved successfully",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/system-components", response_model=StandardResponse[OverhaulSystemComponents])
|
||||||
|
async def get_system_components():
|
||||||
|
"""Get all overhaul system components."""
|
||||||
|
systemComponents = get_overhaul_system_components()
|
||||||
|
return StandardResponse(
|
||||||
|
data=OverhaulSystemComponents(systemComponents=systemComponents),
|
||||||
|
message="Data retrieved successfully",
|
||||||
|
)
|
||||||
@ -0,0 +1,71 @@
|
|||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from pydantic import Field, BaseModel
|
||||||
|
from src.models import DefultBase, Pagination
|
||||||
|
|
||||||
|
|
||||||
|
class OverhaulBase(BaseModel):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class OverhaulCriticalParts(OverhaulBase):
|
||||||
|
criticalParts: List[str] = Field(..., description="List of critical parts")
|
||||||
|
|
||||||
|
|
||||||
|
class OverhaulSchedules(OverhaulBase):
|
||||||
|
schedules: List[Dict[str, Any]
|
||||||
|
] = Field(..., description="List of schedules")
|
||||||
|
|
||||||
|
|
||||||
|
class OverhaulSystemComponents(OverhaulBase):
|
||||||
|
systemComponents: Dict[str,
|
||||||
|
Any] = Field(..., description="List of system components")
|
||||||
|
|
||||||
|
|
||||||
|
class OverhaulRead(OverhaulBase):
|
||||||
|
overview: Dict[str, Any]
|
||||||
|
criticalParts: List[str]
|
||||||
|
schedules: List[Dict[str, Any]]
|
||||||
|
systemComponents: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
# {
|
||||||
|
# "overview": {
|
||||||
|
# "totalEquipment": 30,
|
||||||
|
# "nextSchedule": {
|
||||||
|
# "date": "2025-01-12",
|
||||||
|
# "Overhaul": "B",
|
||||||
|
# "equipmentCount": 30
|
||||||
|
# }
|
||||||
|
# },
|
||||||
|
# "criticalParts": [
|
||||||
|
# "Boiler feed pump",
|
||||||
|
# "Boiler reheater system",
|
||||||
|
# "Drum Level (Right) Root Valve A",
|
||||||
|
# "BCP A Discharge Valve",
|
||||||
|
# "BFPT A EXH Press HI Root VLV"
|
||||||
|
# ],
|
||||||
|
# "schedules": [
|
||||||
|
# {
|
||||||
|
# "date": "2025-01-12",
|
||||||
|
# "Overhaul": "B",
|
||||||
|
# "status": "upcoming"
|
||||||
|
# }
|
||||||
|
# // ... other scheduled overhauls
|
||||||
|
# ],
|
||||||
|
# "systemComponents": {
|
||||||
|
# "boiler": {
|
||||||
|
# "status": "operational",
|
||||||
|
# "lastOverhaul": "2024-06-15"
|
||||||
|
# },
|
||||||
|
# "turbine": {
|
||||||
|
# "hpt": { "status": "operational" },
|
||||||
|
# "ipt": { "status": "operational" },
|
||||||
|
# "lpt": { "status": "operational" }
|
||||||
|
# }
|
||||||
|
# // ... other major components
|
||||||
|
# }
|
||||||
|
# }
|
||||||
@ -0,0 +1,109 @@
|
|||||||
|
|
||||||
|
|
||||||
|
from sqlalchemy import Select, Delete
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from src.database.core import DbSession
|
||||||
|
from src.auth.service import CurrentUser
|
||||||
|
|
||||||
|
|
||||||
|
def get_overhaul_overview():
|
||||||
|
"""Get all overhaul overview."""
|
||||||
|
return {
|
||||||
|
"totalEquipment": 30,
|
||||||
|
"nextSchedule": {
|
||||||
|
"start_date": "2025-01-12",
|
||||||
|
"end_date": "2025-01-15",
|
||||||
|
"duration": 3,
|
||||||
|
"Overhaul": "B",
|
||||||
|
"equipmentCount": 30
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_overhaul_critical_parts():
|
||||||
|
"""Get all overhaul critical parts."""
|
||||||
|
return [
|
||||||
|
"Boiler feed pump",
|
||||||
|
"Boiler reheater system",
|
||||||
|
"Drum Level (Right) Root Valve A",
|
||||||
|
"BCP A Discharge Valve",
|
||||||
|
"BFPT A EXH Press HI Root VLV"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_overhaul_schedules():
|
||||||
|
"""Get all overhaul schedules."""
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"date": "2025-01-12",
|
||||||
|
"Overhaul": "B",
|
||||||
|
"status": "upcoming"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"date": "2025-02-15",
|
||||||
|
"Overhaul": "A",
|
||||||
|
"status": "upcoming"
|
||||||
|
},
|
||||||
|
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_overhaul_system_components():
|
||||||
|
"""Get all overhaul system components."""
|
||||||
|
return {
|
||||||
|
"boiler": {
|
||||||
|
"efficiency": "90%",
|
||||||
|
"work_hours": "1000",
|
||||||
|
"reliability": "95%",
|
||||||
|
},
|
||||||
|
"HPT": {
|
||||||
|
"efficiency": "90%",
|
||||||
|
"work_hours": "1000",
|
||||||
|
"reliability": "95%",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# async def get(*, db_session: DbSession, scope_id: str) -> Optional[Scope]:
|
||||||
|
# """Returns a document based on the given document id."""
|
||||||
|
# query = Select(Scope).filter(Scope.id == scope_id)
|
||||||
|
# result = await db_session.execute(query)
|
||||||
|
# return result.scalars().one_or_none()
|
||||||
|
|
||||||
|
|
||||||
|
# async def get_all(*, db_session: DbSession):
|
||||||
|
# """Returns all documents."""
|
||||||
|
# query = Select(Scope)
|
||||||
|
# result = await db_session.execute(query)
|
||||||
|
# return result.scalars().all()
|
||||||
|
|
||||||
|
|
||||||
|
# async def create(*, db_session: DbSession, scope_id: ScopeCreate):
|
||||||
|
# """Creates a new document."""
|
||||||
|
# scope = Scope(**scope_id.model_dump())
|
||||||
|
# db_session.add(scope)
|
||||||
|
# await db_session.commit()
|
||||||
|
# return scope
|
||||||
|
|
||||||
|
|
||||||
|
# async def update(*, db_session: DbSession, scope: Scope, scope_id: ScopeUpdate):
|
||||||
|
# """Updates a document."""
|
||||||
|
# data = scope_id.model_dump()
|
||||||
|
|
||||||
|
# update_data = scope_id.model_dump(exclude_defaults=True)
|
||||||
|
|
||||||
|
# for field in data:
|
||||||
|
# if field in update_data:
|
||||||
|
# setattr(scope, field, update_data[field])
|
||||||
|
|
||||||
|
# await db_session.commit()
|
||||||
|
|
||||||
|
# return scope
|
||||||
|
|
||||||
|
|
||||||
|
# async def delete(*, db_session: DbSession, scope_id: str):
|
||||||
|
# """Deletes a document."""
|
||||||
|
# query = Delete(Scope).where(Scope.id == scope_id)
|
||||||
|
# await db_session.execute(query)
|
||||||
|
# await db_session.commit()
|
||||||
@ -0,0 +1,5 @@
|
|||||||
|
from slowapi import Limiter
|
||||||
|
from slowapi.util import get_remote_address
|
||||||
|
|
||||||
|
|
||||||
|
limiter = Limiter(key_func=get_remote_address)
|
||||||
@ -0,0 +1,12 @@
|
|||||||
|
|
||||||
|
from sqlalchemy import Column, Float, Integer, String
|
||||||
|
from src.database.core import Base
|
||||||
|
from src.models import DefaultMixin, IdentityMixin, TimeStampMixin
|
||||||
|
|
||||||
|
|
||||||
|
class Scope(Base, DefaultMixin):
|
||||||
|
__tablename__ = "oh_scope"
|
||||||
|
|
||||||
|
scope_name = Column(String, nullable=True)
|
||||||
|
duration_oh = Column(Integer, nullable=True)
|
||||||
|
crew = Column(Integer, nullable=True)
|
||||||
@ -0,0 +1,70 @@
|
|||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, status
|
||||||
|
|
||||||
|
from .model import Scope
|
||||||
|
from .schema import ScopeCreate, ScopeRead, ScopeUpdate, ScopePagination
|
||||||
|
from .service import get, get_all, create, update, delete
|
||||||
|
|
||||||
|
from src.database.service import CommonParameters, search_filter_sort_paginate
|
||||||
|
from src.database.core import DbSession
|
||||||
|
from src.auth.service import CurrentUser
|
||||||
|
from src.models import StandardResponse
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("", response_model=StandardResponse[ScopePagination])
|
||||||
|
async def get_scopes(common: CommonParameters):
|
||||||
|
"""Get all scope pagination."""
|
||||||
|
# return
|
||||||
|
return StandardResponse(
|
||||||
|
data=await search_filter_sort_paginate(model=Scope, **common),
|
||||||
|
message="Data retrieved successfully",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{scope_id}", response_model=StandardResponse[ScopeRead])
|
||||||
|
async def get_scope(db_session: DbSession, scope_id: str):
|
||||||
|
scope = await get(db_session=db_session, scope_id=scope_id)
|
||||||
|
if not scope:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="A data with this id does not exist.",
|
||||||
|
)
|
||||||
|
|
||||||
|
return StandardResponse(data=scope, message="Data retrieved successfully")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("", response_model=StandardResponse[ScopeRead])
|
||||||
|
async def create_scope(db_session: DbSession, scope_in: ScopeCreate):
|
||||||
|
scope = await create(db_session=db_session, scope_in=scope_in)
|
||||||
|
|
||||||
|
return StandardResponse(data=scope, message="Data created successfully")
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{scope_id}", response_model=StandardResponse[ScopeRead])
|
||||||
|
async def update_scope(db_session: DbSession, scope_id: str, scope_in: ScopeUpdate, current_user: CurrentUser):
|
||||||
|
scope = await get(db_session=db_session, scope_id=scope_id)
|
||||||
|
|
||||||
|
if not scope:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="A data with this id does not exist.",
|
||||||
|
)
|
||||||
|
|
||||||
|
return StandardResponse(data=await update(db_session=db_session, scope=scope, scope_in=scope_in), message="Data updated successfully")
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{scope_id}", response_model=StandardResponse[ScopeRead])
|
||||||
|
async def delete_scope(db_session: DbSession, scope_id: str):
|
||||||
|
scope = await get(db_session=db_session, scope_id=scope_id)
|
||||||
|
|
||||||
|
if not scope:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=[{"msg": "A data with this id does not exist."}],
|
||||||
|
)
|
||||||
|
|
||||||
|
await delete(db_session=db_session, scope_id=scope_id)
|
||||||
|
|
||||||
|
return StandardResponse(message="Data deleted successfully", data=scope)
|
||||||
@ -0,0 +1,29 @@
|
|||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Optional
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from src.models import DefultBase, Pagination
|
||||||
|
|
||||||
|
|
||||||
|
class ScopeBase(DefultBase):
|
||||||
|
scope_name: Optional[str] = Field(None, title="Scope Name")
|
||||||
|
duration_oh: Optional[int] = Field(None, title="Duration OH")
|
||||||
|
crew: Optional[int] = Field(None, title="Crew")
|
||||||
|
|
||||||
|
|
||||||
|
class ScopeCreate(ScopeBase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ScopeUpdate(ScopeBase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ScopeRead(ScopeBase):
|
||||||
|
id: UUID
|
||||||
|
|
||||||
|
|
||||||
|
class ScopePagination(Pagination):
|
||||||
|
items: List[ScopeRead] = []
|
||||||
@ -0,0 +1,60 @@
|
|||||||
|
|
||||||
|
|
||||||
|
from sqlalchemy import Select, Delete
|
||||||
|
from .model import Scope
|
||||||
|
from .schema import ScopeCreate, ScopeUpdate
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from src.database.core import DbSession
|
||||||
|
from src.auth.service import CurrentUser
|
||||||
|
|
||||||
|
|
||||||
|
async def get(*, db_session: DbSession, scope_id: str) -> Optional[Scope]:
|
||||||
|
"""Returns a document based on the given document id."""
|
||||||
|
query = Select(Scope).filter(Scope.id == scope_id)
|
||||||
|
result = await db_session.execute(query)
|
||||||
|
return result.scalars().one_or_none()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_all(*, db_session: DbSession):
|
||||||
|
"""Returns all documents."""
|
||||||
|
query = Select(Scope)
|
||||||
|
result = await db_session.execute(query)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
|
||||||
|
async def create(*, db_session: DbSession, scope_in: ScopeCreate):
|
||||||
|
"""Creates a new document."""
|
||||||
|
scope = Scope(**scope_in.model_dump())
|
||||||
|
db_session.add(scope)
|
||||||
|
await db_session.commit()
|
||||||
|
return scope
|
||||||
|
|
||||||
|
|
||||||
|
async def update(*, db_session: DbSession, scope: Scope, scope_in: ScopeUpdate):
|
||||||
|
"""Updates a document."""
|
||||||
|
data = scope_in.model_dump()
|
||||||
|
|
||||||
|
update_data = scope_in.model_dump(exclude_defaults=True)
|
||||||
|
|
||||||
|
for field in data:
|
||||||
|
if field in update_data:
|
||||||
|
setattr(scope, field, update_data[field])
|
||||||
|
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
return scope
|
||||||
|
|
||||||
|
|
||||||
|
async def delete(*, db_session: DbSession, scope_id: str):
|
||||||
|
"""Deletes a document."""
|
||||||
|
query = Delete(Scope).where(Scope.id == scope_id)
|
||||||
|
await db_session.execute(query)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_by_scope_name(*, db_session: DbSession, scope_name: str) -> Optional[Scope]:
|
||||||
|
"""Returns a document based on the given document id."""
|
||||||
|
query = Select(Scope).filter(Scope.scope_name == scope_name)
|
||||||
|
result = await db_session.execute(query)
|
||||||
|
return result.scalars().one_or_none()
|
||||||
@ -0,0 +1,14 @@
|
|||||||
|
|
||||||
|
from sqlalchemy import UUID, Column, Float, Integer, String, ForeignKey
|
||||||
|
from src.database.core import Base
|
||||||
|
from src.models import DefaultMixin, IdentityMixin, TimeStampMixin
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
|
|
||||||
|
class ScopeEquipment(Base, DefaultMixin):
|
||||||
|
__tablename__ = "oh_scope_equip"
|
||||||
|
|
||||||
|
assetnum = Column(String, nullable=True)
|
||||||
|
scope_id = Column(UUID(as_uuid=True), ForeignKey('oh_scope.id'), nullable=False)
|
||||||
|
|
||||||
|
scope = relationship("Scope", backref="scope_equipments", lazy="selectin")
|
||||||
@ -0,0 +1,81 @@
|
|||||||
|
|
||||||
|
from typing import List
|
||||||
|
from fastapi import APIRouter, HTTPException, status
|
||||||
|
from fastapi.params import Query
|
||||||
|
|
||||||
|
from .model import ScopeEquipment
|
||||||
|
from .schema import ScopeEquipmentCreate, ScopeEquipmentPagination, ScopeEquipmentRead, ScopeEquipmentUpdate
|
||||||
|
from .service import get, get_all, create, update, delete, get_by_scope_name, get_exculed_scope_name
|
||||||
|
|
||||||
|
from src.database.service import CommonParameters, search_filter_sort_paginate
|
||||||
|
from src.database.core import DbSession
|
||||||
|
from src.auth.service import CurrentUser
|
||||||
|
from src.models import StandardResponse
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("", response_model=StandardResponse[ScopeEquipmentPagination])
|
||||||
|
async def get_scope_equipments(common: CommonParameters):
|
||||||
|
"""Get all scope pagination."""
|
||||||
|
# return
|
||||||
|
return StandardResponse(
|
||||||
|
data=await search_filter_sort_paginate(model=ScopeEquipment, **common),
|
||||||
|
message="Data retrieved successfully",
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get("/scope/{scope_name}", response_model=StandardResponse[List[ScopeEquipmentRead]])
|
||||||
|
async def get_scope_name(db_session: DbSession, scope_name: str, exclude: bool = Query(False)):
|
||||||
|
if exclude:
|
||||||
|
return StandardResponse(data=await get_exculed_scope_name(db_session=db_session, scope_name=scope_name), message="Data retrieved successfully")
|
||||||
|
|
||||||
|
return StandardResponse(data=await get_by_scope_name(db_session=db_session, scope_name=scope_name), message="Data retrieved successfully")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{scope_equipment_id}", response_model=StandardResponse[ScopeEquipmentRead])
|
||||||
|
async def get_scope_equipment(db_session: DbSession, scope_equipment_id: str):
|
||||||
|
scope = await get(db_session=db_session, scope_equipment_id=scope_equipment_id)
|
||||||
|
if not scope:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="A data with this id does not exist.",
|
||||||
|
)
|
||||||
|
|
||||||
|
return StandardResponse(data=scope, message="Data retrieved successfully")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("", response_model=StandardResponse[ScopeEquipmentRead])
|
||||||
|
async def create_scope_equipment(db_session: DbSession, scope__equipment_in: ScopeEquipmentCreate):
|
||||||
|
scope = await create(db_session=db_session, scope__equipment_in=scope__equipment_in)
|
||||||
|
|
||||||
|
return StandardResponse(data=scope, message="Data created successfully")
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{scope_equipment_id}", response_model=StandardResponse[ScopeEquipmentRead])
|
||||||
|
async def update_scope_equipment(db_session: DbSession, scope_equipment_id: str, scope__equipment_in: ScopeEquipmentUpdate):
|
||||||
|
scope_equipment = await get(db_session=db_session, scope_equipment_id=scope_equipment_id)
|
||||||
|
|
||||||
|
if not scope_equipment:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="A data with this id does not exist.",
|
||||||
|
)
|
||||||
|
|
||||||
|
return StandardResponse(data=await update(db_session=db_session, scope_equipment=scope_equipment, scope__equipment_in=scope__equipment_in), message="Data updated successfully")
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{scope_equipment_id}", response_model=StandardResponse[ScopeEquipmentRead])
|
||||||
|
async def delete_scope_equipment(db_session: DbSession, scope_equipment_id: str):
|
||||||
|
scope_equipment = await get(db_session=db_session, scope_equipment_id=scope_equipment_id)
|
||||||
|
|
||||||
|
if not scope_equipment:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=[{"msg": "A data with this id does not exist."}],
|
||||||
|
)
|
||||||
|
|
||||||
|
await delete(db_session=db_session, scope_equipment_id=scope_equipment_id)
|
||||||
|
|
||||||
|
return StandardResponse(message="Data deleted successfully", data=scope_equipment)
|
||||||
|
|
||||||
|
|
||||||
@ -0,0 +1,28 @@
|
|||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Optional
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from src.models import DefultBase, Pagination
|
||||||
|
|
||||||
|
|
||||||
|
class ScopeEquipmentBase(DefultBase):
|
||||||
|
scope_id: Optional[UUID] = Field(None, title="Scope ID")
|
||||||
|
|
||||||
|
|
||||||
|
class ScopeEquipmentCreate(ScopeEquipmentBase):
|
||||||
|
assetnum: str
|
||||||
|
|
||||||
|
|
||||||
|
class ScopeEquipmentUpdate(ScopeEquipmentBase):
|
||||||
|
assetnum: Optional[str] = Field(None, title="Asset Number")
|
||||||
|
|
||||||
|
|
||||||
|
class ScopeEquipmentRead(ScopeEquipmentBase):
|
||||||
|
id: UUID
|
||||||
|
assetnum: str
|
||||||
|
|
||||||
|
|
||||||
|
class ScopeEquipmentPagination(Pagination):
|
||||||
|
items: List[ScopeEquipmentRead] = []
|
||||||
@ -0,0 +1,84 @@
|
|||||||
|
|
||||||
|
|
||||||
|
from sqlalchemy import Select, Delete
|
||||||
|
from .model import ScopeEquipment
|
||||||
|
from src.scope.service import get_by_scope_name as get_scope_by_name_service
|
||||||
|
from .schema import ScopeEquipmentCreate, ScopeEquipmentUpdate
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from src.database.core import DbSession
|
||||||
|
from src.auth.service import CurrentUser
|
||||||
|
|
||||||
|
|
||||||
|
async def get(*, db_session: DbSession, scope_equipment_id: str) -> Optional[ScopeEquipment]:
|
||||||
|
"""Returns a document based on the given document id."""
|
||||||
|
query = Select(ScopeEquipment).filter(
|
||||||
|
ScopeEquipment.id == scope_equipment_id)
|
||||||
|
result = await db_session.execute(query)
|
||||||
|
return result.scalars().one_or_none()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_all(*, db_session: DbSession):
|
||||||
|
"""Returns all documents."""
|
||||||
|
query = Select(ScopeEquipment)
|
||||||
|
result = await db_session.execute(query)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
|
||||||
|
async def create(*, db_session: DbSession, scope_equipment_in: ScopeEquipmentCreate):
|
||||||
|
"""Creates a new document."""
|
||||||
|
scope_equipment = ScopeEquipment(**scope_equipment_in.model_dump())
|
||||||
|
db_session.add(scope_equipment)
|
||||||
|
await db_session.commit()
|
||||||
|
return scope_equipment
|
||||||
|
|
||||||
|
|
||||||
|
async def update(*, db_session: DbSession, scope_equipment: ScopeEquipment, scope_equipment_in: ScopeEquipmentUpdate):
|
||||||
|
"""Updates a document."""
|
||||||
|
data = scope_equipment_in.model_dump()
|
||||||
|
|
||||||
|
update_data = scope_equipment_in.model_dump(exclude_defaults=True)
|
||||||
|
|
||||||
|
for field in data:
|
||||||
|
if field in update_data:
|
||||||
|
setattr(scope_equipment, field, update_data[field])
|
||||||
|
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
return scope_equipment
|
||||||
|
|
||||||
|
|
||||||
|
async def delete(*, db_session: DbSession, scope_equipment_id: str):
|
||||||
|
"""Deletes a document."""
|
||||||
|
query = Delete(ScopeEquipment).where(
|
||||||
|
ScopeEquipment.id == scope_equipment_id)
|
||||||
|
await db_session.execute(query)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_by_scope_name(*, db_session: DbSession, scope_name: Union[str, list]) -> Optional[ScopeEquipment]:
|
||||||
|
"""Returns a document based on the given document id."""
|
||||||
|
scope = await get_scope_by_name_service(db_session=db_session, scope_name=scope_name)
|
||||||
|
|
||||||
|
query = Select(ScopeEquipment)
|
||||||
|
|
||||||
|
if scope:
|
||||||
|
query = query.filter(ScopeEquipment.scope_id == scope.id)
|
||||||
|
|
||||||
|
result = await db_session.execute(query)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_exculed_scope_name(*, db_session: DbSession, scope_name: Union[str, list]) -> Optional[ScopeEquipment]:
|
||||||
|
scope = await get_scope_by_name_service(db_session=db_session, scope_name=scope_name)
|
||||||
|
|
||||||
|
query = Select(ScopeEquipment)
|
||||||
|
|
||||||
|
if scope:
|
||||||
|
query = query.filter(ScopeEquipment.scope_id != scope.id)
|
||||||
|
|
||||||
|
else:
|
||||||
|
query = query.filter(ScopeEquipment.scope_id != None)
|
||||||
|
|
||||||
|
result = await db_session.execute(query)
|
||||||
|
return result.scalars().all()
|
||||||
@ -0,0 +1,40 @@
|
|||||||
|
from connectors.database import DBConfig
|
||||||
|
from starlette.config import Config
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def get_config():
|
||||||
|
try:
|
||||||
|
# Try to load from .env file first
|
||||||
|
config = Config(".env")
|
||||||
|
except FileNotFoundError:
|
||||||
|
# If .env doesn't exist, use environment variables
|
||||||
|
config = Config(environ=os.environ)
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
env = get_config()
|
||||||
|
|
||||||
|
config = {
|
||||||
|
'batch_size': 1000,
|
||||||
|
'target_table': 'oh_wo_master',
|
||||||
|
'columns': ['assetnum', 'worktype', 'workgroup', 'total_cost_max', 'created_at'],
|
||||||
|
}
|
||||||
|
|
||||||
|
target_config = DBConfig(
|
||||||
|
host=env("DATABASE_HOSTNAME"),
|
||||||
|
port=env("DATABASE_PORT"),
|
||||||
|
database=env("DATABASE_NAME"),
|
||||||
|
user=env("DATABASE_CREDENTIAL_USER"),
|
||||||
|
password=env("DATABASE_CREDENTIAL_PASSWORD")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
source_config = DBConfig(
|
||||||
|
host=env("COLLECTOR_HOSTNAME"),
|
||||||
|
port=env("COLLECTOR_PORT"),
|
||||||
|
database=env("COLLECTOR_NAME"),
|
||||||
|
user=env("COLLECTOR_CREDENTIAL_USER"),
|
||||||
|
password=env("COLLECTOR_CREDENTIAL_PASSWORD")
|
||||||
|
)
|
||||||
@ -0,0 +1,104 @@
|
|||||||
|
# src/connectors/database.py
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Dict, Any, Generator
|
||||||
|
import pandas as pd
|
||||||
|
from sqlalchemy import create_engine, text
|
||||||
|
from sqlalchemy.engine import Engine
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from starlette.config import Config
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DBConfig:
|
||||||
|
host: str
|
||||||
|
port: int
|
||||||
|
database: str
|
||||||
|
user: str
|
||||||
|
password: str
|
||||||
|
|
||||||
|
def get_connection_string(self) -> str:
|
||||||
|
return f"postgresql+psycopg2://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}"
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseConnector:
|
||||||
|
def __init__(self, source_config: DBConfig, target_config: DBConfig):
|
||||||
|
self.source_engine = create_engine(
|
||||||
|
source_config.get_connection_string())
|
||||||
|
self.target_engine = create_engine(
|
||||||
|
target_config.get_connection_string())
|
||||||
|
|
||||||
|
def fetch_batch(
|
||||||
|
self,
|
||||||
|
batch_size: int,
|
||||||
|
last_id: int = 0,
|
||||||
|
columns: List[str] = None
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
Fetch a batch of data from source database
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
query = """
|
||||||
|
SELECT {}
|
||||||
|
FROM dl_wo_staging
|
||||||
|
WHERE id > :last_id
|
||||||
|
AND worktype IN ('PM', 'CM', 'EM', 'PROACTIVE')
|
||||||
|
ORDER BY id
|
||||||
|
LIMIT :batch_size
|
||||||
|
""".format(', '.join(columns) if columns else '*')
|
||||||
|
|
||||||
|
# Execute query
|
||||||
|
params = {
|
||||||
|
"last_id": last_id,
|
||||||
|
"batch_size": batch_size
|
||||||
|
}
|
||||||
|
|
||||||
|
df = pd.read_sql(
|
||||||
|
text(query),
|
||||||
|
self.source_engine,
|
||||||
|
params=params
|
||||||
|
)
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching batch: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def load_batch(
|
||||||
|
self,
|
||||||
|
df: pd.DataFrame,
|
||||||
|
target_table: str,
|
||||||
|
if_exists: str = 'append'
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Load a batch of data to target database
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
df.to_sql(
|
||||||
|
target_table,
|
||||||
|
self.target_engine,
|
||||||
|
if_exists=if_exists,
|
||||||
|
index=False,
|
||||||
|
method='multi',
|
||||||
|
chunksize=1000
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading batch: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get_total_records(self) -> int:
|
||||||
|
"""Get total number of records to migrate"""
|
||||||
|
try:
|
||||||
|
with self.source_engine.connect() as conn:
|
||||||
|
result = conn.execute(text("SELECT COUNT(*) FROM sensor_data"))
|
||||||
|
return result.scalar()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting total records: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
@ -0,0 +1,39 @@
|
|||||||
|
# src/scripts/run_migration.py
|
||||||
|
import asyncio
|
||||||
|
from temporalio.client import Client
|
||||||
|
from temporalio.worker import Worker
|
||||||
|
from workflows.historical_migration import DataMigrationWorkflow, fetch_data_activity, transform_data_activity, validate_data_activity, load_data_activity
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
# Create Temporal client
|
||||||
|
client = await Client.connect("192.168.1.82:7233")
|
||||||
|
|
||||||
|
# Create worker
|
||||||
|
worker = Worker(
|
||||||
|
client,
|
||||||
|
task_queue="migration-queue",
|
||||||
|
workflows=[DataMigrationWorkflow],
|
||||||
|
activities=[
|
||||||
|
fetch_data_activity,
|
||||||
|
transform_data_activity,
|
||||||
|
validate_data_activity,
|
||||||
|
load_data_activity
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Start worker
|
||||||
|
await worker.run()
|
||||||
|
|
||||||
|
# # Start workflow
|
||||||
|
# handle = await client.start_workflow(
|
||||||
|
# DataMigrationWorkflow.run,
|
||||||
|
# id="data-migration",
|
||||||
|
# task_queue="migration-queue"
|
||||||
|
# )
|
||||||
|
|
||||||
|
# result = await handle.result()
|
||||||
|
# print(f"Migration completed: {result}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
@ -0,0 +1,20 @@
|
|||||||
|
from temporalio.client import Client
|
||||||
|
from workflows.historical_migration import DataMigrationWorkflow
|
||||||
|
|
||||||
|
|
||||||
|
async def run():
|
||||||
|
# Start workflow
|
||||||
|
|
||||||
|
client = await Client.connect("192.168.1.82:7233")
|
||||||
|
|
||||||
|
handle = await client.start_workflow(
|
||||||
|
DataMigrationWorkflow.run,
|
||||||
|
id="data-migration",
|
||||||
|
task_queue="migration-queue"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
asyncio.run(run())
|
||||||
@ -0,0 +1,48 @@
|
|||||||
|
|
||||||
|
|
||||||
|
# src/transformers/sensor_data.py
|
||||||
|
from typing import Dict, Any
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
from datetime import datetime
|
||||||
|
from uuid import UUID, uuid4
|
||||||
|
from config import config
|
||||||
|
|
||||||
|
|
||||||
|
class WoDataTransformer:
|
||||||
|
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
Transform sensor data according to business rules
|
||||||
|
"""
|
||||||
|
# Create a copy to avoid modifying original data
|
||||||
|
transformed = df.copy()
|
||||||
|
|
||||||
|
# 1. Add UUID
|
||||||
|
transformed['id'] = uuid4()
|
||||||
|
|
||||||
|
# # 5. Drop unnecessary columns
|
||||||
|
# columns_to_drop = self.config.get('columns_to_drop', [])
|
||||||
|
# if columns_to_drop:
|
||||||
|
# transformed = transformed.drop(columns=columns_to_drop, errors='ignore')
|
||||||
|
|
||||||
|
return transformed
|
||||||
|
|
||||||
|
def validate(self, df: pd.DataFrame) -> bool:
|
||||||
|
"""
|
||||||
|
Validate transformed data
|
||||||
|
"""
|
||||||
|
if df.empty:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check required columns
|
||||||
|
if not all(col in df.columns for col in config.get('columns')):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# check id column and id is UUID
|
||||||
|
if 'id' not in df.columns:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not all(isinstance(val, UUID) for val in df['id']):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
@ -0,0 +1,123 @@
|
|||||||
|
from datetime import timedelta
|
||||||
|
import pandas as pd
|
||||||
|
from connectors.database import DatabaseConnector
|
||||||
|
from temporalio import workflow, activity
|
||||||
|
from temporalio.common import RetryPolicy
|
||||||
|
from typing import Dict, List
|
||||||
|
from config import source_config, target_config, config
|
||||||
|
from transformation.wo_transform import WoDataTransformer
|
||||||
|
# Activities
|
||||||
|
|
||||||
|
|
||||||
|
@activity.defn
|
||||||
|
async def fetch_data_activity(batch_size: int, last_id: int, columns) -> Dict:
|
||||||
|
db_connector = DatabaseConnector(source_config, target_config)
|
||||||
|
df = db_connector.fetch_batch(batch_size, last_id, columns)
|
||||||
|
return df.to_dict(orient='records')
|
||||||
|
|
||||||
|
|
||||||
|
@activity.defn
|
||||||
|
async def transform_data_activity(data: List[Dict]) -> List[Dict]:
|
||||||
|
transformer = WoDataTransformer()
|
||||||
|
df = pd.DataFrame(data)
|
||||||
|
transformed_df = transformer.transform(df)
|
||||||
|
return transformed_df.to_dict(orient='records')
|
||||||
|
|
||||||
|
|
||||||
|
@activity.defn
|
||||||
|
async def validate_data_activity(data: List[Dict]) -> bool:
|
||||||
|
transformer = WoDataTransformer()
|
||||||
|
df = pd.DataFrame(data)
|
||||||
|
return transformer.validate(df)
|
||||||
|
|
||||||
|
|
||||||
|
@activity.defn
|
||||||
|
async def load_data_activity(data: List[Dict], target_table: str) -> bool:
|
||||||
|
db_connector = DatabaseConnector(source_config, target_config)
|
||||||
|
df = pd.DataFrame(data)
|
||||||
|
return db_connector.load_batch(df, target_table)
|
||||||
|
|
||||||
|
# Workflow
|
||||||
|
|
||||||
|
|
||||||
|
@workflow.defn
|
||||||
|
class DataMigrationWorkflow:
|
||||||
|
def __init__(self):
|
||||||
|
self._total_processed = 0
|
||||||
|
self._last_id = 0
|
||||||
|
|
||||||
|
@workflow.run
|
||||||
|
async def run(self) -> Dict:
|
||||||
|
retry_policy = RetryPolicy(
|
||||||
|
initial_interval=timedelta(seconds=1),
|
||||||
|
maximum_interval=timedelta(minutes=10),
|
||||||
|
maximum_attempts=3
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_size = config.get('batch_size', 1000)
|
||||||
|
target_table = config.get('target_table')
|
||||||
|
columns = config.get('columns')
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# 1. Fetch batch
|
||||||
|
data = await workflow.execute_activity(
|
||||||
|
fetch_data_activity,
|
||||||
|
args=[batch_size, self._last_id, columns],
|
||||||
|
retry_policy=retry_policy,
|
||||||
|
start_to_close_timeout=timedelta(minutes=5)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not data:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 2. Transform data
|
||||||
|
transformed_data = await workflow.execute_activity(
|
||||||
|
transform_data_activity,
|
||||||
|
args=[data, config],
|
||||||
|
retry_policy=retry_policy,
|
||||||
|
start_to_close_timeout=timedelta(minutes=10)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Validate data
|
||||||
|
is_valid = await workflow.execute_activity(
|
||||||
|
validate_data_activity,
|
||||||
|
args=[transformed_data],
|
||||||
|
retry_policy=retry_policy,
|
||||||
|
start_to_close_timeout=timedelta(minutes=5)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not is_valid:
|
||||||
|
raise ValueError(
|
||||||
|
f"Data validation failed for batch after ID {self._last_id}")
|
||||||
|
|
||||||
|
# 4. Load data
|
||||||
|
success = await workflow.execute_activity(
|
||||||
|
load_data_activity,
|
||||||
|
args=[transformed_data, target_table],
|
||||||
|
retry_policy=retry_policy,
|
||||||
|
start_to_close_timeout=timedelta(minutes=10)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
raise Exception(
|
||||||
|
f"Failed to load batch after ID {self._last_id}")
|
||||||
|
|
||||||
|
# Update progress
|
||||||
|
self._total_processed += len(data)
|
||||||
|
self._last_id = data[-1]['id']
|
||||||
|
|
||||||
|
# Record progress
|
||||||
|
# await workflow.execute_activity(
|
||||||
|
# record_progress_activity,
|
||||||
|
# args=[{
|
||||||
|
# 'last_id': self._last_id,
|
||||||
|
# 'total_processed': self._total_processed
|
||||||
|
# }],
|
||||||
|
# retry_policy=retry_policy,
|
||||||
|
# start_to_close_timeout=timedelta(minutes=1)
|
||||||
|
# )
|
||||||
|
|
||||||
|
return {
|
||||||
|
'total_processed': self._total_processed,
|
||||||
|
'last_id': self._last_id
|
||||||
|
}
|
||||||
@ -0,0 +1,69 @@
|
|||||||
|
import asyncio
|
||||||
|
from typing import AsyncGenerator, Generator
|
||||||
|
import pytest
|
||||||
|
from httpx import AsyncClient
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
from sqlalchemy.pool import StaticPool
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy_utils import drop_database, database_exists
|
||||||
|
from starlette.config import environ
|
||||||
|
from starlette.testclient import TestClient
|
||||||
|
|
||||||
|
# from src.database import Base, get_db
|
||||||
|
# from src.main import app
|
||||||
|
|
||||||
|
# Test database URL
|
||||||
|
TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
|
||||||
|
|
||||||
|
engine = create_async_engine(
|
||||||
|
TEST_DATABASE_URL,
|
||||||
|
connect_args={"check_same_thread": False},
|
||||||
|
poolclass=StaticPool,
|
||||||
|
)
|
||||||
|
|
||||||
|
async_session = sessionmaker(
|
||||||
|
engine,
|
||||||
|
class_=AsyncSession,
|
||||||
|
expire_on_commit=False,
|
||||||
|
autocommit=False,
|
||||||
|
autoflush=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def override_get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
async with async_session() as session:
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
await session.commit()
|
||||||
|
except Exception:
|
||||||
|
await session.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
await session.close()
|
||||||
|
|
||||||
|
|
||||||
|
app.dependency_overrides[get_db] = override_get_db
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def event_loop() -> Generator:
|
||||||
|
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||||
|
yield loop
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
async def setup_db() -> AsyncGenerator[None, None]:
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
yield
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.drop_all)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def client() -> AsyncGenerator[AsyncClient, None]:
|
||||||
|
async with AsyncClient(app=app, base_url="http://test") as client:
|
||||||
|
yield client
|
||||||
@ -0,0 +1,3 @@
|
|||||||
|
from sqlalchemy.orm import scoped_session, sessionmaker
|
||||||
|
|
||||||
|
Session = scoped_session(sessionmaker())
|
||||||
@ -0,0 +1,33 @@
|
|||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from factory import (
|
||||||
|
LazyAttribute,
|
||||||
|
LazyFunction,
|
||||||
|
Sequence,
|
||||||
|
SubFactory,
|
||||||
|
post_generation,
|
||||||
|
SelfAttribute,
|
||||||
|
)
|
||||||
|
from factory.alchemy import SQLAlchemyModelFactory
|
||||||
|
from factory.fuzzy import FuzzyChoice, FuzzyDateTime, FuzzyInteger, FuzzyText
|
||||||
|
from faker import Faker
|
||||||
|
from faker.providers import misc
|
||||||
|
# from pytz import UTC
|
||||||
|
|
||||||
|
|
||||||
|
from .database import Session
|
||||||
|
|
||||||
|
fake = Faker()
|
||||||
|
fake.add_provider(misc)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseFactory(SQLAlchemyModelFactory):
|
||||||
|
"""Base Factory."""
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
"""Factory configuration."""
|
||||||
|
|
||||||
|
abstract = True
|
||||||
|
sqlalchemy_session = Session
|
||||||
|
sqlalchemy_session_persistence = "commit"
|
||||||
Loading…
Reference in New Issue