Skip to content

Commit 0a31848

Browse files
Unshurestrands-agentclareliguori
authored
feat(agent): add add_hook convenience method for hook callback registration (#1706)
Co-authored-by: Strands Agent <217235299+strands-agent@users.noreply.github.com> Co-authored-by: Clare Liguori <liguori@amazon.com>
1 parent 0eae8a7 commit 0a31848

4 files changed

Lines changed: 277 additions & 5 deletions

File tree

src/strands/agent/agent.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,12 @@
3737
AfterInvocationEvent,
3838
AgentInitializedEvent,
3939
BeforeInvocationEvent,
40+
HookCallback,
4041
HookProvider,
4142
HookRegistry,
4243
MessageAddedEvent,
4344
)
45+
from ..hooks.registry import TEvent
4446
from ..interrupt import _InterruptState
4547
from ..models.bedrock import BedrockModel
4648
from ..models.model import Model
@@ -574,6 +576,47 @@ def cleanup(self) -> None:
574576
"""
575577
self.tool_registry.cleanup()
576578

579+
def add_hook(
580+
self, callback: HookCallback[TEvent], event_type: type[TEvent] | None = None, **kwargs: dict[str, Any]
581+
) -> None:
582+
"""Register a callback function for a specific event type.
583+
584+
This method supports two call patterns:
585+
1. ``add_hook(callback)`` - Event type inferred from callback's type hint
586+
2. ``add_hook(callback, event_type)`` - Event type specified explicitly
587+
588+
Callbacks can be either synchronous or asynchronous functions.
589+
590+
Args:
591+
callback: The callback function to invoke when events of this type occur.
592+
event_type: The class type of events this callback should handle.
593+
If not provided, the event type will be inferred from the callback's
594+
first parameter type hint.
595+
**kwargs: Additional arguments (ignored).
596+
597+
598+
Raises:
599+
ValueError: If event_type is not provided and cannot be inferred from
600+
the callback's type hints.
601+
602+
Example:
603+
```python
604+
def log_model_call(event: BeforeModelCallEvent) -> None:
605+
print(f"Calling model for agent: {event.agent.name}")
606+
607+
agent = Agent()
608+
609+
# With event type inferred from type hint
610+
agent.add_hook(log_model_call)
611+
612+
# With explicit event type
613+
agent.add_hook(log_model_call, BeforeModelCallEvent)
614+
```
615+
Docs:
616+
https://strandsagents.com/latest/documentation/docs/user-guide/concepts/agents/hooks/
617+
"""
618+
self.hooks.add_callback(event_type, callback)
619+
577620
def __del__(self) -> None:
578621
"""Clean up resources when agent is garbage collected."""
579622
# __del__ is called even when an exception is thrown in the constructor,

src/strands/hooks/registry.py

Lines changed: 83 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,15 @@
1111
import logging
1212
from collections.abc import Awaitable, Generator
1313
from dataclasses import dataclass
14-
from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar, runtime_checkable
14+
from typing import (
15+
TYPE_CHECKING,
16+
Any,
17+
Generic,
18+
Protocol,
19+
TypeVar,
20+
get_type_hints,
21+
runtime_checkable,
22+
)
1523

1624
from ..interrupt import Interrupt, InterruptException
1725

@@ -157,28 +165,99 @@ def __init__(self) -> None:
157165
"""Initialize an empty hook registry."""
158166
self._registered_callbacks: dict[type, list[HookCallback]] = {}
159167

160-
def add_callback(self, event_type: type[TEvent], callback: HookCallback[TEvent]) -> None:
168+
def add_callback(
169+
self,
170+
event_type: type[TEvent] | None,
171+
callback: HookCallback[TEvent],
172+
) -> None:
161173
"""Register a callback function for a specific event type.
162174
175+
If ``event_type`` is None, then this will check the callback handler type hint
176+
for the lifecycle event type.
177+
163178
Args:
164179
event_type: The class type of events this callback should handle.
165180
callback: The callback function to invoke when events of this type occur.
166181
182+
Raises:
183+
ValueError: If event_type is not provided and cannot be inferred from
184+
the callback's type hints, or if AgentInitializedEvent is registered
185+
with an async callback.
186+
167187
Example:
168188
```python
169189
def my_handler(event: StartRequestEvent):
170190
print("Request started")
171191
192+
# With explicit event type
172193
registry.add_callback(StartRequestEvent, my_handler)
194+
195+
# With event type inferred from type hint
196+
registry.add_callback(None, my_handler)
173197
```
174198
"""
199+
resolved_event_type: type[TEvent]
200+
201+
# Support both add_callback(None, callback) and add_callback(event_type, callback)
202+
if event_type is None:
203+
# callback provided but event_type is None - infer it
204+
resolved_event_type = self._infer_event_type(callback)
205+
else:
206+
resolved_event_type = event_type
207+
175208
# Related issue: https://github.com/strands-agents/sdk-python/issues/330
176-
if event_type.__name__ == "AgentInitializedEvent" and inspect.iscoroutinefunction(callback):
209+
if resolved_event_type.__name__ == "AgentInitializedEvent" and inspect.iscoroutinefunction(callback):
177210
raise ValueError("AgentInitializedEvent can only be registered with a synchronous callback")
178211

179-
callbacks = self._registered_callbacks.setdefault(event_type, [])
212+
callbacks = self._registered_callbacks.setdefault(resolved_event_type, [])
180213
callbacks.append(callback)
181214

215+
def _infer_event_type(self, callback: HookCallback[TEvent]) -> type[TEvent]:
216+
"""Infer the event type from a callback's type hints.
217+
218+
Args:
219+
callback: The callback function to inspect.
220+
221+
Returns:
222+
The event type inferred from the callback's first parameter type hint.
223+
224+
Raises:
225+
ValueError: If the event type cannot be inferred from the callback's type hints.
226+
"""
227+
try:
228+
hints = get_type_hints(callback)
229+
except Exception as e:
230+
logger.debug("callback=<%s>, error=<%s> | failed to get type hints", callback, e)
231+
raise ValueError(
232+
"failed to get type hints for callback | cannot infer event type, please provide event_type explicitly"
233+
) from e
234+
235+
# Get the first parameter's type hint
236+
sig = inspect.signature(callback)
237+
params = list(sig.parameters.values())
238+
239+
if not params:
240+
raise ValueError(
241+
"callback has no parameters | cannot infer event type, please provide event_type explicitly"
242+
)
243+
244+
first_param = params[0]
245+
type_hint = hints.get(first_param.name)
246+
247+
if type_hint is None:
248+
raise ValueError(
249+
f"parameter=<{first_param.name}> has no type hint | "
250+
"cannot infer event type, please provide event_type explicitly"
251+
)
252+
253+
# Handle single type
254+
if isinstance(type_hint, type) and issubclass(type_hint, BaseHookEvent):
255+
return type_hint # type: ignore[return-value]
256+
257+
raise ValueError(
258+
f"parameter=<{first_param.name}>, type=<{type_hint}> | type hint must be a subclass of BaseHookEvent"
259+
)
260+
182261
def add_hook(self, hook: HookProvider) -> None:
183262
"""Register all callbacks from a hook provider.
184263

tests/strands/agent/test_agent.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager
2121
from strands.agent.state import AgentState
2222
from strands.handlers.callback_handler import PrintingCallbackHandler, null_callback_handler
23-
from strands.hooks import BeforeToolCallEvent
23+
from strands.hooks import BeforeInvocationEvent, BeforeModelCallEvent, BeforeToolCallEvent
2424
from strands.interrupt import Interrupt
2525
from strands.models.bedrock import DEFAULT_BEDROCK_MODEL_ID, BedrockModel
2626
from strands.session.repository_session_manager import RepositorySessionManager
@@ -2550,3 +2550,72 @@ def agent_tool(tool_context: ToolContext) -> str:
25502550
],
25512551
"role": "user",
25522552
}
2553+
2554+
2555+
def test_agent_add_hook_registers_callback():
2556+
"""Test that add_hook registers a callback with the hooks registry."""
2557+
agent = Agent(model=MockedModelProvider([{"role": "assistant", "content": [{"text": "response"}]}]))
2558+
callback = unittest.mock.Mock()
2559+
2560+
agent.add_hook(callback, BeforeModelCallEvent)
2561+
2562+
# Verify callback was registered by checking it gets invoked
2563+
agent("test prompt")
2564+
callback.assert_called_once()
2565+
# Verify it was called with the correct event type
2566+
call_args = callback.call_args[0]
2567+
assert isinstance(call_args[0], BeforeModelCallEvent)
2568+
2569+
2570+
def test_agent_add_hook_delegates_to_hooks_add_callback():
2571+
"""Test that add_hook delegates to self.hooks.add_callback."""
2572+
agent = Agent(model=MockedModelProvider([{"role": "assistant", "content": [{"text": "response"}]}]))
2573+
callback = unittest.mock.Mock()
2574+
2575+
# Spy on the hooks.add_callback method
2576+
with unittest.mock.patch.object(agent.hooks, "add_callback") as mock_add_callback:
2577+
agent.add_hook(callback, BeforeInvocationEvent)
2578+
mock_add_callback.assert_called_once_with(BeforeInvocationEvent, callback)
2579+
2580+
2581+
@pytest.mark.asyncio
2582+
async def test_agent_add_hook_works_with_async_callback():
2583+
"""Test that add_hook works with async callbacks."""
2584+
2585+
agent = Agent(model=MockedModelProvider([{"role": "assistant", "content": [{"text": "response"}]}]))
2586+
async_callback = unittest.mock.AsyncMock()
2587+
2588+
agent.add_hook(async_callback, BeforeModelCallEvent)
2589+
2590+
# Use stream_async to invoke the agent with async support
2591+
_ = [event async for event in agent.stream_async("test prompt")]
2592+
async_callback.assert_called_once()
2593+
# Verify it was called with the correct event type
2594+
call_args = async_callback.call_args[0]
2595+
assert isinstance(call_args[0], BeforeModelCallEvent)
2596+
2597+
2598+
def test_agent_add_hook_infers_event_type_from_callback():
2599+
"""Test that add_hook infers event type from callback type hint."""
2600+
agent = Agent(model=MockedModelProvider([{"role": "assistant", "content": [{"text": "response"}]}]))
2601+
callback_invoked = []
2602+
2603+
def typed_callback(event: BeforeModelCallEvent) -> None:
2604+
callback_invoked.append(event)
2605+
2606+
agent.add_hook(typed_callback)
2607+
agent("test prompt")
2608+
2609+
assert len(callback_invoked) == 1
2610+
assert isinstance(callback_invoked[0], BeforeModelCallEvent)
2611+
2612+
2613+
def test_agent_add_hook_raises_error_when_no_type_hint():
2614+
"""Test that add_hook raises error when event type cannot be inferred."""
2615+
agent = Agent(model=MockedModelProvider([{"role": "assistant", "content": [{"text": "response"}]}]))
2616+
2617+
def untyped_callback(event):
2618+
pass
2619+
2620+
with pytest.raises(ValueError, match="cannot infer event type"):
2621+
agent.add_hook(untyped_callback)

tests/strands/hooks/test_registry.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,84 @@ def test_hook_registry_invoke_callbacks_coroutine(registry, agent):
8787

8888
with pytest.raises(RuntimeError, match=r"use invoke_callbacks_async to invoke async callback"):
8989
registry.invoke_callbacks(BeforeInvocationEvent(agent=agent))
90+
91+
92+
def test_hook_registry_add_callback_infers_event_type(registry):
93+
"""Test that add_callback infers event type from callback type hint."""
94+
95+
def typed_callback(event: BeforeInvocationEvent) -> None:
96+
pass
97+
98+
# Register without explicit event_type - should infer from type hint
99+
registry.add_callback(None, typed_callback)
100+
101+
# Verify callback was registered
102+
assert BeforeInvocationEvent in registry._registered_callbacks
103+
assert typed_callback in registry._registered_callbacks[BeforeInvocationEvent]
104+
105+
106+
def test_hook_registry_add_callback_raises_error_no_type_hint(registry):
107+
"""Test that add_callback raises error when type hint is missing."""
108+
109+
def untyped_callback(event):
110+
pass
111+
112+
with pytest.raises(ValueError, match="cannot infer event type"):
113+
registry.add_callback(None, untyped_callback)
114+
115+
116+
def test_hook_registry_add_callback_raises_error_invalid_type_hint(registry):
117+
"""Test that add_callback raises error when type hint is not a BaseHookEvent subclass."""
118+
119+
def invalid_callback(event: str) -> None:
120+
pass
121+
122+
with pytest.raises(ValueError, match="must be a subclass of BaseHookEvent"):
123+
registry.add_callback(None, invalid_callback)
124+
125+
126+
def test_hook_registry_add_callback_raises_error_no_parameters(registry):
127+
"""Test that add_callback raises error when callback has no parameters."""
128+
129+
def no_param_callback() -> None:
130+
pass
131+
132+
with pytest.raises(ValueError, match="callback has no parameters"):
133+
registry.add_callback(None, no_param_callback)
134+
135+
136+
def test_hook_registry_add_callback_infers_event_type_when_callback_provided_without_event_type(registry):
137+
"""Test that add_callback infers event type when callback is provided but event_type is None."""
138+
139+
def typed_callback(event: BeforeInvocationEvent) -> None:
140+
pass
141+
142+
registry.add_callback(None, typed_callback)
143+
144+
assert BeforeInvocationEvent in registry._registered_callbacks
145+
assert typed_callback in registry._registered_callbacks[BeforeInvocationEvent]
146+
147+
148+
def test_hook_registry_add_callback_with_explicit_event_type_and_callback(registry):
149+
"""Test that add_callback works with explicit event_type and callback."""
150+
151+
def callback(event: BeforeInvocationEvent) -> None:
152+
pass
153+
154+
registry.add_callback(BeforeInvocationEvent, callback)
155+
156+
assert BeforeInvocationEvent in registry._registered_callbacks
157+
assert callback in registry._registered_callbacks[BeforeInvocationEvent]
158+
159+
160+
def test_hook_registry_add_callback_raises_error_on_type_hints_failure(registry):
161+
"""Test that add_callback raises error when get_type_hints fails."""
162+
163+
class BadCallback:
164+
def __call__(self, event: "NonExistentType") -> None: # noqa: F821
165+
pass
166+
167+
callback = BadCallback()
168+
169+
with pytest.raises(ValueError, match="failed to get type hints for callback"):
170+
registry.add_callback(None, callback)

0 commit comments

Comments
 (0)