Add book filter service

This commit is contained in:
2022-03-19 12:50:24 +03:00
parent 42181b6d4e
commit b517869d86
6 changed files with 101 additions and 70 deletions

View File

@@ -13,8 +13,10 @@ def check_token(api_key: str = Security(default_security)):
) )
def get_allowed_langs(allowed_langs: Optional[list[str]] = Query(None)) -> list[str]: def get_allowed_langs(
allowed_langs: Optional[list[str]] = Query(None),
) -> frozenset[str]:
if allowed_langs is not None: if allowed_langs is not None:
return allowed_langs return frozenset(allowed_langs)
return ["ru", "be", "uk"] return frozenset(("ru", "be", "uk"))

View File

@@ -5,7 +5,12 @@ from fastapi import HTTPException, status
from app.models import Author as AuthorDB from app.models import Author as AuthorDB
from app.models import Book as BookDB from app.models import Book as BookDB
from app.serializers.book import CreateBook, CreateRemoteBook from app.serializers.book import CreateBook, CreateRemoteBook
from app.services.common import TRGMSearchService, MeiliSearchService, GetRandomService from app.services.common import (
TRGMSearchService,
MeiliSearchService,
GetRandomService,
BaseFilterService,
)
GET_OBJECT_IDS_QUERY = """ GET_OBJECT_IDS_QUERY = """
@@ -29,6 +34,12 @@ class BookTGRMSearchService(TRGMSearchService):
GET_OBJECT_IDS_QUERY = GET_OBJECT_IDS_QUERY GET_OBJECT_IDS_QUERY = GET_OBJECT_IDS_QUERY
class BookFilterService(BaseFilterService):
MODEL_CLASS = BookDB
PREFETCH_RELATED = ["source"]
SELECT_RELATED = ["authors", "translators", "annotations"]
class BookCreator: class BookCreator:
@classmethod @classmethod
def _raise_bad_request(cls): def _raise_bad_request(cls):

View File

@@ -1,7 +1,7 @@
import abc import abc
import asyncio import asyncio
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Optional, Generic, TypeVar, Union from typing import Optional, Generic, TypeVar, TypedDict, Union
import aioredis import aioredis
from databases import Database from databases import Database
@@ -16,11 +16,12 @@ from app.utils.pagination import Page, CustomPage
from core.config import env_config from core.config import env_config
T = TypeVar("T", bound=Model) MODEL = TypeVar("MODEL", bound=Model)
QUERY = TypeVar("QUERY", bound=TypedDict)
class BaseSearchService(Generic[T], abc.ABC): class BaseSearchService(Generic[MODEL, QUERY], abc.ABC):
MODEL_CLASS: Optional[T] = None MODEL_CLASS: Optional[MODEL] = None
SELECT_RELATED: Optional[Union[list[str], str]] = None SELECT_RELATED: Optional[Union[list[str], str]] = None
PREFETCH_RELATED: Optional[Union[list[str], str]] = None PREFETCH_RELATED: Optional[Union[list[str], str]] = None
CUSTOM_CACHE_PREFIX: Optional[str] = None CUSTOM_CACHE_PREFIX: Optional[str] = None
@@ -36,7 +37,7 @@ class BaseSearchService(Generic[T], abc.ABC):
@classmethod @classmethod
@property @property
def model(cls) -> T: def model(cls) -> MODEL:
assert cls.MODEL_CLASS is not None, f"MODEL in {cls.__name__} don't set!" assert cls.MODEL_CLASS is not None, f"MODEL in {cls.__name__} don't set!"
return cls.MODEL_CLASS return cls.MODEL_CLASS
@@ -55,27 +56,28 @@ class BaseSearchService(Generic[T], abc.ABC):
def cache_prefix(cls) -> str: def cache_prefix(cls) -> str:
return cls.CUSTOM_CACHE_PREFIX or cls.model.Meta.tablename return cls.CUSTOM_CACHE_PREFIX or cls.model.Meta.tablename
@staticmethod
def _get_query_hash(query: QUERY):
return hash(frozenset(query.items()))
@classmethod @classmethod
async def _get_object_ids( async def _get_object_ids(cls, query: QUERY) -> list[int]:
cls, query_data: str, allowed_langs: list[str]
) -> list[int]:
... ...
@classmethod @classmethod
def get_cache_key(cls, query_data: str, allowed_langs: list[str]) -> str: def get_cache_key(cls, query: QUERY) -> str:
model_class_name = cls.cache_prefix model_class_name = cls.cache_prefix
allowed_langs_part = ",".join(sorted(allowed_langs)) query_hash = cls._get_query_hash(query)
return f"{model_class_name}_{query_data}_{allowed_langs_part}" return f"{model_class_name}_{query_hash}"
@classmethod @classmethod
async def get_cached_ids( async def get_cached_ids(
cls, cls,
query_data: str, query: QUERY,
allowed_langs: list[str],
redis: aioredis.Redis, redis: aioredis.Redis,
) -> Optional[list[int]]: ) -> Optional[list[int]]:
try: try:
key = cls.get_cache_key(query_data, allowed_langs) key = cls.get_cache_key(query)
data = await redis.get(key) data = await redis.get(key)
if data is None: if data is None:
@@ -89,37 +91,33 @@ class BaseSearchService(Generic[T], abc.ABC):
@classmethod @classmethod
async def cache_object_ids( async def cache_object_ids(
cls, cls,
query_data: str, query: QUERY,
allowed_langs: list[str],
object_ids: list[int], object_ids: list[int],
redis: aioredis.Redis, redis: aioredis.Redis,
): ):
try: try:
key = cls.get_cache_key(query_data, allowed_langs) key = cls.get_cache_key(query)
await redis.set(key, orjson.dumps(object_ids), ex=cls.CACHE_TTL) await redis.set(key, orjson.dumps(object_ids), ex=cls.CACHE_TTL)
except aioredis.RedisError as e: except aioredis.RedisError as e:
print(e) print(e)
@classmethod @classmethod
async def get_objects( async def get_objects(
cls, cls, query: QUERY, redis: aioredis.Redis
query_data: str, ) -> tuple[int, list[MODEL]]:
redis: aioredis.Redis,
allowed_langs: list[str],
) -> tuple[int, list[T]]:
params = cls.get_raw_params() params = cls.get_raw_params()
cached_object_ids = await cls.get_cached_ids(query_data, allowed_langs, redis) cached_object_ids = await cls.get_cached_ids(query, redis)
if cached_object_ids is None: if cached_object_ids is None:
object_ids = await cls._get_object_ids(query_data, allowed_langs) object_ids = await cls._get_object_ids(query)
await cls.cache_object_ids(query_data, allowed_langs, object_ids, redis) await cls.cache_object_ids(query, object_ids, redis)
else: else:
object_ids = cached_object_ids object_ids = cached_object_ids
limited_object_ids = object_ids[params.offset : params.offset + params.limit] limited_object_ids = object_ids[params.offset : params.offset + params.limit]
queryset: QuerySet[T] = cls.model.objects # type: ignore queryset: QuerySet[MODEL] = cls.model.objects
if cls.PREFETCH_RELATED is not None: if cls.PREFETCH_RELATED is not None:
queryset = queryset.prefetch_related(cls.PREFETCH_RELATED) queryset = queryset.prefetch_related(cls.PREFETCH_RELATED)
@@ -130,17 +128,20 @@ class BaseSearchService(Generic[T], abc.ABC):
return len(object_ids), await queryset.filter(id__in=limited_object_ids).all() return len(object_ids), await queryset.filter(id__in=limited_object_ids).all()
@classmethod @classmethod
async def get( async def get(cls, query: QUERY, redis: aioredis.Redis) -> Page[MODEL]:
cls, query: str, redis: aioredis.Redis, allowed_langs: list[str]
) -> Page[T]:
params = cls.get_params() params = cls.get_params()
total, objects = await cls.get_objects(query, redis, allowed_langs) total, objects = await cls.get_objects(query, redis)
return CustomPage.create(items=objects, total=total, params=params) return CustomPage.create(items=objects, total=total, params=params)
class TRGMSearchService(BaseSearchService[T]): class SearchQuery(TypedDict):
query: str
allowed_langs: frozenset[str]
class TRGMSearchService(Generic[MODEL], BaseSearchService[MODEL, SearchQuery]):
GET_OBJECT_IDS_QUERY: Optional[str] = None GET_OBJECT_IDS_QUERY: Optional[str] = None
@classmethod @classmethod
@@ -152,11 +153,10 @@ class TRGMSearchService(BaseSearchService[T]):
return cls.GET_OBJECT_IDS_QUERY return cls.GET_OBJECT_IDS_QUERY
@classmethod @classmethod
async def _get_object_ids( async def _get_object_ids(cls, query: SearchQuery) -> list[int]:
cls, query_data: str, allowed_langs: list[str]
) -> list[int]:
row = await cls.database.fetch_one( row = await cls.database.fetch_one(
cls.object_ids_query, {"query": query_data, "langs": allowed_langs} cls.object_ids_query,
{"query": query["query"], "langs": query["allowed_langs"]},
) )
if row is None: if row is None:
@@ -165,11 +165,11 @@ class TRGMSearchService(BaseSearchService[T]):
return row["array"] return row["array"]
class MeiliSearchService(BaseSearchService[T]): class MeiliSearchService(Generic[MODEL], BaseSearchService[MODEL, SearchQuery]):
MS_INDEX_NAME: Optional[str] = None MS_INDEX_NAME: Optional[str] = None
MS_INDEX_LANG_KEY: Optional[str] = None MS_INDEX_LANG_KEY: Optional[str] = None
_executor = ThreadPoolExecutor(4) _executor = ThreadPoolExecutor(2)
@classmethod @classmethod
@property @property
@@ -184,7 +184,7 @@ class MeiliSearchService(BaseSearchService[T]):
return cls.MS_INDEX_NAME return cls.MS_INDEX_NAME
@classmethod @classmethod
def get_allowed_langs_filter(cls, allowed_langs: list[str]) -> list[list[str]]: def get_allowed_langs_filter(cls, allowed_langs: frozenset[str]) -> list[list[str]]:
return [[f"{cls.lang_key} = {lang}" for lang in allowed_langs]] return [[f"{cls.lang_key} = {lang}" for lang in allowed_langs]]
@classmethod @classmethod
@@ -210,29 +210,27 @@ class MeiliSearchService(BaseSearchService[T]):
return ids return ids
@classmethod @classmethod
async def _get_object_ids( async def _get_object_ids(cls, query: SearchQuery) -> list[int]:
cls, query_data: str, allowed_langs: list[str]
) -> list[int]:
params = cls.get_raw_params() params = cls.get_raw_params()
allowed_langs_filter = cls.get_allowed_langs_filter(allowed_langs) allowed_langs_filter = cls.get_allowed_langs_filter(query["allowed_langs"])
return await asyncio.get_event_loop().run_in_executor( return await asyncio.get_event_loop().run_in_executor(
cls._executor, cls._executor,
cls.make_request, cls.make_request,
query_data, query["query"],
allowed_langs_filter, allowed_langs_filter,
params.offset, params.offset,
) )
class GetRandomService(Generic[T]): class GetRandomService(Generic[MODEL]):
MODEL_CLASS: Optional[T] = None MODEL_CLASS: Optional[MODEL] = None
GET_RANDOM_OBJECT_ID_QUERY: Optional[str] = None GET_RANDOM_OBJECT_ID_QUERY: Optional[str] = None
@classmethod @classmethod
@property @property
def model(cls) -> T: def model(cls) -> MODEL:
assert cls.MODEL_CLASS is not None, f"MODEL in {cls.__name__} don't set!" assert cls.MODEL_CLASS is not None, f"MODEL in {cls.__name__} don't set!"
return cls.MODEL_CLASS return cls.MODEL_CLASS
@@ -250,7 +248,17 @@ class GetRandomService(Generic[T]):
return cls.GET_RANDOM_OBJECT_ID_QUERY return cls.GET_RANDOM_OBJECT_ID_QUERY
@classmethod @classmethod
async def get_random_id(cls, allowed_langs: list[str]) -> int: async def get_random_id(cls, allowed_langs: frozenset[str]) -> int:
return await cls.database.fetch_val( return await cls.database.fetch_val(
cls.random_object_id_query, {"langs": allowed_langs} cls.random_object_id_query, {"langs": allowed_langs}
) )
class BaseFilterService(Generic[MODEL, QUERY], BaseSearchService[MODEL, QUERY]):
@classmethod
async def _get_object_ids(cls, query: QUERY) -> list[int]:
return (
await cls.model.objects.filter(**query)
.fields("id")
.values_list(flatten=True)
)

View File

@@ -54,7 +54,7 @@ async def create_author(data: CreateAuthor):
@author_router.get("/random", response_model=Author) @author_router.get("/random", response_model=Author)
async def get_random_author(allowed_langs: list[str] = Depends(get_allowed_langs)): async def get_random_author(allowed_langs: frozenset[str] = Depends(get_allowed_langs)):
author_id = await GetRandomAuthorService.get_random_id(allowed_langs) author_id = await GetRandomAuthorService.get_random_id(allowed_langs)
return ( return (
@@ -118,10 +118,12 @@ async def get_author_books(
"/search/{query}", response_model=CustomPage[Author], dependencies=[Depends(Params)] "/search/{query}", response_model=CustomPage[Author], dependencies=[Depends(Params)]
) )
async def search_authors( async def search_authors(
query: str, request: Request, allowed_langs: list[str] = Depends(get_allowed_langs) query: str,
request: Request,
allowed_langs: frozenset[str] = Depends(get_allowed_langs),
): ):
return await AuthorMeiliSearchService.get( return await AuthorMeiliSearchService.get(
query, request.app.state.redis, allowed_langs {"query": query, "allowed_langs": allowed_langs}, request.app.state.redis
) )
@@ -151,8 +153,10 @@ async def get_translated_books(
"/search/{query}", response_model=CustomPage[Author], dependencies=[Depends(Params)] "/search/{query}", response_model=CustomPage[Author], dependencies=[Depends(Params)]
) )
async def search_translators( async def search_translators(
query: str, request: Request, allowed_langs: list[str] = Depends(get_allowed_langs) query: str,
request: Request,
allowed_langs: frozenset[str] = Depends(get_allowed_langs),
): ):
return await TranslatorMeiliSearchService.get( return await TranslatorMeiliSearchService.get(
query, request.app.state.redis, allowed_langs {"query": query, "allowed_langs": allowed_langs}, request.app.state.redis
) )

View File

@@ -3,7 +3,6 @@ from typing import Union
from fastapi import APIRouter, Depends, Request, HTTPException, status from fastapi import APIRouter, Depends, Request, HTTPException, status
from fastapi_pagination import Params from fastapi_pagination import Params
from fastapi_pagination.ext.ormar import paginate
from app.depends import check_token, get_allowed_langs from app.depends import check_token, get_allowed_langs
from app.filters.book import get_book_filter from app.filters.book import get_book_filter
@@ -19,7 +18,12 @@ from app.serializers.book import (
CreateRemoteBook, CreateRemoteBook,
) )
from app.serializers.book_annotation import BookAnnotation from app.serializers.book_annotation import BookAnnotation
from app.services.book import BookMeiliSearchService, GetRandomBookService, BookCreator from app.services.book import (
BookMeiliSearchService,
BookFilterService,
GetRandomBookService,
BookCreator,
)
from app.utils.pagination import CustomPage from app.utils.pagination import CustomPage
@@ -36,12 +40,8 @@ SELECT_RELATED_FIELDS = ["authors", "translators", "annotations"]
@book_router.get( @book_router.get(
"/", response_model=CustomPage[RemoteBook], dependencies=[Depends(Params)] "/", response_model=CustomPage[RemoteBook], dependencies=[Depends(Params)]
) )
async def get_books(book_filter: dict = Depends(get_book_filter)): async def get_books(request: Request, book_filter: dict = Depends(get_book_filter)):
return await paginate( return await BookFilterService.get(book_filter, request.app.state.redis)
BookDB.objects.select_related(SELECT_RELATED_FIELDS)
.prefetch_related(PREFETCH_RELATED_FIELDS)
.filter(**book_filter)
)
@book_router.post("/", response_model=Book) @book_router.post("/", response_model=Book)
@@ -56,7 +56,7 @@ async def create_book(data: Union[CreateBook, CreateRemoteBook]):
@book_router.get("/random", response_model=BookDetail) @book_router.get("/random", response_model=BookDetail)
async def get_random_book(allowed_langs: list[str] = Depends(get_allowed_langs)): async def get_random_book(allowed_langs: frozenset[str] = Depends(get_allowed_langs)):
book_id = await GetRandomBookService.get_random_id(allowed_langs) book_id = await GetRandomBookService.get_random_id(allowed_langs)
book = ( book = (
@@ -137,8 +137,10 @@ async def get_book_annotation(id: int):
"/search/{query}", response_model=CustomPage[Book], dependencies=[Depends(Params)] "/search/{query}", response_model=CustomPage[Book], dependencies=[Depends(Params)]
) )
async def search_books( async def search_books(
query: str, request: Request, allowed_langs: list[str] = Depends(get_allowed_langs) query: str,
request: Request,
allowed_langs: frozenset[str] = Depends(get_allowed_langs),
): ):
return await BookMeiliSearchService.get( return await BookMeiliSearchService.get(
query, request.app.state.redis, allowed_langs {"query": query, "allowed_langs": allowed_langs}, request.app.state.redis
) )

View File

@@ -27,7 +27,9 @@ async def get_sequences():
@sequence_router.get("/random", response_model=Sequence) @sequence_router.get("/random", response_model=Sequence)
async def get_random_sequence(allowed_langs: list[str] = Depends(get_allowed_langs)): async def get_random_sequence(
allowed_langs: frozenset[str] = Depends(get_allowed_langs),
):
sequence_id = await GetRandomSequenceService.get_random_id(allowed_langs) sequence_id = await GetRandomSequenceService.get_random_id(allowed_langs)
return await SequenceDB.objects.get(id=sequence_id) return await SequenceDB.objects.get(id=sequence_id)
@@ -65,8 +67,10 @@ async def create_sequence(data: CreateSequence):
dependencies=[Depends(Params)], dependencies=[Depends(Params)],
) )
async def search_sequences( async def search_sequences(
query: str, request: Request, allowed_langs: list[str] = Depends(get_allowed_langs) query: str,
request: Request,
allowed_langs: frozenset[str] = Depends(get_allowed_langs),
): ):
return await SequenceMeiliSearchService.get( return await SequenceMeiliSearchService.get(
query, request.app.state.redis, allowed_langs {"query": query, "allowed_langs": allowed_langs}, request.app.state.redis
) )