diff --git a/src/google/adk/agents/llm_agent.py b/src/google/adk/agents/llm_agent.py index b41b7f4eff..23745e8d4c 100644 --- a/src/google/adk/agents/llm_agent.py +++ b/src/google/adk/agents/llm_agent.py @@ -350,6 +350,14 @@ class LlmAgent(BaseAgent): - Extracts agent reply for later use, such as in tools, callbacks, etc. - Connects agents to coordinate with each other. """ + accumulate_output_key: bool = True + """Whether to accumulate streamed output fragments into `output_key`. + + When True (default) streamed fragments received before tool calls are + appended into the session state under `output_key` so the final saved value + includes all streamed text. When False, preserves legacy behavior where + only the final response's text from the last event is saved. + """ # Controlled input/output configurations - End # Advance features - Start @@ -486,7 +494,7 @@ async def _run_async_impl( should_pause = False async with Aclosing(self._llm_flow.run_async(ctx)) as agen: async for event in agen: - self.__maybe_save_output_to_state(event) + self.__maybe_save_output_to_state(event, ctx) yield event if ctx.should_pause_invocation(event): # Do not pause immediately, wait until the long-running tool call is @@ -510,7 +518,7 @@ async def _run_live_impl( ) -> AsyncGenerator[Event, None]: async with Aclosing(self._llm_flow.run_live(ctx)) as agen: async for event in agen: - self.__maybe_save_output_to_state(event) + self.__maybe_save_output_to_state(event, ctx) yield event if ctx.end_invocation: return @@ -827,8 +835,16 @@ def __get_transfer_to_agent_or_none( return self.__get_agent_to_run(event.actions.transfer_to_agent) return None - def __maybe_save_output_to_state(self, event: Event): - """Saves the model output to state if needed.""" + def __maybe_save_output_to_state( + self, event: Event, ctx: Optional[InvocationContext] = None + ): + """Saves the model output to state if needed. + + Backwards-compatible: if `ctx` is None, keeps the original behavior of + only saving on final responses. If `ctx` is provided, append streamed + partial text to the existing session state so intermediate streamed + fragments are not lost when tools are called. + """ # skip if the event was authored by some other agent (e.g. current agent # transferred to another agent) if event.author != self.name: @@ -842,33 +858,75 @@ def __maybe_save_output_to_state(self, event: Event): if not self.output_key: return - # Handle text responses - if event.is_final_response() and event.content and event.content.parts: + # Collect text parts from this event + if not (event.content and event.content.parts): + return - # Skip if no text parts at all to avoid overwriting state_delta values - # already set (e.g. after_tool_callback with skip_summarization - # on function_response-only events). - has_text_part = any( - part.text is not None and not part.thought - for part in event.content.parts - ) + result = ''.join( + part.text for part in event.content.parts if part.text and not part.thought + ) - if not has_text_part: + # If no invocation context was provided, preserve legacy behavior: only + # save on final responses and apply schema validation then. + if ctx is None: + if not event.is_final_response(): return - - result = ''.join( - part.text - for part in event.content.parts - if part.text and not part.thought - ) if self.output_schema: - # If the result from the final chunk is just whitespace or empty, - # it means this is an empty final chunk of a stream. - # Do not attempt to parse it as JSON. if not result.strip(): return result = validate_schema(self.output_schema, result) + elif not result: + return event.actions.state_delta[self.output_key] = result + return + + # When ctx is provided, append partial streamed results into session + # state so earlier streamed text is preserved across tool calls. If the + # caller disabled accumulation via `accumulate_output_key`, fall back to + # legacy behavior: ignore non-final fragments and save only the final + # fragment (without combining previous fragments). + # Read the existing value from the session (may be empty). + try: + previous = ctx.session.state.get(self.output_key, '') or '' + except Exception: + previous = '' + # If accumulation disabled, ignore non-final fragments and save only + # the final fragment as legacy behavior. + if not self.accumulate_output_key: + if not event.is_final_response(): + return + # Final-only behavior: validate only the final fragment. + if self.output_schema: + if not result.strip(): + return + validated = validate_schema(self.output_schema, result) + event.actions.state_delta[self.output_key] = validated + return + if not result: + return + event.actions.state_delta[self.output_key] = result + return + + # Accumulation enabled: Final response combines previous + result + # then validate and save. Non-final events append current fragment to + # previous value so it is available to future finalization. + if event.is_final_response(): + combined = (previous or '') + (result or '') + if not combined: + return + if self.output_schema: + if not combined.strip(): + return + validated = validate_schema(self.output_schema, combined) + event.actions.state_delta[self.output_key] = validated + return + event.actions.state_delta[self.output_key] = combined + return + + # Non-final (streaming) response: append the fragment to previous value. + if result: + event.actions.state_delta[self.output_key] = previous + result + return @model_validator(mode='after') def __model_validator_after(self) -> LlmAgent: @@ -1000,6 +1058,8 @@ def _parse_config( kwargs['output_schema'] = resolve_code_reference(config.output_schema) if config.output_key: kwargs['output_key'] = config.output_key + if getattr(config, 'accumulate_output_key', None) is not None: + kwargs['accumulate_output_key'] = config.accumulate_output_key if config.tools: kwargs['tools'] = cls._resolve_tools(config.tools, config_abs_path) if config.before_model_callbacks: diff --git a/src/google/adk/agents/llm_agent_config.py b/src/google/adk/agents/llm_agent_config.py index 93ca718094..e8d248b42b 100644 --- a/src/google/adk/agents/llm_agent_config.py +++ b/src/google/adk/agents/llm_agent_config.py @@ -139,6 +139,11 @@ def _validate_model_sources(self) -> LlmAgentConfig: default=None, description='Optional. LlmAgent.output_key.' ) + accumulate_output_key: Optional[bool] = Field( + default=None, + description='Optional. When true, streamed fragments are accumulated into the `output_key` across tool calls. When false, only the final response is saved to `output_key`.', + ) + include_contents: Literal['default', 'none'] = Field( default='default', description='Optional. LlmAgent.include_contents.' ) diff --git a/tests/unittests/agents/test_llm_agent_output_save.py b/tests/unittests/agents/test_llm_agent_output_save.py index cb3206b210..3830bcdfa5 100644 --- a/tests/unittests/agents/test_llm_agent_output_save.py +++ b/tests/unittests/agents/test_llm_agent_output_save.py @@ -310,24 +310,52 @@ def test_maybe_save_output_to_state_skips_function_response_only_event(self): # The callback-set value should be preserved, not overwritten with "" assert event.actions.state_delta["result"] == [1, 2, 3] - def test_maybe_save_output_to_state_saves_empty_string_when_text_is_empty( - self, - ): - """Test that output is saved as empty string when part.text is explicitly empty.""" + def test_accumulate_output_key_toggle(self): + """Test that `accumulate_output_key` controls accumulation behavior. + + Simulate two streamed fragments separated by a tool call by manually + updating the session state between calls. + """ + class Ctx: + pass + + # Prepare a fake invocation context with session.state + ctx = Ctx() + ctx.session = type('S', (), {'state': {}})() + + # Case 1: accumulation enabled (default) agent = LlmAgent(name="test_agent", output_key="result") - # Explicitly construct a part with empty string text - parts = [types.Part(text="")] - content = types.Content(role="model", parts=parts) - event = Event( - invocation_id="test_invocation", - author="test_agent", - content=content, - actions=EventActions(), + # First (partial) fragment + event1 = create_test_event( + author="test_agent", content_text="Intro: ", is_final=False ) + agent._LlmAgent__maybe_save_output_to_state(event1, ctx) + # Simulate session update that runner would do + ctx.session.state["result"] = event1.actions.state_delta.get("result", "") - agent._LlmAgent__maybe_save_output_to_state(event) + # Final fragment + event2 = create_test_event(author="test_agent", content_text="Conclusion", is_final=True) + agent._LlmAgent__maybe_save_output_to_state(event2, ctx) + + # With accumulation enabled, final saved value should include both parts + assert event2.actions.state_delta["result"] == "Intro: Conclusion" + + # Case 2: accumulation disabled + ctx2 = Ctx() + ctx2.session = type('S', (), {'state': {}})() + agent2 = LlmAgent(name="test_agent", output_key="result", accumulate_output_key=False) + + event1b = create_test_event( + author="test_agent", content_text="Intro: ", is_final=False + ) + agent2._LlmAgent__maybe_save_output_to_state(event1b, ctx2) + # Simulate runner updating session with the partial (though when disabled + # we expect the partial not to be used for final save) + ctx2.session.state["result"] = event1b.actions.state_delta.get("result", "") + + event2b = create_test_event(author="test_agent", content_text="Conclusion", is_final=True) + agent2._LlmAgent__maybe_save_output_to_state(event2b, ctx2) - # Assert key exists and value is empty string - assert "result" in event.actions.state_delta - assert not event.actions.state_delta["result"] + # With accumulation disabled, final saved value should be only final fragment + assert event2b.actions.state_delta["result"] == "Conclusion"