|
11 | 11 | import logging |
12 | 12 | from collections.abc import Awaitable, Generator |
13 | 13 | from dataclasses import dataclass |
14 | | -from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar, runtime_checkable |
| 14 | +from typing import ( |
| 15 | + TYPE_CHECKING, |
| 16 | + Any, |
| 17 | + Generic, |
| 18 | + Protocol, |
| 19 | + TypeVar, |
| 20 | + get_type_hints, |
| 21 | + runtime_checkable, |
| 22 | +) |
15 | 23 |
|
16 | 24 | from ..interrupt import Interrupt, InterruptException |
17 | 25 |
|
@@ -157,28 +165,99 @@ def __init__(self) -> None: |
157 | 165 | """Initialize an empty hook registry.""" |
158 | 166 | self._registered_callbacks: dict[type, list[HookCallback]] = {} |
159 | 167 |
|
160 | | - def add_callback(self, event_type: type[TEvent], callback: HookCallback[TEvent]) -> None: |
| 168 | + def add_callback( |
| 169 | + self, |
| 170 | + event_type: type[TEvent] | None, |
| 171 | + callback: HookCallback[TEvent], |
| 172 | + ) -> None: |
161 | 173 | """Register a callback function for a specific event type. |
162 | 174 |
|
| 175 | + If ``event_type`` is None, then this will check the callback handler type hint |
| 176 | + for the lifecycle event type. |
| 177 | +
|
163 | 178 | Args: |
164 | 179 | event_type: The class type of events this callback should handle. |
165 | 180 | callback: The callback function to invoke when events of this type occur. |
166 | 181 |
|
| 182 | + Raises: |
| 183 | + ValueError: If event_type is not provided and cannot be inferred from |
| 184 | + the callback's type hints, or if AgentInitializedEvent is registered |
| 185 | + with an async callback. |
| 186 | +
|
167 | 187 | Example: |
168 | 188 | ```python |
169 | 189 | def my_handler(event: StartRequestEvent): |
170 | 190 | print("Request started") |
171 | 191 |
|
| 192 | + # With explicit event type |
172 | 193 | registry.add_callback(StartRequestEvent, my_handler) |
| 194 | +
|
| 195 | + # With event type inferred from type hint |
| 196 | + registry.add_callback(None, my_handler) |
173 | 197 | ``` |
174 | 198 | """ |
| 199 | + resolved_event_type: type[TEvent] |
| 200 | + |
| 201 | + # Support both add_callback(None, callback) and add_callback(event_type, callback) |
| 202 | + if event_type is None: |
| 203 | + # callback provided but event_type is None - infer it |
| 204 | + resolved_event_type = self._infer_event_type(callback) |
| 205 | + else: |
| 206 | + resolved_event_type = event_type |
| 207 | + |
175 | 208 | # Related issue: https://github.com/strands-agents/sdk-python/issues/330 |
176 | | - if event_type.__name__ == "AgentInitializedEvent" and inspect.iscoroutinefunction(callback): |
| 209 | + if resolved_event_type.__name__ == "AgentInitializedEvent" and inspect.iscoroutinefunction(callback): |
177 | 210 | raise ValueError("AgentInitializedEvent can only be registered with a synchronous callback") |
178 | 211 |
|
179 | | - callbacks = self._registered_callbacks.setdefault(event_type, []) |
| 212 | + callbacks = self._registered_callbacks.setdefault(resolved_event_type, []) |
180 | 213 | callbacks.append(callback) |
181 | 214 |
|
| 215 | + def _infer_event_type(self, callback: HookCallback[TEvent]) -> type[TEvent]: |
| 216 | + """Infer the event type from a callback's type hints. |
| 217 | +
|
| 218 | + Args: |
| 219 | + callback: The callback function to inspect. |
| 220 | +
|
| 221 | + Returns: |
| 222 | + The event type inferred from the callback's first parameter type hint. |
| 223 | +
|
| 224 | + Raises: |
| 225 | + ValueError: If the event type cannot be inferred from the callback's type hints. |
| 226 | + """ |
| 227 | + try: |
| 228 | + hints = get_type_hints(callback) |
| 229 | + except Exception as e: |
| 230 | + logger.debug("callback=<%s>, error=<%s> | failed to get type hints", callback, e) |
| 231 | + raise ValueError( |
| 232 | + "failed to get type hints for callback | cannot infer event type, please provide event_type explicitly" |
| 233 | + ) from e |
| 234 | + |
| 235 | + # Get the first parameter's type hint |
| 236 | + sig = inspect.signature(callback) |
| 237 | + params = list(sig.parameters.values()) |
| 238 | + |
| 239 | + if not params: |
| 240 | + raise ValueError( |
| 241 | + "callback has no parameters | cannot infer event type, please provide event_type explicitly" |
| 242 | + ) |
| 243 | + |
| 244 | + first_param = params[0] |
| 245 | + type_hint = hints.get(first_param.name) |
| 246 | + |
| 247 | + if type_hint is None: |
| 248 | + raise ValueError( |
| 249 | + f"parameter=<{first_param.name}> has no type hint | " |
| 250 | + "cannot infer event type, please provide event_type explicitly" |
| 251 | + ) |
| 252 | + |
| 253 | + # Handle single type |
| 254 | + if isinstance(type_hint, type) and issubclass(type_hint, BaseHookEvent): |
| 255 | + return type_hint # type: ignore[return-value] |
| 256 | + |
| 257 | + raise ValueError( |
| 258 | + f"parameter=<{first_param.name}>, type=<{type_hint}> | type hint must be a subclass of BaseHookEvent" |
| 259 | + ) |
| 260 | + |
182 | 261 | def add_hook(self, hook: HookProvider) -> None: |
183 | 262 | """Register all callbacks from a hook provider. |
184 | 263 |
|
|
0 commit comments