diff --git a/src/database/core.py b/src/database/core.py index ef637dd..74845c0 100644 --- a/src/database/core.py +++ b/src/database/core.py @@ -1,5 +1,6 @@ # src/database.py import re +import operator from contextlib import asynccontextmanager from typing import Annotated, Any @@ -50,7 +51,7 @@ class Base(DeclarativeBase): def dict(self): """Returns a dict representation of a model.""" if hasattr(self, '__table__'): - return {c.name: getattr(self, c.name) for c in self.__table__.columns} + return {c.name: operator.attrgetter(c.name)(self) for c in self.__table__.columns} return {} class CollectorBase(DeclarativeBase): @@ -61,7 +62,7 @@ class CollectorBase(DeclarativeBase): def dict(self): """Returns a dict representation of a model.""" if hasattr(self, '__table__'): - return {c.name: getattr(self, c.name) for c in self.__table__.columns} + return {c.name: operator.attrgetter(c.name)(self) for c in self.__table__.columns} return {} @asynccontextmanager diff --git a/src/exceptions.py b/src/exceptions.py index 0565447..6db15c9 100644 --- a/src/exceptions.py +++ b/src/exceptions.py @@ -67,7 +67,10 @@ def handle_sqlalchemy_error(error: SQLAlchemyError): """ Handle SQLAlchemy errors and return user-friendly error messages. """ - original_error = getattr(error, "orig", None) + try: + original_error = error.orig + except AttributeError: + original_error = None print(original_error) if isinstance(error, IntegrityError): diff --git a/src/overhaul_activity/service.py b/src/overhaul_activity/service.py index b07cc7b..e77ef40 100644 --- a/src/overhaul_activity/service.py +++ b/src/overhaul_activity/service.py @@ -12,6 +12,7 @@ from src.auth.service import CurrentUser from src.database.core import DbSession from src.database.service import CommonParameters, search_filter_sort_paginate from src.overhaul_activity.utils import get_material_cost, get_service_cost +from src.utils import update_model from src.overhaul_scope.model import OverhaulScope from src.overhaul_scope.service import get as get_session, get_prev_oh from src.standard_scope.model import MasterEquipment, StandardScope @@ -439,9 +440,7 @@ async def update( update_data = overhaul_activity_in.model_dump(exclude_defaults=True) - for field in data: - if field in update_data: - setattr(activity, field, update_data[field]) + update_model(activity, update_data) await db_session.commit() diff --git a/src/overhaul_schedule/service.py b/src/overhaul_schedule/service.py index 4ea9d37..1c69675 100644 --- a/src/overhaul_schedule/service.py +++ b/src/overhaul_schedule/service.py @@ -7,6 +7,7 @@ from sqlalchemy.orm import selectinload from src.auth.service import CurrentUser from src.database.core import DbSession from src.database.service import search_filter_sort_paginate +from src.utils import update_model from src.scope_equipment_job.model import ScopeEquipmentJob from src.overhaul_activity.model import OverhaulActivity @@ -41,9 +42,7 @@ async def update(*, db_session: DbSession, overhaul_schedule_id: str, overhaul_j update_data = overhaul_job_in.model_dump(exclude_defaults=True) - for field in data: - if field in update_data: - setattr(overhaul_schedule, field, update_data[field]) + update_model(overhaul_schedule, update_data) await db_session.commit() diff --git a/src/overhaul_scope/service.py b/src/overhaul_scope/service.py index 6b5075e..47df892 100644 --- a/src/overhaul_scope/service.py +++ b/src/overhaul_scope/service.py @@ -7,7 +7,7 @@ from src.auth.service import CurrentUser from src.database.core import DbSession from src.database.service import search_filter_sort_paginate from src.overhaul_activity.model import OverhaulActivity -from src.utils import time_now +from src.utils import time_now, update_model from src.standard_scope.model import MasterEquipment, StandardScope, EquipmentOHHistory from src.workscope_group.model import MasterActivity from src.workscope_group_maintenance_type.model import WorkscopeOHType @@ -132,9 +132,7 @@ async def update(*, db_session: DbSession, scope: OverhaulScope, scope_in: Scope update_data = scope_in.model_dump(exclude_defaults=True) - for field in data: - if field in update_data: - setattr(scope, field, update_data[field]) + update_model(scope, update_data) await db_session.commit() diff --git a/src/standard_scope/service.py b/src/standard_scope/service.py index 2f0f2e3..f0779f6 100644 --- a/src/standard_scope/service.py +++ b/src/standard_scope/service.py @@ -8,6 +8,7 @@ from sqlalchemy.orm import selectinload from src.auth.service import CurrentUser from src.database.core import DbSession, CollectorDbSession +from src.utils import update_model from src.database.service import CommonParameters, search_filter_sort_paginate from src.overhaul_scope.model import OverhaulScope from src.standard_scope.enum import ScopeEquipmentType @@ -154,9 +155,7 @@ async def update( update_data = scope_equipment_in.model_dump(exclude_defaults=True) - for field in data: - if field in update_data: - setattr(scope_equipment, field, update_data[field]) + update_model(scope_equipment, update_data) await db_session.commit() diff --git a/src/utils.py b/src/utils.py index 3dc0219..ed4a615 100644 --- a/src/utils.py +++ b/src/utils.py @@ -139,3 +139,13 @@ def save_to_pastebin(data, title="Result Log", expire_date="1H"): return response.text # This will be the paste URL else: return f"Error: {response.status_code} - {response.text}" + + +def update_model(model, update_data: dict): + """ + Update a SQLAlchemy model with data from a dictionary. + """ + for key, value in update_data.items(): + if hasattr(model, key): + setattr(model, key, value) + diff --git a/src/workscope_group/service.py b/src/workscope_group/service.py index 9199703..35085d8 100644 --- a/src/workscope_group/service.py +++ b/src/workscope_group/service.py @@ -6,6 +6,7 @@ from sqlalchemy.orm import joinedload, selectinload from src.auth.service import CurrentUser from src.database.core import DbSession from src.database.service import CommonParameters, search_filter_sort_paginate +from src.utils import update_model from .model import MasterActivity from .schema import ActivityMaster, ActivityMasterCreate @@ -43,9 +44,7 @@ async def update( update_data = activity_in.model_dump(exclude_defaults=True) - for field in data: - if field in update_data: - setattr(activity, field, update_data[field]) + update_model(activity, update_data) await db_session.commit() diff --git a/src/workscope_group_maintenance_type/service.py b/src/workscope_group_maintenance_type/service.py index 9199703..35085d8 100644 --- a/src/workscope_group_maintenance_type/service.py +++ b/src/workscope_group_maintenance_type/service.py @@ -6,6 +6,7 @@ from sqlalchemy.orm import joinedload, selectinload from src.auth.service import CurrentUser from src.database.core import DbSession from src.database.service import CommonParameters, search_filter_sort_paginate +from src.utils import update_model from .model import MasterActivity from .schema import ActivityMaster, ActivityMasterCreate @@ -43,9 +44,7 @@ async def update( update_data = activity_in.model_dump(exclude_defaults=True) - for field in data: - if field in update_data: - setattr(activity, field, update_data[field]) + update_model(activity, update_data) await db_session.commit()