|
1 | 1 | import importlib.metadata |
2 | 2 | import json |
| 3 | +import logging |
3 | 4 | from unittest.mock import Mock, patch |
4 | 5 |
|
5 | 6 | import httpx |
@@ -120,68 +121,6 @@ async def test_success(self, httpx_mock: HTTPXMock, agent: Agent[HelloTaskInput, |
120 | 121 | "stream": False, |
121 | 122 | } |
122 | 123 |
|
123 | | - async def test_stream(self, httpx_mock: HTTPXMock, agent: Agent[HelloTaskInput, HelloTaskOutput]): |
124 | | - httpx_mock.add_response( |
125 | | - stream=IteratorStream( |
126 | | - [ |
127 | | - b'data: {"id":"1","task_output":{"message":""}}\n\n', |
128 | | - b'data: {"id":"1","task_output":{"message":"hel"}}\n\ndata: {"id":"1","task_output":{"message":"hello"}}\n\n', # noqa: E501 |
129 | | - b'data: {"id":"1","task_output":{"message":"hello"},"version":{"properties":{"model":"gpt-4o","temperature":0.5}},"cost_usd":0.01,"duration_seconds":10.1}\n\n', # noqa: E501 |
130 | | - ], |
131 | | - ), |
132 | | - ) |
133 | | - |
134 | | - chunks = [chunk async for chunk in agent.stream(HelloTaskInput(name="Alice"))] |
135 | | - |
136 | | - outputs = [chunk.output for chunk in chunks] |
137 | | - assert outputs == [ |
138 | | - HelloTaskOutput(message=""), |
139 | | - HelloTaskOutput(message="hel"), |
140 | | - HelloTaskOutput(message="hello"), |
141 | | - HelloTaskOutput(message="hello"), |
142 | | - ] |
143 | | - last_message = chunks[-1] |
144 | | - assert isinstance(last_message, Run) |
145 | | - assert last_message.version |
146 | | - assert last_message.version.properties.model == "gpt-4o" |
147 | | - assert last_message.version.properties.temperature == 0.5 |
148 | | - assert last_message.cost_usd == 0.01 |
149 | | - assert last_message.duration_seconds == 10.1 |
150 | | - |
151 | | - async def test_stream_not_optional( |
152 | | - self, |
153 | | - httpx_mock: HTTPXMock, |
154 | | - agent_not_optional: Agent[HelloTaskInput, HelloTaskOutputNotOptional], |
155 | | - ): |
156 | | - # Checking that streaming works even with non optional fields |
157 | | - # The first two chunks are missing a required key but the last one has it |
158 | | - httpx_mock.add_response( |
159 | | - stream=IteratorStream( |
160 | | - [ |
161 | | - b'data: {"id":"1","task_output":{"message":""}}\n\n', |
162 | | - b'data: {"id":"1","task_output":{"message":"hel"}}\n\ndata: {"id":"1","task_output":{"message":"hello"}}\n\n', # noqa: E501 |
163 | | - b'data: {"id":"1","task_output":{"message":"hello", "another_field": "test"},"version":{"properties":{"model":"gpt-4o","temperature":0.5}},"cost_usd":0.01,"duration_seconds":10.1}\n\n', # noqa: E501 |
164 | | - ], |
165 | | - ), |
166 | | - ) |
167 | | - |
168 | | - chunks = [chunk async for chunk in agent_not_optional.stream(HelloTaskInput(name="Alice"))] |
169 | | - |
170 | | - messages = [chunk.output.message for chunk in chunks] |
171 | | - assert messages == ["", "hel", "hello", "hello"] |
172 | | - |
173 | | - for chunk in chunks[:-1]: |
174 | | - assert chunk.output.another_field == "" |
175 | | - assert chunks[-1].output.another_field == "test" |
176 | | - |
177 | | - last_message = chunks[-1] |
178 | | - assert isinstance(last_message, Run) |
179 | | - assert last_message.version |
180 | | - assert last_message.version.properties.model == "gpt-4o" |
181 | | - assert last_message.version.properties.temperature == 0.5 |
182 | | - assert last_message.cost_usd == 0.01 |
183 | | - assert last_message.duration_seconds == 10.1 |
184 | | - |
185 | 124 | async def test_run_with_env(self, httpx_mock: HTTPXMock, agent: Agent[HelloTaskInput, HelloTaskOutput]): |
186 | 125 | httpx_mock.add_response(json=fixtures_json("task_run.json")) |
187 | 126 |
|
@@ -1024,3 +963,118 @@ async def test_fetch_completions(self, agent: Agent[HelloTaskInput, HelloTaskOut |
1024 | 963 | ), |
1025 | 964 | ), |
1026 | 965 | ] |
| 966 | + |
| 967 | + |
| 968 | +class TestStream: |
| 969 | + async def test_stream(self, httpx_mock: HTTPXMock, agent: Agent[HelloTaskInput, HelloTaskOutput]): |
| 970 | + httpx_mock.add_response( |
| 971 | + stream=IteratorStream( |
| 972 | + [ |
| 973 | + b'data: {"id":"1","task_output":{"message":""}}\n\n', |
| 974 | + b'data: {"id":"1","task_output":{"message":"hel"}}\n\ndata: {"id":"1","task_output":{"message":"hello"}}\n\n', # noqa: E501 |
| 975 | + b'data: {"id":"1","task_output":{"message":"hello"},"version":{"properties":{"model":"gpt-4o","temperature":0.5}},"cost_usd":0.01,"duration_seconds":10.1}\n\n', # noqa: E501 |
| 976 | + ], |
| 977 | + ), |
| 978 | + ) |
| 979 | + |
| 980 | + chunks = [chunk async for chunk in agent.stream(HelloTaskInput(name="Alice"))] |
| 981 | + |
| 982 | + outputs = [chunk.output for chunk in chunks] |
| 983 | + assert outputs == [ |
| 984 | + HelloTaskOutput(message=""), |
| 985 | + HelloTaskOutput(message="hel"), |
| 986 | + HelloTaskOutput(message="hello"), |
| 987 | + HelloTaskOutput(message="hello"), |
| 988 | + ] |
| 989 | + last_message = chunks[-1] |
| 990 | + assert isinstance(last_message, Run) |
| 991 | + assert last_message.version |
| 992 | + assert last_message.version.properties.model == "gpt-4o" |
| 993 | + assert last_message.version.properties.temperature == 0.5 |
| 994 | + assert last_message.cost_usd == 0.01 |
| 995 | + assert last_message.duration_seconds == 10.1 |
| 996 | + |
| 997 | + async def test_stream_not_optional( |
| 998 | + self, |
| 999 | + httpx_mock: HTTPXMock, |
| 1000 | + agent_not_optional: Agent[HelloTaskInput, HelloTaskOutputNotOptional], |
| 1001 | + ): |
| 1002 | + # Checking that streaming works even with non optional fields |
| 1003 | + # The first two chunks are missing a required key but the last one has it |
| 1004 | + httpx_mock.add_response( |
| 1005 | + stream=IteratorStream( |
| 1006 | + [ |
| 1007 | + b'data: {"id":"1","task_output":{"message":""}}\n\n', |
| 1008 | + b'data: {"id":"1","task_output":{"message":"hel"}}\n\ndata: {"id":"1","task_output":{"message":"hello"}}\n\n', # noqa: E501 |
| 1009 | + b'data: {"id":"1","task_output":{"message":"hello", "another_field": "test"},"version":{"properties":{"model":"gpt-4o","temperature":0.5}},"cost_usd":0.01,"duration_seconds":10.1}\n\n', # noqa: E501 |
| 1010 | + ], |
| 1011 | + ), |
| 1012 | + ) |
| 1013 | + |
| 1014 | + chunks = [chunk async for chunk in agent_not_optional.stream(HelloTaskInput(name="Alice"))] |
| 1015 | + |
| 1016 | + messages = [chunk.output.message for chunk in chunks] |
| 1017 | + assert messages == ["", "hel", "hello", "hello"] |
| 1018 | + |
| 1019 | + for chunk in chunks[:-1]: |
| 1020 | + assert chunk.output.another_field == "" |
| 1021 | + assert chunks[-1].output.another_field == "test" |
| 1022 | + |
| 1023 | + last_message = chunks[-1] |
| 1024 | + assert isinstance(last_message, Run) |
| 1025 | + assert last_message.version |
| 1026 | + assert last_message.version.properties.model == "gpt-4o" |
| 1027 | + assert last_message.version.properties.temperature == 0.5 |
| 1028 | + assert last_message.cost_usd == 0.01 |
| 1029 | + assert last_message.duration_seconds == 10.1 |
| 1030 | + |
| 1031 | + async def test_stream_validation_errors( |
| 1032 | + self, |
| 1033 | + agent: Agent[HelloTaskInput, HelloTaskOutput], |
| 1034 | + httpx_mock: HTTPXMock, |
| 1035 | + caplog: pytest.LogCaptureFixture, |
| 1036 | + ): |
| 1037 | + """Test that validation errors are properly skipped and logged during streaming""" |
| 1038 | + httpx_mock.add_response( |
| 1039 | + stream=IteratorStream( |
| 1040 | + [ |
| 1041 | + b'data: {"id":"1","task_output":{"message":""}}\n\n', |
| 1042 | + # Middle chunk is passed |
| 1043 | + b'data: {"id":"1","task_output":{"message":1}}\n\n', |
| 1044 | + b'data: {"id":"1","task_output":{"message":"hello"}}\n\n', |
| 1045 | + ], |
| 1046 | + ), |
| 1047 | + ) |
| 1048 | + |
| 1049 | + with caplog.at_level(logging.DEBUG): |
| 1050 | + chunks = [chunk async for chunk in agent.stream(HelloTaskInput(name="Alice"))] |
| 1051 | + |
| 1052 | + assert len(chunks) == 2 |
| 1053 | + assert chunks[0].output.message == "" |
| 1054 | + assert chunks[1].output.message == "hello" |
| 1055 | + logs = [record for record in caplog.records if "Client side validation error in stream" in record.message] |
| 1056 | + |
| 1057 | + assert len(logs) == 1 |
| 1058 | + assert logs[0].levelname == "DEBUG" |
| 1059 | + assert logs[0].exc_info is not None |
| 1060 | + |
| 1061 | + async def test_stream_validation_final_error( |
| 1062 | + self, |
| 1063 | + agent: Agent[HelloTaskInput, HelloTaskOutput], |
| 1064 | + httpx_mock: HTTPXMock, |
| 1065 | + ): |
| 1066 | + """Check that we properly raise an error if the final payload fails to validate.""" |
| 1067 | + httpx_mock.add_response( |
| 1068 | + stream=IteratorStream( |
| 1069 | + [ |
| 1070 | + # Stream a single chunk that fails to validate |
| 1071 | + b'data: {"id":"1","task_output":{"message":1}}\n\n', |
| 1072 | + ], |
| 1073 | + ), |
| 1074 | + ) |
| 1075 | + |
| 1076 | + with pytest.raises(WorkflowAIError) as e: |
| 1077 | + _ = [c async for c in agent.stream(HelloTaskInput(name="Alice"))] |
| 1078 | + |
| 1079 | + assert e.value.partial_output == {"message": 1} |
| 1080 | + assert e.value.run_id == "1" |
0 commit comments