from datetime import datetime, timedelta from fastapi import HTTPException, status from sqlalchemy import Select, Delete, and_, desc, func, not_, or_ from sqlalchemy.dialects.postgresql import insert from src.overhaul_scope.model import OverhaulScope from src.scope_equipment.enum import ScopeEquipmentType from src.workorder.model import MasterWorkOrder from .model import MasterEquipmentTree, ScopeEquipment, MasterEquipment from .schema import ScopeEquipmentCreate, ScopeEquipmentUpdate from typing import Optional, Union from sqlalchemy.orm import selectinload from src.database.service import CommonParameters, search_filter_sort_paginate from src.database.core import DbSession from src.auth.service import CurrentUser async def get_by_assetnum(*, db_session: DbSession, assetnum: str): query = Select(ScopeEquipment).filter(ScopeEquipment.assetnum == assetnum).options( selectinload(ScopeEquipment.master_equipment)) result = await db_session.execute(query) return result.unique().scalars().one_or_none() async def get_all(*, common, scope_name: str = None): """Returns all documents.""" query = Select(ScopeEquipment).options( selectinload(ScopeEquipment.master_equipment)) query = query.order_by(desc(ScopeEquipment.created_at)) if scope_name: query = query.where(ScopeEquipment.scope_overhaul == scope_name) results = await search_filter_sort_paginate(model=query, **common) return results async def create(*, db_session: DbSession, scope_equipment_in: ScopeEquipmentCreate): """Creates a new document.""" # scope_equipment = ScopeEquipment(**scope_equipment_in.model_dump()) assetnums = scope_equipment_in.assetnums results = [] removal_date = scope_equipment_in.removal_date if scope_equipment_in.type == ScopeEquipmentType.TEMP: # Search for the next or ongoing overhaul session for the given scope stmt = Select(OverhaulScope.end_date).where( OverhaulScope.type == scope_equipment_in.scope_name, (OverhaulScope.start_date <= datetime.now()) & ( OverhaulScope.end_date >= datetime.now()) # Ongoing | (OverhaulScope.start_date > datetime.now()) # Upcoming ).order_by(OverhaulScope.start_date.asc()).limit(1) result = await db_session.execute(stmt) removal_date = result.scalar_one_or_none() # If no overhaul found, set a default removal date or handle the error if removal_date is None: # Handle if no overhaul session is found, set default or raise an error removal_date = datetime.now() + timedelta(days=30) # Example: 30 days from now for assetnum in assetnums: stmt = insert(ScopeEquipment).values( assetnum=assetnum, scope_overhaul=scope_equipment_in.scope_name, type=scope_equipment_in.type, removal_date=removal_date ) stmt = stmt.on_conflict_do_nothing( index_elements=["assetnum", "scope_overhaul"] ) await db_session.execute(stmt) results.append(assetnum) await db_session.commit() return results async def update(*, db_session: DbSession, scope_equipment: ScopeEquipment, scope_equipment_in: ScopeEquipmentUpdate): """Updates a document.""" data = scope_equipment_in.model_dump() update_data = scope_equipment_in.model_dump(exclude_defaults=True) for field in data: if field in update_data: setattr(scope_equipment, field, update_data[field]) await db_session.commit() return scope_equipment async def delete(*, db_session: DbSession, assetnum: str): """Deletes a document.""" query = Delete(ScopeEquipment).where( ScopeEquipment.assetnum == assetnum) await db_session.execute(query) await db_session.commit() return assetnum # query = Select(ScopeEquipment).filter( # ScopeEquipment.id == scope_equipment_id) # scope_equipment = await db_session.execute(query) # scope_equipment: ScopeEquipment = scope_equipment.scalars().one_or_none() # if not scope_equipment: # raise HTTPException( # status_code=status.HTTP_404_NOT_FOUND, # detail="A data with this id does not exist.", # ) # if not scope_equipment.scope_id: # await db_session.delete(scope_equipment) # else: # if scope_equipment.current_scope_id == scope_equipment.scope_id: # await db_session.delete(scope_equipment) # else: # scope_equipment.current_scope_id = scope_equipment.scope_id # await db_session.commit() async def get_by_scope_name(*, db_session: DbSession, scope_name: Optional[str]) -> Optional[ScopeEquipment]: """Returns a document based on the given document id.""" query = Select(ScopeEquipment).options( selectinload(ScopeEquipment.master_equipment)) if scope_name: query = query.filter(ScopeEquipment.scope_overhaul == scope_name) result = await db_session.execute(query) return result.scalars().all() # async def get_exculed_scope_name(*, db_session: DbSession, scope_name: Union[str, list]) -> Optional[ScopeEquipment]: # scope = await get_scope_by_name_service(db_session=db_session, scope_name=scope_name) # query = Select(ScopeEquipment) # if scope: # query = query.filter(ScopeEquipment.current_scope_id != scope.id) # else: # query = query.filter(ScopeEquipment.current_scope_id != None) # result = await db_session.execute(query) # return result.scalars().all() async def get_all_master_equipment(*, common: CommonParameters, scope_name): equipments_scope = [equip.assetnum for equip in await get_by_scope_name( db_session=common.get("db_session"), scope_name=scope_name)] query = Select(MasterEquipment).filter( MasterEquipment.assetnum.is_not(None)) # Only add not_in filter if there are items in equipments_scope if equipments_scope: query = query.filter(MasterEquipment.assetnum.not_in(equipments_scope)) results = await search_filter_sort_paginate(model=query, **common) return results async def get_equipment_level_by_no(*, db_session: DbSession, level: int): query = Select(MasterEquipmentTree).filter( MasterEquipmentTree.level_no == level) result = await db_session.scalar(query) return result