Files
book_library_server/fastapi_book_server/app/services/common.py
2023-01-31 11:55:00 +01:00

338 lines
10 KiB
Python

import abc
import asyncio
from concurrent.futures import ThreadPoolExecutor
import hashlib
from random import choice
from typing import Generic, Optional, TypedDict, TypeVar, Union
from databases import Database
from fastapi_pagination.api import resolve_params
from fastapi_pagination.bases import AbstractParams, RawParams
import meilisearch
import orjson
from ormar import Model, QuerySet
from redis import asyncio as aioredis
from sqlalchemy import Table
from app.utils.orjson_default import default as orjson_default
from app.utils.pagination import CustomPage, Page
from core.config import env_config
MODEL = TypeVar("MODEL", bound=Model)
QUERY = TypeVar("QUERY", bound=TypedDict)
class BaseService(Generic[MODEL, QUERY], abc.ABC):
MODEL_CLASS: Optional[MODEL] = None
CACHE_PREFIX: str = ""
CUSTOM_MODEL_CACHE_NAME: Optional[str] = None
CACHE_TTL = 6 * 60 * 60
@classmethod
@property
def model(cls) -> MODEL:
assert cls.MODEL_CLASS is not None, f"MODEL in {cls.__name__} don't set!"
return cls.MODEL_CLASS
@classmethod
@property
def table(cls) -> Table:
return cls.model.Meta.table
@classmethod
@property
def database(cls) -> Database:
return cls.model.Meta.database
@classmethod
@property
def cache_prefix(cls) -> str:
return cls.CUSTOM_MODEL_CACHE_NAME or cls.model.Meta.tablename
@staticmethod
def _get_query_hash(query: QUERY) -> str:
json_value = orjson.dumps(query, orjson_default, option=orjson.OPT_SORT_KEYS)
return hashlib.md5(json_value).hexdigest()
@classmethod
def get_cache_key(cls, query: QUERY) -> str:
model_class_name = cls.cache_prefix
query_hash = cls._get_query_hash(query)
cache_key = f"{model_class_name}_{query_hash}"
return f"{cls.CACHE_PREFIX}_{cache_key}" if cls.CACHE_PREFIX else cache_key
@classmethod
async def cache_object_ids(
cls,
query: QUERY,
object_ids: list[int],
redis: aioredis.Redis,
) -> bool:
try:
key = cls.get_cache_key(query)
active_key = f"{key}_active"
p = redis.pipeline()
await p.delete(key)
await p.set(active_key, 1, ex=cls.CACHE_TTL)
await p.sadd(key, *object_ids)
await p.execute()
return True
except aioredis.RedisError as e:
print(e)
return False
class BaseSearchService(Generic[MODEL, QUERY], BaseService[MODEL, QUERY]):
SELECT_RELATED: Optional[Union[list[str], str]] = None
PREFETCH_RELATED: Optional[Union[list[str], str]] = None
@classmethod
def get_params(cls) -> AbstractParams:
return resolve_params()
@classmethod
def get_raw_params(cls) -> RawParams:
return resolve_params().to_raw_params()
@classmethod
async def _get_object_ids(cls, query: QUERY) -> list[int]:
...
@classmethod
async def get_cached_ids(
cls,
query: QUERY,
redis: aioredis.Redis,
params: RawParams,
) -> Optional[tuple[int, list[int]]]:
try:
key = cls.get_cache_key(query)
active_key = f"{key}_active"
if not await redis.exists(active_key):
return None
objects_count, objects = await asyncio.gather(
redis.llen(key),
redis.lrange(key, params.offset, params.offset + params.limit),
)
return objects_count, [int(item.decode()) for item in objects]
except aioredis.RedisError as e:
print(e)
return None
@classmethod
async def get_object_ids(
cls, query: QUERY, redis: aioredis.Redis, no_cache: bool
) -> tuple[int, list[int]]:
params = cls.get_raw_params()
if not no_cache and (
cached_object_ids := await cls.get_cached_ids(query, redis, params)
):
return cached_object_ids
object_ids = await cls._get_object_ids(query)
limited_object_ids = object_ids[params.offset : params.offset + params.limit]
if not no_cache and len(object_ids) != 0:
await cls.cache_object_ids(query, object_ids, redis)
return len(object_ids), limited_object_ids
@classmethod
async def get_limited_objects(
cls, query: QUERY, redis: aioredis.Redis, no_cache: bool
) -> tuple[int, list[MODEL]]:
count, object_ids = await cls.get_object_ids(query, redis, no_cache)
queryset: QuerySet[MODEL] = cls.model.objects
if cls.PREFETCH_RELATED is not None:
queryset = queryset.prefetch_related(cls.PREFETCH_RELATED)
if cls.SELECT_RELATED:
queryset = queryset.select_related(cls.SELECT_RELATED)
db_objects = await queryset.filter(id__in=object_ids).all()
return count, sorted(db_objects, key=lambda o: object_ids.index(o.id))
@classmethod
async def get(cls, query: QUERY, redis: aioredis.Redis) -> Page[MODEL]:
no_cache: bool = query.get("no_cache", False) # type: ignore
params = cls.get_params()
total, objects = await cls.get_limited_objects(query, redis, no_cache)
return CustomPage.create(items=objects, total=total, params=params)
class SearchQuery(TypedDict):
query: str
allowed_langs: frozenset[str]
class TRGMSearchService(Generic[MODEL], BaseSearchService[MODEL, SearchQuery]):
GET_OBJECT_IDS_QUERY: Optional[str] = None
@classmethod
@property
def object_ids_query(cls) -> str:
assert (
cls.GET_OBJECT_IDS_QUERY is not None
), f"GET_OBJECT_IDS_QUERY in {cls.__name__} don't set!"
return cls.GET_OBJECT_IDS_QUERY
@classmethod
async def _get_object_ids(cls, query: SearchQuery) -> list[int]:
row = await cls.database.fetch_one(
cls.object_ids_query,
{"query": query["query"], "langs": query["allowed_langs"]},
)
if row is None:
raise ValueError("Something is wrong!")
return row["array"]
class MeiliSearchService(Generic[MODEL], BaseSearchService[MODEL, SearchQuery]):
MS_INDEX_NAME: Optional[str] = None
MS_INDEX_LANG_KEY: Optional[str] = None
_executor = ThreadPoolExecutor(2)
@classmethod
@property
def lang_key(cls) -> str:
assert (
cls.MS_INDEX_LANG_KEY is not None
), f"MS_INDEX_LANG_KEY in {cls.__name__} don't set!"
return cls.MS_INDEX_LANG_KEY
@classmethod
@property
def index_name(cls) -> str:
assert (
cls.MS_INDEX_NAME is not None
), f"MS_INDEX_NAME in {cls.__name__} don't set!"
return cls.MS_INDEX_NAME
@classmethod
def get_allowed_langs_filter(cls, allowed_langs: frozenset[str]) -> list[str]:
langs_values = ", ".join(allowed_langs)
return [f"{cls.lang_key} IN [{langs_values}]"]
@classmethod
def make_request(
cls, query: str, allowed_langs_filter: list[str], offset: int, limit: int
) -> tuple[int, list[int]]:
client = meilisearch.Client(env_config.MEILI_HOST, env_config.MEILI_MASTER_KEY)
index = client.index(cls.index_name)
result = index.search(
query,
{
"filter": allowed_langs_filter,
"offset": offset,
"limit": limit,
"attributesToRetrieve": ["id"],
},
)
total: int = result["estimatedTotalHits"]
ids: list[int] = [r["id"] for r in result["hits"][:total]]
return total, ids
@classmethod
async def _get_object_ids(cls, query: SearchQuery) -> tuple[int, list[int]]:
params = cls.get_raw_params()
allowed_langs_filter = cls.get_allowed_langs_filter(query["allowed_langs"])
return await asyncio.get_event_loop().run_in_executor(
cls._executor,
cls.make_request,
query["query"],
allowed_langs_filter,
params.offset,
params.limit,
)
@classmethod
async def get_object_ids(
cls, query: SearchQuery, redis: aioredis.Redis
) -> tuple[int, list[int]]:
return await cls._get_object_ids(query)
class GetRandomService(Generic[MODEL, QUERY], BaseService[MODEL, QUERY]):
GET_OBJECTS_ID_QUERY: Optional[str] = None
CACHE_PREFIX: str = "random"
@classmethod
@property
def objects_id_query(cls) -> str:
assert (
cls.GET_OBJECTS_ID_QUERY is not None
), f"GET_OBJECT_IDS_QUERY in {cls.__name__} don't set!"
return cls.GET_OBJECTS_ID_QUERY
@classmethod
async def _get_objects_from_db(cls, query: QUERY) -> list[int]:
objects = await cls.database.fetch_all(
cls.objects_id_query, {"langs": query["allowed_langs"]}
)
return [obj["id"] for obj in objects]
@classmethod
async def _get_random_object_from_cache(
cls, query: QUERY, redis: aioredis.Redis
) -> Optional[int]:
try:
key = cls.get_cache_key(query)
active_key = f"{key}_active"
if not await redis.exists(active_key):
return None
data: str = await redis.srandmember(key) # type: ignore
return int(data)
except aioredis.RedisError as e:
print(e)
return None
@classmethod
async def get_random_id(
cls,
query: QUERY,
redis: aioredis.Redis,
) -> int:
cached_object_id = await cls._get_random_object_from_cache(query, redis)
if cached_object_id is not None:
return cached_object_id
object_ids = await cls._get_objects_from_db(query)
await cls.cache_object_ids(query, object_ids, redis)
return choice(object_ids)
class BaseFilterService(Generic[MODEL, QUERY], BaseSearchService[MODEL, QUERY]):
@classmethod
async def _get_object_ids(cls, query: QUERY) -> list[int]:
return (
await cls.model.objects.filter(**query)
.fields("id")
.values_list(flatten=True)
)