From bf2b8ae3f29ee5465f678aa083fc04fb2654f25b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Halber?= Date: Thu, 19 Mar 2026 17:01:11 -0700 Subject: [PATCH 01/11] fix: langchain python pkg got updated, some submodules are now accessible with langchain_classic module --- py/src/braintrust/wrappers/langchain.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/py/src/braintrust/wrappers/langchain.py b/py/src/braintrust/wrappers/langchain.py index 6beeb578..7a4c9f6d 100644 --- a/py/src/braintrust/wrappers/langchain.py +++ b/py/src/braintrust/wrappers/langchain.py @@ -9,11 +9,11 @@ _logger = logging.getLogger("braintrust.wrappers.langchain") try: - from langchain.callbacks.base import BaseCallbackHandler - from langchain.schema import Document - from langchain.schema.agent import AgentAction - from langchain.schema.messages import BaseMessage - from langchain.schema.output import LLMResult + from langchain_classic.callbacks.base import BaseCallbackHandler + from langchain_classic.schema import Document + from langchain_classic.schema.agent import AgentAction + from langchain_classic.schema.messages import BaseMessage + from langchain_classic.schema.output import LLMResult except ImportError: _logger.warning("Failed to import langchain, using stubs") BaseCallbackHandler = object From 0a173de2a59d8ef0394c3a5236c7d5eaca8ef8a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Halber?= Date: Fri, 20 Mar 2026 23:39:31 +0000 Subject: [PATCH 02/11] chore:remove dead code --- py/src/braintrust/framework.py | 4 ++-- .../integrations/agno/_test_agno_helpers.py | 14 +++++++------- py/src/braintrust/logger.py | 7 ++----- 3 files changed, 11 insertions(+), 14 deletions(-) diff --git a/py/src/braintrust/framework.py b/py/src/braintrust/framework.py index 2eeb00de..be654c49 100644 --- a/py/src/braintrust/framework.py +++ b/py/src/braintrust/framework.py @@ -203,7 +203,7 @@ def tags(self) -> Sequence[str]: """ @abc.abstractmethod - def report_progress(self, progress: TaskProgressEvent) -> None: + def report_progress(self, _progress: TaskProgressEvent) -> None: """ Report progress that will show up in the playground. """ @@ -459,7 +459,7 @@ class EvalResultWithSummary(SerializableDataClass, Generic[Input, Output]): summary: ExperimentSummary results: list[EvalResult[Input, Output]] - def _repr_pretty_(self, p, cycle): + def _repr_pretty_(self, p, _cycle): p.text(f'EvalResultWithSummary(summary="...", results=[...])') diff --git a/py/src/braintrust/integrations/agno/_test_agno_helpers.py b/py/src/braintrust/integrations/agno/_test_agno_helpers.py index 2c7d4b65..e112d1ff 100644 --- a/py/src/braintrust/integrations/agno/_test_agno_helpers.py +++ b/py/src/braintrust/integrations/agno/_test_agno_helpers.py @@ -56,10 +56,10 @@ def __init__(self): self.name = name self.steps = ["first-step"] - async def _aexecute(self, session_id, user_id, execution_input, workflow_run_response, run_context=None): + async def _aexecute(self, session_id, user_id, execution_input, workflow_run_response, _run_context=None): return FakeWorkflowRunResponse(input=execution_input.input, content="workflow-async") - def _execute_stream(self, session, execution_input, workflow_run_response, run_context=None): + def _execute_stream(self, session, execution_input, workflow_run_response, _run_context=None): yield FakeEvent("WorkflowStarted", content=None) yield FakeEvent("StepStarted", content=None) yield FakeEvent("StepCompleted", content="hello ") @@ -74,7 +74,7 @@ def __init__(self): self.name = name self.steps = ["first-step"] - def _execute_stream(self, session, execution_input, workflow_run_response, run_context=None): + def _execute_stream(self, session, execution_input, workflow_run_response, _run_context=None): yield FakeEvent("StepCompleted", content="hello") yield FakeEvent("WorkflowCompleted", content="hello", metrics=FakeMetrics(), status="COMPLETED") @@ -87,7 +87,7 @@ def __init__(self): self.name = name self.steps = ["first-step"] - def _execute_stream(self, session, execution_input, workflow_run_response, run_context=None): + def _execute_stream(self, session, execution_input, workflow_run_response, _run_context=None): yield FakeEvent("WorkflowStarted", content=None) yield FakeEvent("StepCompleted", content="hello ") workflow_run_response.content = "world" @@ -115,7 +115,7 @@ def __init__(self): self.steps = ["agent-step"] self.agent = WrappedAgent() - async def _aexecute(self, session_id, user_id, execution_input, workflow_run_response, run_context=None): + async def _aexecute(self, session_id, user_id, execution_input, workflow_run_response, _run_context=None): return await self.agent.arun(execution_input.input) return FakeWorkflow @@ -128,7 +128,7 @@ def __init__(self): self.id = "workflow-agent-123" self.steps = ["agent-step"] - def _execute_workflow_agent(self, user_input, session, execution_input, run_context, stream=False, **kwargs): + def _execute_workflow_agent(self, user_input, session, execution_input, _run_context, stream=False, **kwargs): if stream: def _stream(): @@ -143,7 +143,7 @@ def _stream(): return _stream() return FakeRunOutput(f"{user_input}-sync") - async def _aexecute_workflow_agent(self, user_input, run_context, execution_input, stream=False, **kwargs): + async def _aexecute_workflow_agent(self, user_input, _run_context, execution_input, stream=False, **kwargs): if stream: async def _astream(): diff --git a/py/src/braintrust/logger.py b/py/src/braintrust/logger.py index f7c56d2b..478ef324 100644 --- a/py/src/braintrust/logger.py +++ b/py/src/braintrust/logger.py @@ -1445,7 +1445,7 @@ def _register_dropped_item_count(self, num_items): self._queue_drop_logging_state["last_logged_timestamp"] = time_now @staticmethod - def _write_payload_to_dir(payload_dir, payload, debug_logging_adjective=None): + def _write_payload_to_dir(payload_dir, payload): payload_file = os.path.join(payload_dir, f"payload_{time.time()}_{str(uuid.uuid4())[:8]}.json") try: os.makedirs(payload_dir, exist_ok=True) @@ -2839,7 +2839,7 @@ def _validate_and_sanitize_experiment_log_partial_args(event: Mapping[str, Any]) # Note that this only checks properties that are expected of a complete event. # _validate_and_sanitize_experiment_log_partial_args should still be invoked # (after handling special fields like 'id'). -def _validate_and_sanitize_experiment_log_full_args(event: Mapping[str, Any], has_dataset: bool) -> Mapping[str, Any]: +def _validate_and_sanitize_experiment_log_full_args(event: Mapping[str, Any]) -> Mapping[str, Any]: input = event.get("input") inputs = event.get("inputs") if (input is not None and inputs is not None) or (input is None and inputs is None): @@ -3833,7 +3833,6 @@ def log( metadata: Metadata | None = None, metrics: Mapping[str, int | float] | None = None, id: str | None = None, - dataset_record_id: str | None = None, allow_concurrent_with_spans: bool = False, ) -> str: """ @@ -3849,7 +3848,6 @@ def log( :param metrics: (Optional) a dictionary of metrics to log. The following keys are populated automatically: "start", "end". :param id: (Optional) a unique identifier for the event. If you don't provide one, BrainTrust will generate one for you. :param allow_concurrent_with_spans: (Optional) in rare cases where you need to log at the top level separately from using spans on the experiment elsewhere, set this to True. - :param dataset_record_id: (Deprecated) the id of the dataset record that this event is associated with. This field is required if and only if the experiment is associated with a dataset. This field is unused and will be removed in a future version. :returns: The `id` of the logged event. """ if self._called_start_span and not allow_concurrent_with_spans: @@ -3869,7 +3867,6 @@ def log( metrics=metrics, id=id, ), - self.dataset is not None, ) span = self._start_span_impl(start_time=self.last_start_time, lookup_span_parent=False, **event) self.last_start_time = span.end() From cc33cac3d68a7446fb0f6a7a50108a27ac95bfc9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Halber?= Date: Fri, 20 Mar 2026 23:39:31 +0000 Subject: [PATCH 03/11] chore:remove dead code --- .../integrations/claude_agent_sdk/_test_transport.py | 4 ++-- py/src/braintrust/wrappers/test_pydantic_ai_integration.py | 4 +--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/py/src/braintrust/integrations/claude_agent_sdk/_test_transport.py b/py/src/braintrust/integrations/claude_agent_sdk/_test_transport.py index 217b2e62..27a94e96 100644 --- a/py/src/braintrust/integrations/claude_agent_sdk/_test_transport.py +++ b/py/src/braintrust/integrations/claude_agent_sdk/_test_transport.py @@ -68,8 +68,8 @@ def _normalize_write(data: str, *, sanitize: bool = False) -> dict[str, Any]: async def _empty_stream(): - return - yield {} # type: ignore[unreachable] + for _ in (): + yield {} def _normalize_for_match(value: Any) -> Any: diff --git a/py/src/braintrust/wrappers/test_pydantic_ai_integration.py b/py/src/braintrust/wrappers/test_pydantic_ai_integration.py index b794b18b..81de2ea4 100644 --- a/py/src/braintrust/wrappers/test_pydantic_ai_integration.py +++ b/py/src/braintrust/wrappers/test_pydantic_ai_integration.py @@ -184,13 +184,11 @@ async def fake_run_chat( *, stream, agent, - deps, - console, - code_theme, prog_name, message_history, model_settings=None, usage_limits=None, + **_, ): assert stream is True assert prog_name == "braintrust-cli" From d812b6fc88a4f3dd97b614ccb3e5b459401f96ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Halber?= Date: Mon, 23 Mar 2026 17:58:29 +0000 Subject: [PATCH 04/11] chore: add vulture to pyproject.toml --- pyproject.toml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index b7d159c8..31230cc9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,3 +20,8 @@ split-on-trailing-comma = true asyncio_mode = "strict" asyncio_default_fixture_loop_scope = "function" addopts = "--durations=3 --durations-min=0.1" + +[tool.vulture] +paths = ["py/src"] +ignore_names = ["with_simulate_login", "reset_id_generator_state", "dataset_record_id"] # pytest fixtures and deprecated-but-public API parameters +min_confidence = 100 From 8ff4992d61c8150f586e2d093ca4d2b821c7ec81 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Halber?= Date: Mon, 23 Mar 2026 20:53:19 +0000 Subject: [PATCH 05/11] chore: add vulture to pre-commit --- .pre-commit-config.yaml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e9df688d..7ea78815 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -32,3 +32,8 @@ repos: args: - "-L" - "rouge,coo,couldn,unsecure,ontext,afterall,als" + - repo: https://github.com/jendrikseipp/vulture + rev: v2.15 + hooks: + - id: vulture + pass_filenames: false From 651fd7adcbf5ebf108165fa220498d13677ed788 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Halber?= Date: Tue, 24 Mar 2026 18:52:42 +0000 Subject: [PATCH 06/11] chore: remove probably unused code (need human review) --- py/src/braintrust/http_headers.py | 4 ---- 1 file changed, 4 deletions(-) delete mode 100644 py/src/braintrust/http_headers.py diff --git a/py/src/braintrust/http_headers.py b/py/src/braintrust/http_headers.py deleted file mode 100644 index 138a1f03..00000000 --- a/py/src/braintrust/http_headers.py +++ /dev/null @@ -1,4 +0,0 @@ -BT_FOUND_EXISTING_HEADER = "x-bt-found-existing" -BT_CURSOR_HEADER = "x-bt-cursor" -BT_IMPERSONATE_USER = "x-bt-impersonate-user" -BT_PARENT = "x-bt-parent" From 6e8c02f6d4a9ab95c044e9fb0e875c46eaada1b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Halber?= Date: Thu, 26 Mar 2026 18:03:30 -0700 Subject: [PATCH 07/11] chore: remove probably unused code (need human review) --- py/src/braintrust/cli/eval.py | 2 +- py/src/braintrust/cli/install/logs.py | 1 - py/src/braintrust/db_fields.py | 9 - py/src/braintrust/framework.py | 27 +- py/src/braintrust/logger.py | 5 - py/src/braintrust/parameters.py | 7 - py/src/braintrust/queue.py | 2 - py/src/braintrust/wrappers/adk/__init__.py | 677 +++++++++++++ .../wrappers/claude_agent_sdk/_wrapper.py | 943 ++++++++++++++++++ py/src/braintrust/wrappers/langchain.py | 1 - py/src/braintrust/wrappers/pydantic_ai.py | 74 -- 11 files changed, 1629 insertions(+), 119 deletions(-) create mode 100644 py/src/braintrust/wrappers/adk/__init__.py create mode 100644 py/src/braintrust/wrappers/claude_agent_sdk/_wrapper.py diff --git a/py/src/braintrust/cli/eval.py b/py/src/braintrust/cli/eval.py index f0e5dc89..595c2c0a 100644 --- a/py/src/braintrust/cli/eval.py +++ b/py/src/braintrust/cli/eval.py @@ -246,7 +246,7 @@ def check_match(path_input, include_patterns, exclude_patterns): def collect_files(input_path): if os.path.isdir(input_path): - for root, dirs, files in os.walk(input_path): + for root, _, files in os.walk(input_path): for file in files: fname = os.path.join(root, file) if check_match(fname, INCLUDE, EXCLUDE): diff --git a/py/src/braintrust/cli/install/logs.py b/py/src/braintrust/cli/install/logs.py index 2b840aec..4d46ad87 100644 --- a/py/src/braintrust/cli/install/logs.py +++ b/py/src/braintrust/cli/install/logs.py @@ -88,7 +88,6 @@ def get_events(stream): with ThreadPoolExecutor(8) as executor: events = executor.map(get_events, all_streams) - last_ts = None for stream, log in zip(all_streams, events): print(f"---- LOG STREAM: {stream['logStreamName']}") for event in log["events"]: diff --git a/py/src/braintrust/db_fields.py b/py/src/braintrust/db_fields.py index a89b9710..6fd95df4 100644 --- a/py/src/braintrust/db_fields.py +++ b/py/src/braintrust/db_fields.py @@ -1,21 +1,12 @@ TRANSACTION_ID_FIELD = "_xact_id" OBJECT_DELETE_FIELD = "_object_delete" -CREATED_FIELD = "created" -ID_FIELD = "id" IS_MERGE_FIELD = "_is_merge" -MERGE_PATHS_FIELD = "_merge_paths" -ARRAY_DELETE_FIELD = "_array_delete" AUDIT_SOURCE_FIELD = "_audit_source" AUDIT_METADATA_FIELD = "_audit_metadata" VALID_SOURCES = ["app", "api", "external"] -PARENT_ID_FIELD = "_parent_id" - -ASYNC_SCORING_CONTROL_FIELD = "_async_scoring_control" -SKIP_ASYNC_SCORING_FIELD = "_skip_async_scoring" - # Keys that identify which object (experiment, dataset, project logs, etc.) a row belongs to. OBJECT_ID_KEYS = ( "experiment_id", diff --git a/py/src/braintrust/framework.py b/py/src/braintrust/framework.py index be654c49..a89710b0 100644 --- a/py/src/braintrust/framework.py +++ b/py/src/braintrust/framework.py @@ -62,15 +62,15 @@ # https://stackoverflow.com/questions/287871/how-do-i-print-colored-text-to-the-terminal class bcolors: - HEADER = "\033[95m" - OKBLUE = "\033[94m" - OKCYAN = "\033[96m" - OKGREEN = "\033[92m" +# HEADER = "\033[95m" +# OKBLUE = "\033[94m" +# OKCYAN = "\033[96m" +# OKGREEN = "\033[92m" WARNING = "\033[93m" FAIL = "\033[91m" ENDC = "\033[0m" - BOLD = "\033[1m" - UNDERLINE = "\033[4m" +# BOLD = "\033[1m" +# UNDERLINE = "\033[4m" @dataclasses.dataclass @@ -228,17 +228,6 @@ def parameters(self) -> ValidatedParameters | None: """ -class EvalScorerArgs(SerializableDataClass, Generic[Input, Output]): - """ - Arguments passed to an evaluator scorer. This includes the input, expected output, actual output, and metadata. - """ - - input: Input - output: Output - expected: Output | None = None - metadata: Metadata | None = None - - OneOrMoreScores = Union[float, int, bool, None, Score, list[Score]] @@ -850,7 +839,7 @@ async def EvalAsync( :param data: Returns an iterator over the evaluation dataset. Each element of the iterator should be a `EvalCase`. :param task: Runs the evaluation task on a single input. The `hooks` object can be used to add metadata to the evaluation. :param scores: A list of scorers to evaluate the results of the task. Each scorer can be a Scorer object or a function - that takes an `EvalScorerArgs` object and returns a `Score` object. + that takes `(input, output, expected)` arguments and returns a `Score` object. :param experiment_name: (Optional) Experiment name. If not specified, a name will be generated automatically. :param trial_count: The number of times to run the evaluator per input. This is useful for evaluating applications that have non-deterministic behavior and gives you both a stronger aggregate measure and a sense of the variance in the results. @@ -977,7 +966,7 @@ def Eval( :param data: Returns an iterator over the evaluation dataset. Each element of the iterator should be a `EvalCase`. :param task: Runs the evaluation task on a single input. The `hooks` object can be used to add metadata to the evaluation. :param scores: A list of scorers to evaluate the results of the task. Each scorer can be a Scorer object or a function - that takes an `EvalScorerArgs` object and returns a `Score` object. + that takes `(input, output, expected)` arguments and returns a `Score` object. :param experiment_name: (Optional) Experiment name. If not specified, a name will be generated automatically. :param trial_count: The number of times to run the evaluator per input. This is useful for evaluating applications that have non-deterministic behavior and gives you both a stronger aggregate measure and a sense of the variance in the results. diff --git a/py/src/braintrust/logger.py b/py/src/braintrust/logger.py index 478ef324..3912a05b 100644 --- a/py/src/braintrust/logger.py +++ b/py/src/braintrust/logger.py @@ -1068,9 +1068,6 @@ def __init__(self, api_conn: LazyValue[HTTPConnection]): self.logger = logging.getLogger("braintrust") self.queue: "LogQueue[LazyValue[Dict[str, Any]]]" = LogQueue(maxsize=self.queue_maxsize) - # Counter for tracking overflow uploads (useful for testing) - self._overflow_upload_count = 0 - if not disable_atexit_flush: atexit.register(self._finalize) @@ -1382,8 +1379,6 @@ def _submit_logs_request(self, items: Sequence[LogItemWithMeta], max_request_siz except Exception as e: error = e if error is None and resp is not None and resp.ok: - if overflow_rows: - self._overflow_upload_count += 1 return if error is None and resp is not None: resp_errmsg = f"{resp.status_code}: {resp.text}" diff --git a/py/src/braintrust/parameters.py b/py/src/braintrust/parameters.py index 595ba3ce..ac9d4a86 100644 --- a/py/src/braintrust/parameters.py +++ b/py/src/braintrust/parameters.py @@ -63,13 +63,6 @@ def from_function_row(cls, row: dict[str, Any]) -> "RemoteEvalParameters": data=function_data.get("data") or {}, ) - def validate(self, data: Any) -> bool: - try: - validate_json_schema(data, self.schema) - return True - except ValueError: - return False - def _pydantic_to_json_schema(model: Any) -> dict[str, Any]: """Convert a pydantic model to JSON schema.""" diff --git a/py/src/braintrust/queue.py b/py/src/braintrust/queue.py index ff6fc6cf..cfd5e834 100644 --- a/py/src/braintrust/queue.py +++ b/py/src/braintrust/queue.py @@ -32,7 +32,6 @@ def __init__(self, maxsize: int = 0): self._mutex = threading.Lock() self._queue: deque[T] = deque(maxlen=maxsize) self._has_items_event = threading.Event() - self._total_dropped = 0 self._enforce_size_limit = False def enforce_queue_size_limit(self, enforce: bool) -> None: @@ -68,7 +67,6 @@ def put(self, item: T) -> list[T]: while len(self._queue) >= self.maxsize: dropped_item = self._queue.popleft() dropped.append(dropped_item) - self._total_dropped += 1 self._queue.append(item) # Signal that items are available if queue was not empty before or item was added diff --git a/py/src/braintrust/wrappers/adk/__init__.py b/py/src/braintrust/wrappers/adk/__init__.py new file mode 100644 index 00000000..3f9036ab --- /dev/null +++ b/py/src/braintrust/wrappers/adk/__init__.py @@ -0,0 +1,677 @@ +import contextvars +import inspect +import logging +import time +from collections.abc import Iterable +from contextlib import aclosing +from typing import Any, cast + +from braintrust.bt_json import bt_safe_deep_copy +from braintrust.logger import NOOP_SPAN, Attachment, current_span, init_logger, start_span +from braintrust.span_types import SpanTypeAttribute +from wrapt import wrap_function_wrapper + + +logger = logging.getLogger(__name__) + +__all__ = ["setup_braintrust", "setup_adk", "wrap_agent", "wrap_runner", "wrap_flow", "wrap_mcp_tool"] + + +def setup_braintrust(*args, **kwargs): + logger.warning("setup_braintrust is deprecated, use setup_adk instead") + return setup_adk(*args, **kwargs) + + +def setup_adk( + api_key: str | None = None, + project_id: str | None = None, + project_name: str | None = None, + SpanProcessor: type | None = None, +) -> bool: + """ + Setup Braintrust integration with Google ADK. Will automatically patch Google ADK agents, runners, flows, and MCP tools for automatic tracing. + + If you prefer manual patching take a look at `wrap_agent`, `wrap_runner`, `wrap_flow`, and `wrap_mcp_tool`. + + Args: + api_key (Optional[str]): Braintrust API key. + project_id (Optional[str]): Braintrust project ID. + project_name (Optional[str]): Braintrust project name. + SpanProcessor (Optional[type]): Deprecated parameter. + + Returns: + bool: True if setup was successful, False otherwise. + """ + if SpanProcessor is not None: + logging.warning("SpanProcessor parameter is deprecated and will be ignored") + + span = current_span() + if span == NOOP_SPAN: + init_logger(project=project_name, api_key=api_key, project_id=project_id) + + try: + from google.adk import agents, runners + from google.adk.flows.llm_flows import base_llm_flow + + agents.BaseAgent = wrap_agent(agents.BaseAgent) + runners.Runner = wrap_runner(runners.Runner) + base_llm_flow.BaseLlmFlow = wrap_flow(base_llm_flow.BaseLlmFlow) + + try: + from google.adk.platform import thread as adk_thread + + adk_thread.create_thread = _wrap_create_thread(adk_thread.create_thread) + runners.create_thread = _wrap_create_thread(runners.create_thread) + logger.debug("ADK thread bridge patching successful") + except Exception as e: + logger.warning(f"Failed to patch ADK thread bridge: {e}") + + # Try to patch McpTool if available (MCP is optional) + try: + from google.adk.tools.mcp_tool import mcp_tool + + mcp_tool.McpTool = wrap_mcp_tool(mcp_tool.McpTool) + logger.debug("McpTool patching successful") + except ImportError: + # MCP is optional - gracefully skip if not installed + logger.debug("McpTool not available, skipping MCP instrumentation") + except Exception as e: + # Log but don't fail - MCP patching is optional + logger.warning(f"Failed to patch McpTool: {e}") + + return True + except ImportError as e: + logger.error(f"Failed to import Google ADK agents: {e}") + logger.error("Google ADK is not installed. Please install it with: pip install google-adk") + return False + + +def _wrap_create_thread(create_thread): + if _is_patched(create_thread): + return create_thread + + def _wrapped_create_thread(target: Any, *args: Any, **kwargs: Any): + ctx = contextvars.copy_context() + + def _run_in_context(*target_args: Any, **target_kwargs: Any): + return ctx.run(target, *target_args, **target_kwargs) + + return create_thread(_run_in_context, *args, **kwargs) + + _wrapped_create_thread._braintrust_patched = True + return _wrapped_create_thread + + +def wrap_agent(Agent: Any) -> Any: + if _is_patched(Agent): + return Agent + + async def agent_run_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): + parent_context = args[0] if len(args) > 0 else kwargs.get("parent_context") + + async def _trace(): + with start_span( + name=f"agent_run [{instance.name}]", + type=SpanTypeAttribute.TASK, + metadata=bt_safe_deep_copy({"parent_context": parent_context, **_omit(kwargs, ["parent_context"])}), + ) as agent_span: + last_event = None + async with aclosing(wrapped(*args, **kwargs)) as agen: + async for event in agen: + if event.is_final_response(): + last_event = event + yield event + if last_event: + agent_span.log(output=last_event) + + async with aclosing(_trace()) as agen: + async for event in agen: + yield event + + wrap_function_wrapper(Agent, "run_async", agent_run_wrapper) + Agent._braintrust_patched = True + return Agent + + +def wrap_flow(Flow: Any): + if _is_patched(Flow): + return Flow + + async def trace_flow(wrapped: Any, instance: Any, args: Any, kwargs: Any): + invocation_context = args[0] if len(args) > 0 else kwargs.get("invocation_context") + + async def _trace(): + with start_span( + name=f"call_llm", + type=SpanTypeAttribute.TASK, + metadata=bt_safe_deep_copy( + { + "invocation_context": invocation_context, + **_omit(kwargs, ["invocation_context"]), + } + ), + ) as llm_span: + last_event = None + async with aclosing(wrapped(*args, **kwargs)) as agen: + async for event in agen: + last_event = event + yield event + if last_event: + llm_span.log(output=last_event) + + async with aclosing(_trace()) as agen: + async for event in agen: + yield event + + wrap_function_wrapper(Flow, "run_async", trace_flow) + + async def trace_run_sync_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): + invocation_context = args[0] if len(args) > 0 else kwargs.get("invocation_context") + llm_request = args[1] if len(args) > 1 else kwargs.get("llm_request") + model_response_event = args[2] if len(args) > 2 else kwargs.get("model_response_event") + + async def _trace(): + # Extract and serialize contents BEFORE converting to dict + # This is critical because bt_safe_deep_copy converts bytes to string representations + serialized_contents = None + if llm_request and hasattr(llm_request, "contents"): + contents = llm_request.contents + if contents: + serialized_contents = ( + [_serialize_content(c) for c in contents] + if isinstance(contents, list) + else _serialize_content(contents) + ) + + # Now convert the whole request to dict + serialized_request = bt_safe_deep_copy(llm_request) + + # Replace contents with our serialized version that has Attachments + if serialized_contents is not None and isinstance(serialized_request, dict): + serialized_request["contents"] = serialized_contents + + # Handle config specifically to serialize Pydantic schema classes + if isinstance(serialized_request, dict) and "config" in serialized_request: + serialized_request["config"] = _serialize_config(serialized_request["config"]) + + # Extract model name from request or instance + model_name = _extract_model_name(None, llm_request, instance) + + # Create span BEFORE execution so child spans (like mcp_tool) have proper parent + # Start with generic name - we'll update it after we see the response + with start_span( + name="llm_call", + type=SpanTypeAttribute.LLM, + input=serialized_request, + metadata=bt_safe_deep_copy( + { + "invocation_context": invocation_context, + "model_response_event": model_response_event, + "flow_class": instance.__class__.__name__, + "model": model_name, + **_omit(kwargs, ["invocation_context", "model_response_event", "flow_class", "llm_call_type"]), + } + ), + ) as llm_span: + # Execute the LLM call and yield events while span is active + last_event = None + event_with_content = None + start_time = time.time() + first_token_time = None + + async with aclosing(wrapped(*args, **kwargs)) as agen: + async for event in agen: + # Record time to first token + if first_token_time is None: + first_token_time = time.time() + + last_event = event + if hasattr(event, "content") and event.content is not None: + event_with_content = event + yield event + + # After execution, update span with correct call type and output + if last_event: + # We need to check if we should merge content from an earlier event + # Convert to dict to inspect/modify, but let span.log() handle final serialization + output_dict = bt_safe_deep_copy(last_event) + if event_with_content and isinstance(output_dict, dict): + if "content" not in output_dict or output_dict.get("content") is None: + content = ( + bt_safe_deep_copy(event_with_content.content) + if hasattr(event_with_content, "content") + else None + ) + if content: + output_dict["content"] = content + + # Extract metrics from response + metrics = _extract_metrics(last_event) + + # Add time to first token if we captured it + if first_token_time is not None: + if metrics is None: + metrics = {} + metrics["time_to_first_token"] = first_token_time - start_time + + # Determine the actual call type based on the response + call_type = _determine_llm_call_type(llm_request, last_event) + + # Update span name with the specific call type now that we know it + llm_span.set_attributes( + name=f"llm_call [{call_type}]", + span_attributes={"llm_call_type": call_type}, + ) + + # Log output and metrics (span.log will handle serialization) + llm_span.log(output=output_dict, metrics=metrics) + + async with aclosing(_trace()) as agen: + async for event in agen: + yield event + + wrap_function_wrapper(Flow, "_call_llm_async", trace_run_sync_wrapper) + Flow._braintrust_patched = True + return Flow + + +def wrap_runner(Runner: Any): + if _is_patched(Runner): + return Runner + + def trace_run_sync_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): + user_id = kwargs.get("user_id") + session_id = kwargs.get("session_id") + new_message = kwargs.get("new_message") + + # Serialize new_message before any dict conversion to handle binary data + serialized_message = _serialize_content(new_message) if new_message else None + + def _trace(): + with start_span( + name=f"invocation [{instance.app_name}]", + type=SpanTypeAttribute.TASK, + input={"new_message": serialized_message}, + metadata=bt_safe_deep_copy( + { + "user_id": user_id, + "session_id": session_id, + **_omit(kwargs, ["user_id", "session_id", "new_message"]), + } + ), + ) as runner_span: + last_event = None + for event in wrapped(*args, **kwargs): + if event.is_final_response(): + last_event = event + yield event + if last_event: + runner_span.log(output=last_event) + + yield from _trace() + + wrap_function_wrapper(Runner, "run", trace_run_sync_wrapper) + + async def trace_run_async_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): + user_id = kwargs.get("user_id") + session_id = kwargs.get("session_id") + new_message = kwargs.get("new_message") + state_delta = kwargs.get("state_delta") + + # Serialize new_message before any dict conversion to handle binary data + serialized_message = _serialize_content(new_message) if new_message else None + + async def _trace(): + with start_span( + name=f"invocation [{instance.app_name}]", + type=SpanTypeAttribute.TASK, + input={"new_message": serialized_message}, + metadata=bt_safe_deep_copy( + { + "user_id": user_id, + "session_id": session_id, + "state_delta": state_delta, + **_omit(kwargs, ["user_id", "session_id", "new_message", "state_delta"]), + } + ), + ) as runner_span: + last_event = None + async with aclosing(wrapped(*args, **kwargs)) as agen: + async for event in agen: + if event.is_final_response(): + last_event = event + yield event + if last_event: + runner_span.log(output=last_event) + + async with aclosing(_trace()) as agen: + async for event in agen: + yield event + + wrap_function_wrapper(Runner, "run_async", trace_run_async_wrapper) + Runner._braintrust_patched = True + return Runner + + +def wrap_mcp_tool(McpTool: Any) -> Any: + """ + Wrap McpTool to trace MCP tool invocations. + + Creates Braintrust spans for each MCP tool call, capturing: + - Tool name + - Input arguments + - Output results + - Execution time + - Errors if they occur + + Args: + McpTool: The McpTool class to wrap + + Returns: + The wrapped McpTool class + """ + if _is_patched(McpTool): + return McpTool + + async def tool_run_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): + # Extract tool information + tool_name = instance.name + tool_args = kwargs.get("args", {}) + + with start_span( + name=f"mcp_tool [{tool_name}]", + type=SpanTypeAttribute.TOOL, + input={"tool_name": tool_name, "arguments": tool_args}, + metadata=_omit(kwargs, ["args"]), + ) as tool_span: + try: + result = await wrapped(*args, **kwargs) + tool_span.log(output=result) + return result + except Exception as e: + # Log error to span but re-raise for ADK to handle + tool_span.log(error=str(e)) + raise + + wrap_function_wrapper(McpTool, "run_async", tool_run_wrapper) + McpTool._braintrust_patched = True + return McpTool + + +def _determine_llm_call_type(llm_request: Any, model_response: Any = None) -> str: + """ + Determine the type of LLM call based on the request and response content. + + Returns: + - "tool_selection" if the LLM selected a tool to call in its response + - "response_generation" if the LLM is generating a response after tool execution + - "direct_response" if there are no tools involved or tools available but not used + """ + try: + # Convert to dict if it's a model object + request_dict = cast(dict[str, Any], bt_safe_deep_copy(llm_request)) + + # Check if there are tools in the config + # Check the conversation history for function responses + contents = request_dict.get("contents", []) + has_function_response = False + + for content in contents: + if isinstance(content, dict): + parts = content.get("parts", []) + for part in parts: + if isinstance(part, dict): + if "function_response" in part and part["function_response"] is not None: + has_function_response = True + + # Check if the response contains function calls + response_has_function_call = False + if model_response: + # Check if it's an Event object with get_function_calls method (ADK Event) + if hasattr(model_response, "get_function_calls"): + try: + function_calls = model_response.get_function_calls() + if function_calls and len(function_calls) > 0: + response_has_function_call = True + except Exception: + pass + + # Fallback: Check the response dict structure + if not response_has_function_call: + response_dict = bt_safe_deep_copy(model_response) + if isinstance(response_dict, dict): + # Try multiple possible response structures + # 1. Standard: response.content.parts + content = response_dict.get("content", {}) + if isinstance(content, dict): + parts = content.get("parts", []) + if isinstance(parts, list): + for part in parts: + if isinstance(part, dict): + if ("function_call" in part and part["function_call"] is not None) or ( + "functionCall" in part and part["functionCall"] is not None + ): + response_has_function_call = True + break + + # 2. Alternative: response has parts directly (for some event types) + if not response_has_function_call and "parts" in response_dict: + parts = response_dict.get("parts", []) + if isinstance(parts, list): + for part in parts: + if isinstance(part, dict): + if ("function_call" in part and part["function_call"] is not None) or ( + "functionCall" in part and part["functionCall"] is not None + ): + response_has_function_call = True + break + + # Determine the call type + if has_function_response: + return "response_generation" + elif response_has_function_call: + return "tool_selection" + else: + return "direct_response" + + except Exception: + return "unknown" + + +def _is_patched(obj: Any): + return getattr(obj, "_braintrust_patched", False) + + +def _serialize_content(content: Any) -> Any: + """Serialize Google ADK Content/Part objects, converting binary data to Attachments.""" + if content is None: + return None + + # Handle Content objects with parts + if hasattr(content, "parts") and content.parts: + serialized_parts = [] + for part in content.parts: + serialized_parts.append(_serialize_part(part)) + + result = {"parts": serialized_parts} + if hasattr(content, "role"): + result["role"] = content.role + return result + + # Handle single Part + return _serialize_part(content) + + +def _serialize_part(part: Any) -> Any: + """Serialize a single Part object, handling binary data.""" + if part is None: + return None + + # If it's already a dict, return as-is + if isinstance(part, dict): + return part + + # Handle Part objects with inline_data (binary data like images) + if hasattr(part, "inline_data") and part.inline_data: + inline_data = part.inline_data + if hasattr(inline_data, "data") and hasattr(inline_data, "mime_type"): + data = inline_data.data + mime_type = inline_data.mime_type + + # Convert bytes to Attachment + if isinstance(data, bytes): + extension = mime_type.split("/")[1] if "/" in mime_type else "bin" + filename = f"file.{extension}" + attachment = Attachment(data=data, filename=filename, content_type=mime_type) + + # Return in image_url format - SDK will replace with AttachmentReference + return {"image_url": {"url": attachment}} + + # Handle Part objects with file_data (file references) + if hasattr(part, "file_data") and part.file_data: + file_data = part.file_data + result = {"file_data": {}} + if hasattr(file_data, "file_uri"): + result["file_data"]["file_uri"] = file_data.file_uri + if hasattr(file_data, "mime_type"): + result["file_data"]["mime_type"] = file_data.mime_type + return result + + # Handle text parts + if hasattr(part, "text") and part.text is not None: + result = {"text": part.text} + if hasattr(part, "thought") and part.thought: + result["thought"] = part.thought + return result + + # Try standard serialization methods + return bt_safe_deep_copy(part) + + +def _serialize_pydantic_schema(schema_class: Any) -> dict[str, Any]: + """ + Serialize a Pydantic model class to its full JSON schema. + + Returns the complete schema including descriptions, constraints, and nested definitions + so engineers can see exactly what structured output schema was used. + """ + try: + from pydantic import BaseModel + + if inspect.isclass(schema_class) and issubclass(schema_class, BaseModel): + # Return the full JSON schema - includes all field info, descriptions, constraints, etc. + return schema_class.model_json_schema() + except (ImportError, AttributeError, TypeError): + pass + # If not a Pydantic model, return class name + return {"__class__": schema_class.__name__ if inspect.isclass(schema_class) else str(type(schema_class).__name__)} + + +def _serialize_config(config: Any) -> dict[str, Any] | Any: + """ + Serialize a config object, specifically handling schema fields that may contain Pydantic classes. + + Google ADK uses these fields for schemas: + - response_schema, response_json_schema (in GenerateContentConfig for LLM requests) + - input_schema, output_schema (in agent config) + """ + if config is None: + return None + if not config: + return config + + # Extract schema fields BEFORE calling bt_safe_deep_copy (which converts Pydantic classes to dicts) + schema_fields = ["response_schema", "response_json_schema", "input_schema", "output_schema"] + serialized_schemas: dict[str, Any] = {} + + for field in schema_fields: + schema_value = None + + # Try to get the field value + if hasattr(config, field): + schema_value = getattr(config, field) + elif isinstance(config, dict) and field in config: + schema_value = config[field] + + # If it's a Pydantic class, serialize it + if schema_value is not None and inspect.isclass(schema_value): + try: + from pydantic import BaseModel + + if issubclass(schema_value, BaseModel): + serialized_schemas[field] = _serialize_pydantic_schema(schema_value) + except (TypeError, ImportError): + pass + + # Serialize the config + config_dict = bt_safe_deep_copy(config) + if not isinstance(config_dict, dict): + return config_dict # type: ignore + + # Replace schema fields with serialized versions + config_dict.update(serialized_schemas) + + return config_dict + + +def _omit(obj: Any, keys: Iterable[str]): + return {k: v for k, v in obj.items() if k not in keys} + + +def _extract_metrics(response: Any) -> dict[str, float] | None: + """Extract token usage metrics from Google GenAI response.""" + if not response: + return None + + usage_metadata = getattr(response, "usage_metadata", None) + if not usage_metadata: + return None + + metrics: dict[str, float] = {} + + # Core token counts + if hasattr(usage_metadata, "prompt_token_count") and usage_metadata.prompt_token_count is not None: + metrics["prompt_tokens"] = float(usage_metadata.prompt_token_count) + + if hasattr(usage_metadata, "candidates_token_count") and usage_metadata.candidates_token_count is not None: + metrics["completion_tokens"] = float(usage_metadata.candidates_token_count) + + if hasattr(usage_metadata, "total_token_count") and usage_metadata.total_token_count is not None: + metrics["tokens"] = float(usage_metadata.total_token_count) + + # Cached token metrics + if hasattr(usage_metadata, "cached_content_token_count") and usage_metadata.cached_content_token_count is not None: + metrics["prompt_cached_tokens"] = float(usage_metadata.cached_content_token_count) + + # Reasoning token metrics (thoughts_token_count) + if hasattr(usage_metadata, "thoughts_token_count") and usage_metadata.thoughts_token_count is not None: + metrics["completion_reasoning_tokens"] = float(usage_metadata.thoughts_token_count) + + return metrics if metrics else None + + +def _extract_model_name(response: Any, llm_request: Any, instance: Any) -> str | None: + """Extract model name from Google GenAI response, request, or flow instance.""" + # Try to get from response first + if response: + model_version = getattr(response, "model_version", None) + if model_version: + return model_version + + # Try to get from llm_request + if llm_request: + if hasattr(llm_request, "model") and llm_request.model: + return str(llm_request.model) + + # Try to get from instance (flow's llm) + if instance: + if hasattr(instance, "llm"): + llm = instance.llm + if hasattr(llm, "model") and llm.model: + return str(llm.model) + + # Try to get model from instance directly + if hasattr(instance, "model") and instance.model: + return str(instance.model) + + return None diff --git a/py/src/braintrust/wrappers/claude_agent_sdk/_wrapper.py b/py/src/braintrust/wrappers/claude_agent_sdk/_wrapper.py new file mode 100644 index 00000000..71460302 --- /dev/null +++ b/py/src/braintrust/wrappers/claude_agent_sdk/_wrapper.py @@ -0,0 +1,943 @@ +import asyncio +import dataclasses +import logging +import threading +import time +from collections.abc import AsyncGenerator, AsyncIterable +from typing import Any + +from braintrust.logger import start_span +from braintrust.span_types import SpanTypeAttribute +from braintrust.wrappers._anthropic_utils import Wrapper, extract_anthropic_usage, finalize_anthropic_tokens +from braintrust.wrappers.claude_agent_sdk._constants import ( + ANTHROPIC_MESSAGES_CREATE_SPAN_NAME, + CLAUDE_AGENT_TASK_SPAN_NAME, + DEFAULT_TOOL_NAME, + MCP_TOOL_METADATA, + MCP_TOOL_NAME_DELIMITER, + MCP_TOOL_PREFIX, + SERIALIZED_CONTENT_TYPE_BY_BLOCK_CLASS, + SYSTEM_MESSAGE_TYPES, + TOOL_METADATA, + BlockClassName, + MessageClassName, + SerializedContentType, +) + + +log = logging.getLogger(__name__) +_thread_local = threading.local() + + +@dataclasses.dataclass(frozen=True) +class ParsedToolName: + raw_name: str + display_name: str + is_mcp: bool = False + mcp_server: str | None = None + + +@dataclasses.dataclass +class _ActiveToolSpan: + span: Any + raw_name: str + display_name: str + input: Any + handler_active: bool = False + + @property + def has_span(self) -> bool: + return True + + def activate(self) -> None: + self.handler_active = True + self.span.set_current() + + def log_error(self, exc: Exception) -> None: + self.span.log(error=str(exc)) + + def release(self) -> None: + if not self.handler_active: + return + + self.handler_active = False + self.span.unset_current() + + +class _NoopActiveToolSpan: + @property + def has_span(self) -> bool: + return False + + def log_error(self, exc: Exception) -> None: + del exc + + def release(self) -> None: + return + + +_NOOP_ACTIVE_TOOL_SPAN = _NoopActiveToolSpan() + + +def _parse_tool_name(tool_name: Any) -> ParsedToolName: + raw_name = str(tool_name) if tool_name is not None else DEFAULT_TOOL_NAME + + if not raw_name.startswith(MCP_TOOL_PREFIX): + return ParsedToolName(raw_name=raw_name, display_name=raw_name) + + remainder = raw_name[len(MCP_TOOL_PREFIX) :] + if not remainder: + return ParsedToolName(raw_name=raw_name, display_name=raw_name) + + server_and_tool = remainder.rsplit(MCP_TOOL_NAME_DELIMITER, 1) + if len(server_and_tool) != 2: + return ParsedToolName(raw_name=raw_name, display_name=raw_name) + + server_name, tool_display_name = server_and_tool + if not server_name or not tool_display_name: + return ParsedToolName(raw_name=raw_name, display_name=raw_name) + + return ParsedToolName( + raw_name=raw_name, + display_name=tool_display_name, + is_mcp=True, + mcp_server=server_name, + ) + + +def _serialize_tool_result_content(content: Any) -> Any: + if dataclasses.is_dataclass(content): + serialized_content = _serialize_content_blocks([content]) + return serialized_content[0] if serialized_content else None + + if not isinstance(content, list): + return content + + serialized_content = _serialize_content_blocks(content) + if ( + isinstance(serialized_content, list) + and len(serialized_content) == 1 + and isinstance(serialized_content[0], dict) + and serialized_content[0].get("type") == SerializedContentType.TEXT + and SerializedContentType.TEXT in serialized_content[0] + ): + return serialized_content[0][SerializedContentType.TEXT] + + return serialized_content + + +def _serialize_tool_result_output(tool_result_block: Any) -> dict[str, Any]: + output = {"content": _serialize_tool_result_content(getattr(tool_result_block, "content", None))} + + if getattr(tool_result_block, "is_error", None) is True: + output["is_error"] = True + + return output + + +def _serialize_system_message(message: Any) -> dict[str, Any]: + serialized = {"subtype": getattr(message, "subtype", None)} + + for field_name in ( + "task_id", + "description", + "uuid", + "session_id", + "tool_use_id", + "task_type", + "status", + "output_file", + "summary", + "last_tool_name", + "usage", + ): + value = getattr(message, field_name, None) + if value is not None: + serialized[field_name] = value + + if len(serialized) == 1: + data = getattr(message, "data", None) + if data: + serialized["data"] = data + + return serialized + + +def _create_tool_wrapper_class(original_tool_class: Any) -> Any: + """Creates a wrapper class for SdkMcpTool that re-enters active TOOL spans.""" + + class WrappedSdkMcpTool(original_tool_class): # type: ignore[valid-type,misc] + def __init__( + self, + name: Any, + description: Any, + input_schema: Any, + handler: Any, + **kwargs: Any, + ): + wrapped_handler = _wrap_tool_handler(handler, name) + super().__init__(name, description, input_schema, wrapped_handler, **kwargs) # type: ignore[call-arg] + + __class_getitem__ = classmethod(lambda cls, params: cls) # type: ignore[assignment] + + return WrappedSdkMcpTool + + +def _wrap_tool_factory(tool_fn: Any) -> Any: + """Wrap the tool() factory so decorated handlers inherit the active TOOL span.""" + + def wrapped_tool(*args: Any, **kwargs: Any) -> Any: + result = tool_fn(*args, **kwargs) + if not callable(result): + return result + + def wrapped_decorator(handler_fn: Any) -> Any: + tool_def = result(handler_fn) + if tool_def and hasattr(tool_def, "handler"): + tool_name = getattr(tool_def, "name", DEFAULT_TOOL_NAME) + tool_def.handler = _wrap_tool_handler(tool_def.handler, tool_name) + return tool_def + + return wrapped_decorator + + return wrapped_tool + + +def _wrap_tool_handler(handler: Any, tool_name: Any) -> Any: + """Wrap a tool handler so nested spans execute under the stream-based TOOL span.""" + if hasattr(handler, "_braintrust_wrapped"): + return handler + + async def wrapped_handler(args: Any) -> Any: + active_tool_span = _activate_tool_span_for_handler(tool_name, args) + if not active_tool_span.has_span: + with start_span( + name=str(tool_name), + span_attributes={"type": SpanTypeAttribute.TOOL}, + input=args, + ) as span: + result = await handler(args) + span.log(output=result) + return result + + try: + return await handler(args) + except Exception as exc: + active_tool_span.log_error(exc) + raise + finally: + active_tool_span.release() + + wrapped_handler._braintrust_wrapped = True # type: ignore[attr-defined] + return wrapped_handler + + +class ToolSpanTracker: + def __init__(self): + self._active_spans: dict[str, _ActiveToolSpan] = {} + self._pending_task_link_tool_use_ids: set[str] = set() + + def start_tool_spans(self, message: Any, llm_span_export: str | None) -> None: + if llm_span_export is None or not hasattr(message, "content"): + return + + for block in message.content: + if type(block).__name__ != BlockClassName.TOOL_USE: + continue + + tool_use_id = getattr(block, "id", None) + if not tool_use_id: + continue + + tool_use_id = str(tool_use_id) + if tool_use_id in self._active_spans: + self._end_tool_span(tool_use_id) + + parsed_tool_name = _parse_tool_name(getattr(block, "name", None)) + metadata = { + TOOL_METADATA.tool_name: parsed_tool_name.display_name, + TOOL_METADATA.tool_call_id: tool_use_id, + } + if parsed_tool_name.raw_name != parsed_tool_name.display_name: + metadata[TOOL_METADATA.raw_tool_name] = parsed_tool_name.raw_name + if parsed_tool_name.is_mcp: + metadata[TOOL_METADATA.operation_name] = MCP_TOOL_METADATA.operation_name + metadata[TOOL_METADATA.mcp_method_name] = MCP_TOOL_METADATA.method_name + if parsed_tool_name.mcp_server: + metadata[TOOL_METADATA.mcp_server] = parsed_tool_name.mcp_server + + tool_span = start_span( + name=parsed_tool_name.display_name, + span_attributes={"type": SpanTypeAttribute.TOOL}, + input=getattr(block, "input", None), + metadata=metadata, + parent=llm_span_export, + ) + self._active_spans[tool_use_id] = _ActiveToolSpan( + span=tool_span, + raw_name=parsed_tool_name.raw_name, + display_name=parsed_tool_name.display_name, + input=getattr(block, "input", None), + ) + if parsed_tool_name.display_name == "Agent": + self._pending_task_link_tool_use_ids.add(tool_use_id) + + def finish_tool_spans(self, message: Any) -> None: + if not hasattr(message, "content"): + return + + for block in message.content: + if type(block).__name__ != BlockClassName.TOOL_RESULT: + continue + + tool_use_id = getattr(block, "tool_use_id", None) + if tool_use_id is None: + continue + + self._end_tool_span(str(tool_use_id), tool_result_block=block) + + def cleanup(self, end_time: float | None = None, exclude_tool_use_ids: frozenset[str] | None = None) -> None: + for tool_use_id in list(self._active_spans): + if exclude_tool_use_ids and tool_use_id in exclude_tool_use_ids: + continue + self._end_tool_span(tool_use_id, end_time=end_time) + + @property + def has_active_spans(self) -> bool: + return bool(self._active_spans) + + @property + def pending_task_link_tool_use_ids(self) -> frozenset[str]: + return frozenset(self._pending_task_link_tool_use_ids) + + def mark_task_started(self, tool_use_id: Any) -> None: + if tool_use_id is None: + return + + self._pending_task_link_tool_use_ids.discard(str(tool_use_id)) + + def acquire_span_for_handler(self, tool_name: Any, args: Any) -> _ActiveToolSpan | None: + parsed_tool_name = _parse_tool_name(tool_name) + candidate_names = list( + dict.fromkeys((parsed_tool_name.raw_name, parsed_tool_name.display_name, str(tool_name))) + ) + + candidates = [ + active_tool_span + for active_tool_span in self._active_spans.values() + if not active_tool_span.handler_active + and (active_tool_span.raw_name in candidate_names or active_tool_span.display_name in candidate_names) + ] + + matched_span = _match_tool_span_for_handler(candidates, args) + if matched_span is None: + return None + + matched_span.activate() + return matched_span + + def _end_tool_span( + self, tool_use_id: str, tool_result_block: Any | None = None, end_time: float | None = None + ) -> None: + active_tool_span = self._active_spans.pop(tool_use_id, None) + self._pending_task_link_tool_use_ids.discard(tool_use_id) + if active_tool_span is None: + return + + if tool_result_block is None: + active_tool_span.span.end(end_time=end_time) + return + + output = _serialize_tool_result_output(tool_result_block) + log_event: dict[str, Any] = {"output": output} + if getattr(tool_result_block, "is_error", None) is True: + log_event["error"] = str(output["content"]) + active_tool_span.span.log(**log_event) + active_tool_span.span.end(end_time=end_time) + + def get_span_export(self, tool_use_id: Any) -> str | None: + if tool_use_id is None: + return None + + active_tool_span = self._active_spans.get(str(tool_use_id)) + if active_tool_span is None: + return None + + return active_tool_span.span.export() + + +def _match_tool_span_for_handler(candidates: list[_ActiveToolSpan], args: Any) -> _ActiveToolSpan | None: + if not candidates: + return None + + exact_input_matches = [candidate for candidate in candidates if candidate.input == args] + if exact_input_matches: + return exact_input_matches[0] + + if len(candidates) == 1: + return candidates[0] + + for active_tool_span in candidates: + if active_tool_span.input is None: + return active_tool_span + + return candidates[0] + + +def _activate_tool_span_for_handler(tool_name: Any, args: Any) -> _ActiveToolSpan | _NoopActiveToolSpan: + tool_span_tracker = getattr(_thread_local, "tool_span_tracker", None) + if tool_span_tracker is None: + return _NOOP_ACTIVE_TOOL_SPAN + + return tool_span_tracker.acquire_span_for_handler(tool_name, args) or _NOOP_ACTIVE_TOOL_SPAN + + +class LLMSpanTracker: + """Manages LLM span lifecycle for Claude Agent SDK message streams. + + Message flow per turn: + 1. UserMessage (tool results) -> mark the time when next LLM will start + 2. AssistantMessage - LLM response arrives -> create span with the marked start time, ending previous span + 3. ResultMessage - usage metrics -> log to span + + We end the previous span when the next AssistantMessage arrives, using the marked + start time to ensure sequential spans (no overlapping LLM spans). + """ + + def __init__(self, query_start_time: float | None = None): + self.current_span: Any | None = None + self.current_span_export: str | None = None + self.current_parent_export: str | None = None + self.current_output: list[dict[str, Any]] | None = None + self.next_start_time: float | None = query_start_time + + def get_next_start_time(self) -> float: + return self.next_start_time if self.next_start_time is not None else time.time() + + def start_llm_span( + self, + message: Any, + prompt: Any, + conversation_history: list[dict[str, Any]], + parent_export: str | None = None, + start_time: float | None = None, + ) -> tuple[dict[str, Any] | None, bool]: + """Start a new LLM span, ending the previous one if it exists.""" + current_message = _serialize_assistant_message(message) + + if ( + self.current_span + and self.next_start_time is None + and self.current_parent_export == parent_export + and current_message is not None + ): + merged_message = _merge_assistant_messages( + self.current_output[0] if self.current_output else None, + current_message, + ) + if merged_message is not None: + self.current_output = [merged_message] + self.current_span.log(output=self.current_output) + return merged_message, True + + resolved_start_time = start_time if start_time is not None else self.get_next_start_time() + first_token_time = time.time() + + if self.current_span: + self.current_span.end(end_time=resolved_start_time) + + final_content, span = _create_llm_span_for_messages( + [message], + prompt, + conversation_history, + parent=parent_export, + start_time=resolved_start_time, + ) + if span is not None: + span.log(metrics={"time_to_first_token": max(0.0, first_token_time - resolved_start_time)}) + self.current_span = span + self.current_span_export = span.export() if span else None + self.current_parent_export = parent_export + self.current_output = [final_content] if final_content is not None else None + self.next_start_time = None + return final_content, False + + def mark_next_llm_start(self) -> None: + """Mark when the next LLM call will start (after tool results).""" + self.next_start_time = time.time() + + def log_usage(self, usage_metrics: dict[str, float]) -> None: + """Log usage metrics to the current LLM span.""" + if self.current_span and usage_metrics: + self.current_span.log(metrics=usage_metrics) + + def cleanup(self) -> None: + """End any unclosed spans.""" + if self.current_span: + self.current_span.end() + self.current_span = None + self.current_span_export = None + self.current_parent_export = None + self.current_output = None + + +class TaskEventSpanTracker: + def __init__(self, root_span_export: str, tool_tracker: ToolSpanTracker): + self._root_span_export = root_span_export + self._tool_tracker = tool_tracker + self._active_spans: dict[str, Any] = {} + self._task_span_by_tool_use_id: dict[str, Any] = {} + self._active_task_order: list[str] = [] + + def process(self, message: Any) -> None: + task_id = getattr(message, "task_id", None) + if task_id is None: + return + + task_id = str(task_id) + message_type = type(message).__name__ + task_span = self._active_spans.get(task_id) + + if task_span is None: + task_span = start_span( + name=self._span_name(message, task_id), + span_attributes={"type": SpanTypeAttribute.TASK}, + metadata=self._metadata(message), + parent=self._parent_export(message), + ) + self._active_spans[task_id] = task_span + self._active_task_order.append(task_id) + tool_use_id = getattr(message, "tool_use_id", None) + if tool_use_id is not None: + tool_use_id = str(tool_use_id) + self._task_span_by_tool_use_id[tool_use_id] = task_span + self._tool_tracker.mark_task_started(tool_use_id) + else: + update: dict[str, Any] = {} + metadata = self._metadata(message) + if metadata: + update["metadata"] = metadata + + output = self._output(message) + if output is not None: + update["output"] = output + + if update: + task_span.log(**update) + + if self._should_end(message_type): + tool_use_id = getattr(message, "tool_use_id", None) + if tool_use_id is not None: + self._task_span_by_tool_use_id.pop(str(tool_use_id), None) + task_span.end() + del self._active_spans[task_id] + self._active_task_order = [ + active_task_id for active_task_id in self._active_task_order if active_task_id != task_id + ] + + @property + def active_tool_use_ids(self) -> frozenset[str]: + return frozenset(self._task_span_by_tool_use_id.keys()) + + def cleanup(self) -> None: + for task_id, span in list(self._active_spans.items()): + span.end() + del self._active_spans[task_id] + self._task_span_by_tool_use_id.clear() + self._active_task_order.clear() + + def parent_export_for_message(self, message: Any, fallback_export: str) -> str: + parent_tool_use_id = getattr(message, "parent_tool_use_id", None) + if parent_tool_use_id is None: + if _message_starts_subagent_tool(message): + return fallback_export + active_task_export = self._latest_active_task_export() + return active_task_export or fallback_export + + task_span = self._task_span_by_tool_use_id.get(str(parent_tool_use_id)) + if task_span is not None: + return task_span.export() + + active_task_export = self._latest_active_task_export() + return active_task_export or fallback_export + + def _latest_active_task_export(self) -> str | None: + for task_id in reversed(self._active_task_order): + task_span = self._active_spans.get(task_id) + if task_span is not None: + return task_span.export() + + return None + + def _parent_export(self, message: Any) -> str: + return self._tool_tracker.get_span_export(getattr(message, "tool_use_id", None)) or self._root_span_export + + def _span_name(self, message: Any, task_id: str) -> str: + return getattr(message, "description", None) or getattr(message, "task_type", None) or f"Task {task_id}" + + def _metadata(self, message: Any) -> dict[str, Any]: + metadata = { + k: v + for k, v in { + "task_id": getattr(message, "task_id", None), + "session_id": getattr(message, "session_id", None), + "tool_use_id": getattr(message, "tool_use_id", None), + "task_type": getattr(message, "task_type", None), + "status": getattr(message, "status", None), + "last_tool_name": getattr(message, "last_tool_name", None), + "usage": getattr(message, "usage", None), + }.items() + if v is not None + } + return metadata + + def _output(self, message: Any) -> dict[str, Any] | None: + summary = getattr(message, "summary", None) + output_file = getattr(message, "output_file", None) + + if summary is None and output_file is None: + return None + + return { + k: v + for k, v in { + "summary": summary, + "output_file": output_file, + }.items() + if v is not None + } + + def _should_end(self, message_type: str) -> bool: + return message_type == MessageClassName.TASK_NOTIFICATION + + +def _message_starts_subagent_tool(message: Any) -> bool: + if not hasattr(message, "content"): + return False + + for block in message.content: + if type(block).__name__ != BlockClassName.TOOL_USE: + continue + if getattr(block, "name", None) == "Agent": + return True + + return False + + +def _create_client_wrapper_class(original_client_class: Any) -> Any: + """Creates a wrapper class for ClaudeSDKClient that wraps query and receive_response.""" + + class WrappedClaudeSDKClient(Wrapper): + def __init__(self, *args: Any, **kwargs: Any): + # Create the original client instance + client = original_client_class(*args, **kwargs) + super().__init__(client) + self.__client = client + self.__last_prompt: str | None = None + self.__query_start_time: float | None = None + self.__captured_messages: list[dict[str, Any]] | None = None + + async def query(self, *args: Any, **kwargs: Any) -> Any: + """Wrap query to capture the prompt and start time for tracing.""" + # Capture the time when query is called (when LLM call starts) + self.__query_start_time = time.time() + self.__captured_messages = None + + # Capture the prompt for use in receive_response + prompt = args[0] if args else kwargs.get("prompt") + + if prompt is not None: + if isinstance(prompt, str): + self.__last_prompt = prompt + elif isinstance(prompt, AsyncIterable): + # AsyncIterable[dict] - wrap it to capture messages as they're yielded + captured: list[dict[str, Any]] = [] + self.__captured_messages = captured + self.__last_prompt = None # Will be set after messages are captured + + async def capturing_wrapper() -> AsyncGenerator[dict[str, Any], None]: + async for msg in prompt: + captured.append(msg) + yield msg + + # Replace the prompt with our capturing wrapper + if args: + args = (capturing_wrapper(),) + args[1:] + else: + kwargs["prompt"] = capturing_wrapper() + else: + self.__last_prompt = str(prompt) + + return await self.__client.query(*args, **kwargs) + + async def receive_response(self) -> AsyncGenerator[Any, None]: + """Wrap receive_response to add tracing. + + Uses start_span context manager which automatically: + - Handles exceptions and logs them as errors + - Sets the span as current so tool calls automatically nest under it + - Manages span lifecycle (start/end) + """ + generator = self.__client.receive_response() + + # Determine the initial input - may be updated later if using async generator + initial_input = self.__last_prompt if self.__last_prompt else None + + with start_span( + name=CLAUDE_AGENT_TASK_SPAN_NAME, + span_attributes={"type": SpanTypeAttribute.TASK}, + input=initial_input, + ) as span: + # If we're capturing async messages, we'll update input after they're consumed + input_needs_update = self.__captured_messages is not None + + final_results: list[dict[str, Any]] = [] + task_events: list[dict[str, Any]] = [] + llm_tracker = LLMSpanTracker(query_start_time=self.__query_start_time) + tool_tracker = ToolSpanTracker() + task_event_span_tracker = TaskEventSpanTracker(span.export(), tool_tracker) + _thread_local.tool_span_tracker = tool_tracker + + try: + async for message in generator: + # Update input from captured async messages (once, after they're consumed) + if input_needs_update: + captured_input = self.__captured_messages if self.__captured_messages else [] + if captured_input: + span.log(input=captured_input) + input_needs_update = False + + message_type = type(message).__name__ + + if message_type == MessageClassName.ASSISTANT: + if llm_tracker.current_span and tool_tracker.has_active_spans: + active_subagent_tool_use_ids = ( + task_event_span_tracker.active_tool_use_ids + | tool_tracker.pending_task_link_tool_use_ids + ) + tool_tracker.cleanup( + end_time=llm_tracker.get_next_start_time(), + exclude_tool_use_ids=active_subagent_tool_use_ids, + ) + llm_parent_export = task_event_span_tracker.parent_export_for_message( + message, + span.export(), + ) + final_content, extended_existing_span = llm_tracker.start_llm_span( + message, + self.__last_prompt, + final_results, + parent_export=llm_parent_export, + ) + tool_tracker.start_tool_spans(message, llm_tracker.current_span_export) + if final_content: + if ( + extended_existing_span + and final_results + and final_results[-1].get("role") == "assistant" + ): + final_results[-1] = final_content + else: + final_results.append(final_content) + elif message_type == MessageClassName.USER: + tool_tracker.finish_tool_spans(message) + has_tool_results = False + if hasattr(message, "content"): + has_tool_results = any( + type(block).__name__ == BlockClassName.TOOL_RESULT for block in message.content + ) + content = _serialize_content_blocks(message.content) + final_results.append({"content": content, "role": "user"}) + if has_tool_results: + llm_tracker.mark_next_llm_start() + elif message_type == MessageClassName.RESULT: + if hasattr(message, "usage"): + usage_metrics = _extract_usage_from_result_message(message) + llm_tracker.log_usage(usage_metrics) + + result_metadata = { + k: v + for k, v in { + "num_turns": getattr(message, "num_turns", None), + "session_id": getattr(message, "session_id", None), + }.items() + if v is not None + } + span.log(metadata=result_metadata) + elif message_type in SYSTEM_MESSAGE_TYPES: + task_event_span_tracker.process(message) + task_events.append(_serialize_system_message(message)) + + yield message + except asyncio.CancelledError: + # The CancelledError may come from the subprocess transport + # (e.g., anyio internal cleanup when subagents complete) rather + # than a genuine external cancellation. We suppress it here so + # the response stream ends cleanly. If the caller genuinely + # cancelled the task, they still have pending cancellation + # requests that will fire at their next await point. + if final_results: + span.log(output=final_results[-1]) + else: + if final_results: + span.log(output=final_results[-1]) + finally: + if task_events: + span.log(metadata={"task_events": task_events}) + task_event_span_tracker.cleanup() + tool_tracker.cleanup() + llm_tracker.cleanup() + if hasattr(_thread_local, "tool_span_tracker"): + delattr(_thread_local, "tool_span_tracker") + + async def __aenter__(self) -> "WrappedClaudeSDKClient": + await self.__client.__aenter__() + return self + + async def __aexit__(self, *args: Any) -> None: + await self.__client.__aexit__(*args) + + return WrappedClaudeSDKClient + + +def _create_llm_span_for_messages( + messages: list[Any], # List of AssistantMessage objects + prompt: Any, + conversation_history: list[dict[str, Any]], + parent: str | None = None, + start_time: float | None = None, +) -> tuple[dict[str, Any] | None, Any | None]: + """Creates an LLM span for a group of AssistantMessage objects. + + Returns a tuple of (final_content, span): + - final_content: The final message content to add to conversation history + - span: The LLM span object (for logging metrics later) + + Automatically nests under the current span (TASK span from receive_response). + + Note: This is called from within a catch_exceptions block, so errors won't break user code. + """ + if not messages: + return None, None + + last_message = messages[-1] + if type(last_message).__name__ != MessageClassName.ASSISTANT: + return None, None + model = getattr(last_message, "model", None) + input_messages = _build_llm_input(prompt, conversation_history) + + outputs: list[dict[str, Any]] = [] + for msg in messages: + if hasattr(msg, "content"): + content = _serialize_content_blocks(msg.content) + outputs.append({"content": content, "role": "assistant"}) + + llm_span = start_span( + name=ANTHROPIC_MESSAGES_CREATE_SPAN_NAME, + span_attributes={"type": SpanTypeAttribute.LLM}, + input=input_messages, + output=outputs, + metadata={"model": model} if model else None, + parent=parent, + start_time=start_time, + ) + + # Return final message content for conversation history and the span + if hasattr(last_message, "content"): + content = _serialize_content_blocks(last_message.content) + return {"content": content, "role": "assistant"}, llm_span + + return None, llm_span + + +def _serialize_assistant_message(message: Any) -> dict[str, Any] | None: + if not hasattr(message, "content"): + return None + + return {"content": _serialize_content_blocks(message.content), "role": "assistant"} + + +def _merge_assistant_messages(existing_message: dict[str, Any] | None, new_message: dict[str, Any]) -> dict[str, Any]: + if existing_message is None: + return new_message + + existing_content = existing_message.get("content") + new_content = new_message.get("content") + if isinstance(existing_content, list) and isinstance(new_content, list): + return { + "role": "assistant", + "content": [*existing_content, *new_content], + } + + return new_message + + +def _serialize_content_blocks(content: Any) -> Any: + """Converts content blocks to a serializable format with proper type fields. + + Claude Agent SDK uses dataclasses for content blocks, so we use dataclasses.asdict() + for serialization and add the 'type' field based on the class name. + """ + if isinstance(content, list): + result = [] + for block in content: + if dataclasses.is_dataclass(block): + serialized = dataclasses.asdict(block) + + block_type = type(block).__name__ + serialized_type = SERIALIZED_CONTENT_TYPE_BY_BLOCK_CLASS.get(block_type) + if serialized_type is not None: + serialized["type"] = serialized_type + + if block_type == BlockClassName.TOOL_RESULT: + content_value = serialized.get("content") + if isinstance(content_value, list) and len(content_value) == 1: + item = content_value[0] + if ( + isinstance(item, dict) + and item.get("type") == SerializedContentType.TEXT + and SerializedContentType.TEXT in item + ): + serialized["content"] = item[SerializedContentType.TEXT] + + if "is_error" in serialized and serialized["is_error"] is None: + del serialized["is_error"] + else: + serialized = block + + result.append(serialized) + return result + return content + + +def _extract_usage_from_result_message(result_message: Any) -> dict[str, float]: + """Extracts and normalizes usage metrics from a ResultMessage. + + Uses shared Anthropic utilities for consistent metric extraction. + """ + if not hasattr(result_message, "usage"): + return {} + + usage = result_message.usage + if not usage: + return {} + + metrics = extract_anthropic_usage(usage) + if metrics: + metrics = finalize_anthropic_tokens(metrics) + + return metrics + + +def _build_llm_input(prompt: Any, conversation_history: list[dict[str, Any]]) -> list[dict[str, Any]] | None: + """Builds the input array for an LLM span from the initial prompt and conversation history. + + Formats input to match Anthropic messages API format for proper UI rendering. + """ + if isinstance(prompt, str): + if len(conversation_history) == 0: + return [{"content": prompt, "role": "user"}] + else: + return [{"content": prompt, "role": "user"}] + conversation_history + + return conversation_history if conversation_history else None diff --git a/py/src/braintrust/wrappers/langchain.py b/py/src/braintrust/wrappers/langchain.py index 7a4c9f6d..9ffb9714 100644 --- a/py/src/braintrust/wrappers/langchain.py +++ b/py/src/braintrust/wrappers/langchain.py @@ -18,7 +18,6 @@ _logger.warning("Failed to import langchain, using stubs") BaseCallbackHandler = object Document = object - AgentAction = object BaseMessage = object LLMResult = object diff --git a/py/src/braintrust/wrappers/pydantic_ai.py b/py/src/braintrust/wrappers/pydantic_ai.py index e3442b85..6dd7ca45 100644 --- a/py/src/braintrust/wrappers/pydantic_ai.py +++ b/py/src/braintrust/wrappers/pydantic_ai.py @@ -327,80 +327,6 @@ def wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): return wrapper -def wrap_model_request(original_func: Any) -> Any: - async def wrapper(*args, **kwargs): - input_data, metadata = _build_direct_model_input_and_metadata(args, kwargs) - - with start_span( - name="model_request", - type=SpanTypeAttribute.LLM, - input=input_data, - metadata=metadata, - ) as span: - start_time = time.time() - result = await original_func(*args, **kwargs) - end_time = time.time() - - output = _serialize_model_response(result) - metrics = _extract_response_metrics(result, start_time, end_time) - - span.log(output=output, metrics=metrics) - return result - - return wrapper - - -def wrap_model_request_sync(original_func: Any) -> Any: - def wrapper(*args, **kwargs): - input_data, metadata = _build_direct_model_input_and_metadata(args, kwargs) - - with start_span( - name="model_request_sync", - type=SpanTypeAttribute.LLM, - input=input_data, - metadata=metadata, - ) as span: - start_time = time.time() - result = original_func(*args, **kwargs) - end_time = time.time() - - output = _serialize_model_response(result) - metrics = _extract_response_metrics(result, start_time, end_time) - - span.log(output=output, metrics=metrics) - return result - - return wrapper - - -def wrap_model_request_stream(original_func: Any) -> Any: - def wrapper(*args, **kwargs): - input_data, metadata = _build_direct_model_input_and_metadata(args, kwargs) - - return _DirectStreamWrapper( - original_func(*args, **kwargs), - "model_request_stream", - input_data, - metadata, - ) - - return wrapper - - -def wrap_model_request_stream_sync(original_func: Any) -> Any: - def wrapper(*args, **kwargs): - input_data, metadata = _build_direct_model_input_and_metadata(args, kwargs) - - return _DirectStreamWrapperSync( - original_func(*args, **kwargs), - "model_request_stream_sync", - input_data, - metadata, - ) - - return wrapper - - def wrap_model_classes(): """Wrap Model classes to capture internal model requests made by agents.""" try: From 1243d1bcac610b8b9366ea7cd75138baf271421d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Halber?= Date: Tue, 24 Mar 2026 23:22:06 +0000 Subject: [PATCH 08/11] chore: remove probably unused code (need human review) --- py/src/braintrust/framework.py | 10 ++++++---- py/src/braintrust/otel/test_distributed_tracing.py | 1 - py/src/braintrust/otel/test_otel_bt_integration.py | 1 - py/src/braintrust/test_http.py | 5 ----- py/src/braintrust/wrappers/google_genai/__init__.py | 12 ------------ 5 files changed, 6 insertions(+), 23 deletions(-) diff --git a/py/src/braintrust/framework.py b/py/src/braintrust/framework.py index a89710b0..2cb3c5e8 100644 --- a/py/src/braintrust/framework.py +++ b/py/src/braintrust/framework.py @@ -62,13 +62,15 @@ # https://stackoverflow.com/questions/287871/how-do-i-print-colored-text-to-the-terminal class bcolors: -# HEADER = "\033[95m" -# OKBLUE = "\033[94m" -# OKCYAN = "\033[96m" -# OKGREEN = "\033[92m" + # HEADER = "\033[95m" + # OKBLUE = "\033[94m" + # OKCYAN = "\033[96m" + # OKGREEN = "\033[92m" WARNING = "\033[93m" FAIL = "\033[91m" ENDC = "\033[0m" + + # BOLD = "\033[1m" # UNDERLINE = "\033[4m" diff --git a/py/src/braintrust/otel/test_distributed_tracing.py b/py/src/braintrust/otel/test_distributed_tracing.py index a2fab2a2..1d9b8c86 100644 --- a/py/src/braintrust/otel/test_distributed_tracing.py +++ b/py/src/braintrust/otel/test_distributed_tracing.py @@ -123,7 +123,6 @@ def test_bt_to_otel_simple_distributed_trace(otel_fixture): assert len(otel_spans) == 1, "Should have 1 OTEL span from Service B" # Get the spans - service_a_exported = bt_spans[0] service_b_exported = otel_spans[0] # Convert OTEL IDs to hex for comparison diff --git a/py/src/braintrust/otel/test_otel_bt_integration.py b/py/src/braintrust/otel/test_otel_bt_integration.py index 579082d9..6792982e 100644 --- a/py/src/braintrust/otel/test_otel_bt_integration.py +++ b/py/src/braintrust/otel/test_otel_bt_integration.py @@ -197,7 +197,6 @@ def test_mixed_otel_bt_tracing_with_otel_first(otel_fixture): s1_trace_id = format(s1.context.trace_id, "032x") s1_span_id = format(s1.context.span_id, "016x") s3_trace_id = format(s3.context.trace_id, "032x") - s3_span_id = format(s3.context.span_id, "016x") assert s1_trace_id == s2["root_span_id"] assert s1_trace_id == s3_trace_id diff --git a/py/src/braintrust/test_http.py b/py/src/braintrust/test_http.py index b9ede8d8..ba5ac282 100644 --- a/py/src/braintrust/test_http.py +++ b/py/src/braintrust/test_http.py @@ -404,17 +404,12 @@ def do_GET(self): session.mount("http://", adapter) errors = [] - success_count = 0 lock = threading.Lock() def make_request(i): - nonlocal success_count try: time.sleep(i * 0.005) # Stagger requests resp = session.get(f"{url}/test{i}") - if resp.status_code == 200: - with lock: - success_count += 1 return resp.status_code except Exception as e: with lock: diff --git a/py/src/braintrust/wrappers/google_genai/__init__.py b/py/src/braintrust/wrappers/google_genai/__init__.py index 3bdae565..dfa3737c 100644 --- a/py/src/braintrust/wrappers/google_genai/__init__.py +++ b/py/src/braintrust/wrappers/google_genai/__init__.py @@ -531,15 +531,3 @@ def _aggregate_generate_content_chunks( def clean(obj: dict[str, Any]) -> dict[str, Any]: return {k: v for k, v in obj.items() if v is not None} - - -def get_path(obj: dict[str, Any], path: str, default: Any = None) -> Any | None: - keys = path.split(".") - current = obj - - for key in keys: - if not (isinstance(current, dict) and key in current): - return default - current = current[key] - - return current From ed9c147d72b86dbedbab08ffe36c4350b9728fde Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Halber?= Date: Wed, 25 Mar 2026 00:27:41 +0000 Subject: [PATCH 09/11] chore: remove probably unused code (need human review) --- .../braintrust/integrations/adk/test_adk_mcp_tool.py | 6 ------ py/src/braintrust/test_context.py | 8 -------- py/src/braintrust/test_logger.py | 3 --- .../wrappers/test_pydantic_ai_integration.py | 10 ++++------ 4 files changed, 4 insertions(+), 23 deletions(-) diff --git a/py/src/braintrust/integrations/adk/test_adk_mcp_tool.py b/py/src/braintrust/integrations/adk/test_adk_mcp_tool.py index c58ec190..ab7a4d27 100644 --- a/py/src/braintrust/integrations/adk/test_adk_mcp_tool.py +++ b/py/src/braintrust/integrations/adk/test_adk_mcp_tool.py @@ -146,9 +146,6 @@ async def run_async(self, *, args, tool_context): # Verify error was logged to span assert mock_span.log.called - # Check if error was logged - log_calls = [call for call in mock_span.log.call_args_list] - # Should have logged the error @pytest.mark.asyncio @@ -310,9 +307,6 @@ async def test_real_context_loss_with_braintrust_spans(): # Initialize a test logger logger = init_logger(project="test-context-loss") - # Track if we hit the context error - context_error_occurred = False - async def problematic_generator(): """Generator that creates a span and yields, simulating the Flow behavior.""" from braintrust import start_span diff --git a/py/src/braintrust/test_context.py b/py/src/braintrust/test_context.py index 313756cf..9c70a987 100644 --- a/py/src/braintrust/test_context.py +++ b/py/src/braintrust/test_context.py @@ -896,8 +896,6 @@ async def generator_with_finally() -> AsyncGenerator[int, None]: yield 1 yield 2 finally: - # What context do we have during cleanup? - cleanup_span = current_span() gen_span.end() # Consumer @@ -1152,14 +1150,11 @@ def test_nested_spans_same_thread(test_logger, with_memory_logger): # Child span with start_span(name="child") as child_span: - child_id = child_span.id - # Verify child is now current assert current_span().id == child_span.id # Grandchild span with start_span(name="grandchild") as grandchild_span: - grandchild_id = grandchild_span.id assert current_span().id == grandchild_span.id # After grandchild closes, child should be current @@ -1227,13 +1222,10 @@ def test_context_with_exception_propagation(test_logger, with_memory_logger): """ Test that context is properly maintained during exception propagation. """ - fail_span_id = None def failing_function(): - nonlocal fail_span_id # Use context manager for proper span lifecycle with start_span(name="failing_span") as fail_span: - fail_span_id = fail_span.id # During this context, fail_span should be current assert current_span().id == fail_span.id raise ValueError("Expected error") diff --git a/py/src/braintrust/test_logger.py b/py/src/braintrust/test_logger.py index 7662ad77..39513c1c 100644 --- a/py/src/braintrust/test_logger.py +++ b/py/src/braintrust/test_logger.py @@ -1437,9 +1437,6 @@ def test_span_set_current(with_memory_logger): """Test that span.set_current() makes the span accessible via current_span().""" init_test_logger(__name__) - # Store initial current span - initial_current = braintrust.current_span() - # Start a span that can be set as current (default behavior) span1 = logger.start_span(name="test-span-1") diff --git a/py/src/braintrust/wrappers/test_pydantic_ai_integration.py b/py/src/braintrust/wrappers/test_pydantic_ai_integration.py index 81de2ea4..8ed8b2e0 100644 --- a/py/src/braintrust/wrappers/test_pydantic_ai_integration.py +++ b/py/src/braintrust/wrappers/test_pydantic_ai_integration.py @@ -325,15 +325,13 @@ async def run_multiple_streams(): # First stream async with agent1.run_stream("Count from 1 to 3.") as result1: - full_text1 = "" - async for text in result1.stream_text(delta=True): - full_text1 += text + async for _ in result1.stream_text(delta=True): + pass # Second stream async with agent2.run_stream("Count from 1 to 3.") as result2: - full_text2 = "" - async for text in result2.stream_text(delta=True): - full_text2 += text + async for _ in result2.stream_text(delta=True): + pass return start From 3ae6bfef47424c841abdb7f6498c329efdaa47b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Halber?= Date: Wed, 25 Mar 2026 18:42:11 +0000 Subject: [PATCH 10/11] chore: wrongly removed code --- py/src/braintrust/db_fields.py | 9 +++++++++ py/src/braintrust/framework.py | 17 +++++++++++++---- pyproject.toml | 2 +- 3 files changed, 23 insertions(+), 5 deletions(-) diff --git a/py/src/braintrust/db_fields.py b/py/src/braintrust/db_fields.py index 6fd95df4..a89b9710 100644 --- a/py/src/braintrust/db_fields.py +++ b/py/src/braintrust/db_fields.py @@ -1,12 +1,21 @@ TRANSACTION_ID_FIELD = "_xact_id" OBJECT_DELETE_FIELD = "_object_delete" +CREATED_FIELD = "created" +ID_FIELD = "id" IS_MERGE_FIELD = "_is_merge" +MERGE_PATHS_FIELD = "_merge_paths" +ARRAY_DELETE_FIELD = "_array_delete" AUDIT_SOURCE_FIELD = "_audit_source" AUDIT_METADATA_FIELD = "_audit_metadata" VALID_SOURCES = ["app", "api", "external"] +PARENT_ID_FIELD = "_parent_id" + +ASYNC_SCORING_CONTROL_FIELD = "_async_scoring_control" +SKIP_ASYNC_SCORING_FIELD = "_skip_async_scoring" + # Keys that identify which object (experiment, dataset, project logs, etc.) a row belongs to. OBJECT_ID_KEYS = ( "experiment_id", diff --git a/py/src/braintrust/framework.py b/py/src/braintrust/framework.py index 2cb3c5e8..040b869a 100644 --- a/py/src/braintrust/framework.py +++ b/py/src/braintrust/framework.py @@ -69,10 +69,8 @@ class bcolors: WARNING = "\033[93m" FAIL = "\033[91m" ENDC = "\033[0m" - - -# BOLD = "\033[1m" -# UNDERLINE = "\033[4m" + # BOLD = "\033[1m" + # UNDERLINE = "\033[4m" @dataclasses.dataclass @@ -230,6 +228,17 @@ def parameters(self) -> ValidatedParameters | None: """ +class EvalScorerArgs(SerializableDataClass, Generic[Input, Output]): + """ + Arguments passed to an evaluator scorer. This includes the input, expected output, actual output, and metadata. + """ + + input: Input + output: Output + expected: Output | None = None + metadata: Metadata | None = None + + OneOrMoreScores = Union[float, int, bool, None, Score, list[Score]] diff --git a/pyproject.toml b/pyproject.toml index 31230cc9..212d6046 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,5 +23,5 @@ addopts = "--durations=3 --durations-min=0.1" [tool.vulture] paths = ["py/src"] -ignore_names = ["with_simulate_login", "reset_id_generator_state", "dataset_record_id"] # pytest fixtures and deprecated-but-public API parameters +ignore_names = ["with_simulate_login", "reset_id_generator_state", "dataset_record_id", "EvalScorerArgs", "CREATED_FIELD", "ID_FIELD", "MERGE_PATHS_FIELD", "ARRAY_DELETE_FIELD", "PARENT_ID_FIELD", "ASYNC_SCORING_CONTROL_FIELD", "SKIP_ASYNC_SCORING_FIELD"] # pytest fixtures, deprecated-but-public API, and protocol field constants min_confidence = 100 From a6369c04ea412b532bcd209f166d46553064429a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Halber?= Date: Thu, 26 Mar 2026 18:15:16 -0700 Subject: [PATCH 11/11] chore:remove unused import --- py/src/braintrust/wrappers/langchain.py | 1 - 1 file changed, 1 deletion(-) diff --git a/py/src/braintrust/wrappers/langchain.py b/py/src/braintrust/wrappers/langchain.py index 9ffb9714..71835e77 100644 --- a/py/src/braintrust/wrappers/langchain.py +++ b/py/src/braintrust/wrappers/langchain.py @@ -11,7 +11,6 @@ try: from langchain_classic.callbacks.base import BaseCallbackHandler from langchain_classic.schema import Document - from langchain_classic.schema.agent import AgentAction from langchain_classic.schema.messages import BaseMessage from langchain_classic.schema.output import LLMResult except ImportError: