Skip to content
1 change: 0 additions & 1 deletion contributing/samples/gepa/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
from tau_bench.types import EnvRunResult
from tau_bench.types import RunConfig
import tau_bench_agent as tau_bench_agent_lib

import utils


Expand Down
1 change: 0 additions & 1 deletion contributing/samples/gepa/run_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from absl import flags
import experiment
from google.genai import types

import utils

_OUTPUT_DIR = flags.DEFINE_string(
Expand Down
44 changes: 36 additions & 8 deletions src/google/adk/agents/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,10 @@ class MyAgent(BaseAgent):

Returns:
Optional[types.Content]: The content to return to the user.
When the content is present, an additional event with the provided content
will be appended to event history as an additional agent response.
When the content is present, it will replace the agent's original output.
The callback's content will be returned as the final agent response instead
of the original response. When None is returned, the original agent output
is used.
"""

def _load_agent_state(
Expand Down Expand Up @@ -287,15 +289,28 @@ async def run_async(
if ctx.end_invocation:
return

has_after_callback = bool(self.canonical_after_agent_callbacks)

final_response_events = []
async with Aclosing(self._run_async_impl(ctx)) as agen:
async for event in agen:
yield event
if event.is_final_response() and has_after_callback:
modified_event = event.model_copy(update={'partial': True})
final_response_events.append(event)
yield modified_event
else:
yield event

if ctx.end_invocation:
return

if event := await self._handle_after_agent_callback(ctx):
yield event
callback_event = await self._handle_after_agent_callback(ctx)

if callback_event:
yield callback_event
elif final_response_events:
for event in final_response_events:
yield event
Comment thread
dylan-apex marked this conversation as resolved.
Outdated

@final
async def run_live(
Expand All @@ -320,13 +335,26 @@ async def run_live(
if ctx.end_invocation:
return

has_after_callback = bool(self.canonical_after_agent_callbacks)

final_response_events = []
async with Aclosing(self._run_live_impl(ctx)) as agen:
async for event in agen:
if event.is_final_response() and has_after_callback:
modified_event = event.model_copy(update={'partial': True})
final_response_events.append(event)
yield modified_event
else:
yield event

callback_event = await self._handle_after_agent_callback(ctx)

if callback_event:
yield callback_event
elif final_response_events:
for event in final_response_events:
yield event
Comment thread
dylan-apex marked this conversation as resolved.
Outdated

if event := await self._handle_after_agent_callback(ctx):
yield event

async def _run_async_impl(
self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]:
Expand Down
39 changes: 21 additions & 18 deletions tests/unittests/agents/test_base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ def mock_sync_agent_cb_side_effect(
('callback_3_response', CallbackType.SYNC),
(None, CallbackType.ASYNC),
],
['Hello, world!', 'callback_2_response'],
['callback_2_response'],
[1, 1, 0, 0],
id='middle_async_callback_returns',
),
Expand All @@ -424,7 +424,7 @@ def mock_sync_agent_cb_side_effect(
('callback_1_response', CallbackType.SYNC),
('callback_2_response', CallbackType.ASYNC),
],
['Hello, world!', 'callback_1_response'],
['callback_1_response'],
[1, 0],
id='first_sync_callback_returns',
),
Expand Down Expand Up @@ -467,7 +467,8 @@ async def test_before_agent_callbacks_chain(
request.function.__name__, agent
)
result = [e async for e in agent.run_async(parent_ctx)]
assert testing_utils.simplify_events(result) == [
final_events = [e for e in result if e.is_final_response()]
Comment thread
dylan-apex marked this conversation as resolved.
Outdated
assert testing_utils.simplify_events(final_events) == [
(f'{request.function.__name__}_test_agent', response)
for response in expected_responses
]
Expand Down Expand Up @@ -528,7 +529,8 @@ async def test_after_agent_callbacks_chain(
request.function.__name__, agent
)
result = [e async for e in agent.run_async(parent_ctx)]
assert testing_utils.simplify_events(result) == [
final_events = [e for e in result if e.is_final_response()]
assert testing_utils.simplify_events(final_events) == [
(f'{request.function.__name__}_test_agent', response)
for response in expected_responses
]
Expand Down Expand Up @@ -575,10 +577,9 @@ async def test_run_async_after_agent_callback_use_plugin(

# Assert
spy_after_agent_callback.assert_not_called()
# The first event is regular model response, the second event is
# after_agent_callback response.
assert len(events) == 2
assert events[1].content.parts[0].text == mock_plugin.after_agent_text
final_events = [e for e in events if e.is_final_response()]
assert len(final_events) == 1
assert final_events[0].content.parts[0].text == mock_plugin.after_agent_text


@pytest.mark.asyncio
Expand All @@ -604,7 +605,8 @@ async def test_run_async_after_agent_callback_noop(
_, kwargs = spy_after_agent_callback.call_args
assert 'callback_context' in kwargs
assert isinstance(kwargs['callback_context'], CallbackContext)
assert len(events) == 1
final_events = [e for e in events if e.is_final_response()]
assert len(final_events) == 1


@pytest.mark.asyncio
Expand All @@ -630,7 +632,8 @@ async def test_run_async_with_async_after_agent_callback_noop(
_, kwargs = spy_after_agent_callback.call_args
assert 'callback_context' in kwargs
assert isinstance(kwargs['callback_context'], CallbackContext)
assert len(events) == 1
final_events = [e for e in events if e.is_final_response()]
assert len(final_events) == 1


@pytest.mark.asyncio
Expand All @@ -649,11 +652,11 @@ async def test_run_async_after_agent_callback_append_reply(
# Act
events = [e async for e in agent.run_async(parent_ctx)]

# Assert
assert len(events) == 2
assert events[1].author == agent.name
final_events = [e for e in events if e.is_final_response()]
assert len(final_events) == 1
assert final_events[0].author == agent.name
assert (
events[1].content.parts[0].text
final_events[0].content.parts[0].text
== 'Agent reply from after agent callback.'
)

Expand All @@ -674,11 +677,11 @@ async def test_run_async_with_async_after_agent_callback_append_reply(
# Act
events = [e async for e in agent.run_async(parent_ctx)]

# Assert
assert len(events) == 2
assert events[1].author == agent.name
final_events = [e for e in events if e.is_final_response()]
assert len(final_events) == 1
assert final_events[0].author == agent.name
assert (
events[1].content.parts[0].text
final_events[0].content.parts[0].text
== 'Agent reply from after agent callback.'
)

Expand Down