diff --git a/py/noxfile.py b/py/noxfile.py index 5e51312b..7a79ab64 100644 --- a/py/noxfile.py +++ b/py/noxfile.py @@ -255,7 +255,7 @@ def test_dspy(session, version): session.skip("dspy latest requires Python >= 3.10 (litellm dependency)") _install_test_deps(session) _install(session, "dspy", version) - _run_tests(session, f"{WRAPPER_DIR}/test_dspy.py") + _run_tests(session, f"{INTEGRATION_DIR}/dspy/test_dspy.py") @nox.session() diff --git a/py/src/braintrust/auto.py b/py/src/braintrust/auto.py index 48276ea3..c71dd140 100644 --- a/py/src/braintrust/auto.py +++ b/py/src/braintrust/auto.py @@ -12,6 +12,7 @@ AgnoIntegration, AnthropicIntegration, ClaudeAgentSDKIntegration, + DSPyIntegration, GoogleGenAIIntegration, ) @@ -125,7 +126,7 @@ def auto_instrument( if claude_agent_sdk: results["claude_agent_sdk"] = _instrument_integration(ClaudeAgentSDKIntegration) if dspy: - results["dspy"] = _instrument_dspy() + results["dspy"] = _instrument_integration(DSPyIntegration) if adk: results["adk"] = _instrument_integration(ADKIntegration) @@ -160,11 +161,3 @@ def _instrument_pydantic_ai() -> bool: return setup_pydantic_ai() return False - - -def _instrument_dspy() -> bool: - with _try_patch(): - from braintrust.wrappers.dspy import patch_dspy - - return patch_dspy() - return False diff --git a/py/src/braintrust/integrations/__init__.py b/py/src/braintrust/integrations/__init__.py index db4f048f..4d02c323 100644 --- a/py/src/braintrust/integrations/__init__.py +++ b/py/src/braintrust/integrations/__init__.py @@ -2,6 +2,7 @@ from .agno import AgnoIntegration from .anthropic import AnthropicIntegration from .claude_agent_sdk import ClaudeAgentSDKIntegration +from .dspy import DSPyIntegration from .google_genai import GoogleGenAIIntegration @@ -10,5 +11,6 @@ "AgnoIntegration", "AnthropicIntegration", "ClaudeAgentSDKIntegration", + "DSPyIntegration", "GoogleGenAIIntegration", ] diff --git a/py/src/braintrust/integrations/auto_test_scripts/test_auto_dspy.py b/py/src/braintrust/integrations/auto_test_scripts/test_auto_dspy.py index 924ceb46..bb6de8c0 100644 --- a/py/src/braintrust/integrations/auto_test_scripts/test_auto_dspy.py +++ b/py/src/braintrust/integrations/auto_test_scripts/test_auto_dspy.py @@ -7,16 +7,15 @@ import dspy from braintrust.auto import auto_instrument -from braintrust.wrappers.dspy import BraintrustDSpyCallback +from braintrust.integrations.dspy import BraintrustDSpyCallback # 1. Verify not patched initially -assert not getattr(dspy, "__braintrust_wrapped__", False) +assert not getattr(dspy.configure, "__braintrust_patched_dspy_configure__", False) # 2. Instrument results = auto_instrument() assert results.get("dspy") == True -assert getattr(dspy, "__braintrust_wrapped__", False) # 3. Idempotent results2 = auto_instrument() diff --git a/py/src/braintrust/integrations/base.py b/py/src/braintrust/integrations/base.py index 690e6c22..8abd3320 100644 --- a/py/src/braintrust/integrations/base.py +++ b/py/src/braintrust/integrations/base.py @@ -114,13 +114,27 @@ def patch_marker_attr(cls) -> str: @classmethod def mark_patched(cls, obj: Any) -> None: """Mark a wrapped target so future patch attempts are idempotent.""" - setattr(obj, cls.patch_marker_attr(), True) + try: + setattr(obj, cls.patch_marker_attr(), True) + except AttributeError: + # Some objects (e.g. bound methods) don't support setattr. + # Callers that need a fallback location (like ``patch()``) handle + # this by catching the failure and storing the marker elsewhere. + pass @classmethod def is_patched(cls, module: Any | None, version: str | None, *, target: Any | None = None) -> bool: """Return whether this patcher's target has already been instrumented.""" + marker = cls.patch_marker_attr() resolved_target = cls.resolve_target(module, version, target=target) - return bool(resolved_target is not None and getattr(resolved_target, cls.patch_marker_attr(), False)) + if resolved_target is not None and getattr(resolved_target, marker, False): + return True + # Fall back to checking the root — the marker may live there when the + # resolved target does not support setattr (e.g. bound methods). + root = cls.resolve_root(module, version, target=target) + if root is not None and root is not resolved_target and getattr(root, marker, False): + return True + return False @classmethod def patch(cls, module: Any | None, version: str | None, *, target: Any | None = None) -> bool: @@ -130,11 +144,16 @@ def patch(cls, module: Any | None, version: str | None, *, target: Any | None = return False wrap_function_wrapper(root, cls.target_path, cls.wrapper) - resolved_target = cls.resolve_target(module, version, target=target) + resolved_target = _resolve_attr_path(root, cls.target_path) if resolved_target is None: return False + marker = cls.patch_marker_attr() cls.mark_patched(resolved_target) + # If mark_patched could not store the marker on the target (e.g. bound + # methods), store it on the root so is_patched() can still find it. + if not getattr(resolved_target, marker, False): + setattr(root, marker, True) return True @classmethod diff --git a/py/src/braintrust/integrations/dspy/__init__.py b/py/src/braintrust/integrations/dspy/__init__.py new file mode 100644 index 00000000..9d650198 --- /dev/null +++ b/py/src/braintrust/integrations/dspy/__init__.py @@ -0,0 +1,12 @@ +"""Braintrust integration for DSPy.""" + +from .integration import DSPyIntegration +from .patchers import patch_dspy +from .tracing import BraintrustDSpyCallback + + +__all__ = [ + "BraintrustDSpyCallback", + "DSPyIntegration", + "patch_dspy", +] diff --git a/py/src/braintrust/wrappers/cassettes/test_dspy_callback.yaml b/py/src/braintrust/integrations/dspy/cassettes/test_dspy_callback.yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_dspy_callback.yaml rename to py/src/braintrust/integrations/dspy/cassettes/test_dspy_callback.yaml diff --git a/py/src/braintrust/integrations/dspy/integration.py b/py/src/braintrust/integrations/dspy/integration.py new file mode 100644 index 00000000..8129c86c --- /dev/null +++ b/py/src/braintrust/integrations/dspy/integration.py @@ -0,0 +1,13 @@ +"""DSPy integration — orchestration class and setup entry-point.""" + +from braintrust.integrations.base import BaseIntegration + +from .patchers import DSPyConfigurePatcher + + +class DSPyIntegration(BaseIntegration): + """Braintrust instrumentation for DSPy.""" + + name = "dspy" + import_names = ("dspy",) + patchers = (DSPyConfigurePatcher,) diff --git a/py/src/braintrust/integrations/dspy/patchers.py b/py/src/braintrust/integrations/dspy/patchers.py new file mode 100644 index 00000000..16fcbb66 --- /dev/null +++ b/py/src/braintrust/integrations/dspy/patchers.py @@ -0,0 +1,43 @@ +"""DSPy patchers — one patcher per coherent patch target.""" + +from braintrust.integrations.base import FunctionWrapperPatcher + +from .tracing import _configure_wrapper + + +class DSPyConfigurePatcher(FunctionWrapperPatcher): + """Patch ``dspy.configure`` to auto-add ``BraintrustDSpyCallback``.""" + + name = "dspy.configure" + target_path = "configure" + wrapper = _configure_wrapper + + +# --------------------------------------------------------------------------- +# Public helper +# --------------------------------------------------------------------------- + + +def patch_dspy() -> bool: + """ + Patch DSPy to automatically add Braintrust tracing callback. + + After calling this, all calls to dspy.configure() will automatically + include the BraintrustDSpyCallback. + + Returns: + True if DSPy was patched (or already patched), False if DSPy is not installed. + + Example: + ```python + import braintrust + braintrust.patch_dspy() + + import dspy + lm = dspy.LM("openai/gpt-4o-mini") + dspy.configure(lm=lm) # BraintrustDSpyCallback auto-added! + ``` + """ + from .integration import DSPyIntegration + + return DSPyIntegration.setup() diff --git a/py/src/braintrust/wrappers/test_dspy.py b/py/src/braintrust/integrations/dspy/test_dspy.py similarity index 82% rename from py/src/braintrust/wrappers/test_dspy.py rename to py/src/braintrust/integrations/dspy/test_dspy.py index edbc6334..f6a0f1da 100644 --- a/py/src/braintrust/wrappers/test_dspy.py +++ b/py/src/braintrust/integrations/dspy/test_dspy.py @@ -5,8 +5,8 @@ import dspy import pytest from braintrust import logger +from braintrust.integrations.dspy import BraintrustDSpyCallback from braintrust.test_helpers import init_test_logger -from braintrust.wrappers.dspy import BraintrustDSpyCallback from braintrust.wrappers.test_utils import run_in_subprocess, verify_autoinstrument_script @@ -63,17 +63,14 @@ def test_dspy_callback(memory_logger): class TestPatchDSPy: - """Tests for patch_dspy() / unpatch_dspy().""" + """Tests for patch_dspy().""" - def test_patch_dspy_sets_wrapped_flag(self): - """patch_dspy() should set __braintrust_wrapped__ on dspy module.""" + def test_patch_dspy_patches_configure(self): + """patch_dspy() should patch dspy.configure via the integration patcher.""" result = run_in_subprocess(""" - dspy = __import__("dspy") - from braintrust.wrappers.dspy import patch_dspy - - assert not hasattr(dspy, "__braintrust_wrapped__") - patch_dspy() - assert hasattr(dspy, "__braintrust_wrapped__") + from braintrust.integrations.dspy import patch_dspy + result = patch_dspy() + assert result, "patch_dspy() should return True" print("SUCCESS") """) assert result.returncode == 0, f"Failed: {result.stderr}" @@ -82,7 +79,7 @@ def test_patch_dspy_sets_wrapped_flag(self): def test_patch_dspy_wraps_configure(self): """After patch_dspy(), dspy.configure() should auto-add BraintrustDSpyCallback.""" result = run_in_subprocess(""" - from braintrust.wrappers.dspy import patch_dspy, BraintrustDSpyCallback + from braintrust.integrations.dspy import patch_dspy, BraintrustDSpyCallback patch_dspy() import dspy @@ -103,7 +100,7 @@ def test_patch_dspy_wraps_configure(self): def test_patch_dspy_preserves_existing_callbacks(self): """patch_dspy() should preserve user-provided callbacks.""" result = run_in_subprocess(""" - from braintrust.wrappers.dspy import patch_dspy, BraintrustDSpyCallback + from braintrust.integrations.dspy import patch_dspy, BraintrustDSpyCallback patch_dspy() import dspy @@ -132,7 +129,7 @@ class MyCallback(BaseCallback): def test_patch_dspy_does_not_duplicate_callback(self): """patch_dspy() should not add duplicate BraintrustDSpyCallback.""" result = run_in_subprocess(""" - from braintrust.wrappers.dspy import patch_dspy, BraintrustDSpyCallback + from braintrust.integrations.dspy import patch_dspy, BraintrustDSpyCallback patch_dspy() import dspy @@ -155,7 +152,7 @@ def test_patch_dspy_does_not_duplicate_callback(self): def test_patch_dspy_idempotent(self): """Multiple patch_dspy() calls should be safe.""" result = run_in_subprocess(""" - from braintrust.wrappers.dspy import patch_dspy + from braintrust.integrations.dspy import patch_dspy import dspy patch_dspy() @@ -169,6 +166,17 @@ def test_patch_dspy_idempotent(self): assert result.returncode == 0, f"Failed: {result.stderr}" assert "SUCCESS" in result.stdout + def test_legacy_wrapper_import_still_works(self): + """The old braintrust.wrappers.dspy import path should still work.""" + result = run_in_subprocess(""" + from braintrust.wrappers.dspy import BraintrustDSpyCallback, patch_dspy + assert BraintrustDSpyCallback is not None + assert callable(patch_dspy) + print("SUCCESS") + """) + assert result.returncode == 0, f"Failed: {result.stderr}" + assert "SUCCESS" in result.stdout + class TestAutoInstrumentDSPy: """Tests for auto_instrument() with DSPy.""" diff --git a/py/src/braintrust/integrations/dspy/tracing.py b/py/src/braintrust/integrations/dspy/tracing.py new file mode 100644 index 00000000..dd5131b0 --- /dev/null +++ b/py/src/braintrust/integrations/dspy/tracing.py @@ -0,0 +1,367 @@ +"""DSPy-specific callback, span creation, and metadata extraction.""" + +from typing import Any + +from braintrust.logger import current_span, start_span +from braintrust.span_types import SpanTypeAttribute + + +try: + from dspy.utils.callback import BaseCallback + + _HAS_DSPY = True +except ImportError: + _HAS_DSPY = False + BaseCallback = object # type: ignore[assignment,misc] + + +class BraintrustDSpyCallback(BaseCallback): + """Callback handler that logs DSPy execution traces to Braintrust. + + This callback integrates DSPy with Braintrust's observability platform, automatically + logging language model calls, module executions, tool invocations, and evaluations. + + Logged information includes: + - Input parameters and output results + - Execution latency + - Error information when exceptions occur + - Hierarchical span relationships for nested operations + + Basic Example: + ```python + import dspy + from braintrust import init_logger + from braintrust.integrations.dspy import BraintrustDSpyCallback + + # Initialize Braintrust + init_logger(project="dspy-example") + + # Configure DSPy with callback + lm = dspy.LM("openai/gpt-4o-mini") + dspy.configure(lm=lm, callbacks=[BraintrustDSpyCallback()]) + + # Use DSPy - execution is automatically logged + predictor = dspy.Predict("question -> answer") + result = predictor(question="What is 2+2?") + ``` + + Advanced Example with LiteLLM Patching: + For additional detailed token metrics from LiteLLM's wrapper, patch before importing DSPy + and disable DSPy's disk cache: + + ```python + from braintrust.wrappers.litellm import patch_litellm + patch_litellm() + + import dspy + from braintrust import init_logger + from braintrust.integrations.dspy import BraintrustDSpyCallback + + init_logger(project="dspy-example") + + # Disable disk cache to ensure LiteLLM calls are traced + dspy.configure_cache(enable_disk_cache=False, enable_memory_cache=True) + + lm = dspy.LM("openai/gpt-4o-mini") + dspy.configure(lm=lm, callbacks=[BraintrustDSpyCallback()]) + ``` + + The callback creates Braintrust spans for: + - DSPy module executions (Predict, ChainOfThought, ReAct, etc.) + - LLM calls with latency metrics + - Tool calls + - Evaluation runs + + For detailed token usage and cost metrics, use LiteLLM patching (see Advanced Example above). + The patched LiteLLM wrapper will create additional "Completion" spans with comprehensive metrics. + + Spans are automatically nested based on the execution hierarchy. + """ + + def __init__(self): + """Initialize the Braintrust DSPy callback handler.""" + if not _HAS_DSPY: + raise ImportError("DSPy is not installed. Please install it with: pip install dspy") + super().__init__() + # Map call_id to span objects for proper nesting + self._spans: dict[str, Any] = {} + + def on_lm_start( + self, + call_id: str, + instance: Any, + inputs: dict[str, Any], + ): + """Log the start of a language model call. + + Args: + call_id: Unique identifier for this call + instance: The LM instance being called + inputs: Input parameters to the LM + """ + metadata = {} + if hasattr(instance, "model"): + metadata["model"] = instance.model + if hasattr(instance, "provider"): + metadata["provider"] = str(instance.provider) + + for key in ["temperature", "max_tokens", "top_p", "top_k", "stop"]: + if key in inputs: + metadata[key] = inputs[key] + + parent = current_span() + parent_export = parent.export() if parent else None + + span = start_span( + name="dspy.lm", + input=inputs, + metadata=metadata, + parent=parent_export, + ) + span.set_current() + self._spans[call_id] = span + + def on_lm_end( + self, + call_id: str, + outputs: dict[str, Any] | None, + exception: Exception | None = None, + ): + """Log the end of a language model call. + + Args: + call_id: Unique identifier for this call + outputs: Output from the LM, or None if there was an exception + exception: Exception raised during execution, if any + """ + span = self._spans.pop(call_id, None) + if not span: + return + + try: + log_data = {} + if exception: + log_data["error"] = exception + if outputs is not None: + log_data["output"] = outputs + + if log_data: + span.log(**log_data) + finally: + span.unset_current() + span.end() + + def on_module_start( + self, + call_id: str, + instance: Any, + inputs: dict[str, Any], + ): + """Log the start of a DSPy module execution. + + Args: + call_id: Unique identifier for this call + instance: The Module instance being called + inputs: Input parameters to the module's forward() method + """ + cls = instance.__class__ + cls_name = cls.__name__ + module_name = f"{cls.__module__}.{cls_name}" + + parent = current_span() + parent_export = parent.export() if parent else None + + span = start_span( + name=f"dspy.module.{cls_name}", + input=inputs, + metadata={"module_class": module_name}, + parent=parent_export, + ) + span.set_current() + self._spans[call_id] = span + + def on_module_end( + self, + call_id: str, + outputs: Any | None, + exception: Exception | None = None, + ): + """Log the end of a DSPy module execution. + + Args: + call_id: Unique identifier for this call + outputs: Output from the module, or None if there was an exception + exception: Exception raised during execution, if any + """ + span = self._spans.pop(call_id, None) + if not span: + return + + try: + log_data = {} + if exception: + log_data["error"] = exception + if outputs is not None: + if hasattr(outputs, "toDict"): + output_dict = outputs.toDict() + elif hasattr(outputs, "__dict__"): + output_dict = outputs.__dict__ + else: + output_dict = outputs + log_data["output"] = output_dict + + if log_data: + span.log(**log_data) + finally: + span.unset_current() + span.end() + + def on_tool_start( + self, + call_id: str, + instance: Any, + inputs: dict[str, Any], + ): + """Log the start of a tool invocation. + + Args: + call_id: Unique identifier for this call + instance: The Tool instance being called + inputs: Input parameters to the tool + """ + tool_name = "unknown" + if hasattr(instance, "name"): + tool_name = instance.name + elif hasattr(instance, "__name__"): + tool_name = instance.__name__ + elif hasattr(instance, "func") and hasattr(instance.func, "__name__"): + tool_name = instance.func.__name__ + + parent = current_span() + parent_export = parent.export() if parent else None + + span = start_span( + name=tool_name, + span_attributes={"type": SpanTypeAttribute.TOOL}, + input=inputs, + parent=parent_export, + ) + span.set_current() + self._spans[call_id] = span + + def on_tool_end( + self, + call_id: str, + outputs: dict[str, Any] | None, + exception: Exception | None = None, + ): + """Log the end of a tool invocation. + + Args: + call_id: Unique identifier for this call + outputs: Output from the tool, or None if there was an exception + exception: Exception raised during execution, if any + """ + span = self._spans.pop(call_id, None) + if not span: + return + + try: + log_data = {} + if exception: + log_data["error"] = exception + if outputs is not None: + log_data["output"] = outputs + + if log_data: + span.log(**log_data) + finally: + span.unset_current() + span.end() + + def on_evaluate_start( + self, + call_id: str, + instance: Any, + inputs: dict[str, Any], + ): + """Log the start of an evaluation run. + + Args: + call_id: Unique identifier for this call + instance: The Evaluate instance + inputs: Input parameters to the evaluation + """ + metadata = {} + if hasattr(instance, "metric") and instance.metric: + if hasattr(instance.metric, "__name__"): + metadata["metric"] = instance.metric.__name__ + if hasattr(instance, "num_threads"): + metadata["num_threads"] = instance.num_threads + + parent = current_span() + parent_export = parent.export() if parent else None + + span = start_span( + name="dspy.evaluate", + input=inputs, + metadata=metadata, + parent=parent_export, + ) + span.set_current() + self._spans[call_id] = span + + def on_evaluate_end( + self, + call_id: str, + outputs: Any | None, + exception: Exception | None = None, + ): + """Log the end of an evaluation run. + + Args: + call_id: Unique identifier for this call + outputs: Output from the evaluation, or None if there was an exception + exception: Exception raised during execution, if any + """ + span = self._spans.pop(call_id, None) + if not span: + return + + try: + log_data = {} + if exception: + log_data["error"] = exception + if outputs is not None: + log_data["output"] = outputs + if isinstance(outputs, dict): + metrics = {} + for key in ["accuracy", "score", "total", "correct"]: + if key in outputs: + try: + metrics[key] = float(outputs[key]) + except (ValueError, TypeError): + pass + if metrics: + log_data["metrics"] = metrics + + if log_data: + span.log(**log_data) + finally: + span.unset_current() + span.end() + + +def _configure_wrapper(wrapped, instance, args, kwargs): + """Wrapper for dspy.configure that auto-adds BraintrustDSpyCallback.""" + callbacks = kwargs.get("callbacks") + if callbacks is None: + callbacks = [] + else: + callbacks = list(callbacks) + + if not any(isinstance(cb, BraintrustDSpyCallback) for cb in callbacks): + callbacks.append(BraintrustDSpyCallback()) + + kwargs["callbacks"] = callbacks + return wrapped(*args, **kwargs) diff --git a/py/src/braintrust/wrappers/dspy.py b/py/src/braintrust/wrappers/dspy.py index 713b3cfe..bb0f5e2a 100644 --- a/py/src/braintrust/wrappers/dspy.py +++ b/py/src/braintrust/wrappers/dspy.py @@ -1,467 +1,6 @@ -""" -Braintrust integration for DSPy. +"""Backward-compatible re-exports — implementation lives in braintrust.integrations.dspy.""" -This module provides the BraintrustDSpyCallback class for logging DSPy execution traces to Braintrust. +from braintrust.integrations.dspy import BraintrustDSpyCallback, patch_dspy -Basic Usage: - ```python - import dspy - from braintrust import init_logger - from braintrust.wrappers.dspy import BraintrustDSpyCallback - - # Initialize Braintrust logger - init_logger(project="my-dspy-project") - - # Configure DSPy with Braintrust callback - lm = dspy.LM("openai/gpt-4o-mini") - dspy.configure(lm=lm, callbacks=[BraintrustDSpyCallback()]) - - # Use DSPy as normal - all execution will be logged to Braintrust - cot = dspy.ChainOfThought("question -> answer") - result = cot(question="What is the capital of France?") - ``` - -Advanced Usage with LiteLLM Patching: - For more detailed token metrics and tracing, you can patch LiteLLM before importing DSPy. - Note: You must disable DSPy's disk cache to ensure all LLM calls are traced. - - ```python - # IMPORTANT: Patch LiteLLM BEFORE importing DSPy - from braintrust.wrappers.litellm import patch_litellm - patch_litellm() - - import dspy - from braintrust import init_logger - from braintrust.wrappers.dspy import BraintrustDSpyCallback - - logger = init_logger(project="my-project") - - # Disable disk cache to ensure LiteLLM wrapper is always called - dspy.configure_cache( - enable_disk_cache=False, - enable_memory_cache=True, # Keep memory cache for performance - ) - - lm = dspy.LM("openai/gpt-4o-mini") - dspy.configure(lm=lm, callbacks=[BraintrustDSpyCallback()]) - ``` -""" - -from typing import Any - -from braintrust.logger import current_span, start_span -from braintrust.span_types import SpanTypeAttribute -from wrapt import wrap_function_wrapper - - -# Note: For detailed token and cost metrics, use patch_litellm() before importing DSPy. -# The DSPy callback focuses on execution flow and span hierarchy. - -try: - from dspy.utils.callback import BaseCallback -except ImportError: - raise ImportError("DSPy is not installed. Please install it with: pip install dspy") __all__ = ["BraintrustDSpyCallback", "patch_dspy"] - - -class BraintrustDSpyCallback(BaseCallback): - """Callback handler that logs DSPy execution traces to Braintrust. - - This callback integrates DSPy with Braintrust's observability platform, automatically - logging language model calls, module executions, tool invocations, and evaluations. - - Logged information includes: - - Input parameters and output results - - Execution latency - - Error information when exceptions occur - - Hierarchical span relationships for nested operations - - Basic Example: - ```python - import dspy - from braintrust import init_logger - from braintrust.wrappers.dspy import BraintrustDSpyCallback - - # Initialize Braintrust - init_logger(project="dspy-example") - - # Configure DSPy with callback - lm = dspy.LM("openai/gpt-4o-mini") - dspy.configure(lm=lm, callbacks=[BraintrustDSpyCallback()]) - - # Use DSPy - execution is automatically logged - predictor = dspy.Predict("question -> answer") - result = predictor(question="What is 2+2?") - ``` - - Advanced Example with LiteLLM Patching: - For additional detailed token metrics from LiteLLM's wrapper, patch before importing DSPy - and disable DSPy's disk cache: - - ```python - from braintrust.wrappers.litellm import patch_litellm - patch_litellm() - - import dspy - from braintrust import init_logger - from braintrust.wrappers.dspy import BraintrustDSpyCallback - - init_logger(project="dspy-example") - - # Disable disk cache to ensure LiteLLM calls are traced - dspy.configure_cache(enable_disk_cache=False, enable_memory_cache=True) - - lm = dspy.LM("openai/gpt-4o-mini") - dspy.configure(lm=lm, callbacks=[BraintrustDSpyCallback()]) - ``` - - The callback creates Braintrust spans for: - - DSPy module executions (Predict, ChainOfThought, ReAct, etc.) - - LLM calls with latency metrics - - Tool calls - - Evaluation runs - - For detailed token usage and cost metrics, use LiteLLM patching (see Advanced Example above). - The patched LiteLLM wrapper will create additional "Completion" spans with comprehensive metrics. - - Spans are automatically nested based on the execution hierarchy. - """ - - def __init__(self): - """Initialize the Braintrust DSPy callback handler.""" - super().__init__() - # Map call_id to span objects for proper nesting - self._spans: dict[str, Any] = {} - - def on_lm_start( - self, - call_id: str, - instance: Any, - inputs: dict[str, Any], - ): - """Log the start of a language model call. - - Args: - call_id: Unique identifier for this call - instance: The LM instance being called - inputs: Input parameters to the LM - """ - # Extract metadata from the LM instance and inputs - metadata = {} - if hasattr(instance, "model"): - metadata["model"] = instance.model - if hasattr(instance, "provider"): - metadata["provider"] = str(instance.provider) - - # Extract common LM parameters from inputs - for key in ["temperature", "max_tokens", "top_p", "top_k", "stop"]: - if key in inputs: - metadata[key] = inputs[key] - - # Get the current active span to establish parent-child relationship - parent = current_span() - parent_export = parent.export() if parent else None - - span = start_span( - name="dspy.lm", - input=inputs, - metadata=metadata, - parent=parent_export, - ) - # Manually set as current span so children can find it - span.set_current() - self._spans[call_id] = span - - def on_lm_end( - self, - call_id: str, - outputs: dict[str, Any] | None, - exception: Exception | None = None, - ): - """Log the end of a language model call. - - Args: - call_id: Unique identifier for this call - outputs: Output from the LM, or None if there was an exception - exception: Exception raised during execution, if any - """ - span = self._spans.pop(call_id, None) - if not span: - return - - try: - log_data = {} - if exception: - log_data["error"] = exception - if outputs: - log_data["output"] = outputs - - if log_data: - span.log(**log_data) - finally: - span.unset_current() - span.end() - - def on_module_start( - self, - call_id: str, - instance: Any, - inputs: dict[str, Any], - ): - """Log the start of a DSPy module execution. - - Args: - call_id: Unique identifier for this call - instance: The Module instance being called - inputs: Input parameters to the module's forward() method - """ - # Get module name - module_name = instance.__class__.__name__ - if hasattr(instance, "__class__") and hasattr(instance.__class__, "__module__"): - module_name = f"{instance.__class__.__module__}.{instance.__class__.__name__}" - - # Get the current active span to establish parent-child relationship - parent = current_span() - parent_export = parent.export() if parent else None - - span = start_span( - name=f"dspy.module.{instance.__class__.__name__}", - input=inputs, - metadata={"module_class": module_name}, - parent=parent_export, - ) - # Manually set as current span so children can find it - span.set_current() - self._spans[call_id] = span - - def on_module_end( - self, - call_id: str, - outputs: Any | None, - exception: Exception | None = None, - ): - """Log the end of a DSPy module execution. - - Args: - call_id: Unique identifier for this call - outputs: Output from the module, or None if there was an exception - exception: Exception raised during execution, if any - """ - span = self._spans.pop(call_id, None) - if not span: - return - - try: - log_data = {} - if exception: - log_data["error"] = exception - if outputs is not None: - # Convert DSPy Prediction objects to dict for logging - if hasattr(outputs, "toDict"): - output_dict = outputs.toDict() - elif hasattr(outputs, "__dict__"): - output_dict = outputs.__dict__ - else: - output_dict = outputs - log_data["output"] = output_dict - - if log_data: - span.log(**log_data) - finally: - span.unset_current() - span.end() - - def on_tool_start( - self, - call_id: str, - instance: Any, - inputs: dict[str, Any], - ): - """Log the start of a tool invocation. - - Args: - call_id: Unique identifier for this call - instance: The Tool instance being called - inputs: Input parameters to the tool - """ - # Get tool name - tool_name = "unknown" - if hasattr(instance, "name"): - tool_name = instance.name - elif hasattr(instance, "__name__"): - tool_name = instance.__name__ - elif hasattr(instance, "func") and hasattr(instance.func, "__name__"): - tool_name = instance.func.__name__ - - # Get the current active span to establish parent-child relationship - parent = current_span() - parent_export = parent.export() if parent else None - - span = start_span( - name=tool_name, - span_attributes={"type": SpanTypeAttribute.TOOL}, - input=inputs, - parent=parent_export, - ) - # Manually set as current span so children can find it - span.set_current() - self._spans[call_id] = span - - def on_tool_end( - self, - call_id: str, - outputs: dict[str, Any] | None, - exception: Exception | None = None, - ): - """Log the end of a tool invocation. - - Args: - call_id: Unique identifier for this call - outputs: Output from the tool, or None if there was an exception - exception: Exception raised during execution, if any - """ - span = self._spans.pop(call_id, None) - if not span: - return - - try: - log_data = {} - if exception: - log_data["error"] = exception - if outputs is not None: - log_data["output"] = outputs - - if log_data: - span.log(**log_data) - finally: - span.unset_current() - span.end() - - def on_evaluate_start( - self, - call_id: str, - instance: Any, - inputs: dict[str, Any], - ): - """Log the start of an evaluation run. - - Args: - call_id: Unique identifier for this call - instance: The Evaluate instance - inputs: Input parameters to the evaluation - """ - metadata = {} - # Extract evaluation metadata - if hasattr(instance, "metric") and instance.metric: - if hasattr(instance.metric, "__name__"): - metadata["metric"] = instance.metric.__name__ - if hasattr(instance, "num_threads"): - metadata["num_threads"] = instance.num_threads - - # Get the current active span to establish parent-child relationship - parent = current_span() - parent_export = parent.export() if parent else None - - span = start_span( - name="dspy.evaluate", - input=inputs, - metadata=metadata, - parent=parent_export, - ) - # Manually set as current span so children can find it - span.set_current() - self._spans[call_id] = span - - def on_evaluate_end( - self, - call_id: str, - outputs: Any | None, - exception: Exception | None = None, - ): - """Log the end of an evaluation run. - - Args: - call_id: Unique identifier for this call - outputs: Output from the evaluation, or None if there was an exception - exception: Exception raised during execution, if any - """ - span = self._spans.pop(call_id, None) - if not span: - return - - try: - log_data = {} - if exception: - log_data["error"] = exception - if outputs is not None: - log_data["output"] = outputs - # Extract metrics from evaluation results - if isinstance(outputs, dict): - metrics = {} - # Common evaluation metrics - for key in ["accuracy", "score", "total", "correct"]: - if key in outputs: - try: - metrics[key] = float(outputs[key]) - except (ValueError, TypeError): - pass - if metrics: - log_data["metrics"] = metrics - - if log_data: - span.log(**log_data) - finally: - span.unset_current() - span.end() - - -def _configure_wrapper(wrapped, instance, args, kwargs): - """Wrapper for dspy.configure that auto-adds BraintrustDSpyCallback.""" - callbacks = kwargs.get("callbacks") - if callbacks is None: - callbacks = [] - else: - callbacks = list(callbacks) - - # Check if already has Braintrust callback - has_bt_callback = any(isinstance(cb, BraintrustDSpyCallback) for cb in callbacks) - if not has_bt_callback: - callbacks.append(BraintrustDSpyCallback()) - - kwargs["callbacks"] = callbacks - return wrapped(*args, **kwargs) - - -def patch_dspy() -> bool: - """ - Patch DSPy to automatically add Braintrust tracing callback. - - After calling this, all calls to dspy.configure() will automatically - include the BraintrustDSpyCallback. - - Returns: - True if DSPy was patched (or already patched), False if DSPy is not installed. - - Example: - ```python - import braintrust - braintrust.patch_dspy() - - import dspy - lm = dspy.LM("openai/gpt-4o-mini") - dspy.configure(lm=lm) # BraintrustDSpyCallback auto-added! - ``` - """ - try: - import dspy - - if getattr(dspy, "__braintrust_wrapped__", False): - return True # Already patched - - wrap_function_wrapper("dspy", "configure", _configure_wrapper) - dspy.__braintrust_wrapped__ = True - return True - - except ImportError: - return False