From d526ef4703282366bacdcea1f525609322080ed3 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Sat, 7 Feb 2026 07:30:31 +0900 Subject: [PATCH] fix: persist streamed run-again tool items to session --- src/agents/run_internal/run_loop.py | 6 +++++ tests/test_agent_runner_streamed.py | 36 +++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/src/agents/run_internal/run_loop.py b/src/agents/run_internal/run_loop.py index e807c0cb11..00d1f0facb 100644 --- a/src/agents/run_internal/run_loop.py +++ b/src/agents/run_internal/run_loop.py @@ -946,6 +946,12 @@ async def _save_stream_items_without_count( if streamed_result._state is not None: streamed_result._state._current_step = NextStepRunAgain() + await _save_stream_items_with_count( + turn_session_items, + turn_result.model_response.response_id, + store_setting, + ) + if streamed_result._cancel_mode == "after_turn": # type: ignore[comparison-overlap] streamed_result.is_complete = True streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) diff --git a/tests/test_agent_runner_streamed.py b/tests/test_agent_runner_streamed.py index 29377db912..10004e88fd 100644 --- a/tests/test_agent_runner_streamed.py +++ b/tests/test_agent_runner_streamed.py @@ -151,6 +151,42 @@ async def test_tool_call_runs(): ) +@pytest.mark.asyncio +async def test_streamed_run_again_persists_tool_items_to_session(): + model = FakeModel() + call_id = "call-session-run-again" + agent = Agent( + name="test", + model=model, + tools=[get_function_tool("foo", "tool_result")], + ) + session = SimpleListSession() + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("foo", json.dumps({"a": "b"}), call_id=call_id)], + [get_text_message("done")], + ] + ) + + result = Runner.run_streamed(agent, input="user_message", session=session) + await consume_stream(result) + + saved_items = await session.get_items() + assert any( + isinstance(item, dict) + and item.get("type") == "function_call" + and item.get("call_id") == call_id + for item in saved_items + ) + assert any( + isinstance(item, dict) + and item.get("type") == "function_call_output" + and item.get("call_id") == call_id + for item in saved_items + ) + + @pytest.mark.asyncio async def test_handoffs(): model = FakeModel()