Skip to content

Commit a0cd676

Browse files
author
Lin-Nikaido
committed
feat: enable to will_continue-like function response in BaseLlmFlow.run_async method with streaming mode.
1 parent 67f23df commit a0cd676

8 files changed

Lines changed: 713 additions & 89 deletions

File tree

src/google/adk/agents/readonly_context.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
if TYPE_CHECKING:
2323
from google.genai import types
2424

25+
from ..sessions.session import Session
2526
from .invocation_context import InvocationContext
2627

2728

@@ -52,3 +53,8 @@ def agent_name(self) -> str:
5253
def state(self) -> MappingProxyType[str, Any]:
5354
"""The state of the current session. READONLY field."""
5455
return MappingProxyType(self._invocation_context.session.state)
56+
57+
@property
58+
def session(self) -> Session:
59+
"""The current session. READONLY field."""
60+
return self._invocation_context.session

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 43 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,11 @@
2424
from typing import Optional
2525
from typing import TYPE_CHECKING
2626

27-
from google.genai import types
2827
from websockets.exceptions import ConnectionClosed
2928
from websockets.exceptions import ConnectionClosedOK
3029

30+
from google.genai import types
31+
3132
from . import _output_schema_processor
3233
from . import functions
3334
from ...agents.base_agent import BaseAgent
@@ -629,43 +630,51 @@ async def _postprocess_handle_function_calls_async(
629630
function_call_event: Event,
630631
llm_request: LlmRequest,
631632
) -> AsyncGenerator[Event, None]:
632-
if function_response_event := await functions.handle_function_calls_async(
633+
# if invocation_context.run_config.streaming_mode == StreamingMode.SSE:
634+
#
635+
# else:
636+
if function_response_event_agen := functions.handle_function_calls_async_gen(
633637
invocation_context, function_call_event, llm_request.tools_dict
634638
):
635-
auth_event = functions.generate_auth_event(
636-
invocation_context, function_response_event
637-
)
638-
if auth_event:
639-
yield auth_event
640-
641-
tool_confirmation_event = functions.generate_request_confirmation_event(
642-
invocation_context, function_call_event, function_response_event
643-
)
644-
if tool_confirmation_event:
645-
yield tool_confirmation_event
646-
647-
# Always yield the function response event first
648-
yield function_response_event
649-
650-
# Check if this is a set_model_response function response
651-
if json_response := _output_schema_processor.get_structured_model_response(
652-
function_response_event
653-
):
654-
# Create and yield a final model response event
655-
final_event = (
656-
_output_schema_processor.create_final_model_response_event(
657-
invocation_context, json_response
658-
)
639+
function_response_event = None
640+
async for function_response_event in function_response_event_agen:
641+
auth_event = functions.generate_auth_event(
642+
invocation_context, function_response_event
659643
)
660-
yield final_event
661-
transfer_to_agent = function_response_event.actions.transfer_to_agent
662-
if transfer_to_agent:
663-
agent_to_run = self._get_agent_to_run(
664-
invocation_context, transfer_to_agent
644+
if auth_event:
645+
yield auth_event
646+
647+
tool_confirmation_event = functions.generate_request_confirmation_event(
648+
invocation_context, function_call_event, function_response_event
665649
)
666-
async with Aclosing(agent_to_run.run_async(invocation_context)) as agen:
667-
async for event in agen:
668-
yield event
650+
if tool_confirmation_event:
651+
yield tool_confirmation_event
652+
653+
# Always yield the function response event first
654+
yield function_response_event
655+
656+
# Check if this is a set_model_response function response
657+
if json_response := _output_schema_processor.get_structured_model_response(
658+
function_response_event
659+
):
660+
# Create and yield a final model response event
661+
final_event = (
662+
_output_schema_processor.create_final_model_response_event(
663+
invocation_context, json_response
664+
)
665+
)
666+
yield final_event
667+
if function_response_event:
668+
transfer_to_agent = function_response_event.actions.transfer_to_agent
669+
if transfer_to_agent:
670+
agent_to_run = self._get_agent_to_run(
671+
invocation_context, transfer_to_agent
672+
)
673+
async with Aclosing(
674+
agent_to_run.run_async(invocation_context)
675+
) as agen:
676+
async for event in agen:
677+
yield event
669678

670679
def _get_agent_to_run(
671680
self, invocation_context: InvocationContext, agent_name: str

src/google/adk/flows/llm_flows/functions.py

Lines changed: 142 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,24 @@
2020
import copy
2121
import inspect
2222
import logging
23-
import threading
2423
from typing import Any
2524
from typing import AsyncGenerator
25+
from typing import AsyncIterator
2626
from typing import cast
27+
from typing import Iterator
28+
from typing import List
2729
from typing import Optional
30+
from typing import Tuple
2831
from typing import TYPE_CHECKING
2932
import uuid
3033

34+
from nltk.sem.chat80 import continent
35+
3136
from google.genai import types
3237

3338
from ...agents.active_streaming_tool import ActiveStreamingTool
3439
from ...agents.invocation_context import InvocationContext
40+
from ...agents.run_config import StreamingMode
3541
from ...auth.auth_tool import AuthToolArguments
3642
from ...events.event import Event
3743
from ...events.event_actions import EventActions
@@ -184,70 +190,85 @@ async def handle_function_calls_async(
184190
tools_dict: dict[str, BaseTool],
185191
filters: Optional[set[str]] = None,
186192
tool_confirmation_dict: Optional[dict[str, ToolConfirmation]] = None,
193+
187194
) -> Optional[Event]:
188195
"""Calls the functions and returns the function response event."""
189-
function_calls = function_call_event.get_function_calls()
190-
return await handle_function_call_list_async(
191-
invocation_context,
192-
function_calls,
193-
tools_dict,
194-
filters,
195-
tool_confirmation_dict,
196-
)
196+
async with Aclosing(
197+
handle_function_calls_async_gen(
198+
invocation_context,
199+
function_call_event,
200+
tools_dict,
201+
filters,
202+
tool_confirmation_dict,
203+
)
204+
) as agen:
205+
last_event = None
206+
async for event in agen:
207+
last_event = event
208+
return last_event
197209

198210

199-
async def handle_function_call_list_async(
211+
async def handle_function_calls_async_gen(
200212
invocation_context: InvocationContext,
201-
function_calls: list[types.FunctionCall],
213+
function_call_event: Event,
202214
tools_dict: dict[str, BaseTool],
203215
filters: Optional[set[str]] = None,
204216
tool_confirmation_dict: Optional[dict[str, ToolConfirmation]] = None,
205-
) -> Optional[Event]:
206-
"""Calls the functions and returns the function response event."""
217+
) -> AsyncGenerator[Optional[Event]]:
218+
"""Calls the functions and returns the function response event as generator."""
207219
from ...agents.llm_agent import LlmAgent
208220

209221
agent = invocation_context.agent
210222
if not isinstance(agent, LlmAgent):
211-
return None
223+
yield None
224+
return
225+
226+
function_calls = function_call_event.get_function_calls()
212227

213228
# Filter function calls
214229
filtered_calls = [
215230
fc for fc in function_calls if not filters or fc.id in filters
216231
]
217232

218233
if not filtered_calls:
219-
return None
234+
yield None
235+
return
220236

221-
# Create tasks for parallel execution
222-
tasks = [
223-
asyncio.create_task(
224-
_execute_single_function_call_async(
225-
invocation_context,
226-
function_call,
227-
tools_dict,
228-
agent,
229-
tool_confirmation_dict[function_call.id]
230-
if tool_confirmation_dict
231-
else None,
232-
)
237+
function_call_async_gens = [
238+
_execute_single_function_call_async_gen(
239+
invocation_context,
240+
function_call,
241+
tools_dict,
242+
agent,
243+
tool_confirmation_dict[function_call.id]
244+
if tool_confirmation_dict
245+
else None,
233246
)
234247
for function_call in filtered_calls
235248
]
236249

237-
# Wait for all tasks to complete
238-
function_response_events = await asyncio.gather(*tasks)
239-
240-
# Filter out None results
241-
function_response_events = [
242-
event for event in function_response_events if event is not None
243-
]
250+
merged_event = None
251+
result_events: List[Optional[Event]] = [None] * len(function_call_async_gens)
252+
function_response_events = []
253+
async for idx, event in _concat_function_call_generators(
254+
function_call_async_gens
255+
):
256+
result_events[idx] = event
257+
function_response_events = [
258+
event for event in result_events if event is not None
259+
]
260+
if function_response_events:
261+
merged_event = merge_parallel_function_response_events(
262+
function_response_events
263+
)
264+
if invocation_context.run_config.streaming_mode == StreamingMode.SSE:
265+
yield merged_event
266+
if invocation_context.run_config.streaming_mode != StreamingMode.SSE:
267+
yield merged_event
244268

245269
if not function_response_events:
246-
return None
247-
248-
merged_event = merge_parallel_function_response_events(
249-
function_response_events
250-
)
270+
yield None
271+
return
251272

252273
if len(function_response_events) > 1:
253274
# this is needed for debug traces of parallel calls
@@ -258,16 +279,61 @@ async def handle_function_call_list_async(
258279
response_event_id=merged_event.id,
259280
function_response_event=merged_event,
260281
)
261-
return merged_event
262282

263283

264-
async def _execute_single_function_call_async(
284+
async def _concat_function_call_generators(
285+
gens: List[AsyncGenerator[Any]],
286+
) -> AsyncIterator[tuple[int, Any]]:
287+
_SENTINEL = object()
288+
q: asyncio.Queue[tuple[str, int, Any]] = asyncio.Queue()
289+
gens = list(gens)
290+
n = len(gens)
291+
292+
async def __pump(idx: int, agen_: AsyncIterator[Any]):
293+
try:
294+
async for x in agen_:
295+
await q.put(('ITEM', idx, x))
296+
except Exception as e:
297+
await q.put(('EXC', idx, e))
298+
finally:
299+
aclose = getattr(agen_, 'aclose', None)
300+
if callable(aclose):
301+
try:
302+
await aclose()
303+
except Exception: # noqa: ignore exception when task canceled.
304+
pass
305+
306+
await q.put(('END', idx, _SENTINEL))
307+
308+
tasks = [asyncio.create_task(__pump(i, agen)) for i, agen in enumerate(gens)]
309+
finished = 0
310+
try:
311+
while finished < n:
312+
kind, i, payload = await q.get()
313+
if kind == 'ITEM':
314+
yield i, payload
315+
316+
elif kind == 'EXC':
317+
for t in tasks:
318+
t.cancel()
319+
await asyncio.gather(*tasks, return_exceptions=True)
320+
raise payload
321+
322+
elif kind == 'END':
323+
finished += 1
324+
finally:
325+
for t in tasks:
326+
t.cancel()
327+
await asyncio.gather(*tasks, return_exceptions=True)
328+
329+
330+
async def _execute_single_function_call_async_gen(
265331
invocation_context: InvocationContext,
266332
function_call: types.FunctionCall,
267333
tools_dict: dict[str, BaseTool],
268334
agent: LlmAgent,
269335
tool_confirmation: Optional[ToolConfirmation] = None,
270-
) -> Optional[Event]:
336+
) -> AsyncGenerator[Optional[Event]]:
271337
"""Execute a single function call with thread safety for state modifications."""
272338
tool, tool_context = _get_tool_and_context(
273339
invocation_context,
@@ -310,6 +376,37 @@ async def _execute_single_function_call_async(
310376
function_response = await __call_tool_async(
311377
tool, args=function_args, tool_context=tool_context
312378
)
379+
if inspect.isasyncgen(function_response) or isinstance(
380+
function_response, AsyncIterator
381+
):
382+
res = None
383+
async for res in function_response:
384+
if inspect.isawaitable(res):
385+
res = await res
386+
if (
387+
invocation_context.run_config.streaming_mode
388+
== StreamingMode.SSE
389+
):
390+
yield __build_response_event(
391+
tool, res, tool_context, invocation_context
392+
)
393+
function_response = res
394+
elif inspect.isgenerator(function_response) or isinstance(
395+
function_response, Iterator
396+
):
397+
res = None
398+
for res in function_response:
399+
if inspect.isawaitable(res):
400+
res = await res
401+
if (
402+
invocation_context.run_config.streaming_mode
403+
== StreamingMode.SSE
404+
):
405+
yield __build_response_event(
406+
tool, res, tool_context, invocation_context
407+
)
408+
function_response = res
409+
313410
except Exception as tool_error:
314411
error_response = (
315412
await invocation_context.plugin_manager.run_on_tool_error_callback(
@@ -359,7 +456,8 @@ async def _execute_single_function_call_async(
359456
# Allow long running function to return None to not provide function
360457
# response.
361458
if not function_response:
362-
return None
459+
yield None
460+
return
363461

364462
# Note: State deltas are not applied here - they are collected in
365463
# tool_context.actions.state_delta and applied later when the session
@@ -374,7 +472,7 @@ async def _execute_single_function_call_async(
374472
args=function_args,
375473
function_response_event=function_response_event,
376474
)
377-
return function_response_event
475+
yield function_response_event
378476

379477

380478
async def handle_function_calls_live(

0 commit comments

Comments
 (0)