initialCommit
commit
a8161e7edb
@ -0,0 +1,16 @@
|
||||
ENV=development
|
||||
LOG_LEVEL=ERROR
|
||||
PORT=3010
|
||||
HOST=0.0.0.0
|
||||
|
||||
DATABASE_HOSTNAME=192.168.1.82
|
||||
DATABASE_PORT=5432
|
||||
DATABASE_CREDENTIAL_USER=postgres
|
||||
DATABASE_CREDENTIAL_PASSWORD=postgres
|
||||
DATABASE_NAME=digital_twin
|
||||
|
||||
COLLECTOR_HOSTNAME=192.168.1.86
|
||||
COLLECTOR_PORT=5432
|
||||
COLLECTOR_CREDENTIAL_USER=postgres
|
||||
COLLECTOR_CREDENTIAL_PASSWORD=postgres
|
||||
COLLECTOR_NAME=digital_twin
|
||||
@ -0,0 +1,2 @@
|
||||
env/
|
||||
__pycache__/
|
||||
@ -0,0 +1,44 @@
|
||||
# Quick Start:
|
||||
#
|
||||
# pip install pre-commit
|
||||
# pre-commit install && pre-commit install -t pre-push
|
||||
# pre-commit run --all-files
|
||||
#
|
||||
# To Skip Checks:
|
||||
#
|
||||
# git commit --no-verify
|
||||
fail_fast: false
|
||||
|
||||
default_language_version:
|
||||
python: python3.11.2
|
||||
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
# ruff version.
|
||||
rev: v0.7.0
|
||||
hooks:
|
||||
# Run the linter.
|
||||
#
|
||||
# When running with --fix, Ruff's lint hook should be placed before Ruff's formatter hook,
|
||||
# and before Black, isort, and other formatting tools, as Ruff's fix behavior can output code changes that require reformatting.
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
# Run the formatter.
|
||||
- id: ruff-format
|
||||
|
||||
# Typos
|
||||
- repo: https://github.com/crate-ci/typos
|
||||
rev: v1.26.1
|
||||
hooks:
|
||||
- id: typos
|
||||
exclude: ^(data/dispatch-sample-data.dump|src/dispatch/static/dispatch/src/|src/dispatch/database/revisions/)
|
||||
|
||||
# Pytest
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: tests
|
||||
name: run tests
|
||||
entry: pytest -v tests/
|
||||
language: system
|
||||
types: [python]
|
||||
stages: [push]
|
||||
@ -0,0 +1,50 @@
|
||||
# Use the official Python 3.11 image from the Docker Hub
|
||||
FROM python:3.11-slim as builder
|
||||
|
||||
# Install Poetry
|
||||
RUN pip install poetry
|
||||
|
||||
# Set environment variables for Poetry
|
||||
ENV POETRY_NO_INTERACTION=1 \
|
||||
POETRY_VIRTUALENVS_IN_PROJECT=1 \
|
||||
POETRY_VIRTUALENVS_CREATE=1 \
|
||||
POETRY_CACHE_DIR=/tmp/poetry_cache
|
||||
|
||||
# Set the working directory
|
||||
WORKDIR /app
|
||||
|
||||
# Copy the Poetry configuration files
|
||||
COPY pyproject.toml poetry.lock ./
|
||||
|
||||
# Install dependencies
|
||||
RUN poetry install --no-dev --no-root
|
||||
|
||||
# Use a new slim image for the runtime
|
||||
FROM python:3.11-slim as runtime
|
||||
|
||||
# Install necessary tools for running the app, including `make`
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
make \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Set environment variables for Poetry
|
||||
ENV POETRY_VIRTUALENVS_IN_PROJECT=1 \
|
||||
PATH="/app/.venv/bin:$PATH"
|
||||
|
||||
# Copy Poetry installation from builder
|
||||
COPY --from=builder /app/.venv /app/.venv
|
||||
|
||||
# Copy application files
|
||||
COPY . /app/
|
||||
|
||||
# Delete Tests for production
|
||||
RUN rm -rf /app/tests/
|
||||
|
||||
# Expose port for the application
|
||||
EXPOSE 3005
|
||||
|
||||
# Set the working directory
|
||||
WORKDIR /app
|
||||
|
||||
# Run `make run` as the entry point
|
||||
CMD ["make", "run"]
|
||||
@ -0,0 +1,14 @@
|
||||
# Define variables
|
||||
POETRY = poetry
|
||||
PYTHON = $(POETRY) run python
|
||||
APP = src/server.py
|
||||
|
||||
# Targets and their rules
|
||||
|
||||
# Install dependencies
|
||||
install:
|
||||
$(POETRY) install
|
||||
|
||||
# Run the application
|
||||
run:
|
||||
python run.py
|
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,32 @@
|
||||
[tool.poetry]
|
||||
name = "optimumohservice"
|
||||
version = "0.1.0"
|
||||
description = ""
|
||||
authors = ["Cizz22 <cisatraa@gmail.com>"]
|
||||
license = "MIT"
|
||||
readme = "README.md"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.11"
|
||||
fastapi = {extras = ["standard"], version = "^0.115.4"}
|
||||
sqlalchemy = "^2.0.36"
|
||||
httpx = "^0.27.2"
|
||||
pytest = "^8.3.3"
|
||||
faker = "^30.8.2"
|
||||
factory-boy = "^3.3.1"
|
||||
sqlalchemy-utils = "^0.41.2"
|
||||
slowapi = "^0.1.9"
|
||||
uvicorn = "^0.32.0"
|
||||
pytz = "^2024.2"
|
||||
sqlalchemy-filters = "^0.13.0"
|
||||
asyncpg = "^0.30.0"
|
||||
requests = "^2.32.3"
|
||||
pydantic = "^2.10.2"
|
||||
temporalio = "^1.8.0"
|
||||
pandas = "^2.2.3"
|
||||
psycopg2-binary = "^2.9.10"
|
||||
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
@ -0,0 +1,10 @@
|
||||
import uvicorn
|
||||
from src.config import PORT, HOST
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(
|
||||
"src.main:app",
|
||||
host=HOST,
|
||||
port=PORT,
|
||||
reload=True
|
||||
)
|
||||
@ -0,0 +1,71 @@
|
||||
from typing import List, Optional
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
from src.auth.service import JWTBearer
|
||||
|
||||
|
||||
from src.scope.router import router as scope_router
|
||||
from src.scope_equipment.router import router as scope_equipment_router
|
||||
from src.overhaul.router import router as overhaul_router
|
||||
from src.calculation_time_constrains.router import router as calculation_time_constrains_router
|
||||
|
||||
|
||||
class ErrorMessage(BaseModel):
|
||||
msg: str
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
detail: Optional[List[ErrorMessage]]
|
||||
|
||||
|
||||
api_router = APIRouter(
|
||||
default_response_class=JSONResponse,
|
||||
responses={
|
||||
400: {"model": ErrorResponse},
|
||||
401: {"model": ErrorResponse},
|
||||
403: {"model": ErrorResponse},
|
||||
404: {"model": ErrorResponse},
|
||||
500: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@api_router.get("/healthcheck", include_in_schema=False)
|
||||
def healthcheck():
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
authenticated_api_router = APIRouter(dependencies=[Depends(JWTBearer())],
|
||||
)
|
||||
# overhaul data
|
||||
authenticated_api_router.include_router(
|
||||
overhaul_router, prefix="/overhauls", tags=["overhaul"])
|
||||
|
||||
|
||||
# Scope data
|
||||
authenticated_api_router.include_router(
|
||||
scope_router, prefix="/scopes", tags=["scope"])
|
||||
|
||||
authenticated_api_router.include_router(
|
||||
scope_equipment_router, prefix="/scope-equipments", tags=["scope_equipment"]
|
||||
)
|
||||
|
||||
# calculation
|
||||
calculation_router = APIRouter(prefix="/calculation", tags=["calculations"])
|
||||
|
||||
# Time constrains
|
||||
calculation_router.include_router(
|
||||
calculation_time_constrains_router, prefix="/time-constraint", tags=["time_constraint"])
|
||||
|
||||
# Target reliability
|
||||
|
||||
# Budget Constrain
|
||||
|
||||
authenticated_api_router.include_router(
|
||||
calculation_router
|
||||
)
|
||||
|
||||
api_router.include_router(authenticated_api_router)
|
||||
@ -0,0 +1,9 @@
|
||||
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class UserBase(BaseModel):
|
||||
name: str
|
||||
role: str
|
||||
user_id: str
|
||||
@ -0,0 +1,55 @@
|
||||
# app/auth/auth_bearer.py
|
||||
|
||||
from typing import Annotated, Optional
|
||||
from fastapi import Depends, Request, HTTPException
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
import requests
|
||||
import src.config as config
|
||||
from .model import UserBase
|
||||
|
||||
|
||||
class JWTBearer(HTTPBearer):
|
||||
def __init__(self, auto_error: bool = True):
|
||||
super(JWTBearer, self).__init__(auto_error=auto_error)
|
||||
|
||||
async def __call__(self, request: Request):
|
||||
credentials: HTTPAuthorizationCredentials = await super(JWTBearer, self).__call__(request)
|
||||
if credentials:
|
||||
if not credentials.scheme == "Bearer":
|
||||
raise HTTPException(
|
||||
status_code=403, detail="Invalid authentication scheme.")
|
||||
user_info = self.verify_jwt(credentials.credentials)
|
||||
if not user_info:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="Invalid token or expired token.")
|
||||
|
||||
request.state.user = user_info
|
||||
return user_info
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="Invalid authorization code.")
|
||||
|
||||
def verify_jwt(self, jwtoken: str) -> Optional[UserBase]:
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{config.AUTH_SERVICE_API}/verify-token",
|
||||
headers={"Authorization": f"Bearer {jwtoken}"},
|
||||
)
|
||||
|
||||
if not response.ok:
|
||||
return None
|
||||
|
||||
user_data = response.json()
|
||||
return UserBase(**user_data['data'])
|
||||
|
||||
except Exception as e:
|
||||
print(f"Token verification error: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
# Create dependency to get current user from request state
|
||||
async def get_current_user(request: Request) -> UserBase:
|
||||
return request.state.user
|
||||
|
||||
|
||||
CurrentUser = Annotated[UserBase, Depends(get_current_user)]
|
||||
@ -0,0 +1,96 @@
|
||||
|
||||
from typing import Any, Dict
|
||||
from fastapi import APIRouter, HTTPException, status
|
||||
from .schema import CalculationTimeConstrainsParametersRead, CalculationTimeConstrainsRead, CalculationTimeConstrainsCreate
|
||||
from src.database.core import DbSession
|
||||
from src.models import StandardResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/parameters", response_model=StandardResponse[CalculationTimeConstrainsParametersRead])
|
||||
async def get_calculation_parameters():
|
||||
"""Get all calculation parameter pagination."""
|
||||
|
||||
# {
|
||||
# "costPerFailure": 733.614,
|
||||
# "availableScopes": ["A", "B"],
|
||||
# "recommendedScope": "B",
|
||||
# "historicalData": {
|
||||
# "averageOverhaulCost": 10000000,
|
||||
# "lastCalculation": {
|
||||
# "id": "calc_122",
|
||||
# "date": "2024-10-15",
|
||||
# "scope": "B"
|
||||
# }
|
||||
# }
|
||||
# }
|
||||
|
||||
return StandardResponse(
|
||||
data=CalculationTimeConstrainsParametersRead(
|
||||
costPerFailure=733.614,
|
||||
availableScopes=["A", "B"],
|
||||
recommendedScope="B",
|
||||
historicalData={
|
||||
"averageOverhaulCost": 10000000,
|
||||
"lastCalculation": {
|
||||
"id": "calc_122",
|
||||
"date": "2024-10-15",
|
||||
"scope": "B",
|
||||
},
|
||||
},
|
||||
),
|
||||
message="Data retrieved successfully",
|
||||
)
|
||||
|
||||
|
||||
@router.post("", response_model=StandardResponse[CalculationTimeConstrainsRead])
|
||||
async def create_calculation_time_constrains(db_session: DbSession, calculation_time_constrains_in: CalculationTimeConstrainsCreate):
|
||||
"""Calculate Here"""
|
||||
calculation_result = {
|
||||
"id": "calc_123",
|
||||
"result": {
|
||||
"summary": {
|
||||
"scope": "B",
|
||||
"numberOfFailures": 59,
|
||||
"optimumOHTime": 90,
|
||||
"optimumTotalCost": 500000000
|
||||
},
|
||||
"chartData": {},
|
||||
"comparisons": {
|
||||
"vsLastCalculation": {
|
||||
"costDifference": -50000000,
|
||||
"timeChange": "+15 days"
|
||||
}
|
||||
}
|
||||
},
|
||||
"simulationLimits": {
|
||||
"minInterval": 30,
|
||||
"maxInterval": 180
|
||||
}
|
||||
}
|
||||
|
||||
return StandardResponse(data=CalculationTimeConstrainsRead(
|
||||
id=calculation_result["id"],
|
||||
result=calculation_result["result"],
|
||||
), message="Data created successfully")
|
||||
|
||||
|
||||
@router.get("/simulation", response_model=StandardResponse[Dict[str, Any]])
|
||||
async def get_simulation_result():
|
||||
|
||||
results = {
|
||||
"simulation": {
|
||||
"intervalDays": 45,
|
||||
"numberOfFailures": 75,
|
||||
"totalCost": 550000000,
|
||||
},
|
||||
"comparison": {
|
||||
"vsOptimal": {
|
||||
"failureDifference": 16,
|
||||
"costDifference": 50000000
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return StandardResponse(data=results, message="Data retrieved successfully")
|
||||
@ -0,0 +1,78 @@
|
||||
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import Field
|
||||
from src.models import DefultBase
|
||||
|
||||
# {
|
||||
# "costPerFailure": 733.614,
|
||||
# "availableScopes": ["A", "B"],
|
||||
# "recommendedScope": "B",
|
||||
# "historicalData": {
|
||||
# "averageOverhaulCost": 10000000,
|
||||
# "lastCalculation": {
|
||||
# "id": "calc_122",
|
||||
# "date": "2024-10-15",
|
||||
# "scope": "B"
|
||||
# }
|
||||
# }
|
||||
# }
|
||||
|
||||
|
||||
class CalculationTimeConstrainsBase(DefultBase):
|
||||
pass
|
||||
|
||||
|
||||
class CalculationTimeConstrainsParametersRead(CalculationTimeConstrainsBase):
|
||||
costPerFailure: float = Field(..., description="Cost per failure")
|
||||
availableScopes: List[str] = Field(..., description="Available scopes")
|
||||
recommendedScope: str = Field(..., description="Recommended scope")
|
||||
historicalData: Dict[str, Any] = Field(..., description="Historical data")
|
||||
|
||||
|
||||
# {
|
||||
# "overhaulCost": 10000000,
|
||||
# "scopeOH": "B",
|
||||
# "costPerFailure": 733.614,
|
||||
# "metadata": {
|
||||
# "unit": "PLTU1",
|
||||
# "calculatedBy": "user123",
|
||||
# "timestamp": "2024-11-28T10:00:00Z"
|
||||
# }
|
||||
# }
|
||||
|
||||
class CalculationTimeConstrainsCreate(CalculationTimeConstrainsBase):
|
||||
overhaulCost: float = Field(..., description="Overhaul cost")
|
||||
scopeOH: str = Field(..., description="Scope OH")
|
||||
costPerFailure: float = Field(..., description="Cost per failure")
|
||||
metadata: Dict[str, Any] = Field(..., description="Metadata")
|
||||
|
||||
# {
|
||||
# "calculationId": "calc_123",
|
||||
# "result": {
|
||||
# "summary": {
|
||||
# "scope": "B",
|
||||
# "numberOfFailures": 59,
|
||||
# "optimumOHTime": 90,
|
||||
# "optimumTotalCost": 500000000
|
||||
# },
|
||||
# "chartData": {/* ... */},
|
||||
# "comparisons": {
|
||||
# "vsLastCalculation": {
|
||||
# "costDifference": -50000000,
|
||||
# "timeChange": "+15 days"
|
||||
# }
|
||||
# }
|
||||
# },
|
||||
# "simulationLimits": {
|
||||
# "minInterval": 30,
|
||||
# "maxInterval": 180
|
||||
# }
|
||||
# }
|
||||
|
||||
class CalculationTimeConstrainsRead(CalculationTimeConstrainsBase):
|
||||
id: Union[UUID, str]
|
||||
result: Dict[str, Any]
|
||||
@ -0,0 +1,74 @@
|
||||
import logging
|
||||
import os
|
||||
import base64
|
||||
from urllib import parse
|
||||
from typing import List
|
||||
from pydantic import BaseModel
|
||||
|
||||
from starlette.config import Config
|
||||
from starlette.datastructures import CommaSeparatedStrings
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseConfigurationModel(BaseModel):
|
||||
"""Base configuration model used by all config options."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def get_env_tags(tag_list: List[str]) -> dict:
|
||||
"""Create dictionary of available env tags."""
|
||||
tags = {}
|
||||
for t in tag_list:
|
||||
tag_key, env_key = t.split(":")
|
||||
|
||||
env_value = os.environ.get(env_key)
|
||||
|
||||
if env_value:
|
||||
tags.update({tag_key: env_value})
|
||||
|
||||
return tags
|
||||
|
||||
def get_config():
|
||||
try:
|
||||
# Try to load from .env file first
|
||||
config = Config(".env")
|
||||
except FileNotFoundError:
|
||||
# If .env doesn't exist, use environment variables
|
||||
config = Config(environ=os.environ)
|
||||
|
||||
return config
|
||||
|
||||
config = get_config()
|
||||
|
||||
|
||||
LOG_LEVEL = config("LOG_LEVEL", default=logging.WARNING)
|
||||
ENV = config("ENV", default="local")
|
||||
PORT = config("PORT", cast=int, default=8000)
|
||||
HOST = config("HOST", default="localhost")
|
||||
|
||||
|
||||
# database
|
||||
DATABASE_HOSTNAME = config("DATABASE_HOSTNAME")
|
||||
_DATABASE_CREDENTIAL_USER = config("DATABASE_CREDENTIAL_USER")
|
||||
_DATABASE_CREDENTIAL_PASSWORD = config("DATABASE_CREDENTIAL_PASSWORD")
|
||||
_QUOTED_DATABASE_PASSWORD = parse.quote(str(_DATABASE_CREDENTIAL_PASSWORD))
|
||||
DATABASE_NAME = config("DATABASE_NAME", default="digital_twin")
|
||||
DATABASE_PORT = config("DATABASE_PORT", default="5432")
|
||||
|
||||
DATABASE_ENGINE_POOL_SIZE = config(
|
||||
"DATABASE_ENGINE_POOL_SIZE", cast=int, default=20)
|
||||
DATABASE_ENGINE_MAX_OVERFLOW = config(
|
||||
"DATABASE_ENGINE_MAX_OVERFLOW", cast=int, default=0)
|
||||
# Deal with DB disconnects
|
||||
# https://docs.sqlalchemy.org/en/20/core/pooling.html#pool-disconnects
|
||||
DATABASE_ENGINE_POOL_PING = config("DATABASE_ENGINE_POOL_PING", default=False)
|
||||
SQLALCHEMY_DATABASE_URI = f"postgresql+asyncpg://{_DATABASE_CREDENTIAL_USER}:{_QUOTED_DATABASE_PASSWORD}@{DATABASE_HOSTNAME}:{DATABASE_PORT}/{DATABASE_NAME}"
|
||||
|
||||
TIMEZONE = "Asia/Jakarta"
|
||||
|
||||
|
||||
AUTH_SERVICE_API = config(
|
||||
"AUTH_SERVICE_API", default="http://192.168.1.82:8000/auth")
|
||||
@ -0,0 +1,156 @@
|
||||
# src/database.py
|
||||
from starlette.requests import Request
|
||||
from sqlalchemy_utils import get_mapper
|
||||
from sqlalchemy.sql.expression import true
|
||||
from sqlalchemy.orm import object_session, sessionmaker, Session
|
||||
from sqlalchemy.ext.declarative import declarative_base, declared_attr
|
||||
from sqlalchemy import create_engine, inspect
|
||||
from pydantic import BaseModel
|
||||
from fastapi import Depends
|
||||
from typing import Annotated, Any
|
||||
from contextlib import contextmanager
|
||||
import re
|
||||
import functools
|
||||
from typing import AsyncGenerator
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import DeclarativeBase, sessionmaker
|
||||
|
||||
from src.config import SQLALCHEMY_DATABASE_URI
|
||||
|
||||
engine = create_async_engine(
|
||||
SQLALCHEMY_DATABASE_URI,
|
||||
echo=False,
|
||||
future=True
|
||||
)
|
||||
|
||||
async_session = sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
|
||||
def get_db(request: Request):
|
||||
return request.state.db
|
||||
|
||||
|
||||
DbSession = Annotated[AsyncSession, Depends(get_db)]
|
||||
|
||||
|
||||
class CustomBase:
|
||||
__repr_attrs__ = []
|
||||
__repr_max_length__ = 15
|
||||
|
||||
@declared_attr
|
||||
def __tablename__(self):
|
||||
return resolve_table_name(self.__name__)
|
||||
|
||||
def dict(self):
|
||||
"""Returns a dict representation of a model."""
|
||||
return {c.name: getattr(self, c.name) for c in self.__table__.columns}
|
||||
|
||||
@property
|
||||
def _id_str(self):
|
||||
ids = inspect(self).identity
|
||||
if ids:
|
||||
return "-".join([str(x) for x in ids]) if len(ids) > 1 else str(ids[0])
|
||||
else:
|
||||
return "None"
|
||||
|
||||
@property
|
||||
def _repr_attrs_str(self):
|
||||
max_length = self.__repr_max_length__
|
||||
|
||||
values = []
|
||||
single = len(self.__repr_attrs__) == 1
|
||||
for key in self.__repr_attrs__:
|
||||
if not hasattr(self, key):
|
||||
raise KeyError(
|
||||
"{} has incorrect attribute '{}' in " "__repr__attrs__".format(
|
||||
self.__class__, key
|
||||
)
|
||||
)
|
||||
value = getattr(self, key)
|
||||
wrap_in_quote = isinstance(value, str)
|
||||
|
||||
value = str(value)
|
||||
if len(value) > max_length:
|
||||
value = value[:max_length] + "..."
|
||||
|
||||
if wrap_in_quote:
|
||||
value = "'{}'".format(value)
|
||||
values.append(value if single else "{}:{}".format(key, value))
|
||||
|
||||
return " ".join(values)
|
||||
|
||||
def __repr__(self):
|
||||
# get id like '#123'
|
||||
id_str = ("#" + self._id_str) if self._id_str else ""
|
||||
# join class name, id and repr_attrs
|
||||
return "<{} {}{}>".format(
|
||||
self.__class__.__name__,
|
||||
id_str,
|
||||
" " + self._repr_attrs_str if self._repr_attrs_str else "",
|
||||
)
|
||||
|
||||
|
||||
Base = declarative_base(cls=CustomBase)
|
||||
# make_searchable(Base.metadata)
|
||||
|
||||
|
||||
@contextmanager
|
||||
async def get_session():
|
||||
"""Context manager to ensure the session is closed after use."""
|
||||
session = async_session()
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except:
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
def resolve_table_name(name):
|
||||
"""Resolves table names to their mapped names."""
|
||||
names = re.split("(?=[A-Z])", name) # noqa
|
||||
return "_".join([x.lower() for x in names if x])
|
||||
|
||||
|
||||
raise_attribute_error = object()
|
||||
|
||||
|
||||
# def resolve_attr(obj, attr, default=None):
|
||||
# """Attempts to access attr via dotted notation, returns none if attr does not exist."""
|
||||
# try:
|
||||
# return functools.reduce(getattr, attr.split("."), obj)
|
||||
# except AttributeError:
|
||||
# return default
|
||||
|
||||
|
||||
# def get_model_name_by_tablename(table_fullname: str) -> str:
|
||||
# """Returns the model name of a given table."""
|
||||
# return get_class_by_tablename(table_fullname=table_fullname).__name__
|
||||
|
||||
|
||||
def get_class_by_tablename(table_fullname: str) -> Any:
|
||||
"""Return class reference mapped to table."""
|
||||
|
||||
def _find_class(name):
|
||||
for c in Base._decl_class_registry.values():
|
||||
if hasattr(c, "__table__"):
|
||||
if c.__table__.fullname.lower() == name.lower():
|
||||
return c
|
||||
|
||||
mapped_name = resolve_table_name(table_fullname)
|
||||
mapped_class = _find_class(mapped_name)
|
||||
|
||||
return mapped_class
|
||||
|
||||
|
||||
# def get_table_name_by_class_instance(class_instance: Base) -> str:
|
||||
# """Returns the name of the table for a given class instance."""
|
||||
# return class_instance._sa_instance_state.mapper.mapped_table.name
|
||||
@ -0,0 +1,133 @@
|
||||
|
||||
import logging
|
||||
from typing import Annotated, List
|
||||
|
||||
from sqlalchemy import desc, func, or_, Select
|
||||
from sqlalchemy_filters import apply_pagination
|
||||
from sqlalchemy.exc import ProgrammingError
|
||||
from .core import DbSession
|
||||
|
||||
|
||||
from fastapi import Query, Depends
|
||||
from pydantic.types import Json, constr
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# allows only printable characters
|
||||
QueryStr = constr(pattern=r"^[ -~]+$", min_length=1)
|
||||
|
||||
|
||||
def common_parameters(
|
||||
db_session: DbSession, # type: ignore
|
||||
current_user: QueryStr = Query(None, alias="currentUser"), # type: ignore
|
||||
page: int = Query(1, gt=0, lt=2147483647),
|
||||
items_per_page: int = Query(5, alias="itemsPerPage", gt=-2, lt=2147483647),
|
||||
query_str: QueryStr = Query(None, alias="q"), # type: ignore
|
||||
filter_spec: QueryStr = Query(None, alias="filter"), # type: ignore
|
||||
sort_by: List[str] = Query([], alias="sortBy[]"),
|
||||
descending: List[bool] = Query([], alias="descending[]"),
|
||||
exclude: List[str] = Query([], alias="exclude[]"),
|
||||
# role: QueryStr = Depends(get_current_role),
|
||||
):
|
||||
return {
|
||||
"db_session": db_session,
|
||||
"page": page,
|
||||
"items_per_page": items_per_page,
|
||||
"query_str": query_str,
|
||||
"filter_spec": filter_spec,
|
||||
"sort_by": sort_by,
|
||||
"descending": descending,
|
||||
"current_user": current_user,
|
||||
# "role": role,
|
||||
}
|
||||
|
||||
|
||||
CommonParameters = Annotated[
|
||||
dict[str, int | str | DbSession | QueryStr |
|
||||
Json | List[str] | List[bool]],
|
||||
Depends(common_parameters),
|
||||
]
|
||||
|
||||
|
||||
def search(*, query_str: str, query: Query, model, sort=False):
|
||||
"""Perform a search based on the query."""
|
||||
search_model = model
|
||||
|
||||
if not query_str.strip():
|
||||
return query
|
||||
|
||||
search = []
|
||||
if hasattr(search_model, "search_vector"):
|
||||
vector = search_model.search_vector
|
||||
search.append(vector.op("@@")(func.tsq_parse(query_str)))
|
||||
|
||||
if hasattr(search_model, "name"):
|
||||
search.append(
|
||||
search_model.name.ilike(f"%{query_str}%"),
|
||||
)
|
||||
search.append(search_model.name == query_str)
|
||||
|
||||
if not search:
|
||||
raise Exception(f"Search not supported for model: {model}")
|
||||
|
||||
query = query.filter(or_(*search))
|
||||
|
||||
if sort:
|
||||
query = query.order_by(
|
||||
desc(func.ts_rank_cd(vector, func.tsq_parse(query_str))))
|
||||
|
||||
return query.params(term=query_str)
|
||||
|
||||
|
||||
async def search_filter_sort_paginate(
|
||||
db_session: DbSession,
|
||||
model,
|
||||
query_str: str = None,
|
||||
filter_spec: str | dict | None = None,
|
||||
page: int = 1,
|
||||
items_per_page: int = 5,
|
||||
sort_by: List[str] = None,
|
||||
descending: List[bool] = None,
|
||||
current_user: str = None,
|
||||
exclude: List[str] = None,
|
||||
):
|
||||
"""Common functionality for searching, filtering, sorting, and pagination."""
|
||||
# try:
|
||||
query = Select(model)
|
||||
|
||||
if query_str:
|
||||
sort = False if sort_by else True
|
||||
query = search(query_str=query_str, query=query,
|
||||
model=model, sort=sort)
|
||||
|
||||
# Get total count
|
||||
count_query = Select(func.count()).select_from(query.subquery())
|
||||
total = await db_session.scalar(count_query)
|
||||
|
||||
query = (
|
||||
query
|
||||
.offset((page - 1) * items_per_page)
|
||||
.limit(items_per_page)
|
||||
)
|
||||
|
||||
result = await db_session.execute(query)
|
||||
items = result.scalars().all()
|
||||
|
||||
# try:
|
||||
# query, pagination = apply_pagination(
|
||||
# query=query, page_number=page, page_size=items_per_page)
|
||||
# except ProgrammingError as e:
|
||||
# log.debug(e)
|
||||
# return {
|
||||
# "items": [],
|
||||
# "itemsPerPage": items_per_page,
|
||||
# "page": page,
|
||||
# "total": 0,
|
||||
# }
|
||||
|
||||
return {
|
||||
"items": items,
|
||||
"itemsPerPage": items_per_page,
|
||||
"page": page,
|
||||
"total": total,
|
||||
}
|
||||
@ -0,0 +1,24 @@
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class OptimumOHEnum(StrEnum):
|
||||
"""
|
||||
A custom Enum class that extends StrEnum.
|
||||
|
||||
This class inherits all functionality from StrEnum, including
|
||||
string representation and automatic value conversion to strings.
|
||||
|
||||
Example:
|
||||
class Visibility(DispatchEnum):
|
||||
OPEN = "Open"
|
||||
RESTRICTED = "Restricted"
|
||||
|
||||
assert str(Visibility.OPEN) == "Open"
|
||||
"""
|
||||
|
||||
pass # No additional implementation needed
|
||||
|
||||
|
||||
class ResponseStatus(OptimumOHEnum):
|
||||
SUCCESS = "success"
|
||||
ERROR = "error"
|
||||
@ -0,0 +1,164 @@
|
||||
# Define base error model
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from src.enums import ResponseStatus
|
||||
from slowapi import _rate_limit_exceeded_handler
|
||||
from slowapi.errors import RateLimitExceeded
|
||||
|
||||
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, DataError, DBAPIError
|
||||
from asyncpg.exceptions import DataError as AsyncPGDataError, PostgresError
|
||||
|
||||
|
||||
class ErrorDetail(BaseModel):
|
||||
field: Optional[str] = None
|
||||
message: str
|
||||
code: Optional[str] = None
|
||||
params: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
data: Optional[Any] = None
|
||||
message: str
|
||||
status: ResponseStatus = ResponseStatus.ERROR
|
||||
errors: Optional[List[ErrorDetail]] = None
|
||||
|
||||
# Custom exception handler setup
|
||||
|
||||
|
||||
def get_request_context(request: Request):
|
||||
"""
|
||||
Get detailed request context for logging.
|
||||
"""
|
||||
|
||||
def get_client_ip():
|
||||
"""
|
||||
Get the real client IP address from Kong Gateway headers.
|
||||
Kong sets X-Real-IP and X-Forwarded-For headers by default.
|
||||
"""
|
||||
# Kong specific headers
|
||||
if "X-Real-IP" in request.headers:
|
||||
return request.headers["X-Real-IP"]
|
||||
|
||||
# Fallback to X-Forwarded-For
|
||||
if "X-Forwarded-For" in request.headers:
|
||||
# Get the first IP (original client)
|
||||
return request.headers["X-Forwarded-For"].split(",")[0].strip()
|
||||
|
||||
# Last resort
|
||||
return request.client.host
|
||||
|
||||
return {
|
||||
"endpoint": request.url.path,
|
||||
"url": request.url,
|
||||
"method": request.method,
|
||||
"remote_addr": get_client_ip(),
|
||||
}
|
||||
|
||||
|
||||
def handle_sqlalchemy_error(error: SQLAlchemyError):
|
||||
"""
|
||||
Handle SQLAlchemy errors and return user-friendly error messages.
|
||||
"""
|
||||
original_error = getattr(error, 'orig', None)
|
||||
print(original_error)
|
||||
|
||||
if isinstance(error, IntegrityError):
|
||||
if "unique constraint" in str(error).lower():
|
||||
return "This record already exists.", 409
|
||||
elif "foreign key constraint" in str(error).lower():
|
||||
return "Related record not found.", 400
|
||||
else:
|
||||
return "Data integrity error.", 400
|
||||
elif isinstance(error, DataError) or isinstance(original_error, AsyncPGDataError):
|
||||
return "Invalid data provided.", 400
|
||||
elif isinstance(error, DBAPIError):
|
||||
if "unique constraint" in str(error).lower():
|
||||
return "This record already exists.", 409
|
||||
elif "foreign key constraint" in str(error).lower():
|
||||
return "Related record not found.", 400
|
||||
elif "null value in column" in str(error).lower():
|
||||
return "Required data missing.", 400
|
||||
elif "invalid input for query argument" in str(error).lower():
|
||||
return "Invalid data provided.", 400
|
||||
else:
|
||||
return "Database error.", 500
|
||||
else:
|
||||
# Log the full error for debugging purposes
|
||||
logging.error(f"Unexpected database error: {str(error)}")
|
||||
return "An unexpected database error occurred.", 500
|
||||
|
||||
|
||||
def handle_exception(request: Request, exc: Exception):
|
||||
"""
|
||||
Global exception handler for Fastapi application.
|
||||
"""
|
||||
request_info = get_request_context(request)
|
||||
|
||||
if isinstance(exc, RateLimitExceeded):
|
||||
_rate_limit_exceeded_handler(request, exc)
|
||||
if isinstance(exc, HTTPException):
|
||||
logging.error(
|
||||
f"HTTP exception | Code: {exc.status_code} | Error: {exc.detail} | Request: {request_info}",
|
||||
extra={"error_category": "http"},
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content={
|
||||
"data": None,
|
||||
"message": str(exc.detail),
|
||||
"status": ResponseStatus.ERROR,
|
||||
"errors": [
|
||||
ErrorDetail(
|
||||
message=str(exc.detail)
|
||||
).model_dump()
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
if isinstance(exc, SQLAlchemyError):
|
||||
error_message, status_code = handle_sqlalchemy_error(exc)
|
||||
logging.error(
|
||||
f"Database Error | Error: {str(error_message)} | Request: {request_info}",
|
||||
extra={"error_category": "database"},
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status_code,
|
||||
content={
|
||||
"data": None,
|
||||
"message": error_message,
|
||||
"status": ResponseStatus.ERROR,
|
||||
"errors": [
|
||||
ErrorDetail(
|
||||
message=error_message
|
||||
).model_dump()
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
# Log unexpected errors
|
||||
logging.error(
|
||||
f"Unexpected Error | Error: {str(exc)} | Request: {request_info}",
|
||||
extra={"error_category": "unexpected"},
|
||||
)
|
||||
|
||||
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"data": None,
|
||||
"message": exc.__class__.__name__,
|
||||
"status": ResponseStatus.ERROR,
|
||||
"errors": [
|
||||
ErrorDetail(
|
||||
message="An unexpected error occurred."
|
||||
).model_dump()
|
||||
]
|
||||
}
|
||||
)
|
||||
@ -0,0 +1,33 @@
|
||||
import logging
|
||||
|
||||
from src.config import LOG_LEVEL
|
||||
from src.enums import OptimumOHEnum
|
||||
|
||||
|
||||
LOG_FORMAT_DEBUG = "%(levelname)s:%(message)s:%(pathname)s:%(funcName)s:%(lineno)d"
|
||||
|
||||
|
||||
class LogLevels(OptimumOHEnum):
|
||||
info = "INFO"
|
||||
warn = "WARN"
|
||||
error = "ERROR"
|
||||
debug = "DEBUG"
|
||||
|
||||
|
||||
def configure_logging():
|
||||
log_level = str(LOG_LEVEL).upper() # cast to string
|
||||
log_levels = list(LogLevels)
|
||||
|
||||
if log_level not in log_levels:
|
||||
# we use error as the default log level
|
||||
logging.basicConfig(level=LogLevels.error)
|
||||
return
|
||||
|
||||
if log_level == LogLevels.debug:
|
||||
logging.basicConfig(level=log_level, format=LOG_FORMAT_DEBUG)
|
||||
return
|
||||
|
||||
logging.basicConfig(level=log_level)
|
||||
|
||||
# sometimes the slack client can be too verbose
|
||||
logging.getLogger("slack_sdk.web.base_client").setLevel(logging.CRITICAL)
|
||||
@ -0,0 +1,112 @@
|
||||
|
||||
import time
|
||||
import logging
|
||||
from os import path
|
||||
from uuid import uuid1
|
||||
from typing import Optional, Final
|
||||
from contextvars import ContextVar
|
||||
|
||||
from fastapi import FastAPI, HTTPException, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import ValidationError
|
||||
|
||||
from slowapi import _rate_limit_exceeded_handler
|
||||
from slowapi.errors import RateLimitExceeded
|
||||
from sqlalchemy import inspect
|
||||
from sqlalchemy.orm import scoped_session
|
||||
from sqlalchemy.ext.asyncio import async_scoped_session
|
||||
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
||||
from starlette.requests import Request
|
||||
from starlette.routing import compile_path
|
||||
from starlette.middleware.gzip import GZipMiddleware
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from starlette.responses import Response, StreamingResponse, FileResponse
|
||||
from starlette.staticfiles import StaticFiles
|
||||
import logging
|
||||
|
||||
from src.enums import ResponseStatus
|
||||
from src.logging import configure_logging
|
||||
from src.rate_limiter import limiter
|
||||
from src.api import api_router
|
||||
from src.database.core import engine, async_session
|
||||
from src.exceptions import handle_exception
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# we configure the logging level and format
|
||||
configure_logging()
|
||||
|
||||
# we define the exception handlers
|
||||
exception_handlers = {Exception: handle_exception}
|
||||
|
||||
# we create the ASGI for the app
|
||||
app = FastAPI(exception_handlers=exception_handlers, openapi_url="", title="LCCA API",
|
||||
description="Welcome to LCCA's API documentation!",
|
||||
version="0.1.0")
|
||||
app.state.limiter = limiter
|
||||
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
||||
app.add_middleware(GZipMiddleware, minimum_size=2000)
|
||||
# credentials: "include",
|
||||
|
||||
|
||||
REQUEST_ID_CTX_KEY: Final[str] = "request_id"
|
||||
_request_id_ctx_var: ContextVar[Optional[str]] = ContextVar(
|
||||
REQUEST_ID_CTX_KEY, default=None)
|
||||
|
||||
|
||||
def get_request_id() -> Optional[str]:
|
||||
return _request_id_ctx_var.get()
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def db_session_middleware(request: Request, call_next):
|
||||
request_id = str(uuid1())
|
||||
|
||||
# we create a per-request id such that we can ensure that our session is scoped for a particular request.
|
||||
# see: https://github.com/tiangolo/fastapi/issues/726
|
||||
ctx_token = _request_id_ctx_var.set(request_id)
|
||||
|
||||
try:
|
||||
session = async_scoped_session(async_session, scopefunc=get_request_id)
|
||||
request.state.db = session()
|
||||
response = await call_next(request)
|
||||
except Exception as e:
|
||||
raise e from None
|
||||
finally:
|
||||
await request.state.db.close()
|
||||
|
||||
_request_id_ctx_var.reset(ctx_token)
|
||||
return response
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def add_security_headers(request: Request, call_next):
|
||||
response = await call_next(request)
|
||||
response.headers["Strict-Transport-Security"] = "max-age=31536000 ; includeSubDomains"
|
||||
return response
|
||||
|
||||
# class MetricsMiddleware(BaseHTTPMiddleware):
|
||||
# async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
|
||||
# method = request.method
|
||||
# endpoint = request.url.path
|
||||
# tags = {"method": method, "endpoint": endpoint}
|
||||
|
||||
# try:
|
||||
# start = time.perf_counter()
|
||||
# response = await call_next(request)
|
||||
# elapsed_time = time.perf_counter() - start
|
||||
# tags.update({"status_code": response.status_code})
|
||||
# metric_provider.counter("server.call.counter", tags=tags)
|
||||
# metric_provider.timer("server.call.elapsed", value=elapsed_time, tags=tags)
|
||||
# log.debug(f"server.call.elapsed.{endpoint}: {elapsed_time}")
|
||||
# except Exception as e:
|
||||
# metric_provider.counter("server.call.exception.counter", tags=tags)
|
||||
# raise e from None
|
||||
# return response
|
||||
|
||||
|
||||
# app.add_middleware(ExceptionMiddleware)
|
||||
|
||||
app.include_router(api_router)
|
||||
@ -0,0 +1,46 @@
|
||||
# import logging
|
||||
|
||||
# from dispatch.plugins.base import plugins
|
||||
|
||||
# from .config import METRIC_PROVIDERS
|
||||
|
||||
# log = logging.getLogger(__file__)
|
||||
|
||||
|
||||
# class Metrics(object):
|
||||
# _providers = []
|
||||
|
||||
# def __init__(self):
|
||||
# if not METRIC_PROVIDERS:
|
||||
# log.info(
|
||||
# "No metric providers defined via METRIC_PROVIDERS env var. Metrics will not be sent."
|
||||
# )
|
||||
# else:
|
||||
# self._providers = METRIC_PROVIDERS
|
||||
|
||||
# def gauge(self, name, value, tags=None):
|
||||
# for provider in self._providers:
|
||||
# log.debug(
|
||||
# f"Sending gauge metric {name} to provider {provider}. Value: {value} Tags: {tags}"
|
||||
# )
|
||||
# p = plugins.get(provider)
|
||||
# p.gauge(name, value, tags=tags)
|
||||
|
||||
# def counter(self, name, value=None, tags=None):
|
||||
# for provider in self._providers:
|
||||
# log.debug(
|
||||
# f"Sending counter metric {name} to provider {provider}. Value: {value} Tags: {tags}"
|
||||
# )
|
||||
# p = plugins.get(provider)
|
||||
# p.counter(name, value=value, tags=tags)
|
||||
|
||||
# def timer(self, name, value, tags=None):
|
||||
# for provider in self._providers:
|
||||
# log.debug(
|
||||
# f"Sending timer metric {name} to provider {provider}. Value: {value} Tags: {tags}"
|
||||
# )
|
||||
# p = plugins.get(provider)
|
||||
# p.timer(name, value, tags=tags)
|
||||
|
||||
|
||||
# provider = Metrics()
|
||||
@ -0,0 +1,85 @@
|
||||
# src/common/models.py
|
||||
from datetime import datetime
|
||||
from typing import Generic, Optional, TypeVar
|
||||
import uuid
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
from sqlalchemy import Column, DateTime, String, func, event
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from src.config import TIMEZONE
|
||||
import pytz
|
||||
|
||||
from src.auth.service import CurrentUser
|
||||
from src.enums import ResponseStatus
|
||||
# SQLAlchemy Mixins
|
||||
|
||||
|
||||
class TimeStampMixin(object):
|
||||
"""Timestamping mixin"""
|
||||
|
||||
created_at = Column(
|
||||
DateTime(timezone=True), default=datetime.now(pytz.timezone(TIMEZONE)))
|
||||
created_at._creation_order = 9998
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True), default=datetime.now(pytz.timezone(TIMEZONE)))
|
||||
updated_at._creation_order = 9998
|
||||
|
||||
@staticmethod
|
||||
def _updated_at(mapper, connection, target):
|
||||
target.updated_at = datetime.now(pytz.timezone(TIMEZONE))
|
||||
|
||||
@classmethod
|
||||
def __declare_last__(cls):
|
||||
event.listen(cls, "before_update", cls._updated_at)
|
||||
|
||||
|
||||
class UUIDMixin:
|
||||
"""UUID mixin"""
|
||||
id = Column(UUID(as_uuid=True), primary_key=True,
|
||||
default=uuid.uuid4, unique=True, nullable=False)
|
||||
|
||||
|
||||
class IdentityMixin:
|
||||
"""Identity mixin"""
|
||||
created_by = Column(String(100), nullable=True)
|
||||
updated_by = Column(String(100), nullable=True)
|
||||
|
||||
|
||||
class DefaultMixin(TimeStampMixin, UUIDMixin):
|
||||
"""Default mixin"""
|
||||
pass
|
||||
|
||||
|
||||
# Pydantic Models
|
||||
class DefultBase(BaseModel):
|
||||
class Config:
|
||||
from_attributes = True
|
||||
validate_assignment = True
|
||||
arbitrary_types_allowed = True
|
||||
str_strip_whitespace = True
|
||||
|
||||
json_encoders = {
|
||||
# custom output conversion for datetime
|
||||
datetime: lambda v: v.strftime("%Y-%m-%dT%H:%M:%S.%fZ") if v else None,
|
||||
SecretStr: lambda v: v.get_secret_value() if v else None,
|
||||
}
|
||||
|
||||
|
||||
class Pagination(DefultBase):
|
||||
itemsPerPage: int
|
||||
page: int
|
||||
total: int
|
||||
|
||||
|
||||
class PrimaryKeyModel(BaseModel):
|
||||
id: uuid.UUID
|
||||
|
||||
|
||||
# Define data type variable for generic response
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class StandardResponse(BaseModel, Generic[T]):
|
||||
data: Optional[T] = None
|
||||
message: str = "Success"
|
||||
status: ResponseStatus = ResponseStatus.SUCCESS
|
||||
@ -0,0 +1,59 @@
|
||||
|
||||
from fastapi import APIRouter, HTTPException, status
|
||||
|
||||
from src.overhaul.service import get_overhaul_critical_parts, get_overhaul_overview, get_overhaul_schedules, get_overhaul_system_components
|
||||
|
||||
from .schema import OverhaulRead, OverhaulSchedules, OverhaulCriticalParts, OverhaulSystemComponents
|
||||
|
||||
from src.models import StandardResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("", response_model=StandardResponse[OverhaulRead])
|
||||
async def get_overhaul():
|
||||
"""Get all scope pagination."""
|
||||
overview = get_overhaul_overview()
|
||||
schedules = get_overhaul_schedules()
|
||||
criticalParts = get_overhaul_critical_parts()
|
||||
systemComponents = get_overhaul_system_components()
|
||||
|
||||
return StandardResponse(
|
||||
data=OverhaulRead(
|
||||
overview=overview,
|
||||
schedules=schedules,
|
||||
criticalParts=criticalParts,
|
||||
systemComponents=systemComponents,
|
||||
),
|
||||
message="Data retrieved successfully",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/schedules", response_model=StandardResponse[OverhaulSchedules])
|
||||
async def get_schedules():
|
||||
"""Get all overhaul schedules."""
|
||||
schedules = get_overhaul_schedules()
|
||||
return StandardResponse(
|
||||
data=OverhaulSchedules(schedules=schedules),
|
||||
message="Data retrieved successfully",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/critical-parts", response_model=StandardResponse[OverhaulCriticalParts])
|
||||
async def get_critical_parts():
|
||||
"""Get all overhaul critical parts."""
|
||||
criticalParts = get_overhaul_critical_parts()
|
||||
return StandardResponse(
|
||||
data=OverhaulCriticalParts(criticalParts=criticalParts),
|
||||
message="Data retrieved successfully",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/system-components", response_model=StandardResponse[OverhaulSystemComponents])
|
||||
async def get_system_components():
|
||||
"""Get all overhaul system components."""
|
||||
systemComponents = get_overhaul_system_components()
|
||||
return StandardResponse(
|
||||
data=OverhaulSystemComponents(systemComponents=systemComponents),
|
||||
message="Data retrieved successfully",
|
||||
)
|
||||
@ -0,0 +1,71 @@
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import Field, BaseModel
|
||||
from src.models import DefultBase, Pagination
|
||||
|
||||
|
||||
class OverhaulBase(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class OverhaulCriticalParts(OverhaulBase):
|
||||
criticalParts: List[str] = Field(..., description="List of critical parts")
|
||||
|
||||
|
||||
class OverhaulSchedules(OverhaulBase):
|
||||
schedules: List[Dict[str, Any]
|
||||
] = Field(..., description="List of schedules")
|
||||
|
||||
|
||||
class OverhaulSystemComponents(OverhaulBase):
|
||||
systemComponents: Dict[str,
|
||||
Any] = Field(..., description="List of system components")
|
||||
|
||||
|
||||
class OverhaulRead(OverhaulBase):
|
||||
overview: Dict[str, Any]
|
||||
criticalParts: List[str]
|
||||
schedules: List[Dict[str, Any]]
|
||||
systemComponents: Dict[str, Any]
|
||||
|
||||
|
||||
# {
|
||||
# "overview": {
|
||||
# "totalEquipment": 30,
|
||||
# "nextSchedule": {
|
||||
# "date": "2025-01-12",
|
||||
# "Overhaul": "B",
|
||||
# "equipmentCount": 30
|
||||
# }
|
||||
# },
|
||||
# "criticalParts": [
|
||||
# "Boiler feed pump",
|
||||
# "Boiler reheater system",
|
||||
# "Drum Level (Right) Root Valve A",
|
||||
# "BCP A Discharge Valve",
|
||||
# "BFPT A EXH Press HI Root VLV"
|
||||
# ],
|
||||
# "schedules": [
|
||||
# {
|
||||
# "date": "2025-01-12",
|
||||
# "Overhaul": "B",
|
||||
# "status": "upcoming"
|
||||
# }
|
||||
# // ... other scheduled overhauls
|
||||
# ],
|
||||
# "systemComponents": {
|
||||
# "boiler": {
|
||||
# "status": "operational",
|
||||
# "lastOverhaul": "2024-06-15"
|
||||
# },
|
||||
# "turbine": {
|
||||
# "hpt": { "status": "operational" },
|
||||
# "ipt": { "status": "operational" },
|
||||
# "lpt": { "status": "operational" }
|
||||
# }
|
||||
# // ... other major components
|
||||
# }
|
||||
# }
|
||||
@ -0,0 +1,109 @@
|
||||
|
||||
|
||||
from sqlalchemy import Select, Delete
|
||||
from typing import Optional
|
||||
|
||||
from src.database.core import DbSession
|
||||
from src.auth.service import CurrentUser
|
||||
|
||||
|
||||
def get_overhaul_overview():
|
||||
"""Get all overhaul overview."""
|
||||
return {
|
||||
"totalEquipment": 30,
|
||||
"nextSchedule": {
|
||||
"start_date": "2025-01-12",
|
||||
"end_date": "2025-01-15",
|
||||
"duration": 3,
|
||||
"Overhaul": "B",
|
||||
"equipmentCount": 30
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def get_overhaul_critical_parts():
|
||||
"""Get all overhaul critical parts."""
|
||||
return [
|
||||
"Boiler feed pump",
|
||||
"Boiler reheater system",
|
||||
"Drum Level (Right) Root Valve A",
|
||||
"BCP A Discharge Valve",
|
||||
"BFPT A EXH Press HI Root VLV"
|
||||
]
|
||||
|
||||
|
||||
def get_overhaul_schedules():
|
||||
"""Get all overhaul schedules."""
|
||||
return [
|
||||
{
|
||||
"date": "2025-01-12",
|
||||
"Overhaul": "B",
|
||||
"status": "upcoming"
|
||||
},
|
||||
{
|
||||
"date": "2025-02-15",
|
||||
"Overhaul": "A",
|
||||
"status": "upcoming"
|
||||
},
|
||||
|
||||
]
|
||||
|
||||
|
||||
def get_overhaul_system_components():
|
||||
"""Get all overhaul system components."""
|
||||
return {
|
||||
"boiler": {
|
||||
"efficiency": "90%",
|
||||
"work_hours": "1000",
|
||||
"reliability": "95%",
|
||||
},
|
||||
"HPT": {
|
||||
"efficiency": "90%",
|
||||
"work_hours": "1000",
|
||||
"reliability": "95%",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# async def get(*, db_session: DbSession, scope_id: str) -> Optional[Scope]:
|
||||
# """Returns a document based on the given document id."""
|
||||
# query = Select(Scope).filter(Scope.id == scope_id)
|
||||
# result = await db_session.execute(query)
|
||||
# return result.scalars().one_or_none()
|
||||
|
||||
|
||||
# async def get_all(*, db_session: DbSession):
|
||||
# """Returns all documents."""
|
||||
# query = Select(Scope)
|
||||
# result = await db_session.execute(query)
|
||||
# return result.scalars().all()
|
||||
|
||||
|
||||
# async def create(*, db_session: DbSession, scope_id: ScopeCreate):
|
||||
# """Creates a new document."""
|
||||
# scope = Scope(**scope_id.model_dump())
|
||||
# db_session.add(scope)
|
||||
# await db_session.commit()
|
||||
# return scope
|
||||
|
||||
|
||||
# async def update(*, db_session: DbSession, scope: Scope, scope_id: ScopeUpdate):
|
||||
# """Updates a document."""
|
||||
# data = scope_id.model_dump()
|
||||
|
||||
# update_data = scope_id.model_dump(exclude_defaults=True)
|
||||
|
||||
# for field in data:
|
||||
# if field in update_data:
|
||||
# setattr(scope, field, update_data[field])
|
||||
|
||||
# await db_session.commit()
|
||||
|
||||
# return scope
|
||||
|
||||
|
||||
# async def delete(*, db_session: DbSession, scope_id: str):
|
||||
# """Deletes a document."""
|
||||
# query = Delete(Scope).where(Scope.id == scope_id)
|
||||
# await db_session.execute(query)
|
||||
# await db_session.commit()
|
||||
@ -0,0 +1,5 @@
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
|
||||
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
@ -0,0 +1,12 @@
|
||||
|
||||
from sqlalchemy import Column, Float, Integer, String
|
||||
from src.database.core import Base
|
||||
from src.models import DefaultMixin, IdentityMixin, TimeStampMixin
|
||||
|
||||
|
||||
class Scope(Base, DefaultMixin):
|
||||
__tablename__ = "oh_scope"
|
||||
|
||||
scope_name = Column(String, nullable=True)
|
||||
duration_oh = Column(Integer, nullable=True)
|
||||
crew = Column(Integer, nullable=True)
|
||||
@ -0,0 +1,70 @@
|
||||
|
||||
from fastapi import APIRouter, HTTPException, status
|
||||
|
||||
from .model import Scope
|
||||
from .schema import ScopeCreate, ScopeRead, ScopeUpdate, ScopePagination
|
||||
from .service import get, get_all, create, update, delete
|
||||
|
||||
from src.database.service import CommonParameters, search_filter_sort_paginate
|
||||
from src.database.core import DbSession
|
||||
from src.auth.service import CurrentUser
|
||||
from src.models import StandardResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("", response_model=StandardResponse[ScopePagination])
|
||||
async def get_scopes(common: CommonParameters):
|
||||
"""Get all scope pagination."""
|
||||
# return
|
||||
return StandardResponse(
|
||||
data=await search_filter_sort_paginate(model=Scope, **common),
|
||||
message="Data retrieved successfully",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{scope_id}", response_model=StandardResponse[ScopeRead])
|
||||
async def get_scope(db_session: DbSession, scope_id: str):
|
||||
scope = await get(db_session=db_session, scope_id=scope_id)
|
||||
if not scope:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="A data with this id does not exist.",
|
||||
)
|
||||
|
||||
return StandardResponse(data=scope, message="Data retrieved successfully")
|
||||
|
||||
|
||||
@router.post("", response_model=StandardResponse[ScopeRead])
|
||||
async def create_scope(db_session: DbSession, scope_in: ScopeCreate):
|
||||
scope = await create(db_session=db_session, scope_in=scope_in)
|
||||
|
||||
return StandardResponse(data=scope, message="Data created successfully")
|
||||
|
||||
|
||||
@router.put("/{scope_id}", response_model=StandardResponse[ScopeRead])
|
||||
async def update_scope(db_session: DbSession, scope_id: str, scope_in: ScopeUpdate, current_user: CurrentUser):
|
||||
scope = await get(db_session=db_session, scope_id=scope_id)
|
||||
|
||||
if not scope:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="A data with this id does not exist.",
|
||||
)
|
||||
|
||||
return StandardResponse(data=await update(db_session=db_session, scope=scope, scope_in=scope_in), message="Data updated successfully")
|
||||
|
||||
|
||||
@router.delete("/{scope_id}", response_model=StandardResponse[ScopeRead])
|
||||
async def delete_scope(db_session: DbSession, scope_id: str):
|
||||
scope = await get(db_session=db_session, scope_id=scope_id)
|
||||
|
||||
if not scope:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=[{"msg": "A data with this id does not exist."}],
|
||||
)
|
||||
|
||||
await delete(db_session=db_session, scope_id=scope_id)
|
||||
|
||||
return StandardResponse(message="Data deleted successfully", data=scope)
|
||||
@ -0,0 +1,29 @@
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import Field
|
||||
from src.models import DefultBase, Pagination
|
||||
|
||||
|
||||
class ScopeBase(DefultBase):
|
||||
scope_name: Optional[str] = Field(None, title="Scope Name")
|
||||
duration_oh: Optional[int] = Field(None, title="Duration OH")
|
||||
crew: Optional[int] = Field(None, title="Crew")
|
||||
|
||||
|
||||
class ScopeCreate(ScopeBase):
|
||||
pass
|
||||
|
||||
|
||||
class ScopeUpdate(ScopeBase):
|
||||
pass
|
||||
|
||||
|
||||
class ScopeRead(ScopeBase):
|
||||
id: UUID
|
||||
|
||||
|
||||
class ScopePagination(Pagination):
|
||||
items: List[ScopeRead] = []
|
||||
@ -0,0 +1,60 @@
|
||||
|
||||
|
||||
from sqlalchemy import Select, Delete
|
||||
from .model import Scope
|
||||
from .schema import ScopeCreate, ScopeUpdate
|
||||
from typing import Optional
|
||||
|
||||
from src.database.core import DbSession
|
||||
from src.auth.service import CurrentUser
|
||||
|
||||
|
||||
async def get(*, db_session: DbSession, scope_id: str) -> Optional[Scope]:
|
||||
"""Returns a document based on the given document id."""
|
||||
query = Select(Scope).filter(Scope.id == scope_id)
|
||||
result = await db_session.execute(query)
|
||||
return result.scalars().one_or_none()
|
||||
|
||||
|
||||
async def get_all(*, db_session: DbSession):
|
||||
"""Returns all documents."""
|
||||
query = Select(Scope)
|
||||
result = await db_session.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
async def create(*, db_session: DbSession, scope_in: ScopeCreate):
|
||||
"""Creates a new document."""
|
||||
scope = Scope(**scope_in.model_dump())
|
||||
db_session.add(scope)
|
||||
await db_session.commit()
|
||||
return scope
|
||||
|
||||
|
||||
async def update(*, db_session: DbSession, scope: Scope, scope_in: ScopeUpdate):
|
||||
"""Updates a document."""
|
||||
data = scope_in.model_dump()
|
||||
|
||||
update_data = scope_in.model_dump(exclude_defaults=True)
|
||||
|
||||
for field in data:
|
||||
if field in update_data:
|
||||
setattr(scope, field, update_data[field])
|
||||
|
||||
await db_session.commit()
|
||||
|
||||
return scope
|
||||
|
||||
|
||||
async def delete(*, db_session: DbSession, scope_id: str):
|
||||
"""Deletes a document."""
|
||||
query = Delete(Scope).where(Scope.id == scope_id)
|
||||
await db_session.execute(query)
|
||||
await db_session.commit()
|
||||
|
||||
|
||||
async def get_by_scope_name(*, db_session: DbSession, scope_name: str) -> Optional[Scope]:
|
||||
"""Returns a document based on the given document id."""
|
||||
query = Select(Scope).filter(Scope.scope_name == scope_name)
|
||||
result = await db_session.execute(query)
|
||||
return result.scalars().one_or_none()
|
||||
@ -0,0 +1,14 @@
|
||||
|
||||
from sqlalchemy import UUID, Column, Float, Integer, String, ForeignKey
|
||||
from src.database.core import Base
|
||||
from src.models import DefaultMixin, IdentityMixin, TimeStampMixin
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
|
||||
class ScopeEquipment(Base, DefaultMixin):
|
||||
__tablename__ = "oh_scope_equip"
|
||||
|
||||
assetnum = Column(String, nullable=True)
|
||||
scope_id = Column(UUID(as_uuid=True), ForeignKey('oh_scope.id'), nullable=False)
|
||||
|
||||
scope = relationship("Scope", backref="scope_equipments", lazy="selectin")
|
||||
@ -0,0 +1,81 @@
|
||||
|
||||
from typing import List
|
||||
from fastapi import APIRouter, HTTPException, status
|
||||
from fastapi.params import Query
|
||||
|
||||
from .model import ScopeEquipment
|
||||
from .schema import ScopeEquipmentCreate, ScopeEquipmentPagination, ScopeEquipmentRead, ScopeEquipmentUpdate
|
||||
from .service import get, get_all, create, update, delete, get_by_scope_name, get_exculed_scope_name
|
||||
|
||||
from src.database.service import CommonParameters, search_filter_sort_paginate
|
||||
from src.database.core import DbSession
|
||||
from src.auth.service import CurrentUser
|
||||
from src.models import StandardResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("", response_model=StandardResponse[ScopeEquipmentPagination])
|
||||
async def get_scope_equipments(common: CommonParameters):
|
||||
"""Get all scope pagination."""
|
||||
# return
|
||||
return StandardResponse(
|
||||
data=await search_filter_sort_paginate(model=ScopeEquipment, **common),
|
||||
message="Data retrieved successfully",
|
||||
)
|
||||
|
||||
@router.get("/scope/{scope_name}", response_model=StandardResponse[List[ScopeEquipmentRead]])
|
||||
async def get_scope_name(db_session: DbSession, scope_name: str, exclude: bool = Query(False)):
|
||||
if exclude:
|
||||
return StandardResponse(data=await get_exculed_scope_name(db_session=db_session, scope_name=scope_name), message="Data retrieved successfully")
|
||||
|
||||
return StandardResponse(data=await get_by_scope_name(db_session=db_session, scope_name=scope_name), message="Data retrieved successfully")
|
||||
|
||||
|
||||
@router.get("/{scope_equipment_id}", response_model=StandardResponse[ScopeEquipmentRead])
|
||||
async def get_scope_equipment(db_session: DbSession, scope_equipment_id: str):
|
||||
scope = await get(db_session=db_session, scope_equipment_id=scope_equipment_id)
|
||||
if not scope:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="A data with this id does not exist.",
|
||||
)
|
||||
|
||||
return StandardResponse(data=scope, message="Data retrieved successfully")
|
||||
|
||||
|
||||
@router.post("", response_model=StandardResponse[ScopeEquipmentRead])
|
||||
async def create_scope_equipment(db_session: DbSession, scope__equipment_in: ScopeEquipmentCreate):
|
||||
scope = await create(db_session=db_session, scope__equipment_in=scope__equipment_in)
|
||||
|
||||
return StandardResponse(data=scope, message="Data created successfully")
|
||||
|
||||
|
||||
@router.put("/{scope_equipment_id}", response_model=StandardResponse[ScopeEquipmentRead])
|
||||
async def update_scope_equipment(db_session: DbSession, scope_equipment_id: str, scope__equipment_in: ScopeEquipmentUpdate):
|
||||
scope_equipment = await get(db_session=db_session, scope_equipment_id=scope_equipment_id)
|
||||
|
||||
if not scope_equipment:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="A data with this id does not exist.",
|
||||
)
|
||||
|
||||
return StandardResponse(data=await update(db_session=db_session, scope_equipment=scope_equipment, scope__equipment_in=scope__equipment_in), message="Data updated successfully")
|
||||
|
||||
|
||||
@router.delete("/{scope_equipment_id}", response_model=StandardResponse[ScopeEquipmentRead])
|
||||
async def delete_scope_equipment(db_session: DbSession, scope_equipment_id: str):
|
||||
scope_equipment = await get(db_session=db_session, scope_equipment_id=scope_equipment_id)
|
||||
|
||||
if not scope_equipment:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=[{"msg": "A data with this id does not exist."}],
|
||||
)
|
||||
|
||||
await delete(db_session=db_session, scope_equipment_id=scope_equipment_id)
|
||||
|
||||
return StandardResponse(message="Data deleted successfully", data=scope_equipment)
|
||||
|
||||
|
||||
@ -0,0 +1,28 @@
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import Field
|
||||
from src.models import DefultBase, Pagination
|
||||
|
||||
|
||||
class ScopeEquipmentBase(DefultBase):
|
||||
scope_id: Optional[UUID] = Field(None, title="Scope ID")
|
||||
|
||||
|
||||
class ScopeEquipmentCreate(ScopeEquipmentBase):
|
||||
assetnum: str
|
||||
|
||||
|
||||
class ScopeEquipmentUpdate(ScopeEquipmentBase):
|
||||
assetnum: Optional[str] = Field(None, title="Asset Number")
|
||||
|
||||
|
||||
class ScopeEquipmentRead(ScopeEquipmentBase):
|
||||
id: UUID
|
||||
assetnum: str
|
||||
|
||||
|
||||
class ScopeEquipmentPagination(Pagination):
|
||||
items: List[ScopeEquipmentRead] = []
|
||||
@ -0,0 +1,84 @@
|
||||
|
||||
|
||||
from sqlalchemy import Select, Delete
|
||||
from .model import ScopeEquipment
|
||||
from src.scope.service import get_by_scope_name as get_scope_by_name_service
|
||||
from .schema import ScopeEquipmentCreate, ScopeEquipmentUpdate
|
||||
from typing import Optional, Union
|
||||
|
||||
from src.database.core import DbSession
|
||||
from src.auth.service import CurrentUser
|
||||
|
||||
|
||||
async def get(*, db_session: DbSession, scope_equipment_id: str) -> Optional[ScopeEquipment]:
|
||||
"""Returns a document based on the given document id."""
|
||||
query = Select(ScopeEquipment).filter(
|
||||
ScopeEquipment.id == scope_equipment_id)
|
||||
result = await db_session.execute(query)
|
||||
return result.scalars().one_or_none()
|
||||
|
||||
|
||||
async def get_all(*, db_session: DbSession):
|
||||
"""Returns all documents."""
|
||||
query = Select(ScopeEquipment)
|
||||
result = await db_session.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
async def create(*, db_session: DbSession, scope_equipment_in: ScopeEquipmentCreate):
|
||||
"""Creates a new document."""
|
||||
scope_equipment = ScopeEquipment(**scope_equipment_in.model_dump())
|
||||
db_session.add(scope_equipment)
|
||||
await db_session.commit()
|
||||
return scope_equipment
|
||||
|
||||
|
||||
async def update(*, db_session: DbSession, scope_equipment: ScopeEquipment, scope_equipment_in: ScopeEquipmentUpdate):
|
||||
"""Updates a document."""
|
||||
data = scope_equipment_in.model_dump()
|
||||
|
||||
update_data = scope_equipment_in.model_dump(exclude_defaults=True)
|
||||
|
||||
for field in data:
|
||||
if field in update_data:
|
||||
setattr(scope_equipment, field, update_data[field])
|
||||
|
||||
await db_session.commit()
|
||||
|
||||
return scope_equipment
|
||||
|
||||
|
||||
async def delete(*, db_session: DbSession, scope_equipment_id: str):
|
||||
"""Deletes a document."""
|
||||
query = Delete(ScopeEquipment).where(
|
||||
ScopeEquipment.id == scope_equipment_id)
|
||||
await db_session.execute(query)
|
||||
await db_session.commit()
|
||||
|
||||
|
||||
async def get_by_scope_name(*, db_session: DbSession, scope_name: Union[str, list]) -> Optional[ScopeEquipment]:
|
||||
"""Returns a document based on the given document id."""
|
||||
scope = await get_scope_by_name_service(db_session=db_session, scope_name=scope_name)
|
||||
|
||||
query = Select(ScopeEquipment)
|
||||
|
||||
if scope:
|
||||
query = query.filter(ScopeEquipment.scope_id == scope.id)
|
||||
|
||||
result = await db_session.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
async def get_exculed_scope_name(*, db_session: DbSession, scope_name: Union[str, list]) -> Optional[ScopeEquipment]:
|
||||
scope = await get_scope_by_name_service(db_session=db_session, scope_name=scope_name)
|
||||
|
||||
query = Select(ScopeEquipment)
|
||||
|
||||
if scope:
|
||||
query = query.filter(ScopeEquipment.scope_id != scope.id)
|
||||
|
||||
else:
|
||||
query = query.filter(ScopeEquipment.scope_id != None)
|
||||
|
||||
result = await db_session.execute(query)
|
||||
return result.scalars().all()
|
||||
@ -0,0 +1,40 @@
|
||||
from connectors.database import DBConfig
|
||||
from starlette.config import Config
|
||||
import os
|
||||
|
||||
|
||||
def get_config():
|
||||
try:
|
||||
# Try to load from .env file first
|
||||
config = Config(".env")
|
||||
except FileNotFoundError:
|
||||
# If .env doesn't exist, use environment variables
|
||||
config = Config(environ=os.environ)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
env = get_config()
|
||||
|
||||
config = {
|
||||
'batch_size': 1000,
|
||||
'target_table': 'oh_wo_master',
|
||||
'columns': ['assetnum', 'worktype', 'workgroup', 'total_cost_max', 'created_at'],
|
||||
}
|
||||
|
||||
target_config = DBConfig(
|
||||
host=env("DATABASE_HOSTNAME"),
|
||||
port=env("DATABASE_PORT"),
|
||||
database=env("DATABASE_NAME"),
|
||||
user=env("DATABASE_CREDENTIAL_USER"),
|
||||
password=env("DATABASE_CREDENTIAL_PASSWORD")
|
||||
)
|
||||
|
||||
|
||||
source_config = DBConfig(
|
||||
host=env("COLLECTOR_HOSTNAME"),
|
||||
port=env("COLLECTOR_PORT"),
|
||||
database=env("COLLECTOR_NAME"),
|
||||
user=env("COLLECTOR_CREDENTIAL_USER"),
|
||||
password=env("COLLECTOR_CREDENTIAL_PASSWORD")
|
||||
)
|
||||
@ -0,0 +1,104 @@
|
||||
# src/connectors/database.py
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Dict, Any, Generator
|
||||
import pandas as pd
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.engine import Engine
|
||||
import logging
|
||||
import os
|
||||
from starlette.config import Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class DBConfig:
|
||||
host: str
|
||||
port: int
|
||||
database: str
|
||||
user: str
|
||||
password: str
|
||||
|
||||
def get_connection_string(self) -> str:
|
||||
return f"postgresql+psycopg2://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}"
|
||||
|
||||
|
||||
class DatabaseConnector:
|
||||
def __init__(self, source_config: DBConfig, target_config: DBConfig):
|
||||
self.source_engine = create_engine(
|
||||
source_config.get_connection_string())
|
||||
self.target_engine = create_engine(
|
||||
target_config.get_connection_string())
|
||||
|
||||
def fetch_batch(
|
||||
self,
|
||||
batch_size: int,
|
||||
last_id: int = 0,
|
||||
columns: List[str] = None
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Fetch a batch of data from source database
|
||||
"""
|
||||
try:
|
||||
query = """
|
||||
SELECT {}
|
||||
FROM dl_wo_staging
|
||||
WHERE id > :last_id
|
||||
AND worktype IN ('PM', 'CM', 'EM', 'PROACTIVE')
|
||||
ORDER BY id
|
||||
LIMIT :batch_size
|
||||
""".format(', '.join(columns) if columns else '*')
|
||||
|
||||
# Execute query
|
||||
params = {
|
||||
"last_id": last_id,
|
||||
"batch_size": batch_size
|
||||
}
|
||||
|
||||
df = pd.read_sql(
|
||||
text(query),
|
||||
self.source_engine,
|
||||
params=params
|
||||
)
|
||||
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching batch: {str(e)}")
|
||||
raise
|
||||
|
||||
def load_batch(
|
||||
self,
|
||||
df: pd.DataFrame,
|
||||
target_table: str,
|
||||
if_exists: str = 'append'
|
||||
) -> bool:
|
||||
"""
|
||||
Load a batch of data to target database
|
||||
"""
|
||||
try:
|
||||
df.to_sql(
|
||||
target_table,
|
||||
self.target_engine,
|
||||
if_exists=if_exists,
|
||||
index=False,
|
||||
method='multi',
|
||||
chunksize=1000
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading batch: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_total_records(self) -> int:
|
||||
"""Get total number of records to migrate"""
|
||||
try:
|
||||
with self.source_engine.connect() as conn:
|
||||
result = conn.execute(text("SELECT COUNT(*) FROM sensor_data"))
|
||||
return result.scalar()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting total records: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@ -0,0 +1,39 @@
|
||||
# src/scripts/run_migration.py
|
||||
import asyncio
|
||||
from temporalio.client import Client
|
||||
from temporalio.worker import Worker
|
||||
from workflows.historical_migration import DataMigrationWorkflow, fetch_data_activity, transform_data_activity, validate_data_activity, load_data_activity
|
||||
|
||||
|
||||
async def main():
|
||||
# Create Temporal client
|
||||
client = await Client.connect("192.168.1.82:7233")
|
||||
|
||||
# Create worker
|
||||
worker = Worker(
|
||||
client,
|
||||
task_queue="migration-queue",
|
||||
workflows=[DataMigrationWorkflow],
|
||||
activities=[
|
||||
fetch_data_activity,
|
||||
transform_data_activity,
|
||||
validate_data_activity,
|
||||
load_data_activity
|
||||
]
|
||||
)
|
||||
|
||||
# Start worker
|
||||
await worker.run()
|
||||
|
||||
# # Start workflow
|
||||
# handle = await client.start_workflow(
|
||||
# DataMigrationWorkflow.run,
|
||||
# id="data-migration",
|
||||
# task_queue="migration-queue"
|
||||
# )
|
||||
|
||||
# result = await handle.result()
|
||||
# print(f"Migration completed: {result}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@ -0,0 +1,20 @@
|
||||
from temporalio.client import Client
|
||||
from workflows.historical_migration import DataMigrationWorkflow
|
||||
|
||||
|
||||
async def run():
|
||||
# Start workflow
|
||||
|
||||
client = await Client.connect("192.168.1.82:7233")
|
||||
|
||||
handle = await client.start_workflow(
|
||||
DataMigrationWorkflow.run,
|
||||
id="data-migration",
|
||||
task_queue="migration-queue"
|
||||
)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
asyncio.run(run())
|
||||
@ -0,0 +1,48 @@
|
||||
|
||||
|
||||
# src/transformers/sensor_data.py
|
||||
from typing import Dict, Any
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
from uuid import UUID, uuid4
|
||||
from config import config
|
||||
|
||||
|
||||
class WoDataTransformer:
|
||||
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Transform sensor data according to business rules
|
||||
"""
|
||||
# Create a copy to avoid modifying original data
|
||||
transformed = df.copy()
|
||||
|
||||
# 1. Add UUID
|
||||
transformed['id'] = uuid4()
|
||||
|
||||
# # 5. Drop unnecessary columns
|
||||
# columns_to_drop = self.config.get('columns_to_drop', [])
|
||||
# if columns_to_drop:
|
||||
# transformed = transformed.drop(columns=columns_to_drop, errors='ignore')
|
||||
|
||||
return transformed
|
||||
|
||||
def validate(self, df: pd.DataFrame) -> bool:
|
||||
"""
|
||||
Validate transformed data
|
||||
"""
|
||||
if df.empty:
|
||||
return False
|
||||
|
||||
# Check required columns
|
||||
if not all(col in df.columns for col in config.get('columns')):
|
||||
return False
|
||||
|
||||
# check id column and id is UUID
|
||||
if 'id' not in df.columns:
|
||||
return False
|
||||
|
||||
if not all(isinstance(val, UUID) for val in df['id']):
|
||||
return False
|
||||
|
||||
return True
|
||||
@ -0,0 +1,123 @@
|
||||
from datetime import timedelta
|
||||
import pandas as pd
|
||||
from connectors.database import DatabaseConnector
|
||||
from temporalio import workflow, activity
|
||||
from temporalio.common import RetryPolicy
|
||||
from typing import Dict, List
|
||||
from config import source_config, target_config, config
|
||||
from transformation.wo_transform import WoDataTransformer
|
||||
# Activities
|
||||
|
||||
|
||||
@activity.defn
|
||||
async def fetch_data_activity(batch_size: int, last_id: int, columns) -> Dict:
|
||||
db_connector = DatabaseConnector(source_config, target_config)
|
||||
df = db_connector.fetch_batch(batch_size, last_id, columns)
|
||||
return df.to_dict(orient='records')
|
||||
|
||||
|
||||
@activity.defn
|
||||
async def transform_data_activity(data: List[Dict]) -> List[Dict]:
|
||||
transformer = WoDataTransformer()
|
||||
df = pd.DataFrame(data)
|
||||
transformed_df = transformer.transform(df)
|
||||
return transformed_df.to_dict(orient='records')
|
||||
|
||||
|
||||
@activity.defn
|
||||
async def validate_data_activity(data: List[Dict]) -> bool:
|
||||
transformer = WoDataTransformer()
|
||||
df = pd.DataFrame(data)
|
||||
return transformer.validate(df)
|
||||
|
||||
|
||||
@activity.defn
|
||||
async def load_data_activity(data: List[Dict], target_table: str) -> bool:
|
||||
db_connector = DatabaseConnector(source_config, target_config)
|
||||
df = pd.DataFrame(data)
|
||||
return db_connector.load_batch(df, target_table)
|
||||
|
||||
# Workflow
|
||||
|
||||
|
||||
@workflow.defn
|
||||
class DataMigrationWorkflow:
|
||||
def __init__(self):
|
||||
self._total_processed = 0
|
||||
self._last_id = 0
|
||||
|
||||
@workflow.run
|
||||
async def run(self) -> Dict:
|
||||
retry_policy = RetryPolicy(
|
||||
initial_interval=timedelta(seconds=1),
|
||||
maximum_interval=timedelta(minutes=10),
|
||||
maximum_attempts=3
|
||||
)
|
||||
|
||||
batch_size = config.get('batch_size', 1000)
|
||||
target_table = config.get('target_table')
|
||||
columns = config.get('columns')
|
||||
|
||||
while True:
|
||||
# 1. Fetch batch
|
||||
data = await workflow.execute_activity(
|
||||
fetch_data_activity,
|
||||
args=[batch_size, self._last_id, columns],
|
||||
retry_policy=retry_policy,
|
||||
start_to_close_timeout=timedelta(minutes=5)
|
||||
)
|
||||
|
||||
if not data:
|
||||
break
|
||||
|
||||
# 2. Transform data
|
||||
transformed_data = await workflow.execute_activity(
|
||||
transform_data_activity,
|
||||
args=[data, config],
|
||||
retry_policy=retry_policy,
|
||||
start_to_close_timeout=timedelta(minutes=10)
|
||||
)
|
||||
|
||||
# 3. Validate data
|
||||
is_valid = await workflow.execute_activity(
|
||||
validate_data_activity,
|
||||
args=[transformed_data],
|
||||
retry_policy=retry_policy,
|
||||
start_to_close_timeout=timedelta(minutes=5)
|
||||
)
|
||||
|
||||
if not is_valid:
|
||||
raise ValueError(
|
||||
f"Data validation failed for batch after ID {self._last_id}")
|
||||
|
||||
# 4. Load data
|
||||
success = await workflow.execute_activity(
|
||||
load_data_activity,
|
||||
args=[transformed_data, target_table],
|
||||
retry_policy=retry_policy,
|
||||
start_to_close_timeout=timedelta(minutes=10)
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise Exception(
|
||||
f"Failed to load batch after ID {self._last_id}")
|
||||
|
||||
# Update progress
|
||||
self._total_processed += len(data)
|
||||
self._last_id = data[-1]['id']
|
||||
|
||||
# Record progress
|
||||
# await workflow.execute_activity(
|
||||
# record_progress_activity,
|
||||
# args=[{
|
||||
# 'last_id': self._last_id,
|
||||
# 'total_processed': self._total_processed
|
||||
# }],
|
||||
# retry_policy=retry_policy,
|
||||
# start_to_close_timeout=timedelta(minutes=1)
|
||||
# )
|
||||
|
||||
return {
|
||||
'total_processed': self._total_processed,
|
||||
'last_id': self._last_id
|
||||
}
|
||||
@ -0,0 +1,69 @@
|
||||
import asyncio
|
||||
from typing import AsyncGenerator, Generator
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
import pytest
|
||||
from sqlalchemy_utils import drop_database, database_exists
|
||||
from starlette.config import environ
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
# from src.database import Base, get_db
|
||||
# from src.main import app
|
||||
|
||||
# Test database URL
|
||||
TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
|
||||
|
||||
engine = create_async_engine(
|
||||
TEST_DATABASE_URL,
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
|
||||
async_session = sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
|
||||
async def override_get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
async with async_session() as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop() -> Generator:
|
||||
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def setup_db() -> AsyncGenerator[None, None]:
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client() -> AsyncGenerator[AsyncClient, None]:
|
||||
async with AsyncClient(app=app, base_url="http://test") as client:
|
||||
yield client
|
||||
@ -0,0 +1,3 @@
|
||||
from sqlalchemy.orm import scoped_session, sessionmaker
|
||||
|
||||
Session = scoped_session(sessionmaker())
|
||||
@ -0,0 +1,33 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from factory import (
|
||||
LazyAttribute,
|
||||
LazyFunction,
|
||||
Sequence,
|
||||
SubFactory,
|
||||
post_generation,
|
||||
SelfAttribute,
|
||||
)
|
||||
from factory.alchemy import SQLAlchemyModelFactory
|
||||
from factory.fuzzy import FuzzyChoice, FuzzyDateTime, FuzzyInteger, FuzzyText
|
||||
from faker import Faker
|
||||
from faker.providers import misc
|
||||
# from pytz import UTC
|
||||
|
||||
|
||||
from .database import Session
|
||||
|
||||
fake = Faker()
|
||||
fake.add_provider(misc)
|
||||
|
||||
|
||||
class BaseFactory(SQLAlchemyModelFactory):
|
||||
"""Base Factory."""
|
||||
|
||||
class Meta:
|
||||
"""Factory configuration."""
|
||||
|
||||
abstract = True
|
||||
sqlalchemy_session = Session
|
||||
sqlalchemy_session_persistence = "commit"
|
||||
Loading…
Reference in New Issue