from typing import List, Optional from uuid import UUID import httpx from fastapi import HTTPException, status from sqlalchemy import Delete, Select, func from sqlalchemy.orm import selectinload from src.auth.service import CurrentUser from src.config import AEROS_BASE_URL, DEFAULT_PROJECT_NAME from src.database.core import DbSession from src.database.service import search_filter_sort_paginate from .model import AerosEquipment, AerosEquipmentDetail, MasterEquipment from .schema import EquipmentConfiguration client = httpx.AsyncClient(timeout=300.0) async def get_all(*, common): """Returns all documents.""" query = Select(AerosEquipment).options( selectinload(AerosEquipment.master_equipment) ) results = await search_filter_sort_paginate(model=query, **common) reg_nodes = [node.node_name for node in results["items"]] equipment_data = { node.node_name: node for node in results["items"] } updateNodeReq = {"projectName": DEFAULT_PROJECT_NAME, "equipmentNames": reg_nodes} try: response = await client.post( f"{AEROS_BASE_URL}/api/UpdateDisParams/GetUpdatedNodeDistributions", json=updateNodeReq, headers={"Content-Type": "application/json"}, ) response.raise_for_status() res = response.json() except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e) ) results["items"] = [ {"AerosData": data, "MasterData": equipment_data.get(data["equipmentName"]) } for data in res ] return results async def get_equipment_by_location_tag(*, db_session: DbSession, location_tag: str): query = ( Select(AerosEquipment) .where(AerosEquipment.location_tag == location_tag) .options(selectinload(AerosEquipment.aeros_equipment_details)) ) async def get_by_id(*, db_session: DbSession, id: UUID): query = ( Select(AerosEquipment) .where(AerosEquipment.id == id) .options(selectinload(AerosEquipment.aeros_equipment_details)) ) result = await db_session.execute(query) aerosEquipmentResult = result.scalar() if not aerosEquipmentResult: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="AerosEquipment not found" ) aerosNodeReq = { "projectName": "ParallelNode", "equipmentName": [aerosEquipmentResult.node_name], } try: response = await client.post( f"{AEROS_BASE_URL}/api/UpdateDisParams/GetUpdatedNodeDistributions", json=aerosNodeReq, headers={"Content-Type": "application/json"}, ) response.raise_for_status() aerosEquipmentData = response.json() except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e) ) return aerosEquipmentResult, aerosEquipmentData async def update_node( *, db_session: DbSession, equipment_nodes: List[EquipmentConfiguration] ): updateNodeReq = {"projectName": "ParallelNode", "regNodeInputs": equipment_nodes} try: response = await client.post( f"{AEROS_BASE_URL}/api/UpdateDisParams/UpdateEquipmentDistributions", json=updateNodeReq, headers={"Content-Type": "application/json"}, ) response.raise_for_status() result = response.json() return result except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e) ) async def save_default_equipment(*, db_session: DbSession, project_name: str): equipmments = Select(MasterEquipment).where( MasterEquipment.location_tag.isnot(None) ) equipment_nodes = await db_session.execute(equipmments) reg_nodes = [node.location_tag for node in equipment_nodes.scalars().all()] updateNodeReq = {"projectName": project_name, "equipmentNames": reg_nodes} # Delete old data query = Delete(AerosEquipment) await db_session.execute(query) try: response = await client.post( f"{AEROS_BASE_URL}/api/UpdateDisParams/GetUpdatedNodeDistributions", json=updateNodeReq, headers={"Content-Type": "application/json"}, ) response.raise_for_status() results = response.json() nodes = [] # save to db for equipment in results: node = AerosEquipment( node_name=equipment["equipmentName"], location_tag=equipment["equipmentName"], ) nodes.append(node) db_session.add_all(nodes) await db_session.commit() return results except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e) )