This commit is contained in:
2023-05-21 01:54:57 +02:00
parent 4fc6f53fee
commit 4d3aca532f

View File

@@ -1,35 +1,35 @@
import logging from inspect import signature
from typing import Any from typing import Any
from taskiq import SimpleRetryMiddleware from taskiq import SimpleRetryMiddleware
from taskiq.message import TaskiqMessage from taskiq.message import TaskiqMessage
from taskiq.result import TaskiqResult from taskiq.result import TaskiqResult
from taskiq_dependencies.dependency import Dependency
logger = logging.getLogger("taskiq_middleware")
class FastAPIREtryMiddleware(SimpleRetryMiddleware): class FastAPIREtryMiddleware(SimpleRetryMiddleware):
@staticmethod @staticmethod
def _is_need_to_remove(to_remove: list[Any], value: Any) -> bool: def _remove_depends(
return type(value) in to_remove task_func: Any, message_kwargs: dict[str, Any]
) -> dict[str, Any]:
sig = signature(task_func)
for key in message_kwargs.keys():
param = sig.parameters.get(key, None)
if param is None:
continue
if isinstance(param.default, Dependency):
message_kwargs.pop(key)
return message_kwargs
async def on_error( async def on_error(
self, message: TaskiqMessage, result: TaskiqResult[Any], exception: Exception self, message: TaskiqMessage, result: TaskiqResult[Any], exception: Exception
) -> None: ) -> None:
types_to_remove = list(self.broker.custom_dependency_context.keys()) task_func = self.broker.available_tasks[message.task_name].original_func
message.args = [ message.kwargs = self._remove_depends(task_func, message.kwargs)
arg
for arg in message.args
if not self._is_need_to_remove(types_to_remove, arg)
]
message.kwargs = {
key: value
for key, value in message.kwargs.items()
if not self._is_need_to_remove(types_to_remove, value)
}
raise Exception(f"{self.broker.custom_dependency_context=} {message=}")
return await super().on_error(message, result, exception) return await super().on_error(message, result, exception)