Skip to content

Commit 7c87c6a

Browse files
author
onefeng
committed
fix: run send middlewares when requeueing tasks
1 parent 9f8db96 commit 7c87c6a

2 files changed

Lines changed: 49 additions & 1 deletion

File tree

taskiq/context.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from typing import TYPE_CHECKING
22

33
from taskiq.abc.broker import AsyncBroker
4+
from taskiq.abc.middleware import TaskiqMiddleware
45
from taskiq.exceptions import NoResultError, TaskRejectedError
56
from taskiq.message import TaskiqMessage
7+
from taskiq.utils import maybe_awaitable
68

79
if TYPE_CHECKING: # pragma: no cover
810
from taskiq.state import TaskiqState
@@ -30,7 +32,17 @@ async def requeue(self) -> None:
3032
requeue_count = int(self.message.labels.get("X-Taskiq-requeue", 0))
3133
requeue_count += 1
3234
self.message.labels["X-Taskiq-requeue"] = str(requeue_count)
33-
await self.broker.kick(self.broker.formatter.dumps(self.message))
35+
message = self.message
36+
for middleware in self.broker.middlewares:
37+
if middleware.__class__.pre_send != TaskiqMiddleware.pre_send:
38+
message = await maybe_awaitable(middleware.pre_send(message))
39+
40+
await self.broker.kick(self.broker.formatter.dumps(message))
41+
42+
for middleware in reversed(self.broker.middlewares):
43+
if middleware.__class__.post_send != TaskiqMiddleware.post_send:
44+
await maybe_awaitable(middleware.post_send(message))
45+
3446
raise NoResultError
3547

3648
def reject(self) -> None:

tests/test_requeue.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from taskiq import Context, InMemoryBroker, TaskiqDepends
2+
from taskiq.abc.middleware import TaskiqMiddleware
3+
from taskiq.message import TaskiqMessage
24

35

46
async def test_requeue() -> None:
@@ -46,3 +48,37 @@ async def task(_: None = TaskiqDepends(dep_func)) -> None:
4648
)
4749

4850
assert runs_count == 2
51+
52+
53+
async def test_requeue_triggers_send_middlewares() -> None:
54+
broker = InMemoryBroker()
55+
runs_count = 0
56+
57+
class CountingMiddleware(TaskiqMiddleware):
58+
def __init__(self) -> None:
59+
super().__init__()
60+
self.pre_send_calls = 0
61+
self.post_send_calls = 0
62+
63+
def pre_send(self, message: TaskiqMessage) -> TaskiqMessage:
64+
self.pre_send_calls += 1
65+
return message
66+
67+
def post_send(self, message: TaskiqMessage) -> None:
68+
self.post_send_calls += 1
69+
70+
middleware = CountingMiddleware()
71+
broker.add_middlewares(middleware)
72+
73+
@broker.task
74+
async def task(context: Context = TaskiqDepends()) -> None:
75+
nonlocal runs_count
76+
runs_count += 1
77+
if runs_count < 2:
78+
await context.requeue()
79+
80+
kicked = await task.kiq()
81+
await kicked.wait_result()
82+
83+
assert middleware.pre_send_calls == 2
84+
assert middleware.post_send_calls == 2

0 commit comments

Comments
 (0)