From 2a6950e2cfac03c0dbd2bf30511bc42836f86232 Mon Sep 17 00:00:00 2001 From: Kurbanov Bulat Date: Fri, 19 Nov 2021 15:01:51 +0300 Subject: [PATCH] Optimize search queries --- .../app/alembic/versions/08193b547a80_.py | 30 ++++++ fastapi_book_server/app/models.py | 4 +- fastapi_book_server/app/services/common.py | 93 +++++++++---------- 3 files changed, 75 insertions(+), 52 deletions(-) create mode 100644 fastapi_book_server/app/alembic/versions/08193b547a80_.py diff --git a/fastapi_book_server/app/alembic/versions/08193b547a80_.py b/fastapi_book_server/app/alembic/versions/08193b547a80_.py new file mode 100644 index 0000000..94e3cc1 --- /dev/null +++ b/fastapi_book_server/app/alembic/versions/08193b547a80_.py @@ -0,0 +1,30 @@ +"""empty message + +Revision ID: 08193b547a80 +Revises: b44117a41998 +Create Date: 2021-11-19 14:04:16.589304 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '08193b547a80' +down_revision = 'b44117a41998' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_index(op.f('ix_books_title'), 'books', ['title'], unique=False) + op.create_index(op.f('ix_sequences_name'), 'sequences', ['name'], unique=False) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f('ix_sequences_name'), table_name='sequences') + op.drop_index(op.f('ix_books_title'), table_name='books') + # ### end Alembic commands ### diff --git a/fastapi_book_server/app/models.py b/fastapi_book_server/app/models.py index 251d265..1c58ed5 100644 --- a/fastapi_book_server/app/models.py +++ b/fastapi_book_server/app/models.py @@ -79,7 +79,7 @@ class Sequence(ormar.Model): source: Source = ormar.ForeignKey(Source, nullable=False) remote_id: int = ormar.Integer(minimum=0, nullable=False) # type: ignore - name: str = ormar.String(max_length=256, nullable=False) # type: ignore + name: str = ormar.String(max_length=256, nullable=False, index=True) # type: ignore class BookAuthors(ormar.Model): @@ -129,7 +129,7 @@ class Book(ormar.Model): source: Source = ormar.ForeignKey(Source, nullable=False) remote_id: int = ormar.Integer(minimum=0, nullable=False) # type: ignore - title: str = ormar.String(max_length=256, nullable=False) # type: ignore + title: str = ormar.String(max_length=256, nullable=False, index=True) # type: ignore lang: str = ormar.String(max_length=3, nullable=False) # type: ignore file_type: str = ormar.String(max_length=4, nullable=False) # type: ignore uploaded: date = ormar.Date() # type: ignore diff --git a/fastapi_book_server/app/services/common.py b/fastapi_book_server/app/services/common.py index e9955c6..8a0ce5a 100644 --- a/fastapi_book_server/app/services/common.py +++ b/fastapi_book_server/app/services/common.py @@ -1,13 +1,14 @@ -from typing import Optional, Generic, TypeVar, Union, cast +from typing import Optional, Generic, TypeVar, Union from itertools import permutations -import asyncio +import json from fastapi_pagination.api import resolve_params from fastapi_pagination.bases import AbstractParams, RawParams from app.utils.pagination import Page, CustomPage from ormar import Model, QuerySet -from sqlalchemy import text, func, select, desc, Table, Column +from sqlalchemy import text, func, select, or_, Table, Column +from sqlalchemy.orm import Session from databases import Database @@ -59,60 +60,57 @@ class TRGMSearchService(Generic[T]): assert cls.FIELDS is not None, f"FIELDS in {cls.__name__} don't set!" assert len(cls.FIELDS) != 0, f"FIELDS in {cls.__name__} must be not empty!" - if len(cls.FIELDS) == 1: - return cls.FIELDS - - combinations = [] - - for i in range(1, len(cls.FIELDS)): - combinations += permutations(cls.FIELDS, i) - - return combinations + return permutations(cls.FIELDS, len(cls.FIELDS)) @classmethod def get_similarity_subquery(cls, query: str): + combs = cls.fields_combinations + return func.greatest( - *[func.similarity(join_fields(comb), f"{query}::text") for comb in cls.fields_combinations] + *[func.similarity(join_fields(comb), f"{query}::text") for comb in combs] ).label("sml") @classmethod - def get_object_ids_query(cls, query: str): - similarity = cls.get_similarity_subquery(query) - params = cls.get_raw_params() - - return select( - [cls.table.c.id], - ).where( - similarity > 0.5 - ).order_by( - desc(similarity) - ).limit(params.limit).offset(params.offset) - - @classmethod - def get_objects_count_query(cls, query: str): - similarity = cls.get_similarity_subquery(query) - - return select( - func.count(cls.table.c.id) - ).where( - similarity > 0.5 + def get_similarity_filter_subquery(cls, query: str): + return or_( + *[join_fields(comb) % f"{query}::text" for comb in cls.fields_combinations] ) @classmethod - async def get_objects_count(cls, query: str) -> int: - count_query = cls.get_objects_count_query(query) + async def get_objects(cls, query_data: str) -> tuple[int, list[T]]: + similarity = cls.get_similarity_subquery(query_data) + similarity_filter = cls.get_similarity_filter_subquery(query_data) - count_row = await cls.database.fetch_one(count_query) + params = cls.get_raw_params() + + session = Session(cls.database.connection()) - assert count_row is not None + q1 = session.query( + cls.table.c.id, similarity + ).order_by( + text('sml DESC') + ).filter( + cls.table.c.is_deleted == False, + similarity_filter + ).cte('objs') - return cast(int, count_row.get("count_1")) + sq = session.query(q1.c.id).limit(params.limit).offset(params.offset).subquery() - @classmethod - async def get_objects(cls, query: str) -> list[T]: - ids_query = cls.get_object_ids_query(query) + q2 = session.query( + func.json_build_object( + text("'total'"), func.count(q1.c.id), + text("'items'"), select(func.array_to_json(func.array_agg(sq.c.id))) + ) + ).cte() - ids = await cls.database.fetch_all(ids_query) + print(str(q2)) + + row = await cls.database.fetch_one(q2) + + if row is None: + raise ValueError('Something is wrong!') + + result = json.loads(row['json_build_object_1']) queryset: QuerySet[T] = cls.model.objects @@ -122,19 +120,14 @@ class TRGMSearchService(Generic[T]): if cls.SELECT_RELATED: queryset = queryset.select_related(cls.SELECT_RELATED) - return await queryset.filter(id__in=[r.get("id") for r in ids]).all() + return result['total'], await queryset.filter(id__in=result['items']).all() + @classmethod async def get(cls, query: str) -> Page[T]: params = cls.get_params() - objects_task = asyncio.create_task(cls.get_objects(query)) - total_task = asyncio.create_task(cls.get_objects_count(query)) - - await asyncio.wait({objects_task, total_task}) - - objects = objects_task.result() - total = total_task.result() + total, objects = await cls.get_objects(query) return CustomPage.create( items=objects,