Add book filter service

This commit is contained in:
2022-03-19 12:50:24 +03:00
parent 42181b6d4e
commit b517869d86
6 changed files with 101 additions and 70 deletions

View File

@@ -5,7 +5,12 @@ from fastapi import HTTPException, status
from app.models import Author as AuthorDB
from app.models import Book as BookDB
from app.serializers.book import CreateBook, CreateRemoteBook
from app.services.common import TRGMSearchService, MeiliSearchService, GetRandomService
from app.services.common import (
TRGMSearchService,
MeiliSearchService,
GetRandomService,
BaseFilterService,
)
GET_OBJECT_IDS_QUERY = """
@@ -29,6 +34,12 @@ class BookTGRMSearchService(TRGMSearchService):
GET_OBJECT_IDS_QUERY = GET_OBJECT_IDS_QUERY
class BookFilterService(BaseFilterService):
MODEL_CLASS = BookDB
PREFETCH_RELATED = ["source"]
SELECT_RELATED = ["authors", "translators", "annotations"]
class BookCreator:
@classmethod
def _raise_bad_request(cls):

View File

@@ -1,7 +1,7 @@
import abc
import asyncio
from concurrent.futures import ThreadPoolExecutor
from typing import Optional, Generic, TypeVar, Union
from typing import Optional, Generic, TypeVar, TypedDict, Union
import aioredis
from databases import Database
@@ -16,11 +16,12 @@ from app.utils.pagination import Page, CustomPage
from core.config import env_config
T = TypeVar("T", bound=Model)
MODEL = TypeVar("MODEL", bound=Model)
QUERY = TypeVar("QUERY", bound=TypedDict)
class BaseSearchService(Generic[T], abc.ABC):
MODEL_CLASS: Optional[T] = None
class BaseSearchService(Generic[MODEL, QUERY], abc.ABC):
MODEL_CLASS: Optional[MODEL] = None
SELECT_RELATED: Optional[Union[list[str], str]] = None
PREFETCH_RELATED: Optional[Union[list[str], str]] = None
CUSTOM_CACHE_PREFIX: Optional[str] = None
@@ -36,7 +37,7 @@ class BaseSearchService(Generic[T], abc.ABC):
@classmethod
@property
def model(cls) -> T:
def model(cls) -> MODEL:
assert cls.MODEL_CLASS is not None, f"MODEL in {cls.__name__} don't set!"
return cls.MODEL_CLASS
@@ -55,27 +56,28 @@ class BaseSearchService(Generic[T], abc.ABC):
def cache_prefix(cls) -> str:
return cls.CUSTOM_CACHE_PREFIX or cls.model.Meta.tablename
@staticmethod
def _get_query_hash(query: QUERY):
return hash(frozenset(query.items()))
@classmethod
async def _get_object_ids(
cls, query_data: str, allowed_langs: list[str]
) -> list[int]:
async def _get_object_ids(cls, query: QUERY) -> list[int]:
...
@classmethod
def get_cache_key(cls, query_data: str, allowed_langs: list[str]) -> str:
def get_cache_key(cls, query: QUERY) -> str:
model_class_name = cls.cache_prefix
allowed_langs_part = ",".join(sorted(allowed_langs))
return f"{model_class_name}_{query_data}_{allowed_langs_part}"
query_hash = cls._get_query_hash(query)
return f"{model_class_name}_{query_hash}"
@classmethod
async def get_cached_ids(
cls,
query_data: str,
allowed_langs: list[str],
query: QUERY,
redis: aioredis.Redis,
) -> Optional[list[int]]:
try:
key = cls.get_cache_key(query_data, allowed_langs)
key = cls.get_cache_key(query)
data = await redis.get(key)
if data is None:
@@ -89,37 +91,33 @@ class BaseSearchService(Generic[T], abc.ABC):
@classmethod
async def cache_object_ids(
cls,
query_data: str,
allowed_langs: list[str],
query: QUERY,
object_ids: list[int],
redis: aioredis.Redis,
):
try:
key = cls.get_cache_key(query_data, allowed_langs)
key = cls.get_cache_key(query)
await redis.set(key, orjson.dumps(object_ids), ex=cls.CACHE_TTL)
except aioredis.RedisError as e:
print(e)
@classmethod
async def get_objects(
cls,
query_data: str,
redis: aioredis.Redis,
allowed_langs: list[str],
) -> tuple[int, list[T]]:
cls, query: QUERY, redis: aioredis.Redis
) -> tuple[int, list[MODEL]]:
params = cls.get_raw_params()
cached_object_ids = await cls.get_cached_ids(query_data, allowed_langs, redis)
cached_object_ids = await cls.get_cached_ids(query, redis)
if cached_object_ids is None:
object_ids = await cls._get_object_ids(query_data, allowed_langs)
await cls.cache_object_ids(query_data, allowed_langs, object_ids, redis)
object_ids = await cls._get_object_ids(query)
await cls.cache_object_ids(query, object_ids, redis)
else:
object_ids = cached_object_ids
limited_object_ids = object_ids[params.offset : params.offset + params.limit]
queryset: QuerySet[T] = cls.model.objects # type: ignore
queryset: QuerySet[MODEL] = cls.model.objects
if cls.PREFETCH_RELATED is not None:
queryset = queryset.prefetch_related(cls.PREFETCH_RELATED)
@@ -130,17 +128,20 @@ class BaseSearchService(Generic[T], abc.ABC):
return len(object_ids), await queryset.filter(id__in=limited_object_ids).all()
@classmethod
async def get(
cls, query: str, redis: aioredis.Redis, allowed_langs: list[str]
) -> Page[T]:
async def get(cls, query: QUERY, redis: aioredis.Redis) -> Page[MODEL]:
params = cls.get_params()
total, objects = await cls.get_objects(query, redis, allowed_langs)
total, objects = await cls.get_objects(query, redis)
return CustomPage.create(items=objects, total=total, params=params)
class TRGMSearchService(BaseSearchService[T]):
class SearchQuery(TypedDict):
query: str
allowed_langs: frozenset[str]
class TRGMSearchService(Generic[MODEL], BaseSearchService[MODEL, SearchQuery]):
GET_OBJECT_IDS_QUERY: Optional[str] = None
@classmethod
@@ -152,11 +153,10 @@ class TRGMSearchService(BaseSearchService[T]):
return cls.GET_OBJECT_IDS_QUERY
@classmethod
async def _get_object_ids(
cls, query_data: str, allowed_langs: list[str]
) -> list[int]:
async def _get_object_ids(cls, query: SearchQuery) -> list[int]:
row = await cls.database.fetch_one(
cls.object_ids_query, {"query": query_data, "langs": allowed_langs}
cls.object_ids_query,
{"query": query["query"], "langs": query["allowed_langs"]},
)
if row is None:
@@ -165,11 +165,11 @@ class TRGMSearchService(BaseSearchService[T]):
return row["array"]
class MeiliSearchService(BaseSearchService[T]):
class MeiliSearchService(Generic[MODEL], BaseSearchService[MODEL, SearchQuery]):
MS_INDEX_NAME: Optional[str] = None
MS_INDEX_LANG_KEY: Optional[str] = None
_executor = ThreadPoolExecutor(4)
_executor = ThreadPoolExecutor(2)
@classmethod
@property
@@ -184,7 +184,7 @@ class MeiliSearchService(BaseSearchService[T]):
return cls.MS_INDEX_NAME
@classmethod
def get_allowed_langs_filter(cls, allowed_langs: list[str]) -> list[list[str]]:
def get_allowed_langs_filter(cls, allowed_langs: frozenset[str]) -> list[list[str]]:
return [[f"{cls.lang_key} = {lang}" for lang in allowed_langs]]
@classmethod
@@ -210,29 +210,27 @@ class MeiliSearchService(BaseSearchService[T]):
return ids
@classmethod
async def _get_object_ids(
cls, query_data: str, allowed_langs: list[str]
) -> list[int]:
async def _get_object_ids(cls, query: SearchQuery) -> list[int]:
params = cls.get_raw_params()
allowed_langs_filter = cls.get_allowed_langs_filter(allowed_langs)
allowed_langs_filter = cls.get_allowed_langs_filter(query["allowed_langs"])
return await asyncio.get_event_loop().run_in_executor(
cls._executor,
cls.make_request,
query_data,
query["query"],
allowed_langs_filter,
params.offset,
)
class GetRandomService(Generic[T]):
MODEL_CLASS: Optional[T] = None
class GetRandomService(Generic[MODEL]):
MODEL_CLASS: Optional[MODEL] = None
GET_RANDOM_OBJECT_ID_QUERY: Optional[str] = None
@classmethod
@property
def model(cls) -> T:
def model(cls) -> MODEL:
assert cls.MODEL_CLASS is not None, f"MODEL in {cls.__name__} don't set!"
return cls.MODEL_CLASS
@@ -250,7 +248,17 @@ class GetRandomService(Generic[T]):
return cls.GET_RANDOM_OBJECT_ID_QUERY
@classmethod
async def get_random_id(cls, allowed_langs: list[str]) -> int:
async def get_random_id(cls, allowed_langs: frozenset[str]) -> int:
return await cls.database.fetch_val(
cls.random_object_id_query, {"langs": allowed_langs}
)
class BaseFilterService(Generic[MODEL, QUERY], BaseSearchService[MODEL, QUERY]):
@classmethod
async def _get_object_ids(cls, query: QUERY) -> list[int]:
return (
await cls.model.objects.filter(**query)
.fields("id")
.values_list(flatten=True)
)