import logging from typing import Annotated, List from sqlalchemy import desc, func, or_, Select from sqlalchemy_filters import apply_pagination from sqlalchemy.exc import ProgrammingError from .core import DbSession from fastapi import Query, Depends from pydantic.types import Json, constr log = logging.getLogger(__name__) # allows only printable characters QueryStr = constr(pattern=r"^[ -~]+$", min_length=1) def common_parameters( db_session: DbSession, # type: ignore current_user: QueryStr = Query(None, alias="currentUser"), # type: ignore page: int = Query(1, gt=0, lt=2147483647), items_per_page: int = Query(5, alias="itemsPerPage", gt=-2, lt=2147483647), query_str: QueryStr = Query(None, alias="q"), # type: ignore filter_spec: QueryStr = Query(None, alias="filter"), # type: ignore sort_by: List[str] = Query([], alias="sortBy[]"), descending: List[bool] = Query([], alias="descending[]"), exclude: List[str] = Query([], alias="exclude[]"), # role: QueryStr = Depends(get_current_role), ): return { "db_session": db_session, "page": page, "items_per_page": items_per_page, "query_str": query_str, "filter_spec": filter_spec, "sort_by": sort_by, "descending": descending, "current_user": current_user, # "role": role, } CommonParameters = Annotated[ dict[str, int | str | DbSession | QueryStr | Json | List[str] | List[bool]], Depends(common_parameters), ] def search(*, query_str: str, query: Query, model, sort=False): """Perform a search based on the query.""" search_model = model if not query_str.strip(): return query search = [] if hasattr(search_model, "search_vector"): vector = search_model.search_vector search.append(vector.op("@@")(func.tsq_parse(query_str))) if hasattr(search_model, "name"): search.append( search_model.name.ilike(f"%{query_str}%"), ) search.append(search_model.name == query_str) if not search: raise Exception(f"Search not supported for model: {model}") query = query.filter(or_(*search)) if sort: query = query.order_by( desc(func.ts_rank_cd(vector, func.tsq_parse(query_str)))) return query.params(term=query_str) async def search_filter_sort_paginate( db_session: DbSession, model, query_str: str = None, filter_spec: str | dict | None = None, page: int = 1, items_per_page: int = 5, sort_by: List[str] = None, descending: List[bool] = None, current_user: str = None, exclude: List[str] = None, ): """Common functionality for searching, filtering, sorting, and pagination.""" # try: query = Select(model) if query_str: sort = False if sort_by else True query = search(query_str=query_str, query=query, model=model, sort=sort) # Get total count count_query = Select(func.count()).select_from(query.subquery()) total = await db_session.scalar(count_query) query = ( query .offset((page - 1) * items_per_page) .limit(items_per_page) ) result = await db_session.execute(query) items = result.scalars().all() # try: # query, pagination = apply_pagination( # query=query, page_number=page, page_size=items_per_page) # except ProgrammingError as e: # log.debug(e) # return { # "items": [], # "itemsPerPage": items_per_page, # "page": page, # "total": 0, # } return { "items": items, "itemsPerPage": items_per_page, "page": page, "total": total, }