Skip to content

Commit 738e743

Browse files
Revert "enable to will_continue-like function response in BaseLlmFlow.run_async method."
This reverts commit e683005.
1 parent cfa1dbf commit 738e743

2 files changed

Lines changed: 35 additions & 55 deletions

File tree

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -466,18 +466,16 @@ async def _postprocess_handle_function_calls_async(
466466
function_call_event: Event,
467467
llm_request: LlmRequest,
468468
) -> AsyncGenerator[Event, None]:
469-
if function_response_event_generator := await functions.handle_function_calls_async(
469+
if function_response_event := await functions.handle_function_calls_async(
470470
invocation_context, function_call_event, llm_request.tools_dict
471471
):
472-
async for function_response_event in function_response_event_generator:
473-
auth_event = functions.generate_auth_event(
474-
invocation_context, function_response_event
475-
)
476-
if auth_event:
477-
yield auth_event
478-
479-
yield function_response_event
472+
auth_event = functions.generate_auth_event(
473+
invocation_context, function_response_event
474+
)
475+
if auth_event:
476+
yield auth_event
480477

478+
yield function_response_event
481479
transfer_to_agent = function_response_event.actions.transfer_to_agent
482480
if transfer_to_agent:
483481
agent_to_run = self._get_agent_to_run(

src/google/adk/flows/llm_flows/functions.py

Lines changed: 28 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -128,17 +128,17 @@ async def handle_function_calls_async(
128128
function_call_event: Event,
129129
tools_dict: dict[str, BaseTool],
130130
filters: Optional[set[str]] = None,
131-
) -> AsyncGenerator[Opeional[Event], None]:
131+
) -> Optional[Event]:
132132
"""Calls the functions and returns the function response event."""
133133
from ...agents.llm_agent import LlmAgent
134134

135135
agent = invocation_context.agent
136136
if not isinstance(agent, LlmAgent):
137-
yield None
138137
return
139138

140139
function_calls = function_call_event.get_function_calls()
141140

141+
function_response_events: list[Event] = []
142142
for function_call in function_calls:
143143
if filters and function_call.id not in filters:
144144
continue
@@ -217,51 +217,33 @@ async def handle_function_calls_async(
217217
if not function_response:
218218
continue
219219

220-
function_response_events: list[Event] = []
221-
async for function_response_one in _ensure_async_generator_function_response(
222-
function_response
223-
):
224-
# Builds the function response event.
225-
function_response_event = __build_response_event(
226-
tool, function_response_one, tool_context, invocation_context
227-
)
228-
trace_tool_call(
229-
tool=tool,
230-
args=function_args,
231-
function_response_event=function_response_event,
232-
)
233-
function_response_events.append(function_response_event)
220+
# Builds the function response event.
221+
function_response_event = __build_response_event(
222+
tool, function_response, tool_context, invocation_context
223+
)
224+
trace_tool_call(
225+
tool=tool,
226+
args=function_args,
227+
function_response_event=function_response_event,
228+
)
229+
function_response_events.append(function_response_event)
234230

235-
if not function_response_events:
236-
yield None
237-
return
238-
merged_event = merge_parallel_function_response_events(
239-
function_response_events
240-
)
241-
if len(function_response_events) > 1:
242-
# this is needed for debug traces of parallel calls
243-
# individual response with tool.name is traced in __build_response_event
244-
# (we drop tool.name from span name here as this is merged event)
245-
with tracer.start_as_current_span('execute_tool (merged)'):
246-
trace_merged_tool_calls(
247-
response_event_id=merged_event.id,
248-
function_response_event=merged_event,
249-
)
250-
yield merged_event
251-
252-
253-
254-
async def _ensure_async_generator_function_response(
255-
function_response: Union[str, dict, AsyncGenerator]
256-
) -> AsyncGenerator[Opeional[Event], None]:
257-
if inspect.isasyncgen(function_response):
258-
async for response in function_response:
259-
yield response
260-
elif inspect.isgenerator(function_response):
261-
for response in function_response:
262-
yield response
263-
else:
264-
yield function_response
231+
if not function_response_events:
232+
return None
233+
merged_event = merge_parallel_function_response_events(
234+
function_response_events
235+
)
236+
237+
if len(function_response_events) > 1:
238+
# this is needed for debug traces of parallel calls
239+
# individual response with tool.name is traced in __build_response_event
240+
# (we drop tool.name from span name here as this is merged event)
241+
with tracer.start_as_current_span('execute_tool (merged)'):
242+
trace_merged_tool_calls(
243+
response_event_id=merged_event.id,
244+
function_response_event=merged_event,
245+
)
246+
return merged_event
265247

266248

267249
async def handle_function_calls_live(

0 commit comments

Comments
 (0)