From 792c65222d85eb0bde55a4a2acb95088799e8057 Mon Sep 17 00:00:00 2001 From: crazybolillo Date: Thu, 23 Apr 2026 23:38:06 -0600 Subject: [PATCH] fix(agent): ignore cancel() when idle (#2156) Previously calling cancel() on an idle agent would make it so the next invocation was immediately cancelled. This is unexpected behavior, cancelling an idle agent should be a noop and have no effect on future invocations. Active agent loops are now tracked and cancel() won't have any effect if there are no active loops. This method works for both concurrent modes, as UNSAFE_REENTRANT does not use the existing invocation lock that could be used to tell if there was a running agent. --- src/strands/agent/agent.py | 17 ++++- .../strands/agent/test_agent_cancellation.py | 69 ++++++++++--------- 2 files changed, 52 insertions(+), 34 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 965969961..06747e481 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -268,6 +268,9 @@ def __init__( # Create internal cancel signal for graceful cancellation using threading.Event self._cancel_signal = threading.Event() + self._invocations = 0 + self._invocations_lock = threading.Lock() + self.tool_registry = ToolRegistry() # Process tool list if provided @@ -368,6 +371,14 @@ def __init__( self.hooks.invoke_callbacks(AgentInitializedEvent(agent=self)) + def _track_invocation(self) -> None: + with self._invocations_lock: + self._invocations += 1 + + def _untrack_invocation(self) -> None: + with self._invocations_lock: + self._invocations -= 1 + def cancel(self) -> None: """Cancel the currently running agent invocation. @@ -397,7 +408,9 @@ def cancel(self) -> None: Note: Multiple calls to cancel() are safe and idempotent. """ - self._cancel_signal.set() + with self._invocations_lock: + if self._invocations > 0: + self._cancel_signal.set() @property def system_prompt(self) -> str | None: @@ -982,6 +995,7 @@ async def _execute_event_loop_cycle( structured_output_context.register_tool(self.tool_registry) try: + self._track_invocation() events = event_loop_cycle( agent=self, invocation_state=invocation_state, @@ -1003,6 +1017,7 @@ async def _execute_event_loop_cycle( yield event finally: + self._untrack_invocation() if structured_output_context: structured_output_context.cleanup(self.tool_registry) diff --git a/tests/strands/agent/test_agent_cancellation.py b/tests/strands/agent/test_agent_cancellation.py index 756e96485..7ac917c4d 100644 --- a/tests/strands/agent/test_agent_cancellation.py +++ b/tests/strands/agent/test_agent_cancellation.py @@ -17,12 +17,26 @@ } +class DelayedModelProvider(MockedModelProvider): + """Model provider that blocks streaming until signaled, for cancel timing tests.""" + + def __init__(self, responses: list) -> None: + super().__init__(responses) + self.streaming_started = asyncio.Event() + self.cancel_ready = asyncio.Event() + + async def stream(self, *args, **kwargs): + self.streaming_started.set() + await self.cancel_ready.wait() + async for event in super().stream(*args, **kwargs): + yield event + + @pytest.mark.asyncio async def test_agent_cancel_before_invocation(): """Test agent.cancel() before invocation starts. - Verifies that calling cancel() before invoke_async() results in - immediate cancellation without any model calls. + Verifies that calling cancel() before invoke_async() (agent runs) has no effect. """ agent = Agent(model=MockedModelProvider([DEFAULT_RESPONSE])) @@ -31,8 +45,8 @@ async def test_agent_cancel_before_invocation(): result = await agent.invoke_async("Hello") - assert result.stop_reason == "cancelled" - assert result.message == {"role": "assistant", "content": [{"text": "Cancelled by user"}], "metadata": ANY} + assert result.stop_reason == "end_turn" + assert result.message == {**DEFAULT_RESPONSE, "metadata": ANY} @pytest.mark.asyncio @@ -42,23 +56,13 @@ async def test_agent_cancel_during_execution(): Verifies that calling cancel() while the agent is running stops execution at the next checkpoint. """ - streaming_started = asyncio.Event() - cancel_ready = asyncio.Event() - - class DelayedModelProvider(MockedModelProvider): - async def stream(self, *args, **kwargs): - streaming_started.set() - # Block until cancel has been called - await cancel_ready.wait() - async for event in super().stream(*args, **kwargs): - yield event - - agent = Agent(model=DelayedModelProvider([DEFAULT_RESPONSE])) + model = DelayedModelProvider([DEFAULT_RESPONSE]) + agent = Agent(model=model) async def cancel_when_ready(): - await streaming_started.wait() + await model.streaming_started.wait() agent.cancel() - cancel_ready.set() + model.cancel_ready.set() cancel_task = asyncio.create_task(cancel_when_ready()) result = await agent.invoke_async("Hello") @@ -128,7 +132,7 @@ async def test_agent_cancel_idempotent(): result = await agent.invoke_async("Hello") - assert result.stop_reason == "cancelled" + assert result.stop_reason == "end_turn" @pytest.mark.asyncio @@ -138,24 +142,16 @@ async def test_agent_cancel_from_thread(): Verifies thread-safety of the cancel() method when called from a background thread. """ - streaming_started = asyncio.Event() - cancel_ready = asyncio.Event() loop = asyncio.get_running_loop() - class DelayedModelProvider(MockedModelProvider): - async def stream(self, *args, **kwargs): - streaming_started.set() - await cancel_ready.wait() - async for event in super().stream(*args, **kwargs): - yield event - - agent = Agent(model=DelayedModelProvider([DEFAULT_RESPONSE])) + model = DelayedModelProvider([DEFAULT_RESPONSE]) + agent = Agent(model=model) def cancel_from_thread(): # Wait for streaming to start before cancelling - asyncio.run_coroutine_threadsafe(streaming_started.wait(), loop).result() + asyncio.run_coroutine_threadsafe(model.streaming_started.wait(), loop).result() agent.cancel() - loop.call_soon_threadsafe(cancel_ready.set) + loop.call_soon_threadsafe(model.cancel_ready.set) thread = threading.Thread(target=cancel_from_thread) thread.start() @@ -279,10 +275,17 @@ async def test_agent_cancel_continue_after(): Verifies that the cancel signal is cleared after an invocation completes, allowing subsequent invocations to run normally. """ - agent = Agent(model=MockedModelProvider([DEFAULT_RESPONSE, DEFAULT_RESPONSE])) + model = DelayedModelProvider([DEFAULT_RESPONSE, DEFAULT_RESPONSE]) + agent = Agent(model=model) - agent.cancel() + async def cancel_when_ready(): + await model.streaming_started.wait() + agent.cancel() + model.cancel_ready.set() + + cancel_task = asyncio.create_task(cancel_when_ready()) result1 = await agent.invoke_async("Hello") + await cancel_task assert result1.stop_reason == "cancelled" # Second invocation should work normally