from sqlalchemy import Select, Delete, cast, String from src.acquisition_cost.model import AcquisitionData from src.acquisition_cost.schema import AcquisitionCostDataCreate, AcquisitionCostDataUpdate from src.database.service import search_filter_sort_paginate from typing import Optional from src.database.core import DbSession from src.auth.service import CurrentUser from src.equipment.model import Equipment def _calculate_cost_unit_3(cost_unit_3_n_4: Optional[float]) -> Optional[float]: """Derive cost_unit_3 by splitting the combined unit 3&4 cost evenly.""" if cost_unit_3_n_4 is None: return None return cost_unit_3_n_4 / 2 async def _sync_equipment_acquisition_costs( *, db_session: DbSession, category_no: Optional[str], cost_unit_3: Optional[float] ): """Keep equipment acquisition cost in sync for the affected category.""" if not category_no or cost_unit_3 is None: return equipment_query = Select(Equipment).filter(Equipment.category_no == category_no) equipment_result = await db_session.execute(equipment_query) equipments = equipment_result.scalars().all() for equipment in equipments: if equipment.proportion is None: continue equipment.acquisition_cost = (equipment.proportion * 0.01) * cost_unit_3 async def get(*, db_session: DbSession, acquisition_cost_data_id: str) -> Optional[AcquisitionData]: """Returns a document based on the given document id.""" query = Select(AcquisitionData).filter(AcquisitionData.id == acquisition_cost_data_id) result = await db_session.execute(query) return result.scalars().one_or_none() async def get_all( *, db_session: DbSession, items_per_page: Optional[int], search: Optional[str] = None, common, ): """Returns all documents.""" query = Select(AcquisitionData).order_by(AcquisitionData.name.asc()) if search: query = query.filter(cast(AcquisitionData.name, String).ilike(f"%{search}%")) common["items_per_page"] = items_per_page results = await search_filter_sort_paginate(model=query, **common) # return results.scalars().all() return results async def create(*, db_session: DbSession, acquisition_data_in: AcquisitionCostDataCreate): """Creates a new document.""" data = acquisition_data_in.model_dump() cost_unit_changed = False if data.get("cost_unit_3_n_4") is not None: derived_cost_unit = _calculate_cost_unit_3(data["cost_unit_3_n_4"]) data["cost_unit_3"] = derived_cost_unit cost_unit_changed = derived_cost_unit is not None acquisition_data = AcquisitionData(**data) db_session.add(acquisition_data) if cost_unit_changed: await _sync_equipment_acquisition_costs( db_session=db_session, category_no=acquisition_data.category_no, cost_unit_3=acquisition_data.cost_unit_3, ) await db_session.commit() return acquisition_data async def update( *, db_session: DbSession, acquisition_data: AcquisitionData, acquisition_data_in: AcquisitionCostDataUpdate ): """Updates a document.""" data = acquisition_data_in.model_dump() update_data = acquisition_data_in.model_dump(exclude_defaults=True) cost_unit_changed = False if "cost_unit_3_n_4" in update_data: derived_cost_unit = _calculate_cost_unit_3(update_data.get("cost_unit_3_n_4")) update_data["cost_unit_3"] = derived_cost_unit cost_unit_changed = derived_cost_unit is not None for field in data: if field in update_data: setattr(acquisition_data, field, update_data[field]) if cost_unit_changed: await _sync_equipment_acquisition_costs( db_session=db_session, category_no=acquisition_data.category_no, cost_unit_3=acquisition_data.cost_unit_3, ) await db_session.commit() return acquisition_data async def delete(*, db_session: DbSession, acquisition_cost_data_id: str): """Deletes a document.""" query = Delete(AcquisitionData).where(AcquisitionData.id == acquisition_cost_data_id) await db_session.execute(query) await db_session.commit()