diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 965969961..06747e481 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -268,6 +268,9 @@ def __init__( # Create internal cancel signal for graceful cancellation using threading.Event self._cancel_signal = threading.Event() + self._invocations = 0 + self._invocations_lock = threading.Lock() + self.tool_registry = ToolRegistry() # Process tool list if provided @@ -368,6 +371,14 @@ def __init__( self.hooks.invoke_callbacks(AgentInitializedEvent(agent=self)) + def _track_invocation(self) -> None: + with self._invocations_lock: + self._invocations += 1 + + def _untrack_invocation(self) -> None: + with self._invocations_lock: + self._invocations -= 1 + def cancel(self) -> None: """Cancel the currently running agent invocation. @@ -397,7 +408,9 @@ def cancel(self) -> None: Note: Multiple calls to cancel() are safe and idempotent. """ - self._cancel_signal.set() + with self._invocations_lock: + if self._invocations > 0: + self._cancel_signal.set() @property def system_prompt(self) -> str | None: @@ -982,6 +995,7 @@ async def _execute_event_loop_cycle( structured_output_context.register_tool(self.tool_registry) try: + self._track_invocation() events = event_loop_cycle( agent=self, invocation_state=invocation_state, @@ -1003,6 +1017,7 @@ async def _execute_event_loop_cycle( yield event finally: + self._untrack_invocation() if structured_output_context: structured_output_context.cleanup(self.tool_registry) diff --git a/tests/strands/agent/test_agent_cancellation.py b/tests/strands/agent/test_agent_cancellation.py index 756e96485..7ac917c4d 100644 --- a/tests/strands/agent/test_agent_cancellation.py +++ b/tests/strands/agent/test_agent_cancellation.py @@ -17,12 +17,26 @@ } +class DelayedModelProvider(MockedModelProvider): + """Model provider that blocks streaming until signaled, for cancel timing tests.""" + + def __init__(self, responses: list) -> None: + super().__init__(responses) + self.streaming_started = asyncio.Event() + self.cancel_ready = asyncio.Event() + + async def stream(self, *args, **kwargs): + self.streaming_started.set() + await self.cancel_ready.wait() + async for event in super().stream(*args, **kwargs): + yield event + + @pytest.mark.asyncio async def test_agent_cancel_before_invocation(): """Test agent.cancel() before invocation starts. - Verifies that calling cancel() before invoke_async() results in - immediate cancellation without any model calls. + Verifies that calling cancel() before invoke_async() (agent runs) has no effect. """ agent = Agent(model=MockedModelProvider([DEFAULT_RESPONSE])) @@ -31,8 +45,8 @@ async def test_agent_cancel_before_invocation(): result = await agent.invoke_async("Hello") - assert result.stop_reason == "cancelled" - assert result.message == {"role": "assistant", "content": [{"text": "Cancelled by user"}], "metadata": ANY} + assert result.stop_reason == "end_turn" + assert result.message == {**DEFAULT_RESPONSE, "metadata": ANY} @pytest.mark.asyncio @@ -42,23 +56,13 @@ async def test_agent_cancel_during_execution(): Verifies that calling cancel() while the agent is running stops execution at the next checkpoint. """ - streaming_started = asyncio.Event() - cancel_ready = asyncio.Event() - - class DelayedModelProvider(MockedModelProvider): - async def stream(self, *args, **kwargs): - streaming_started.set() - # Block until cancel has been called - await cancel_ready.wait() - async for event in super().stream(*args, **kwargs): - yield event - - agent = Agent(model=DelayedModelProvider([DEFAULT_RESPONSE])) + model = DelayedModelProvider([DEFAULT_RESPONSE]) + agent = Agent(model=model) async def cancel_when_ready(): - await streaming_started.wait() + await model.streaming_started.wait() agent.cancel() - cancel_ready.set() + model.cancel_ready.set() cancel_task = asyncio.create_task(cancel_when_ready()) result = await agent.invoke_async("Hello") @@ -128,7 +132,7 @@ async def test_agent_cancel_idempotent(): result = await agent.invoke_async("Hello") - assert result.stop_reason == "cancelled" + assert result.stop_reason == "end_turn" @pytest.mark.asyncio @@ -138,24 +142,16 @@ async def test_agent_cancel_from_thread(): Verifies thread-safety of the cancel() method when called from a background thread. """ - streaming_started = asyncio.Event() - cancel_ready = asyncio.Event() loop = asyncio.get_running_loop() - class DelayedModelProvider(MockedModelProvider): - async def stream(self, *args, **kwargs): - streaming_started.set() - await cancel_ready.wait() - async for event in super().stream(*args, **kwargs): - yield event - - agent = Agent(model=DelayedModelProvider([DEFAULT_RESPONSE])) + model = DelayedModelProvider([DEFAULT_RESPONSE]) + agent = Agent(model=model) def cancel_from_thread(): # Wait for streaming to start before cancelling - asyncio.run_coroutine_threadsafe(streaming_started.wait(), loop).result() + asyncio.run_coroutine_threadsafe(model.streaming_started.wait(), loop).result() agent.cancel() - loop.call_soon_threadsafe(cancel_ready.set) + loop.call_soon_threadsafe(model.cancel_ready.set) thread = threading.Thread(target=cancel_from_thread) thread.start() @@ -279,10 +275,17 @@ async def test_agent_cancel_continue_after(): Verifies that the cancel signal is cleared after an invocation completes, allowing subsequent invocations to run normally. """ - agent = Agent(model=MockedModelProvider([DEFAULT_RESPONSE, DEFAULT_RESPONSE])) + model = DelayedModelProvider([DEFAULT_RESPONSE, DEFAULT_RESPONSE]) + agent = Agent(model=model) - agent.cancel() + async def cancel_when_ready(): + await model.streaming_started.wait() + agent.cancel() + model.cancel_ready.set() + + cancel_task = asyncio.create_task(cancel_when_ready()) result1 = await agent.invoke_async("Hello") + await cancel_task assert result1.stop_reason == "cancelled" # Second invocation should work normally