diff --git a/py/noxfile.py b/py/noxfile.py index 7a79ab64..ed554ce9 100644 --- a/py/noxfile.py +++ b/py/noxfile.py @@ -243,7 +243,7 @@ def test_litellm(session, version): # Install fastapi and orjson as they're required by litellm for proxy/responses operations session.install("openai<=1.99.9", "--force-reinstall", "fastapi", "orjson") _install(session, "litellm", version) - _run_tests(session, f"{WRAPPER_DIR}/test_litellm.py") + _run_tests(session, f"{INTEGRATION_DIR}/litellm/test_litellm.py") _run_core_tests(session) diff --git a/py/src/braintrust/__init__.py b/py/src/braintrust/__init__.py index c961ac72..a43ad0fb 100644 --- a/py/src/braintrust/__init__.py +++ b/py/src/braintrust/__init__.py @@ -73,6 +73,9 @@ def is_equal(expected, output): from .integrations.anthropic import ( wrap_anthropic, # noqa: F401 # type: ignore[reportUnusedImport] ) +from .integrations.litellm import ( + wrap_litellm, # noqa: F401 # type: ignore[reportUnusedImport] +) from .logger import * from .logger import ( _internal_get_global_state, # noqa: F401 # type: ignore[reportUnusedImport] @@ -92,9 +95,6 @@ def is_equal(expected, output): BT_IS_ASYNC_ATTRIBUTE, # noqa: F401 # type: ignore[reportUnusedImport] MarkAsyncWrapper, # noqa: F401 # type: ignore[reportUnusedImport] ) -from .wrappers.litellm import ( - wrap_litellm, # noqa: F401 # type: ignore[reportUnusedImport] -) from .wrappers.pydantic_ai import ( setup_pydantic_ai, # noqa: F401 # type: ignore[reportUnusedImport] ) diff --git a/py/src/braintrust/auto.py b/py/src/braintrust/auto.py index c71dd140..eb4951f4 100644 --- a/py/src/braintrust/auto.py +++ b/py/src/braintrust/auto.py @@ -14,6 +14,7 @@ ClaudeAgentSDKIntegration, DSPyIntegration, GoogleGenAIIntegration, + LiteLLMIntegration, ) @@ -116,7 +117,7 @@ def auto_instrument( if anthropic: results["anthropic"] = _instrument_integration(AnthropicIntegration) if litellm: - results["litellm"] = _instrument_litellm() + results["litellm"] = _instrument_integration(LiteLLMIntegration) if pydantic_ai: results["pydantic_ai"] = _instrument_pydantic_ai() if google_genai: @@ -147,14 +148,6 @@ def _instrument_integration(integration) -> bool: return False -def _instrument_litellm() -> bool: - with _try_patch(): - from braintrust.wrappers.litellm import patch_litellm - - return patch_litellm() - return False - - def _instrument_pydantic_ai() -> bool: with _try_patch(): from braintrust.wrappers.pydantic_ai import setup_pydantic_ai diff --git a/py/src/braintrust/integrations/__init__.py b/py/src/braintrust/integrations/__init__.py index 4d02c323..6fcec6fc 100644 --- a/py/src/braintrust/integrations/__init__.py +++ b/py/src/braintrust/integrations/__init__.py @@ -4,6 +4,7 @@ from .claude_agent_sdk import ClaudeAgentSDKIntegration from .dspy import DSPyIntegration from .google_genai import GoogleGenAIIntegration +from .litellm import LiteLLMIntegration __all__ = [ @@ -13,4 +14,5 @@ "ClaudeAgentSDKIntegration", "DSPyIntegration", "GoogleGenAIIntegration", + "LiteLLMIntegration", ] diff --git a/py/src/braintrust/integrations/auto_test_scripts/test_auto_litellm.py b/py/src/braintrust/integrations/auto_test_scripts/test_auto_litellm.py index 0d8db254..5663c982 100644 --- a/py/src/braintrust/integrations/auto_test_scripts/test_auto_litellm.py +++ b/py/src/braintrust/integrations/auto_test_scripts/test_auto_litellm.py @@ -1,24 +1,29 @@ """Test auto_instrument for LiteLLM.""" +from pathlib import Path + import litellm from braintrust.auto import auto_instrument +from braintrust.integrations.litellm import LiteLLMIntegration from braintrust.wrappers.test_utils import autoinstrument_test_context +_CASSETTES_DIR = Path(__file__).resolve().parent.parent / "litellm" / "cassettes" + # 1. Verify not patched initially -assert not hasattr(litellm, "_braintrust_wrapped") +assert not LiteLLMIntegration.patchers[0].is_patched(litellm, None) # 2. Instrument results = auto_instrument() assert results.get("litellm") == True -assert hasattr(litellm, "_braintrust_wrapped") +assert LiteLLMIntegration.patchers[0].is_patched(litellm, None) # 3. Idempotent results2 = auto_instrument() assert results2.get("litellm") == True # 4. Make API call and verify span -with autoinstrument_test_context("test_auto_litellm") as memory_logger: +with autoinstrument_test_context("test_auto_litellm", cassettes_dir=_CASSETTES_DIR) as memory_logger: response = litellm.completion( model="gpt-4o-mini", messages=[{"role": "user", "content": "Say hi"}], diff --git a/py/src/braintrust/integrations/auto_test_scripts/test_patch_litellm_aresponses.py b/py/src/braintrust/integrations/auto_test_scripts/test_patch_litellm_aresponses.py index 42191de0..bf882504 100644 --- a/py/src/braintrust/integrations/auto_test_scripts/test_patch_litellm_aresponses.py +++ b/py/src/braintrust/integrations/auto_test_scripts/test_patch_litellm_aresponses.py @@ -1,17 +1,20 @@ """Test that patch_litellm() patches aresponses.""" import asyncio +from pathlib import Path import litellm -from braintrust.wrappers.litellm import patch_litellm +from braintrust.integrations.litellm import patch_litellm from braintrust.wrappers.test_utils import autoinstrument_test_context +_CASSETTES_DIR = Path(__file__).resolve().parent.parent / "litellm" / "cassettes" + patch_litellm() async def main(): - with autoinstrument_test_context("test_patch_litellm_aresponses") as memory_logger: + with autoinstrument_test_context("test_patch_litellm_aresponses", cassettes_dir=_CASSETTES_DIR) as memory_logger: response = await litellm.aresponses( model="gpt-4o-mini", input="What's 12 + 12?", diff --git a/py/src/braintrust/integrations/auto_test_scripts/test_patch_litellm_responses.py b/py/src/braintrust/integrations/auto_test_scripts/test_patch_litellm_responses.py index 2b2eac38..96105fef 100644 --- a/py/src/braintrust/integrations/auto_test_scripts/test_patch_litellm_responses.py +++ b/py/src/braintrust/integrations/auto_test_scripts/test_patch_litellm_responses.py @@ -1,13 +1,17 @@ """Test that patch_litellm() patches responses.""" +from pathlib import Path + import litellm -from braintrust.wrappers.litellm import patch_litellm +from braintrust.integrations.litellm import patch_litellm from braintrust.wrappers.test_utils import autoinstrument_test_context +_CASSETTES_DIR = Path(__file__).resolve().parent.parent / "litellm" / "cassettes" + patch_litellm() -with autoinstrument_test_context("test_patch_litellm_responses") as memory_logger: +with autoinstrument_test_context("test_patch_litellm_responses", cassettes_dir=_CASSETTES_DIR) as memory_logger: response = litellm.responses( model="gpt-4o-mini", input="What's 12 + 12?", diff --git a/py/src/braintrust/integrations/dspy/tracing.py b/py/src/braintrust/integrations/dspy/tracing.py index dd5131b0..e771edb9 100644 --- a/py/src/braintrust/integrations/dspy/tracing.py +++ b/py/src/braintrust/integrations/dspy/tracing.py @@ -50,7 +50,7 @@ class BraintrustDSpyCallback(BaseCallback): and disable DSPy's disk cache: ```python - from braintrust.wrappers.litellm import patch_litellm + from braintrust.integrations.litellm import patch_litellm patch_litellm() import dspy diff --git a/py/src/braintrust/integrations/litellm/__init__.py b/py/src/braintrust/integrations/litellm/__init__.py new file mode 100644 index 00000000..fb5b2ee8 --- /dev/null +++ b/py/src/braintrust/integrations/litellm/__init__.py @@ -0,0 +1,40 @@ +"""Braintrust LiteLLM integration.""" + +from .integration import LiteLLMIntegration +from .patchers import wrap_litellm + + +def patch_litellm() -> bool: + """Patch LiteLLM to add Braintrust tracing. + + This wraps litellm.completion, litellm.acompletion, litellm.responses, + litellm.aresponses, litellm.embedding, and litellm.moderation to + automatically create Braintrust spans with detailed token metrics, + timing, and costs. + + Returns: + True if LiteLLM was patched (or already patched), False if LiteLLM is not installed. + + Example: + ```python + import braintrust + braintrust.patch_litellm() + + import litellm + from braintrust import init_logger + + logger = init_logger(project="my-project") + response = litellm.completion( + model="gpt-4o-mini", + messages=[{"role": "user", "content": "Hello"}] + ) + ``` + """ + return LiteLLMIntegration.setup() + + +__all__ = [ + "LiteLLMIntegration", + "patch_litellm", + "wrap_litellm", +] diff --git a/py/src/braintrust/wrappers/cassettes/test_auto_litellm.yaml b/py/src/braintrust/integrations/litellm/cassettes/test_auto_litellm.yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_auto_litellm.yaml rename to py/src/braintrust/integrations/litellm/cassettes/test_auto_litellm.yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_litellm_acompletion_metrics.yaml b/py/src/braintrust/integrations/litellm/cassettes/test_litellm_acompletion_metrics.yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_litellm_acompletion_metrics.yaml rename to py/src/braintrust/integrations/litellm/cassettes/test_litellm_acompletion_metrics.yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_litellm_acompletion_streaming_async.yaml b/py/src/braintrust/integrations/litellm/cassettes/test_litellm_acompletion_streaming_async.yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_litellm_acompletion_streaming_async.yaml rename to py/src/braintrust/integrations/litellm/cassettes/test_litellm_acompletion_streaming_async.yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_litellm_acompletion_with_system_prompt.yaml b/py/src/braintrust/integrations/litellm/cassettes/test_litellm_acompletion_with_system_prompt.yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_litellm_acompletion_with_system_prompt.yaml rename to py/src/braintrust/integrations/litellm/cassettes/test_litellm_acompletion_with_system_prompt.yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_litellm_aresponses_metrics.yaml b/py/src/braintrust/integrations/litellm/cassettes/test_litellm_aresponses_metrics.yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_litellm_aresponses_metrics.yaml rename to py/src/braintrust/integrations/litellm/cassettes/test_litellm_aresponses_metrics.yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_litellm_aresponses_streaming_async.yaml b/py/src/braintrust/integrations/litellm/cassettes/test_litellm_aresponses_streaming_async.yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_litellm_aresponses_streaming_async.yaml rename to py/src/braintrust/integrations/litellm/cassettes/test_litellm_aresponses_streaming_async.yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_litellm_async_parallel_requests.yaml b/py/src/braintrust/integrations/litellm/cassettes/test_litellm_async_parallel_requests.yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_litellm_async_parallel_requests.yaml rename to py/src/braintrust/integrations/litellm/cassettes/test_litellm_async_parallel_requests.yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_litellm_async_streaming_with_break.yaml b/py/src/braintrust/integrations/litellm/cassettes/test_litellm_async_streaming_with_break.yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_litellm_async_streaming_with_break.yaml rename to py/src/braintrust/integrations/litellm/cassettes/test_litellm_async_streaming_with_break.yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_litellm_completion_metrics.yaml b/py/src/braintrust/integrations/litellm/cassettes/test_litellm_completion_metrics.yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_litellm_completion_metrics.yaml rename to py/src/braintrust/integrations/litellm/cassettes/test_litellm_completion_metrics.yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_litellm_completion_streaming_sync.yaml b/py/src/braintrust/integrations/litellm/cassettes/test_litellm_completion_streaming_sync.yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_litellm_completion_streaming_sync.yaml rename to py/src/braintrust/integrations/litellm/cassettes/test_litellm_completion_streaming_sync.yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_litellm_completion_with_system_prompt.yaml b/py/src/braintrust/integrations/litellm/cassettes/test_litellm_completion_with_system_prompt.yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_litellm_completion_with_system_prompt.yaml rename to py/src/braintrust/integrations/litellm/cassettes/test_litellm_completion_with_system_prompt.yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_litellm_embeddings.yaml b/py/src/braintrust/integrations/litellm/cassettes/test_litellm_embeddings.yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_litellm_embeddings.yaml rename to py/src/braintrust/integrations/litellm/cassettes/test_litellm_embeddings.yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_litellm_moderation.yaml b/py/src/braintrust/integrations/litellm/cassettes/test_litellm_moderation.yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_litellm_moderation.yaml rename to py/src/braintrust/integrations/litellm/cassettes/test_litellm_moderation.yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_litellm_openrouter_no_booleans_in_metrics.yaml b/py/src/braintrust/integrations/litellm/cassettes/test_litellm_openrouter_no_booleans_in_metrics.yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_litellm_openrouter_no_booleans_in_metrics.yaml rename to py/src/braintrust/integrations/litellm/cassettes/test_litellm_openrouter_no_booleans_in_metrics.yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_litellm_responses_metrics.yaml b/py/src/braintrust/integrations/litellm/cassettes/test_litellm_responses_metrics.yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_litellm_responses_metrics.yaml rename to py/src/braintrust/integrations/litellm/cassettes/test_litellm_responses_metrics.yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_litellm_responses_streaming_sync.yaml b/py/src/braintrust/integrations/litellm/cassettes/test_litellm_responses_streaming_sync.yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_litellm_responses_streaming_sync.yaml rename to py/src/braintrust/integrations/litellm/cassettes/test_litellm_responses_streaming_sync.yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_litellm_tool_calls.yaml b/py/src/braintrust/integrations/litellm/cassettes/test_litellm_tool_calls.yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_litellm_tool_calls.yaml rename to py/src/braintrust/integrations/litellm/cassettes/test_litellm_tool_calls.yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_patch_litellm_aresponses.yaml b/py/src/braintrust/integrations/litellm/cassettes/test_patch_litellm_aresponses.yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_patch_litellm_aresponses.yaml rename to py/src/braintrust/integrations/litellm/cassettes/test_patch_litellm_aresponses.yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_patch_litellm_responses.yaml b/py/src/braintrust/integrations/litellm/cassettes/test_patch_litellm_responses.yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_patch_litellm_responses.yaml rename to py/src/braintrust/integrations/litellm/cassettes/test_patch_litellm_responses.yaml diff --git a/py/src/braintrust/integrations/litellm/integration.py b/py/src/braintrust/integrations/litellm/integration.py new file mode 100644 index 00000000..b4faac1e --- /dev/null +++ b/py/src/braintrust/integrations/litellm/integration.py @@ -0,0 +1,13 @@ +"""LiteLLM integration definition.""" + +from braintrust.integrations.base import BaseIntegration + +from .patchers import _ALL_LITELLM_PATCHERS + + +class LiteLLMIntegration(BaseIntegration): + """Braintrust instrumentation for the LiteLLM Python SDK.""" + + name = "litellm" + import_names = ("litellm",) + patchers = _ALL_LITELLM_PATCHERS diff --git a/py/src/braintrust/integrations/litellm/patchers.py b/py/src/braintrust/integrations/litellm/patchers.py new file mode 100644 index 00000000..5019dc00 --- /dev/null +++ b/py/src/braintrust/integrations/litellm/patchers.py @@ -0,0 +1,109 @@ +"""LiteLLM patchers — FunctionWrapperPatcher subclasses for each patch target.""" + +from typing import Any + +from braintrust.integrations.base import FunctionWrapperPatcher + +from .tracing import ( + _acompletion_wrapper_async, + _aresponses_wrapper_async, + _completion_wrapper, + _embedding_wrapper, + _moderation_wrapper, + _responses_wrapper, +) + + +# --------------------------------------------------------------------------- +# Individual patchers +# --------------------------------------------------------------------------- + + +class LiteLLMCompletionPatcher(FunctionWrapperPatcher): + name = "litellm.completion" + target_path = "completion" + wrapper = _completion_wrapper + + +class LiteLLMAcompletionPatcher(FunctionWrapperPatcher): + name = "litellm.acompletion" + target_path = "acompletion" + wrapper = _acompletion_wrapper_async + + +class LiteLLMResponsesPatcher(FunctionWrapperPatcher): + name = "litellm.responses" + target_path = "responses" + wrapper = _responses_wrapper + + +class LiteLLMAresponsesPatcher(FunctionWrapperPatcher): + name = "litellm.aresponses" + target_path = "aresponses" + wrapper = _aresponses_wrapper_async + + +class LiteLLMEmbeddingPatcher(FunctionWrapperPatcher): + name = "litellm.embedding" + target_path = "embedding" + wrapper = _embedding_wrapper + + +class LiteLLMModerationPatcher(FunctionWrapperPatcher): + name = "litellm.moderation" + target_path = "moderation" + wrapper = _moderation_wrapper + + +# --------------------------------------------------------------------------- +# All patchers, in declaration order +# --------------------------------------------------------------------------- + +_ALL_LITELLM_PATCHERS = ( + LiteLLMCompletionPatcher, + LiteLLMAcompletionPatcher, + LiteLLMResponsesPatcher, + LiteLLMAresponsesPatcher, + LiteLLMEmbeddingPatcher, + LiteLLMModerationPatcher, +) + + +# --------------------------------------------------------------------------- +# Manual wrapping helper +# --------------------------------------------------------------------------- + + +def wrap_litellm(litellm: Any) -> Any: + """Wrap a LiteLLM module to add Braintrust tracing. + + Unlike :func:`patch_litellm`, which patches the globally-imported ``litellm`` + module, this function instruments a specific module object (or any object + that exposes the same top-level callables such as ``completion``, + ``acompletion``, ``responses``, ``aresponses``, ``embedding``, and + ``moderation``). Each patcher is applied idempotently — calling + ``wrap_litellm`` twice on the same object is safe. + + Args: + litellm: The ``litellm`` module or a module-like object that exposes + the standard LiteLLM top-level functions. + + Returns: + The same *litellm* object, with tracing wrappers applied in-place. + + Example:: + + import litellm + from braintrust.integrations.litellm import wrap_litellm + + wrap_litellm(litellm) + + # All subsequent calls are automatically traced. + response = litellm.completion( + model="gpt-4o-mini", + messages=[{"role": "user", "content": "Hello"}], + ) + """ + for patcher in _ALL_LITELLM_PATCHERS: + patcher.wrap_target(litellm) + return litellm diff --git a/py/src/braintrust/wrappers/test_litellm.py b/py/src/braintrust/integrations/litellm/test_litellm.py similarity index 76% rename from py/src/braintrust/wrappers/test_litellm.py rename to py/src/braintrust/integrations/litellm/test_litellm.py index 6020634c..ac26f1fd 100644 --- a/py/src/braintrust/wrappers/test_litellm.py +++ b/py/src/braintrust/integrations/litellm/test_litellm.py @@ -4,8 +4,8 @@ import litellm import pytest from braintrust import logger +from braintrust.integrations.litellm import patch_litellm from braintrust.test_helpers import assert_dict_matches, init_test_logger -from braintrust.wrappers.litellm import wrap_litellm from braintrust.wrappers.test_utils import assert_metrics_are_valid, verify_autoinstrument_script @@ -16,6 +16,11 @@ TEST_SYSTEM_PROMPT = "You are a helpful assistant that only responds with numbers." +@pytest.fixture(autouse=True) +def _patch(): + patch_litellm() + + @pytest.fixture def memory_logger(): init_test_logger(PROJECT_NAME) @@ -27,27 +32,14 @@ def memory_logger(): def test_litellm_completion_metrics(memory_logger) -> None: assert not memory_logger.pop() - # Test unwrapped client first - response = litellm.completion(model=TEST_MODEL, messages=[{"role": "user", "content": TEST_PROMPT}]) - assert response - assert response.choices[0].message.content - assert "24" in response.choices[0].message.content or "twenty-four" in response.choices[0].message.content.lower() - - # No spans should be generated with unwrapped client - assert not memory_logger.pop() - - # Now test with wrapped client - wrapped_litellm = wrap_litellm(litellm) - start = time.time() - response = wrapped_litellm.completion(model=TEST_MODEL, messages=[{"role": "user", "content": TEST_PROMPT}]) + response = litellm.completion(model=TEST_MODEL, messages=[{"role": "user", "content": TEST_PROMPT}]) end = time.time() assert response assert response.choices[0].message.content assert "24" in response.choices[0].message.content or "twenty-four" in response.choices[0].message.content.lower() - # Verify spans were created with wrapped client spans = memory_logger.pop() assert len(spans) == 1 span = spans[0] @@ -64,17 +56,14 @@ def test_litellm_completion_metrics(memory_logger) -> None: async def test_litellm_acompletion_metrics(memory_logger): assert not memory_logger.pop() - wrapped_litellm = wrap_litellm(litellm) - start = time.time() - response = await wrapped_litellm.acompletion(model=TEST_MODEL, messages=[{"role": "user", "content": TEST_PROMPT}]) + response = await litellm.acompletion(model=TEST_MODEL, messages=[{"role": "user", "content": TEST_PROMPT}]) end = time.time() assert response assert response.choices[0].message.content assert "24" in response.choices[0].message.content or "twenty-four" in response.choices[0].message.content.lower() - # Verify spans were created with wrapped client spans = memory_logger.pop() assert len(spans) == 1 span = spans[0] @@ -90,38 +79,8 @@ async def test_litellm_acompletion_metrics(memory_logger): def test_litellm_completion_streaming_sync(memory_logger): assert not memory_logger.pop() - # Test unwrapped client first - stream = litellm.completion( - model=TEST_MODEL, - messages=[{"role": "user", "content": TEST_PROMPT}], - stream=True, - ) - - chunks = [] - for chunk in stream: - chunks.append(chunk) - - # Verify streaming works - assert chunks - assert len(chunks) > 1 - - # Concatenate content from chunks to verify - content = "" - for chunk in chunks: - if chunk.choices and chunk.choices[0].delta.content: - content += chunk.choices[0].delta.content - - # Make sure we got a valid answer in the content - assert "24" in content or "twenty-four" in content.lower() - - # No spans should be generated with unwrapped client - assert not memory_logger.pop() - - # Now test with wrapped client - wrapped_litellm = wrap_litellm(litellm) - start = time.time() - stream = wrapped_litellm.completion( + stream = litellm.completion( model=TEST_MODEL, messages=[{"role": "user", "content": TEST_PROMPT}], stream=True, @@ -145,7 +104,7 @@ def test_litellm_completion_streaming_sync(memory_logger): # Make sure we got a valid answer in the content assert "24" in content or "twenty-four" in content.lower() - # Verify spans were created with wrapped client + # Verify spans were created spans = memory_logger.pop() assert len(spans) == 1 span = spans[0] @@ -163,10 +122,8 @@ def test_litellm_completion_streaming_sync(memory_logger): async def test_litellm_acompletion_streaming_async(memory_logger): assert not memory_logger.pop() - wrapped_litellm = wrap_litellm(litellm) - start = time.time() - stream = await wrapped_litellm.acompletion( + stream = await litellm.acompletion( model=TEST_MODEL, messages=[{"role": "user", "content": TEST_PROMPT}], stream=True, @@ -181,7 +138,7 @@ async def test_litellm_acompletion_streaming_async(memory_logger): assert chunks assert len(chunks) > 1 - # Verify spans were created with wrapped client + # Verify spans were created spans = memory_logger.pop() assert len(spans) == 1 span = spans[0] @@ -197,25 +154,8 @@ async def test_litellm_acompletion_streaming_async(memory_logger): def test_litellm_responses_metrics(memory_logger): assert not memory_logger.pop() - # Test unwrapped client first - response = litellm.responses( - model=TEST_MODEL, - input=TEST_PROMPT, - instructions="Just the number please", - ) - assert response - assert response.output - assert len(response.output) > 0 - unwrapped_content = response.output[0].content[0].text - - # No spans should be generated with unwrapped client - assert not memory_logger.pop() - - # Now test with wrapped client - wrapped_litellm = wrap_litellm(litellm) - start = time.time() - response = wrapped_litellm.responses( + response = litellm.responses( model=TEST_MODEL, input=TEST_PROMPT, instructions="Just the number please", @@ -225,13 +165,10 @@ def test_litellm_responses_metrics(memory_logger): assert response assert response.output assert len(response.output) > 0 - wrapped_content = response.output[0].content[0].text - - # Both should contain a numeric response for the math question - assert "24" in unwrapped_content or "twenty-four" in unwrapped_content.lower() - assert "24" in wrapped_content or "twenty-four" in wrapped_content.lower() + content = response.output[0].content[0].text + assert "24" in content or "twenty-four" in content.lower() - # Verify spans were created with wrapped client + # Verify spans were created spans = memory_logger.pop() assert len(spans) == 1 span = spans[0] @@ -248,10 +185,8 @@ def test_litellm_responses_metrics(memory_logger): async def test_litellm_aresponses_metrics(memory_logger): assert not memory_logger.pop() - wrapped_litellm = wrap_litellm(litellm) - start = time.time() - response = await wrapped_litellm.aresponses( + response = await litellm.aresponses( model=TEST_MODEL, input=TEST_PROMPT, instructions="Just the number please", @@ -261,11 +196,10 @@ async def test_litellm_aresponses_metrics(memory_logger): assert response assert response.output assert len(response.output) > 0 - wrapped_content = response.output[0].content[0].text - - assert "24" in wrapped_content or "twenty-four" in wrapped_content.lower() + content = response.output[0].content[0].text + assert "24" in content or "twenty-four" in content.lower() - # Verify spans were created with wrapped client + # Verify spans were created spans = memory_logger.pop() assert len(spans) == 1 span = spans[0] @@ -281,27 +215,15 @@ async def test_litellm_aresponses_metrics(memory_logger): def test_litellm_embeddings(memory_logger): assert not memory_logger.pop() - # Test unwrapped client first - response = litellm.embedding(model="text-embedding-ada-002", input="This is a test") - assert response - assert response.data - assert response.data[0]["embedding"] - - # No spans should be generated with unwrapped client - assert not memory_logger.pop() - - # Now test with wrapped client - wrapped_litellm = wrap_litellm(litellm) - start = time.time() - response = wrapped_litellm.embedding(model="text-embedding-ada-002", input="This is a test") + response = litellm.embedding(model="text-embedding-ada-002", input="This is a test") end = time.time() assert response assert response.data assert response.data[0]["embedding"] - # Verify spans were created with wrapped client + # Verify spans were created spans = memory_logger.pop() assert len(spans) == 1 span = spans[0] @@ -315,30 +237,18 @@ def test_litellm_embeddings(memory_logger): def test_litellm_moderation(memory_logger): assert not memory_logger.pop() - # Test unwrapped client first - response = litellm.moderation(model="text-moderation-latest", input="This is a test message") - assert response - assert response.results - - # No spans should be generated with unwrapped client - assert not memory_logger.pop() - - # Now test with wrapped client - wrapped_litellm = wrap_litellm(litellm) - start = time.time() - response = wrapped_litellm.moderation(model="text-moderation-latest", input="This is a test message") + response = litellm.moderation(model="text-moderation-latest", input="This is a test message") end = time.time() assert response assert response.results - # Verify spans were created with wrapped client + # Verify spans were created spans = memory_logger.pop() assert len(spans) == 1 span = spans[0] assert span - metrics = span["metrics"] assert span["metadata"]["model"] == "text-moderation-latest" assert span["metadata"]["provider"] == "litellm" assert "This is a test message" in str(span["input"]) @@ -348,9 +258,7 @@ def test_litellm_moderation(memory_logger): def test_litellm_completion_with_system_prompt(memory_logger): assert not memory_logger.pop() - wrapped_litellm = wrap_litellm(litellm) - - response = wrapped_litellm.completion( + response = litellm.completion( model=TEST_MODEL, messages=[{"role": "system", "content": TEST_SYSTEM_PROMPT}, {"role": "user", "content": TEST_PROMPT}], ) @@ -375,9 +283,7 @@ def test_litellm_completion_with_system_prompt(memory_logger): async def test_litellm_acompletion_with_system_prompt(memory_logger): assert not memory_logger.pop() - wrapped_litellm = wrap_litellm(litellm) - - response = await wrapped_litellm.acompletion( + response = await litellm.acompletion( model=TEST_MODEL, messages=[{"role": "system", "content": TEST_SYSTEM_PROMPT}, {"role": "user", "content": TEST_PROMPT}], ) @@ -401,13 +307,11 @@ async def test_litellm_acompletion_with_system_prompt(memory_logger): def test_litellm_completion_error(memory_logger): assert not memory_logger.pop() - wrapped_litellm = wrap_litellm(litellm) - # Use a non-existent model to force an error fake_model = "non-existent-model" try: - wrapped_litellm.completion(model=fake_model, messages=[{"role": "user", "content": TEST_PROMPT}]) + litellm.completion(model=fake_model, messages=[{"role": "user", "content": TEST_PROMPT}]) pytest.fail("Expected an exception but none was raised") except Exception: # We expect an error here @@ -426,13 +330,11 @@ def test_litellm_completion_error(memory_logger): async def test_litellm_acompletion_error(memory_logger): assert not memory_logger.pop() - wrapped_litellm = wrap_litellm(litellm) - # Use a non-existent model to force an error fake_model = "non-existent-model" try: - await wrapped_litellm.acompletion(model=fake_model, messages=[{"role": "user", "content": TEST_PROMPT}]) + await litellm.acompletion(model=fake_model, messages=[{"role": "user", "content": TEST_PROMPT}]) pytest.fail("Expected an exception but none was raised") except Exception: # We expect an error here @@ -449,18 +351,15 @@ async def test_litellm_acompletion_error(memory_logger): @pytest.mark.vcr @pytest.mark.asyncio async def test_litellm_async_parallel_requests(memory_logger): - """Test multiple parallel async requests with the wrapped client.""" + """Test multiple parallel async requests.""" assert not memory_logger.pop() - wrapped_litellm = wrap_litellm(litellm) - # Create multiple prompts prompts = [f"What is {i} + {i}?" for i in range(3, 6)] # Run requests in parallel tasks = [ - wrapped_litellm.acompletion(model=TEST_MODEL, messages=[{"role": "user", "content": prompt}]) - for prompt in prompts + litellm.acompletion(model=TEST_MODEL, messages=[{"role": "user", "content": prompt}]) for prompt in prompts ] # Wait for all to complete @@ -504,10 +403,8 @@ def test_litellm_tool_calls(memory_logger): }, ] - wrapped_litellm = wrap_litellm(litellm) - start = time.time() - response = wrapped_litellm.completion( + response = litellm.completion( model=TEST_MODEL, messages=[{"role": "user", "content": "What's the weather in New York?"}], tools=tools, @@ -546,10 +443,8 @@ def test_litellm_responses_streaming_sync(memory_logger): """Test the responses API with streaming.""" assert not memory_logger.pop() - wrapped_litellm = wrap_litellm(litellm) - start = time.time() - stream = wrapped_litellm.responses(model=TEST_MODEL, input="What's 12 + 12?", stream=True) + stream = litellm.responses(model=TEST_MODEL, input="What's 12 + 12?", stream=True) chunks = [] for chunk in stream: @@ -568,7 +463,7 @@ def test_litellm_responses_streaming_sync(memory_logger): span = spans[0] metrics = span["metrics"] assert_metrics_are_valid(metrics, start, end) - assert span["metadata"]["stream"] == True + assert span["metadata"]["stream"] is True assert "What's 12 + 12?" in str(span["input"]) assert "24" in str(span["output"]) @@ -579,10 +474,8 @@ async def test_litellm_aresponses_streaming_async(memory_logger): """Test the async responses API with streaming.""" assert not memory_logger.pop() - wrapped_litellm = wrap_litellm(litellm) - start = time.time() - stream = await wrapped_litellm.aresponses(model=TEST_MODEL, input="What's 12 + 12?", stream=True) + stream = await litellm.aresponses(model=TEST_MODEL, input="What's 12 + 12?", stream=True) chunks = [] async for chunk in stream: @@ -598,7 +491,7 @@ async def test_litellm_aresponses_streaming_async(memory_logger): span = spans[0] metrics = span["metrics"] assert_metrics_are_valid(metrics, start, end) - assert span["metadata"]["stream"] == True + assert span["metadata"]["stream"] is True assert "What's 12 + 12?" in str(span["input"]) @@ -608,10 +501,8 @@ async def test_litellm_async_streaming_with_break(memory_logger): """Test breaking out of the async streaming loop early.""" assert not memory_logger.pop() - wrapped_litellm = wrap_litellm(litellm) - start = time.time() - stream = await wrapped_litellm.acompletion( + stream = await litellm.acompletion( model=TEST_MODEL, messages=[{"role": "user", "content": TEST_PROMPT}], stream=True ) @@ -665,7 +556,7 @@ def test_litellm_parse_metrics_excludes_booleans(): When OpenRouter returns usage data with `is_byok: true`, the metrics parser should filter it out rather than passing it through to the API. """ - from braintrust.wrappers.litellm import _parse_metrics_from_usage + from braintrust.integrations.litellm.tracing import _parse_metrics_from_usage usage = { "prompt_tokens": 10, @@ -695,10 +586,8 @@ def test_litellm_openrouter_no_booleans_in_metrics(memory_logger): assert not memory_logger.pop() - wrapped_litellm = wrap_litellm(litellm) - start = time.time() - response = wrapped_litellm.completion( + response = litellm.completion( model="openrouter/openai/gpt-4o-mini", messages=[{"role": "user", "content": "What is 2+2? Reply with just the number."}], max_tokens=10, diff --git a/py/src/braintrust/integrations/litellm/tracing.py b/py/src/braintrust/integrations/litellm/tracing.py new file mode 100644 index 00000000..7fcecb60 --- /dev/null +++ b/py/src/braintrust/integrations/litellm/tracing.py @@ -0,0 +1,531 @@ +"""LiteLLM tracing helpers — spans, metadata extraction, stream handling.""" + +import time +from collections.abc import AsyncGenerator, Generator +from types import TracebackType +from typing import Any + +from braintrust.logger import Span, start_span +from braintrust.span_types import SpanTypeAttribute +from braintrust.util import is_numeric, merge_dicts + + +# LiteLLM's representation to Braintrust's representation +TOKEN_NAME_MAP: dict[str, str] = { + # chat API + "total_tokens": "tokens", + "prompt_tokens": "prompt_tokens", + "completion_tokens": "completion_tokens", + # responses API + "tokens": "tokens", + "input_tokens": "prompt_tokens", + "output_tokens": "completion_tokens", +} + +TOKEN_PREFIX_MAP: dict[str, str] = { + "input": "prompt", + "output": "completion", +} + + +# --------------------------------------------------------------------------- +# Async response wrapper (preserves async context manager / iterator behavior) +# --------------------------------------------------------------------------- + + +class AsyncResponseWrapper: + """Wrapper that properly preserves async context manager behavior for LiteLLM responses.""" + + def __init__(self, response: Any) -> None: + self._response = response + + async def __aenter__(self) -> Any: + if hasattr(self._response, "__aenter__"): + return await self._response.__aenter__() + return self._response + + async def __aexit__( + self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None + ) -> bool | None: + if hasattr(self._response, "__aexit__"): + return await self._response.__aexit__(exc_type, exc_val, exc_tb) + return None + + def __aiter__(self) -> AsyncGenerator[Any, None]: + if hasattr(self._response, "__aiter__"): + return self._response.__aiter__() + raise TypeError("Response object is not an async iterator") + + async def __anext__(self) -> Any: + if hasattr(self._response, "__anext__"): + return await self._response.__anext__() + raise StopAsyncIteration + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + +# --------------------------------------------------------------------------- +# Streaming helpers +# --------------------------------------------------------------------------- + + +def _handle_completion_streaming( + raw_response: Any, span: Span, start_time: float, is_async: bool = False +) -> AsyncResponseWrapper | Generator[Any, None, None]: + """Handle streaming response for completion (sync and async).""" + if is_async: + + async def async_gen() -> AsyncGenerator[Any, None]: + try: + first = True + all_results: list[dict[str, Any]] = [] + async for item in raw_response: + if first: + span.log(metrics={"time_to_first_token": time.time() - start_time}) + first = False + all_results.append(_try_to_dict(item)) + yield item + + span.log(**_postprocess_completion_streaming_results(all_results)) + finally: + span.end() + + return AsyncResponseWrapper(async_gen()) + else: + + def sync_gen() -> Generator[Any, None, None]: + try: + first = True + all_results: list[dict[str, Any]] = [] + for item in raw_response: + if first: + span.log(metrics={"time_to_first_token": time.time() - start_time}) + first = False + all_results.append(_try_to_dict(item)) + yield item + + span.log(**_postprocess_completion_streaming_results(all_results)) + finally: + span.end() + + return sync_gen() + + +def _handle_responses_streaming( + raw_response: Any, span: Span, start_time: float, is_async: bool = False +) -> AsyncResponseWrapper | Generator[Any, None, None]: + """Handle streaming response for responses API (sync and async).""" + if is_async: + + async def async_gen() -> AsyncGenerator[Any, None]: + try: + first = True + all_results: list[Any] = [] + async for item in raw_response: + if first: + span.log(metrics={"time_to_first_token": time.time() - start_time}) + first = False + all_results.append(item) + yield item + + span.log(**_postprocess_responses_streaming_results(all_results)) + finally: + span.end() + + return AsyncResponseWrapper(async_gen()) + else: + + def sync_gen() -> Generator[Any, None, None]: + try: + first = True + all_results: list[Any] = [] + for item in raw_response: + if first: + span.log(metrics={"time_to_first_token": time.time() - start_time}) + first = False + all_results.append(item) + yield item + + span.log(**_postprocess_responses_streaming_results(all_results)) + finally: + span.end() + + return sync_gen() + + +# --------------------------------------------------------------------------- +# wrapt-style wrapper functions (used by FunctionWrapperPatcher) +# --------------------------------------------------------------------------- + + +def _completion_wrapper(wrapped, instance, args, kwargs): + """wrapt wrapper for litellm.completion.""" + updated_span_payload = _update_span_payload_from_params(kwargs, input_key="messages") + is_streaming = kwargs.get("stream", False) + + span = start_span( + **merge_dicts(dict(name="Completion", span_attributes={"type": SpanTypeAttribute.LLM}), updated_span_payload) + ) + should_end = True + + try: + start = time.time() + completion_response = wrapped(*args, **kwargs) + + if is_streaming: + should_end = False + return _handle_completion_streaming(completion_response, span, start, is_async=False) + else: + log_response = _try_to_dict(completion_response) + metrics = _parse_metrics_from_usage(log_response.get("usage", {})) + metrics["time_to_first_token"] = time.time() - start + span.log(metrics=metrics, output=log_response["choices"]) + return completion_response + finally: + if should_end: + span.end() + + +async def _acompletion_wrapper_async(wrapped, instance, args, kwargs): + """wrapt wrapper for litellm.acompletion.""" + updated_span_payload = _update_span_payload_from_params(kwargs, input_key="messages") + is_streaming = kwargs.get("stream", False) + + span = start_span( + **merge_dicts(dict(name="Completion", span_attributes={"type": SpanTypeAttribute.LLM}), updated_span_payload) + ) + should_end = True + + try: + start = time.time() + completion_response = await wrapped(*args, **kwargs) + + if is_streaming: + should_end = False + return _handle_completion_streaming(completion_response, span, start, is_async=True) + else: + log_response = _try_to_dict(completion_response) + metrics = _parse_metrics_from_usage(log_response.get("usage", {})) + metrics["time_to_first_token"] = time.time() - start + span.log(metrics=metrics, output=log_response["choices"]) + return completion_response + finally: + if should_end: + span.end() + + +def _responses_wrapper(wrapped, instance, args, kwargs): + """wrapt wrapper for litellm.responses.""" + updated_span_payload = _update_span_payload_from_params(kwargs, input_key="input") + is_streaming = kwargs.get("stream", False) + + span = start_span( + **merge_dicts(dict(name="Response", span_attributes={"type": SpanTypeAttribute.LLM}), updated_span_payload) + ) + should_end = True + + try: + start = time.time() + response = wrapped(*args, **kwargs) + + if is_streaming: + should_end = False + return _handle_responses_streaming(response, span, start, is_async=False) + else: + log_response = _try_to_dict(response) + metrics = _parse_metrics_from_usage(log_response.get("usage", {})) + metrics["time_to_first_token"] = time.time() - start + span.log(metrics=metrics, output=log_response["output"]) + return response + finally: + if should_end: + span.end() + + +async def _aresponses_wrapper_async(wrapped, instance, args, kwargs): + """wrapt wrapper for litellm.aresponses.""" + updated_span_payload = _update_span_payload_from_params(kwargs, input_key="input") + is_streaming = kwargs.get("stream", False) + + span = start_span( + **merge_dicts(dict(name="Response", span_attributes={"type": SpanTypeAttribute.LLM}), updated_span_payload) + ) + should_end = True + + try: + start = time.time() + response = await wrapped(*args, **kwargs) + + if is_streaming: + should_end = False + return _handle_responses_streaming(response, span, start, is_async=True) + else: + log_response = _try_to_dict(response) + metrics = _parse_metrics_from_usage(log_response.get("usage", {})) + metrics["time_to_first_token"] = time.time() - start + span.log(metrics=metrics, output=log_response["output"]) + return response + finally: + if should_end: + span.end() + + +def _embedding_wrapper(wrapped, instance, args, kwargs): + """wrapt wrapper for litellm.embedding.""" + updated_span_payload = _update_span_payload_from_params(kwargs, input_key="input") + + with start_span( + **merge_dicts(dict(name="Embedding", span_attributes={"type": SpanTypeAttribute.LLM}), updated_span_payload) + ) as span: + embedding_response = wrapped(*args, **kwargs) + log_response = _try_to_dict(embedding_response) + usage = log_response.get("usage") + metrics = _parse_metrics_from_usage(usage) + span.log( + metrics=metrics, + output={"embedding_length": len(log_response["data"][0]["embedding"])}, + ) + return embedding_response + + +def _moderation_wrapper(wrapped, instance, args, kwargs): + """wrapt wrapper for litellm.moderation.""" + updated_span_payload = _update_span_payload_from_params(kwargs, input_key="input") + + with start_span( + **merge_dicts(dict(name="Moderation", span_attributes={"type": SpanTypeAttribute.LLM}), updated_span_payload) + ) as span: + moderation_response = wrapped(*args, **kwargs) + log_response = _try_to_dict(moderation_response) + usage = log_response.get("usage") + metrics = _parse_metrics_from_usage(usage) + span.log( + metrics=metrics, + output=log_response["results"], + ) + return moderation_response + + +# --------------------------------------------------------------------------- +# Streaming post-processing +# --------------------------------------------------------------------------- + + +def _postprocess_completion_streaming_results(all_results: list[dict[str, Any]]) -> dict[str, Any]: + """Process streaming results to extract final response.""" + role = None + content = None + tool_calls: list[Any] | None = None + finish_reason = None + metrics: dict[str, float] = {} + + for result in all_results: + usage = result.get("usage") + if usage: + metrics.update(_parse_metrics_from_usage(usage)) + + choices = result["choices"] + if not choices: + continue + delta = choices[0]["delta"] + if not delta: + continue + + if role is None and delta.get("role") is not None: + role = delta.get("role") + + if delta.get("finish_reason") is not None: + finish_reason = delta.get("finish_reason") + + if delta.get("content") is not None: + content = (content or "") + delta.get("content") + + if delta.get("tool_calls") is not None: + delta_tool_calls = delta.get("tool_calls") + if not delta_tool_calls: + continue + tool_delta = delta_tool_calls[0] + + # pylint: disable=unsubscriptable-object + if not tool_calls or (tool_delta.get("id") and tool_calls[-1]["id"] != tool_delta.get("id")): + tool_calls = (tool_calls or []) + [ + { + "id": tool_delta.get("id"), + "type": tool_delta.get("type"), + "function": tool_delta.get("function"), + } + ] + else: + # pylint: disable=unsubscriptable-object + tool_calls[-1]["function"]["arguments"] += delta["tool_calls"][0]["function"]["arguments"] + + return { + "metrics": metrics, + "output": [ + { + "index": 0, + "message": {"role": role, "content": content, "tool_calls": tool_calls}, + "logprobs": None, + "finish_reason": finish_reason, + } + ], + } + + +def _postprocess_responses_streaming_results(all_results: list[Any]) -> dict[str, Any]: + """Process responses API streaming results.""" + metrics: dict[str, Any] = {} + output: list[dict[str, Any]] = [] + + for result in all_results: + usage = None + if hasattr(result, "usage"): + usage = getattr(result, "usage") + elif result.type == "response.completed" and hasattr(result, "response"): + usage = getattr(result.response, "usage") + + if usage: + parsed_metrics = _parse_metrics_from_usage(usage) + metrics.update(parsed_metrics) + + if result.type == "response.output_item.added": + output.append({"id": result.item.get("id"), "type": result.item.get("type")}) + continue + + if not hasattr(result, "output_index"): + continue + + output_index = result.output_index + current_output = output[output_index] + if result.type == "response.output_item.done": + current_output["status"] = result.item.get("status") + continue + + if result.type == "response.output_item.delta": + current_output["delta"] = result.delta + continue + + if hasattr(result, "content_index"): + if "content" not in current_output: + current_output["content"] = [] + content_index = result.content_index + if content_index == len(current_output["content"]): + current_output["content"].append({}) + current_content = current_output["content"][content_index] + if hasattr(result, "delta") and result.delta: + current_content["text"] = (current_content.get("text") or "") + result.delta + + if result.type == "response.output_text.annotation.added": + annotation_index = result.annotation_index + if "annotations" not in current_content: + current_content["annotations"] = [] + if annotation_index == len(current_content["annotations"]): + current_content["annotations"].append({}) + current_content["annotations"][annotation_index] = _try_to_dict(result.annotation) + + return { + "metrics": metrics, + "output": output, + } + + +# --------------------------------------------------------------------------- +# Utility helpers +# --------------------------------------------------------------------------- + + +def _update_span_payload_from_params(params: dict[str, Any], input_key: str = "input") -> dict[str, Any]: + """Updates the span payload with the parameters into LiteLLM's completion/acompletion methods. + + Works on a shallow copy so the caller's kwargs dict is never mutated. + """ + params = params.copy() + span_info_d = params.pop("span_info", {}) + + params = prettify_params(params) + input_data = params.pop(input_key, None) + model = params.pop("model", None) + + return merge_dicts( + span_info_d, + {"input": input_data, "metadata": {**params, "provider": "litellm", "model": model}}, + ) + + +def _parse_metrics_from_usage(usage: Any) -> dict[str, Any]: + """Parse usage metrics from API response.""" + metrics: dict[str, Any] = {} + + if not usage: + return metrics + + usage = _try_to_dict(usage) + if not isinstance(usage, dict): + return metrics + + for oai_name, value in usage.items(): + if oai_name.endswith("_tokens_details"): + if not isinstance(value, dict): + continue + raw_prefix = oai_name[: -len("_tokens_details")] + prefix = TOKEN_PREFIX_MAP.get(raw_prefix, raw_prefix) + for k, v in value.items(): + if is_numeric(v): + metrics[f"{prefix}_{k}"] = v + elif is_numeric(value): + name = TOKEN_NAME_MAP.get(oai_name, oai_name) + metrics[name] = value + + return metrics + + +def prettify_params(params: dict[str, Any]) -> dict[str, Any]: + """Return a shallow copy of *params* with response_format serialized for logging.""" + + if "response_format" in params: + ret = params.copy() + ret["response_format"] = serialize_response_format(ret["response_format"]) + return ret + + return params + + +def _try_to_dict(obj: Any) -> dict[str, Any] | Any: + """Try to convert an object to a dictionary.""" + if isinstance(obj, dict): + return obj + if hasattr(obj, "model_dump") and callable(obj.model_dump): + try: + result = obj.model_dump() + if isinstance(result, dict): + return result + except Exception: + pass + if hasattr(obj, "dict") and callable(obj.dict): + try: + result = obj.dict() + if isinstance(result, dict): + return result + except Exception: + pass + return obj + + +def serialize_response_format(response_format: Any) -> Any: + """Serialize response format for logging.""" + try: + from pydantic import BaseModel + except ImportError: + return response_format + + if isinstance(response_format, type) and issubclass(response_format, BaseModel): + return dict( + type="json_schema", + json_schema=dict( + name=response_format.__name__, + schema=response_format.model_json_schema(), + ), + ) + else: + return response_format diff --git a/py/src/braintrust/wrappers/litellm.py b/py/src/braintrust/wrappers/litellm.py index 236df998..ddaf1ef1 100644 --- a/py/src/braintrust/wrappers/litellm.py +++ b/py/src/braintrust/wrappers/litellm.py @@ -1,667 +1,12 @@ -from __future__ import annotations +"""Compatibility re-exports — implementation lives in braintrust.integrations.litellm.""" -import time -from collections.abc import AsyncGenerator, Callable, Generator -from types import TracebackType -from typing import Any +from braintrust.integrations.litellm import ( + patch_litellm, # noqa: F401 + wrap_litellm, # noqa: F401 +) -from braintrust.logger import Span, start_span -from braintrust.span_types import SpanTypeAttribute -from braintrust.util import is_numeric, merge_dicts - -X_LEGACY_CACHED_HEADER = "x-cached" -X_CACHED_HEADER = "x-bt-cached" - - -# LiteLLM's representation to Braintrust's representation -TOKEN_NAME_MAP: dict[str, str] = { - # chat API - "total_tokens": "tokens", - "prompt_tokens": "prompt_tokens", - "completion_tokens": "completion_tokens", - # responses API - "tokens": "tokens", - "input_tokens": "prompt_tokens", - "output_tokens": "completion_tokens", -} - -TOKEN_PREFIX_MAP: dict[str, str] = { - "input": "prompt", - "output": "completion", -} - - -class NamedWrapper: - """Wrapper that preserves access to the original wrapped object's attributes.""" - - def __init__(self, wrapped: Any) -> None: - self.__wrapped = wrapped - - def __getattr__(self, name: str) -> Any: - return getattr(self.__wrapped, name) - - -class AsyncResponseWrapper: - """Wrapper that properly preserves async context manager behavior for LiteLLM responses.""" - - def __init__(self, response: Any) -> None: - self._response = response - - async def __aenter__(self) -> Any: - if hasattr(self._response, "__aenter__"): - return await self._response.__aenter__() - return self._response - - async def __aexit__( - self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None - ) -> bool | None: - if hasattr(self._response, "__aexit__"): - return await self._response.__aexit__(exc_type, exc_val, exc_tb) - return None - - def __aiter__(self) -> AsyncGenerator[Any, None]: - if hasattr(self._response, "__aiter__"): - return self._response.__aiter__() - raise TypeError("Response object is not an async iterator") - - async def __anext__(self) -> Any: - if hasattr(self._response, "__anext__"): - return await self._response.__anext__() - raise StopAsyncIteration - - def __getattr__(self, name: str) -> Any: - return getattr(self._response, name) - - -class CompletionWrapper: - """Wrapper for LiteLLM completion functions with tracing support.""" - - def __init__(self, completion_fn: Callable[..., Any] | None, acompletion_fn: Callable[..., Any] | None) -> None: - self.completion_fn = completion_fn - self.acompletion_fn = acompletion_fn - - def _handle_streaming_response( - self, raw_response: Any, span: Span, start_time: float, is_async: bool = False - ) -> AsyncResponseWrapper | Generator[Any, None, None]: - """Handle streaming response for both sync and async cases.""" - if is_async: - - async def async_gen() -> AsyncGenerator[Any, None]: - try: - first = True - all_results: list[dict[str, Any]] = [] - async for item in raw_response: - if first: - span.log(metrics={"time_to_first_token": time.time() - start_time}) - first = False - all_results.append(_try_to_dict(item)) - yield item - - span.log(**self._postprocess_streaming_results(all_results)) - finally: - span.end() - - streamer = async_gen() - return AsyncResponseWrapper(streamer) - else: - - def sync_gen() -> Generator[Any, None, None]: - try: - first = True - all_results: list[dict[str, Any]] = [] - for item in raw_response: - if first: - span.log(metrics={"time_to_first_token": time.time() - start_time}) - first = False - all_results.append(_try_to_dict(item)) - yield item - - span.log(**self._postprocess_streaming_results(all_results)) - finally: - span.end() - - return sync_gen() - - def _handle_non_streaming_response(self, raw_response: Any, span: Span, start_time: float) -> Any: - """Handle non-streaming response.""" - log_response = _try_to_dict(raw_response) - metrics = _parse_metrics_from_usage(log_response.get("usage", {})) - metrics["time_to_first_token"] = time.time() - start_time - span.log(metrics=metrics, output=log_response["choices"]) - return raw_response - - def completion(self, *args: Any, **kwargs: Any) -> Any: - """Sync completion with tracing.""" - updated_span_payload = _update_span_payload_from_params(kwargs, input_key="messages") - is_streaming = kwargs.get("stream", False) - - span = start_span( - **merge_dicts( - dict(name="Completion", span_attributes={"type": SpanTypeAttribute.LLM}), updated_span_payload - ) - ) - should_end = True - - try: - start = time.time() - completion_response = self.completion_fn(*args, **kwargs) - # if hasattr(completion_response, "parse"): - # raw_response = completion_response.parse() - # log_headers(completion_response, span) - # else: - # raw_response = completion_response - - if is_streaming: - should_end = False - return self._handle_streaming_response(completion_response, span, start, is_async=False) - else: - return self._handle_non_streaming_response(completion_response, span, start) - finally: - if should_end: - span.end() - - async def acompletion(self, *args: Any, **kwargs: Any) -> Any: - """Async completion with tracing.""" - updated_span_payload = _update_span_payload_from_params(kwargs, input_key="messages") - is_streaming = kwargs.get("stream", False) - - span = start_span( - **merge_dicts( - dict(name="Completion", span_attributes={"type": SpanTypeAttribute.LLM}), updated_span_payload - ) - ) - should_end = True - - try: - start = time.time() - completion_response = await self.acompletion_fn(*args, **kwargs) - - # if hasattr(completion_response, "parse"): - # raw_response = completion_response.parse() - # log_headers(completion_response, span) - # else: - # raw_response = completion_response - - if is_streaming: - should_end = False - return self._handle_streaming_response(completion_response, span, start, is_async=True) - else: - return self._handle_non_streaming_response(completion_response, span, start) - finally: - if should_end: - span.end() - - @classmethod - def _postprocess_streaming_results(cls, all_results: list[dict[str, Any]]) -> dict[str, Any]: - """Process streaming results to extract final response.""" - role = None - content = None - tool_calls: list[Any] | None = None - finish_reason = None - metrics: dict[str, float] = {} - - for result in all_results: - usage = result.get("usage") - if usage: - metrics.update(_parse_metrics_from_usage(usage)) - - choices = result["choices"] - if not choices: - continue - delta = choices[0]["delta"] - if not delta: - continue - - if role is None and delta.get("role") is not None: - role = delta.get("role") - - if delta.get("finish_reason") is not None: - finish_reason = delta.get("finish_reason") - - if delta.get("content") is not None: - content = (content or "") + delta.get("content") - - if delta.get("tool_calls") is not None: - delta_tool_calls = delta.get("tool_calls") - if not delta_tool_calls: - continue - tool_delta = delta_tool_calls[0] - - # pylint: disable=unsubscriptable-object - if not tool_calls or (tool_delta.get("id") and tool_calls[-1]["id"] != tool_delta.get("id")): - tool_calls = (tool_calls or []) + [ - { - "id": tool_delta.get("id"), - "type": tool_delta.get("type"), - "function": tool_delta.get("function"), - } - ] - else: - # pylint: disable=unsubscriptable-object - tool_calls[-1]["function"]["arguments"] += delta["tool_calls"][0]["function"]["arguments"] - - return { - "metrics": metrics, - "output": [ - { - "index": 0, - "message": {"role": role, "content": content, "tool_calls": tool_calls}, - "logprobs": None, - "finish_reason": finish_reason, - } - ], - } - - -class ResponsesWrapper: - """Wrapper for LiteLLM responses functions with tracing support.""" - - def __init__(self, responses_fn: Callable[..., Any] | None, aresponses_fn: Callable[..., Any] | None) -> None: - self.responses_fn = responses_fn - self.aresponses_fn = aresponses_fn - - def _handle_streaming_response( - self, raw_response: Any, span: Span, start_time: float, is_async: bool = False - ) -> AsyncResponseWrapper | Generator[Any, None, None]: - """Handle streaming response for both sync and async cases.""" - if is_async: - - async def async_gen() -> AsyncGenerator[Any, None]: - try: - first = True - all_results: list[dict[str, Any]] = [] - async for item in raw_response: - if first: - span.log(metrics={"time_to_first_token": time.time() - start_time}) - first = False - all_results.append(item) - yield item - - span.log(**self._postprocess_streaming_results(all_results)) - finally: - span.end() - - streamer = async_gen() - return AsyncResponseWrapper(streamer) - else: - - def sync_gen() -> Generator[Any, None, None]: - try: - first = True - all_results: list[dict[str, Any]] = [] - for item in raw_response: - if first: - span.log(metrics={"time_to_first_token": time.time() - start_time}) - first = False - all_results.append(item) - yield item - - span.log(**self._postprocess_streaming_results(all_results)) - finally: - span.end() - - return sync_gen() - - def _handle_non_streaming_response(self, raw_response: Any, span: Span, start_time: float) -> Any: - """Handle non-streaming response.""" - log_response = _try_to_dict(raw_response) - metrics = _parse_metrics_from_usage(log_response.get("usage", {})) - metrics["time_to_first_token"] = time.time() - start_time - span.log(metrics=metrics, output=log_response["output"]) - return raw_response - - def responses(self, *args: Any, **kwargs: Any) -> Any: - """Sync responses with tracing.""" - updated_span_payload = _update_span_payload_from_params(kwargs, input_key="input") - is_streaming = kwargs.get("stream", False) - - span = start_span( - **merge_dicts(dict(name="Response", span_attributes={"type": SpanTypeAttribute.LLM}), updated_span_payload) - ) - should_end = True - - try: - start = time.time() - response = self.responses_fn(*args, **kwargs) - - if is_streaming: - should_end = False - return self._handle_streaming_response(response, span, start, is_async=False) - else: - return self._handle_non_streaming_response(response, span, start) - finally: - if should_end: - span.end() - - async def aresponses(self, *args: Any, **kwargs: Any) -> Any: - """Async completion with tracing.""" - updated_span_payload = _update_span_payload_from_params(kwargs, input_key="input") - is_streaming = kwargs.get("stream", False) - - span = start_span( - **merge_dicts(dict(name="Response", span_attributes={"type": SpanTypeAttribute.LLM}), updated_span_payload) - ) - should_end = True - - try: - start = time.time() - response = await self.aresponses_fn(*args, **kwargs) - - if is_streaming: - should_end = False - return self._handle_streaming_response(response, span, start, is_async=True) - else: - return self._handle_non_streaming_response(response, span, start) - finally: - if should_end: - span.end() - - @classmethod - def _postprocess_streaming_results(cls, all_results: list[Any]) -> dict[str, Any]: - role = None - content = None - tool_calls = None - finish_reason = None - metrics = {} - output = [] - for result in all_results: - usage = None - if hasattr(result, "usage"): - usage = getattr(result, "usage") - elif result.type == "response.completed" and hasattr(result, "response"): - usage = getattr(result.response, "usage") - - if usage: - parsed_metrics = _parse_metrics_from_usage(usage) - metrics.update(parsed_metrics) - - if result.type == "response.output_item.added": - output.append({"id": result.item.get("id"), "type": result.item.get("type")}) - continue - - if not hasattr(result, "output_index"): - continue - - output_index = result.output_index - current_output = output[output_index] - if result.type == "response.output_item.done": - current_output["status"] = result.item.get("status") - continue - - if result.type == "response.output_item.delta": - current_output["delta"] = result.delta - continue - - if hasattr(result, "content_index"): - if "content" not in current_output: - current_output["content"] = [] - content_index = result.content_index - if content_index == len(current_output["content"]): - current_output["content"].append({}) - current_content = current_output["content"][content_index] - if hasattr(result, "delta") and result.delta: - current_content["text"] = (current_content.get("text") or "") + result.delta - - if result.type == "response.output_text.annotation.added": - annotation_index = result.annotation_index - if "annotations" not in current_content: - current_content["annotations"] = [] - if annotation_index == len(current_content["annotations"]): - current_content["annotations"].append({}) - current_content["annotations"][annotation_index] = _try_to_dict(result.annotation) - - return { - "metrics": metrics, - "output": output, - } - - -class EmbeddingWrapper: - """Wrapper for LiteLLM embedding functions.""" - - def __init__(self, embedding_fn: Callable[..., Any] | None) -> None: - self.embedding_fn = embedding_fn - - def embedding(self, *args: Any, **kwargs: Any) -> Any: - """Sync embedding with tracing.""" - updated_span_payload = _update_span_payload_from_params(kwargs, input_key="input") - - with start_span( - **merge_dicts( - dict(name="Embedding", span_attributes={"type": SpanTypeAttribute.LLM}), updated_span_payload - ) - ) as span: - embedding_response = self.embedding_fn(*args, **kwargs) - log_response = _try_to_dict(embedding_response) - self._process_output(log_response, span) - return embedding_response - - def _process_output(self, response: dict[str, Any], span: Span) -> None: - """Process embedding response and log metrics.""" - usage = response.get("usage") - metrics = _parse_metrics_from_usage(usage) - span.log( - metrics=metrics, - # TODO: Add a flag to control whether to log the full embedding vector, - # possibly w/ JSON compression. - output={"embedding_length": len(response["data"][0]["embedding"])}, - ) - - -class ModerationWrapper: - """Wrapper for LiteLLM moderation functions.""" - - def __init__(self, moderation_fn: Callable[..., Any] | None) -> None: - self.moderation_fn = moderation_fn - - def moderation(self, *args: Any, **kwargs: Any) -> Any: - """Sync moderation with tracing.""" - updated_span_payload = _update_span_payload_from_params(kwargs, input_key="input") - - with start_span( - **merge_dicts( - dict(name="Moderation", span_attributes={"type": SpanTypeAttribute.LLM}), updated_span_payload - ) - ) as span: - moderation_response = self.moderation_fn(*args, **kwargs) - log_response = _try_to_dict(moderation_response) - self._process_output(log_response, span) - return moderation_response - - def _process_output(self, response: dict[str, Any], span: Span) -> None: - """Process moderation response and log metrics.""" - usage = response.get("usage") - metrics = _parse_metrics_from_usage(usage) - span.log( - metrics=metrics, - # TODO: Add a flag to control whether to log the full embedding vector, - # possibly w/ JSON compression. - output=response["results"], - ) - - -class LiteLLMWrapper(NamedWrapper): - """Main wrapper for the LiteLLM module.""" - - def __init__(self, litellm_module: Any) -> None: - super().__init__(litellm_module) - self._completion_wrapper = CompletionWrapper(litellm_module.completion, None) - self._acompletion_wrapper = CompletionWrapper(None, litellm_module.acompletion) - self._responses_wrapper = ResponsesWrapper(litellm_module.responses, None) - self._aresponses_wrapper = ResponsesWrapper(None, litellm_module.aresponses) - self._embedding_wrapper = EmbeddingWrapper(litellm_module.embedding) - self._moderation_wrapper = ModerationWrapper(litellm_module.moderation) - - def completion(self, *args: Any, **kwargs: Any) -> Any: - """Sync completion with tracing.""" - return self._completion_wrapper.completion(*args, **kwargs) - - async def acompletion(self, *args: Any, **kwargs: Any) -> Any: - """Async completion with tracing.""" - return await self._acompletion_wrapper.acompletion(*args, **kwargs) - - def responses(self, *args: Any, **kwargs: Any) -> Any: - """Sync responses with tracing.""" - return self._responses_wrapper.responses(*args, **kwargs) - - async def aresponses(self, *args: Any, **kwargs: Any) -> Any: - """Async responses with tracing.""" - return await self._aresponses_wrapper.aresponses(*args, **kwargs) - - def embedding(self, *args: Any, **kwargs: Any) -> Any: - """Sync embedding with tracing.""" - return self._embedding_wrapper.embedding(*args, **kwargs) - - def moderation(self, *args: Any, **kwargs: Any) -> Any: - """Sync moderation with tracing.""" - return self._moderation_wrapper.moderation(*args, **kwargs) - - -def wrap_litellm(litellm_module: Any) -> LiteLLMWrapper: - """ - Wrap the litellm module to add tracing. - If Braintrust is not configured, nothing will be traced. - - :param litellm_module: The litellm module - :return: Wrapped litellm module with tracing - """ - return LiteLLMWrapper(litellm_module) - - -def _update_span_payload_from_params(params: dict[str, Any], input_key: str = "input") -> dict[str, Any]: - """Updates the span payload with the parameters into LiteLLM's completion/acompletion methods.""" - span_info_d = params.pop("span_info", {}) - - params = prettify_params(params) - input_data = params.pop(input_key, None) - model = params.pop("model", None) - - return merge_dicts( - span_info_d, - {"input": input_data, "metadata": {**params, "provider": "litellm", "model": model}}, - ) - - -def _parse_metrics_from_usage(usage: Any) -> dict[str, Any]: - """Parse usage metrics from API response.""" - # For simplicity, this function handles all the different APIs - metrics: dict[str, Any] = {} - - if not usage: - return metrics - - # This might be a dict or a Usage object that can be cast to a dict - usage = _try_to_dict(usage) - if not isinstance(usage, dict): - return metrics # unexpected - - for oai_name, value in usage.items(): - if oai_name.endswith("_tokens_details"): - # handle `_tokens_detail` dicts - if not isinstance(value, dict): - continue # unexpected - raw_prefix = oai_name[: -len("_tokens_details")] - prefix = TOKEN_PREFIX_MAP.get(raw_prefix, raw_prefix) - for k, v in value.items(): - if is_numeric(v): - metrics[f"{prefix}_{k}"] = v - elif is_numeric(value): - name = TOKEN_NAME_MAP.get(oai_name, oai_name) - metrics[name] = value - - return metrics - - -def prettify_params(params: dict[str, Any]) -> dict[str, Any]: - """Clean up parameters by filtering out NOT_GIVEN values and serializing response_format.""" - # Filter out NOT_GIVEN parameters - # https://linear.app/braintrustdata/issue/BRA-2467 - # ret = {k: v for k, v in params.items() if not _is_not_given(v)} - ret = {k: v for k, v in params.items()} - - if "response_format" in ret: - ret["response_format"] = serialize_response_format(ret["response_format"]) - return ret - - -def _try_to_dict(obj: Any) -> dict[str, Any] | Any: - """Try to convert an object to a dictionary.""" - if isinstance(obj, dict): - return obj - # convert a pydantic object to a dict - if hasattr(obj, "model_dump") and callable(obj.model_dump): - try: - result = obj.model_dump() - if isinstance(result, dict): - return result - except Exception: - pass - # deprecated pydantic method, try model_dump first. - if hasattr(obj, "dict") and callable(obj.dict): - try: - result = obj.dict() - if isinstance(result, dict): - return result - except Exception: - pass - return obj - - -def serialize_response_format(response_format: Any) -> Any: - """Serialize response format for logging.""" - try: - from pydantic import BaseModel - except ImportError: - return response_format - - if isinstance(response_format, type) and issubclass(response_format, BaseModel): - return dict( - type="json_schema", - json_schema=dict( - name=response_format.__name__, - schema=response_format.model_json_schema(), - ), - ) - else: - return response_format - - -def patch_litellm() -> bool: - """ - Patch LiteLLM to add Braintrust tracing. - - This wraps litellm.completion and litellm.acompletion to automatically - create Braintrust spans with detailed token metrics, timing, and costs. - - Returns: - True if LiteLLM was patched (or already patched), False if LiteLLM is not installed. - - Example: - ```python - import braintrust - braintrust.patch_litellm() - - import litellm - from braintrust import init_logger - - logger = init_logger(project="my-project") - response = litellm.completion( - model="gpt-4o-mini", - messages=[{"role": "user", "content": "Hello"}] - ) - ``` - """ - try: - import litellm - - if not hasattr(litellm, "_braintrust_wrapped"): - wrapped = wrap_litellm(litellm) - litellm.completion = wrapped.completion - litellm.acompletion = wrapped.acompletion - litellm.responses = wrapped.responses - litellm.aresponses = wrapped.aresponses - litellm._braintrust_wrapped = True - return True - except ImportError: - return False +__all__ = [ + "patch_litellm", + "wrap_litellm", +] diff --git a/py/src/braintrust/wrappers/test_utils.py b/py/src/braintrust/wrappers/test_utils.py index 5158404a..4a38debc 100644 --- a/py/src/braintrust/wrappers/test_utils.py +++ b/py/src/braintrust/wrappers/test_utils.py @@ -71,7 +71,7 @@ def assert_metrics_are_valid(metrics, start=None, end=None): @contextmanager -def autoinstrument_test_context(cassette_name: str, *, use_vcr: bool = True): +def autoinstrument_test_context(cassette_name: str, *, use_vcr: bool = True, cassettes_dir: Path | None = None): """Context manager for auto_instrument tests. Sets up the shared memory_logger context and, by default, VCR. @@ -80,12 +80,16 @@ def autoinstrument_test_context(cassette_name: str, *, use_vcr: bool = True): non-VCR mechanism, such as the Claude Agent SDK subprocess cassette transport. + Use ``cassettes_dir`` to override the cassette directory (e.g. when + cassettes live next to an integration package rather than in + ``wrappers/cassettes/``). + Usage: with autoinstrument_test_context("test_auto_openai") as memory_logger: # make API call spans = memory_logger.pop() """ - cassette_path = CASSETTES_DIR / f"{cassette_name}.yaml" + cassette_path = (cassettes_dir or CASSETTES_DIR) / f"{cassette_name}.yaml" init_test_logger("test-auto-instrument")