You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

301 lines
9.2 KiB
Python

import asyncio
import os
from subprocess import PIPE
from typing import Dict, List, Optional
from uuid import UUID
from sqlalchemy import Delete, Select
from src.database.core import DbSession
from src.database.service import search_filter_sort_paginate
from .model import MasterDataSimulation
from .schema import MasterDataSimulationCreate, MasterDataSimulationUpdate
MASTERDATA_SIM_ATTR_FIELDS = {
"name",
"description",
"unit_of_measurement",
"value_num",
"value_str",
"created_by",
"updated_by",
}
async def _apply_masterdata_simulation_update_logic(
*,
db_session: DbSession,
masterdata: MasterDataSimulation,
masterdata_in: MasterDataSimulationUpdate,
records_by_name: Dict[str, MasterDataSimulation],
simulation_id: UUID,
):
"""Mirror the update behaviour from src.masterdata.service for simulations."""
update_data = masterdata_in.model_dump(exclude_defaults=True)
async def get_value(name: str) -> float:
record = records_by_name.get(name)
if record is not None and record.value_num is not None:
return record.value_num
query_val = (
Select(MasterDataSimulation)
.where(MasterDataSimulation.simulation_id == simulation_id)
.where(MasterDataSimulation.name == name)
)
res_val = await db_session.execute(query_val)
row = res_val.scalars().one_or_none()
if row:
records_by_name[row.name] = row
return row.value_num if row.value_num is not None else 0
return 0
run_plant_calculation = False
def flag_special(record: MasterDataSimulation):
"""Track when special masterdata rows change to trigger recalculation."""
nonlocal run_plant_calculation
rec_name = getattr(record, "name", None)
if rec_name in [
"umur_teknis",
"discount_rate",
"loan_portion",
"interest_rate",
"loan_tenor",
"corporate_tax_rate",
"wacc_on_equity",
"auxiliary",
"susut_trafo",
"sfc",
"electricity_price_a",
"electricity_price_b",
"electricity_price_c",
"electricity_price_d",
"harga_bahan_bakar",
"inflation_rate",
"loan",
"wacc_on_project",
"principal_interest_payment",
"equity",
]:
run_plant_calculation = True
for field, val in update_data.items():
if field in MASTERDATA_SIM_ATTR_FIELDS:
setattr(masterdata, field, val)
flag_special(masterdata)
else:
query_other = (
Select(MasterDataSimulation)
.where(MasterDataSimulation.simulation_id == simulation_id)
.where(MasterDataSimulation.name == field)
)
res_other = await db_session.execute(query_other)
other = res_other.scalars().one_or_none()
if other:
if isinstance(val, (int, float)):
other.value_num = val
flag_special(other)
else:
other.value_str = str(val)
if other.name:
records_by_name[other.name] = other
if "loan_portion" in update_data:
equity_portion = 100 - await get_value("loan_portion")
setattr(masterdata, "equity_portion", equity_portion)
total_project_cost = await get_value("total_project_cost")
loan = total_project_cost * (await get_value("loan_portion") / 100)
setattr(masterdata, "loan", loan)
equity = total_project_cost * (equity_portion / 100)
setattr(masterdata, "equity", equity)
if any(field in update_data for field in ["loan", "interest_rate", "loan_tenor"]):
pmt = calculate_pmt(
rate=await get_value("interest_rate"),
nper=await get_value("loan_tenor"),
pv=await get_value("loan"),
)
setattr(masterdata, "principal_interest_payment", pmt)
if any(
field in update_data
for field in [
"loan_portion",
"interest_rate",
"corporate_tax_rate",
"wacc_on_equity",
"equity_portion",
]
):
wacc = (
await get_value("loan_portion")
* (
await get_value("interest_rate")
* (1 - await get_value("corporate_tax_rate"))
)
) + (await get_value("wacc_on_equity") * await get_value("equity_portion"))
setattr(masterdata, "wacc_on_project", wacc)
return masterdata, run_plant_calculation
async def _trigger_masterdata_simulation_recalculation(
*, db_session: DbSession, run_plant_calculation_change: bool = False
):
if not run_plant_calculation_change:
return
try:
directory_path = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../modules/plant")
)
script_path = os.path.join(directory_path, "run_plant_simulation.py")
process = await asyncio.create_subprocess_exec(
"python",
script_path,
stdout=PIPE,
stderr=PIPE,
cwd=directory_path,
)
stdout, stderr = await process.communicate()
if process.returncode != 0:
print(f"Plant recalc error: {stderr.decode()}")
else:
print(f"Plant recalc output: {stdout.decode()}")
except Exception as exc:
print(f"Error during simulation masterdata recalculation: {exc}")
def calculate_pmt(rate, nper, pv):
rate = float(rate) / 100 if rate > 1 else float(rate)
if rate == 0:
return -pv / nper
return -pv * (rate * (1 + rate) ** nper) / ((1 + rate) ** nper - 1)
async def get(
*, db_session: DbSession, masterdata_id: str
) -> Optional[MasterDataSimulation]:
query = Select(MasterDataSimulation).where(MasterDataSimulation.id == masterdata_id)
result = await db_session.execute(query)
return result.scalars().one_or_none()
async def get_all(
*,
db_session: DbSession,
items_per_page: int,
simulation_id: UUID,
search: Optional[str],
common,
):
query = (
Select(MasterDataSimulation)
.where(MasterDataSimulation.simulation_id == simulation_id)
.order_by(MasterDataSimulation.seq.asc())
)
if search:
query = query.filter(MasterDataSimulation.name.ilike(f"%{search}%"))
common["items_per_page"] = items_per_page
return await search_filter_sort_paginate(model=query, **common)
async def create(*, db_session: DbSession, masterdata_in: MasterDataSimulationCreate):
masterdata = MasterDataSimulation(**masterdata_in.model_dump())
db_session.add(masterdata)
await db_session.commit()
return masterdata
async def update(
*,
db_session: DbSession,
masterdata: MasterDataSimulation,
masterdata_in: MasterDataSimulationUpdate,
):
records_by_name: Dict[str, MasterDataSimulation] = {}
if masterdata.name:
records_by_name[masterdata.name] = masterdata
_, run_plant_calculation = await _apply_masterdata_simulation_update_logic(
db_session=db_session,
masterdata=masterdata,
masterdata_in=masterdata_in,
records_by_name=records_by_name,
simulation_id=masterdata.simulation_id,
)
await db_session.commit()
await _trigger_masterdata_simulation_recalculation(
db_session=db_session,
run_plant_calculation_change=run_plant_calculation,
)
return masterdata
async def bulk_update(
*,
db_session: DbSession,
updates: List[MasterDataSimulationUpdate],
ids: List[str],
simulation_id: UUID,
) -> List[MasterDataSimulation]:
query = (
Select(MasterDataSimulation)
.where(MasterDataSimulation.id.in_(ids))
.where(MasterDataSimulation.simulation_id == simulation_id)
)
result = await db_session.execute(query)
records = result.scalars().all()
records_map = {str(record.id): record for record in records}
records_by_name = {record.name: record for record in records if record.name}
run_plant_calculation_change = False
updated_records: List[MasterDataSimulation] = []
for masterdata_id, masterdata_in in zip(ids, updates):
masterdata = records_map.get(masterdata_id)
if not masterdata:
continue
_, run_plant_calculation = await _apply_masterdata_simulation_update_logic(
db_session=db_session,
masterdata=masterdata,
masterdata_in=masterdata_in,
records_by_name=records_by_name,
simulation_id=simulation_id,
)
if run_plant_calculation:
run_plant_calculation_change = True
updated_records.append(masterdata)
await db_session.commit()
await _trigger_masterdata_simulation_recalculation(
db_session=db_session,
run_plant_calculation_change=run_plant_calculation_change,
)
return updated_records
async def delete(*, db_session: DbSession, masterdata_id: str):
query = Delete(MasterDataSimulation).where(MasterDataSimulation.id == masterdata_id)
await db_session.execute(query)
await db_session.commit()