Skip to content

Commit 8fff014

Browse files
committed
fix async iterator issue
1 parent 4916a18 commit 8fff014

File tree

2 files changed

+121
-31
lines changed

2 files changed

+121
-31
lines changed

py/src/braintrust/integrations/agentscope/test_agentscope.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,85 @@ async def _stream():
219219
assert llm_span["metrics"]["tokens"] == 32
220220

221221

222+
@pytest.mark.asyncio
223+
async def test_model_call_wrapper_stream_span_covers_full_stream_duration(memory_logger):
224+
"""Span end timestamp must be recorded after the stream is fully consumed, not before."""
225+
import asyncio
226+
227+
from braintrust.integrations.agentscope.tracing import _model_call_wrapper
228+
229+
assert not memory_logger.pop()
230+
231+
class FakeModel:
232+
model_name = "gpt-4o-mini"
233+
234+
async def wrapped(*_args, **_kwargs):
235+
async def _stream():
236+
for i in range(3):
237+
await asyncio.sleep(0.1)
238+
yield {"content": [{"type": "text", "text": f"chunk-{i}"}]}
239+
240+
return _stream()
241+
242+
stream = await _model_call_wrapper(
243+
wrapped,
244+
FakeModel(),
245+
args=([{"role": "user", "content": "hi"}],),
246+
kwargs={},
247+
)
248+
async for _ in stream:
249+
pass
250+
251+
spans = memory_logger.pop()
252+
assert len(spans) == 1
253+
span = spans[0]
254+
m = span.get("metrics", {})
255+
duration_ms = (m["end"] - m["start"]) * 1000
256+
# Stream takes ~300ms (3 chunks × 100ms). The span duration must reflect that.
257+
assert duration_ms >= 200, f"Span duration {duration_ms:.0f}ms is too short; span ended before stream was consumed"
258+
259+
260+
@pytest.mark.asyncio
261+
async def test_toolkit_call_tool_function_wrapper_stream_span_covers_full_stream_duration(memory_logger):
262+
"""Tool span end timestamp must be recorded after the stream is fully consumed, not before."""
263+
import asyncio
264+
265+
from braintrust.integrations.agentscope.tracing import _toolkit_call_tool_function_wrapper
266+
267+
assert not memory_logger.pop()
268+
269+
class FakeToolkit:
270+
pass
271+
272+
class FakeToolCall:
273+
name = "my_tool"
274+
275+
async def wrapped(*_args, **_kwargs):
276+
async def _stream():
277+
for i in range(3):
278+
await asyncio.sleep(0.1)
279+
yield f"chunk-{i}"
280+
281+
return _stream()
282+
283+
stream = await _toolkit_call_tool_function_wrapper(
284+
wrapped,
285+
FakeToolkit(),
286+
args=(FakeToolCall(),),
287+
kwargs={},
288+
)
289+
async for _ in stream:
290+
pass
291+
292+
spans = memory_logger.pop()
293+
assert len(spans) == 1
294+
span = spans[0]
295+
m = span.get("metrics", {})
296+
duration_ms = (m["end"] - m["start"]) * 1000
297+
# Stream takes ~300ms (3 chunks × 100ms). The span duration must reflect that.
298+
assert duration_ms >= 200, f"Span duration {duration_ms:.0f}ms is too short; span ended before stream was consumed"
299+
300+
222301
class TestAutoInstrumentAgentScope:
223302
def test_auto_instrument_agentscope(self):
224303
verify_autoinstrument_script("test_auto_agentscope.py")

py/src/braintrust/integrations/agentscope/tracing.py

Lines changed: 42 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""AgentScope-specific span creation and stream aggregation."""
22

3+
import contextlib
34
from contextlib import aclosing
45
from typing import Any
56

@@ -200,29 +201,34 @@ async def _fanout_pipeline_wrapper(wrapped: Any, instance: Any, args: Any, kwarg
200201
async def _toolkit_call_tool_function_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: dict[str, Any]) -> Any:
201202
tool_call = args[0] if args else kwargs.get("tool_call")
202203
tool_name = _tool_name(tool_call)
203-
with start_span(
204-
name=f"{tool_name}.execute",
205-
type=SpanTypeAttribute.TOOL,
206-
input=_clean(
207-
{
208-
"tool_name": tool_name,
209-
"tool_call": tool_call,
210-
}
211-
),
212-
metadata=_clean({"toolkit_class": instance.__class__.__name__}),
213-
) as span:
204+
with contextlib.ExitStack() as stack:
205+
span = stack.enter_context(
206+
start_span(
207+
name=f"{tool_name}.execute",
208+
type=SpanTypeAttribute.TOOL,
209+
input=_clean(
210+
{
211+
"tool_name": tool_name,
212+
"tool_call": tool_call,
213+
}
214+
),
215+
metadata=_clean({"toolkit_class": instance.__class__.__name__}),
216+
)
217+
)
214218
try:
215219
result = await wrapped(*args, **kwargs)
216220
if _is_async_iterator(result):
221+
deferred = stack.pop_all()
217222

218223
async def _trace():
219-
last_chunk = None
220-
async with aclosing(result) as agen:
221-
async for chunk in agen:
222-
last_chunk = chunk
223-
yield chunk
224-
if last_chunk is not None:
225-
span.log(output=last_chunk)
224+
with deferred:
225+
last_chunk = None
226+
async with aclosing(result) as agen:
227+
async for chunk in agen:
228+
last_chunk = chunk
229+
yield chunk
230+
if last_chunk is not None:
231+
span.log(output=last_chunk)
226232

227233
return _trace()
228234

@@ -241,24 +247,29 @@ def _is_async_iterator(value: Any) -> bool:
241247

242248

243249
async def _model_call_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: dict[str, Any]) -> Any:
244-
with start_span(
245-
name=f"{_model_provider_name(instance)}.call",
246-
type=SpanTypeAttribute.LLM,
247-
input=_model_call_input(args, kwargs),
248-
metadata=_model_call_metadata(instance, kwargs),
249-
) as span:
250+
with contextlib.ExitStack() as stack:
251+
span = stack.enter_context(
252+
start_span(
253+
name=f"{_model_provider_name(instance)}.call",
254+
type=SpanTypeAttribute.LLM,
255+
input=_model_call_input(args, kwargs),
256+
metadata=_model_call_metadata(instance, kwargs),
257+
)
258+
)
250259
try:
251260
result = await wrapped(*args, **kwargs)
252261
if _is_async_iterator(result):
262+
deferred = stack.pop_all()
253263

254264
async def _trace():
255-
last_chunk = None
256-
async with aclosing(result) as agen:
257-
async for chunk in agen:
258-
last_chunk = chunk
259-
yield chunk
260-
if last_chunk is not None:
261-
span.log(output=_model_call_output(last_chunk), metrics=_extract_metrics(last_chunk))
265+
with deferred:
266+
last_chunk = None
267+
async with aclosing(result) as agen:
268+
async for chunk in agen:
269+
last_chunk = chunk
270+
yield chunk
271+
if last_chunk is not None:
272+
span.log(output=_model_call_output(last_chunk), metrics=_extract_metrics(last_chunk))
262273

263274
return _trace()
264275

0 commit comments

Comments
 (0)