Skip to content

Commit 8928640

Browse files
authored
Dataset pipeline follow ups (#470)
1 parent 8fbf37a commit 8928640

2 files changed

Lines changed: 56 additions & 8 deletions

File tree

py/src/braintrust/dataset_pipeline.py

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
from typing_extensions import NotRequired, TypedDict
66

7-
from .generated_types import ObjectReference
87
from .logger import Metadata
98
from .trace import Trace
109

@@ -19,48 +18,77 @@
1918
"DatasetPipelineTransform",
2019
"DatasetPipelineTransformArgs",
2120
"DatasetPipelineTransformResult",
22-
"get_registered_dataset_pipelines",
2321
]
2422

2523

2624
DatasetPipelineScope: TypeAlias = Literal["span", "trace"]
2725

2826

2927
class DatasetPipelineSource(TypedDict, total=False):
28+
"""Information about what spans or traces should be passed into the dataset pipeline."""
29+
3030
project_id: str
31+
"""Project ID to take spans or traces from. Takes precedence over project_name."""
3132
project_name: str
33+
"""Project name to take spans or traces from."""
3234
org_name: str
35+
"""Organization name to take spans or traces from."""
3336
filter: str
37+
"""Optional BTQL filter. When omitted, all spans or traces are eligible."""
3438
scope: DatasetPipelineScope
39+
"""Whether to pass spans or entire traces to the pipeline. Defaults to "span"."""
3540

3641

3742
class DatasetPipelineTarget(TypedDict):
43+
"""Information about the target dataset."""
44+
3845
dataset_name: str
46+
"""Dataset name. This can be an existing dataset name or a name to create."""
3947
project_id: NotRequired[str]
48+
"""Project ID where the dataset lives or should be created."""
4049
project_name: NotRequired[str]
50+
"""Project name where the dataset lives or should be created."""
4151
org_name: NotRequired[str]
52+
"""Organization name where the dataset lives or should be created."""
4253
description: NotRequired[str]
54+
"""Dataset description to use when creating the dataset."""
4355
metadata: NotRequired[Metadata]
56+
"""Dataset metadata to use when creating the dataset."""
4457

4558

4659
class DatasetPipelineRow(TypedDict, total=False):
60+
"""A row returned by a dataset pipeline transform."""
61+
4762
id: str
63+
"""Stable row ID for the target dataset. Defaults to the source span or trace ID."""
4864
input: Any | None
65+
"""Input value for the target dataset row."""
4966
expected: Any | None
67+
"""Expected value for the target dataset row."""
5068
tags: Sequence[str] | None
69+
"""Tags for the target dataset row."""
5170
metadata: Metadata | None
52-
origin: ObjectReference
71+
"""Metadata for the target dataset row."""
5372

5473

5574
Row = TypeVar("Row", bound=DatasetPipelineRow, covariant=True)
5675

5776

5877
class DatasetPipelineTransformArgs(TypedDict, total=False):
78+
"""Arguments passed to a dataset pipeline transform."""
79+
80+
id: str
81+
"""Source span row ID for span-scoped transforms."""
5982
input: Any | None
83+
"""Source span input for span-scoped transforms."""
6084
output: Any | None
85+
"""Source span output for span-scoped transforms."""
6186
metadata: Metadata | None
87+
"""Source span metadata for span-scoped transforms."""
6288
expected: Any | None
89+
"""Source span expected value for span-scoped transforms."""
6390
trace: Trace
91+
"""Source trace. This is always available."""
6492

6593

6694
DatasetPipelineTransformResult: TypeAlias = Row | Sequence[Row] | None
@@ -69,6 +97,7 @@ class DatasetPipelineTransformArgs(TypedDict, total=False):
6997
class DatasetPipelineTransform(Protocol[Row]):
7098
def __call__(
7199
self,
100+
id: str | None = None,
72101
input: Any | None = None,
73102
output: Any | None = None,
74103
metadata: Metadata | None = None,
@@ -79,6 +108,8 @@ def __call__(
79108

80109
@dataclass(frozen=True)
81110
class DatasetPipelineDefinition(Generic[Row]):
111+
"""A registered dataset pipeline definition consumed by the bt CLI."""
112+
82113
source: DatasetPipelineSource
83114
transform: DatasetPipelineTransform[Row]
84115
target: DatasetPipelineTarget
@@ -88,20 +119,32 @@ class DatasetPipelineDefinition(Generic[Row]):
88119
_DATASET_PIPELINES: list[DatasetPipelineDefinition[Any]] = []
89120

90121

91-
def get_registered_dataset_pipelines() -> list[DatasetPipelineDefinition[Any]]:
92-
return list(_DATASET_PIPELINES)
93-
94-
95122
def DatasetPipeline(
96123
name: str | None = None,
97124
*,
98125
source: DatasetPipelineSource,
99126
transform: DatasetPipelineTransform[DatasetPipelineRow],
100127
target: DatasetPipelineTarget,
101128
) -> DatasetPipelineDefinition[DatasetPipelineRow]:
129+
"""Create a runnable dataset pipeline.
130+
131+
Dataset pipelines take trace data stored in Braintrust, filter and transform it,
132+
and feed it back into a Braintrust dataset.
133+
134+
Run a dataset pipeline with the bt CLI:
135+
136+
bt datasets pipeline run path/to/pipeline.py --limit 100
137+
138+
The limit controls how many spans or traces, depending on source["scope"], are
139+
discovered for the pipeline.
140+
141+
This API is experimental and may change or be removed across non-major versions.
142+
"""
143+
stored_source = source.copy()
144+
stored_source["scope"] = stored_source.get("scope", "span")
102145
definition = DatasetPipelineDefinition(
103146
name=name,
104-
source=source.copy(),
147+
source=stored_source,
105148
transform=transform,
106149
target=target.copy(),
107150
)

py/src/braintrust/test_context.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ def _threadpool_scenario(test_logger, with_memory_logger):
4848
# ThreadPoolExecutor.submit globally. Without this flag the background
4949
# logger's atexit handler tries to flush via the patched executor during
5050
# Python shutdown, which crashes the subprocess (SIGABRT / 0xC0000409).
51+
# The memory logger override is thread-local, so worker threads also need a
52+
# process-local fallback to avoid the real HTTP background logger in this
53+
# isolated test process.
5154
_SCENARIO_TEMPLATE = """\
5255
import os, inspect, asyncio
5356
os.environ["BRAINTRUST_APP_URL"] = "https://www.braintrust.dev"
@@ -58,9 +61,11 @@ def _threadpool_scenario(test_logger, with_memory_logger):
5861
os.environ.setdefault("GOOGLE_API_KEY", os.environ.get("GEMINI_API_KEY", "your_google_api_key_here"))
5962
from braintrust import logger as _logger
6063
from braintrust.test_helpers import init_test_logger
64+
from braintrust.util import LazyValue
6165
from braintrust.test_context import {fn_name} as _fn
6266
_logger._state.reset_parent_state()
6367
with _logger._internal_with_memory_background_logger() as _bgl:
68+
_logger._state._global_bg_logger = LazyValue(lambda: _bgl, use_mutex=False)
6469
_tl = init_test_logger("test-context-project")
6570
if {instrument}:
6671
from braintrust.wrappers.threads import setup_threads

0 commit comments

Comments
 (0)