Skip to content

Commit 8d99592

Browse files
Pass parent session ID to child sessions in AgentTool
When AgentTool creates a child session for a sub-agent, it now passes the parent session ID to maintain session continuity across agent boundaries. This ensures that child agents can properly track and reference their parent session context. Changes: - Added session_id parameter to create_session call in agent_tool.py - Added test_agent_tool_passes_parent_session_id to verify the behavior Testing: - All 16 unit tests pass - New test specifically validates session ID propagation
1 parent 322dd18 commit 8d99592

3 files changed

Lines changed: 248 additions & 101 deletions

File tree

src/google/adk/runners.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,91 @@ async def rewind_async(
519519
if not session:
520520
raise ValueError(f'Session not found: {session_id}')
521521

522+
await self._rewind_session(
523+
session=session,
524+
rewind_before_invocation_id=rewind_before_invocation_id,
525+
)
526+
527+
async def rewind_last_invocation_async(
528+
self,
529+
*,
530+
user_id: str,
531+
session_id: str,
532+
) -> None:
533+
"""Rewinds the session to before the last invocation.
534+
535+
This is a convenience method that finds the most recent invocation
536+
and rewinds before it. Useful when you want to undo the last agent
537+
interaction without needing to track invocation IDs.
538+
539+
Args:
540+
user_id: The user ID of the session.
541+
session_id: The session ID of the session.
542+
543+
Raises:
544+
ValueError: If the session is not found or has no invocations.
545+
"""
546+
session = await self.session_service.get_session(
547+
app_name=self.app_name, user_id=user_id, session_id=session_id
548+
)
549+
if not session:
550+
raise ValueError(f'Session not found: {session_id}')
551+
552+
if not session.events:
553+
raise ValueError(f'Session {session_id} has no events to rewind.')
554+
555+
# Find the most recent invocation ID by finding the event with the latest timestamp
556+
# Skip rewind events (events that have rewind_before_invocation_id set)
557+
# AND skip invocations that have already been rewound
558+
559+
# First, collect all invocation IDs that have been rewound
560+
rewound_invocation_ids = set()
561+
for event in session.events:
562+
if event.actions and event.actions.rewind_before_invocation_id:
563+
rewound_invocation_ids.add(event.actions.rewind_before_invocation_id)
564+
565+
# Now find the most recent invocation that hasn't been rewound
566+
last_event_with_invocation = None
567+
for event in session.events:
568+
if event.invocation_id:
569+
# Skip rewind events themselves
570+
if event.actions and event.actions.rewind_before_invocation_id:
571+
continue
572+
# Skip invocations that have already been rewound
573+
if event.invocation_id in rewound_invocation_ids:
574+
continue
575+
if (last_event_with_invocation is None or
576+
event.timestamp > last_event_with_invocation.timestamp):
577+
last_event_with_invocation = event
578+
579+
if not last_event_with_invocation:
580+
raise ValueError(
581+
f'No invocation found in session {session_id} to rewind.'
582+
)
583+
584+
last_invocation_id = last_event_with_invocation.invocation_id
585+
586+
# Reuse the core rewind logic
587+
await self._rewind_session(
588+
session=session,
589+
rewind_before_invocation_id=last_invocation_id,
590+
)
591+
592+
async def _rewind_session(
593+
self,
594+
*,
595+
session: Session,
596+
rewind_before_invocation_id: str,
597+
) -> None:
598+
"""Core rewind logic that performs the actual rewind operation.
599+
600+
Args:
601+
session: The session to rewind.
602+
rewind_before_invocation_id: The invocation ID to rewind before.
603+
604+
Raises:
605+
ValueError: If the invocation ID is not found in the session.
606+
"""
522607
rewind_event_index = -1
523608
for i, event in enumerate(session.events):
524609
if event.invocation_id == rewind_before_invocation_id:

src/google/adk/tools/agent_tool.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ async def run_async(
167167
app_name=child_app_name,
168168
user_id=tool_context._invocation_context.user_id,
169169
state=state_dict,
170+
session_id=tool_context._invocation_context.session.id,
170171
)
171172

172173
last_content = None

tests/unittests/tools/test_agent_tool.py

Lines changed: 162 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from google.genai import types
3434
from google.genai.types import Part
3535
from pydantic import BaseModel
36-
from pytest import mark
36+
from pytest import mark, fixture
3737

3838
from .. import testing_utils
3939

@@ -59,124 +59,185 @@ def change_state_callback(callback_context: CallbackContext):
5959
print('change_state_callback: ', callback_context.state)
6060

6161

62-
@mark.asyncio
63-
async def test_agent_tool_inherits_parent_app_name(monkeypatch):
64-
parent_app_name = 'parent_app'
65-
captured: dict[str, str] = {}
66-
67-
class RecordingSessionService(InMemorySessionService):
68-
69-
async def create_session(
70-
self,
71-
*,
72-
app_name: str,
73-
user_id: str,
74-
state: Optional[dict[str, Any]] = None,
75-
session_id: Optional[str] = None,
76-
):
77-
captured['session_app_name'] = app_name
78-
return await super().create_session(
79-
app_name=app_name,
80-
user_id=user_id,
81-
state=state,
82-
session_id=session_id,
83-
)
84-
85-
monkeypatch.setattr(
86-
'google.adk.sessions.in_memory_session_service.InMemorySessionService',
87-
RecordingSessionService,
88-
)
89-
62+
@fixture
63+
def agent_tool_setup_factory(monkeypatch):
9064
async def _empty_async_generator():
9165
if False:
9266
yield None
9367

94-
class StubRunner:
68+
async def _create_setup(
69+
*,
70+
parent_app_name: str,
71+
parent_session_id: Optional[str] | None = None,
72+
capture_runner_app_name: bool = False,
73+
capture_session_app_name: bool = False,
74+
capture_child_session_id: bool = False,
75+
):
76+
captured: dict[str, Any] = {}
77+
78+
class RecordingSessionService(InMemorySessionService):
79+
80+
async def create_session(
81+
self,
82+
*,
83+
app_name: str,
84+
user_id: str,
85+
state: Optional[dict[str, Any]] = None,
86+
session_id: Optional[str] = None,
87+
):
88+
if capture_session_app_name:
89+
captured['session_app_name'] = app_name
90+
if capture_child_session_id:
91+
captured['child_session_id'] = session_id
92+
return await super().create_session(
93+
app_name=app_name,
94+
user_id=user_id,
95+
state=state,
96+
session_id=session_id,
97+
)
9598

96-
def __init__(
97-
self,
98-
*,
99-
app_name: str,
100-
agent: Agent,
101-
artifact_service,
102-
session_service,
103-
memory_service,
104-
credential_service,
105-
plugins,
106-
):
107-
del artifact_service, memory_service, credential_service
108-
captured['runner_app_name'] = app_name
109-
self.agent = agent
110-
self.session_service = session_service
111-
self.plugin_manager = PluginManager(plugins=plugins)
112-
self.app_name = app_name
113-
114-
def run_async(
115-
self,
116-
*,
117-
user_id: str,
118-
session_id: str,
119-
invocation_id: Optional[str] = None,
120-
new_message: Optional[types.Content] = None,
121-
state_delta: Optional[dict[str, Any]] = None,
122-
run_config: Optional[RunConfig] = None,
123-
):
124-
del (
125-
user_id,
126-
session_id,
127-
invocation_id,
128-
new_message,
129-
state_delta,
130-
run_config,
131-
)
132-
return _empty_async_generator()
133-
134-
async def close(self):
135-
"""Mock close method."""
136-
pass
137-
138-
monkeypatch.setattr('google.adk.runners.Runner', StubRunner)
99+
monkeypatch.setattr(
100+
'google.adk.sessions.in_memory_session_service.InMemorySessionService',
101+
RecordingSessionService,
102+
)
139103

140-
tool_agent = Agent(
141-
name='tool_agent',
142-
model='test-model',
143-
)
144-
agent_tool = AgentTool(agent=tool_agent)
145-
root_agent = Agent(
146-
name='root_agent',
147-
model='test-model',
148-
tools=[agent_tool],
149-
)
104+
class StubRunner:
105+
106+
def __init__(
107+
self,
108+
*,
109+
app_name: str,
110+
agent: Agent,
111+
artifact_service,
112+
session_service,
113+
memory_service,
114+
credential_service,
115+
plugins,
116+
):
117+
del artifact_service, memory_service, credential_service
118+
if capture_runner_app_name:
119+
captured['runner_app_name'] = app_name
120+
self.agent = agent
121+
self.session_service = session_service
122+
self.plugin_manager = PluginManager(plugins=plugins)
123+
self.app_name = app_name
124+
125+
def run_async(
126+
self,
127+
*,
128+
user_id: str,
129+
session_id: str,
130+
invocation_id: Optional[str] = None,
131+
new_message: Optional[types.Content] = None,
132+
state_delta: Optional[dict[str, Any]] = None,
133+
run_config: Optional[RunConfig] = None,
134+
):
135+
del (
136+
user_id,
137+
session_id,
138+
invocation_id,
139+
new_message,
140+
state_delta,
141+
run_config,
142+
)
143+
return _empty_async_generator()
150144

151-
artifact_service = InMemoryArtifactService()
152-
parent_session_service = InMemorySessionService()
153-
parent_session = await parent_session_service.create_session(
154-
app_name=parent_app_name,
155-
user_id='user',
156-
)
157-
invocation_context = InvocationContext(
158-
artifact_service=artifact_service,
159-
session_service=parent_session_service,
160-
memory_service=InMemoryMemoryService(),
161-
plugin_manager=PluginManager(),
162-
invocation_id='invocation-id',
163-
agent=root_agent,
164-
session=parent_session,
165-
run_config=RunConfig(),
145+
async def close(self):
146+
"""Mock close method."""
147+
pass
148+
149+
monkeypatch.setattr('google.adk.runners.Runner', StubRunner)
150+
151+
tool_agent = Agent(
152+
name='tool_agent',
153+
model='test-model',
154+
)
155+
agent_tool = AgentTool(agent=tool_agent)
156+
root_agent = Agent(
157+
name='root_agent',
158+
model='test-model',
159+
tools=[agent_tool],
160+
)
161+
162+
artifact_service = InMemoryArtifactService()
163+
parent_session_service = InMemorySessionService()
164+
parent_session = await parent_session_service.create_session(
165+
app_name=parent_app_name,
166+
user_id='user',
167+
session_id=parent_session_id,
168+
)
169+
invocation_context = InvocationContext(
170+
artifact_service=artifact_service,
171+
session_service=parent_session_service,
172+
memory_service=InMemoryMemoryService(),
173+
plugin_manager=PluginManager(),
174+
invocation_id='invocation-id',
175+
agent=root_agent,
176+
session=parent_session,
177+
run_config=RunConfig(),
178+
)
179+
tool_context = ToolContext(invocation_context)
180+
181+
return {
182+
'agent_tool': agent_tool,
183+
'tool_context': tool_context,
184+
'captured': captured,
185+
}
186+
187+
return _create_setup
188+
189+
190+
@mark.asyncio
191+
async def test_agent_tool_inherits_parent_app_name(agent_tool_setup_factory):
192+
parent_app_name = 'parent_app'
193+
194+
setup = await agent_tool_setup_factory(
195+
parent_app_name=parent_app_name,
196+
capture_runner_app_name=True,
197+
capture_session_app_name=True,
166198
)
167-
tool_context = ToolContext(invocation_context)
199+
200+
agent_tool = setup['agent_tool']
201+
tool_context = setup['tool_context']
202+
captured = setup['captured']
168203

169204
assert tool_context._invocation_context.app_name == parent_app_name
170205

171206
await agent_tool.run_async(
172-
args={'request': 'hello'},
173-
tool_context=tool_context,
207+
args={'request': 'hello'},
208+
tool_context=tool_context,
174209
)
175210

176211
assert captured['runner_app_name'] == parent_app_name
177212
assert captured['session_app_name'] == parent_app_name
178213

179214

215+
@mark.asyncio
216+
async def test_agent_tool_passes_parent_session_id(agent_tool_setup_factory):
217+
"""Test that the parent session ID is passed to the child session."""
218+
parent_app_name = 'parent_app'
219+
parent_session_id = 'parent-session-123'
220+
setup = await agent_tool_setup_factory(
221+
parent_app_name=parent_app_name,
222+
parent_session_id=parent_session_id,
223+
capture_child_session_id=True,
224+
)
225+
226+
agent_tool = setup['agent_tool']
227+
tool_context = setup['tool_context']
228+
captured = setup['captured']
229+
230+
assert tool_context._invocation_context.session.id == parent_session_id
231+
232+
await agent_tool.run_async(
233+
args={'request': 'hello'},
234+
tool_context=tool_context,
235+
)
236+
237+
# Verify that the parent session ID was passed to the child session
238+
assert captured['child_session_id'] == parent_session_id
239+
240+
180241
def test_no_schema():
181242
mock_model = testing_utils.MockModel.create(
182243
responses=[

0 commit comments

Comments
 (0)