You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
be-optimumoh/src/database/service.py

134 lines
3.7 KiB
Python

import logging
from typing import Annotated, List
from sqlalchemy import desc, func, or_, Select
from sqlalchemy_filters import apply_pagination
from sqlalchemy.exc import ProgrammingError
from .core import DbSession
from fastapi import Query, Depends
from pydantic.types import Json, constr
log = logging.getLogger(__name__)
# allows only printable characters
QueryStr = constr(pattern=r"^[ -~]+$", min_length=1)
def common_parameters(
db_session: DbSession, # type: ignore
current_user: QueryStr = Query(None, alias="currentUser"), # type: ignore
page: int = Query(1, gt=0, lt=2147483647),
items_per_page: int = Query(5, alias="itemsPerPage", gt=-2, lt=2147483647),
query_str: QueryStr = Query(None, alias="q"), # type: ignore
filter_spec: QueryStr = Query(None, alias="filter"), # type: ignore
sort_by: List[str] = Query([], alias="sortBy[]"),
descending: List[bool] = Query([], alias="descending[]"),
exclude: List[str] = Query([], alias="exclude[]"),
# role: QueryStr = Depends(get_current_role),
):
return {
"db_session": db_session,
"page": page,
"items_per_page": items_per_page,
"query_str": query_str,
"filter_spec": filter_spec,
"sort_by": sort_by,
"descending": descending,
"current_user": current_user,
# "role": role,
}
CommonParameters = Annotated[
dict[str, int | str | DbSession | QueryStr |
Json | List[str] | List[bool]],
Depends(common_parameters),
]
def search(*, query_str: str, query: Query, model, sort=False):
"""Perform a search based on the query."""
search_model = model
if not query_str.strip():
return query
search = []
if hasattr(search_model, "search_vector"):
vector = search_model.search_vector
search.append(vector.op("@@")(func.tsq_parse(query_str)))
if hasattr(search_model, "name"):
search.append(
search_model.name.ilike(f"%{query_str}%"),
)
search.append(search_model.name == query_str)
if not search:
raise Exception(f"Search not supported for model: {model}")
query = query.filter(or_(*search))
if sort:
query = query.order_by(
desc(func.ts_rank_cd(vector, func.tsq_parse(query_str))))
return query.params(term=query_str)
async def search_filter_sort_paginate(
db_session: DbSession,
model,
query_str: str = None,
filter_spec: str | dict | None = None,
page: int = 1,
items_per_page: int = 5,
sort_by: List[str] = None,
descending: List[bool] = None,
current_user: str = None,
exclude: List[str] = None,
):
"""Common functionality for searching, filtering, sorting, and pagination."""
# try:
query = Select(model)
if query_str:
sort = False if sort_by else True
query = search(query_str=query_str, query=query,
model=model, sort=sort)
# Get total count
count_query = Select(func.count()).select_from(query.subquery())
total = await db_session.scalar(count_query)
query = (
query
.offset((page - 1) * items_per_page)
.limit(items_per_page)
)
result = await db_session.execute(query)
items = result.scalars().all()
# try:
# query, pagination = apply_pagination(
# query=query, page_number=page, page_size=items_per_page)
# except ProgrammingError as e:
# log.debug(e)
# return {
# "items": [],
# "itemsPerPage": items_per_page,
# "page": page,
# "total": 0,
# }
return {
"items": items,
"itemsPerPage": items_per_page,
"page": page,
"total": total,
}