refactor: centralize model update logic into a new utility function and apply it across services, while also refining SQLAlchemy error handling and model attribute access.

main
Cizz22 1 month ago
parent 5ba2a5c607
commit 475e1f9c32

@ -1,5 +1,6 @@
# src/database.py
import re
import operator
from contextlib import asynccontextmanager
from typing import Annotated, Any
@ -50,7 +51,7 @@ class Base(DeclarativeBase):
def dict(self):
"""Returns a dict representation of a model."""
if hasattr(self, '__table__'):
return {c.name: getattr(self, c.name) for c in self.__table__.columns}
return {c.name: operator.attrgetter(c.name)(self) for c in self.__table__.columns}
return {}
class CollectorBase(DeclarativeBase):
@ -61,7 +62,7 @@ class CollectorBase(DeclarativeBase):
def dict(self):
"""Returns a dict representation of a model."""
if hasattr(self, '__table__'):
return {c.name: getattr(self, c.name) for c in self.__table__.columns}
return {c.name: operator.attrgetter(c.name)(self) for c in self.__table__.columns}
return {}
@asynccontextmanager

@ -67,7 +67,10 @@ def handle_sqlalchemy_error(error: SQLAlchemyError):
"""
Handle SQLAlchemy errors and return user-friendly error messages.
"""
original_error = getattr(error, "orig", None)
try:
original_error = error.orig
except AttributeError:
original_error = None
print(original_error)
if isinstance(error, IntegrityError):

@ -12,6 +12,7 @@ from src.auth.service import CurrentUser
from src.database.core import DbSession
from src.database.service import CommonParameters, search_filter_sort_paginate
from src.overhaul_activity.utils import get_material_cost, get_service_cost
from src.utils import update_model
from src.overhaul_scope.model import OverhaulScope
from src.overhaul_scope.service import get as get_session, get_prev_oh
from src.standard_scope.model import MasterEquipment, StandardScope
@ -439,9 +440,7 @@ async def update(
update_data = overhaul_activity_in.model_dump(exclude_defaults=True)
for field in data:
if field in update_data:
setattr(activity, field, update_data[field])
update_model(activity, update_data)
await db_session.commit()

@ -7,6 +7,7 @@ from sqlalchemy.orm import selectinload
from src.auth.service import CurrentUser
from src.database.core import DbSession
from src.database.service import search_filter_sort_paginate
from src.utils import update_model
from src.scope_equipment_job.model import ScopeEquipmentJob
from src.overhaul_activity.model import OverhaulActivity
@ -41,9 +42,7 @@ async def update(*, db_session: DbSession, overhaul_schedule_id: str, overhaul_j
update_data = overhaul_job_in.model_dump(exclude_defaults=True)
for field in data:
if field in update_data:
setattr(overhaul_schedule, field, update_data[field])
update_model(overhaul_schedule, update_data)
await db_session.commit()

@ -7,7 +7,7 @@ from src.auth.service import CurrentUser
from src.database.core import DbSession
from src.database.service import search_filter_sort_paginate
from src.overhaul_activity.model import OverhaulActivity
from src.utils import time_now
from src.utils import time_now, update_model
from src.standard_scope.model import MasterEquipment, StandardScope, EquipmentOHHistory
from src.workscope_group.model import MasterActivity
from src.workscope_group_maintenance_type.model import WorkscopeOHType
@ -132,9 +132,7 @@ async def update(*, db_session: DbSession, scope: OverhaulScope, scope_in: Scope
update_data = scope_in.model_dump(exclude_defaults=True)
for field in data:
if field in update_data:
setattr(scope, field, update_data[field])
update_model(scope, update_data)
await db_session.commit()

@ -8,6 +8,7 @@ from sqlalchemy.orm import selectinload
from src.auth.service import CurrentUser
from src.database.core import DbSession, CollectorDbSession
from src.utils import update_model
from src.database.service import CommonParameters, search_filter_sort_paginate
from src.overhaul_scope.model import OverhaulScope
from src.standard_scope.enum import ScopeEquipmentType
@ -154,9 +155,7 @@ async def update(
update_data = scope_equipment_in.model_dump(exclude_defaults=True)
for field in data:
if field in update_data:
setattr(scope_equipment, field, update_data[field])
update_model(scope_equipment, update_data)
await db_session.commit()

@ -139,3 +139,13 @@ def save_to_pastebin(data, title="Result Log", expire_date="1H"):
return response.text # This will be the paste URL
else:
return f"Error: {response.status_code} - {response.text}"
def update_model(model, update_data: dict):
"""
Update a SQLAlchemy model with data from a dictionary.
"""
for key, value in update_data.items():
if hasattr(model, key):
setattr(model, key, value)

@ -6,6 +6,7 @@ from sqlalchemy.orm import joinedload, selectinload
from src.auth.service import CurrentUser
from src.database.core import DbSession
from src.database.service import CommonParameters, search_filter_sort_paginate
from src.utils import update_model
from .model import MasterActivity
from .schema import ActivityMaster, ActivityMasterCreate
@ -43,9 +44,7 @@ async def update(
update_data = activity_in.model_dump(exclude_defaults=True)
for field in data:
if field in update_data:
setattr(activity, field, update_data[field])
update_model(activity, update_data)
await db_session.commit()

@ -6,6 +6,7 @@ from sqlalchemy.orm import joinedload, selectinload
from src.auth.service import CurrentUser
from src.database.core import DbSession
from src.database.service import CommonParameters, search_filter_sort_paginate
from src.utils import update_model
from .model import MasterActivity
from .schema import ActivityMaster, ActivityMasterCreate
@ -43,9 +44,7 @@ async def update(
update_data = activity_in.model_dump(exclude_defaults=True)
for field in data:
if field in update_data:
setattr(activity, field, update_data[field])
update_model(activity, update_data)
await db_session.commit()

Loading…
Cancel
Save