Skip to content

Commit 9ec38c0

Browse files
xuanyang15copybara-github
authored andcommitted
feat: Add on_model_error_callback in LlmAgent
Co-authored-by: Xuan Yang <xygoogle@google.com> PiperOrigin-RevId: 828560608
1 parent d6b928b commit 9ec38c0

4 files changed

Lines changed: 137 additions & 7 deletions

File tree

src/google/adk/agents/llm_agent.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,16 @@
8585
list[_SingleAfterModelCallback],
8686
]
8787

88+
_SingleOnModelErrorCallback: TypeAlias = Callable[
89+
[CallbackContext, LlmRequest, Exception],
90+
Union[Awaitable[Optional[LlmResponse]], Optional[LlmResponse]],
91+
]
92+
93+
OnModelErrorCallback: TypeAlias = Union[
94+
_SingleOnModelErrorCallback,
95+
list[_SingleOnModelErrorCallback],
96+
]
97+
8898
_SingleBeforeToolCallback: TypeAlias = Callable[
8999
[BaseTool, dict[str, Any], ToolContext],
90100
Union[Awaitable[Optional[dict]], Optional[dict]],
@@ -364,6 +374,21 @@ class LlmAgent(BaseAgent):
364374
The content to return to the user. When present, the actual model response
365375
will be ignored and the provided content will be returned to user.
366376
"""
377+
on_model_error_callback: Optional[OnModelErrorCallback] = None
378+
"""Callback or list of callbacks to be called when a model call encounters an error.
379+
380+
When a list of callbacks is provided, the callbacks will be called in the
381+
order they are listed until a callback does not return None.
382+
383+
Args:
384+
callback_context: CallbackContext,
385+
llm_request: LlmRequest, The raw model request.
386+
error: The error from the model call.
387+
388+
Returns:
389+
The content to return to the user. When present, the error will be
390+
ignored and the provided content will be returned to user.
391+
"""
367392
before_tool_callback: Optional[BeforeToolCallback] = None
368393
"""Callback or list of callbacks to be called before calling the tool.
369394
@@ -587,6 +612,20 @@ def canonical_after_model_callbacks(self) -> list[_SingleAfterModelCallback]:
587612
return self.after_model_callback
588613
return [self.after_model_callback]
589614

615+
@property
616+
def canonical_on_model_error_callbacks(
617+
self,
618+
) -> list[_SingleOnModelErrorCallback]:
619+
"""The resolved self.on_model_error_callback field as a list of _SingleOnModelErrorCallback.
620+
621+
This method is only for use by Agent Development Kit.
622+
"""
623+
if not self.on_model_error_callback:
624+
return []
625+
if isinstance(self.on_model_error_callback, list):
626+
return self.on_model_error_callback
627+
return [self.on_model_error_callback]
628+
590629
@property
591630
def canonical_before_tool_callbacks(
592631
self,

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -977,6 +977,44 @@ async def _run_and_handle_error(
977977
Yields:
978978
A generator of LlmResponse.
979979
"""
980+
981+
from ...agents.llm_agent import LlmAgent
982+
983+
agent = invocation_context.agent
984+
if not isinstance(agent, LlmAgent):
985+
raise TypeError(
986+
f'Expected agent to be an LlmAgent, but got {type(agent)}'
987+
)
988+
989+
async def _run_on_model_error_callbacks(
990+
*,
991+
callback_context: CallbackContext,
992+
llm_request: LlmRequest,
993+
error: Exception,
994+
) -> Optional[LlmResponse]:
995+
error_response = (
996+
await invocation_context.plugin_manager.run_on_model_error_callback(
997+
callback_context=callback_context,
998+
llm_request=llm_request,
999+
error=error,
1000+
)
1001+
)
1002+
if error_response is not None:
1003+
return error_response
1004+
1005+
for callback in agent.canonical_on_model_error_callbacks:
1006+
error_response = callback(
1007+
callback_context=callback_context,
1008+
llm_request=llm_request,
1009+
error=error,
1010+
)
1011+
if inspect.isawaitable(error_response):
1012+
error_response = await error_response
1013+
if error_response is not None:
1014+
return error_response
1015+
1016+
return None
1017+
9801018
try:
9811019
async with Aclosing(response_generator) as agen:
9821020
async for response in agen:
@@ -985,12 +1023,10 @@ async def _run_and_handle_error(
9851023
callback_context = CallbackContext(
9861024
invocation_context, event_actions=model_response_event.actions
9871025
)
988-
error_response = (
989-
await invocation_context.plugin_manager.run_on_model_error_callback(
990-
callback_context=callback_context,
991-
llm_request=llm_request,
992-
error=model_error,
993-
)
1026+
error_response = await _run_on_model_error_callbacks(
1027+
callback_context=callback_context,
1028+
llm_request=llm_request,
1029+
error=model_error,
9941030
)
9951031
if error_response is not None:
9961032
yield error_response

tests/unittests/flows/llm_flows/test_model_callbacks.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,22 @@ def __call__(
5656
)
5757

5858

59+
class MockOnModelCallback(BaseModel):
60+
mock_response: str
61+
62+
def __call__(
63+
self,
64+
callback_context: CallbackContext,
65+
llm_request: LlmRequest,
66+
error: Exception,
67+
) -> LlmResponse:
68+
return LlmResponse(
69+
content=testing_utils.ModelContent(
70+
[types.Part.from_text(text=self.mock_response)]
71+
)
72+
)
73+
74+
5975
def noop_callback(**kwargs) -> Optional[LlmResponse]:
6076
pass
6177

@@ -140,3 +156,40 @@ async def test_after_model_callback_noop():
140156
assert testing_utils.simplify_events(
141157
await runner.run_async_with_new_session('test')
142158
) == [('root_agent', 'model_response')]
159+
160+
161+
@pytest.mark.asyncio
162+
async def test_on_model_callback_model_error_noop():
163+
"""Test that the on_model_error_callback is a no-op when the model returns an error."""
164+
mock_model = testing_utils.MockModel.create(
165+
responses=[], error=SystemError('error')
166+
)
167+
agent = Agent(
168+
name='root_agent',
169+
model=mock_model,
170+
on_model_error_callback=noop_callback,
171+
)
172+
173+
runner = testing_utils.TestInMemoryRunner(agent)
174+
with pytest.raises(SystemError):
175+
await runner.run_async_with_new_session('test')
176+
177+
178+
@pytest.mark.asyncio
179+
async def test_on_model_callback_model_error_modify_model_response():
180+
"""Test that the on_model_error_callback can modify the model response."""
181+
mock_model = testing_utils.MockModel.create(
182+
responses=[], error=SystemError('error')
183+
)
184+
agent = Agent(
185+
name='root_agent',
186+
model=mock_model,
187+
on_model_error_callback=MockOnModelCallback(
188+
mock_response='on_model_error_callback_response'
189+
),
190+
)
191+
192+
runner = testing_utils.TestInMemoryRunner(agent)
193+
assert testing_utils.simplify_events(
194+
await runner.run_async_with_new_session('test')
195+
) == [('root_agent', 'on_model_error_callback_response')]

tests/unittests/testing_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ def supported_models(cls) -> list[str]:
363363
def generate_content(
364364
self, llm_request: LlmRequest, stream: bool = False
365365
) -> Generator[LlmResponse, None, None]:
366-
if self.error:
366+
if self.error is not None:
367367
raise self.error
368368
# Increasement of the index has to happen before the yield.
369369
self.response_index += 1
@@ -375,6 +375,8 @@ def generate_content(
375375
async def generate_content_async(
376376
self, llm_request: LlmRequest, stream: bool = False
377377
) -> AsyncGenerator[LlmResponse, None]:
378+
if self.error is not None:
379+
raise self.error
378380
# Increasement of the index has to happen before the yield.
379381
self.response_index += 1
380382
self.requests.append(llm_request)

0 commit comments

Comments
 (0)