refactor: centralize model update logic into a new utility function and apply it across services, while also refining SQLAlchemy error handling and model attribute access.

main
Cizz22 1 month ago
parent 5ba2a5c607
commit 475e1f9c32

@ -1,5 +1,6 @@
# src/database.py # src/database.py
import re import re
import operator
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Annotated, Any from typing import Annotated, Any
@ -50,7 +51,7 @@ class Base(DeclarativeBase):
def dict(self): def dict(self):
"""Returns a dict representation of a model.""" """Returns a dict representation of a model."""
if hasattr(self, '__table__'): if hasattr(self, '__table__'):
return {c.name: getattr(self, c.name) for c in self.__table__.columns} return {c.name: operator.attrgetter(c.name)(self) for c in self.__table__.columns}
return {} return {}
class CollectorBase(DeclarativeBase): class CollectorBase(DeclarativeBase):
@ -61,7 +62,7 @@ class CollectorBase(DeclarativeBase):
def dict(self): def dict(self):
"""Returns a dict representation of a model.""" """Returns a dict representation of a model."""
if hasattr(self, '__table__'): if hasattr(self, '__table__'):
return {c.name: getattr(self, c.name) for c in self.__table__.columns} return {c.name: operator.attrgetter(c.name)(self) for c in self.__table__.columns}
return {} return {}
@asynccontextmanager @asynccontextmanager

@ -67,7 +67,10 @@ def handle_sqlalchemy_error(error: SQLAlchemyError):
""" """
Handle SQLAlchemy errors and return user-friendly error messages. Handle SQLAlchemy errors and return user-friendly error messages.
""" """
original_error = getattr(error, "orig", None) try:
original_error = error.orig
except AttributeError:
original_error = None
print(original_error) print(original_error)
if isinstance(error, IntegrityError): if isinstance(error, IntegrityError):

@ -12,6 +12,7 @@ from src.auth.service import CurrentUser
from src.database.core import DbSession from src.database.core import DbSession
from src.database.service import CommonParameters, search_filter_sort_paginate from src.database.service import CommonParameters, search_filter_sort_paginate
from src.overhaul_activity.utils import get_material_cost, get_service_cost from src.overhaul_activity.utils import get_material_cost, get_service_cost
from src.utils import update_model
from src.overhaul_scope.model import OverhaulScope from src.overhaul_scope.model import OverhaulScope
from src.overhaul_scope.service import get as get_session, get_prev_oh from src.overhaul_scope.service import get as get_session, get_prev_oh
from src.standard_scope.model import MasterEquipment, StandardScope from src.standard_scope.model import MasterEquipment, StandardScope
@ -439,9 +440,7 @@ async def update(
update_data = overhaul_activity_in.model_dump(exclude_defaults=True) update_data = overhaul_activity_in.model_dump(exclude_defaults=True)
for field in data: update_model(activity, update_data)
if field in update_data:
setattr(activity, field, update_data[field])
await db_session.commit() await db_session.commit()

@ -7,6 +7,7 @@ from sqlalchemy.orm import selectinload
from src.auth.service import CurrentUser from src.auth.service import CurrentUser
from src.database.core import DbSession from src.database.core import DbSession
from src.database.service import search_filter_sort_paginate from src.database.service import search_filter_sort_paginate
from src.utils import update_model
from src.scope_equipment_job.model import ScopeEquipmentJob from src.scope_equipment_job.model import ScopeEquipmentJob
from src.overhaul_activity.model import OverhaulActivity from src.overhaul_activity.model import OverhaulActivity
@ -41,9 +42,7 @@ async def update(*, db_session: DbSession, overhaul_schedule_id: str, overhaul_j
update_data = overhaul_job_in.model_dump(exclude_defaults=True) update_data = overhaul_job_in.model_dump(exclude_defaults=True)
for field in data: update_model(overhaul_schedule, update_data)
if field in update_data:
setattr(overhaul_schedule, field, update_data[field])
await db_session.commit() await db_session.commit()

@ -7,7 +7,7 @@ from src.auth.service import CurrentUser
from src.database.core import DbSession from src.database.core import DbSession
from src.database.service import search_filter_sort_paginate from src.database.service import search_filter_sort_paginate
from src.overhaul_activity.model import OverhaulActivity from src.overhaul_activity.model import OverhaulActivity
from src.utils import time_now from src.utils import time_now, update_model
from src.standard_scope.model import MasterEquipment, StandardScope, EquipmentOHHistory from src.standard_scope.model import MasterEquipment, StandardScope, EquipmentOHHistory
from src.workscope_group.model import MasterActivity from src.workscope_group.model import MasterActivity
from src.workscope_group_maintenance_type.model import WorkscopeOHType from src.workscope_group_maintenance_type.model import WorkscopeOHType
@ -132,9 +132,7 @@ async def update(*, db_session: DbSession, scope: OverhaulScope, scope_in: Scope
update_data = scope_in.model_dump(exclude_defaults=True) update_data = scope_in.model_dump(exclude_defaults=True)
for field in data: update_model(scope, update_data)
if field in update_data:
setattr(scope, field, update_data[field])
await db_session.commit() await db_session.commit()

@ -8,6 +8,7 @@ from sqlalchemy.orm import selectinload
from src.auth.service import CurrentUser from src.auth.service import CurrentUser
from src.database.core import DbSession, CollectorDbSession from src.database.core import DbSession, CollectorDbSession
from src.utils import update_model
from src.database.service import CommonParameters, search_filter_sort_paginate from src.database.service import CommonParameters, search_filter_sort_paginate
from src.overhaul_scope.model import OverhaulScope from src.overhaul_scope.model import OverhaulScope
from src.standard_scope.enum import ScopeEquipmentType from src.standard_scope.enum import ScopeEquipmentType
@ -154,9 +155,7 @@ async def update(
update_data = scope_equipment_in.model_dump(exclude_defaults=True) update_data = scope_equipment_in.model_dump(exclude_defaults=True)
for field in data: update_model(scope_equipment, update_data)
if field in update_data:
setattr(scope_equipment, field, update_data[field])
await db_session.commit() await db_session.commit()

@ -139,3 +139,13 @@ def save_to_pastebin(data, title="Result Log", expire_date="1H"):
return response.text # This will be the paste URL return response.text # This will be the paste URL
else: else:
return f"Error: {response.status_code} - {response.text}" return f"Error: {response.status_code} - {response.text}"
def update_model(model, update_data: dict):
"""
Update a SQLAlchemy model with data from a dictionary.
"""
for key, value in update_data.items():
if hasattr(model, key):
setattr(model, key, value)

@ -6,6 +6,7 @@ from sqlalchemy.orm import joinedload, selectinload
from src.auth.service import CurrentUser from src.auth.service import CurrentUser
from src.database.core import DbSession from src.database.core import DbSession
from src.database.service import CommonParameters, search_filter_sort_paginate from src.database.service import CommonParameters, search_filter_sort_paginate
from src.utils import update_model
from .model import MasterActivity from .model import MasterActivity
from .schema import ActivityMaster, ActivityMasterCreate from .schema import ActivityMaster, ActivityMasterCreate
@ -43,9 +44,7 @@ async def update(
update_data = activity_in.model_dump(exclude_defaults=True) update_data = activity_in.model_dump(exclude_defaults=True)
for field in data: update_model(activity, update_data)
if field in update_data:
setattr(activity, field, update_data[field])
await db_session.commit() await db_session.commit()

@ -6,6 +6,7 @@ from sqlalchemy.orm import joinedload, selectinload
from src.auth.service import CurrentUser from src.auth.service import CurrentUser
from src.database.core import DbSession from src.database.core import DbSession
from src.database.service import CommonParameters, search_filter_sort_paginate from src.database.service import CommonParameters, search_filter_sort_paginate
from src.utils import update_model
from .model import MasterActivity from .model import MasterActivity
from .schema import ActivityMaster, ActivityMasterCreate from .schema import ActivityMaster, ActivityMasterCreate
@ -43,9 +44,7 @@ async def update(
update_data = activity_in.model_dump(exclude_defaults=True) update_data = activity_in.model_dump(exclude_defaults=True)
for field in data: update_model(activity, update_data)
if field in update_data:
setattr(activity, field, update_data[field])
await db_session.commit() await db_session.commit()

Loading…
Cancel
Save