import asyncio import os from subprocess import PIPE from typing import Dict, List, Optional from uuid import UUID from sqlalchemy import Delete, Select from src.database.core import DbSession from src.database.service import search_filter_sort_paginate from .model import MasterDataSimulation from .schema import MasterDataSimulationCreate, MasterDataSimulationUpdate MASTERDATA_SIM_ATTR_FIELDS = { "name", "description", "unit_of_measurement", "value_num", "value_str", "created_by", "updated_by", } async def _apply_masterdata_simulation_update_logic( *, db_session: DbSession, masterdata: MasterDataSimulation, masterdata_in: MasterDataSimulationUpdate, records_by_name: Dict[str, MasterDataSimulation], simulation_id: UUID, ): """Mirror the update behaviour from src.masterdata.service for simulations.""" update_data = masterdata_in.model_dump(exclude_defaults=True) async def get_value(name: str) -> float: record = records_by_name.get(name) if record is not None and record.value_num is not None: return record.value_num query_val = ( Select(MasterDataSimulation) .where(MasterDataSimulation.simulation_id == simulation_id) .where(MasterDataSimulation.name == name) ) res_val = await db_session.execute(query_val) row = res_val.scalars().one_or_none() if row: records_by_name[row.name] = row return row.value_num if row.value_num is not None else 0 return 0 run_plant_calculation = False def flag_special(record: MasterDataSimulation): """Track when special masterdata rows change to trigger recalculation.""" nonlocal run_plant_calculation rec_name = getattr(record, "name", None) if rec_name in [ "umur_teknis", "discount_rate", "loan_portion", "interest_rate", "loan_tenor", "corporate_tax_rate", "wacc_on_equity", "auxiliary", "susut_trafo", "sfc", "electricity_price_a", "electricity_price_b", "electricity_price_c", "electricity_price_d", "harga_bahan_bakar", "inflation_rate", "loan", "wacc_on_project", "principal_interest_payment", "equity", ]: run_plant_calculation = True for field, val in update_data.items(): if field in MASTERDATA_SIM_ATTR_FIELDS: setattr(masterdata, field, val) flag_special(masterdata) else: query_other = ( Select(MasterDataSimulation) .where(MasterDataSimulation.simulation_id == simulation_id) .where(MasterDataSimulation.name == field) ) res_other = await db_session.execute(query_other) other = res_other.scalars().one_or_none() if other: if isinstance(val, (int, float)): other.value_num = val flag_special(other) else: other.value_str = str(val) if other.name: records_by_name[other.name] = other if "loan_portion" in update_data: equity_portion = 100 - await get_value("loan_portion") setattr(masterdata, "equity_portion", equity_portion) total_project_cost = await get_value("total_project_cost") loan = total_project_cost * (await get_value("loan_portion") / 100) setattr(masterdata, "loan", loan) equity = total_project_cost * (equity_portion / 100) setattr(masterdata, "equity", equity) if any(field in update_data for field in ["loan", "interest_rate", "loan_tenor"]): pmt = calculate_pmt( rate=await get_value("interest_rate"), nper=await get_value("loan_tenor"), pv=await get_value("loan"), ) setattr(masterdata, "principal_interest_payment", pmt) if any( field in update_data for field in [ "loan_portion", "interest_rate", "corporate_tax_rate", "wacc_on_equity", "equity_portion", ] ): wacc = ( await get_value("loan_portion") * ( await get_value("interest_rate") * (1 - await get_value("corporate_tax_rate")) ) ) + (await get_value("wacc_on_equity") * await get_value("equity_portion")) setattr(masterdata, "wacc_on_project", wacc) return masterdata, run_plant_calculation async def _trigger_masterdata_simulation_recalculation( *, db_session: DbSession, run_plant_calculation_change: bool = False ): if not run_plant_calculation_change: return try: directory_path = os.path.abspath( os.path.join(os.path.dirname(__file__), "../modules/plant") ) script_path = os.path.join(directory_path, "run_plant_simulation.py") process = await asyncio.create_subprocess_exec( "python", script_path, stdout=PIPE, stderr=PIPE, cwd=directory_path, ) stdout, stderr = await process.communicate() if process.returncode != 0: print(f"Plant recalc error: {stderr.decode()}") else: print(f"Plant recalc output: {stdout.decode()}") except Exception as exc: print(f"Error during simulation masterdata recalculation: {exc}") def calculate_pmt(rate, nper, pv): rate = float(rate) / 100 if rate > 1 else float(rate) if rate == 0: return -pv / nper return -pv * (rate * (1 + rate) ** nper) / ((1 + rate) ** nper - 1) async def get( *, db_session: DbSession, masterdata_id: str ) -> Optional[MasterDataSimulation]: query = Select(MasterDataSimulation).where(MasterDataSimulation.id == masterdata_id) result = await db_session.execute(query) return result.scalars().one_or_none() async def get_all( *, db_session: DbSession, items_per_page: int, simulation_id: UUID, search: Optional[str], common, ): query = ( Select(MasterDataSimulation) .where(MasterDataSimulation.simulation_id == simulation_id) .order_by(MasterDataSimulation.seq.asc()) ) if search: query = query.filter(MasterDataSimulation.name.ilike(f"%{search}%")) common["items_per_page"] = items_per_page return await search_filter_sort_paginate(model=query, **common) async def create(*, db_session: DbSession, masterdata_in: MasterDataSimulationCreate): masterdata = MasterDataSimulation(**masterdata_in.model_dump()) db_session.add(masterdata) await db_session.commit() return masterdata async def update( *, db_session: DbSession, masterdata: MasterDataSimulation, masterdata_in: MasterDataSimulationUpdate, ): records_by_name: Dict[str, MasterDataSimulation] = {} if masterdata.name: records_by_name[masterdata.name] = masterdata _, run_plant_calculation = await _apply_masterdata_simulation_update_logic( db_session=db_session, masterdata=masterdata, masterdata_in=masterdata_in, records_by_name=records_by_name, simulation_id=masterdata.simulation_id, ) await db_session.commit() await _trigger_masterdata_simulation_recalculation( db_session=db_session, run_plant_calculation_change=run_plant_calculation, ) return masterdata async def bulk_update( *, db_session: DbSession, updates: List[MasterDataSimulationUpdate], ids: List[str], simulation_id: UUID, ) -> List[MasterDataSimulation]: query = ( Select(MasterDataSimulation) .where(MasterDataSimulation.id.in_(ids)) .where(MasterDataSimulation.simulation_id == simulation_id) ) result = await db_session.execute(query) records = result.scalars().all() records_map = {str(record.id): record for record in records} records_by_name = {record.name: record for record in records if record.name} run_plant_calculation_change = False updated_records: List[MasterDataSimulation] = [] for masterdata_id, masterdata_in in zip(ids, updates): masterdata = records_map.get(masterdata_id) if not masterdata: continue _, run_plant_calculation = await _apply_masterdata_simulation_update_logic( db_session=db_session, masterdata=masterdata, masterdata_in=masterdata_in, records_by_name=records_by_name, simulation_id=simulation_id, ) if run_plant_calculation: run_plant_calculation_change = True updated_records.append(masterdata) await db_session.commit() await _trigger_masterdata_simulation_recalculation( db_session=db_session, run_plant_calculation_change=run_plant_calculation_change, ) return updated_records async def delete(*, db_session: DbSession, masterdata_id: str): query = Delete(MasterDataSimulation).where(MasterDataSimulation.id == masterdata_id) await db_session.execute(query) await db_session.commit()