mirror of
https://github.com/flibusta-apps/telegram_files_cache_server.git
synced 2025-12-06 06:35:38 +01:00
Rewrite to rust
This commit is contained in:
2
.cargo/config.toml
Normal file
2
.cargo/config.toml
Normal file
@@ -0,0 +1,2 @@
|
||||
[alias]
|
||||
prisma = "run -p prisma-cli --"
|
||||
35
.github/workflows/codeql-analysis.yml
vendored
35
.github/workflows/codeql-analysis.yml
vendored
@@ -1,35 +0,0 @@
|
||||
name: "CodeQL"
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
schedule:
|
||||
- cron: '0 12 * * *'
|
||||
|
||||
jobs:
|
||||
analyze:
|
||||
name: Analyze
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
actions: read
|
||||
contents: read
|
||||
security-events: write
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
language: [ 'python' ]
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Initialize CodeQL
|
||||
uses: github/codeql-action/init@v2
|
||||
with:
|
||||
languages: ${{ matrix.language }}
|
||||
|
||||
- name: Perform CodeQL Analysis
|
||||
uses: github/codeql-action/analyze@v2
|
||||
35
.github/workflows/linters.yaml
vendored
35
.github/workflows/linters.yaml
vendored
@@ -1,35 +0,0 @@
|
||||
name: Linters
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened]
|
||||
|
||||
jobs:
|
||||
Run-Pre-Commit:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 32
|
||||
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.11
|
||||
|
||||
- name: Install pre-commit
|
||||
run: pip3 install pre-commit
|
||||
|
||||
- name: Pre-commit (Push)
|
||||
env:
|
||||
SETUPTOOLS_USE_DISTUTILS: stdlib
|
||||
if: ${{ github.event_name == 'push' }}
|
||||
run: pre-commit run --source ${{ github.event.before }} --origin ${{ github.event.after }} --show-diff-on-failure
|
||||
|
||||
- name: Pre-commit (Pull-Request)
|
||||
env:
|
||||
SETUPTOOLS_USE_DISTUTILS: stdlib
|
||||
if: ${{ github.event_name == 'pull_request' }}
|
||||
run: pre-commit run --source ${{ github.event.pull_request.base.sha }} --origin ${{ github.event.pull_request.head.sha }} --show-diff-on-failure
|
||||
7
.gitignore
vendored
7
.gitignore
vendored
@@ -1,5 +1,4 @@
|
||||
/target
|
||||
|
||||
.vscode
|
||||
|
||||
__pycache__
|
||||
|
||||
venv
|
||||
.env
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
exclude: 'docs|node_modules|migrations|.git|.tox'
|
||||
|
||||
repos:
|
||||
- repo: https://github.com/ambv/black
|
||||
rev: 23.3.0
|
||||
hooks:
|
||||
- id: black
|
||||
language_version: python3.11
|
||||
|
||||
- repo: https://github.com/charliermarsh/ruff-pre-commit
|
||||
rev: 'v0.0.265'
|
||||
hooks:
|
||||
- id: ruff
|
||||
|
||||
- repo: https://github.com/crate-ci/typos
|
||||
rev: typos-dict-v0.9.26
|
||||
hooks:
|
||||
- id: typos
|
||||
5728
Cargo.lock
generated
Normal file
5728
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
40
Cargo.toml
Normal file
40
Cargo.toml
Normal file
@@ -0,0 +1,40 @@
|
||||
[package]
|
||||
name = "telegram_files_cache_server"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[workspace]
|
||||
members = [
|
||||
"prisma-cli"
|
||||
]
|
||||
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
once_cell = "1.18.0"
|
||||
|
||||
prisma-client-rust = { git = "https://github.com/Brendonovich/prisma-client-rust", tag = "0.6.8", features = ["postgresql"] }
|
||||
serde = { version = "1.0.163", features = ["derive"] }
|
||||
serde_json = "1.0.104"
|
||||
reqwest = { version = "0.11.18", features = ["json", "stream", "multipart"] }
|
||||
|
||||
tokio = { version = "1.28.2", features = ["full"] }
|
||||
tokio-util = { version = "0.7.8", features = ["compat"] }
|
||||
axum = { version = "0.6.18", features = ["json"] }
|
||||
axum-prometheus = "0.4.0"
|
||||
chrono = "0.4.26"
|
||||
sentry = "0.31.5"
|
||||
|
||||
tracing = "0.1.37"
|
||||
tracing-subscriber = { version = "0.3.17", features = ["env-filter"]}
|
||||
tower-http = { version = "0.4.3", features = ["trace"] }
|
||||
|
||||
base64 = "0.21.2"
|
||||
|
||||
grammers-client = "0.4.0"
|
||||
grammers-session = "0.4.0"
|
||||
|
||||
futures = "0.3.28"
|
||||
futures-core = "0.3.28"
|
||||
async-stream = "0.3.5"
|
||||
@@ -1,32 +1,21 @@
|
||||
FROM ghcr.io/flibusta-apps/base_docker_images:3.11-postgres-asyncpg-poetry-buildtime as build-image
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install git -y --no-install-recommends \
|
||||
&& rm -rf /var/cache/*
|
||||
|
||||
WORKDIR /root/poetry
|
||||
COPY pyproject.toml poetry.lock /root/poetry/
|
||||
|
||||
ENV VENV_PATH=/opt/venv
|
||||
|
||||
RUN poetry export --without-hashes > requirements.txt \
|
||||
&& . /opt/venv/bin/activate \
|
||||
&& pip install -r requirements.txt --no-cache-dir
|
||||
|
||||
|
||||
FROM ghcr.io/flibusta-apps/base_docker_images:3.11-postgres-runtime as runtime-image
|
||||
FROM rust:bullseye AS builder
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY ./src/ /app/
|
||||
COPY . .
|
||||
|
||||
ENV VENV_PATH=/opt/venv
|
||||
ENV PATH="$VENV_PATH/bin:$PATH"
|
||||
RUN cargo build --release --bin telegram_files_cache_server
|
||||
|
||||
COPY --from=build-image $VENV_PATH $VENV_PATH
|
||||
COPY ./scripts/start_production.sh /root/
|
||||
COPY ./scripts/healthcheck.py /root/
|
||||
|
||||
EXPOSE 8080
|
||||
FROM debian:bullseye-slim
|
||||
|
||||
CMD bash /root/start_production.sh
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y openssl ca-certificates \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN update-ca-certificates
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY --from=builder /app/target/release/telegram_files_cache_server /usr/local/bin
|
||||
ENTRYPOINT ["/usr/local/bin/telegram_files_cache_server"]
|
||||
|
||||
1561
poetry.lock
generated
1561
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
3
prisma-cli/.gitignore
vendored
Normal file
3
prisma-cli/.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
node_modules
|
||||
# Keep environment variables out of version control
|
||||
.env
|
||||
4622
prisma-cli/Cargo.lock
generated
Normal file
4622
prisma-cli/Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
9
prisma-cli/Cargo.toml
Normal file
9
prisma-cli/Cargo.toml
Normal file
@@ -0,0 +1,9 @@
|
||||
[package]
|
||||
name = "prisma-cli"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
prisma-client-rust-cli = { git = "https://github.com/Brendonovich/prisma-client-rust", tag = "0.6.8", features = ["postgresql"] }
|
||||
3
prisma-cli/src/main.rs
Normal file
3
prisma-cli/src/main.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
fn main() {
|
||||
prisma_client_rust_cli::run();
|
||||
}
|
||||
24
prisma/schema.prisma
Normal file
24
prisma/schema.prisma
Normal file
@@ -0,0 +1,24 @@
|
||||
generator client {
|
||||
provider = "cargo prisma"
|
||||
output = "../src/prisma.rs"
|
||||
}
|
||||
|
||||
datasource db {
|
||||
provider = "postgresql"
|
||||
url = env("DATABASE_URL")
|
||||
}
|
||||
|
||||
model CachedFile {
|
||||
id Int @id @default(autoincrement())
|
||||
object_id Int
|
||||
object_type String @db.VarChar(8)
|
||||
message_id BigInt @unique(map: "ix_cached_files_message_id")
|
||||
chat_id BigInt
|
||||
|
||||
@@unique([message_id, chat_id], map: "uc_cached_files_message_id_chat_id")
|
||||
@@unique([object_id, object_type], map: "uc_cached_files_object_id_object_type")
|
||||
@@index([object_id], map: "ix_cached_files_object_id")
|
||||
@@index([object_type], map: "ix_cached_files_object_type")
|
||||
|
||||
@@map("cached_files")
|
||||
}
|
||||
@@ -1,77 +0,0 @@
|
||||
[tool.poetry]
|
||||
name = "telegram_channel_files_manager"
|
||||
version = "0.1.0"
|
||||
description = ""
|
||||
authors = ["Kurbanov Bulat <kurbanovbul@gmail.com>"]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.11"
|
||||
fastapi = "^0.101.0"
|
||||
httpx = "^0.24.1"
|
||||
alembic = "^1.11.2"
|
||||
uvicorn = {extras = ["standard"], version = "^0.23.2"}
|
||||
prometheus-fastapi-instrumentator = "^6.1.0"
|
||||
uvloop = "^0.17.0"
|
||||
orjson = "^3.9.4"
|
||||
sentry-sdk = "^1.29.2"
|
||||
ormar = {extras = ["postgresql"], version = "^0.12.2"}
|
||||
pydantic = "^1.10.4"
|
||||
redis = {extras = ["hiredis"], version = "^4.6.0"}
|
||||
msgpack = "^1.0.5"
|
||||
taskiq = "^0.8.6"
|
||||
taskiq-redis = "^0.4.0"
|
||||
taskiq-fastapi = "^0.3.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pre-commit = "^2.21.0"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.black]
|
||||
include = '\.pyi?$'
|
||||
exclude = '''
|
||||
/(
|
||||
\.git
|
||||
| \.vscode
|
||||
| \venv
|
||||
| alembic
|
||||
)/
|
||||
'''
|
||||
|
||||
[tool.ruff]
|
||||
fix = true
|
||||
target-version = "py311"
|
||||
src = ["app"]
|
||||
line-length=88
|
||||
ignore = []
|
||||
select = ["B", "C", "E", "F", "W", "B9", "I001"]
|
||||
exclude = [
|
||||
# No need to traverse our git directory
|
||||
".git",
|
||||
# There's no value in checking cache directories
|
||||
"__pycache__",
|
||||
# The conf file is mostly autogenerated, ignore it
|
||||
"src/app/alembic",
|
||||
]
|
||||
|
||||
[tool.ruff.flake8-bugbear]
|
||||
extend-immutable-calls = ["fastapi.File", "fastapi.Form", "fastapi.Security", "taskiq.TaskiqDepends"]
|
||||
|
||||
[tool.ruff.mccabe]
|
||||
max-complexity = 15
|
||||
|
||||
[tool.ruff.isort]
|
||||
known-first-party = ["core", "app"]
|
||||
force-sort-within-sections = true
|
||||
force-wrap-aliases = true
|
||||
section-order = ["future", "standard-library", "base_framework", "framework_ext", "third-party", "first-party", "local-folder"]
|
||||
lines-after-imports = 2
|
||||
|
||||
[tool.ruff.isort.sections]
|
||||
base_framework = ["fastapi",]
|
||||
framework_ext = ["starlette"]
|
||||
|
||||
[tool.ruff.pyupgrade]
|
||||
keep-runtime-typing = true
|
||||
@@ -1,6 +0,0 @@
|
||||
import httpx
|
||||
|
||||
|
||||
response = httpx.get("http://localhost:8080/healthcheck")
|
||||
print(f"HEALTHCHECK STATUS: {response.status_code}")
|
||||
exit(0 if response.status_code == 200 else 1)
|
||||
@@ -1,7 +0,0 @@
|
||||
cd /app
|
||||
alembic -c ./app/alembic.ini upgrade head
|
||||
|
||||
rm -rf prometheus
|
||||
mkdir prometheus
|
||||
|
||||
uvicorn main:app --host 0.0.0.0 --port 8080 --loop uvloop
|
||||
BIN
src/.DS_Store
vendored
Normal file
BIN
src/.DS_Store
vendored
Normal file
Binary file not shown.
@@ -1,98 +0,0 @@
|
||||
# A generic, single database configuration.
|
||||
|
||||
[alembic]
|
||||
# path to migration scripts
|
||||
script_location = ./app/alembic
|
||||
|
||||
# template used to generate migration files
|
||||
# file_template = %%(rev)s_%%(slug)s
|
||||
|
||||
# sys.path path, will be prepended to sys.path if present.
|
||||
# defaults to the current working directory.
|
||||
prepend_sys_path = .
|
||||
|
||||
# timezone to use when rendering the date within the migration file
|
||||
# as well as the filename.
|
||||
# If specified, requires the python-dateutil library that can be
|
||||
# installed by adding `alembic[tz]` to the pip requirements
|
||||
# string value is passed to dateutil.tz.gettz()
|
||||
# leave blank for localtime
|
||||
# timezone =
|
||||
|
||||
# max length of characters to apply to the
|
||||
# "slug" field
|
||||
# truncate_slug_length = 40
|
||||
|
||||
# set to 'true' to run the environment during
|
||||
# the 'revision' command, regardless of autogenerate
|
||||
# revision_environment = false
|
||||
|
||||
# set to 'true' to allow .pyc and .pyo files without
|
||||
# a source .py file to be detected as revisions in the
|
||||
# versions/ directory
|
||||
# sourceless = false
|
||||
|
||||
# version location specification; This defaults
|
||||
# to alembic/versions. When using multiple version
|
||||
# directories, initial revisions must be specified with --version-path.
|
||||
# The path separator used here should be the separator specified by "version_path_separator"
|
||||
# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions
|
||||
|
||||
# version path separator; As mentioned above, this is the character used to split
|
||||
# version_locations. Valid values are:
|
||||
#
|
||||
# version_path_separator = :
|
||||
# version_path_separator = ;
|
||||
# version_path_separator = space
|
||||
version_path_separator = os # default: use os.pathsep
|
||||
|
||||
# the output encoding used when revision files
|
||||
# are written from script.py.mako
|
||||
# output_encoding = utf-8
|
||||
|
||||
|
||||
[post_write_hooks]
|
||||
# post_write_hooks defines scripts or Python functions that are run
|
||||
# on newly generated revision scripts. See the documentation for further
|
||||
# detail and examples
|
||||
|
||||
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
||||
# hooks = black
|
||||
# black.type = console_scripts
|
||||
# black.entrypoint = black
|
||||
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
||||
|
||||
# Logging configuration
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARN
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
@@ -1 +0,0 @@
|
||||
Generic single-database configuration.
|
||||
@@ -1,67 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
from alembic import context
|
||||
from sqlalchemy.engine import create_engine
|
||||
|
||||
from core.db import DATABASE_URL
|
||||
|
||||
|
||||
myPath = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.insert(0, myPath + "/../../")
|
||||
|
||||
config = context.config
|
||||
|
||||
|
||||
from app.models import BaseMeta
|
||||
|
||||
|
||||
target_metadata = BaseMeta.metadata
|
||||
|
||||
|
||||
def run_migrations_offline():
|
||||
"""Run migrations in 'offline' mode.
|
||||
|
||||
This configures the context with just a URL
|
||||
and not an Engine, though an Engine is acceptable
|
||||
here as well. By skipping the Engine creation
|
||||
we don't even need a DBAPI to be available.
|
||||
|
||||
Calls to context.execute() here emit the given string to the
|
||||
script output.
|
||||
|
||||
"""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def run_migrations_online():
|
||||
"""Run migrations in 'online' mode.
|
||||
|
||||
In this scenario we need to create an Engine
|
||||
and associate a connection with the context.
|
||||
|
||||
"""
|
||||
connectable = create_engine(DATABASE_URL)
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(
|
||||
connection=connection, target_metadata=target_metadata, compare_type=True
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
@@ -1,24 +0,0 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = ${repr(up_revision)}
|
||||
down_revision = ${repr(down_revision)}
|
||||
branch_labels = ${repr(branch_labels)}
|
||||
depends_on = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade():
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade():
|
||||
${downgrades if downgrades else "pass"}
|
||||
@@ -1,32 +0,0 @@
|
||||
"""empty message
|
||||
|
||||
Revision ID: 62d57916ec53
|
||||
Revises: f77b0b14f9eb
|
||||
Create Date: 2022-12-30 23:30:50.867163
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "62d57916ec53"
|
||||
down_revision = "f77b0b14f9eb"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.drop_column("cached_files", "data")
|
||||
op.create_unique_constraint(
|
||||
"uc_cached_files_message_id_chat_id", "cached_files", ["message_id", "chat_id"]
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_cached_files_message_id"), "cached_files", ["message_id"], unique=True
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.add_column("cached_files", sa.Column("data", sa.JSON(), nullable=False))
|
||||
op.drop_constraint("uc_cached_files_message_id_chat_id", "cached_files")
|
||||
op.drop_index("ix_cached_files_message_id", "cached_files")
|
||||
@@ -1,49 +0,0 @@
|
||||
"""empty message
|
||||
|
||||
Revision ID: 9b7cfb422191
|
||||
Revises:
|
||||
Create Date: 2021-11-21 14:09:17.478532
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "9b7cfb422191"
|
||||
down_revision = None
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"cached_files",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("object_id", sa.Integer(), nullable=False),
|
||||
sa.Column("object_type", sa.String(length=8), nullable=False),
|
||||
sa.Column("data", sa.JSON(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint(
|
||||
"object_id", "object_type", name="uc_cached_files_object_id_object_type"
|
||||
),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_cached_files_object_id"), "cached_files", ["object_id"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_cached_files_object_type"),
|
||||
"cached_files",
|
||||
["object_type"],
|
||||
unique=False,
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table("cached_files")
|
||||
op.drop_index("ix_cached_files_object_id")
|
||||
op.drop_index("ix_cached_files_object_type")
|
||||
# ### end Alembic commands ###
|
||||
@@ -1,28 +0,0 @@
|
||||
"""empty message
|
||||
|
||||
Revision ID: f77b0b14f9eb
|
||||
Revises: 9b7cfb422191
|
||||
Create Date: 2022-12-30 22:53:41.951490
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f77b0b14f9eb"
|
||||
down_revision = "9b7cfb422191"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.add_column(
|
||||
"cached_files", sa.Column("message_id", sa.BigInteger(), nullable=True)
|
||||
)
|
||||
op.add_column("cached_files", sa.Column("chat_id", sa.BigInteger(), nullable=True))
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_column("cached_files", "message_id")
|
||||
op.drop_column("cached_files", "chat_id")
|
||||
@@ -1,18 +0,0 @@
|
||||
from fastapi import HTTPException, Request, Security, status
|
||||
|
||||
from redis.asyncio import ConnectionPool
|
||||
from taskiq import TaskiqDepends
|
||||
|
||||
from core.auth import default_security
|
||||
from core.config import env_config
|
||||
|
||||
|
||||
async def check_token(api_key: str = Security(default_security)):
|
||||
if api_key != env_config.API_KEY:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="Wrong api key!"
|
||||
)
|
||||
|
||||
|
||||
def get_redis_pool(request: Request = TaskiqDepends()) -> ConnectionPool:
|
||||
return request.app.state.redis_pool
|
||||
@@ -1,30 +0,0 @@
|
||||
import ormar
|
||||
|
||||
from core.db import database, metadata
|
||||
|
||||
|
||||
class BaseMeta(ormar.ModelMeta):
|
||||
metadata = metadata
|
||||
database = database
|
||||
|
||||
|
||||
class CachedFile(ormar.Model):
|
||||
class Meta(BaseMeta):
|
||||
tablename = "cached_files"
|
||||
constraints = [
|
||||
ormar.UniqueColumns("object_id", "object_type"),
|
||||
ormar.UniqueColumns("message_id", "chat_id"),
|
||||
]
|
||||
|
||||
id: int = ormar.Integer(primary_key=True) # type: ignore
|
||||
object_id: int = ormar.Integer(index=True) # type: ignore
|
||||
object_type: str = ormar.String(
|
||||
max_length=8, index=True, unique=True
|
||||
) # type: ignore
|
||||
|
||||
message_id: int = ormar.BigInteger(index=True) # type: ignore
|
||||
chat_id: int = ormar.BigInteger() # type: ignore
|
||||
|
||||
@ormar.property_field
|
||||
def data(self) -> dict:
|
||||
return {"message_id": self.message_id, "chat_id": self.chat_id}
|
||||
@@ -1,14 +0,0 @@
|
||||
from pydantic import BaseModel, constr
|
||||
|
||||
|
||||
class CachedFile(BaseModel):
|
||||
id: int
|
||||
object_id: int
|
||||
object_type: str
|
||||
data: dict
|
||||
|
||||
|
||||
class CreateCachedFile(BaseModel):
|
||||
object_id: int
|
||||
object_type: constr(max_length=8) # type: ignore
|
||||
data: dict
|
||||
@@ -1,167 +0,0 @@
|
||||
import collections
|
||||
from datetime import date, timedelta
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import UploadFile
|
||||
|
||||
import httpx
|
||||
from redis.asyncio import ConnectionPool, Redis
|
||||
from taskiq import TaskiqDepends
|
||||
|
||||
from app.depends import get_redis_pool
|
||||
from app.models import CachedFile
|
||||
from app.services.caption_getter import get_caption
|
||||
from app.services.downloader import download
|
||||
from app.services.files_client import upload_file
|
||||
from app.services.library_client import Book, get_book, get_books
|
||||
from core.taskiq_worker import broker
|
||||
|
||||
|
||||
PAGE_SIZE = 100
|
||||
|
||||
|
||||
class Retry(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class FileTypeNotAllowed(Exception):
|
||||
def __init__(self, message: str) -> None:
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
@broker.task
|
||||
async def check_books_page(
|
||||
page_number: int, uploaded_gte: str, uploaded_lte: str
|
||||
) -> bool:
|
||||
page = await get_books(
|
||||
page_number,
|
||||
PAGE_SIZE,
|
||||
uploaded_gte=date.fromisoformat(uploaded_gte),
|
||||
uploaded_lte=date.fromisoformat(uploaded_lte),
|
||||
)
|
||||
|
||||
object_ids = [book.id for book in page.items]
|
||||
|
||||
cached_files = await CachedFile.objects.filter(object_id__in=object_ids).all()
|
||||
|
||||
cached_files_map = collections.defaultdict(set)
|
||||
for cached_file in cached_files:
|
||||
cached_files_map[cached_file.object_id].add(cached_file.object_type)
|
||||
|
||||
for book in page.items:
|
||||
for file_type in book.available_types:
|
||||
if file_type not in cached_files_map[book.id]:
|
||||
await cache_file_by_book_id.kiq(
|
||||
book_id=book.id,
|
||||
file_type=file_type,
|
||||
by_request=False,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@broker.task
|
||||
async def check_books(*args, **kwargs) -> bool:
|
||||
uploaded_lte = date.today() + timedelta(days=1)
|
||||
uploaded_gte = date.today() - timedelta(days=1)
|
||||
|
||||
books_page = await get_books(
|
||||
1, PAGE_SIZE, uploaded_gte=uploaded_gte, uploaded_lte=uploaded_lte
|
||||
)
|
||||
|
||||
for page_number in range(1, books_page.pages + 1):
|
||||
await check_books_page.kiq(
|
||||
page_number,
|
||||
uploaded_gte=uploaded_gte.isoformat(),
|
||||
uploaded_lte=uploaded_lte.isoformat(),
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def cache_file(book: Book, file_type: str) -> Optional[CachedFile]:
|
||||
if await CachedFile.objects.filter(
|
||||
object_id=book.id, object_type=file_type
|
||||
).exists():
|
||||
return
|
||||
|
||||
try:
|
||||
data = await download(book.source.id, book.remote_id, file_type)
|
||||
except httpx.HTTPError:
|
||||
data = None
|
||||
|
||||
if data is None:
|
||||
raise Retry
|
||||
|
||||
response, client, filename = data
|
||||
caption = get_caption(book)
|
||||
|
||||
temp_file = UploadFile(BytesIO(), filename=filename)
|
||||
async for chunk in response.aiter_bytes(2048):
|
||||
await temp_file.write(chunk)
|
||||
|
||||
file_size = temp_file.file.tell()
|
||||
await temp_file.seek(0)
|
||||
|
||||
await response.aclose()
|
||||
await client.aclose()
|
||||
|
||||
upload_data = await upload_file(temp_file.file, file_size, filename, caption)
|
||||
|
||||
if upload_data is None:
|
||||
return None
|
||||
|
||||
cached_file, created = await CachedFile.objects.get_or_create(
|
||||
{
|
||||
"message_id": upload_data.data["message_id"],
|
||||
"chat_id": upload_data.data["chat_id"],
|
||||
},
|
||||
object_id=book.id,
|
||||
object_type=file_type,
|
||||
)
|
||||
|
||||
if created:
|
||||
return cached_file
|
||||
|
||||
cached_file.message_id = upload_data.data["message_id"]
|
||||
cached_file.chat_id = upload_data.data["chat_id"]
|
||||
|
||||
return await cached_file.update(["message_id", "chat_id"])
|
||||
|
||||
|
||||
@broker.task(retry_on_error=True)
|
||||
async def cache_file_by_book_id(
|
||||
book_id: int,
|
||||
file_type: str,
|
||||
by_request: bool = True,
|
||||
redis_pool: ConnectionPool = TaskiqDepends(get_redis_pool),
|
||||
) -> Optional[CachedFile]:
|
||||
book = await get_book(book_id, 3)
|
||||
|
||||
if book is None:
|
||||
if by_request:
|
||||
return None
|
||||
raise Retry
|
||||
|
||||
if file_type not in book.available_types:
|
||||
return None
|
||||
|
||||
async with Redis(connection_pool=redis_pool) as redis_client:
|
||||
lock = redis_client.lock(
|
||||
f"{book_id}_{file_type}", blocking_timeout=5, thread_local=False
|
||||
)
|
||||
|
||||
if await lock.locked() and not by_request:
|
||||
raise Retry
|
||||
|
||||
try:
|
||||
result = await cache_file(book, file_type)
|
||||
except Retry as e:
|
||||
if by_request:
|
||||
return None
|
||||
raise e
|
||||
|
||||
if by_request:
|
||||
return result
|
||||
return None
|
||||
@@ -1,41 +0,0 @@
|
||||
from app.services.library_client import Book, BookAuthor
|
||||
|
||||
|
||||
def get_author_string(author: BookAuthor) -> str:
|
||||
author_parts = []
|
||||
|
||||
if author.last_name:
|
||||
author_parts.append(author.last_name)
|
||||
|
||||
if author.first_name:
|
||||
author_parts.append(author.first_name)
|
||||
|
||||
if author.middle_name:
|
||||
author_parts.append(author.middle_name)
|
||||
|
||||
return " ".join(author_parts)
|
||||
|
||||
|
||||
def get_caption(book: Book) -> str:
|
||||
caption_title = f"📖 {book.title}"
|
||||
caption_title_length = len(caption_title) + 3
|
||||
|
||||
caption_authors_parts = []
|
||||
authors_caption_length = 0
|
||||
for author in book.authors:
|
||||
author_caption = f"👤 {get_author_string(author)}"
|
||||
|
||||
if (
|
||||
caption_title_length + authors_caption_length + len(author_caption) + 3
|
||||
) <= 1024:
|
||||
caption_authors_parts.append(author_caption)
|
||||
authors_caption_length += len(author_caption) + 3
|
||||
else:
|
||||
break
|
||||
|
||||
if not caption_authors_parts:
|
||||
return caption_title
|
||||
|
||||
caption_authors = "\n".join(caption_authors_parts)
|
||||
|
||||
return caption_title + "\n\n" + caption_authors
|
||||
@@ -1,57 +0,0 @@
|
||||
from base64 import b64decode
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
from sentry_sdk import capture_exception
|
||||
|
||||
from core.config import env_config
|
||||
|
||||
|
||||
async def download(
|
||||
source_id: int, remote_id: int, file_type: str
|
||||
) -> Optional[tuple[httpx.Response, httpx.AsyncClient, str]]:
|
||||
headers = {"Authorization": env_config.DOWNLOADER_API_KEY}
|
||||
|
||||
client = httpx.AsyncClient(timeout=600)
|
||||
request = client.build_request(
|
||||
"GET",
|
||||
f"{env_config.DOWNLOADER_URL}/download/{source_id}/{remote_id}/{file_type}",
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
try:
|
||||
response = await client.send(request, stream=True)
|
||||
except httpx.ConnectError:
|
||||
await client.aclose()
|
||||
return None
|
||||
|
||||
if response.status_code != 200:
|
||||
await response.aclose()
|
||||
await client.aclose()
|
||||
return None
|
||||
|
||||
name = b64decode(response.headers["x-filename-b64"]).decode()
|
||||
|
||||
return response, client, name
|
||||
|
||||
|
||||
async def get_filename(book_id: int, file_type: str) -> Optional[tuple[str, str]]:
|
||||
headers = {"Authorization": env_config.DOWNLOADER_API_KEY}
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{env_config.DOWNLOADER_URL}/filename/{book_id}/{file_type}",
|
||||
headers=headers,
|
||||
timeout=5 * 60,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
return None
|
||||
|
||||
data = response.json()
|
||||
|
||||
return data["filename"], data["filename_ascii"]
|
||||
except httpx.HTTPError as e:
|
||||
capture_exception(e)
|
||||
return None
|
||||
@@ -1,67 +0,0 @@
|
||||
from typing import BinaryIO, Optional
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from core.config import env_config
|
||||
|
||||
|
||||
class Data(TypedDict):
|
||||
chat_id: int
|
||||
message_id: int
|
||||
|
||||
|
||||
class UploadedFile(BaseModel):
|
||||
backend: str
|
||||
data: Data
|
||||
|
||||
|
||||
async def upload_file(
|
||||
content: BinaryIO, content_size: int, filename: str, caption: str
|
||||
) -> Optional[UploadedFile]:
|
||||
headers = {"Authorization": env_config.FILES_SERVER_API_KEY}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
form = {"caption": caption, "file_size": content_size}
|
||||
files = {"file": (filename, content)}
|
||||
|
||||
response = await client.post(
|
||||
f"{env_config.FILES_SERVER_URL}/api/v1/files/upload/",
|
||||
data=form,
|
||||
files=files,
|
||||
headers=headers,
|
||||
timeout=5 * 60,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
return None
|
||||
|
||||
return UploadedFile.parse_obj(response.json())
|
||||
|
||||
|
||||
async def download_file(
|
||||
chat_id: int, message_id: int
|
||||
) -> Optional[tuple[httpx.Response, httpx.AsyncClient]]:
|
||||
headers = {"Authorization": env_config.FILES_SERVER_API_KEY}
|
||||
|
||||
client = httpx.AsyncClient(timeout=60)
|
||||
request = client.build_request(
|
||||
"GET",
|
||||
f"{env_config.FILES_SERVER_URL}"
|
||||
f"/api/v1/files/download_by_message/{chat_id}/{message_id}",
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
try:
|
||||
response = await client.send(request, stream=True)
|
||||
except httpx.ConnectError:
|
||||
await client.aclose()
|
||||
return None
|
||||
|
||||
if response.status_code != 200:
|
||||
await response.aclose()
|
||||
await client.aclose()
|
||||
return None
|
||||
|
||||
return response, client
|
||||
@@ -1,123 +0,0 @@
|
||||
from datetime import date
|
||||
from typing import Generic, Optional, TypeVar
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
from sentry_sdk import capture_exception
|
||||
|
||||
from core.config import env_config
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class Page(BaseModel, Generic[T]):
|
||||
items: list[T]
|
||||
total: int
|
||||
|
||||
size: int
|
||||
|
||||
page: int
|
||||
pages: int
|
||||
|
||||
|
||||
class BaseBookInfo(BaseModel):
|
||||
id: int
|
||||
available_types: list[str]
|
||||
|
||||
|
||||
class BookAuthor(BaseModel):
|
||||
id: int
|
||||
first_name: str
|
||||
last_name: str
|
||||
middle_name: str
|
||||
|
||||
|
||||
class BookSource(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
|
||||
|
||||
class Book(BaseModel):
|
||||
id: int
|
||||
title: str
|
||||
file_type: str
|
||||
available_types: list[str]
|
||||
source: BookSource
|
||||
remote_id: int
|
||||
uploaded: date
|
||||
authors: list[BookAuthor]
|
||||
|
||||
|
||||
class BookDetail(Book):
|
||||
is_deleted: bool
|
||||
|
||||
|
||||
AUTH_HEADERS = {"Authorization": env_config.LIBRARY_API_KEY}
|
||||
|
||||
|
||||
async def get_book(
|
||||
book_id: int, retry: int = 3, last_exp: Exception | None = None
|
||||
) -> Optional[BookDetail]:
|
||||
if retry == 0:
|
||||
if last_exp:
|
||||
capture_exception(last_exp)
|
||||
|
||||
return None
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=2 * 60) as client:
|
||||
response = await client.get(
|
||||
f"{env_config.LIBRARY_URL}/api/v1/books/{book_id}", headers=AUTH_HEADERS
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
return None
|
||||
|
||||
return BookDetail.parse_obj(response.json())
|
||||
except httpx.HTTPError as e:
|
||||
return await get_book(book_id, retry=retry - 1, last_exp=e)
|
||||
|
||||
|
||||
async def get_books(
|
||||
page: int,
|
||||
page_size: int,
|
||||
uploaded_gte: date | None = None,
|
||||
uploaded_lte: date | None = None,
|
||||
) -> Page[BaseBookInfo]:
|
||||
params: dict[str, str] = {
|
||||
"page": str(page),
|
||||
"page_size": str(page_size),
|
||||
"is_deleted": "false",
|
||||
}
|
||||
|
||||
if uploaded_gte:
|
||||
params["uploaded_gte"] = uploaded_gte.isoformat()
|
||||
|
||||
if uploaded_lte:
|
||||
params["uploaded_lte"] = uploaded_lte.isoformat()
|
||||
|
||||
params_string = urlencode(params)
|
||||
|
||||
async with httpx.AsyncClient(timeout=5 * 60) as client:
|
||||
response = await client.get(
|
||||
f"{env_config.LIBRARY_URL}/api/v1/books/base/?{params_string}",
|
||||
headers=AUTH_HEADERS,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
page_data = Page[BaseBookInfo].parse_obj(data)
|
||||
page_data.items = [BaseBookInfo.parse_obj(item) for item in page_data.items]
|
||||
|
||||
return page_data
|
||||
|
||||
|
||||
async def get_last_book_id() -> int:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{env_config.LIBRARY_URL}/api/v1/books/last", headers=AUTH_HEADERS
|
||||
)
|
||||
|
||||
return int(response.text)
|
||||
@@ -1,24 +0,0 @@
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
from redis.asyncio import ConnectionPool
|
||||
|
||||
from app.models import CachedFile as CachedFileDB
|
||||
from app.services.cache_updater import cache_file_by_book_id
|
||||
|
||||
|
||||
async def get_cached_file_or_cache(
|
||||
object_id: int, object_type: str, connection_pool: ConnectionPool
|
||||
) -> CachedFileDB:
|
||||
cached_file = await CachedFileDB.objects.get_or_none(
|
||||
object_id=object_id, object_type=object_type
|
||||
)
|
||||
|
||||
if not cached_file:
|
||||
cached_file = await cache_file_by_book_id(
|
||||
object_id, object_type, redis_pool=connection_pool
|
||||
)
|
||||
|
||||
if not cached_file:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
|
||||
|
||||
return cached_file
|
||||
145
src/app/views.py
145
src/app/views.py
@@ -1,145 +0,0 @@
|
||||
from base64 import b64encode
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from starlette.background import BackgroundTask
|
||||
|
||||
from redis.asyncio import ConnectionPool
|
||||
|
||||
from app.depends import check_token, get_redis_pool
|
||||
from app.models import CachedFile as CachedFileDB
|
||||
from app.serializers import CachedFile, CreateCachedFile
|
||||
from app.services.cache_updater import cache_file_by_book_id, check_books
|
||||
from app.services.caption_getter import get_caption
|
||||
from app.services.downloader import get_filename
|
||||
from app.services.files_client import download_file as download_file_from_cache
|
||||
from app.services.library_client import get_book
|
||||
from app.utils import get_cached_file_or_cache
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/v1", tags=["files"], dependencies=[Depends(check_token)]
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{object_id}/{object_type}", response_model=CachedFile)
|
||||
async def get_cached_file(
|
||||
redis_pool: Annotated[ConnectionPool, Depends(get_redis_pool)],
|
||||
object_id: int,
|
||||
object_type: str,
|
||||
):
|
||||
cached_file = await CachedFileDB.objects.get_or_none(
|
||||
object_id=object_id, object_type=object_type
|
||||
)
|
||||
|
||||
if not cached_file:
|
||||
cached_file = await cache_file_by_book_id(
|
||||
object_id, object_type, by_request=True, redis_pool=redis_pool
|
||||
)
|
||||
|
||||
if not cached_file:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
|
||||
|
||||
return cached_file
|
||||
|
||||
|
||||
@router.get("/download/{object_id}/{object_type}")
|
||||
async def download_cached_file(request: Request, object_id: int, object_type: str):
|
||||
cached_file = await get_cached_file_or_cache(
|
||||
object_id, object_type, request.app.state.redis_pool
|
||||
)
|
||||
cache_data: dict = cached_file.data # type: ignore
|
||||
|
||||
data = await download_file_from_cache(
|
||||
cache_data["chat_id"], cache_data["message_id"]
|
||||
)
|
||||
if data is None:
|
||||
await CachedFileDB.objects.filter(id=cached_file.id).delete()
|
||||
|
||||
cached_file = await get_cached_file_or_cache(
|
||||
object_id, object_type, request.app.state.redis_pool
|
||||
)
|
||||
cache_data: dict = cached_file.data # type: ignore
|
||||
|
||||
data = await download_file_from_cache(
|
||||
cache_data["chat_id"], cache_data["message_id"]
|
||||
)
|
||||
|
||||
if data is None:
|
||||
raise HTTPException(status_code=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
if (filename_data := await get_filename(object_id, object_type)) is None:
|
||||
raise HTTPException(status_code=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
if (book := await get_book(object_id)) is None:
|
||||
raise HTTPException(status_code=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
response, client = data
|
||||
|
||||
async def close():
|
||||
await response.aclose()
|
||||
await client.aclose()
|
||||
|
||||
filename, filename_ascii = filename_data
|
||||
|
||||
return StreamingResponse(
|
||||
response.aiter_bytes(),
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename={filename_ascii}",
|
||||
"X-Caption-B64": b64encode(get_caption(book).encode("utf-8")).decode(),
|
||||
"X-Filename-B64": b64encode(filename.encode("utf-8")).decode(),
|
||||
},
|
||||
background=BackgroundTask(close),
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{object_id}/{object_type}", response_model=CachedFile)
|
||||
async def delete_cached_file(object_id: int, object_type: str):
|
||||
cached_file = await CachedFileDB.objects.get_or_none(
|
||||
object_id=object_id, object_type=object_type
|
||||
)
|
||||
|
||||
if not cached_file:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
|
||||
|
||||
await cached_file.delete()
|
||||
|
||||
return cached_file
|
||||
|
||||
|
||||
@router.post("/", response_model=CachedFile)
|
||||
async def create_or_update_cached_file(data: CreateCachedFile):
|
||||
cached_file = await CachedFileDB.objects.get_or_none(
|
||||
object_id=data.data["object_id"], object_type=data.data["object_type"]
|
||||
)
|
||||
|
||||
if cached_file is not None:
|
||||
cached_file.message_id = data.data["message_id"]
|
||||
cached_file.chat_id = data.data["chat_id"]
|
||||
return await cached_file.update()
|
||||
|
||||
return await CachedFileDB.objects.create(
|
||||
object_id=data.object_id,
|
||||
object_type=data.object_type,
|
||||
message_id=data.data["message_id"],
|
||||
chat_id=data.data["chat_id"],
|
||||
)
|
||||
|
||||
|
||||
@router.post("/update_cache")
|
||||
async def update_cache(request: Request):
|
||||
await check_books.kiq()
|
||||
|
||||
return "Ok!"
|
||||
|
||||
|
||||
healthcheck_router = APIRouter(
|
||||
tags=["healthcheck"],
|
||||
)
|
||||
|
||||
|
||||
@healthcheck_router.get("/healthcheck")
|
||||
async def healthcheck():
|
||||
return "Ok!"
|
||||
66
src/config.rs
Normal file
66
src/config.rs
Normal file
@@ -0,0 +1,66 @@
|
||||
use once_cell::sync::Lazy;
|
||||
|
||||
pub struct Config {
|
||||
pub api_key: String,
|
||||
|
||||
pub postgres_user: String,
|
||||
pub postgres_password: String,
|
||||
pub postgres_host: String,
|
||||
pub postgres_port: u32,
|
||||
pub postgres_db: String,
|
||||
|
||||
pub downloader_api_key: String,
|
||||
pub downloader_url: String,
|
||||
|
||||
pub library_api_key: String,
|
||||
pub library_url: String,
|
||||
|
||||
pub files_api_key: String,
|
||||
pub files_url: String,
|
||||
|
||||
pub telegram_api_id: i32,
|
||||
pub telegram_api_hash: String,
|
||||
pub telegram_bot_tokens: Vec<String>,
|
||||
|
||||
pub sentry_dsn: String
|
||||
}
|
||||
|
||||
|
||||
fn get_env(env: &'static str) -> String {
|
||||
std::env::var(env).unwrap_or_else(|_| panic!("Cannot get the {} env variable", env))
|
||||
}
|
||||
|
||||
|
||||
impl Config {
|
||||
pub fn load() -> Config {
|
||||
Config {
|
||||
api_key: get_env("API_KEY"),
|
||||
|
||||
postgres_user: get_env("POSTGRES_USER"),
|
||||
postgres_password: get_env("POSTGRES_PASSWORD"),
|
||||
postgres_host: get_env("POSTGRES_HOST"),
|
||||
postgres_port: get_env("POSTGRES_PORT").parse().unwrap(),
|
||||
postgres_db: get_env("POSTGRES_DB"),
|
||||
|
||||
downloader_url: get_env("DOWNLOADER_API_KEY"),
|
||||
downloader_api_key: get_env("DOWNLOADER_URL"),
|
||||
|
||||
library_api_key: get_env("LIBRARY_API_KEY"),
|
||||
library_url: get_env("LIBRARY_URL"),
|
||||
|
||||
files_api_key: get_env("FILES_SERVER_API_KEY"),
|
||||
files_url: get_env("FILES_SERVER_URL"),
|
||||
|
||||
telegram_api_id: get_env("TELEGRAM_API_ID").parse().unwrap(),
|
||||
telegram_api_hash: get_env("TELEGRAM_API_HASH"),
|
||||
telegram_bot_tokens: serde_json::from_str(&get_env("TELEGRAM_BOT_TOKENS")).unwrap(),
|
||||
|
||||
sentry_dsn: get_env("SENTRY_DSN")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
pub static CONFIG: Lazy<Config> = Lazy::new(|| {
|
||||
Config::load()
|
||||
});
|
||||
@@ -1,44 +0,0 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import ORJSONResponse
|
||||
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
from redis.asyncio import ConnectionPool
|
||||
|
||||
from app.views import healthcheck_router, router
|
||||
from core.config import REDIS_URL
|
||||
from core.db import database
|
||||
from core.taskiq_worker import broker
|
||||
|
||||
|
||||
def start_app() -> FastAPI:
|
||||
app = FastAPI(default_response_class=ORJSONResponse)
|
||||
|
||||
app.state.redis_pool = ConnectionPool.from_url(REDIS_URL)
|
||||
|
||||
app.include_router(router)
|
||||
app.include_router(healthcheck_router)
|
||||
|
||||
@app.on_event("startup")
|
||||
async def app_startup():
|
||||
if not database.is_connected:
|
||||
await database.connect()
|
||||
|
||||
if not broker.is_worker_process:
|
||||
await broker.startup()
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def app_shutdown():
|
||||
if database.is_connected:
|
||||
await database.disconnect()
|
||||
|
||||
if not broker.is_worker_process:
|
||||
await broker.shutdown()
|
||||
|
||||
await app.state.redis_pool.disconnect()
|
||||
|
||||
Instrumentator(
|
||||
should_ignore_untemplated=True,
|
||||
excluded_handlers=["/docs", "/metrics", "/healthcheck"],
|
||||
).instrument(app).expose(app, include_in_schema=True)
|
||||
|
||||
return app
|
||||
@@ -1,4 +0,0 @@
|
||||
from fastapi.security import APIKeyHeader
|
||||
|
||||
|
||||
default_security = APIKeyHeader(name="Authorization")
|
||||
@@ -1,33 +0,0 @@
|
||||
from pydantic import BaseSettings
|
||||
|
||||
|
||||
class EnvConfig(BaseSettings):
|
||||
API_KEY: str
|
||||
|
||||
POSTGRES_USER: str
|
||||
POSTGRES_PASSWORD: str
|
||||
POSTGRES_HOST: str
|
||||
POSTGRES_PORT: int
|
||||
POSTGRES_DB: str
|
||||
|
||||
DOWNLOADER_API_KEY: str
|
||||
DOWNLOADER_URL: str
|
||||
|
||||
LIBRARY_API_KEY: str
|
||||
LIBRARY_URL: str
|
||||
|
||||
FILES_SERVER_API_KEY: str
|
||||
FILES_SERVER_URL: str
|
||||
|
||||
REDIS_HOST: str
|
||||
REDIS_PORT: int
|
||||
REDIS_DB: int
|
||||
|
||||
SENTRY_DSN: str
|
||||
|
||||
|
||||
env_config = EnvConfig() # type: ignore
|
||||
|
||||
REDIS_URL = (
|
||||
f"redis://{env_config.REDIS_HOST}:{env_config.REDIS_PORT}/{env_config.REDIS_DB}"
|
||||
)
|
||||
@@ -1,15 +0,0 @@
|
||||
from urllib.parse import quote
|
||||
|
||||
from databases import Database
|
||||
from sqlalchemy import MetaData
|
||||
|
||||
from core.config import env_config
|
||||
|
||||
|
||||
DATABASE_URL = (
|
||||
f"postgresql://{env_config.POSTGRES_USER}:{quote(env_config.POSTGRES_PASSWORD)}@"
|
||||
f"{env_config.POSTGRES_HOST}:{env_config.POSTGRES_PORT}/{env_config.POSTGRES_DB}"
|
||||
)
|
||||
|
||||
metadata = MetaData()
|
||||
database = Database(DATABASE_URL, min_size=1, max_size=10)
|
||||
@@ -1,40 +0,0 @@
|
||||
from inspect import signature
|
||||
from typing import Any
|
||||
|
||||
from taskiq import SimpleRetryMiddleware
|
||||
from taskiq.message import TaskiqMessage
|
||||
from taskiq.result import TaskiqResult
|
||||
from taskiq_dependencies.dependency import Dependency
|
||||
|
||||
|
||||
class FastAPIRetryMiddleware(SimpleRetryMiddleware):
|
||||
@staticmethod
|
||||
def _remove_depends(
|
||||
task_func: Any, message_kwargs: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
sig = signature(task_func)
|
||||
|
||||
keys_to_remove = []
|
||||
|
||||
for key in message_kwargs.keys():
|
||||
param = sig.parameters.get(key, None)
|
||||
|
||||
if param is None:
|
||||
continue
|
||||
|
||||
if isinstance(param.default, Dependency):
|
||||
keys_to_remove.append(key)
|
||||
|
||||
for key in keys_to_remove:
|
||||
message_kwargs.pop(key)
|
||||
|
||||
return message_kwargs
|
||||
|
||||
async def on_error(
|
||||
self, message: TaskiqMessage, result: TaskiqResult[Any], exception: Exception
|
||||
) -> None:
|
||||
task_func = self.broker.available_tasks[message.task_name].original_func
|
||||
|
||||
message.kwargs = self._remove_depends(task_func, message.kwargs)
|
||||
|
||||
return await super().on_error(message, result, exception)
|
||||
@@ -1,17 +0,0 @@
|
||||
import taskiq_fastapi
|
||||
from taskiq_redis import ListQueueBroker, RedisAsyncResultBackend
|
||||
|
||||
from core.config import REDIS_URL
|
||||
from core.taskiq_middlewares import FastAPIRetryMiddleware
|
||||
|
||||
|
||||
broker = (
|
||||
ListQueueBroker(url=REDIS_URL)
|
||||
.with_result_backend(
|
||||
RedisAsyncResultBackend(redis_url=REDIS_URL, result_ex_time=5 * 60)
|
||||
)
|
||||
.with_middlewares(FastAPIRetryMiddleware())
|
||||
)
|
||||
|
||||
|
||||
taskiq_fastapi.init(broker, "main:app")
|
||||
19
src/db.rs
Normal file
19
src/db.rs
Normal file
@@ -0,0 +1,19 @@
|
||||
use crate::{prisma::PrismaClient, config::CONFIG};
|
||||
|
||||
|
||||
pub async fn get_prisma_client() -> PrismaClient {
|
||||
let database_url: String = format!(
|
||||
"postgresql://{}:{}@{}:{}/{}?connection_limit=1",
|
||||
CONFIG.postgres_user,
|
||||
CONFIG.postgres_password,
|
||||
CONFIG.postgres_host,
|
||||
CONFIG.postgres_port,
|
||||
CONFIG.postgres_db
|
||||
);
|
||||
|
||||
PrismaClient::_builder()
|
||||
.with_url(database_url)
|
||||
.build()
|
||||
.await
|
||||
.unwrap()
|
||||
}
|
||||
11
src/main.py
11
src/main.py
@@ -1,11 +0,0 @@
|
||||
import sentry_sdk
|
||||
|
||||
from app.services.cache_updater import Retry
|
||||
from core.app import start_app
|
||||
from core.config import env_config
|
||||
|
||||
|
||||
if env_config.SENTRY_DSN:
|
||||
sentry_sdk.init(dsn=env_config.SENTRY_DSN, ignore_errors=[Retry])
|
||||
|
||||
app = start_app()
|
||||
30
src/main.rs
Normal file
30
src/main.rs
Normal file
@@ -0,0 +1,30 @@
|
||||
pub mod config;
|
||||
pub mod db;
|
||||
pub mod prisma;
|
||||
pub mod views;
|
||||
pub mod services;
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use tracing::info;
|
||||
|
||||
use crate::views::get_router;
|
||||
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
tracing_subscriber::fmt()
|
||||
.with_target(false)
|
||||
.compact()
|
||||
.init();
|
||||
|
||||
let addr = SocketAddr::from(([0, 0, 0, 0], 8080));
|
||||
|
||||
let app = get_router().await;
|
||||
|
||||
info!("Start webserver...");
|
||||
axum::Server::bind(&addr)
|
||||
.serve(app.into_make_service())
|
||||
.await
|
||||
.unwrap();
|
||||
info!("Webserver shutdown...")
|
||||
}
|
||||
1360
src/prisma.rs
Normal file
1360
src/prisma.rs
Normal file
File diff suppressed because one or more lines are too long
BIN
src/services/.DS_Store
vendored
Normal file
BIN
src/services/.DS_Store
vendored
Normal file
Binary file not shown.
49
src/services/book_library/mod.rs
Normal file
49
src/services/book_library/mod.rs
Normal file
@@ -0,0 +1,49 @@
|
||||
pub mod types;
|
||||
|
||||
use serde::de::DeserializeOwned;
|
||||
|
||||
use crate::config::CONFIG;
|
||||
|
||||
async fn _make_request<T>(
|
||||
url: &str,
|
||||
params: Vec<(&str, String)>,
|
||||
) -> Result<T, Box<dyn std::error::Error + Send + Sync>>
|
||||
where
|
||||
T: DeserializeOwned,
|
||||
{
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
let formated_url = format!("{}{}", CONFIG.library_url, url);
|
||||
|
||||
let response = client
|
||||
.get(formated_url)
|
||||
.query(¶ms)
|
||||
.header("Authorization", CONFIG.library_api_key.clone())
|
||||
.send()
|
||||
.await;
|
||||
|
||||
let response = match response {
|
||||
Ok(v) => v,
|
||||
Err(err) => return Err(Box::new(err)),
|
||||
};
|
||||
|
||||
let response = match response.error_for_status() {
|
||||
Ok(v) => v,
|
||||
Err(err) => return Err(Box::new(err)),
|
||||
};
|
||||
|
||||
match response.json::<T>().await {
|
||||
Ok(v) => Ok(v),
|
||||
Err(err) => Err(Box::new(err)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_sources() -> Result<types::Source, Box<dyn std::error::Error + Send + Sync>> {
|
||||
_make_request("/api/v1/sources", vec![]).await
|
||||
}
|
||||
|
||||
pub async fn get_book(
|
||||
book_id: i32,
|
||||
) -> Result<types::BookWithRemote, Box<dyn std::error::Error + Send + Sync>> {
|
||||
_make_request(format!("/api/v1/books/{book_id}").as_str(), vec![]).await
|
||||
}
|
||||
108
src/services/book_library/types.rs
Normal file
108
src/services/book_library/types.rs
Normal file
@@ -0,0 +1,108 @@
|
||||
use serde::Deserialize;
|
||||
|
||||
|
||||
#[derive(Deserialize, Debug, Clone)]
|
||||
pub struct Source {
|
||||
// id: u32,
|
||||
// name: String
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Clone)]
|
||||
pub struct BookAuthor {
|
||||
pub id: u32,
|
||||
pub first_name: String,
|
||||
pub last_name: String,
|
||||
pub middle_name: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Clone)]
|
||||
pub struct Book {
|
||||
pub id: u32,
|
||||
pub title: String,
|
||||
pub lang: String,
|
||||
pub file_type: String,
|
||||
pub uploaded: String,
|
||||
pub authors: Vec<BookAuthor>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Clone)]
|
||||
pub struct BookWithRemote {
|
||||
pub id: u32,
|
||||
pub remote_id: u32,
|
||||
pub title: String,
|
||||
pub lang: String,
|
||||
pub file_type: String,
|
||||
pub uploaded: String,
|
||||
pub authors: Vec<BookAuthor>,
|
||||
}
|
||||
|
||||
impl BookWithRemote {
|
||||
pub fn from_book(book: Book, remote_id: u32) -> Self {
|
||||
Self {
|
||||
id: book.id,
|
||||
remote_id,
|
||||
title: book.title,
|
||||
lang: book.lang,
|
||||
file_type: book.file_type,
|
||||
uploaded: book.uploaded,
|
||||
authors: book.authors
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
impl BookAuthor {
|
||||
pub fn get_caption(self) -> String {
|
||||
let mut parts: Vec<String> = vec![];
|
||||
|
||||
if !self.last_name.is_empty() {
|
||||
parts.push(self.last_name);
|
||||
}
|
||||
|
||||
if !self.first_name.is_empty() {
|
||||
parts.push(self.first_name);
|
||||
}
|
||||
|
||||
if !self.middle_name.is_empty() {
|
||||
parts.push(self.middle_name);
|
||||
}
|
||||
|
||||
let joined_parts = parts.join(" ");
|
||||
|
||||
format!("👤 {joined_parts}")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
impl BookWithRemote {
|
||||
pub fn get_caption(self) -> String {
|
||||
let BookWithRemote {
|
||||
title,
|
||||
authors,
|
||||
..
|
||||
} = self;
|
||||
|
||||
let caption_title = format!("📖 {title}");
|
||||
|
||||
let author_captions: Vec<String> = authors
|
||||
.into_iter()
|
||||
.map(|a| a.get_caption())
|
||||
.collect();
|
||||
|
||||
let mut author_parts: Vec<String> = vec![];
|
||||
let mut author_parts_len = 3;
|
||||
|
||||
for author_caption in author_captions {
|
||||
if caption_title.len() + author_parts_len + author_caption.len() + 1 <= 1024 {
|
||||
author_parts_len = author_caption.len() + 1;
|
||||
author_parts.push(author_caption);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let caption_authors = author_parts.join("\n");
|
||||
|
||||
format!("{caption_title}\n\n{caption_authors}")
|
||||
}
|
||||
}
|
||||
19
src/services/download_utils.rs
Normal file
19
src/services/download_utils.rs
Normal file
@@ -0,0 +1,19 @@
|
||||
use futures::TryStreamExt;
|
||||
use reqwest::Response;
|
||||
use tokio::io::AsyncRead;
|
||||
use tokio_util::compat::FuturesAsyncReadCompatExt;
|
||||
|
||||
|
||||
pub struct DownloadResult {
|
||||
pub response: Response,
|
||||
pub filename: String,
|
||||
pub filename_ascii: String,
|
||||
pub caption: String,
|
||||
}
|
||||
|
||||
pub fn get_response_async_read(it: Response) -> impl AsyncRead {
|
||||
it.bytes_stream()
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
|
||||
.into_async_read()
|
||||
.compat()
|
||||
}
|
||||
57
src/services/downloader/mod.rs
Normal file
57
src/services/downloader/mod.rs
Normal file
@@ -0,0 +1,57 @@
|
||||
use reqwest::Response;
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::config::CONFIG;
|
||||
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct FilenameData {
|
||||
pub filename: String,
|
||||
pub filename_ascii: String
|
||||
}
|
||||
|
||||
|
||||
pub async fn download_from_downloader(
|
||||
remote_id: u32,
|
||||
object_id: i32,
|
||||
object_type: String
|
||||
) -> Result<Response, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let url = format!(
|
||||
"{}/download/{remote_id}/{object_id}/{object_type}",
|
||||
CONFIG.downloader_url
|
||||
);
|
||||
|
||||
let response = reqwest::Client::new()
|
||||
.get(url)
|
||||
.header("Authorization", &CONFIG.downloader_api_key)
|
||||
.send()
|
||||
.await?
|
||||
.error_for_status()?;
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
|
||||
pub async fn get_filename(
|
||||
object_id: i32,
|
||||
object_type: String
|
||||
) -> Result<FilenameData, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let url = format!(
|
||||
"{}/filename/{object_id}/{object_type}",
|
||||
CONFIG.downloader_url
|
||||
);
|
||||
|
||||
let response = reqwest::Client::new()
|
||||
.get(url)
|
||||
.header("Authorization", &CONFIG.downloader_api_key)
|
||||
.send()
|
||||
.await?
|
||||
.error_for_status()?;
|
||||
|
||||
match response.json::<FilenameData>().await {
|
||||
Ok(v) => Ok(v),
|
||||
Err(err) => {
|
||||
Err(Box::new(err))
|
||||
},
|
||||
}
|
||||
}
|
||||
124
src/services/mod.rs
Normal file
124
src/services/mod.rs
Normal file
@@ -0,0 +1,124 @@
|
||||
pub mod book_library;
|
||||
pub mod download_utils;
|
||||
pub mod telegram_files;
|
||||
pub mod downloader;
|
||||
|
||||
use tracing::log;
|
||||
|
||||
use crate::{prisma::cached_file, views::Database};
|
||||
|
||||
use self::{download_utils::DownloadResult, telegram_files::{download_from_telegram_files, UploadData, upload_to_telegram_files}, downloader::{get_filename, FilenameData, download_from_downloader}, book_library::get_book};
|
||||
|
||||
|
||||
pub async fn get_cached_file_or_cache(
|
||||
object_id: i32,
|
||||
object_type: String,
|
||||
db: Database
|
||||
) -> Option<cached_file::Data> {
|
||||
let cached_file = db.cached_file()
|
||||
.find_unique(cached_file::object_id_object_type(object_id, object_type.clone()))
|
||||
.exec()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
match cached_file {
|
||||
Some(cached_file) => Some(cached_file),
|
||||
None => cache_file(object_id, object_type, db).await,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
pub async fn cache_file(
|
||||
object_id: i32,
|
||||
object_type: String,
|
||||
db: Database
|
||||
) -> Option<cached_file::Data> {
|
||||
let book = match get_book(object_id).await {
|
||||
Ok(v) => v,
|
||||
Err(err) => {
|
||||
log::error!("{:?}", err);
|
||||
return None;
|
||||
},
|
||||
};
|
||||
|
||||
let downloader_result = match download_from_downloader(
|
||||
book.remote_id,
|
||||
object_id,
|
||||
object_type.clone()
|
||||
).await {
|
||||
Ok(v) => v,
|
||||
Err(err) => {
|
||||
log::error!("{:?}", err);
|
||||
return None;
|
||||
},
|
||||
};
|
||||
|
||||
let UploadData { chat_id, message_id } = match upload_to_telegram_files(
|
||||
downloader_result,
|
||||
book.get_caption()
|
||||
).await {
|
||||
Ok(v) => v,
|
||||
Err(err) => {
|
||||
log::error!("{:?}", err);
|
||||
return None;
|
||||
},
|
||||
};
|
||||
|
||||
Some(
|
||||
db
|
||||
.cached_file()
|
||||
.create(
|
||||
object_id,
|
||||
object_type,
|
||||
message_id,
|
||||
chat_id,
|
||||
vec![]
|
||||
)
|
||||
.exec()
|
||||
.await
|
||||
.unwrap()
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
pub async fn download_from_cache(
|
||||
cached_data: cached_file::Data,
|
||||
) -> Option<DownloadResult> {
|
||||
let response_task = tokio::task::spawn(download_from_telegram_files(cached_data.message_id, cached_data.chat_id));
|
||||
let filename_task = tokio::task::spawn(get_filename(cached_data.object_id, cached_data.object_type.clone()));
|
||||
let book_task = tokio::task::spawn(get_book(cached_data.object_id));
|
||||
|
||||
let response = match response_task.await.unwrap() {
|
||||
Ok(v) => v,
|
||||
Err(err) => {
|
||||
log::error!("{:?}", err);
|
||||
return None;
|
||||
},
|
||||
};
|
||||
|
||||
let filename_data = match filename_task.await.unwrap() {
|
||||
Ok(v) => v,
|
||||
Err(err) => {
|
||||
log::error!("{:?}", err);
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
let book = match book_task.await.unwrap() {
|
||||
Ok(v) => v,
|
||||
Err(err) => {
|
||||
log::error!("{:?}", err);
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
let FilenameData {filename, filename_ascii} = filename_data;
|
||||
let caption = book.get_caption();
|
||||
|
||||
Some(DownloadResult {
|
||||
response,
|
||||
filename,
|
||||
filename_ascii,
|
||||
caption
|
||||
})
|
||||
}
|
||||
87
src/services/telegram_files/mod.rs
Normal file
87
src/services/telegram_files/mod.rs
Normal file
@@ -0,0 +1,87 @@
|
||||
use reqwest::{Response, multipart::{Form, Part}, header};
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::config::CONFIG;
|
||||
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct UploadData {
|
||||
pub chat_id: i64,
|
||||
pub message_id: i64
|
||||
}
|
||||
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct UploadResult {
|
||||
pub backend: String,
|
||||
pub data: UploadData
|
||||
}
|
||||
|
||||
|
||||
pub async fn download_from_telegram_files(
|
||||
message_id: i64,
|
||||
chat_id: i64
|
||||
) -> Result<Response, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let url = format!(
|
||||
"{}/api/v1/files/download_by_message/{chat_id}/{message_id}",
|
||||
CONFIG.files_url
|
||||
);
|
||||
|
||||
let response = reqwest::Client::new()
|
||||
.get(url)
|
||||
.header("Authorization", CONFIG.library_api_key.clone())
|
||||
.send()
|
||||
.await?
|
||||
.error_for_status()?;
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
|
||||
pub async fn upload_to_telegram_files(
|
||||
data_response: Response,
|
||||
caption: String
|
||||
) -> Result<UploadData, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let url = format!(
|
||||
"{}/api/v1/files/upload/",
|
||||
CONFIG.files_url
|
||||
);
|
||||
|
||||
let headers = data_response.headers();
|
||||
|
||||
let file_size = headers
|
||||
.get(header::CONTENT_LENGTH)
|
||||
.unwrap()
|
||||
.to_str()
|
||||
.unwrap()
|
||||
.to_string();
|
||||
|
||||
let filename = headers
|
||||
.get("x-filename-b64-ascii")
|
||||
.unwrap()
|
||||
.to_str()
|
||||
.unwrap()
|
||||
.to_string();
|
||||
|
||||
let part = Part::stream(data_response)
|
||||
.file_name(filename);
|
||||
|
||||
let form = Form::new()
|
||||
.text("caption", caption)
|
||||
.text("file_size", file_size)
|
||||
.part("file", part);
|
||||
|
||||
let response = reqwest::Client::new()
|
||||
.post(url)
|
||||
.multipart(form)
|
||||
.send()
|
||||
.await?
|
||||
.error_for_status()?;
|
||||
|
||||
match response.json::<UploadResult>().await {
|
||||
Ok(v) => Ok(v.data),
|
||||
Err(err) => {
|
||||
Err(Box::new(err))
|
||||
},
|
||||
}
|
||||
}
|
||||
158
src/views.rs
Normal file
158
src/views.rs
Normal file
@@ -0,0 +1,158 @@
|
||||
use axum::{Router, response::{Response, IntoResponse, AppendHeaders}, http::{StatusCode, self, Request, header}, middleware::{Next, self}, Extension, routing::{get, delete, post}, extract::Path, Json, body::StreamBody};
|
||||
use axum_prometheus::PrometheusMetricLayer;
|
||||
use tokio_util::io::ReaderStream;
|
||||
use tower_http::trace::{TraceLayer, self};
|
||||
use tracing::Level;
|
||||
use std::sync::Arc;
|
||||
use base64::{engine::general_purpose, Engine};
|
||||
|
||||
use crate::{config::CONFIG, db::get_prisma_client, prisma::{PrismaClient, cached_file::{self}}, services::{get_cached_file_or_cache, download_from_cache, download_utils::get_response_async_read}};
|
||||
|
||||
|
||||
pub type Database = Arc<PrismaClient>;
|
||||
|
||||
//
|
||||
|
||||
async fn get_cached_file(
|
||||
Path((object_id, object_type)): Path<(i32, String)>,
|
||||
Extension(Ext { db, .. }): Extension<Ext>
|
||||
) -> impl IntoResponse {
|
||||
match get_cached_file_or_cache(object_id, object_type, db).await {
|
||||
Some(cached_file) => Json(cached_file).into_response(),
|
||||
None => StatusCode::NOT_FOUND.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn download_cached_file(
|
||||
Path((object_id, object_type)): Path<(i32, String)>,
|
||||
Extension(Ext { db }): Extension<Ext>
|
||||
) -> impl IntoResponse {
|
||||
let cached_file = match get_cached_file_or_cache(object_id, object_type, db).await {
|
||||
Some(cached_file) => cached_file,
|
||||
None => return StatusCode::NO_CONTENT.into_response(),
|
||||
};
|
||||
|
||||
let data = match download_from_cache(cached_file).await {
|
||||
Some(v) => v,
|
||||
None => {
|
||||
return StatusCode::NO_CONTENT.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let filename = data.filename.clone();
|
||||
let filename_ascii = data.filename_ascii.clone();
|
||||
let caption = data.caption.clone();
|
||||
|
||||
let encoder = general_purpose::STANDARD;
|
||||
|
||||
let reader = get_response_async_read(data.response);
|
||||
let stream = ReaderStream::new(reader);
|
||||
let body = StreamBody::new(stream);
|
||||
|
||||
let headers = AppendHeaders([
|
||||
(
|
||||
header::CONTENT_DISPOSITION,
|
||||
format!("attachment; filename={filename_ascii}"),
|
||||
),
|
||||
(
|
||||
header::HeaderName::from_static("x-filename-b64"),
|
||||
encoder.encode(filename),
|
||||
),
|
||||
(
|
||||
header::HeaderName::from_static("x-caption-b64"),
|
||||
encoder.encode(caption)
|
||||
)
|
||||
]);
|
||||
|
||||
(headers, body).into_response()
|
||||
}
|
||||
|
||||
async fn delete_cached_file(
|
||||
Path((object_id, object_type)): Path<(i32, String)>,
|
||||
Extension(Ext { db, .. }): Extension<Ext>
|
||||
) -> impl IntoResponse {
|
||||
let cached_file = db.cached_file()
|
||||
.find_unique(cached_file::object_id_object_type(object_id, object_type.clone()))
|
||||
.exec()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
match cached_file {
|
||||
Some(v) => {
|
||||
db.cached_file()
|
||||
.delete(cached_file::object_id_object_type(object_id, object_type))
|
||||
.exec()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
Json(v).into_response()
|
||||
},
|
||||
None => StatusCode::NOT_FOUND.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn update_cache(
|
||||
_ext: Extension<Ext>
|
||||
) -> impl IntoResponse {
|
||||
StatusCode::OK.into_response() // TODO
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
|
||||
async fn auth<B>(req: Request<B>, next: Next<B>) -> Result<Response, StatusCode> {
|
||||
let auth_header = req.headers()
|
||||
.get(http::header::AUTHORIZATION)
|
||||
.and_then(|header| header.to_str().ok());
|
||||
|
||||
let auth_header = if let Some(auth_header) = auth_header {
|
||||
auth_header
|
||||
} else {
|
||||
return Err(StatusCode::UNAUTHORIZED);
|
||||
};
|
||||
|
||||
if auth_header != CONFIG.api_key {
|
||||
return Err(StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
Ok(next.run(req).await)
|
||||
}
|
||||
|
||||
|
||||
#[derive(Clone)]
|
||||
struct Ext {
|
||||
pub db: Arc<PrismaClient>,
|
||||
}
|
||||
|
||||
|
||||
pub async fn get_router() -> Router {
|
||||
let db = Arc::new(get_prisma_client().await);
|
||||
|
||||
let ext = Ext { db };
|
||||
|
||||
let (prometheus_layer, metric_handle) = PrometheusMetricLayer::pair();
|
||||
|
||||
let app_router = Router::new()
|
||||
.route("/:object_id/:object_type/", get(get_cached_file))
|
||||
.route("/download/:object_id/:object_type/", get(download_cached_file))
|
||||
.route("/:object_id/:object_type/", delete(delete_cached_file))
|
||||
.route("/update_cache", post(update_cache))
|
||||
|
||||
.layer(middleware::from_fn(auth))
|
||||
.layer(Extension(ext))
|
||||
.layer(prometheus_layer);
|
||||
|
||||
let metric_router = Router::new()
|
||||
.route("/metrics", get(|| async move { metric_handle.render() }));
|
||||
|
||||
Router::new()
|
||||
.nest("/api/v1/", app_router)
|
||||
.nest("/", metric_router)
|
||||
.layer(
|
||||
TraceLayer::new_for_http()
|
||||
.make_span_with(trace::DefaultMakeSpan::new()
|
||||
.level(Level::INFO))
|
||||
.on_response(trace::DefaultOnResponse::new()
|
||||
.level(Level::INFO)),
|
||||
)
|
||||
}
|
||||
Reference in New Issue
Block a user