Add search result cache

This commit is contained in:
2021-11-21 16:48:04 +03:00
parent ef26b979d4
commit cc3ded9a7d
8 changed files with 121 additions and 36 deletions

View File

@@ -1,15 +1,17 @@
from typing import Optional, Generic, TypeVar, Union
from itertools import permutations
from databases import Database
import json
from fastapi_pagination.api import resolve_params
from fastapi_pagination.bases import AbstractParams, RawParams
from app.utils.pagination import Page, CustomPage
import aioredis
import orjson
from ormar import Model, QuerySet
from sqlalchemy import text, func, select, or_, Table, Column, cast, Text
from sqlalchemy.orm import Session
from databases import Database
def join_fields(fields):
@@ -30,6 +32,7 @@ class TRGMSearchService(Generic[T]):
SELECT_RELATED: Optional[Union[list[str], str]] = None
PREFETCH_RELATED: Optional[Union[list[str], str]] = None
FILTERS = []
CACHE_TTL = 5 * 60
@classmethod
def get_params(cls) -> AbstractParams:
@@ -78,15 +81,13 @@ class TRGMSearchService(Generic[T]):
)
@classmethod
async def get_objects(cls, query_data: str) -> tuple[int, list[T]]:
async def _get_object_ids(cls, query_data: str) -> list[int]:
similarity = cls.get_similarity_subquery(query_data)
similarity_filter = cls.get_similarity_filter_subquery(query_data)
params = cls.get_raw_params()
session = Session(cls.database.connection())
q1 = session.query(
filtered_objects_query = session.query(
cls.table.c.id, similarity
).order_by(
text('sml DESC')
@@ -95,23 +96,57 @@ class TRGMSearchService(Generic[T]):
*cls.FILTERS
).cte('objs')
sq = session.query(q1.c.id).limit(params.limit).offset(params.offset).subquery()
q2 = session.query(
func.json_build_object(
text("'total'"), func.count(q1.c.id),
text("'items'"), select(func.array_to_json(func.array_agg(sq.c.id)))
)
object_ids_query = session.query(
func.array_agg(filtered_objects_query.c.id)
).cte()
print(str(q2))
row = await cls.database.fetch_one(q2)
row = await cls.database.fetch_one(object_ids_query)
if row is None:
raise ValueError('Something is wrong!')
result = json.loads(row['json_build_object_1'])
return row['array_agg_1']
@classmethod
def get_cache_key(cls, query_data: str) -> str:
model_class_name = cls.model.__class__.__name__
return f"{model_class_name}_{query_data}"
@classmethod
async def get_cached_ids(cls, query_data: str, redis: aioredis.Redis) -> Optional[list[int]]:
try:
key = cls.get_cache_key(query_data)
data = await redis.get(key)
if data is None:
return data
return orjson.loads(data)
except aioredis.RedisError as e:
print(e)
return None
@classmethod
async def cache_object_ids(cls, query_data: str, object_ids: list[int], redis: aioredis.Redis):
try:
key = cls.get_cache_key(query_data)
await redis.set(key, orjson.dumps(object_ids), ex=cls.CACHE_TTL)
except aioredis.RedisError as e:
print(e)
@classmethod
async def get_objects(cls, query_data: str, redis: aioredis.Redis) -> tuple[int, list[T]]:
params = cls.get_raw_params()
cached_object_ids = await cls.get_cached_ids(query_data, redis)
if cached_object_ids is None:
object_ids = await cls._get_object_ids(query_data)
await cls.cache_object_ids(query_data, object_ids, redis)
else:
object_ids = cached_object_ids
limited_object_ids = object_ids[params.offset:params.offset + params.limit]
queryset: QuerySet[T] = cls.model.objects
@@ -121,14 +156,13 @@ class TRGMSearchService(Generic[T]):
if cls.SELECT_RELATED:
queryset = queryset.select_related(cls.SELECT_RELATED)
return result['total'], await queryset.filter(id__in=result['items']).all()
return len(object_ids), await queryset.filter(id__in=limited_object_ids).all()
@classmethod
async def get(cls, query: str) -> Page[T]:
async def get(cls, query: str, redis: aioredis.Redis) -> Page[T]:
params = cls.get_params()
total, objects = await cls.get_objects(query)
total, objects = await cls.get_objects(query, redis)
return CustomPage.create(
items=objects,