You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
644 lines
23 KiB
Python
644 lines
23 KiB
Python
from datetime import datetime
|
|
from typing import Optional
|
|
from uuid import uuid4, uuid4, UUID
|
|
import logging
|
|
import httpx
|
|
from fastapi import HTTPException, status
|
|
from sqlalchemy import delete, select, update, and_
|
|
from sqlalchemy.orm import selectinload
|
|
|
|
from src.config import AEROS_BASE_URL, DEFAULT_PROJECT_NAME
|
|
from src.database.core import DbSession
|
|
from src.database.service import CommonParameters, search_filter_sort_paginate
|
|
from src.utils import save_to_pastebin
|
|
import aiohttp
|
|
import asyncio
|
|
log = logging.getLogger(__name__)
|
|
|
|
from .model import (
|
|
AerosNode,
|
|
AerosSimulation,
|
|
AerosSimulationCalcResult,
|
|
AerosSimulationPlotResult,
|
|
AerosSchematic
|
|
)
|
|
from src.aeros_equipment.model import AerosEquipment, AerosEquipmentCustomParameterData
|
|
from src.aeros_equipment.schema import EquipmentWithCustomParameters
|
|
from .schema import SimulationInput, SimulationRankingParameters
|
|
|
|
client = httpx.AsyncClient(timeout=300.0)
|
|
active_simulations = {}
|
|
|
|
|
|
async def get_all(common: CommonParameters):
|
|
query = select(AerosSimulation).where(AerosSimulation.status == "completed")
|
|
|
|
results = await search_filter_sort_paginate(model=query, **common)
|
|
|
|
return results
|
|
|
|
|
|
async def get_simulation_by_id(
|
|
*,
|
|
db_session: DbSession,
|
|
simulation_id: Optional[UUID] = None,
|
|
is_completed: bool = False,
|
|
):
|
|
"""Get a simulation by id."""
|
|
query = select(AerosSimulation)
|
|
|
|
if is_completed:
|
|
query = query.where(AerosSimulation.status == "completed")
|
|
|
|
if simulation_id:
|
|
query = query.where(AerosSimulation.id == simulation_id)
|
|
else:
|
|
query = query.order_by(AerosSimulation.completed_at.desc()).limit(1)
|
|
|
|
results = await db_session.execute(query)
|
|
return results.scalar()
|
|
|
|
|
|
async def get_simulation_node_by(*, db_session: DbSession, **kwargs):
|
|
"""Get a simulation node by column."""
|
|
# Build WHERE conditions from kwargs
|
|
conditions = []
|
|
for key, value in kwargs.items():
|
|
if hasattr(AerosNode, key):
|
|
conditions.append(getattr(AerosNode, key) == value)
|
|
|
|
if not conditions:
|
|
raise ValueError("No valid column conditions provided")
|
|
|
|
query = select(AerosNode).where(*conditions)
|
|
result = await db_session.execute(query)
|
|
return result.scalar()
|
|
|
|
|
|
async def get_or_save_node(*, db_session: DbSession, node_data: dict, type: str = "calc"):
|
|
"""Get a simulation node by column."""
|
|
node = await get_simulation_node_by(
|
|
db_session=db_session, node_name=node_data["nodeName"]
|
|
)
|
|
|
|
raise Exception(node_data)
|
|
|
|
|
|
if not node:
|
|
print("Creating new node")
|
|
save = save_to_pastebin(str(node_data))
|
|
print(save)
|
|
|
|
if type == "calc":
|
|
print("Creating calc node")
|
|
node = AerosNode(
|
|
node_name=node_data["nodeName"],
|
|
node_type=node_data["nodeType"],
|
|
node_id=convert_id_to_none_if_negative(node_data["nodeId"]),
|
|
original_node_id=convert_id_to_none_if_negative(node_data["originalNodeId"]),
|
|
structure_name=node_data["structureName"],
|
|
schematic_name=node_data["schematicName"],
|
|
schematic_id=convert_id_to_none_if_negative(node_data["schematicId"]),
|
|
original_schematic_id=convert_id_to_none_if_negative(
|
|
node_data["originalSchematicId"]
|
|
),
|
|
ref_schematic_id=convert_id_to_none_if_negative(node_data["refSchematicId"]),
|
|
orignal_ref_schematic_id=convert_id_to_none_if_negative(
|
|
node_data["orinalRefSchematic"]
|
|
),
|
|
)
|
|
else:
|
|
print("Creating plot node")
|
|
nodeId = convert_id_to_none_if_negative(node_data["nodeId"])
|
|
nodeName = node_data["nodeName"]
|
|
node = AerosNode(
|
|
node_name=nodeName,
|
|
node_type=node_data["nodeType"],
|
|
node_id=nodeId,
|
|
original_node_id=convert_id_to_none_if_negative(node_data["originalNodeId"]),
|
|
schematic_name=node_data["schematicName"],
|
|
schematic_id=convert_id_to_none_if_negative(node_data["parentSchematicId"]),
|
|
original_schematic_id=convert_id_to_none_if_negative(
|
|
node_data["originalParentSchematicId"]
|
|
),
|
|
ref_schematic_id=convert_id_to_none_if_negative(node_data["targetSchematicId"]),
|
|
orignal_ref_schematic_id=convert_id_to_none_if_negative(
|
|
node_data["originalTargetSchematicId"]
|
|
),
|
|
)
|
|
|
|
|
|
db_session.add(node)
|
|
await db_session.commit()
|
|
|
|
return node
|
|
|
|
|
|
async def execute_simulation(
|
|
*,
|
|
db_session: DbSession,
|
|
simulation_id: Optional[UUID] = None,
|
|
sim_data: dict,
|
|
is_saved: bool = False,
|
|
eq_update: dict = {},
|
|
):
|
|
"""Execute the actual simulation call"""
|
|
print("Executing simulation with id: %s", simulation_id, sim_data["SchematicName"])
|
|
|
|
try:
|
|
response = await client.post(
|
|
f"{AEROS_BASE_URL}/api/Simulation/RunSimulation",
|
|
json=sim_data,
|
|
headers={"Content-Type": "application/json"},
|
|
)
|
|
response.raise_for_status()
|
|
result = response.json()
|
|
|
|
if is_saved:
|
|
simulation = await get_simulation_by_id(
|
|
db_session=db_session, simulation_id=simulation_id
|
|
)
|
|
simulation.status = "proccessing"
|
|
simulation.result = result
|
|
await db_session.commit()
|
|
await save_simulation_result(
|
|
db_session=db_session, simulation_id=simulation_id, result=result, schematic_name=sim_data["SchematicName"],eq_update=eq_update
|
|
)
|
|
|
|
print("Simulation completed with id: %s", simulation_id)
|
|
return result
|
|
|
|
except Exception as e:
|
|
simulation = await get_simulation_by_id(
|
|
db_session=db_session, simulation_id=simulation_id
|
|
)
|
|
simulation.status = "failed"
|
|
simulation.error = str(e)
|
|
await db_session.commit()
|
|
|
|
log.error("Simulation failed with error: %s", str(e))
|
|
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
|
|
)
|
|
|
|
|
|
async def get_all_aeros_node(*, db_session: DbSession, schematic_name: Optional[str] = None):
|
|
query = select(AerosNode)
|
|
|
|
if schematic_name:
|
|
aeros_schematic = await get_aeros_schematic_by_name(db_session=db_session, schematic_name=schematic_name)
|
|
|
|
if not aeros_schematic:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Schematic not found")
|
|
|
|
query = query.where(AerosNode.aeros_schematic_id == aeros_schematic.id)
|
|
|
|
results = await db_session.execute(query)
|
|
return results.scalars().all()
|
|
|
|
|
|
|
|
async def save_simulation_result(
|
|
*, db_session: DbSession, simulation_id: UUID, result: dict, schematic_name: str, eq_update: dict
|
|
):
|
|
print("Saving simulation result")
|
|
"""Save the simulation result."""
|
|
calc_result = result["nodeResultOuts"]
|
|
plot_result = result["plotNodeOuts"]
|
|
|
|
"""Save the simulation result"""
|
|
avaiable_nodes = {
|
|
f"{node.node_type}:{node.node_name}": node
|
|
for node in await get_all_aeros_node(db_session=db_session, schematic_name=schematic_name)
|
|
}
|
|
calc_objects = []
|
|
plot_objects = []
|
|
|
|
|
|
try:
|
|
for result in calc_result:
|
|
node_type = "RegularNode" if result["nodeType"] == "RegularNode" else "SchematicNode"
|
|
node = avaiable_nodes.get(f"{node_type}:{result['nodeName']}", None)
|
|
|
|
eq_reliability = eq_update.get(result["nodeName"], {
|
|
"eta": 0,
|
|
"beta": 0,
|
|
"mttr": 0
|
|
})
|
|
|
|
|
|
if not node:
|
|
if result["nodeType"] != "RegularNode" and result["nodeType"] != "Schematic":
|
|
continue
|
|
node = await get_or_save_node(
|
|
db_session=db_session, node_data=result, type="calc"
|
|
)
|
|
|
|
calc_result = AerosSimulationCalcResult(
|
|
aeros_simulation_id=simulation_id,
|
|
aeros_node_id=node.id,
|
|
total_downtime=result["totalDowntime"],
|
|
total_uptime=result["totalUpTime"],
|
|
num_events=result["numEvents"],
|
|
production=result["production"],
|
|
production_std=result["productionStd"],
|
|
ideal_production=result["idealProduction"],
|
|
availability=result["availability"],
|
|
efficiency=result["efficiency"],
|
|
effective_loss=result["effectiveLoss"],
|
|
num_cm=result["numCM"],
|
|
cm_waiting_time=result["cmWaitingTime"],
|
|
total_cm_downtime=result["totalCMDowntime"],
|
|
num_pm=result["numPM"],
|
|
total_pm_downtime=result["totalPMDowntime"],
|
|
num_ip=result["numIP"],
|
|
total_ip_downtime=result["totalIPDowntime"],
|
|
num_oh=result["numOH"],
|
|
total_oh_downtime=result["totalOHDowntime"],
|
|
t_wait_for_crew=result["tWaitForCrew"],
|
|
t_wait_for_spare=result["tWaitForSpare"],
|
|
duration_at_full=result["durationAtFull"],
|
|
duration_above_hh=result["durationAboveHH"],
|
|
duration_above_h=result["durationAboveH"],
|
|
duration_below_l=result["durationBelowL"],
|
|
duration_below_ll=result["durationBelowLL"],
|
|
duration_at_empty=result["durationAtEmpty"],
|
|
stg_input=result["stgInput"],
|
|
stg_output=result["stgOutput"],
|
|
average_level=result["averageLevel"],
|
|
potential_production=result["potentialProduction"],
|
|
eaf=result["production"] / result["idealProduction"] if result["idealProduction"] > 0 else 0,
|
|
beta=eq_reliability["beta"] if node_type == "RegularNode" else None,
|
|
eta=eq_reliability["eta"] if node_type == "RegularNode" else None,
|
|
mttr=eq_reliability["mttr"] if node_type == "RegularNode" else None
|
|
)
|
|
|
|
calc_objects.append(calc_result)
|
|
|
|
for result in plot_result:
|
|
node_type = "RegularNode" if result["nodeType"] == "RegularNode" else "SchematicNode"
|
|
node = avaiable_nodes.get(f"{node_type}:{result['nodeName']}", None)
|
|
if not node:
|
|
if result["nodeType"] != "RegularNode" and result["nodeType"] != "Schematic":
|
|
continue
|
|
node = await get_or_save_node(
|
|
db_session=db_session, node_data=result, type="plot"
|
|
)
|
|
|
|
plot_result = AerosSimulationPlotResult(
|
|
aeros_simulation_id=simulation_id,
|
|
aeros_node_id=node.id,
|
|
max_flow_rate=result["maxFlowrate"],
|
|
storage_capacity=result["storageCapacity"],
|
|
point_availabilities=result["pointAvailabilities"],
|
|
point_flowrates=result["pointFlowrates"],
|
|
timestamp_outs=result["timeStampOuts"],
|
|
)
|
|
|
|
plot_objects.append(plot_result)
|
|
|
|
except Exception as e:
|
|
simulation = await get_simulation_by_id(
|
|
db_session=db_session, simulation_id=simulation_id
|
|
)
|
|
simulation.status = "failed"
|
|
simulation.result = str(e)
|
|
await db_session.commit()
|
|
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
|
|
)
|
|
|
|
db_session.add_all(calc_objects)
|
|
db_session.add_all(plot_objects)
|
|
|
|
simulation = await get_simulation_by_id(
|
|
db_session=db_session, simulation_id=simulation_id
|
|
)
|
|
simulation.status = "completed"
|
|
simulation.completed_at = datetime.now()
|
|
|
|
await db_session.commit()
|
|
|
|
return
|
|
|
|
async def process_single_schematic(*, db_session: DbSession, sim_data: dict, schematic):
|
|
"""Process a single schematic simulation and return the nodes"""
|
|
try:
|
|
# Execute simulation for this schematic
|
|
results = await execute_simulation(db_session=db_session, sim_data=sim_data)
|
|
|
|
# Create main schematic node
|
|
mainSchematicId = uuid4()
|
|
mainSchematic = AerosNode(
|
|
id=mainSchematicId,
|
|
node_name=schematic.schematic_name,
|
|
schematic_name=schematic.schematic_name,
|
|
schematic_id=None,
|
|
node_type="SchematicNode",
|
|
aeros_schematic_id=schematic.id,
|
|
structure_name=schematic.schematic_name
|
|
)
|
|
|
|
# Process simulation results recursively
|
|
nodes = await save_recusive_simulation_result_node(
|
|
db_session=db_session,
|
|
data=results,
|
|
schematic_name=mainSchematic.node_name,
|
|
schematic_id=mainSchematicId,
|
|
aeros_schematic_id=schematic.id
|
|
)
|
|
nodes.append(mainSchematic)
|
|
|
|
return nodes
|
|
|
|
except Exception as e:
|
|
print(f"Error processing schematic {schematic.schematic_name}: {e}")
|
|
raise # Re-raise to be caught by asyncio.gather
|
|
|
|
async def save_recusive_simulation_result_node(*, db_session: DbSession, data, schematic_name: str, aeros_schematic_id ,schematic_id: Optional[UUID] = None):
|
|
## Get All schematic
|
|
|
|
#doing multiple simulation with all schematic
|
|
|
|
#1 Record schmatic ID from master schematic, ex - TJB - Unit 3 - = 1
|
|
#2 Get The highest parent from Plot data using nodeName == schematicName
|
|
#3 save the highest parent, add master schematic ID, get highest parent_id,
|
|
# continue looping through all plot data, check if it regular node and schemmaticName = highest parent schematic ID, save
|
|
# If schematicName = Parent schematic name, but not regular node, that mean that node is schematic and should have children
|
|
# search for children schematic and save them
|
|
|
|
|
|
plotResult = data["plotNodeOuts"]
|
|
|
|
structure_names = {result["nodeName"]:result["structureName"] for result in data["nodeResultOuts"]}
|
|
|
|
results = []
|
|
|
|
for result in plotResult:
|
|
|
|
if result["schematicName"] == schematic_name and result["nodeType"] == "RegularNode":
|
|
|
|
node = AerosNode(
|
|
node_name=result["nodeName"],
|
|
schematic_id=schematic_id,
|
|
node_type="RegularNode",
|
|
schematic_name=schematic_name,
|
|
aeros_schematic_id=aeros_schematic_id,
|
|
structure_name=structure_names.get(result["nodeName"])
|
|
)
|
|
|
|
results.append(node)
|
|
|
|
elif result["schematicName"] == schematic_name and result["nodeType"] == "SubSchematic":
|
|
schematicId = uuid4()
|
|
schematic = AerosNode(
|
|
id=schematicId,
|
|
node_name=result["nodeName"],
|
|
schematic_name=schematic_name,
|
|
schematic_id=schematic_id,
|
|
node_type="SchematicNode",
|
|
aeros_schematic_id=aeros_schematic_id,
|
|
structure_name=structure_names.get(result["nodeName"])
|
|
)
|
|
results.append(schematic)
|
|
|
|
res = await save_recusive_simulation_result_node(db_session=db_session, data=data, schematic_name=result["nodeName"], schematic_id=schematicId, aeros_schematic_id=aeros_schematic_id)
|
|
results.extend(res)
|
|
else:
|
|
continue
|
|
|
|
return results
|
|
|
|
|
|
async def get_aeros_schematic_by_name(*, db_session: DbSession, schematic_name: str):
|
|
query = select(AerosSchematic).where(AerosSchematic.schematic_name == schematic_name)
|
|
results = await db_session.execute(query)
|
|
return results.scalar_one_or_none()
|
|
|
|
async def save_default_simulation_node(
|
|
*, db_session: DbSession, project_name: str = "trialapi"
|
|
):
|
|
tasks = []
|
|
all_results = []
|
|
# Get all schematic
|
|
schematics = await get_all_schematic_aeros(db_session=db_session)
|
|
|
|
for schematic in schematics:
|
|
sim_data = {
|
|
"projectName": project_name,
|
|
"SchematicName": schematic.schematic_name,
|
|
"SimSeed": 1,
|
|
"SimDuration": 1,
|
|
"DurationUnit": "UMinute",
|
|
"SimNumRun": 1,
|
|
}
|
|
|
|
# Create a task for each simulation
|
|
results = await process_single_schematic(
|
|
db_session=db_session,
|
|
sim_data=sim_data,
|
|
schematic=schematic
|
|
)
|
|
|
|
all_results.extend(results)
|
|
|
|
|
|
# all_results_lists = await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
|
|
# for i, result in enumerate(all_results_lists):
|
|
# if isinstance(result, Exception):
|
|
# print(f"Simulation failed for schematic {schematics[i].schematic_name}: {result}")
|
|
# # You might want to handle this differently based on your requirements
|
|
# continue
|
|
# all_results.extend(result)
|
|
|
|
# # delete old data
|
|
await db_session.execute(delete(AerosNode))
|
|
|
|
db_session.add_all(all_results)
|
|
await db_session.commit()
|
|
|
|
|
|
def convert_id_to_none_if_negative(value):
|
|
"""Convert ID to None if it's below 0, otherwise return the value."""
|
|
return None if value < 0 else value
|
|
|
|
|
|
async def create_simulation(*, db_session: DbSession, simulation_in: SimulationInput):
|
|
"""Create a new simulation."""
|
|
input = simulation_in.model_dump(exclude={"SimulationName"})
|
|
active_simulations = {
|
|
"status": "running",
|
|
"started_at": datetime.now(),
|
|
"simulation_name": simulation_in.SimulationName,
|
|
"schematic_name": "- TJB - Unit 3 -",
|
|
}
|
|
|
|
simulation = AerosSimulation(**active_simulations)
|
|
db_session.add(simulation)
|
|
await db_session.commit()
|
|
return simulation
|
|
|
|
|
|
async def get_simulation_with_calc_result(
|
|
*, db_session: DbSession, simulation_id: UUID, aeros_node_id: Optional[UUID] = None, schematic_name: Optional[str] = None
|
|
):
|
|
"""Get a simulation by id."""
|
|
query = (select(AerosSimulationCalcResult).filter(
|
|
AerosSimulationCalcResult.aeros_simulation_id == simulation_id))
|
|
|
|
if schematic_name:
|
|
if schematic_name == "WTP":
|
|
query = query.join(
|
|
AerosNode, AerosNode.id == AerosSimulationCalcResult.aeros_node_id
|
|
).filter(AerosNode.structure_name.contains(schematic_name))
|
|
else:
|
|
query = query.join(
|
|
AerosNode, AerosNode.id == AerosSimulationCalcResult.aeros_node_id
|
|
).filter(AerosNode.structure_name.contains(schematic_name))
|
|
|
|
query = query.options(
|
|
selectinload(AerosSimulationCalcResult.aeros_node).options(
|
|
selectinload(AerosNode.equipment)
|
|
))
|
|
|
|
simulation = await db_session.execute(query)
|
|
|
|
return simulation.scalars().all()
|
|
|
|
|
|
async def get_result_ranking(*, db_session: DbSession, simulation_id: UUID):
|
|
|
|
query = select(AerosEquipment, AerosSimulationCalcResult.eaf).join(AerosNode, AerosNode.node_name == AerosEquipment.node_name).join(AerosSimulationCalcResult, AerosSimulationCalcResult.aeros_node_id == AerosNode.id)
|
|
|
|
query = query.filter(
|
|
and_(
|
|
AerosSimulationCalcResult.aeros_simulation_id == simulation_id,
|
|
AerosNode.node_type == "RegularNode",
|
|
AerosEquipment.custom_parameters.any()
|
|
)
|
|
)
|
|
|
|
query = query.order_by(AerosSimulationCalcResult.eaf.desc()).limit(10)
|
|
|
|
|
|
query = query.options(
|
|
selectinload(AerosEquipment.custom_parameters)).options(
|
|
selectinload(AerosEquipment.master_equipment)
|
|
)
|
|
|
|
result = await db_session.execute(query)
|
|
|
|
data = [
|
|
SimulationRankingParameters(
|
|
location_tag=equipment.location_tag,
|
|
master_equipment=equipment.master_equipment,
|
|
custom_parameters=equipment.custom_parameters,
|
|
eaf=eaf
|
|
)
|
|
for equipment, eaf in result
|
|
]
|
|
|
|
return data
|
|
|
|
|
|
async def get_simulation_with_plot_result(
|
|
*, db_session: DbSession, simulation_id: UUID
|
|
):
|
|
"""Get a simulation by id."""
|
|
query = (
|
|
select(AerosSimulation)
|
|
.where(AerosSimulation.id == simulation_id)
|
|
.options(
|
|
selectinload(AerosSimulation.plot_results).options(
|
|
selectinload(AerosSimulationPlotResult.aeros_node)
|
|
)
|
|
)
|
|
)
|
|
simulation = await db_session.execute(query)
|
|
return simulation.scalar()
|
|
|
|
|
|
async def get_calc_result_by(
|
|
*, db_session: DbSession, simulation_id: UUID, node_name: Optional[str] = None
|
|
):
|
|
"""Get a simulation node by column."""
|
|
# Build WHERE conditions from kwargs
|
|
query = select(AerosSimulationCalcResult).where(
|
|
AerosSimulationCalcResult.aeros_simulation_id == simulation_id
|
|
)
|
|
|
|
if node_name:
|
|
query = query.join(AerosSimulationCalcResult.aeros_node).filter(AerosNode.node_name == node_name)
|
|
|
|
result = await db_session.execute(query)
|
|
return result.scalar()
|
|
|
|
|
|
async def get_custom_parameters(*, db_session: DbSession, simulation_id: UUID):
|
|
"""Get a simulation node by column."""
|
|
# Build WHERE conditions from kwargs
|
|
query = select(AerosSimulationCalcResult).where(
|
|
AerosSimulationCalcResult.aeros_simulation_id == simulation_id
|
|
)
|
|
query = query.join(
|
|
AerosNode, AerosNode.id == AerosSimulationCalcResult.aeros_node_id
|
|
)
|
|
query = query.where(AerosNode.node_type == "RegularNode")
|
|
query = (
|
|
query.order_by(AerosSimulationCalcResult.eaf.desc())
|
|
.limit(20)
|
|
.options(selectinload(AerosSimulationCalcResult.aeros_node))
|
|
)
|
|
result = await db_session.execute(query)
|
|
return result.scalars().all()
|
|
|
|
async def get_regular_nodes_by_schematic(*, db_session: DbSession, schematic_name: str) -> set[UUID]:
|
|
"""
|
|
Get all regular node IDs that are descendants of a given schematic (system or subsystem).
|
|
Uses recursive CTE to traverse the hierarchy.
|
|
"""
|
|
|
|
# Using recursive CTE to find all descendants
|
|
# First, find the root node(s) with the given schematic name
|
|
root_cte = (
|
|
select(AerosNode.id, AerosNode.schematic_id, AerosNode.ref_schematic_id,AerosNode.node_type, AerosNode.node_name)
|
|
.where(AerosNode.node_name == schematic_name)
|
|
.cte(name="hierarchy", recursive=True)
|
|
)
|
|
|
|
# Recursive part: find all children
|
|
children_cte = (
|
|
select(AerosNode.id, AerosNode.schematic_id,AerosNode.ref_schematic_id ,AerosNode.node_type, AerosNode.node_name)
|
|
.select_from(
|
|
AerosNode.join(root_cte, AerosNode.schematic_id == root_cte.c.ref_schematic_id)
|
|
)
|
|
)
|
|
|
|
# Union the base case and recursive case
|
|
hierarchy_cte = root_cte.union_all(children_cte)
|
|
|
|
# Final query to get only regular nodes from the hierarchy
|
|
query = (
|
|
select(hierarchy_cte.c.id)
|
|
.where(hierarchy_cte.c.node_type == "RegularNode") # Adjust this condition based on your node_type values
|
|
)
|
|
|
|
result = await db_session.execute(query)
|
|
return set(result.scalars().all())
|
|
|
|
|
|
async def get_all_schematic_aeros(*, db_session: DbSession):
|
|
query = select(AerosSchematic)
|
|
results = await db_session.execute(query)
|
|
return results.scalars().all()
|
|
|
|
|
|
async def update_simulation(*, db_session: DbSession, simulation_id: UUID, data: dict):
|
|
query = update(AerosSimulation).where(AerosSimulation.id == simulation_id).values(**data)
|
|
await db_session.execute(query)
|
|
await db_session.commit()
|