Skip to content

Commit 4bd5469

Browse files
moonbox3Copilot
andauthored
Python: Improve ag-ui tests and coverage (#4442)
* Improve ag-ui tests and coverage * fix tests paths * Fixes * Improve AG-UI test robustness and correctness - Map toolName → tool_call_name in SSE helpers for TOOL_CALL_START events - Fail loudly on malformed SSE JSON in parse_sse_response() instead of silently dropping - Detect duplicate TOOL_CALL_START/TOOL_CALL_END in assert_tool_calls_balanced() - Remove fragile source line reference from test docstring - Add found guard in test_client_tool_sets_additional_properties to prevent vacuous pass Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 1ac68f6 commit 4bd5469

22 files changed

Lines changed: 4766 additions & 17 deletions

python/packages/ag-ui/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ packages = ["agent_framework_ag_ui", "agent_framework_ag_ui_examples"]
4444
[tool.pytest.ini_options]
4545
asyncio_mode = "auto"
4646
testpaths = ["tests/ag_ui"]
47-
pythonpath = ["."]
47+
pythonpath = [".", "tests/ag_ui"]
4848
markers = [
4949
"integration: marks tests as integration tests that require external services",
5050
]

python/packages/ag-ui/tests/ag_ui/conftest.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import sys
66
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable, Mapping, MutableSequence, Sequence
7+
from pathlib import Path
78
from types import SimpleNamespace
89
from typing import Any, Generic, Literal, cast, overload
910

@@ -36,6 +37,13 @@
3637
ResponseFn = Callable[..., Awaitable[ChatResponse]]
3738

3839

40+
def pytest_configure() -> None:
41+
"""Ensure this test directory is on sys.path so helper modules can be imported by name."""
42+
test_dir = str(Path(__file__).resolve().parent)
43+
if test_dir not in sys.path:
44+
sys.path.insert(0, test_dir)
45+
46+
3947
class StreamingChatClientStub(
4048
ChatMiddlewareLayer[OptionsCoT],
4149
FunctionInvocationLayer[OptionsCoT],
@@ -241,3 +249,83 @@ def stream_from_updates_fixture() -> Callable[[list[ChatResponseUpdate]], Stream
241249
def stub_agent() -> type[SupportsAgentRun]:
242250
"""Return the StubAgent class for creating test instances."""
243251
return StubAgent # type: ignore[return-value]
252+
253+
254+
# ── Fixtures for golden / integration tests ──
255+
256+
257+
@pytest.fixture
258+
def collect_events() -> Callable[..., Any]:
259+
"""Return an async helper that collects all events from an async generator."""
260+
261+
async def _collect(async_gen: AsyncIterable[Any]) -> list[Any]:
262+
return [event async for event in async_gen]
263+
264+
return _collect
265+
266+
267+
@pytest.fixture
268+
def make_agent_wrapper() -> Callable[..., Any]:
269+
"""Factory that builds an AgentFrameworkAgent from a stream function.
270+
271+
Usage::
272+
273+
agent = make_agent_wrapper(
274+
stream_fn=stream_from_updates(updates),
275+
state_schema=...,
276+
)
277+
events = [e async for e in agent.run(payload)]
278+
"""
279+
from agent_framework_ag_ui import AgentFrameworkAgent
280+
281+
def _factory(
282+
stream_fn: StreamFn,
283+
*,
284+
state_schema: Any | None = None,
285+
predict_state_config: dict[str, dict[str, str]] | None = None,
286+
require_confirmation: bool = True,
287+
) -> Any:
288+
client = StreamingChatClientStub(stream_fn)
289+
stub = StubAgent(client=client)
290+
return AgentFrameworkAgent(
291+
agent=stub,
292+
state_schema=state_schema,
293+
predict_state_config=predict_state_config,
294+
require_confirmation=require_confirmation,
295+
)
296+
297+
return _factory
298+
299+
300+
@pytest.fixture
301+
def make_app() -> Callable[..., Any]:
302+
"""Factory that builds a FastAPI app with an AG-UI endpoint.
303+
304+
Usage::
305+
306+
app = make_app(agent_or_wrapper, path="/test")
307+
"""
308+
from fastapi import FastAPI
309+
310+
from agent_framework_ag_ui import add_agent_framework_fastapi_endpoint
311+
312+
def _factory(
313+
agent: Any,
314+
*,
315+
path: str = "/",
316+
state_schema: Any | None = None,
317+
predict_state_config: dict[str, dict[str, str]] | None = None,
318+
default_state: dict[str, Any] | None = None,
319+
) -> FastAPI:
320+
app = FastAPI()
321+
add_agent_framework_fastapi_endpoint(
322+
app,
323+
agent,
324+
path=path,
325+
state_schema=state_schema,
326+
predict_state_config=predict_state_config,
327+
default_state=default_state,
328+
)
329+
return app
330+
331+
return _factory
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
# Copyright (c) Microsoft. All rights reserved.
2+
3+
"""EventStream assertion helper for AG-UI regression tests."""
4+
5+
from __future__ import annotations
6+
7+
from typing import Any
8+
9+
10+
class EventStream:
11+
"""Wraps a list of AG-UI events with structured assertion methods.
12+
13+
Usage:
14+
events = [event async for event in agent.run(payload)]
15+
stream = EventStream(events)
16+
stream.assert_bookends()
17+
stream.assert_text_messages_balanced()
18+
"""
19+
20+
def __init__(self, events: list[Any]) -> None:
21+
self.events = events
22+
23+
def __len__(self) -> int:
24+
return len(self.events)
25+
26+
def __iter__(self):
27+
return iter(self.events)
28+
29+
def types(self) -> list[str]:
30+
"""Return ordered list of event type strings."""
31+
return [self._type_str(e) for e in self.events]
32+
33+
def get(self, event_type: str) -> list[Any]:
34+
"""Filter events matching the given type string."""
35+
return [e for e in self.events if self._type_str(e) == event_type]
36+
37+
def first(self, event_type: str) -> Any:
38+
"""Return the first event matching the given type, or raise."""
39+
matches = self.get(event_type)
40+
if not matches:
41+
raise ValueError(f"No event of type {event_type!r} found. Available: {self.types()}")
42+
return matches[0]
43+
44+
def last(self, event_type: str) -> Any:
45+
"""Return the last event matching the given type, or raise."""
46+
matches = self.get(event_type)
47+
if not matches:
48+
raise ValueError(f"No event of type {event_type!r} found. Available: {self.types()}")
49+
return matches[-1]
50+
51+
def snapshot(self) -> dict[str, Any]:
52+
"""Return the latest StateSnapshotEvent snapshot dict."""
53+
return self.last("STATE_SNAPSHOT").snapshot
54+
55+
def messages_snapshot(self) -> list[Any]:
56+
"""Return the latest MessagesSnapshotEvent messages list."""
57+
return self.last("MESSAGES_SNAPSHOT").messages
58+
59+
# ── Structural assertions ──
60+
61+
def assert_bookends(self) -> None:
62+
"""Assert first event is RUN_STARTED and last is RUN_FINISHED."""
63+
types = self.types()
64+
assert types, "Event stream is empty"
65+
assert types[0] == "RUN_STARTED", f"Expected RUN_STARTED first, got {types[0]}"
66+
assert types[-1] == "RUN_FINISHED", f"Expected RUN_FINISHED last, got {types[-1]}"
67+
68+
def assert_has_run_lifecycle(self) -> None:
69+
"""Assert RUN_STARTED is first and RUN_FINISHED exists (may not be last).
70+
71+
Use this instead of assert_bookends() for workflow resume streams where
72+
_drain_open_message() can emit TEXT_MESSAGE_END after RUN_FINISHED.
73+
"""
74+
types = self.types()
75+
assert types, "Event stream is empty"
76+
assert types[0] == "RUN_STARTED", f"Expected RUN_STARTED first, got {types[0]}"
77+
assert "RUN_FINISHED" in types, f"Expected RUN_FINISHED in stream. Types: {types}"
78+
79+
def assert_strict_types(self, expected: list[str]) -> None:
80+
"""Assert exact type sequence match."""
81+
actual = self.types()
82+
assert actual == expected, f"Event type mismatch.\nExpected: {expected}\nActual: {actual}"
83+
84+
def assert_ordered_types(self, expected: list[str]) -> None:
85+
"""Assert expected types appear as a subsequence (in order, not necessarily contiguous)."""
86+
actual = self.types()
87+
actual_idx = 0
88+
for expected_type in expected:
89+
found = False
90+
while actual_idx < len(actual):
91+
if actual[actual_idx] == expected_type:
92+
actual_idx += 1
93+
found = True
94+
break
95+
actual_idx += 1
96+
if not found:
97+
raise AssertionError(
98+
f"Expected subsequence type {expected_type!r} not found after index {actual_idx}.\n"
99+
f"Expected subsequence: {expected}\n"
100+
f"Actual types: {actual}"
101+
)
102+
103+
def assert_text_messages_balanced(self) -> None:
104+
"""Assert every TEXT_MESSAGE_START has a matching TEXT_MESSAGE_END with the same message_id."""
105+
starts: dict[str, int] = {}
106+
ends: set[str] = set()
107+
for i, event in enumerate(self.events):
108+
t = self._type_str(event)
109+
if t == "TEXT_MESSAGE_START":
110+
mid = event.message_id
111+
assert mid not in starts, f"Duplicate TEXT_MESSAGE_START for message_id={mid}"
112+
starts[mid] = i
113+
elif t == "TEXT_MESSAGE_END":
114+
mid = event.message_id
115+
assert mid in starts, f"TEXT_MESSAGE_END for unknown message_id={mid}"
116+
assert mid not in ends, f"Duplicate TEXT_MESSAGE_END for message_id={mid}"
117+
ends.add(mid)
118+
119+
unclosed = set(starts.keys()) - ends
120+
assert not unclosed, f"Unclosed text messages: {unclosed}"
121+
122+
def assert_tool_calls_balanced(self) -> None:
123+
"""Assert every TOOL_CALL_START has a matching TOOL_CALL_END with the same tool_call_id."""
124+
starts: dict[str, int] = {}
125+
ends: set[str] = set()
126+
for i, event in enumerate(self.events):
127+
t = self._type_str(event)
128+
if t == "TOOL_CALL_START":
129+
tid = event.tool_call_id
130+
assert tid not in starts, f"Duplicate TOOL_CALL_START for tool_call_id={tid}"
131+
starts[tid] = i
132+
elif t == "TOOL_CALL_END":
133+
tid = event.tool_call_id
134+
assert tid in starts, f"TOOL_CALL_END for unknown tool_call_id={tid}"
135+
assert tid not in ends, f"Duplicate TOOL_CALL_END for tool_call_id={tid}"
136+
ends.add(tid)
137+
138+
unclosed = set(starts.keys()) - ends
139+
assert not unclosed, f"Unclosed tool calls: {unclosed}"
140+
141+
def assert_no_run_error(self) -> None:
142+
"""Assert no RUN_ERROR events exist."""
143+
errors = self.get("RUN_ERROR")
144+
if errors:
145+
messages = [getattr(e, "message", str(e)) for e in errors]
146+
raise AssertionError(f"Found {len(errors)} RUN_ERROR event(s): {messages}")
147+
148+
def assert_has_type(self, event_type: str) -> None:
149+
"""Assert at least one event of the given type exists."""
150+
assert event_type in self.types(), f"Expected {event_type!r} in stream. Available: {self.types()}"
151+
152+
def assert_message_ids_consistent(self) -> None:
153+
"""Assert TEXT_MESSAGE_CONTENT events reference valid, open message_ids."""
154+
open_messages: set[str] = set()
155+
for event in self.events:
156+
t = self._type_str(event)
157+
if t == "TEXT_MESSAGE_START":
158+
open_messages.add(event.message_id)
159+
elif t == "TEXT_MESSAGE_END":
160+
open_messages.discard(event.message_id)
161+
elif t == "TEXT_MESSAGE_CONTENT":
162+
mid = event.message_id
163+
assert mid in open_messages, f"TEXT_MESSAGE_CONTENT references message_id={mid} which is not open"
164+
165+
# ── Internal ──
166+
167+
@staticmethod
168+
def _type_str(event: Any) -> str:
169+
"""Extract event type as a plain string."""
170+
t = getattr(event, "type", None)
171+
if t is None:
172+
return type(event).__name__
173+
if isinstance(t, str):
174+
return t
175+
return getattr(t, "value", str(t))
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Copyright (c) Microsoft. All rights reserved.
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) Microsoft. All rights reserved.
2+
3+
"""Conftest for golden tests — ensures parent test dir is importable."""
4+
5+
import sys
6+
from pathlib import Path
7+
8+
9+
def pytest_configure() -> None:
10+
"""Ensure parent test directory is on sys.path for helper module imports."""
11+
parent_test_dir = str(Path(__file__).resolve().parent.parent)
12+
if parent_test_dir not in sys.path:
13+
sys.path.insert(0, parent_test_dir)

0 commit comments

Comments
 (0)