Optimize search queries

This commit is contained in:
2021-11-19 15:01:51 +03:00
parent 2cb808ec3c
commit 2a6950e2cf
3 changed files with 75 additions and 52 deletions

View File

@@ -0,0 +1,30 @@
"""empty message
Revision ID: 08193b547a80
Revises: b44117a41998
Create Date: 2021-11-19 14:04:16.589304
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '08193b547a80'
down_revision = 'b44117a41998'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_index(op.f('ix_books_title'), 'books', ['title'], unique=False)
op.create_index(op.f('ix_sequences_name'), 'sequences', ['name'], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_sequences_name'), table_name='sequences')
op.drop_index(op.f('ix_books_title'), table_name='books')
# ### end Alembic commands ###

View File

@@ -79,7 +79,7 @@ class Sequence(ormar.Model):
source: Source = ormar.ForeignKey(Source, nullable=False) source: Source = ormar.ForeignKey(Source, nullable=False)
remote_id: int = ormar.Integer(minimum=0, nullable=False) # type: ignore remote_id: int = ormar.Integer(minimum=0, nullable=False) # type: ignore
name: str = ormar.String(max_length=256, nullable=False) # type: ignore name: str = ormar.String(max_length=256, nullable=False, index=True) # type: ignore
class BookAuthors(ormar.Model): class BookAuthors(ormar.Model):
@@ -129,7 +129,7 @@ class Book(ormar.Model):
source: Source = ormar.ForeignKey(Source, nullable=False) source: Source = ormar.ForeignKey(Source, nullable=False)
remote_id: int = ormar.Integer(minimum=0, nullable=False) # type: ignore remote_id: int = ormar.Integer(minimum=0, nullable=False) # type: ignore
title: str = ormar.String(max_length=256, nullable=False) # type: ignore title: str = ormar.String(max_length=256, nullable=False, index=True) # type: ignore
lang: str = ormar.String(max_length=3, nullable=False) # type: ignore lang: str = ormar.String(max_length=3, nullable=False) # type: ignore
file_type: str = ormar.String(max_length=4, nullable=False) # type: ignore file_type: str = ormar.String(max_length=4, nullable=False) # type: ignore
uploaded: date = ormar.Date() # type: ignore uploaded: date = ormar.Date() # type: ignore

View File

@@ -1,13 +1,14 @@
from typing import Optional, Generic, TypeVar, Union, cast from typing import Optional, Generic, TypeVar, Union
from itertools import permutations from itertools import permutations
import asyncio import json
from fastapi_pagination.api import resolve_params from fastapi_pagination.api import resolve_params
from fastapi_pagination.bases import AbstractParams, RawParams from fastapi_pagination.bases import AbstractParams, RawParams
from app.utils.pagination import Page, CustomPage from app.utils.pagination import Page, CustomPage
from ormar import Model, QuerySet from ormar import Model, QuerySet
from sqlalchemy import text, func, select, desc, Table, Column from sqlalchemy import text, func, select, or_, Table, Column
from sqlalchemy.orm import Session
from databases import Database from databases import Database
@@ -59,60 +60,57 @@ class TRGMSearchService(Generic[T]):
assert cls.FIELDS is not None, f"FIELDS in {cls.__name__} don't set!" assert cls.FIELDS is not None, f"FIELDS in {cls.__name__} don't set!"
assert len(cls.FIELDS) != 0, f"FIELDS in {cls.__name__} must be not empty!" assert len(cls.FIELDS) != 0, f"FIELDS in {cls.__name__} must be not empty!"
if len(cls.FIELDS) == 1: return permutations(cls.FIELDS, len(cls.FIELDS))
return cls.FIELDS
combinations = []
for i in range(1, len(cls.FIELDS)):
combinations += permutations(cls.FIELDS, i)
return combinations
@classmethod @classmethod
def get_similarity_subquery(cls, query: str): def get_similarity_subquery(cls, query: str):
combs = cls.fields_combinations
return func.greatest( return func.greatest(
*[func.similarity(join_fields(comb), f"{query}::text") for comb in cls.fields_combinations] *[func.similarity(join_fields(comb), f"{query}::text") for comb in combs]
).label("sml") ).label("sml")
@classmethod @classmethod
def get_object_ids_query(cls, query: str): def get_similarity_filter_subquery(cls, query: str):
similarity = cls.get_similarity_subquery(query) return or_(
params = cls.get_raw_params() *[join_fields(comb) % f"{query}::text" for comb in cls.fields_combinations]
return select(
[cls.table.c.id],
).where(
similarity > 0.5
).order_by(
desc(similarity)
).limit(params.limit).offset(params.offset)
@classmethod
def get_objects_count_query(cls, query: str):
similarity = cls.get_similarity_subquery(query)
return select(
func.count(cls.table.c.id)
).where(
similarity > 0.5
) )
@classmethod @classmethod
async def get_objects_count(cls, query: str) -> int: async def get_objects(cls, query_data: str) -> tuple[int, list[T]]:
count_query = cls.get_objects_count_query(query) similarity = cls.get_similarity_subquery(query_data)
similarity_filter = cls.get_similarity_filter_subquery(query_data)
count_row = await cls.database.fetch_one(count_query) params = cls.get_raw_params()
assert count_row is not None session = Session(cls.database.connection())
return cast(int, count_row.get("count_1")) q1 = session.query(
cls.table.c.id, similarity
).order_by(
text('sml DESC')
).filter(
cls.table.c.is_deleted == False,
similarity_filter
).cte('objs')
@classmethod sq = session.query(q1.c.id).limit(params.limit).offset(params.offset).subquery()
async def get_objects(cls, query: str) -> list[T]:
ids_query = cls.get_object_ids_query(query)
ids = await cls.database.fetch_all(ids_query) q2 = session.query(
func.json_build_object(
text("'total'"), func.count(q1.c.id),
text("'items'"), select(func.array_to_json(func.array_agg(sq.c.id)))
)
).cte()
print(str(q2))
row = await cls.database.fetch_one(q2)
if row is None:
raise ValueError('Something is wrong!')
result = json.loads(row['json_build_object_1'])
queryset: QuerySet[T] = cls.model.objects queryset: QuerySet[T] = cls.model.objects
@@ -122,19 +120,14 @@ class TRGMSearchService(Generic[T]):
if cls.SELECT_RELATED: if cls.SELECT_RELATED:
queryset = queryset.select_related(cls.SELECT_RELATED) queryset = queryset.select_related(cls.SELECT_RELATED)
return await queryset.filter(id__in=[r.get("id") for r in ids]).all() return result['total'], await queryset.filter(id__in=result['items']).all()
@classmethod @classmethod
async def get(cls, query: str) -> Page[T]: async def get(cls, query: str) -> Page[T]:
params = cls.get_params() params = cls.get_params()
objects_task = asyncio.create_task(cls.get_objects(query)) total, objects = await cls.get_objects(query)
total_task = asyncio.create_task(cls.get_objects_count(query))
await asyncio.wait({objects_task, total_task})
objects = objects_task.result()
total = total_task.result()
return CustomPage.create( return CustomPage.create(
items=objects, items=objects,