diff --git a/src/openai/lib/streaming/_assistants.py b/src/openai/lib/streaming/_assistants.py index 6efb3ca3f1..4ab97c8cf0 100644 --- a/src/openai/lib/streaming/_assistants.py +++ b/src/openai/lib/streaming/_assistants.py @@ -980,11 +980,19 @@ def accumulate_event( def accumulate_delta(acc: dict[object, object], delta: dict[object, object]) -> dict[object, object]: for key, delta_value in delta.items(): if key not in acc: + if is_list(delta_value): + acc[key] = _accumulate_list_delta([], delta_value) + continue + acc[key] = delta_value continue acc_value = acc[key] if acc_value is None: + if is_list(delta_value): + acc[key] = _accumulate_list_delta([], delta_value) + continue + acc[key] = delta_value continue @@ -1005,34 +1013,44 @@ def accumulate_delta(acc: dict[object, object], delta: dict[object, object]) -> elif is_dict(acc_value) and is_dict(delta_value): acc_value = accumulate_delta(acc_value, delta_value) elif is_list(acc_value) and is_list(delta_value): - # for lists of non-dictionary items we'll only ever get new entries - # in the array, existing entries will never be changed - if all(isinstance(x, (str, int, float)) for x in acc_value): - acc_value.extend(delta_value) - continue + acc_value = _accumulate_list_delta(acc_value, delta_value) - for delta_entry in delta_value: - if not is_dict(delta_entry): - raise TypeError(f"Unexpected list delta entry is not a dictionary: {delta_entry}") + acc[key] = acc_value - try: - index = delta_entry["index"] - except KeyError as exc: - raise RuntimeError(f"Expected list delta entry to have an `index` key; {delta_entry}") from exc + return acc - if not isinstance(index, int): - raise TypeError(f"Unexpected, list delta entry `index` value is not an integer; {index}") - try: - acc_entry = acc_value[index] - except IndexError: - acc_value.insert(index, delta_entry) - else: - if not is_dict(acc_entry): - raise TypeError("not handled yet") +def _accumulate_list_delta(acc: list[object], delta: list[object]) -> list[object]: + # for lists of non-dictionary items we'll only ever get new entries + # in the array, existing entries will never be changed + if all(isinstance(x, (str, int, float)) for x in [*acc, *delta]): + acc.extend(delta) + return acc - acc_value[index] = accumulate_delta(acc_entry, delta_entry) + for delta_entry in delta: + if not is_dict(delta_entry): + raise TypeError(f"Unexpected list delta entry is not a dictionary: {delta_entry}") - acc[key] = acc_value + try: + index = delta_entry["index"] + except KeyError as exc: + raise RuntimeError(f"Expected list delta entry to have an `index` key; {delta_entry}") from exc + + if not isinstance(index, int): + raise TypeError(f"Unexpected, list delta entry `index` value is not an integer; {index}") + + acc_entry = None + acc_index = None + for i, entry in enumerate(acc): + if is_dict(entry) and entry.get("index") == index: + acc_entry = entry + acc_index = i + break + + if acc_entry is None: + acc.insert(index, delta_entry) + else: + assert acc_index is not None + acc[acc_index] = accumulate_delta(acc_entry, delta_entry) return acc diff --git a/src/openai/lib/streaming/_deltas.py b/src/openai/lib/streaming/_deltas.py index a5e1317612..35a97a0b55 100644 --- a/src/openai/lib/streaming/_deltas.py +++ b/src/openai/lib/streaming/_deltas.py @@ -6,11 +6,19 @@ def accumulate_delta(acc: dict[object, object], delta: dict[object, object]) -> dict[object, object]: for key, delta_value in delta.items(): if key not in acc: + if is_list(delta_value): + acc[key] = _accumulate_list_delta([], delta_value) + continue + acc[key] = delta_value continue acc_value = acc[key] if acc_value is None: + if is_list(delta_value): + acc[key] = _accumulate_list_delta([], delta_value) + continue + acc[key] = delta_value continue @@ -31,34 +39,44 @@ def accumulate_delta(acc: dict[object, object], delta: dict[object, object]) -> elif is_dict(acc_value) and is_dict(delta_value): acc_value = accumulate_delta(acc_value, delta_value) elif is_list(acc_value) and is_list(delta_value): - # for lists of non-dictionary items we'll only ever get new entries - # in the array, existing entries will never be changed - if all(isinstance(x, (str, int, float)) for x in acc_value): - acc_value.extend(delta_value) - continue + acc_value = _accumulate_list_delta(acc_value, delta_value) - for delta_entry in delta_value: - if not is_dict(delta_entry): - raise TypeError(f"Unexpected list delta entry is not a dictionary: {delta_entry}") + acc[key] = acc_value - try: - index = delta_entry["index"] - except KeyError as exc: - raise RuntimeError(f"Expected list delta entry to have an `index` key; {delta_entry}") from exc + return acc - if not isinstance(index, int): - raise TypeError(f"Unexpected, list delta entry `index` value is not an integer; {index}") - try: - acc_entry = acc_value[index] - except IndexError: - acc_value.insert(index, delta_entry) - else: - if not is_dict(acc_entry): - raise TypeError("not handled yet") +def _accumulate_list_delta(acc: list[object], delta: list[object]) -> list[object]: + # for lists of non-dictionary items we'll only ever get new entries + # in the array, existing entries will never be changed + if all(isinstance(x, (str, int, float)) for x in [*acc, *delta]): + acc.extend(delta) + return acc - acc_value[index] = accumulate_delta(acc_entry, delta_entry) + for delta_entry in delta: + if not is_dict(delta_entry): + raise TypeError(f"Unexpected list delta entry is not a dictionary: {delta_entry}") - acc[key] = acc_value + try: + index = delta_entry["index"] + except KeyError as exc: + raise RuntimeError(f"Expected list delta entry to have an `index` key; {delta_entry}") from exc + + if not isinstance(index, int): + raise TypeError(f"Unexpected, list delta entry `index` value is not an integer; {index}") + + acc_entry = None + acc_index = None + for i, entry in enumerate(acc): + if is_dict(entry) and entry.get("index") == index: + acc_entry = entry + acc_index = i + break + + if acc_entry is None: + acc.insert(index, delta_entry) + else: + assert acc_index is not None + acc[acc_index] = accumulate_delta(acc_entry, delta_entry) return acc diff --git a/tests/lib/test_streaming_deltas.py b/tests/lib/test_streaming_deltas.py new file mode 100644 index 0000000000..5defed1235 --- /dev/null +++ b/tests/lib/test_streaming_deltas.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from collections.abc import Callable + +import pytest + +from openai.lib.streaming._deltas import accumulate_delta as accumulate_chat_delta +from openai.lib.streaming._assistants import accumulate_delta as accumulate_assistants_delta + + +@pytest.mark.parametrize("accumulate_delta", [accumulate_chat_delta, accumulate_assistants_delta]) +def test_accumulate_delta_merges_duplicate_index_entries_in_initial_list( + accumulate_delta: Callable[[dict[object, object], dict[object, object]], dict[object, object]], +) -> None: + acc: dict[object, object] = {} + + accumulate_delta( + acc, + { + "tool_calls": [ + {"index": 0, "id": "call_abc", "function": {"name": "list_files"}, "type": "function"}, + {"index": 0, "function": {"arguments": '{"path"'}}, + ] + }, + ) + accumulate_delta(acc, {"tool_calls": [{"index": 0, "function": {"arguments": ': "."}'}}]}) + + assert acc == { + "tool_calls": [ + { + "index": 0, + "id": "call_abc", + "function": {"name": "list_files", "arguments": '{"path": "."}'}, + "type": "function", + } + ] + }