Rewrite to rust

This commit is contained in:
2023-08-09 01:33:30 +02:00
parent 1d1cd63e7b
commit bbaa343547
54 changed files with 12525 additions and 2917 deletions

BIN
src/.DS_Store vendored Normal file

Binary file not shown.

View File

@@ -1,98 +0,0 @@
# A generic, single database configuration.
[alembic]
# path to migration scripts
script_location = ./app/alembic
# template used to generate migration files
# file_template = %%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the python-dateutil library that can be
# installed by adding `alembic[tz]` to the pip requirements
# string value is passed to dateutil.tz.gettz()
# leave blank for localtime
# timezone =
# max length of characters to apply to the
# "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to alembic/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "version_path_separator"
# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions
# version path separator; As mentioned above, this is the character used to split
# version_locations. Valid values are:
#
# version_path_separator = :
# version_path_separator = ;
# version_path_separator = space
version_path_separator = os # default: use os.pathsep
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# Logging configuration
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

View File

@@ -1 +0,0 @@
Generic single-database configuration.

View File

@@ -1,67 +0,0 @@
import os
import sys
from alembic import context
from sqlalchemy.engine import create_engine
from core.db import DATABASE_URL
myPath = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, myPath + "/../../")
config = context.config
from app.models import BaseMeta
target_metadata = BaseMeta.metadata
def run_migrations_offline():
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online():
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = create_engine(DATABASE_URL)
with connectable.connect() as connection:
context.configure(
connection=connection, target_metadata=target_metadata, compare_type=True
)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

View File

@@ -1,24 +0,0 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision = ${repr(up_revision)}
down_revision = ${repr(down_revision)}
branch_labels = ${repr(branch_labels)}
depends_on = ${repr(depends_on)}
def upgrade():
${upgrades if upgrades else "pass"}
def downgrade():
${downgrades if downgrades else "pass"}

View File

@@ -1,32 +0,0 @@
"""empty message
Revision ID: 62d57916ec53
Revises: f77b0b14f9eb
Create Date: 2022-12-30 23:30:50.867163
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "62d57916ec53"
down_revision = "f77b0b14f9eb"
branch_labels = None
depends_on = None
def upgrade():
op.drop_column("cached_files", "data")
op.create_unique_constraint(
"uc_cached_files_message_id_chat_id", "cached_files", ["message_id", "chat_id"]
)
op.create_index(
op.f("ix_cached_files_message_id"), "cached_files", ["message_id"], unique=True
)
def downgrade():
op.add_column("cached_files", sa.Column("data", sa.JSON(), nullable=False))
op.drop_constraint("uc_cached_files_message_id_chat_id", "cached_files")
op.drop_index("ix_cached_files_message_id", "cached_files")

View File

@@ -1,49 +0,0 @@
"""empty message
Revision ID: 9b7cfb422191
Revises:
Create Date: 2021-11-21 14:09:17.478532
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "9b7cfb422191"
down_revision = None
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"cached_files",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("object_id", sa.Integer(), nullable=False),
sa.Column("object_type", sa.String(length=8), nullable=False),
sa.Column("data", sa.JSON(), nullable=False),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint(
"object_id", "object_type", name="uc_cached_files_object_id_object_type"
),
)
op.create_index(
op.f("ix_cached_files_object_id"), "cached_files", ["object_id"], unique=False
)
op.create_index(
op.f("ix_cached_files_object_type"),
"cached_files",
["object_type"],
unique=False,
)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("cached_files")
op.drop_index("ix_cached_files_object_id")
op.drop_index("ix_cached_files_object_type")
# ### end Alembic commands ###

View File

@@ -1,28 +0,0 @@
"""empty message
Revision ID: f77b0b14f9eb
Revises: 9b7cfb422191
Create Date: 2022-12-30 22:53:41.951490
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "f77b0b14f9eb"
down_revision = "9b7cfb422191"
branch_labels = None
depends_on = None
def upgrade():
op.add_column(
"cached_files", sa.Column("message_id", sa.BigInteger(), nullable=True)
)
op.add_column("cached_files", sa.Column("chat_id", sa.BigInteger(), nullable=True))
def downgrade():
op.drop_column("cached_files", "message_id")
op.drop_column("cached_files", "chat_id")

View File

@@ -1,18 +0,0 @@
from fastapi import HTTPException, Request, Security, status
from redis.asyncio import ConnectionPool
from taskiq import TaskiqDepends
from core.auth import default_security
from core.config import env_config
async def check_token(api_key: str = Security(default_security)):
if api_key != env_config.API_KEY:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Wrong api key!"
)
def get_redis_pool(request: Request = TaskiqDepends()) -> ConnectionPool:
return request.app.state.redis_pool

View File

@@ -1,30 +0,0 @@
import ormar
from core.db import database, metadata
class BaseMeta(ormar.ModelMeta):
metadata = metadata
database = database
class CachedFile(ormar.Model):
class Meta(BaseMeta):
tablename = "cached_files"
constraints = [
ormar.UniqueColumns("object_id", "object_type"),
ormar.UniqueColumns("message_id", "chat_id"),
]
id: int = ormar.Integer(primary_key=True) # type: ignore
object_id: int = ormar.Integer(index=True) # type: ignore
object_type: str = ormar.String(
max_length=8, index=True, unique=True
) # type: ignore
message_id: int = ormar.BigInteger(index=True) # type: ignore
chat_id: int = ormar.BigInteger() # type: ignore
@ormar.property_field
def data(self) -> dict:
return {"message_id": self.message_id, "chat_id": self.chat_id}

View File

@@ -1,14 +0,0 @@
from pydantic import BaseModel, constr
class CachedFile(BaseModel):
id: int
object_id: int
object_type: str
data: dict
class CreateCachedFile(BaseModel):
object_id: int
object_type: constr(max_length=8) # type: ignore
data: dict

View File

@@ -1,167 +0,0 @@
import collections
from datetime import date, timedelta
from io import BytesIO
from typing import Optional
from fastapi import UploadFile
import httpx
from redis.asyncio import ConnectionPool, Redis
from taskiq import TaskiqDepends
from app.depends import get_redis_pool
from app.models import CachedFile
from app.services.caption_getter import get_caption
from app.services.downloader import download
from app.services.files_client import upload_file
from app.services.library_client import Book, get_book, get_books
from core.taskiq_worker import broker
PAGE_SIZE = 100
class Retry(Exception):
pass
class FileTypeNotAllowed(Exception):
def __init__(self, message: str) -> None:
super().__init__(message)
@broker.task
async def check_books_page(
page_number: int, uploaded_gte: str, uploaded_lte: str
) -> bool:
page = await get_books(
page_number,
PAGE_SIZE,
uploaded_gte=date.fromisoformat(uploaded_gte),
uploaded_lte=date.fromisoformat(uploaded_lte),
)
object_ids = [book.id for book in page.items]
cached_files = await CachedFile.objects.filter(object_id__in=object_ids).all()
cached_files_map = collections.defaultdict(set)
for cached_file in cached_files:
cached_files_map[cached_file.object_id].add(cached_file.object_type)
for book in page.items:
for file_type in book.available_types:
if file_type not in cached_files_map[book.id]:
await cache_file_by_book_id.kiq(
book_id=book.id,
file_type=file_type,
by_request=False,
)
return True
@broker.task
async def check_books(*args, **kwargs) -> bool:
uploaded_lte = date.today() + timedelta(days=1)
uploaded_gte = date.today() - timedelta(days=1)
books_page = await get_books(
1, PAGE_SIZE, uploaded_gte=uploaded_gte, uploaded_lte=uploaded_lte
)
for page_number in range(1, books_page.pages + 1):
await check_books_page.kiq(
page_number,
uploaded_gte=uploaded_gte.isoformat(),
uploaded_lte=uploaded_lte.isoformat(),
)
return True
async def cache_file(book: Book, file_type: str) -> Optional[CachedFile]:
if await CachedFile.objects.filter(
object_id=book.id, object_type=file_type
).exists():
return
try:
data = await download(book.source.id, book.remote_id, file_type)
except httpx.HTTPError:
data = None
if data is None:
raise Retry
response, client, filename = data
caption = get_caption(book)
temp_file = UploadFile(BytesIO(), filename=filename)
async for chunk in response.aiter_bytes(2048):
await temp_file.write(chunk)
file_size = temp_file.file.tell()
await temp_file.seek(0)
await response.aclose()
await client.aclose()
upload_data = await upload_file(temp_file.file, file_size, filename, caption)
if upload_data is None:
return None
cached_file, created = await CachedFile.objects.get_or_create(
{
"message_id": upload_data.data["message_id"],
"chat_id": upload_data.data["chat_id"],
},
object_id=book.id,
object_type=file_type,
)
if created:
return cached_file
cached_file.message_id = upload_data.data["message_id"]
cached_file.chat_id = upload_data.data["chat_id"]
return await cached_file.update(["message_id", "chat_id"])
@broker.task(retry_on_error=True)
async def cache_file_by_book_id(
book_id: int,
file_type: str,
by_request: bool = True,
redis_pool: ConnectionPool = TaskiqDepends(get_redis_pool),
) -> Optional[CachedFile]:
book = await get_book(book_id, 3)
if book is None:
if by_request:
return None
raise Retry
if file_type not in book.available_types:
return None
async with Redis(connection_pool=redis_pool) as redis_client:
lock = redis_client.lock(
f"{book_id}_{file_type}", blocking_timeout=5, thread_local=False
)
if await lock.locked() and not by_request:
raise Retry
try:
result = await cache_file(book, file_type)
except Retry as e:
if by_request:
return None
raise e
if by_request:
return result
return None

View File

@@ -1,41 +0,0 @@
from app.services.library_client import Book, BookAuthor
def get_author_string(author: BookAuthor) -> str:
author_parts = []
if author.last_name:
author_parts.append(author.last_name)
if author.first_name:
author_parts.append(author.first_name)
if author.middle_name:
author_parts.append(author.middle_name)
return " ".join(author_parts)
def get_caption(book: Book) -> str:
caption_title = f"📖 {book.title}"
caption_title_length = len(caption_title) + 3
caption_authors_parts = []
authors_caption_length = 0
for author in book.authors:
author_caption = f"👤 {get_author_string(author)}"
if (
caption_title_length + authors_caption_length + len(author_caption) + 3
) <= 1024:
caption_authors_parts.append(author_caption)
authors_caption_length += len(author_caption) + 3
else:
break
if not caption_authors_parts:
return caption_title
caption_authors = "\n".join(caption_authors_parts)
return caption_title + "\n\n" + caption_authors

View File

@@ -1,57 +0,0 @@
from base64 import b64decode
from typing import Optional
import httpx
from sentry_sdk import capture_exception
from core.config import env_config
async def download(
source_id: int, remote_id: int, file_type: str
) -> Optional[tuple[httpx.Response, httpx.AsyncClient, str]]:
headers = {"Authorization": env_config.DOWNLOADER_API_KEY}
client = httpx.AsyncClient(timeout=600)
request = client.build_request(
"GET",
f"{env_config.DOWNLOADER_URL}/download/{source_id}/{remote_id}/{file_type}",
headers=headers,
)
try:
response = await client.send(request, stream=True)
except httpx.ConnectError:
await client.aclose()
return None
if response.status_code != 200:
await response.aclose()
await client.aclose()
return None
name = b64decode(response.headers["x-filename-b64"]).decode()
return response, client, name
async def get_filename(book_id: int, file_type: str) -> Optional[tuple[str, str]]:
headers = {"Authorization": env_config.DOWNLOADER_API_KEY}
try:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{env_config.DOWNLOADER_URL}/filename/{book_id}/{file_type}",
headers=headers,
timeout=5 * 60,
)
if response.status_code != 200:
return None
data = response.json()
return data["filename"], data["filename_ascii"]
except httpx.HTTPError as e:
capture_exception(e)
return None

View File

@@ -1,67 +0,0 @@
from typing import BinaryIO, Optional
import httpx
from pydantic import BaseModel
from typing_extensions import TypedDict
from core.config import env_config
class Data(TypedDict):
chat_id: int
message_id: int
class UploadedFile(BaseModel):
backend: str
data: Data
async def upload_file(
content: BinaryIO, content_size: int, filename: str, caption: str
) -> Optional[UploadedFile]:
headers = {"Authorization": env_config.FILES_SERVER_API_KEY}
async with httpx.AsyncClient() as client:
form = {"caption": caption, "file_size": content_size}
files = {"file": (filename, content)}
response = await client.post(
f"{env_config.FILES_SERVER_URL}/api/v1/files/upload/",
data=form,
files=files,
headers=headers,
timeout=5 * 60,
)
if response.status_code != 200:
return None
return UploadedFile.parse_obj(response.json())
async def download_file(
chat_id: int, message_id: int
) -> Optional[tuple[httpx.Response, httpx.AsyncClient]]:
headers = {"Authorization": env_config.FILES_SERVER_API_KEY}
client = httpx.AsyncClient(timeout=60)
request = client.build_request(
"GET",
f"{env_config.FILES_SERVER_URL}"
f"/api/v1/files/download_by_message/{chat_id}/{message_id}",
headers=headers,
)
try:
response = await client.send(request, stream=True)
except httpx.ConnectError:
await client.aclose()
return None
if response.status_code != 200:
await response.aclose()
await client.aclose()
return None
return response, client

View File

@@ -1,123 +0,0 @@
from datetime import date
from typing import Generic, Optional, TypeVar
from urllib.parse import urlencode
import httpx
from pydantic import BaseModel
from sentry_sdk import capture_exception
from core.config import env_config
T = TypeVar("T")
class Page(BaseModel, Generic[T]):
items: list[T]
total: int
size: int
page: int
pages: int
class BaseBookInfo(BaseModel):
id: int
available_types: list[str]
class BookAuthor(BaseModel):
id: int
first_name: str
last_name: str
middle_name: str
class BookSource(BaseModel):
id: int
name: str
class Book(BaseModel):
id: int
title: str
file_type: str
available_types: list[str]
source: BookSource
remote_id: int
uploaded: date
authors: list[BookAuthor]
class BookDetail(Book):
is_deleted: bool
AUTH_HEADERS = {"Authorization": env_config.LIBRARY_API_KEY}
async def get_book(
book_id: int, retry: int = 3, last_exp: Exception | None = None
) -> Optional[BookDetail]:
if retry == 0:
if last_exp:
capture_exception(last_exp)
return None
try:
async with httpx.AsyncClient(timeout=2 * 60) as client:
response = await client.get(
f"{env_config.LIBRARY_URL}/api/v1/books/{book_id}", headers=AUTH_HEADERS
)
if response.status_code != 200:
return None
return BookDetail.parse_obj(response.json())
except httpx.HTTPError as e:
return await get_book(book_id, retry=retry - 1, last_exp=e)
async def get_books(
page: int,
page_size: int,
uploaded_gte: date | None = None,
uploaded_lte: date | None = None,
) -> Page[BaseBookInfo]:
params: dict[str, str] = {
"page": str(page),
"page_size": str(page_size),
"is_deleted": "false",
}
if uploaded_gte:
params["uploaded_gte"] = uploaded_gte.isoformat()
if uploaded_lte:
params["uploaded_lte"] = uploaded_lte.isoformat()
params_string = urlencode(params)
async with httpx.AsyncClient(timeout=5 * 60) as client:
response = await client.get(
f"{env_config.LIBRARY_URL}/api/v1/books/base/?{params_string}",
headers=AUTH_HEADERS,
)
data = response.json()
page_data = Page[BaseBookInfo].parse_obj(data)
page_data.items = [BaseBookInfo.parse_obj(item) for item in page_data.items]
return page_data
async def get_last_book_id() -> int:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{env_config.LIBRARY_URL}/api/v1/books/last", headers=AUTH_HEADERS
)
return int(response.text)

View File

@@ -1,24 +0,0 @@
from fastapi import HTTPException, status
from redis.asyncio import ConnectionPool
from app.models import CachedFile as CachedFileDB
from app.services.cache_updater import cache_file_by_book_id
async def get_cached_file_or_cache(
object_id: int, object_type: str, connection_pool: ConnectionPool
) -> CachedFileDB:
cached_file = await CachedFileDB.objects.get_or_none(
object_id=object_id, object_type=object_type
)
if not cached_file:
cached_file = await cache_file_by_book_id(
object_id, object_type, redis_pool=connection_pool
)
if not cached_file:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
return cached_file

View File

@@ -1,145 +0,0 @@
from base64 import b64encode
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import StreamingResponse
from starlette.background import BackgroundTask
from redis.asyncio import ConnectionPool
from app.depends import check_token, get_redis_pool
from app.models import CachedFile as CachedFileDB
from app.serializers import CachedFile, CreateCachedFile
from app.services.cache_updater import cache_file_by_book_id, check_books
from app.services.caption_getter import get_caption
from app.services.downloader import get_filename
from app.services.files_client import download_file as download_file_from_cache
from app.services.library_client import get_book
from app.utils import get_cached_file_or_cache
router = APIRouter(
prefix="/api/v1", tags=["files"], dependencies=[Depends(check_token)]
)
@router.get("/{object_id}/{object_type}", response_model=CachedFile)
async def get_cached_file(
redis_pool: Annotated[ConnectionPool, Depends(get_redis_pool)],
object_id: int,
object_type: str,
):
cached_file = await CachedFileDB.objects.get_or_none(
object_id=object_id, object_type=object_type
)
if not cached_file:
cached_file = await cache_file_by_book_id(
object_id, object_type, by_request=True, redis_pool=redis_pool
)
if not cached_file:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
return cached_file
@router.get("/download/{object_id}/{object_type}")
async def download_cached_file(request: Request, object_id: int, object_type: str):
cached_file = await get_cached_file_or_cache(
object_id, object_type, request.app.state.redis_pool
)
cache_data: dict = cached_file.data # type: ignore
data = await download_file_from_cache(
cache_data["chat_id"], cache_data["message_id"]
)
if data is None:
await CachedFileDB.objects.filter(id=cached_file.id).delete()
cached_file = await get_cached_file_or_cache(
object_id, object_type, request.app.state.redis_pool
)
cache_data: dict = cached_file.data # type: ignore
data = await download_file_from_cache(
cache_data["chat_id"], cache_data["message_id"]
)
if data is None:
raise HTTPException(status_code=status.HTTP_204_NO_CONTENT)
if (filename_data := await get_filename(object_id, object_type)) is None:
raise HTTPException(status_code=status.HTTP_204_NO_CONTENT)
if (book := await get_book(object_id)) is None:
raise HTTPException(status_code=status.HTTP_204_NO_CONTENT)
response, client = data
async def close():
await response.aclose()
await client.aclose()
filename, filename_ascii = filename_data
return StreamingResponse(
response.aiter_bytes(),
headers={
"Content-Disposition": f"attachment; filename={filename_ascii}",
"X-Caption-B64": b64encode(get_caption(book).encode("utf-8")).decode(),
"X-Filename-B64": b64encode(filename.encode("utf-8")).decode(),
},
background=BackgroundTask(close),
)
@router.delete("/{object_id}/{object_type}", response_model=CachedFile)
async def delete_cached_file(object_id: int, object_type: str):
cached_file = await CachedFileDB.objects.get_or_none(
object_id=object_id, object_type=object_type
)
if not cached_file:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
await cached_file.delete()
return cached_file
@router.post("/", response_model=CachedFile)
async def create_or_update_cached_file(data: CreateCachedFile):
cached_file = await CachedFileDB.objects.get_or_none(
object_id=data.data["object_id"], object_type=data.data["object_type"]
)
if cached_file is not None:
cached_file.message_id = data.data["message_id"]
cached_file.chat_id = data.data["chat_id"]
return await cached_file.update()
return await CachedFileDB.objects.create(
object_id=data.object_id,
object_type=data.object_type,
message_id=data.data["message_id"],
chat_id=data.data["chat_id"],
)
@router.post("/update_cache")
async def update_cache(request: Request):
await check_books.kiq()
return "Ok!"
healthcheck_router = APIRouter(
tags=["healthcheck"],
)
@healthcheck_router.get("/healthcheck")
async def healthcheck():
return "Ok!"

66
src/config.rs Normal file
View File

@@ -0,0 +1,66 @@
use once_cell::sync::Lazy;
pub struct Config {
pub api_key: String,
pub postgres_user: String,
pub postgres_password: String,
pub postgres_host: String,
pub postgres_port: u32,
pub postgres_db: String,
pub downloader_api_key: String,
pub downloader_url: String,
pub library_api_key: String,
pub library_url: String,
pub files_api_key: String,
pub files_url: String,
pub telegram_api_id: i32,
pub telegram_api_hash: String,
pub telegram_bot_tokens: Vec<String>,
pub sentry_dsn: String
}
fn get_env(env: &'static str) -> String {
std::env::var(env).unwrap_or_else(|_| panic!("Cannot get the {} env variable", env))
}
impl Config {
pub fn load() -> Config {
Config {
api_key: get_env("API_KEY"),
postgres_user: get_env("POSTGRES_USER"),
postgres_password: get_env("POSTGRES_PASSWORD"),
postgres_host: get_env("POSTGRES_HOST"),
postgres_port: get_env("POSTGRES_PORT").parse().unwrap(),
postgres_db: get_env("POSTGRES_DB"),
downloader_url: get_env("DOWNLOADER_API_KEY"),
downloader_api_key: get_env("DOWNLOADER_URL"),
library_api_key: get_env("LIBRARY_API_KEY"),
library_url: get_env("LIBRARY_URL"),
files_api_key: get_env("FILES_SERVER_API_KEY"),
files_url: get_env("FILES_SERVER_URL"),
telegram_api_id: get_env("TELEGRAM_API_ID").parse().unwrap(),
telegram_api_hash: get_env("TELEGRAM_API_HASH"),
telegram_bot_tokens: serde_json::from_str(&get_env("TELEGRAM_BOT_TOKENS")).unwrap(),
sentry_dsn: get_env("SENTRY_DSN")
}
}
}
pub static CONFIG: Lazy<Config> = Lazy::new(|| {
Config::load()
});

View File

@@ -1,44 +0,0 @@
from fastapi import FastAPI
from fastapi.responses import ORJSONResponse
from prometheus_fastapi_instrumentator import Instrumentator
from redis.asyncio import ConnectionPool
from app.views import healthcheck_router, router
from core.config import REDIS_URL
from core.db import database
from core.taskiq_worker import broker
def start_app() -> FastAPI:
app = FastAPI(default_response_class=ORJSONResponse)
app.state.redis_pool = ConnectionPool.from_url(REDIS_URL)
app.include_router(router)
app.include_router(healthcheck_router)
@app.on_event("startup")
async def app_startup():
if not database.is_connected:
await database.connect()
if not broker.is_worker_process:
await broker.startup()
@app.on_event("shutdown")
async def app_shutdown():
if database.is_connected:
await database.disconnect()
if not broker.is_worker_process:
await broker.shutdown()
await app.state.redis_pool.disconnect()
Instrumentator(
should_ignore_untemplated=True,
excluded_handlers=["/docs", "/metrics", "/healthcheck"],
).instrument(app).expose(app, include_in_schema=True)
return app

View File

@@ -1,4 +0,0 @@
from fastapi.security import APIKeyHeader
default_security = APIKeyHeader(name="Authorization")

View File

@@ -1,33 +0,0 @@
from pydantic import BaseSettings
class EnvConfig(BaseSettings):
API_KEY: str
POSTGRES_USER: str
POSTGRES_PASSWORD: str
POSTGRES_HOST: str
POSTGRES_PORT: int
POSTGRES_DB: str
DOWNLOADER_API_KEY: str
DOWNLOADER_URL: str
LIBRARY_API_KEY: str
LIBRARY_URL: str
FILES_SERVER_API_KEY: str
FILES_SERVER_URL: str
REDIS_HOST: str
REDIS_PORT: int
REDIS_DB: int
SENTRY_DSN: str
env_config = EnvConfig() # type: ignore
REDIS_URL = (
f"redis://{env_config.REDIS_HOST}:{env_config.REDIS_PORT}/{env_config.REDIS_DB}"
)

View File

@@ -1,15 +0,0 @@
from urllib.parse import quote
from databases import Database
from sqlalchemy import MetaData
from core.config import env_config
DATABASE_URL = (
f"postgresql://{env_config.POSTGRES_USER}:{quote(env_config.POSTGRES_PASSWORD)}@"
f"{env_config.POSTGRES_HOST}:{env_config.POSTGRES_PORT}/{env_config.POSTGRES_DB}"
)
metadata = MetaData()
database = Database(DATABASE_URL, min_size=1, max_size=10)

View File

@@ -1,40 +0,0 @@
from inspect import signature
from typing import Any
from taskiq import SimpleRetryMiddleware
from taskiq.message import TaskiqMessage
from taskiq.result import TaskiqResult
from taskiq_dependencies.dependency import Dependency
class FastAPIRetryMiddleware(SimpleRetryMiddleware):
@staticmethod
def _remove_depends(
task_func: Any, message_kwargs: dict[str, Any]
) -> dict[str, Any]:
sig = signature(task_func)
keys_to_remove = []
for key in message_kwargs.keys():
param = sig.parameters.get(key, None)
if param is None:
continue
if isinstance(param.default, Dependency):
keys_to_remove.append(key)
for key in keys_to_remove:
message_kwargs.pop(key)
return message_kwargs
async def on_error(
self, message: TaskiqMessage, result: TaskiqResult[Any], exception: Exception
) -> None:
task_func = self.broker.available_tasks[message.task_name].original_func
message.kwargs = self._remove_depends(task_func, message.kwargs)
return await super().on_error(message, result, exception)

View File

@@ -1,17 +0,0 @@
import taskiq_fastapi
from taskiq_redis import ListQueueBroker, RedisAsyncResultBackend
from core.config import REDIS_URL
from core.taskiq_middlewares import FastAPIRetryMiddleware
broker = (
ListQueueBroker(url=REDIS_URL)
.with_result_backend(
RedisAsyncResultBackend(redis_url=REDIS_URL, result_ex_time=5 * 60)
)
.with_middlewares(FastAPIRetryMiddleware())
)
taskiq_fastapi.init(broker, "main:app")

19
src/db.rs Normal file
View File

@@ -0,0 +1,19 @@
use crate::{prisma::PrismaClient, config::CONFIG};
pub async fn get_prisma_client() -> PrismaClient {
let database_url: String = format!(
"postgresql://{}:{}@{}:{}/{}?connection_limit=1",
CONFIG.postgres_user,
CONFIG.postgres_password,
CONFIG.postgres_host,
CONFIG.postgres_port,
CONFIG.postgres_db
);
PrismaClient::_builder()
.with_url(database_url)
.build()
.await
.unwrap()
}

View File

@@ -1,11 +0,0 @@
import sentry_sdk
from app.services.cache_updater import Retry
from core.app import start_app
from core.config import env_config
if env_config.SENTRY_DSN:
sentry_sdk.init(dsn=env_config.SENTRY_DSN, ignore_errors=[Retry])
app = start_app()

30
src/main.rs Normal file
View File

@@ -0,0 +1,30 @@
pub mod config;
pub mod db;
pub mod prisma;
pub mod views;
pub mod services;
use std::net::SocketAddr;
use tracing::info;
use crate::views::get_router;
#[tokio::main]
async fn main() {
tracing_subscriber::fmt()
.with_target(false)
.compact()
.init();
let addr = SocketAddr::from(([0, 0, 0, 0], 8080));
let app = get_router().await;
info!("Start webserver...");
axum::Server::bind(&addr)
.serve(app.into_make_service())
.await
.unwrap();
info!("Webserver shutdown...")
}

1360
src/prisma.rs Normal file

File diff suppressed because one or more lines are too long

BIN
src/services/.DS_Store vendored Normal file

Binary file not shown.

View File

@@ -0,0 +1,49 @@
pub mod types;
use serde::de::DeserializeOwned;
use crate::config::CONFIG;
async fn _make_request<T>(
url: &str,
params: Vec<(&str, String)>,
) -> Result<T, Box<dyn std::error::Error + Send + Sync>>
where
T: DeserializeOwned,
{
let client = reqwest::Client::new();
let formated_url = format!("{}{}", CONFIG.library_url, url);
let response = client
.get(formated_url)
.query(&params)
.header("Authorization", CONFIG.library_api_key.clone())
.send()
.await;
let response = match response {
Ok(v) => v,
Err(err) => return Err(Box::new(err)),
};
let response = match response.error_for_status() {
Ok(v) => v,
Err(err) => return Err(Box::new(err)),
};
match response.json::<T>().await {
Ok(v) => Ok(v),
Err(err) => Err(Box::new(err)),
}
}
pub async fn get_sources() -> Result<types::Source, Box<dyn std::error::Error + Send + Sync>> {
_make_request("/api/v1/sources", vec![]).await
}
pub async fn get_book(
book_id: i32,
) -> Result<types::BookWithRemote, Box<dyn std::error::Error + Send + Sync>> {
_make_request(format!("/api/v1/books/{book_id}").as_str(), vec![]).await
}

View File

@@ -0,0 +1,108 @@
use serde::Deserialize;
#[derive(Deserialize, Debug, Clone)]
pub struct Source {
// id: u32,
// name: String
}
#[derive(Deserialize, Debug, Clone)]
pub struct BookAuthor {
pub id: u32,
pub first_name: String,
pub last_name: String,
pub middle_name: String,
}
#[derive(Deserialize, Debug, Clone)]
pub struct Book {
pub id: u32,
pub title: String,
pub lang: String,
pub file_type: String,
pub uploaded: String,
pub authors: Vec<BookAuthor>,
}
#[derive(Deserialize, Debug, Clone)]
pub struct BookWithRemote {
pub id: u32,
pub remote_id: u32,
pub title: String,
pub lang: String,
pub file_type: String,
pub uploaded: String,
pub authors: Vec<BookAuthor>,
}
impl BookWithRemote {
pub fn from_book(book: Book, remote_id: u32) -> Self {
Self {
id: book.id,
remote_id,
title: book.title,
lang: book.lang,
file_type: book.file_type,
uploaded: book.uploaded,
authors: book.authors
}
}
}
impl BookAuthor {
pub fn get_caption(self) -> String {
let mut parts: Vec<String> = vec![];
if !self.last_name.is_empty() {
parts.push(self.last_name);
}
if !self.first_name.is_empty() {
parts.push(self.first_name);
}
if !self.middle_name.is_empty() {
parts.push(self.middle_name);
}
let joined_parts = parts.join(" ");
format!("👤 {joined_parts}")
}
}
impl BookWithRemote {
pub fn get_caption(self) -> String {
let BookWithRemote {
title,
authors,
..
} = self;
let caption_title = format!("📖 {title}");
let author_captions: Vec<String> = authors
.into_iter()
.map(|a| a.get_caption())
.collect();
let mut author_parts: Vec<String> = vec![];
let mut author_parts_len = 3;
for author_caption in author_captions {
if caption_title.len() + author_parts_len + author_caption.len() + 1 <= 1024 {
author_parts_len = author_caption.len() + 1;
author_parts.push(author_caption);
} else {
break;
}
}
let caption_authors = author_parts.join("\n");
format!("{caption_title}\n\n{caption_authors}")
}
}

View File

@@ -0,0 +1,19 @@
use futures::TryStreamExt;
use reqwest::Response;
use tokio::io::AsyncRead;
use tokio_util::compat::FuturesAsyncReadCompatExt;
pub struct DownloadResult {
pub response: Response,
pub filename: String,
pub filename_ascii: String,
pub caption: String,
}
pub fn get_response_async_read(it: Response) -> impl AsyncRead {
it.bytes_stream()
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
.into_async_read()
.compat()
}

View File

@@ -0,0 +1,57 @@
use reqwest::Response;
use serde::Deserialize;
use crate::config::CONFIG;
#[derive(Deserialize)]
pub struct FilenameData {
pub filename: String,
pub filename_ascii: String
}
pub async fn download_from_downloader(
remote_id: u32,
object_id: i32,
object_type: String
) -> Result<Response, Box<dyn std::error::Error + Send + Sync>> {
let url = format!(
"{}/download/{remote_id}/{object_id}/{object_type}",
CONFIG.downloader_url
);
let response = reqwest::Client::new()
.get(url)
.header("Authorization", &CONFIG.downloader_api_key)
.send()
.await?
.error_for_status()?;
Ok(response)
}
pub async fn get_filename(
object_id: i32,
object_type: String
) -> Result<FilenameData, Box<dyn std::error::Error + Send + Sync>> {
let url = format!(
"{}/filename/{object_id}/{object_type}",
CONFIG.downloader_url
);
let response = reqwest::Client::new()
.get(url)
.header("Authorization", &CONFIG.downloader_api_key)
.send()
.await?
.error_for_status()?;
match response.json::<FilenameData>().await {
Ok(v) => Ok(v),
Err(err) => {
Err(Box::new(err))
},
}
}

124
src/services/mod.rs Normal file
View File

@@ -0,0 +1,124 @@
pub mod book_library;
pub mod download_utils;
pub mod telegram_files;
pub mod downloader;
use tracing::log;
use crate::{prisma::cached_file, views::Database};
use self::{download_utils::DownloadResult, telegram_files::{download_from_telegram_files, UploadData, upload_to_telegram_files}, downloader::{get_filename, FilenameData, download_from_downloader}, book_library::get_book};
pub async fn get_cached_file_or_cache(
object_id: i32,
object_type: String,
db: Database
) -> Option<cached_file::Data> {
let cached_file = db.cached_file()
.find_unique(cached_file::object_id_object_type(object_id, object_type.clone()))
.exec()
.await
.unwrap();
match cached_file {
Some(cached_file) => Some(cached_file),
None => cache_file(object_id, object_type, db).await,
}
}
pub async fn cache_file(
object_id: i32,
object_type: String,
db: Database
) -> Option<cached_file::Data> {
let book = match get_book(object_id).await {
Ok(v) => v,
Err(err) => {
log::error!("{:?}", err);
return None;
},
};
let downloader_result = match download_from_downloader(
book.remote_id,
object_id,
object_type.clone()
).await {
Ok(v) => v,
Err(err) => {
log::error!("{:?}", err);
return None;
},
};
let UploadData { chat_id, message_id } = match upload_to_telegram_files(
downloader_result,
book.get_caption()
).await {
Ok(v) => v,
Err(err) => {
log::error!("{:?}", err);
return None;
},
};
Some(
db
.cached_file()
.create(
object_id,
object_type,
message_id,
chat_id,
vec![]
)
.exec()
.await
.unwrap()
)
}
pub async fn download_from_cache(
cached_data: cached_file::Data,
) -> Option<DownloadResult> {
let response_task = tokio::task::spawn(download_from_telegram_files(cached_data.message_id, cached_data.chat_id));
let filename_task = tokio::task::spawn(get_filename(cached_data.object_id, cached_data.object_type.clone()));
let book_task = tokio::task::spawn(get_book(cached_data.object_id));
let response = match response_task.await.unwrap() {
Ok(v) => v,
Err(err) => {
log::error!("{:?}", err);
return None;
},
};
let filename_data = match filename_task.await.unwrap() {
Ok(v) => v,
Err(err) => {
log::error!("{:?}", err);
return None;
}
};
let book = match book_task.await.unwrap() {
Ok(v) => v,
Err(err) => {
log::error!("{:?}", err);
return None;
}
};
let FilenameData {filename, filename_ascii} = filename_data;
let caption = book.get_caption();
Some(DownloadResult {
response,
filename,
filename_ascii,
caption
})
}

View File

@@ -0,0 +1,87 @@
use reqwest::{Response, multipart::{Form, Part}, header};
use serde::Deserialize;
use crate::config::CONFIG;
#[derive(Deserialize)]
pub struct UploadData {
pub chat_id: i64,
pub message_id: i64
}
#[derive(Deserialize)]
pub struct UploadResult {
pub backend: String,
pub data: UploadData
}
pub async fn download_from_telegram_files(
message_id: i64,
chat_id: i64
) -> Result<Response, Box<dyn std::error::Error + Send + Sync>> {
let url = format!(
"{}/api/v1/files/download_by_message/{chat_id}/{message_id}",
CONFIG.files_url
);
let response = reqwest::Client::new()
.get(url)
.header("Authorization", CONFIG.library_api_key.clone())
.send()
.await?
.error_for_status()?;
Ok(response)
}
pub async fn upload_to_telegram_files(
data_response: Response,
caption: String
) -> Result<UploadData, Box<dyn std::error::Error + Send + Sync>> {
let url = format!(
"{}/api/v1/files/upload/",
CONFIG.files_url
);
let headers = data_response.headers();
let file_size = headers
.get(header::CONTENT_LENGTH)
.unwrap()
.to_str()
.unwrap()
.to_string();
let filename = headers
.get("x-filename-b64-ascii")
.unwrap()
.to_str()
.unwrap()
.to_string();
let part = Part::stream(data_response)
.file_name(filename);
let form = Form::new()
.text("caption", caption)
.text("file_size", file_size)
.part("file", part);
let response = reqwest::Client::new()
.post(url)
.multipart(form)
.send()
.await?
.error_for_status()?;
match response.json::<UploadResult>().await {
Ok(v) => Ok(v.data),
Err(err) => {
Err(Box::new(err))
},
}
}

158
src/views.rs Normal file
View File

@@ -0,0 +1,158 @@
use axum::{Router, response::{Response, IntoResponse, AppendHeaders}, http::{StatusCode, self, Request, header}, middleware::{Next, self}, Extension, routing::{get, delete, post}, extract::Path, Json, body::StreamBody};
use axum_prometheus::PrometheusMetricLayer;
use tokio_util::io::ReaderStream;
use tower_http::trace::{TraceLayer, self};
use tracing::Level;
use std::sync::Arc;
use base64::{engine::general_purpose, Engine};
use crate::{config::CONFIG, db::get_prisma_client, prisma::{PrismaClient, cached_file::{self}}, services::{get_cached_file_or_cache, download_from_cache, download_utils::get_response_async_read}};
pub type Database = Arc<PrismaClient>;
//
async fn get_cached_file(
Path((object_id, object_type)): Path<(i32, String)>,
Extension(Ext { db, .. }): Extension<Ext>
) -> impl IntoResponse {
match get_cached_file_or_cache(object_id, object_type, db).await {
Some(cached_file) => Json(cached_file).into_response(),
None => StatusCode::NOT_FOUND.into_response(),
}
}
async fn download_cached_file(
Path((object_id, object_type)): Path<(i32, String)>,
Extension(Ext { db }): Extension<Ext>
) -> impl IntoResponse {
let cached_file = match get_cached_file_or_cache(object_id, object_type, db).await {
Some(cached_file) => cached_file,
None => return StatusCode::NO_CONTENT.into_response(),
};
let data = match download_from_cache(cached_file).await {
Some(v) => v,
None => {
return StatusCode::NO_CONTENT.into_response();
}
};
let filename = data.filename.clone();
let filename_ascii = data.filename_ascii.clone();
let caption = data.caption.clone();
let encoder = general_purpose::STANDARD;
let reader = get_response_async_read(data.response);
let stream = ReaderStream::new(reader);
let body = StreamBody::new(stream);
let headers = AppendHeaders([
(
header::CONTENT_DISPOSITION,
format!("attachment; filename={filename_ascii}"),
),
(
header::HeaderName::from_static("x-filename-b64"),
encoder.encode(filename),
),
(
header::HeaderName::from_static("x-caption-b64"),
encoder.encode(caption)
)
]);
(headers, body).into_response()
}
async fn delete_cached_file(
Path((object_id, object_type)): Path<(i32, String)>,
Extension(Ext { db, .. }): Extension<Ext>
) -> impl IntoResponse {
let cached_file = db.cached_file()
.find_unique(cached_file::object_id_object_type(object_id, object_type.clone()))
.exec()
.await
.unwrap();
match cached_file {
Some(v) => {
db.cached_file()
.delete(cached_file::object_id_object_type(object_id, object_type))
.exec()
.await
.unwrap();
Json(v).into_response()
},
None => StatusCode::NOT_FOUND.into_response(),
}
}
async fn update_cache(
_ext: Extension<Ext>
) -> impl IntoResponse {
StatusCode::OK.into_response() // TODO
}
//
async fn auth<B>(req: Request<B>, next: Next<B>) -> Result<Response, StatusCode> {
let auth_header = req.headers()
.get(http::header::AUTHORIZATION)
.and_then(|header| header.to_str().ok());
let auth_header = if let Some(auth_header) = auth_header {
auth_header
} else {
return Err(StatusCode::UNAUTHORIZED);
};
if auth_header != CONFIG.api_key {
return Err(StatusCode::UNAUTHORIZED);
}
Ok(next.run(req).await)
}
#[derive(Clone)]
struct Ext {
pub db: Arc<PrismaClient>,
}
pub async fn get_router() -> Router {
let db = Arc::new(get_prisma_client().await);
let ext = Ext { db };
let (prometheus_layer, metric_handle) = PrometheusMetricLayer::pair();
let app_router = Router::new()
.route("/:object_id/:object_type/", get(get_cached_file))
.route("/download/:object_id/:object_type/", get(download_cached_file))
.route("/:object_id/:object_type/", delete(delete_cached_file))
.route("/update_cache", post(update_cache))
.layer(middleware::from_fn(auth))
.layer(Extension(ext))
.layer(prometheus_layer);
let metric_router = Router::new()
.route("/metrics", get(|| async move { metric_handle.render() }));
Router::new()
.nest("/api/v1/", app_router)
.nest("/", metric_router)
.layer(
TraceLayer::new_for_http()
.make_span_with(trace::DefaultMakeSpan::new()
.level(Level::INFO))
.on_response(trace::DefaultOnResponse::new()
.level(Level::INFO)),
)
}