Add allowed_langs filter

This commit is contained in:
2022-01-02 20:15:55 +03:00
parent cbba30f2af
commit 017cc05a19
9 changed files with 170 additions and 53 deletions

View File

@@ -54,8 +54,12 @@ class TRGMSearchService(Generic[T]):
return cls.GET_OBJECT_IDS_QUERY
@classmethod
async def _get_object_ids(cls, query_data: str) -> list[int]:
row = await cls.database.fetch_one(cls.object_ids_query, {"query": query_data})
async def _get_object_ids(
cls, query_data: str, allowed_langs: list[str]
) -> list[int]:
row = await cls.database.fetch_one(
cls.object_ids_query, {"query": query_data, "langs": allowed_langs}
)
if row is None:
raise ValueError("Something is wrong!")
@@ -63,16 +67,20 @@ class TRGMSearchService(Generic[T]):
return row["array"]
@classmethod
def get_cache_key(cls, query_data: str) -> str:
def get_cache_key(cls, query_data: str, allowed_langs: list[str]) -> str:
model_class_name = cls.model.__class__.__name__
return f"{model_class_name}_{query_data}"
allowed_langs_part = ",".join(allowed_langs)
return f"{model_class_name}_{query_data}_{allowed_langs_part}"
@classmethod
async def get_cached_ids(
cls, query_data: str, redis: aioredis.Redis
cls,
query_data: str,
allowed_langs: list[str],
redis: aioredis.Redis,
) -> Optional[list[int]]:
try:
key = cls.get_cache_key(query_data)
key = cls.get_cache_key(query_data, allowed_langs)
data = await redis.get(key)
if data is None:
@@ -85,25 +93,32 @@ class TRGMSearchService(Generic[T]):
@classmethod
async def cache_object_ids(
cls, query_data: str, object_ids: list[int], redis: aioredis.Redis
cls,
query_data: str,
allowed_langs: list[str],
object_ids: list[int],
redis: aioredis.Redis,
):
try:
key = cls.get_cache_key(query_data)
key = cls.get_cache_key(query_data, allowed_langs)
await redis.set(key, orjson.dumps(object_ids), ex=cls.CACHE_TTL)
except aioredis.RedisError as e:
print(e)
@classmethod
async def get_objects(
cls, query_data: str, redis: aioredis.Redis
cls,
query_data: str,
redis: aioredis.Redis,
allowed_langs: list[str],
) -> tuple[int, list[T]]:
params = cls.get_raw_params()
cached_object_ids = await cls.get_cached_ids(query_data, redis)
cached_object_ids = await cls.get_cached_ids(query_data, allowed_langs, redis)
if cached_object_ids is None:
object_ids = await cls._get_object_ids(query_data)
await cls.cache_object_ids(query_data, object_ids, redis)
object_ids = await cls._get_object_ids(query_data, allowed_langs)
await cls.cache_object_ids(query_data, allowed_langs, object_ids, redis)
else:
object_ids = cached_object_ids
@@ -120,23 +135,19 @@ class TRGMSearchService(Generic[T]):
return len(object_ids), await queryset.filter(id__in=limited_object_ids).all()
@classmethod
async def get(cls, query: str, redis: aioredis.Redis) -> Page[T]:
async def get(
cls, query: str, redis: aioredis.Redis, allowed_langs: list[str]
) -> Page[T]:
params = cls.get_params()
total, objects = await cls.get_objects(query, redis)
total, objects = await cls.get_objects(query, redis, allowed_langs)
return CustomPage.create(items=objects, total=total, params=params)
GET_RANDOM_OBJECT_ID_QUERY = """
SELECT id FROM {table}
WHERE id >= RANDOM() * (SELECT MAX(id) FROM {table})
ORDER BY id LIMIT 1;
"""
class GetRandomService(Generic[T]):
MODEL_CLASS: Optional[T] = None
GET_RANDOM_OBJECT_ID_QUERY: Optional[str] = None
@classmethod
@property
@@ -150,7 +161,15 @@ class GetRandomService(Generic[T]):
return cls.model.Meta.database
@classmethod
async def get_random_id(cls) -> int:
table_name = cls.model.Meta.tablename
query = GET_RANDOM_OBJECT_ID_QUERY.format(table=table_name)
return await cls.database.fetch_val(query)
@property
def random_object_id_query(cls) -> str:
assert (
cls.GET_RANDOM_OBJECT_ID_QUERY is not None
), f"GET_OBJECT_IDS_QUERY in {cls.__name__} don't set!"
return cls.GET_RANDOM_OBJECT_ID_QUERY
@classmethod
async def get_random_id(cls, allowed_langs: list[str]) -> int:
return await cls.database.fetch_val(
cls.random_object_id_query, {"langs": allowed_langs}
)