Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 30 additions & 17 deletions src/google/adk/flows/llm_flows/contents.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,35 @@ def _is_timestamp_compacted(ts: float) -> bool:
return [event for _, _, event in processed_items]


def _filter_rewound_events(events: list[Event]) -> list[Event]:
"""Returns events with those annulled by a rewind removed.

Iterates backward; when a rewind marker is found, skips all events
back to the rewind_before_invocation_id.

Args:
events: The full event list from the session.

Returns:
A new list with rewound events removed, in the original order.
"""
filtered = []
i = len(events) - 1
while i >= 0:
event = events[i]
if event.actions and event.actions.rewind_before_invocation_id:
rewind_id = event.actions.rewind_before_invocation_id
for j in range(0, i):
if events[j].invocation_id == rewind_id:
i = j
break
else:
filtered.append(event)
i -= 1
filtered.reverse()
return filtered
Comment thread
rahulmansharamani14 marked this conversation as resolved.
Outdated


def _get_contents(
current_branch: Optional[str],
events: list[Event],
Expand All @@ -430,23 +459,7 @@ def _get_contents(
accumulated_output_transcription = ''

# Filter out events that are annulled by a rewind.
# By iterating backward, when a rewind event is found, we skip all events
# from that point back to the `rewind_before_invocation_id`, thus removing
# them from the history used for the LLM request.
rewind_filtered_events = []
i = len(events) - 1
while i >= 0:
event = events[i]
if event.actions and event.actions.rewind_before_invocation_id:
rewind_invocation_id = event.actions.rewind_before_invocation_id
for j in range(0, i, 1):
if events[j].invocation_id == rewind_invocation_id:
i = j
break
else:
rewind_filtered_events.append(event)
i -= 1
rewind_filtered_events.reverse()
rewind_filtered_events = _filter_rewound_events(events)
Comment thread
rahulmansharamani14 marked this conversation as resolved.
Outdated

# Parse the events, leaving the contents and the function calls and
# responses from the current agent.
Expand Down
5 changes: 3 additions & 2 deletions src/google/adk/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -1055,7 +1055,8 @@ def _find_agent_to_run(
# the agent that returned the corresponding function call regardless the
# type of the agent. e.g. a remote a2a agent may surface a credential
# request as a special long-running function tool call.
event = find_matching_function_call(session.events)
filtered_events = contents._filter_rewound_events(session.events)
Comment thread
rahulmansharamani14 marked this conversation as resolved.
Outdated
event = find_matching_function_call(filtered_events)
if event and event.author:
return root_agent.find_agent(event.author)

Expand All @@ -1067,7 +1068,7 @@ def _event_filter(event: Event) -> bool:
return False
return True

for event in filter(_event_filter, reversed(session.events)):
for event in filter(_event_filter, reversed(filtered_events)):
if event.author == root_agent.name:
# Found root agent.
return root_agent
Expand Down
91 changes: 91 additions & 0 deletions tests/unittests/test_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
from google.adk.cli.utils.agent_loader import AgentLoader
from google.adk.events.event import Event
from google.adk.events.event import EventActions
from google.adk.flows.llm_flows.contents import _filter_rewound_events
Comment thread
rahulmansharamani14 marked this conversation as resolved.
Outdated
from google.adk.plugins.base_plugin import BasePlugin
from google.adk.runners import Runner
from google.adk.sessions.in_memory_session_service import InMemorySessionService
Expand Down Expand Up @@ -641,6 +643,95 @@ def test_is_transferable_across_agent_tree_with_non_llm_agent(self):
assert result is False


def test_find_agent_to_run_ignores_rewound_sub_agent_event():
"""After a rewind, events from the rewound invocation are ignored."""
root_agent = MockLlmAgent("root_agent")
sub_agent1 = MockLlmAgent("sub_agent1", parent_agent=root_agent)
root_agent.sub_agents = [sub_agent1]

runner = Runner(
app_name="test_app",
agent=root_agent,
session_service=InMemorySessionService(),
artifact_service=InMemoryArtifactService(),
)

# sub_agent1 was the last active agent during inv1
sub_agent_event = Event(
invocation_id="inv1",
author="sub_agent1",
content=types.Content(
role="model", parts=[types.Part(text="Sub agent response")]
),
)
# Rewind event that annuls inv1 and everything after it
rewind_event = Event(
invocation_id="inv2",
author="user",
actions=EventActions(rewind_before_invocation_id="inv1"),
)
session = Session(
id="test_session",
user_id="test_user",
app_name="test_app",
events=[sub_agent_event, rewind_event],
)

result = runner._find_agent_to_run(session, root_agent)
assert result == root_agent


def test_find_agent_to_run_ignores_rewound_function_call():
"""After a rewind, a function call from the rewound invocation is not matched."""
root_agent = MockLlmAgent("root_agent")
sub_agent2 = MockLlmAgent("sub_agent2", parent_agent=root_agent)
root_agent.sub_agents = [sub_agent2]

runner = Runner(
app_name="test_app",
agent=root_agent,
session_service=InMemorySessionService(),
artifact_service=InMemoryArtifactService(),
)

function_call = types.FunctionCall(id="func_789", name="test_func", args={})
function_response = types.FunctionResponse(
id="func_789", name="test_func", response={}
)

# sub_agent2 issued a function call in inv1
call_event = Event(
invocation_id="inv1",
author="sub_agent2",
content=types.Content(
role="model", parts=[types.Part(function_call=function_call)]
),
)
# User provides the function response, also in inv1
response_event = Event(
invocation_id="inv1",
author="user",
content=types.Content(
role="user", parts=[types.Part(function_response=function_response)]
),
)
# Rewind event that annuls inv1
rewind_event = Event(
invocation_id="inv2",
author="user",
actions=EventActions(rewind_before_invocation_id="inv1"),
)
session = Session(
id="test_session",
user_id="test_user",
app_name="test_app",
events=[call_event, response_event, rewind_event],
)

# The rewound function call should not be matched; root_agent is returned
result = runner._find_agent_to_run(session, root_agent)
assert result == root_agent

@pytest.mark.asyncio
async def test_run_config_custom_metadata_propagates_to_events():
session_service = InMemorySessionService()
Expand Down