Skip to content

Commit 2bd3aa6

Browse files
committed
test: add deterministic mock tests for astream incremental logic
The existing test_astream_incremental.py tests were marked qualitative because they rely on a live LLM backend and token streaming timing which is non-deterministic, leading to flaky assertions in CI. To maintain test coverage of the internal ModelOutputThunk async queue logic, this commit introduces test_astream_mock.py. This new file manually structures chunks into the queue identically to test the system reliably in all environments.
1 parent 9c78504 commit 2bd3aa6

1 file changed

Lines changed: 164 additions & 0 deletions

File tree

test/core/test_astream_mock.py

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
"""Deterministic Mock Tests for ModelOutputThunk.astream() incremental return behavior.
2+
3+
Tests that astream() returns only new content added since the beginning of
4+
each astream() call, not the entire accumulated value. Uses manual queue
5+
injection to bypass LLM calls and network operations, guaranteeing determinism.
6+
"""
7+
8+
import asyncio
9+
from typing import Any
10+
11+
import pytest
12+
13+
from mellea.core.base import CBlock, GenerateType, ModelOutputThunk
14+
15+
16+
async def mock_process(mot: ModelOutputThunk, chunk: Any) -> None:
17+
"""Mock process function that simply appends the chunk to the underlying value."""
18+
if mot._underlying_value is None:
19+
mot._underlying_value = ""
20+
if chunk is not None:
21+
mot._underlying_value += chunk
22+
23+
24+
async def mock_post_process(mot: ModelOutputThunk) -> None:
25+
"""Mock post-process function (does nothing)."""
26+
27+
28+
def create_manual_mock_thunk() -> ModelOutputThunk:
29+
"""Helper to create a mock ModelOutputThunk where we manually populate the queue."""
30+
mot = ModelOutputThunk(value=None)
31+
mot._action = CBlock("mock_action")
32+
mot._generate_type = GenerateType.ASYNC
33+
mot._process = mock_process
34+
mot._post_process = mock_post_process
35+
mot._chunk_size = 0 # Read exactly what is available
36+
return mot
37+
38+
39+
@pytest.mark.asyncio
40+
async def test_astream_returns_incremental_chunks():
41+
"""Test that astream() returns only new content, not accumulated content."""
42+
mot = create_manual_mock_thunk()
43+
44+
# Drop the first chunk and pull it
45+
mot._async_queue.put_nowait("chunk1 ")
46+
chunk1 = await mot.astream()
47+
assert chunk1 == "chunk1 "
48+
49+
# Drop the second chunk and pull it
50+
mot._async_queue.put_nowait("chunk2 ")
51+
chunk2 = await mot.astream()
52+
assert chunk2 == "chunk2 "
53+
54+
# Drop the third chunk and pull it
55+
mot._async_queue.put_nowait("chunk3 ")
56+
chunk3 = await mot.astream()
57+
assert chunk3 == "chunk3 "
58+
59+
# Send completion sentinel
60+
mot._async_queue.put_nowait(None)
61+
62+
# Wait until fully consumed
63+
while not mot.is_computed():
64+
await mot.astream()
65+
66+
final_val = await mot.avalue()
67+
assert final_val == "chunk1 chunk2 chunk3 "
68+
69+
70+
@pytest.mark.asyncio
71+
async def test_astream_multiple_calls_accumulate_correctly():
72+
"""Test that multiple astream() calls accumulate to the final value."""
73+
# Simulating a scenario where queue chunks outpace the reading loop
74+
mot = create_manual_mock_thunk()
75+
76+
# Drop multiple items at once to simulate fast network
77+
mot._async_queue.put_nowait("c")
78+
mot._async_queue.put_nowait("h")
79+
mot._async_queue.put_nowait("u")
80+
81+
# Calling astream should drain all currently queued items ("chu")
82+
chunk1 = await mot.astream()
83+
assert chunk1 == "chu"
84+
85+
mot._async_queue.put_nowait("n")
86+
mot._async_queue.put_nowait("k")
87+
mot._async_queue.put_nowait(None)
88+
89+
chunk2 = await mot.astream()
90+
assert chunk2 == "chunk"
91+
92+
final_val = await mot.avalue()
93+
94+
assert mot.is_computed()
95+
assert final_val == "chunk"
96+
97+
98+
@pytest.mark.asyncio
99+
async def test_astream_beginning_length_tracking():
100+
"""Test that beginning_length is correctly tracked across astream calls."""
101+
mot = create_manual_mock_thunk()
102+
103+
mot._async_queue.put_nowait("AAA")
104+
chunk1 = await mot.astream()
105+
assert chunk1 == "AAA"
106+
107+
mot._async_queue.put_nowait("BBB")
108+
chunk2 = await mot.astream()
109+
# verify incremental length tracking works
110+
assert not chunk2.startswith(chunk1)
111+
assert chunk2 == "BBB"
112+
113+
114+
@pytest.mark.asyncio
115+
async def test_astream_empty_beginning():
116+
"""Test astream when _underlying_value starts as None."""
117+
mot = create_manual_mock_thunk()
118+
119+
mot._async_queue.put_nowait("First")
120+
# At the start, _underlying_value is None, beginning_length is 0
121+
chunk = await mot.astream()
122+
123+
# Because beginning length was 0, astream returns the full chunk
124+
assert chunk == "First"
125+
assert mot._underlying_value == "First"
126+
127+
128+
@pytest.mark.asyncio
129+
async def test_astream_computed_returns_full_value():
130+
"""Test that astream returns full value when already computed."""
131+
# Precomputed thunk skips queue checking completely
132+
mot = ModelOutputThunk(value="Hello, world!")
133+
134+
# For a precomputed thunk, astream directly returns value
135+
result = await mot.astream()
136+
assert result == "Hello, world!"
137+
138+
139+
@pytest.mark.asyncio
140+
async def test_astream_final_call_returns_full_value():
141+
"""Test that the final astream call returns the full value when computed."""
142+
mot = create_manual_mock_thunk()
143+
144+
mot._async_queue.put_nowait("part1")
145+
chunk1 = await mot.astream()
146+
assert chunk1 == "part1"
147+
148+
mot._async_queue.put_nowait("part2")
149+
chunk2 = await mot.astream()
150+
assert chunk2 == "part2"
151+
152+
mot._async_queue.put_nowait("part3")
153+
mot._async_queue.put_nowait(None)
154+
155+
# Calling astream here processes "part3" and `None`, flagging it as done
156+
chunk3 = await mot.astream()
157+
158+
final_val = await mot.avalue()
159+
160+
# The last chunk received by an `astream` call that resolves the thunk
161+
# is actually designed in Mellea base.py (line 368) to return the FULL value,
162+
# bypassing incremental boundaries. This test verifies that somewhat awkward Mellea idiom.
163+
assert chunk3 == "part1part2part3"
164+
assert chunk3 == final_val

0 commit comments

Comments
 (0)