You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

172 lines
4.9 KiB
Python

import logging
from typing import Annotated, List
from sqlalchemy import desc, func, or_, Select
from sqlalchemy_filters import apply_pagination
from sqlalchemy.exc import ProgrammingError
from .core import DbSession
from fastapi import Query, Depends
from pydantic.types import Json, constr
log = logging.getLogger(__name__)
# allows only printable characters
QueryStr = constr(pattern=r"^[ -~]+$", min_length=1)
def common_parameters(
db_session: DbSession, # type: ignore
current_user: QueryStr = Query(None, alias="currentUser"), # type: ignore
page: int = Query(1, gt=0, lt=2147483647),
items_per_page: int = Query(5, alias="itemsPerPage", gt=-2, lt=2147483647),
query_str: QueryStr = Query(None, alias="q"), # type: ignore
filter_spec: QueryStr = Query(None, alias="filter"), # type: ignore
sort_by: List[str] = Query([], alias="sortBy[]"),
descending: List[bool] = Query([], alias="descending[]"),
all: int = Query(0),
# role: QueryStr = Depends(get_current_role),
):
return {
"db_session": db_session,
"page": page,
"items_per_page": items_per_page,
"query_str": query_str,
"filter_spec": filter_spec,
"sort_by": sort_by,
"descending": descending,
"current_user": current_user,
# "role": role,
"all": bool(all),
}
CommonParameters = Annotated[
dict[str, int | str | DbSession | QueryStr | Json | List[str] | List[bool] | bool],
Depends(common_parameters),
]
def search(*, query_str: str, query: Query, model, sort=False):
"""Perform a search based on the query."""
search_model = model
if not query_str.strip():
return query
search = []
if hasattr(search_model, "search_vector"):
vector = search_model.search_vector
search.append(vector.op("@@")(func.tsq_parse(query_str)))
if hasattr(search_model, "name"):
search.append(
search_model.name.ilike(f"%{query_str}%"),
)
search.append(search_model.name == query_str)
if not search:
raise Exception(f"Search not supported for model: {model}")
query = query.filter(or_(*search))
if sort:
query = query.order_by(desc(func.ts_rank_cd(vector, func.tsq_parse(query_str))))
return query.params(term=query_str)
def _extract_result_items(result):
"""Normalize SQLAlchemy result rows into ORM entities with labeled columns attached."""
rows = result.fetchall()
items = []
for row in rows:
row_values = tuple(row)
if not row_values:
continue
primary = row_values[0]
if hasattr(primary, "__table__"):
mapping = getattr(row, "_mapping", None)
if mapping:
for key, value in mapping.items():
if isinstance(key, str):
setattr(primary, key, value)
items.append(primary)
else:
# Fall back to first column (scalar or tuple) if no ORM entity is present.
if len(row_values) == 1:
items.append(primary)
else:
items.append(row_values)
return items
async def search_filter_sort_paginate(
db_session: DbSession,
model,
query_str: str = None,
filter_spec: str | dict | None = None,
page: int = 1,
items_per_page: int = 5,
sort_by: List[str] = None,
descending: List[bool] = None,
current_user: str = None,
all: bool = False,
):
"""Common functionality for searching, filtering, sorting, and pagination."""
# try:
if not isinstance(model, Select):
query = Select(model)
else:
query = model
if query_str:
sort = False if sort_by else True
query = search(query_str=query_str, query=query, model=model, sort=sort)
# Get total count
count_query = Select(func.count()).select_from(query.subquery())
total = await db_session.scalar(count_query)
if all:
result = await db_session.execute(query)
items = _extract_result_items(result)
return {
"items": items,
"itemsPerPage": total,
"totalPages": 1,
"page": 1,
"total": total,
}
query = query.offset((page - 1) * items_per_page).limit(items_per_page)
result = await db_session.execute(query)
items = _extract_result_items(result)
# try:
# query, pagination = apply_pagination(
# query=query, page_number=page, page_size=items_per_page)
# except ProgrammingError as e:
# log.debug(e)
# return {
# "items": [],
# "itemsPerPage": items_per_page,
# "page": page,
# "total": 0,
# }
return {
"items": items,
"itemsPerPage": items_per_page,
"totalPages": (total + items_per_page - 1) // items_per_page,
"page": page,
"total": total,
}