You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

120 lines
4.1 KiB
Python

from sqlalchemy import Select, Delete, cast, String
from src.acquisition_cost.model import AcquisitionData
from src.acquisition_cost.schema import AcquisitionCostDataCreate, AcquisitionCostDataUpdate
from src.database.service import search_filter_sort_paginate
from typing import Optional
from src.database.core import DbSession
from src.auth.service import CurrentUser
from src.equipment.model import Equipment
def _calculate_cost_unit_3(cost_unit_3_n_4: Optional[float]) -> Optional[float]:
"""Derive cost_unit_3 by splitting the combined unit 3&4 cost evenly."""
if cost_unit_3_n_4 is None:
return None
return cost_unit_3_n_4 / 2
async def _sync_equipment_acquisition_costs(
*, db_session: DbSession, category_no: Optional[str], cost_unit_3: Optional[float]
):
"""Keep equipment acquisition cost in sync for the affected category."""
if not category_no or cost_unit_3 is None:
return
equipment_query = Select(Equipment).filter(Equipment.category_no == category_no)
equipment_result = await db_session.execute(equipment_query)
equipments = equipment_result.scalars().all()
for equipment in equipments:
if equipment.proportion is None:
continue
equipment.acquisition_cost = (equipment.proportion * 0.01) * cost_unit_3
async def get(*, db_session: DbSession, acquisition_cost_data_id: str) -> Optional[AcquisitionData]:
"""Returns a document based on the given document id."""
query = Select(AcquisitionData).filter(AcquisitionData.id == acquisition_cost_data_id)
result = await db_session.execute(query)
return result.scalars().one_or_none()
async def get_all(
*,
db_session: DbSession,
items_per_page: Optional[int],
search: Optional[str] = None,
common,
):
"""Returns all documents."""
query = Select(AcquisitionData).order_by(AcquisitionData.name.asc())
if search:
query = query.filter(cast(AcquisitionData.name, String).ilike(f"%{search}%"))
common["items_per_page"] = items_per_page
results = await search_filter_sort_paginate(model=query, **common)
# return results.scalars().all()
return results
async def create(*, db_session: DbSession, acquisition_data_in: AcquisitionCostDataCreate):
"""Creates a new document."""
data = acquisition_data_in.model_dump()
cost_unit_changed = False
if data.get("cost_unit_3_n_4") is not None:
derived_cost_unit = _calculate_cost_unit_3(data["cost_unit_3_n_4"])
data["cost_unit_3"] = derived_cost_unit
cost_unit_changed = derived_cost_unit is not None
acquisition_data = AcquisitionData(**data)
db_session.add(acquisition_data)
if cost_unit_changed:
await _sync_equipment_acquisition_costs(
db_session=db_session,
category_no=acquisition_data.category_no,
cost_unit_3=acquisition_data.cost_unit_3,
)
await db_session.commit()
return acquisition_data
async def update(
*, db_session: DbSession, acquisition_data: AcquisitionData, acquisition_data_in: AcquisitionCostDataUpdate
):
"""Updates a document."""
data = acquisition_data_in.model_dump()
update_data = acquisition_data_in.model_dump(exclude_defaults=True)
cost_unit_changed = False
if "cost_unit_3_n_4" in update_data:
derived_cost_unit = _calculate_cost_unit_3(update_data.get("cost_unit_3_n_4"))
update_data["cost_unit_3"] = derived_cost_unit
cost_unit_changed = derived_cost_unit is not None
for field in data:
if field in update_data:
setattr(acquisition_data, field, update_data[field])
if cost_unit_changed:
await _sync_equipment_acquisition_costs(
db_session=db_session,
category_no=acquisition_data.category_no,
cost_unit_3=acquisition_data.cost_unit_3,
)
await db_session.commit()
return acquisition_data
async def delete(*, db_session: DbSession, acquisition_cost_data_id: str):
"""Deletes a document."""
query = Delete(AcquisitionData).where(AcquisitionData.id == acquisition_cost_data_id)
await db_session.execute(query)
await db_session.commit()