This commit is contained in:
2023-05-21 01:54:57 +02:00
parent 4fc6f53fee
commit 4d3aca532f

View File

@@ -1,35 +1,35 @@
import logging
from inspect import signature
from typing import Any
from taskiq import SimpleRetryMiddleware
from taskiq.message import TaskiqMessage
from taskiq.result import TaskiqResult
logger = logging.getLogger("taskiq_middleware")
from taskiq_dependencies.dependency import Dependency
class FastAPIREtryMiddleware(SimpleRetryMiddleware):
@staticmethod
def _is_need_to_remove(to_remove: list[Any], value: Any) -> bool:
return type(value) in to_remove
def _remove_depends(
task_func: Any, message_kwargs: dict[str, Any]
) -> dict[str, Any]:
sig = signature(task_func)
for key in message_kwargs.keys():
param = sig.parameters.get(key, None)
if param is None:
continue
if isinstance(param.default, Dependency):
message_kwargs.pop(key)
return message_kwargs
async def on_error(
self, message: TaskiqMessage, result: TaskiqResult[Any], exception: Exception
) -> None:
types_to_remove = list(self.broker.custom_dependency_context.keys())
task_func = self.broker.available_tasks[message.task_name].original_func
message.args = [
arg
for arg in message.args
if not self._is_need_to_remove(types_to_remove, arg)
]
message.kwargs = {
key: value
for key, value in message.kwargs.items()
if not self._is_need_to_remove(types_to_remove, value)
}
raise Exception(f"{self.broker.custom_dependency_context=} {message=}")
message.kwargs = self._remove_depends(task_func, message.kwargs)
return await super().on_error(message, result, exception)