This commit is contained in:
2023-08-06 15:04:22 +02:00
parent c472d5038d
commit d1ccc71e4b
3 changed files with 25 additions and 9 deletions

View File

@@ -35,8 +35,11 @@ class FileUploader:
def user_storages(cls) -> list[UserStorage]:
return StoragesContainer.USER_STORAGES
def __init__(self, file: UploadFile, caption: Optional[str] = None) -> None:
def __init__(
self, file: UploadFile, file_size: int, caption: Optional[str] = None
) -> None:
self.file = file
self.file_size = file_size
self.caption = caption
self.upload_data: Optional[Data] = None
@@ -61,7 +64,9 @@ class FileUploader:
wrapped = Wrapper(self.file.file, self.file.filename)
data = await storage.upload(wrapped, caption=self.caption)
data = await storage.upload(
wrapped, file_size=self.file_size, caption=self.caption
)
if not data:
return False
@@ -101,9 +106,9 @@ class FileUploader:
@classmethod
async def upload(
cls, file: UploadFile, caption: Optional[str] = None
cls, file: UploadFile, file_size: int, caption: Optional[str] = None
) -> Optional[UploadedFile]:
uploader = cls(file, caption)
uploader = cls(file, file_size, caption)
upload_result = await uploader._upload()
if not upload_result:

View File

@@ -22,15 +22,22 @@ class BaseStorage:
...
async def upload(
self, file: telethon.hints.FileLike, caption: Optional[str] = None
self,
file: telethon.hints.FileLike,
file_size: int,
caption: Optional[str] = None,
) -> Optional[tuple[Union[str, int], int]]:
try:
uploaded_file = await self.client.upload_file(file, file_size=file_size)
if caption:
message = await self.client.send_file(
entity=self.channel_id, file=file, caption=caption
entity=self.channel_id, file=uploaded_file, caption=caption
)
else:
message = await self.client.send_file(entity=self.channel_id, file=file)
message = await self.client.send_file(
entity=self.channel_id, file=uploaded_file
)
except telethon.errors.FilePartInvalidError:
return None
except telethon.errors.PhotoInvalidError:

View File

@@ -15,8 +15,12 @@ router = APIRouter(
@router.post("/upload/", response_model=UploadedFile)
async def upload_file(file: UploadFile = File({}), caption: Optional[str] = Form({})):
return await FileUploader.upload(file, caption=caption)
async def upload_file(
file: UploadFile = File({}),
file_size: int = Form({}),
caption: Optional[str] = Form({}),
):
return await FileUploader.upload(file, file_size, caption=caption)
@router.get("/download_by_message/{chat_id}/{message_id}")