Skip to content

Commit 2e838c9

Browse files
committed
test: add test for validation errors during streaming
1 parent 6624df1 commit 2e838c9

File tree

2 files changed

+121
-63
lines changed

2 files changed

+121
-63
lines changed

workflowai/core/client/agent.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,11 @@ async def stream(
444444
)
445445
final_error = None
446446
except ValidationError as e:
447-
logger.debug("Validation error in stream", exc_info=e)
447+
logger.debug(
448+
"Client side validation error in stream. There is likely an "
449+
"issue with the validator or the model.",
450+
exc_info=e,
451+
)
448452
final_error = e
449453
continue
450454
if final_error:

workflowai/core/client/agent_test.py

Lines changed: 116 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import importlib.metadata
22
import json
3+
import logging
34
from unittest.mock import Mock, patch
45

56
import httpx
@@ -120,68 +121,6 @@ async def test_success(self, httpx_mock: HTTPXMock, agent: Agent[HelloTaskInput,
120121
"stream": False,
121122
}
122123

123-
async def test_stream(self, httpx_mock: HTTPXMock, agent: Agent[HelloTaskInput, HelloTaskOutput]):
124-
httpx_mock.add_response(
125-
stream=IteratorStream(
126-
[
127-
b'data: {"id":"1","task_output":{"message":""}}\n\n',
128-
b'data: {"id":"1","task_output":{"message":"hel"}}\n\ndata: {"id":"1","task_output":{"message":"hello"}}\n\n', # noqa: E501
129-
b'data: {"id":"1","task_output":{"message":"hello"},"version":{"properties":{"model":"gpt-4o","temperature":0.5}},"cost_usd":0.01,"duration_seconds":10.1}\n\n', # noqa: E501
130-
],
131-
),
132-
)
133-
134-
chunks = [chunk async for chunk in agent.stream(HelloTaskInput(name="Alice"))]
135-
136-
outputs = [chunk.output for chunk in chunks]
137-
assert outputs == [
138-
HelloTaskOutput(message=""),
139-
HelloTaskOutput(message="hel"),
140-
HelloTaskOutput(message="hello"),
141-
HelloTaskOutput(message="hello"),
142-
]
143-
last_message = chunks[-1]
144-
assert isinstance(last_message, Run)
145-
assert last_message.version
146-
assert last_message.version.properties.model == "gpt-4o"
147-
assert last_message.version.properties.temperature == 0.5
148-
assert last_message.cost_usd == 0.01
149-
assert last_message.duration_seconds == 10.1
150-
151-
async def test_stream_not_optional(
152-
self,
153-
httpx_mock: HTTPXMock,
154-
agent_not_optional: Agent[HelloTaskInput, HelloTaskOutputNotOptional],
155-
):
156-
# Checking that streaming works even with non optional fields
157-
# The first two chunks are missing a required key but the last one has it
158-
httpx_mock.add_response(
159-
stream=IteratorStream(
160-
[
161-
b'data: {"id":"1","task_output":{"message":""}}\n\n',
162-
b'data: {"id":"1","task_output":{"message":"hel"}}\n\ndata: {"id":"1","task_output":{"message":"hello"}}\n\n', # noqa: E501
163-
b'data: {"id":"1","task_output":{"message":"hello", "another_field": "test"},"version":{"properties":{"model":"gpt-4o","temperature":0.5}},"cost_usd":0.01,"duration_seconds":10.1}\n\n', # noqa: E501
164-
],
165-
),
166-
)
167-
168-
chunks = [chunk async for chunk in agent_not_optional.stream(HelloTaskInput(name="Alice"))]
169-
170-
messages = [chunk.output.message for chunk in chunks]
171-
assert messages == ["", "hel", "hello", "hello"]
172-
173-
for chunk in chunks[:-1]:
174-
assert chunk.output.another_field == ""
175-
assert chunks[-1].output.another_field == "test"
176-
177-
last_message = chunks[-1]
178-
assert isinstance(last_message, Run)
179-
assert last_message.version
180-
assert last_message.version.properties.model == "gpt-4o"
181-
assert last_message.version.properties.temperature == 0.5
182-
assert last_message.cost_usd == 0.01
183-
assert last_message.duration_seconds == 10.1
184-
185124
async def test_run_with_env(self, httpx_mock: HTTPXMock, agent: Agent[HelloTaskInput, HelloTaskOutput]):
186125
httpx_mock.add_response(json=fixtures_json("task_run.json"))
187126

@@ -1024,3 +963,118 @@ async def test_fetch_completions(self, agent: Agent[HelloTaskInput, HelloTaskOut
1024963
),
1025964
),
1026965
]
966+
967+
968+
class TestStream:
969+
async def test_stream(self, httpx_mock: HTTPXMock, agent: Agent[HelloTaskInput, HelloTaskOutput]):
970+
httpx_mock.add_response(
971+
stream=IteratorStream(
972+
[
973+
b'data: {"id":"1","task_output":{"message":""}}\n\n',
974+
b'data: {"id":"1","task_output":{"message":"hel"}}\n\ndata: {"id":"1","task_output":{"message":"hello"}}\n\n', # noqa: E501
975+
b'data: {"id":"1","task_output":{"message":"hello"},"version":{"properties":{"model":"gpt-4o","temperature":0.5}},"cost_usd":0.01,"duration_seconds":10.1}\n\n', # noqa: E501
976+
],
977+
),
978+
)
979+
980+
chunks = [chunk async for chunk in agent.stream(HelloTaskInput(name="Alice"))]
981+
982+
outputs = [chunk.output for chunk in chunks]
983+
assert outputs == [
984+
HelloTaskOutput(message=""),
985+
HelloTaskOutput(message="hel"),
986+
HelloTaskOutput(message="hello"),
987+
HelloTaskOutput(message="hello"),
988+
]
989+
last_message = chunks[-1]
990+
assert isinstance(last_message, Run)
991+
assert last_message.version
992+
assert last_message.version.properties.model == "gpt-4o"
993+
assert last_message.version.properties.temperature == 0.5
994+
assert last_message.cost_usd == 0.01
995+
assert last_message.duration_seconds == 10.1
996+
997+
async def test_stream_not_optional(
998+
self,
999+
httpx_mock: HTTPXMock,
1000+
agent_not_optional: Agent[HelloTaskInput, HelloTaskOutputNotOptional],
1001+
):
1002+
# Checking that streaming works even with non optional fields
1003+
# The first two chunks are missing a required key but the last one has it
1004+
httpx_mock.add_response(
1005+
stream=IteratorStream(
1006+
[
1007+
b'data: {"id":"1","task_output":{"message":""}}\n\n',
1008+
b'data: {"id":"1","task_output":{"message":"hel"}}\n\ndata: {"id":"1","task_output":{"message":"hello"}}\n\n', # noqa: E501
1009+
b'data: {"id":"1","task_output":{"message":"hello", "another_field": "test"},"version":{"properties":{"model":"gpt-4o","temperature":0.5}},"cost_usd":0.01,"duration_seconds":10.1}\n\n', # noqa: E501
1010+
],
1011+
),
1012+
)
1013+
1014+
chunks = [chunk async for chunk in agent_not_optional.stream(HelloTaskInput(name="Alice"))]
1015+
1016+
messages = [chunk.output.message for chunk in chunks]
1017+
assert messages == ["", "hel", "hello", "hello"]
1018+
1019+
for chunk in chunks[:-1]:
1020+
assert chunk.output.another_field == ""
1021+
assert chunks[-1].output.another_field == "test"
1022+
1023+
last_message = chunks[-1]
1024+
assert isinstance(last_message, Run)
1025+
assert last_message.version
1026+
assert last_message.version.properties.model == "gpt-4o"
1027+
assert last_message.version.properties.temperature == 0.5
1028+
assert last_message.cost_usd == 0.01
1029+
assert last_message.duration_seconds == 10.1
1030+
1031+
async def test_stream_validation_errors(
1032+
self,
1033+
agent: Agent[HelloTaskInput, HelloTaskOutput],
1034+
httpx_mock: HTTPXMock,
1035+
caplog: pytest.LogCaptureFixture,
1036+
):
1037+
"""Test that validation errors are properly skipped and logged during streaming"""
1038+
httpx_mock.add_response(
1039+
stream=IteratorStream(
1040+
[
1041+
b'data: {"id":"1","task_output":{"message":""}}\n\n',
1042+
# Middle chunk is passed
1043+
b'data: {"id":"1","task_output":{"message":1}}\n\n',
1044+
b'data: {"id":"1","task_output":{"message":"hello"}}\n\n',
1045+
],
1046+
),
1047+
)
1048+
1049+
with caplog.at_level(logging.DEBUG):
1050+
chunks = [chunk async for chunk in agent.stream(HelloTaskInput(name="Alice"))]
1051+
1052+
assert len(chunks) == 2
1053+
assert chunks[0].output.message == ""
1054+
assert chunks[1].output.message == "hello"
1055+
logs = [record for record in caplog.records if "Client side validation error in stream" in record.message]
1056+
1057+
assert len(logs) == 1
1058+
assert logs[0].levelname == "DEBUG"
1059+
assert logs[0].exc_info is not None
1060+
1061+
async def test_stream_validation_final_error(
1062+
self,
1063+
agent: Agent[HelloTaskInput, HelloTaskOutput],
1064+
httpx_mock: HTTPXMock,
1065+
):
1066+
"""Check that we properly raise an error if the final payload fails to validate."""
1067+
httpx_mock.add_response(
1068+
stream=IteratorStream(
1069+
[
1070+
# Stream a single chunk that fails to validate
1071+
b'data: {"id":"1","task_output":{"message":1}}\n\n',
1072+
],
1073+
),
1074+
)
1075+
1076+
with pytest.raises(WorkflowAIError) as e:
1077+
_ = [c async for c in agent.stream(HelloTaskInput(name="Alice"))]
1078+
1079+
assert e.value.partial_output == {"message": 1}
1080+
assert e.value.run_id == "1"

0 commit comments

Comments
 (0)