You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
134 lines
3.7 KiB
Python
134 lines
3.7 KiB
Python
|
|
import logging
|
|
from typing import Annotated, List
|
|
|
|
from sqlalchemy import desc, func, or_, Select
|
|
from sqlalchemy_filters import apply_pagination
|
|
from sqlalchemy.exc import ProgrammingError
|
|
from .core import DbSession
|
|
|
|
|
|
from fastapi import Query, Depends
|
|
from pydantic.types import Json, constr
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
# allows only printable characters
|
|
QueryStr = constr(pattern=r"^[ -~]+$", min_length=1)
|
|
|
|
|
|
def common_parameters(
|
|
db_session: DbSession, # type: ignore
|
|
current_user: QueryStr = Query(None, alias="currentUser"), # type: ignore
|
|
page: int = Query(1, gt=0, lt=2147483647),
|
|
items_per_page: int = Query(5, alias="itemsPerPage", gt=-2, lt=2147483647),
|
|
query_str: QueryStr = Query(None, alias="q"), # type: ignore
|
|
filter_spec: QueryStr = Query(None, alias="filter"), # type: ignore
|
|
sort_by: List[str] = Query([], alias="sortBy[]"),
|
|
descending: List[bool] = Query([], alias="descending[]"),
|
|
exclude: List[str] = Query([], alias="exclude[]"),
|
|
# role: QueryStr = Depends(get_current_role),
|
|
):
|
|
return {
|
|
"db_session": db_session,
|
|
"page": page,
|
|
"items_per_page": items_per_page,
|
|
"query_str": query_str,
|
|
"filter_spec": filter_spec,
|
|
"sort_by": sort_by,
|
|
"descending": descending,
|
|
"current_user": current_user,
|
|
# "role": role,
|
|
}
|
|
|
|
|
|
CommonParameters = Annotated[
|
|
dict[str, int | str | DbSession | QueryStr |
|
|
Json | List[str] | List[bool]],
|
|
Depends(common_parameters),
|
|
]
|
|
|
|
|
|
def search(*, query_str: str, query: Query, model, sort=False):
|
|
"""Perform a search based on the query."""
|
|
search_model = model
|
|
|
|
if not query_str.strip():
|
|
return query
|
|
|
|
search = []
|
|
if hasattr(search_model, "search_vector"):
|
|
vector = search_model.search_vector
|
|
search.append(vector.op("@@")(func.tsq_parse(query_str)))
|
|
|
|
if hasattr(search_model, "name"):
|
|
search.append(
|
|
search_model.name.ilike(f"%{query_str}%"),
|
|
)
|
|
search.append(search_model.name == query_str)
|
|
|
|
if not search:
|
|
raise Exception(f"Search not supported for model: {model}")
|
|
|
|
query = query.filter(or_(*search))
|
|
|
|
if sort:
|
|
query = query.order_by(
|
|
desc(func.ts_rank_cd(vector, func.tsq_parse(query_str))))
|
|
|
|
return query.params(term=query_str)
|
|
|
|
|
|
async def search_filter_sort_paginate(
|
|
db_session: DbSession,
|
|
model,
|
|
query_str: str = None,
|
|
filter_spec: str | dict | None = None,
|
|
page: int = 1,
|
|
items_per_page: int = 5,
|
|
sort_by: List[str] = None,
|
|
descending: List[bool] = None,
|
|
current_user: str = None,
|
|
exclude: List[str] = None,
|
|
):
|
|
"""Common functionality for searching, filtering, sorting, and pagination."""
|
|
# try:
|
|
query = Select(model)
|
|
|
|
if query_str:
|
|
sort = False if sort_by else True
|
|
query = search(query_str=query_str, query=query,
|
|
model=model, sort=sort)
|
|
|
|
# Get total count
|
|
count_query = Select(func.count()).select_from(query.subquery())
|
|
total = await db_session.scalar(count_query)
|
|
|
|
query = (
|
|
query
|
|
.offset((page - 1) * items_per_page)
|
|
.limit(items_per_page)
|
|
)
|
|
|
|
result = await db_session.execute(query)
|
|
items = result.scalars().all()
|
|
|
|
# try:
|
|
# query, pagination = apply_pagination(
|
|
# query=query, page_number=page, page_size=items_per_page)
|
|
# except ProgrammingError as e:
|
|
# log.debug(e)
|
|
# return {
|
|
# "items": [],
|
|
# "itemsPerPage": items_per_page,
|
|
# "page": page,
|
|
# "total": 0,
|
|
# }
|
|
|
|
return {
|
|
"items": items,
|
|
"itemsPerPage": items_per_page,
|
|
"page": page,
|
|
"total": total,
|
|
}
|