2020import copy
2121import inspect
2222import logging
23- import threading
2423from typing import Any
2524from typing import AsyncGenerator
25+ from typing import AsyncIterator
2626from typing import cast
27+ from typing import Iterator
28+ from typing import List
2729from typing import Optional
30+ from typing import Tuple
2831from typing import TYPE_CHECKING
2932import uuid
3033
34+ from nltk .sem .chat80 import continent
35+
3136from google .genai import types
3237
3338from ...agents .active_streaming_tool import ActiveStreamingTool
3439from ...agents .invocation_context import InvocationContext
40+ from ...agents .run_config import StreamingMode
3541from ...auth .auth_tool import AuthToolArguments
3642from ...events .event import Event
3743from ...events .event_actions import EventActions
@@ -184,70 +190,85 @@ async def handle_function_calls_async(
184190 tools_dict : dict [str , BaseTool ],
185191 filters : Optional [set [str ]] = None ,
186192 tool_confirmation_dict : Optional [dict [str , ToolConfirmation ]] = None ,
193+
187194) -> Optional [Event ]:
188195 """Calls the functions and returns the function response event."""
189- function_calls = function_call_event .get_function_calls ()
190- return await handle_function_call_list_async (
191- invocation_context ,
192- function_calls ,
193- tools_dict ,
194- filters ,
195- tool_confirmation_dict ,
196- )
196+ async with Aclosing (
197+ handle_function_calls_async_gen (
198+ invocation_context ,
199+ function_call_event ,
200+ tools_dict ,
201+ filters ,
202+ tool_confirmation_dict ,
203+ )
204+ ) as agen :
205+ last_event = None
206+ async for event in agen :
207+ last_event = event
208+ return last_event
197209
198210
199- async def handle_function_call_list_async (
211+ async def handle_function_calls_async_gen (
200212 invocation_context : InvocationContext ,
201- function_calls : list [ types . FunctionCall ] ,
213+ function_call_event : Event ,
202214 tools_dict : dict [str , BaseTool ],
203215 filters : Optional [set [str ]] = None ,
204216 tool_confirmation_dict : Optional [dict [str , ToolConfirmation ]] = None ,
205- ) -> Optional [Event ]:
206- """Calls the functions and returns the function response event."""
217+ ) -> AsyncGenerator [ Optional [Event ] ]:
218+ """Calls the functions and returns the function response event as generator ."""
207219 from ...agents .llm_agent import LlmAgent
208220
209221 agent = invocation_context .agent
210222 if not isinstance (agent , LlmAgent ):
211- return None
223+ yield None
224+ return
225+
226+ function_calls = function_call_event .get_function_calls ()
212227
213228 # Filter function calls
214229 filtered_calls = [
215230 fc for fc in function_calls if not filters or fc .id in filters
216231 ]
217232
218233 if not filtered_calls :
219- return None
234+ yield None
235+ return
220236
221- # Create tasks for parallel execution
222- tasks = [
223- asyncio .create_task (
224- _execute_single_function_call_async (
225- invocation_context ,
226- function_call ,
227- tools_dict ,
228- agent ,
229- tool_confirmation_dict [function_call .id ]
230- if tool_confirmation_dict
231- else None ,
232- )
237+ function_call_async_gens = [
238+ _execute_single_function_call_async_gen (
239+ invocation_context ,
240+ function_call ,
241+ tools_dict ,
242+ agent ,
243+ tool_confirmation_dict [function_call .id ]
244+ if tool_confirmation_dict
245+ else None ,
233246 )
234247 for function_call in filtered_calls
235248 ]
236249
237- # Wait for all tasks to complete
238- function_response_events = await asyncio .gather (* tasks )
239-
240- # Filter out None results
241- function_response_events = [
242- event for event in function_response_events if event is not None
243- ]
250+ merged_event = None
251+ result_events : List [Optional [Event ]] = [None ] * len (function_call_async_gens )
252+ function_response_events = []
253+ async for idx , event in _concat_function_call_generators (
254+ function_call_async_gens
255+ ):
256+ result_events [idx ] = event
257+ function_response_events = [
258+ event for event in result_events if event is not None
259+ ]
260+ if function_response_events :
261+ merged_event = merge_parallel_function_response_events (
262+ function_response_events
263+ )
264+ if invocation_context .run_config .streaming_mode == StreamingMode .SSE :
265+ yield merged_event
266+ if invocation_context .run_config .streaming_mode != StreamingMode .SSE :
267+ yield merged_event
244268
245269 if not function_response_events :
246- return None
247-
248- merged_event = merge_parallel_function_response_events (
249- function_response_events
250- )
270+ yield None
271+ return
251272
252273 if len (function_response_events ) > 1 :
253274 # this is needed for debug traces of parallel calls
@@ -258,16 +279,61 @@ async def handle_function_call_list_async(
258279 response_event_id = merged_event .id ,
259280 function_response_event = merged_event ,
260281 )
261- return merged_event
262282
263283
264- async def _execute_single_function_call_async (
284+ async def _concat_function_call_generators (
285+ gens : List [AsyncGenerator [Any ]],
286+ ) -> AsyncIterator [tuple [int , Any ]]:
287+ _SENTINEL = object ()
288+ q : asyncio .Queue [tuple [str , int , Any ]] = asyncio .Queue ()
289+ gens = list (gens )
290+ n = len (gens )
291+
292+ async def __pump (idx : int , agen_ : AsyncIterator [Any ]):
293+ try :
294+ async for x in agen_ :
295+ await q .put (('ITEM' , idx , x ))
296+ except Exception as e :
297+ await q .put (('EXC' , idx , e ))
298+ finally :
299+ aclose = getattr (agen_ , 'aclose' , None )
300+ if callable (aclose ):
301+ try :
302+ await aclose ()
303+ except Exception : # noqa: ignore exception when task canceled.
304+ pass
305+
306+ await q .put (('END' , idx , _SENTINEL ))
307+
308+ tasks = [asyncio .create_task (__pump (i , agen )) for i , agen in enumerate (gens )]
309+ finished = 0
310+ try :
311+ while finished < n :
312+ kind , i , payload = await q .get ()
313+ if kind == 'ITEM' :
314+ yield i , payload
315+
316+ elif kind == 'EXC' :
317+ for t in tasks :
318+ t .cancel ()
319+ await asyncio .gather (* tasks , return_exceptions = True )
320+ raise payload
321+
322+ elif kind == 'END' :
323+ finished += 1
324+ finally :
325+ for t in tasks :
326+ t .cancel ()
327+ await asyncio .gather (* tasks , return_exceptions = True )
328+
329+
330+ async def _execute_single_function_call_async_gen (
265331 invocation_context : InvocationContext ,
266332 function_call : types .FunctionCall ,
267333 tools_dict : dict [str , BaseTool ],
268334 agent : LlmAgent ,
269335 tool_confirmation : Optional [ToolConfirmation ] = None ,
270- ) -> Optional [Event ]:
336+ ) -> AsyncGenerator [ Optional [Event ] ]:
271337 """Execute a single function call with thread safety for state modifications."""
272338 tool , tool_context = _get_tool_and_context (
273339 invocation_context ,
@@ -310,6 +376,37 @@ async def _execute_single_function_call_async(
310376 function_response = await __call_tool_async (
311377 tool , args = function_args , tool_context = tool_context
312378 )
379+ if inspect .isasyncgen (function_response ) or isinstance (
380+ function_response , AsyncIterator
381+ ):
382+ res = None
383+ async for res in function_response :
384+ if inspect .isawaitable (res ):
385+ res = await res
386+ if (
387+ invocation_context .run_config .streaming_mode
388+ == StreamingMode .SSE
389+ ):
390+ yield __build_response_event (
391+ tool , res , tool_context , invocation_context
392+ )
393+ function_response = res
394+ elif inspect .isgenerator (function_response ) or isinstance (
395+ function_response , Iterator
396+ ):
397+ res = None
398+ for res in function_response :
399+ if inspect .isawaitable (res ):
400+ res = await res
401+ if (
402+ invocation_context .run_config .streaming_mode
403+ == StreamingMode .SSE
404+ ):
405+ yield __build_response_event (
406+ tool , res , tool_context , invocation_context
407+ )
408+ function_response = res
409+
313410 except Exception as tool_error :
314411 error_response = (
315412 await invocation_context .plugin_manager .run_on_tool_error_callback (
@@ -359,7 +456,8 @@ async def _execute_single_function_call_async(
359456 # Allow long running function to return None to not provide function
360457 # response.
361458 if not function_response :
362- return None
459+ yield None
460+ return
363461
364462 # Note: State deltas are not applied here - they are collected in
365463 # tool_context.actions.state_delta and applied later when the session
@@ -374,7 +472,7 @@ async def _execute_single_function_call_async(
374472 args = function_args ,
375473 function_response_event = function_response_event ,
376474 )
377- return function_response_event
475+ yield function_response_event
378476
379477
380478async def handle_function_calls_live (
0 commit comments