mirror of
https://github.com/flibusta-apps/book_library_server.git
synced 2025-12-06 15:15:36 +01:00
New TGRM search implementation
This commit is contained in:
@@ -1,7 +1,5 @@
|
||||
from typing import Optional, Generic, TypeVar, Union
|
||||
from itertools import permutations
|
||||
from databases import Database
|
||||
import json
|
||||
|
||||
from fastapi_pagination.api import resolve_params
|
||||
from fastapi_pagination.bases import AbstractParams, RawParams
|
||||
@@ -10,17 +8,7 @@ import aioredis
|
||||
import orjson
|
||||
|
||||
from ormar import Model, QuerySet
|
||||
from sqlalchemy import text, func, select, or_, Table, Column, cast, Text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
def join_fields(fields):
|
||||
result = fields[0]
|
||||
|
||||
for el in fields[1:]:
|
||||
result += text("' '") + el
|
||||
|
||||
return result
|
||||
from sqlalchemy import Table
|
||||
|
||||
|
||||
T = TypeVar('T', bound=Model)
|
||||
@@ -28,10 +16,9 @@ T = TypeVar('T', bound=Model)
|
||||
|
||||
class TRGMSearchService(Generic[T]):
|
||||
MODEL_CLASS: Optional[T] = None
|
||||
FIELDS: Optional[list[Column]] = None
|
||||
SELECT_RELATED: Optional[Union[list[str], str]] = None
|
||||
PREFETCH_RELATED: Optional[Union[list[str], str]] = None
|
||||
FILTERS = []
|
||||
GET_OBJECT_IDS_QUERY: Optional[str] = None
|
||||
CACHE_TTL = 5 * 60
|
||||
|
||||
@classmethod
|
||||
@@ -60,52 +47,18 @@ class TRGMSearchService(Generic[T]):
|
||||
|
||||
@classmethod
|
||||
@property
|
||||
def fields_combinations(cls):
|
||||
assert cls.FIELDS is not None, f"FIELDS in {cls.__name__} don't set!"
|
||||
assert len(cls.FIELDS) != 0, f"FIELDS in {cls.__name__} must be not empty!"
|
||||
|
||||
return permutations(cls.FIELDS, len(cls.FIELDS))
|
||||
|
||||
@classmethod
|
||||
def get_similarity_subquery(cls, query: str):
|
||||
combs = cls.fields_combinations
|
||||
|
||||
return func.greatest(
|
||||
*[func.similarity(join_fields(comb), cast(query, Text)) for comb in combs]
|
||||
).label("sml")
|
||||
|
||||
@classmethod
|
||||
def get_similarity_filter_subquery(cls, query: str):
|
||||
return or_(
|
||||
*[join_fields(comb) % f"{query}::text" for comb in cls.fields_combinations]
|
||||
)
|
||||
def object_ids_query(cls) -> str:
|
||||
assert cls.GET_OBJECT_IDS_QUERY is not None, f"GET_OBJECT_IDS_QUERY in {cls.__name__} don't set!"
|
||||
return cls.GET_OBJECT_IDS_QUERY
|
||||
|
||||
@classmethod
|
||||
async def _get_object_ids(cls, query_data: str) -> list[int]:
|
||||
similarity = cls.get_similarity_subquery(query_data)
|
||||
similarity_filter = cls.get_similarity_filter_subquery(query_data)
|
||||
|
||||
session = Session(cls.database.connection())
|
||||
|
||||
filtered_objects_query = session.query(
|
||||
cls.table.c.id, similarity
|
||||
).order_by(
|
||||
text('sml DESC')
|
||||
).filter(
|
||||
similarity_filter,
|
||||
*cls.FILTERS
|
||||
).cte('objs')
|
||||
|
||||
object_ids_query = session.query(
|
||||
func.array_agg(filtered_objects_query.c.id)
|
||||
).cte()
|
||||
|
||||
row = await cls.database.fetch_one(object_ids_query)
|
||||
row = await cls.database.fetch_one(cls.object_ids_query, {"query": query_data})
|
||||
|
||||
if row is None:
|
||||
raise ValueError('Something is wrong!')
|
||||
|
||||
return row['array_agg_1']
|
||||
return row['array']
|
||||
|
||||
@classmethod
|
||||
def get_cache_key(cls, query_data: str) -> str:
|
||||
|
||||
Reference in New Issue
Block a user