diff --git a/src/openai/lib/streaming/_assistants.py b/src/openai/lib/streaming/_assistants.py index 6efb3ca3f1..ccde728f2f 100644 --- a/src/openai/lib/streaming/_assistants.py +++ b/src/openai/lib/streaming/_assistants.py @@ -980,13 +980,20 @@ def accumulate_event( def accumulate_delta(acc: dict[object, object], delta: dict[object, object]) -> dict[object, object]: for key, delta_value in delta.items(): if key not in acc: - acc[key] = delta_value - continue + if is_list(delta_value): + acc_value = [] + else: + acc[key] = delta_value + continue - acc_value = acc[key] - if acc_value is None: - acc[key] = delta_value - continue + else: + acc_value = acc[key] + if acc_value is None: + if is_list(delta_value): + acc_value = [] + else: + acc[key] = delta_value + continue # the `index` property is used in arrays of objects so it should # not be accumulated like other values e.g. @@ -1007,8 +1014,11 @@ def accumulate_delta(acc: dict[object, object], delta: dict[object, object]) -> elif is_list(acc_value) and is_list(delta_value): # for lists of non-dictionary items we'll only ever get new entries # in the array, existing entries will never be changed - if all(isinstance(x, (str, int, float)) for x in acc_value): + if all(isinstance(x, (str, int, float)) for x in acc_value) and all( + isinstance(x, (str, int, float)) for x in delta_value + ): acc_value.extend(delta_value) + acc[key] = acc_value continue for delta_entry in delta_value: diff --git a/src/openai/lib/streaming/_deltas.py b/src/openai/lib/streaming/_deltas.py index a5e1317612..27c22b6abd 100644 --- a/src/openai/lib/streaming/_deltas.py +++ b/src/openai/lib/streaming/_deltas.py @@ -6,13 +6,20 @@ def accumulate_delta(acc: dict[object, object], delta: dict[object, object]) -> dict[object, object]: for key, delta_value in delta.items(): if key not in acc: - acc[key] = delta_value - continue + if is_list(delta_value): + acc_value = [] + else: + acc[key] = delta_value + continue - acc_value = acc[key] - if acc_value is None: - acc[key] = delta_value - continue + else: + acc_value = acc[key] + if acc_value is None: + if is_list(delta_value): + acc_value = [] + else: + acc[key] = delta_value + continue # the `index` property is used in arrays of objects so it should # not be accumulated like other values e.g. @@ -33,8 +40,11 @@ def accumulate_delta(acc: dict[object, object], delta: dict[object, object]) -> elif is_list(acc_value) and is_list(delta_value): # for lists of non-dictionary items we'll only ever get new entries # in the array, existing entries will never be changed - if all(isinstance(x, (str, int, float)) for x in acc_value): + if all(isinstance(x, (str, int, float)) for x in acc_value) and all( + isinstance(x, (str, int, float)) for x in delta_value + ): acc_value.extend(delta_value) + acc[key] = acc_value continue for delta_entry in delta_value: diff --git a/src/openai/lib/streaming/chat/_completions.py b/src/openai/lib/streaming/chat/_completions.py index 5f072cafbd..c1a581a7fe 100644 --- a/src/openai/lib/streaming/chat/_completions.py +++ b/src/openai/lib/streaming/chat/_completions.py @@ -744,7 +744,7 @@ def _convert_initial_chunk_into_snapshot(chunk: ChatCompletionChunk) -> ParsedCh for choice in chunk.choices: choices[choice.index] = { **choice.model_dump(exclude_unset=True, exclude={"delta"}), - "message": choice.delta.to_dict(), + "message": accumulate_delta({}, choice.delta.to_dict()), } return cast( diff --git a/tests/lib/test_streaming_deltas.py b/tests/lib/test_streaming_deltas.py new file mode 100644 index 0000000000..d081629357 --- /dev/null +++ b/tests/lib/test_streaming_deltas.py @@ -0,0 +1,134 @@ +from __future__ import annotations + +import json +from copy import deepcopy +from collections.abc import Callable + +import pytest + +from openai.types.chat import ChatCompletionChunk +from openai.lib.streaming.chat import ChatCompletionStreamState +from openai.lib.streaming._deltas import accumulate_delta as accumulate_chat_delta +from openai.lib.streaming._assistants import accumulate_delta as accumulate_assistant_delta + +AccumulateDelta = Callable[[dict[object, object], dict[object, object]], dict[object, object]] + + +@pytest.mark.parametrize("accumulate_delta", [accumulate_chat_delta, accumulate_assistant_delta]) +@pytest.mark.parametrize("initial_acc", [{}, {"tool_calls": None}]) +def test_accumulate_delta_merges_duplicate_index_entries_in_initial_list( + accumulate_delta: AccumulateDelta, + initial_acc: dict[object, object], +) -> None: + acc = deepcopy(initial_acc) + + accumulate_delta( + acc, + { + "tool_calls": [ + { + "index": 0, + "id": "call_abc", + "function": {"name": "get_weather"}, + "type": "function", + }, + { + "index": 0, + "function": {"arguments": '{"city"'}, + }, + ] + }, + ) + accumulate_delta( + acc, + { + "tool_calls": [ + { + "index": 0, + "function": {"arguments": ': "London"}'}, + }, + ] + }, + ) + + tool_calls = acc["tool_calls"] + assert isinstance(tool_calls, list) + assert len(tool_calls) == 1 + + arguments = tool_calls[0]["function"]["arguments"] + assert arguments == '{"city": "London"}' + assert json.loads(arguments) == {"city": "London"} + + +@pytest.mark.parametrize("accumulate_delta", [accumulate_chat_delta, accumulate_assistant_delta]) +@pytest.mark.parametrize("initial_acc", [{}, {"content": None}]) +def test_accumulate_delta_preserves_initial_primitive_lists( + accumulate_delta: AccumulateDelta, + initial_acc: dict[object, object], +) -> None: + acc = deepcopy(initial_acc) + + accumulate_delta(acc, {"content": ["hello", " ", "world"]}) + + assert acc["content"] == ["hello", " ", "world"] + + +def test_chat_stream_state_merges_duplicate_tool_call_indexes_in_first_chunk() -> None: + state = ChatCompletionStreamState[object]() + + state.handle_chunk( + ChatCompletionChunk( + id="chatcmpl_123", + created=1, + model="gpt-4o", + object="chat.completion.chunk", + choices=[ + { + "index": 0, + "finish_reason": None, + "delta": { + "role": "assistant", + "tool_calls": [ + { + "index": 0, + "id": "call_abc", + "function": {"name": "get_weather"}, + "type": "function", + }, + { + "index": 0, + "function": {"arguments": '{"city"'}, + }, + ], + }, + } + ], + ) + ) + state.handle_chunk( + ChatCompletionChunk( + id="chatcmpl_123", + created=1, + model="gpt-4o", + object="chat.completion.chunk", + choices=[ + { + "index": 0, + "finish_reason": None, + "delta": { + "tool_calls": [ + { + "index": 0, + "function": {"arguments": ': "London"}'}, + }, + ], + }, + } + ], + ) + ) + + tool_calls = state.current_completion_snapshot.choices[0].message.tool_calls + assert tool_calls is not None + assert len(tool_calls) == 1 + assert tool_calls[0].function.arguments == '{"city": "London"}'