This commit is contained in:
2021-11-14 10:38:47 +03:00
commit 30835e31fa
43 changed files with 2366 additions and 0 deletions

View File

@@ -0,0 +1,13 @@
from app.models import Author
from app.services.common import TRGMSearchService
class AuthorTGRMSearchService(TRGMSearchService):
MODEL = Author
FIELDS = [
Author.Meta.table.c.last_name,
Author.Meta.table.c.first_name,
Author.Meta.table.c.middle_name
]
PREFETCH_RELATED = ["source"]

View File

@@ -0,0 +1,67 @@
from typing import Union
from fastapi import HTTPException, status
from app.models import Book as BookDB, Author as AuthorDB
from app.services.common import TRGMSearchService
from app.serializers.book import CreateBook, CreateRemoteBook
class BookTGRMSearchService(TRGMSearchService):
MODEL = BookDB
FIELDS = [
BookDB.Meta.table.c.title
]
PREFETCH_RELATED = ["source"]
class BookCreator:
@classmethod
def _raise_bad_request(cls):
raise HTTPException(status.HTTP_404_NOT_FOUND)
@classmethod
async def _create_book(cls, data: CreateBook) -> BookDB:
data_dict = data.dict()
author_ids = data_dict.pop("authors", [])
authors = await AuthorDB.objects.filter(id__in=author_ids).all()
if len(author_ids) != len(authors):
cls._raise_bad_request()
book = await BookDB.objects.create(
**data_dict
)
for author in authors:
await book.authors.add(author)
return book
@classmethod
async def _create_remote_book(cls, data: CreateRemoteBook) -> BookDB:
data_dict = data.dict()
author_ids = data_dict.pop("remote_authors", [])
authors = await AuthorDB.objects.filter(source__id=data.source, remote_id__in=author_ids).all()
if len(author_ids) != len(authors):
cls._raise_bad_request()
book = await BookDB.objects.create(
**data_dict
)
for author in authors:
await book.authors.add(author)
return book
@classmethod
async def create(cls, data: Union[CreateBook, CreateRemoteBook]) -> BookDB:
if isinstance(data, CreateBook):
return await cls._create_book(data)
if isinstance(data, CreateRemoteBook):
return await cls._create_remote_book(data)

View File

@@ -0,0 +1,134 @@
from typing import Optional, Generic, TypeVar, Union, Any, cast
from itertools import permutations
from fastapi_pagination.api import resolve_params
from fastapi_pagination.bases import RawParams
from app.utils.pagination import CustomPage
from ormar import Model, QuerySet
from sqlalchemy import text, func, select, desc, Table, Column
from databases import Database
def join_fields(fields):
result = fields[0]
for el in fields[1:]:
result += text("' '") + el
return result
T = TypeVar('T', bound=Model)
class TRGMSearchService(Generic[T]):
MODEL_CLASS: Optional[T] = None
FIELDS: Optional[list[Column]] = None
SELECT_RELATED: Optional[Union[list[str], str]] = None
PREFETCH_RELATED: Optional[Union[list[str], str]] = None
@classmethod
def get_params(cls) -> RawParams:
return resolve_params().to_raw_params()
@classmethod
@property
def model(cls) -> T:
assert cls.MODEL_CLASS is not None, f"MODEL in {cls.__name__} don't set!"
return cls.MODEL_CLASS
@classmethod
@property
def table(cls) -> Table:
return cls.model.Meta.table
@classmethod
@property
def database(cls) -> Database:
return cls.model.Meta.database
@classmethod
@property
def fields_combinations(cls):
assert cls.FIELDS is not None, f"FIELDS in {cls.__name__} don't set!"
assert len(cls.FIELDS) == 0, f"FIELDS in {cls.__name__} must be not empty!"
if len(cls.FIELDS) == 1:
return cls.FIELDS
combinations = []
for i in range(1, len(cls.FIELDS)):
combinations += permutations(cls.FIELDS, i)
return combinations
@classmethod
def get_similarity_subquery(cls, query: str):
return func.greatest(
*[func.similarity(join_fields(comb), f"{query}::text") for comb in cls.fields_combinations]
).label("sml")
@classmethod
def get_object_ids_query(cls, query: str):
similarity = cls.get_similarity_subquery(query)
params = cls.get_params()
return select(
[cls.table.c.id],
).where(
similarity > 0.5
).order_by(
desc(similarity)
).limit(params.limit).offset(params.offset)
@classmethod
def get_objects_count_query(cls, query: str):
similarity = cls.get_similarity_subquery(query)
return select(
func.count(cls.table.c.id)
).where(
similarity > 0.5
)
@classmethod
async def get_objects_count(cls, query: str) -> int:
count_query = cls.get_objects_count_query(query)
count_row = await cls.database.fetch_one(count_query)
assert count_row is not None
return cast(int, count_row.get("count_1"))
@classmethod
async def get_objects(cls, query: str) -> list[T]:
ids_query = cls.get_object_ids_query(query)
ids = await cls.database.fetch_all(ids_query)
queryset: QuerySet[T] = cls.model.objects
if cls.PREFETCH_RELATED is not None:
queryset = queryset.prefetch_related(cls.PREFETCH_RELATED)
if cls.SELECT_RELATED:
queryset = queryset.select_related(cls.SELECT_RELATED)
return await queryset.filter(id__in=[r.get("id") for r in ids]).all()
@classmethod
async def get(cls, query: str) -> CustomPage[T]:
params = cls.get_params()
authors = await cls.get_objects(query)
total = await cls.get_objects_count(query)
return CustomPage(
items=authors,
total=total,
limit=params.limit,
offset=params.offset
)

View File

@@ -0,0 +1,11 @@
from app.models import Sequence
from app.services.common import TRGMSearchService
class SequenceTGRMSearchService(TRGMSearchService):
MODEL = Sequence
FIELDS = [
Sequence.Meta.table.c.name
]
PREFETCH_RELATED = ["source"]

View File

@@ -0,0 +1,46 @@
from typing import Union
from fastapi import HTTPException, status
from app.models import SequenceInfo as SequenceInfoDB, Source as SourceDB, Book as BookDB, Sequence as SequenceDB
from app.serializers.sequence_info import CreateSequenceInfo, CreateRemoteSequenceInfo
class SequenceInfoCreator:
@classmethod
def _raise_bad_request(cls):
raise HTTPException(status.HTTP_404_NOT_FOUND)
@classmethod
async def _create_sequence_info(cls, data: CreateSequenceInfo) -> SequenceInfoDB:
return await SequenceInfoDB.objects.create(**data.dict())
@classmethod
async def _create_remote_sequence_info(cls, data: CreateRemoteSequenceInfo) -> SequenceInfoDB:
source = await SourceDB.objects.get_or_none(id=data.source)
if source is None:
cls._raise_bad_request()
book = await BookDB.objects.get_or_none(source__id=source.id, remote_id=data.remote_book)
if book is None:
cls._raise_bad_request()
sequence = await SequenceDB.objects.get_or_none(source__id=source.id, remote_id=data.remote_sequence)
if sequence is None:
cls._raise_bad_request()
return await SequenceInfoDB.objects.create(
book=book.id,
sequence=sequence.id,
position=data.position,
)
@classmethod
async def create(cls, data: Union[CreateSequenceInfo, CreateRemoteSequenceInfo]) -> SequenceInfoDB:
if isinstance(data, CreateSequenceInfo):
return await cls._create_sequence_info(data)
if isinstance(data, CreateRemoteSequenceInfo):
return await cls._create_remote_sequence_info(data)

View File

@@ -0,0 +1,49 @@
from typing import Union
from fastapi import HTTPException, status
from app.serializers.translation import CreateTranslation, CreateRemoteTranslation
from app.models import Translation as TranslationDB, Source as SourceDB, Book as BookDB, Author as AuthorDB
class TranslationCreator:
@classmethod
def _raise_bad_request(cls):
raise HTTPException(status.HTTP_404_NOT_FOUND)
@classmethod
async def _create_translation(cls, data: CreateTranslation) -> TranslationDB:
return await TranslationDB.objects.create(
**data.dict()
)
@classmethod
async def _create_remote_translation(cls, data: CreateRemoteTranslation) -> TranslationDB:
source = await SourceDB.objects.get_or_none(id=data.source)
if source is None:
cls._raise_bad_request()
book = await BookDB.objects.get_or_none(source__id=source.id, remote_id=data.remote_book)
if book is None:
cls._raise_bad_request()
translator = await AuthorDB.objects.get_or_none(source__id=source.id, remote_id=data.remote_translator)
if translator is None:
cls._raise_bad_request()
return await TranslationDB.objects.create(
book=book.id,
translator=translator.id,
position=data.position,
)
@classmethod
async def create(cls, data: Union[CreateTranslation, CreateRemoteTranslation]) -> TranslationDB:
if isinstance(data, CreateTranslation):
return await cls._create_translation(data)
if isinstance(data, CreateRemoteTranslation):
return await cls._create_remote_translation(data)