3333from google .genai import types
3434from google .genai .types import Part
3535from pydantic import BaseModel
36- from pytest import mark
36+ from pytest import mark , fixture
3737
3838from .. import testing_utils
3939
@@ -59,124 +59,185 @@ def change_state_callback(callback_context: CallbackContext):
5959 print ('change_state_callback: ' , callback_context .state )
6060
6161
62- @mark .asyncio
63- async def test_agent_tool_inherits_parent_app_name (monkeypatch ):
64- parent_app_name = 'parent_app'
65- captured : dict [str , str ] = {}
66-
67- class RecordingSessionService (InMemorySessionService ):
68-
69- async def create_session (
70- self ,
71- * ,
72- app_name : str ,
73- user_id : str ,
74- state : Optional [dict [str , Any ]] = None ,
75- session_id : Optional [str ] = None ,
76- ):
77- captured ['session_app_name' ] = app_name
78- return await super ().create_session (
79- app_name = app_name ,
80- user_id = user_id ,
81- state = state ,
82- session_id = session_id ,
83- )
84-
85- monkeypatch .setattr (
86- 'google.adk.sessions.in_memory_session_service.InMemorySessionService' ,
87- RecordingSessionService ,
88- )
89-
62+ @fixture
63+ def agent_tool_setup_factory (monkeypatch ):
9064 async def _empty_async_generator ():
9165 if False :
9266 yield None
9367
94- class StubRunner :
68+ async def _create_setup (
69+ * ,
70+ parent_app_name : str ,
71+ parent_session_id : Optional [str ] | None = None ,
72+ capture_runner_app_name : bool = False ,
73+ capture_session_app_name : bool = False ,
74+ capture_child_session_id : bool = False ,
75+ ):
76+ captured : dict [str , Any ] = {}
77+
78+ class RecordingSessionService (InMemorySessionService ):
79+
80+ async def create_session (
81+ self ,
82+ * ,
83+ app_name : str ,
84+ user_id : str ,
85+ state : Optional [dict [str , Any ]] = None ,
86+ session_id : Optional [str ] = None ,
87+ ):
88+ if capture_session_app_name :
89+ captured ['session_app_name' ] = app_name
90+ if capture_child_session_id :
91+ captured ['child_session_id' ] = session_id
92+ return await super ().create_session (
93+ app_name = app_name ,
94+ user_id = user_id ,
95+ state = state ,
96+ session_id = session_id ,
97+ )
9598
96- def __init__ (
97- self ,
98- * ,
99- app_name : str ,
100- agent : Agent ,
101- artifact_service ,
102- session_service ,
103- memory_service ,
104- credential_service ,
105- plugins ,
106- ):
107- del artifact_service , memory_service , credential_service
108- captured ['runner_app_name' ] = app_name
109- self .agent = agent
110- self .session_service = session_service
111- self .plugin_manager = PluginManager (plugins = plugins )
112- self .app_name = app_name
113-
114- def run_async (
115- self ,
116- * ,
117- user_id : str ,
118- session_id : str ,
119- invocation_id : Optional [str ] = None ,
120- new_message : Optional [types .Content ] = None ,
121- state_delta : Optional [dict [str , Any ]] = None ,
122- run_config : Optional [RunConfig ] = None ,
123- ):
124- del (
125- user_id ,
126- session_id ,
127- invocation_id ,
128- new_message ,
129- state_delta ,
130- run_config ,
131- )
132- return _empty_async_generator ()
133-
134- async def close (self ):
135- """Mock close method."""
136- pass
137-
138- monkeypatch .setattr ('google.adk.runners.Runner' , StubRunner )
99+ monkeypatch .setattr (
100+ 'google.adk.sessions.in_memory_session_service.InMemorySessionService' ,
101+ RecordingSessionService ,
102+ )
139103
140- tool_agent = Agent (
141- name = 'tool_agent' ,
142- model = 'test-model' ,
143- )
144- agent_tool = AgentTool (agent = tool_agent )
145- root_agent = Agent (
146- name = 'root_agent' ,
147- model = 'test-model' ,
148- tools = [agent_tool ],
149- )
104+ class StubRunner :
105+
106+ def __init__ (
107+ self ,
108+ * ,
109+ app_name : str ,
110+ agent : Agent ,
111+ artifact_service ,
112+ session_service ,
113+ memory_service ,
114+ credential_service ,
115+ plugins ,
116+ ):
117+ del artifact_service , memory_service , credential_service
118+ if capture_runner_app_name :
119+ captured ['runner_app_name' ] = app_name
120+ self .agent = agent
121+ self .session_service = session_service
122+ self .plugin_manager = PluginManager (plugins = plugins )
123+ self .app_name = app_name
124+
125+ def run_async (
126+ self ,
127+ * ,
128+ user_id : str ,
129+ session_id : str ,
130+ invocation_id : Optional [str ] = None ,
131+ new_message : Optional [types .Content ] = None ,
132+ state_delta : Optional [dict [str , Any ]] = None ,
133+ run_config : Optional [RunConfig ] = None ,
134+ ):
135+ del (
136+ user_id ,
137+ session_id ,
138+ invocation_id ,
139+ new_message ,
140+ state_delta ,
141+ run_config ,
142+ )
143+ return _empty_async_generator ()
150144
151- artifact_service = InMemoryArtifactService ()
152- parent_session_service = InMemorySessionService ()
153- parent_session = await parent_session_service .create_session (
154- app_name = parent_app_name ,
155- user_id = 'user' ,
156- )
157- invocation_context = InvocationContext (
158- artifact_service = artifact_service ,
159- session_service = parent_session_service ,
160- memory_service = InMemoryMemoryService (),
161- plugin_manager = PluginManager (),
162- invocation_id = 'invocation-id' ,
163- agent = root_agent ,
164- session = parent_session ,
165- run_config = RunConfig (),
145+ async def close (self ):
146+ """Mock close method."""
147+ pass
148+
149+ monkeypatch .setattr ('google.adk.runners.Runner' , StubRunner )
150+
151+ tool_agent = Agent (
152+ name = 'tool_agent' ,
153+ model = 'test-model' ,
154+ )
155+ agent_tool = AgentTool (agent = tool_agent )
156+ root_agent = Agent (
157+ name = 'root_agent' ,
158+ model = 'test-model' ,
159+ tools = [agent_tool ],
160+ )
161+
162+ artifact_service = InMemoryArtifactService ()
163+ parent_session_service = InMemorySessionService ()
164+ parent_session = await parent_session_service .create_session (
165+ app_name = parent_app_name ,
166+ user_id = 'user' ,
167+ session_id = parent_session_id ,
168+ )
169+ invocation_context = InvocationContext (
170+ artifact_service = artifact_service ,
171+ session_service = parent_session_service ,
172+ memory_service = InMemoryMemoryService (),
173+ plugin_manager = PluginManager (),
174+ invocation_id = 'invocation-id' ,
175+ agent = root_agent ,
176+ session = parent_session ,
177+ run_config = RunConfig (),
178+ )
179+ tool_context = ToolContext (invocation_context )
180+
181+ return {
182+ 'agent_tool' : agent_tool ,
183+ 'tool_context' : tool_context ,
184+ 'captured' : captured ,
185+ }
186+
187+ return _create_setup
188+
189+
190+ @mark .asyncio
191+ async def test_agent_tool_inherits_parent_app_name (agent_tool_setup_factory ):
192+ parent_app_name = 'parent_app'
193+
194+ setup = await agent_tool_setup_factory (
195+ parent_app_name = parent_app_name ,
196+ capture_runner_app_name = True ,
197+ capture_session_app_name = True ,
166198 )
167- tool_context = ToolContext (invocation_context )
199+
200+ agent_tool = setup ['agent_tool' ]
201+ tool_context = setup ['tool_context' ]
202+ captured = setup ['captured' ]
168203
169204 assert tool_context ._invocation_context .app_name == parent_app_name
170205
171206 await agent_tool .run_async (
172- args = {'request' : 'hello' },
173- tool_context = tool_context ,
207+ args = {'request' : 'hello' },
208+ tool_context = tool_context ,
174209 )
175210
176211 assert captured ['runner_app_name' ] == parent_app_name
177212 assert captured ['session_app_name' ] == parent_app_name
178213
179214
215+ @mark .asyncio
216+ async def test_agent_tool_passes_parent_session_id (agent_tool_setup_factory ):
217+ """Test that the parent session ID is passed to the child session."""
218+ parent_app_name = 'parent_app'
219+ parent_session_id = 'parent-session-123'
220+ setup = await agent_tool_setup_factory (
221+ parent_app_name = parent_app_name ,
222+ parent_session_id = parent_session_id ,
223+ capture_child_session_id = True ,
224+ )
225+
226+ agent_tool = setup ['agent_tool' ]
227+ tool_context = setup ['tool_context' ]
228+ captured = setup ['captured' ]
229+
230+ assert tool_context ._invocation_context .session .id == parent_session_id
231+
232+ await agent_tool .run_async (
233+ args = {'request' : 'hello' },
234+ tool_context = tool_context ,
235+ )
236+
237+ # Verify that the parent session ID was passed to the child session
238+ assert captured ['child_session_id' ] == parent_session_id
239+
240+
180241def test_no_schema ():
181242 mock_model = testing_utils .MockModel .create (
182243 responses = [
0 commit comments