You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

195 lines
6.1 KiB
Python

import asyncio
from typing import List, Optional
from uuid import UUID
from sqlalchemy import Delete, Select, func, select
from sqlalchemy import update as sqlUpdate
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.orm import joinedload, selectinload
from src.auth.service import CurrentUser
from src.database.core import DbSession
from src.database.service import CommonParameters, search_filter_sort_paginate
from src.overhaul_activity.utils import get_material_cost, get_service_cost
from src.overhaul_scope.model import OverhaulScope
from src.overhaul_scope.service import get as get_session
from .model import OverhaulActivity
from .schema import (OverhaulActivityCreate, OverhaulActivityRead,
OverhaulActivityUpdate)
async def get(
*, db_session: DbSession, assetnum: str, overhaul_session_id: Optional[UUID] = None
) -> Optional[OverhaulActivityRead]:
"""Returns a document based on the given document id."""
query = (
Select(OverhaulActivity)
.where(OverhaulActivity.assetnum == assetnum)
.options(joinedload(OverhaulActivity.equipment))
)
if overhaul_session_id:
query = query.filter(OverhaulActivity.overhaul_scope_id == overhaul_session_id)
result = await db_session.execute(query)
return result.scalar()
async def get_all(
*,
common: CommonParameters,
overhaul_session_id: UUID,
assetnum: Optional[str] = None,
scope_name: Optional[str] = None
):
query = (
Select(OverhaulActivity)
.where(OverhaulActivity.overhaul_scope_id == overhaul_session_id)
.options(joinedload(OverhaulActivity.equipment))
)
if assetnum:
query = query.filter(OverhaulActivity.assetnum == assetnum).options(
joinedload(OverhaulActivity.overhaul_scope)
)
if scope_name:
query = query.filter(OverhaulActivity.scope_name == scope_name).options(
joinedload(OverhaulActivity.overhaul_scope)
)
results = await search_filter_sort_paginate(model=query, **common)
return results
async def get_all_by_session_id(*, db_session: DbSession, overhaul_session_id):
query = (
Select(OverhaulActivity)
.where(OverhaulActivity.overhaul_scope_id == overhaul_session_id)
.options(joinedload(OverhaulActivity.equipment))
.options(selectinload(OverhaulActivity.overhaul_scope))
)
results = await db_session.execute(query)
return results.scalars().all()
# async def create(*, db_session: DbSession, overhaul_activty_in: OverhaulActivityCreate, overhaul_session_id: UUID):
# # Check if the combination of assetnum and activity_id already exists
# existing_equipment_query = (
# Select(OverhaulActivity)
# .where(
# OverhaulActivity.assetnum == overhaul_activty_in.assetnum,
# OverhaulActivity.overhaul_scope_id == overhaul_session_id
# )
# )
# result = await db_session.execute(existing_equipment_query)
# existing_activity = result.scalar_one_or_none()
# # If the combination exists, raise an exception or return the existing activity
# if existing_activity:
# raise ValueError("This assetnum already exist.")
# activity = OverhaulActivity(
# **overhaul_activty_in.model_dump(),
# overhaul_scope_id=overhaul_session_id)
# db_session.add(activity)
# await db_session.commit()
# # Refresh and load relationships using joinedload
# query = (
# Select(OverhaulActivity)
# .options(joinedload(OverhaulActivity.equipment))
# .where(OverhaulActivity.id == activity.id)
# )
# result = await db_session.execute(query)
# activity_with_relationship = result.scalar_one()
# return activity_with_relationship
async def create(
*,
db_session: DbSession,
overhaul_activty_in: OverhaulActivityCreate,
overhaul_session_id: UUID
):
"""Creates a new document."""
assetnums = overhaul_activty_in.assetnums
if not assetnums:
return []
# Get session and count in parallel
session = await get_session(
db_session=db_session, overhaul_session_id=overhaul_session_id
)
equipment_count = await db_session.scalar(
select(func.count())
.select_from(OverhaulActivity)
.where(OverhaulActivity.overhaul_scope_id == overhaul_session_id)
)
# Calculate costs for all records
total_equipment = equipment_count + len(assetnums)
material_cost = get_material_cost(
scope=session.type, total_equipment=total_equipment
)
service_cost = get_service_cost(scope=session.type, total_equipment=total_equipment)
# Create the insert statement
stmt = insert(OverhaulActivity).values(
[
{
"assetnum": assetnum,
"overhaul_scope_id": overhaul_session_id,
"material_cost": material_cost,
"service_cost": service_cost,
}
for assetnum in assetnums
]
)
# Add the ON CONFLICT DO NOTHING clause
stmt = stmt.on_conflict_do_nothing(index_elements=["assetnum", "overhaul_scope_id"])
# Execute the statement
await db_session.execute(stmt)
await db_session.execute(
sqlUpdate(OverhaulActivity)
.where(OverhaulActivity.overhaul_scope_id == overhaul_session_id)
.values(material_cost=material_cost, service_cost=service_cost)
)
await db_session.commit()
return assetnums
async def update(
*,
db_session: DbSession,
activity: OverhaulActivity,
overhaul_activity_in: OverhaulActivityUpdate
):
"""Updates a document."""
data = overhaul_activity_in.model_dump()
update_data = overhaul_activity_in.model_dump(exclude_defaults=True)
for field in data:
if field in update_data:
setattr(activity, field, update_data[field])
await db_session.commit()
return activity
async def delete(*, db_session: DbSession, overhaul_activity_id: str):
"""Deletes a document."""
activity = await db_session.get(OverhaulActivity, overhaul_activity_id)
await db_session.delete(activity)
await db_session.commit()