diff --git a/.agents/skills/sdk-integrations/SKILL.md b/.agents/skills/sdk-integrations/SKILL.md index 754b61f7..7a9ca12e 100644 --- a/.agents/skills/sdk-integrations/SKILL.md +++ b/.agents/skills/sdk-integrations/SKILL.md @@ -13,12 +13,19 @@ If the provider already has a real implementation under `py/src/braintrust/wrapp Start from one structural reference and one patching reference instead of designing from scratch: -- ADK (`py/src/braintrust/integrations/adk/`) for direct method patching, `target_module`, `CompositeFunctionWrapperPatcher`, manual `wrap_*()` helpers, and priority-based context propagation. +- ADK (`py/src/braintrust/integrations/adk/`) for direct method patching, `target_module`, `CompositeFunctionWrapperPatcher`, manual `wrap_*()` helpers, priority-based context propagation, and input-side `inline_data` to `Attachment` conversion. - Agno (`py/src/braintrust/integrations/agno/`) for multi-target patching, version-conditional fallbacks with `superseded_by`, and providers that need several related patchers. - Anthropic (`py/src/braintrust/integrations/anthropic/`) for constructor patching and a compact provider package with a small public surface. +- Google GenAI (`py/src/braintrust/integrations/google_genai/`) for multimodal serialization, generated media outputs, and output-side `Attachment` handling. Match an existing repo pattern unless the target provider forces a different shape. +Choose the example based on the hardest part of the task, not just provider similarity: + +- If the task is mostly about patcher topology, copy the closest patcher layout first. +- If the task is mostly about traced payload shaping, copy the closest tracing implementation first. +- If the task involves generated media or multimodal payloads, start from ADK or Google GenAI before looking at simpler text-only integrations. + ## Read First Always read: @@ -42,6 +49,25 @@ Read when relevant: - `py/src/braintrust/conftest.py` for VCR behavior - `py/src/braintrust/integrations/auto_test_scripts/` for subprocess auto-instrument tests - `py/src/braintrust/integrations/adk/test_adk.py` and `py/src/braintrust/integrations/anthropic/test_anthropic.py` for test layout patterns +- `py/src/braintrust/integrations/adk/tracing.py` and `py/src/braintrust/integrations/google_genai/tracing.py` when the provider accepts binary inputs, emits generated files, or otherwise needs `Attachment` objects in traced input/output + +## Working Sequence + +Use this order unless the task is obviously narrower: + +1. Read the nearest provider package and the shared integration primitives. +2. Decide which public surface is being patched: constructor, top-level function, client method, stream method, or manual `wrap_*()` helper. +3. Decide what the span should look like before writing patchers: + - what belongs in `input` + - what belongs in `output` + - what belongs in `metadata` + - what belongs in `metrics` +4. Implement or update patchers. +5. Implement or update tracing helpers. +6. Add or update focused tests in the provider package. +7. Run the narrowest nox session first, then expand only if shared code changed. + +Do not start by wiring patchers and only later asking what the logged span should contain. The traced shape should drive the tracing helper design from the start. ## Route The Task @@ -66,6 +92,7 @@ Read when relevant: 2. Change only the affected patchers, tracing helpers, exports, tests, and cassettes. 3. Preserve the provider's public setup and `wrap_*()` surface unless the task explicitly changes it. 4. Keep repo-level changes narrow; do not touch `auto.py`, `integrations/__init__.py`, or `py/noxfile.py` unless the task actually requires it. +5. Preserve existing span shape conventions unless the task is intentionally improving or correcting them. ### `auto_instrument()` only @@ -102,6 +129,77 @@ Keep span creation, metadata extraction, stream aggregation, error logging, and Preserve provider behavior. Do not let tracing-only code change provider return values, control flow, or error behavior except where the task explicitly requires it. +Generate structured spans. Do not pass raw `args` and `kwargs` straight into traced spans unless the provider API already exposes the exact stable schema you want to log. Instead: + +- Build a provider-shaped `input` object that names the important request fields explicitly, for example model, messages/contents, prompt, config, tools, or options. +- Build an `output` object that captures the useful response payload in normalized form instead of logging opaque SDK objects. +- Put secondary facts in `metadata`, such as provider ids, finish reasons, model versions, safety attributes, or normalized request/response annotations that are useful but not the primary payload. +- Put timings and token/accounting values in `metrics`, such as `start`, `end`, `duration`, `time_to_first_token`, `prompt_tokens`, `completion_tokens`, and `tokens`. +- Drop noisy transport-level or duplicate fields rather than mirroring the full raw call surface. +- Add small provider-local helpers in `tracing.py` to extract `input`, `output`, `metadata`, and `metrics` before opening or closing spans. + +Aim for spans that are readable in the UI without requiring someone to reverse-engineer the provider SDK's calling convention from positional arguments. + +Shape spans by semantics, not by the provider SDK object model: + +- `input` is the meaningful request a human would describe, not the raw Python call signature. +- `output` is the meaningful result, not a provider response class dumped wholesale. +- `metadata` is for supporting context that helps interpretation but is not the main payload. +- `metrics` is for timings, token counts, and similar numeric accounting. + +Good span shaping usually means: + +- flattening positional arguments into named fields +- omitting duplicate values that appear in both request and response objects +- normalizing provider-specific classes into dicts/lists/scalars +- aggregating streaming chunks into one final `output` plus stream-specific `metrics` +- preserving useful provider identifiers without leaking transport noise + +Use provider-local helper functions instead of building spans inline inside wrappers. A good pattern is: + +```python +def _prepare_traced_call(args: list[Any], kwargs: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]: + ... + + +def _process_result(result: Any, start: float) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]: + ... +``` + +Keep wrapper bodies thin: prepare input, open the span, call the provider, normalize the result, then log `output`, `metadata`, and `metrics`. + +When deciding where a field belongs: + +- Put it in `input` if the caller intentionally supplied it. +- Put it in `output` if it is the core result a user would care about. +- Put it in `metadata` if it explains the result but is not the result itself. +- Put it in `metrics` if it is numeric operational accounting or timing. +- Put it in `error` if the call failed and you want the span to record the exception or failure message instead of pretending the failure is ordinary output. + +Distinguish span payload fields from span setup fields: + +- Treat `input`, `output`, `metadata`, `metrics`, and `error` as the main logged payload fields. +- Treat `name` plus `type` or `span_attributes` as span identity/classification, not as payload. +- Use `parent` only when you need to attach the span to an explicit exported parent instead of relying on current-span context. +- Use `start_time` when the true start happened before the wrapper got control and you need accurate duration or time-to-first-token accounting. + +Examples: + +- prompt/model/tools/config belong in `input` +- generated text, tool calls, embeddings summary, generated images summary, or normalized message content belong in `output` +- provider request ids, finish reasons, safety annotations, cached-hit indicators, or model revision identifiers belong in `metadata` +- token counts, elapsed time, time-to-first-token, retry counts, or billable character counts belong in `metrics` +- exceptions, provider errors, and wrapper failures belong in `error` + +Across the current integrations, `input`/`output`/`metadata`/`metrics` are the common structured logging fields, and `error` is the main additional event field used during failures. Other values should usually live inside one of those containers unless they are truly span-level controls like `name`, `type`, `span_attributes`, `parent`, or `start_time`. + +Treat provider-owned binary payloads as attachments, not raw logged bytes. When traced input or output contains inline media, generated files, or other uploadable content: + +- Convert raw `bytes` into `braintrust.logger.Attachment` objects in provider-local tracing helpers instead of logging raw bytes or large base64 blobs. +- Use the repo's existing message/content shapes when embedding attachments in traced payloads. For multimodal content this is often `{"image_url": {"url": attachment}}`, even when the MIME type is not literally an image. +- Preserve ordinary remote URLs as strings. Only convert provider-owned binary content or data-URL style payloads that Braintrust should upload to object storage. +- Keep structured metadata alongside the attachment, such as MIME type, size, safety attributes, or provider ids, so spans stay inspectable without reading the blob. + Prefer feature detection first and version checks second. Use: - `detect_module_version(...)` @@ -110,6 +208,8 @@ Prefer feature detection first and version checks second. Use: Let `BaseIntegration.resolve_patchers()` reject duplicate patcher ids; do not silently paper over duplicates. +If a provider surface has both sync and async variants, try to keep the traced schema aligned across both paths. Differences in implementation are fine; differences in logged shape should be intentional. + ## Patcher Rules Create one patcher per coherent patch target. Split unrelated targets into separate patchers. @@ -154,6 +254,8 @@ Use `@pytest.mark.vcr` for real provider network behavior. Prefer recorded provi - local version-routing logic - patcher existence checks +Write tests against the emitted span shape, not just the provider return value. A tracing change is incomplete if the provider call still works but the logged span becomes noisy, incomplete, or inconsistent. + Cover the surfaces that changed: - direct `wrap_*()` behavior @@ -164,9 +266,26 @@ Cover the surfaces that changed: - idempotence - failure and error logging - patcher resolution and duplicate detection +- attachment conversion for binary inputs or generated media, including assertions that traced payloads contain `Attachment` objects rather than raw bytes +- span structure, including assertions that the traced span exposes meaningful `input`, `output`, `metadata`, and `metrics` rather than opaque raw call arguments + +For span assertions, prefer checking the specific normalized fields that matter: + +- the `input` contains the expected model/messages/prompt/config fields +- the `output` contains normalized provider results rather than opaque SDK instances +- the `metadata` contains finish reasons, ids, or annotations in the expected place +- the `metrics` contain the expected timing or token fields when the provider returns them +- binary payloads are represented as `Attachment` objects where applicable + +If a change affects streaming, verify both: + +- intermediate behavior still returns the provider's expected iterator or async iterator +- the final logged span contains the aggregated `output` and stream-specific `metrics` Keep VCR cassettes in `py/src/braintrust/integrations//cassettes/`. Re-record only when the behavior change is intentional. +When the provider returns binary HTTP responses or generated media, make cassette sanitization part of the change if needed so recorded fixtures do not store raw file bytes. + When choosing commands, confirm the real session name in `py/noxfile.py` instead of assuming it matches the provider folder. Examples in this repo include `test_agno`, `test_anthropic`, and `test_google_adk`. ## Commands diff --git a/py/src/braintrust/integrations/google_genai/cassettes/test_generate_images.yaml b/py/src/braintrust/integrations/google_genai/cassettes/test_generate_images.yaml new file mode 100644 index 00000000..7e9a3c7f --- /dev/null +++ b/py/src/braintrust/integrations/google_genai/cassettes/test_generate_images.yaml @@ -0,0 +1,57 @@ +interactions: +- request: + body: '{"instances": [{"prompt": "A watercolor fox in a forest"}], "parameters": + {"sampleCount": 1, "aspectRatio": "1:1", "safetySetting": "BLOCK_LOW_AND_ABOVE", + "includeRaiReason": true}}' + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate, zstd + Connection: + - keep-alive + Content-Length: + - '181' + Content-Type: + - application/json + Host: + - generativelanguage.googleapis.com + user-agent: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + x-goog-api-client: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + method: POST + uri: https://generativelanguage.googleapis.com/v1beta/models/imagen-4.0-fast-generate-001:predict + response: + body: + string: '{"predictions": [{"bytesBase64Encoded": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg==", + "mimeType": "image/png"}]}' + headers: + Alt-Svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + Content-Type: + - application/json; charset=UTF-8 + Date: + - Fri, 27 Mar 2026 01:18:45 GMT + Server: + - scaffolding on HTTPServer2 + Server-Timing: + - gfet4t7; dur=4159 + Transfer-Encoding: + - chunked + Vary: + - Origin + - X-Origin + - Referer + X-Content-Type-Options: + - nosniff + X-Frame-Options: + - SAMEORIGIN + X-XSS-Protection: + - '0' + content-length: + - '2622509' + status: + code: 200 + message: OK +version: 1 diff --git a/py/src/braintrust/integrations/google_genai/cassettes/test_generate_images_async.yaml b/py/src/braintrust/integrations/google_genai/cassettes/test_generate_images_async.yaml new file mode 100644 index 00000000..14d47429 --- /dev/null +++ b/py/src/braintrust/integrations/google_genai/cassettes/test_generate_images_async.yaml @@ -0,0 +1,47 @@ +interactions: +- request: + body: '{"instances": [{"prompt": "A watercolor fox in a forest"}], "parameters": + {"sampleCount": 1, "aspectRatio": "1:1", "safetySetting": "BLOCK_LOW_AND_ABOVE", + "includeRaiReason": true}}' + headers: + Content-Type: + - application/json + user-agent: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + x-goog-api-client: + - google-genai-sdk/1.66.0 gl-python/3.13.3 + method: POST + uri: https://generativelanguage.googleapis.com/v1beta/models/imagen-4.0-fast-generate-001:predict + response: + body: + string: '{"predictions": [{"bytesBase64Encoded": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg==", + "mimeType": "image/png"}]}' + headers: + Alt-Svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + Content-Type: + - application/json; charset=UTF-8 + Date: + - Fri, 27 Mar 2026 01:18:49 GMT + Server: + - scaffolding on HTTPServer2 + Server-Timing: + - gfet4t7; dur=3224 + Transfer-Encoding: + - chunked + Vary: + - Origin + - X-Origin + - Referer + X-Content-Type-Options: + - nosniff + X-Frame-Options: + - SAMEORIGIN + X-XSS-Protection: + - '0' + content-length: + - '2750553' + status: + code: 200 + message: OK +version: 1 diff --git a/py/src/braintrust/integrations/google_genai/integration.py b/py/src/braintrust/integrations/google_genai/integration.py index eaa5a201..48faf088 100644 --- a/py/src/braintrust/integrations/google_genai/integration.py +++ b/py/src/braintrust/integrations/google_genai/integration.py @@ -8,9 +8,11 @@ AsyncModelsEmbedContentPatcher, AsyncModelsGenerateContentPatcher, AsyncModelsGenerateContentStreamPatcher, + AsyncModelsGenerateImagesPatcher, ModelsEmbedContentPatcher, ModelsGenerateContentPatcher, ModelsGenerateContentStreamPatcher, + ModelsGenerateImagesPatcher, ) @@ -26,7 +28,9 @@ class GoogleGenAIIntegration(BaseIntegration): ModelsGenerateContentPatcher, ModelsGenerateContentStreamPatcher, ModelsEmbedContentPatcher, + ModelsGenerateImagesPatcher, AsyncModelsGenerateContentPatcher, AsyncModelsGenerateContentStreamPatcher, AsyncModelsEmbedContentPatcher, + AsyncModelsGenerateImagesPatcher, ) diff --git a/py/src/braintrust/integrations/google_genai/patchers.py b/py/src/braintrust/integrations/google_genai/patchers.py index 600a604b..7bc895f1 100644 --- a/py/src/braintrust/integrations/google_genai/patchers.py +++ b/py/src/braintrust/integrations/google_genai/patchers.py @@ -6,9 +6,11 @@ _async_embed_content_wrapper, _async_generate_content_stream_wrapper, _async_generate_content_wrapper, + _async_generate_images_wrapper, _embed_content_wrapper, _generate_content_stream_wrapper, _generate_content_wrapper, + _generate_images_wrapper, ) @@ -44,6 +46,15 @@ class ModelsEmbedContentPatcher(FunctionWrapperPatcher): wrapper = _embed_content_wrapper +class ModelsGenerateImagesPatcher(FunctionWrapperPatcher): + """Patch ``Models.generate_images`` for tracing.""" + + name = "google_genai.models.generate_images" + target_module = "google.genai.models" + target_path = "Models.generate_images" + wrapper = _generate_images_wrapper + + # --------------------------------------------------------------------------- # Async Models patchers # --------------------------------------------------------------------------- @@ -74,3 +85,12 @@ class AsyncModelsEmbedContentPatcher(FunctionWrapperPatcher): target_module = "google.genai.models" target_path = "AsyncModels.embed_content" wrapper = _async_embed_content_wrapper + + +class AsyncModelsGenerateImagesPatcher(FunctionWrapperPatcher): + """Patch ``AsyncModels.generate_images`` for tracing.""" + + name = "google_genai.async_models.generate_images" + target_module = "google.genai.models" + target_path = "AsyncModels.generate_images" + wrapper = _async_generate_images_wrapper diff --git a/py/src/braintrust/integrations/google_genai/test_google_genai.py b/py/src/braintrust/integrations/google_genai/test_google_genai.py index 9834fe30..f9a869f1 100644 --- a/py/src/braintrust/integrations/google_genai/test_google_genai.py +++ b/py/src/braintrust/integrations/google_genai/test_google_genai.py @@ -1,3 +1,5 @@ +import gzip +import json import os import time from pathlib import Path @@ -5,6 +7,7 @@ import pytest from braintrust import logger from braintrust.integrations.google_genai import setup_genai +from braintrust.logger import Attachment from braintrust.test_helpers import init_test_logger from braintrust.wrappers.test_utils import verify_autoinstrument_script from google.genai import types @@ -14,7 +17,59 @@ PROJECT_NAME = "test-genai-app" MODEL = "gemini-2.0-flash-001" EMBEDDING_MODEL = "gemini-embedding-001" +IMAGE_MODEL = "imagen-4.0-fast-generate-001" FIXTURES_DIR = Path(__file__).parent.parent.parent.parent.parent / "internal/golden/fixtures" +TINY_PNG_BASE64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg==" + + +def _sanitize_generate_images_body(value): + if isinstance(value, dict): + return { + key: ( + TINY_PNG_BASE64 + if key == "bytesBase64Encoded" and isinstance(val, str) + else _sanitize_generate_images_body(val) + ) + for key, val in value.items() + } + if isinstance(value, list): + return [_sanitize_generate_images_body(item) for item in value] + return value + + +def _sanitize_generate_images_response(response): + body = response.get("body", {}) + payload = body.get("string") + if not payload: + return response + + is_bytes = isinstance(payload, bytes) + is_gzipped = False + + if is_bytes: + raw_payload = payload + if raw_payload[:2] == b"\x1f\x8b": + raw_payload = gzip.decompress(raw_payload) + is_gzipped = True + payload = raw_payload.decode("utf-8") + + try: + parsed = json.loads(payload) + except Exception: + return response + + sanitized = _sanitize_generate_images_body(parsed) + if sanitized == parsed: + return response + + sanitized_payload = json.dumps(sanitized) + if is_bytes: + body["string"] = ( + gzip.compress(sanitized_payload.encode("utf-8")) if is_gzipped else sanitized_payload.encode("utf-8") + ) + else: + body["string"] = sanitized_payload + return response @pytest.fixture(scope="module") @@ -27,14 +82,19 @@ def before_record_request(request): request.method = request.method.upper() return request + def before_record_response(response): + return _sanitize_generate_images_response(response) + return { "record_mode": record_mode, + "decode_compressed_response": True, "filter_headers": [ "authorization", "x-api-key", "x-goog-api-key", ], "before_record_request": before_record_request, + "before_record_response": before_record_response, } @@ -669,6 +729,105 @@ def test_attachment_in_config(memory_logger): assert copied["temperature"] == 0.5 +@pytest.mark.vcr +def test_generate_images(memory_logger): + assert not memory_logger.pop() + + client = Client() + start = time.time() + + response = client.models.generate_images( + model=IMAGE_MODEL, + prompt="A watercolor fox in a forest", + config=types.GenerateImagesConfig( + number_of_images=1, + aspect_ratio="1:1", + safety_filter_level="BLOCK_LOW_AND_ABOVE", + include_rai_reason=True, + ), + ) + end = time.time() + + assert len(response.generated_images) == 1 + assert response.generated_images[0].image + assert response.generated_images[0].image.image_bytes + + spans = memory_logger.pop() + assert len(spans) == 1 + span = spans[0] + assert span["metadata"]["model"] == IMAGE_MODEL + assert span["input"]["prompt"] == "A watercolor fox in a forest" + assert span["input"]["config"]["number_of_images"] == 1 + assert span["input"]["config"]["aspect_ratio"] == "1:1" + assert span["input"]["config"]["safety_filter_level"] == "BLOCK_LOW_AND_ABOVE" + assert span["input"]["config"]["include_rai_reason"] is True + assert span["output"]["generated_images_count"] == 1 + generated_image = span["output"]["generated_images"][0] + assert generated_image["image_size_bytes"] > 0 + assert generated_image["mime_type"] in {"image/png", "image/jpeg", "image/webp"} + + # Verify the image bytes are stored as an Attachment for upload to object storage + assert "image_url" in generated_image + attachment = generated_image["image_url"]["url"] + assert isinstance(attachment, Attachment) + assert attachment.reference["type"] == "braintrust_attachment" + assert attachment.reference["content_type"] == generated_image["mime_type"] + assert attachment.reference["filename"].startswith("generated_image_") + assert attachment.reference["key"] + + _assert_timing_metrics_are_valid(span["metrics"], start, end) + + +@pytest.mark.vcr +@pytest.mark.asyncio +async def test_generate_images_async(memory_logger): + assert not memory_logger.pop() + + client = Client() + start = time.time() + + response = await client.aio.models.generate_images( + model=IMAGE_MODEL, + prompt="A watercolor fox in a forest", + config=types.GenerateImagesConfig( + number_of_images=1, + aspect_ratio="1:1", + safety_filter_level="BLOCK_LOW_AND_ABOVE", + include_rai_reason=True, + ), + ) + end = time.time() + + assert len(response.generated_images) == 1 + assert response.generated_images[0].image + assert response.generated_images[0].image.image_bytes + + spans = memory_logger.pop() + assert len(spans) == 1 + span = spans[0] + assert span["metadata"]["model"] == IMAGE_MODEL + assert span["input"]["prompt"] == "A watercolor fox in a forest" + assert span["input"]["config"]["number_of_images"] == 1 + assert span["input"]["config"]["aspect_ratio"] == "1:1" + assert span["input"]["config"]["safety_filter_level"] == "BLOCK_LOW_AND_ABOVE" + assert span["input"]["config"]["include_rai_reason"] is True + assert span["output"]["generated_images_count"] == 1 + generated_image = span["output"]["generated_images"][0] + assert generated_image["image_size_bytes"] > 0 + assert generated_image["mime_type"] in {"image/png", "image/jpeg", "image/webp"} + + # Verify the image bytes are stored as an Attachment for upload to object storage + assert "image_url" in generated_image + attachment = generated_image["image_url"]["url"] + assert isinstance(attachment, Attachment) + assert attachment.reference["type"] == "braintrust_attachment" + assert attachment.reference["content_type"] == generated_image["mime_type"] + assert attachment.reference["filename"].startswith("generated_image_") + assert attachment.reference["key"] + + _assert_timing_metrics_are_valid(span["metrics"], start, end) + + def test_nested_attachments_in_contents(memory_logger): """Test that nested attachments in contents are preserved.""" from braintrust.bt_json import bt_safe_deep_copy diff --git a/py/src/braintrust/integrations/google_genai/tracing.py b/py/src/braintrust/integrations/google_genai/tracing.py index ed7f572e..beb265fe 100644 --- a/py/src/braintrust/integrations/google_genai/tracing.py +++ b/py/src/braintrust/integrations/google_genai/tracing.py @@ -145,6 +145,16 @@ def _prepare_traced_call( return _serialize_input(api_client, input), clean_kwargs +def _prepare_generate_images_traced_call( + api_client: Any, args: list[Any], kwargs: dict[str, Any] +) -> tuple[dict[str, Any], dict[str, Any]]: + del api_client + input, clean_kwargs = _get_args_kwargs(args, kwargs, ["model", "prompt", "config"], ["prompt", "config"]) + if input.get("config") is not None: + input["config"] = bt_safe_deep_copy(input["config"]) + return _clean(input), clean_kwargs + + # --------------------------------------------------------------------------- # Metric extraction helpers # --------------------------------------------------------------------------- @@ -222,6 +232,71 @@ def _extract_embed_content_metrics(response: "EmbedContentResponse", start: floa return _clean(metrics) +def _extract_generate_images_output(response: Any) -> dict[str, Any]: + generated_images = getattr(response, "generated_images", None) or [] + serialized_images = [] + + for i, generated_image in enumerate(generated_images): + image = getattr(generated_image, "image", None) + image_bytes = getattr(image, "image_bytes", None) + mime_type = getattr(image, "mime_type", None) + safety_attributes = getattr(generated_image, "safety_attributes", None) + + image_entry: dict[str, Any] = _clean( + { + "mime_type": mime_type, + "gcs_uri": getattr(image, "gcs_uri", None), + "image_size_bytes": len(image_bytes) if image_bytes is not None else None, + "rai_filtered_reason": getattr(generated_image, "rai_filtered_reason", None), + "enhanced_prompt": getattr(generated_image, "enhanced_prompt", None), + "safety_categories": getattr(safety_attributes, "categories", None), + "safety_scores": getattr(safety_attributes, "scores", None), + "safety_content_type": getattr(safety_attributes, "content_type", None), + } + ) + + # Convert image bytes to an Attachment so the SDK uploads them to + # object storage and the Braintrust UI can render the image. + if isinstance(image_bytes, bytes) and mime_type: + extension = mime_type.split("/")[1] if "/" in mime_type else "bin" + filename = f"generated_image_{i}.{extension}" + attachment = Attachment(data=image_bytes, filename=filename, content_type=mime_type) + image_entry["image_url"] = {"url": attachment} + + serialized_images.append(image_entry) + + positive_prompt_safety_attributes = getattr(response, "positive_prompt_safety_attributes", None) + positive_prompt_summary = None + if positive_prompt_safety_attributes is not None: + positive_prompt_summary = _clean( + { + "categories": getattr(positive_prompt_safety_attributes, "categories", None), + "scores": getattr(positive_prompt_safety_attributes, "scores", None), + "content_type": getattr(positive_prompt_safety_attributes, "content_type", None), + } + ) + + return _clean( + { + "generated_images_count": len(generated_images), + "generated_images": serialized_images, + "has_positive_prompt_safety_attributes": positive_prompt_safety_attributes is not None, + "positive_prompt_safety_attributes": positive_prompt_summary, + } + ) + + +def _extract_generate_images_metrics(start: float) -> dict[str, Any]: + end_time = time.time() + return _clean( + dict( + start=start, + end=end_time, + duration=end_time - start, + ) + ) + + # --------------------------------------------------------------------------- # Result processing helpers # --------------------------------------------------------------------------- @@ -235,6 +310,10 @@ def _embed_process_result(result: "EmbedContentResponse", start: float) -> tuple return _extract_embed_content_output(result), _extract_embed_content_metrics(result, start) +def _generate_images_process_result(result: Any, start: float) -> tuple[Any, dict[str, Any]]: + return _extract_generate_images_output(result), _extract_generate_images_metrics(start) + + # --------------------------------------------------------------------------- # Stream aggregation # --------------------------------------------------------------------------- @@ -346,8 +425,11 @@ def _run_traced_call( name: str, invoke: Callable[[], Any], process_result: Callable[[Any, float], tuple[Any, dict[str, Any]]], + prepare_call: Callable[ + [Any, list[Any], dict[str, Any]], tuple[dict[str, Any], dict[str, Any]] + ] = _prepare_traced_call, ) -> Any: - input, clean_kwargs = _prepare_traced_call(api_client, args, kwargs) + input, clean_kwargs = prepare_call(api_client, args, kwargs) start = time.time() with start_span(name=name, type=SpanTypeAttribute.LLM, input=input, metadata=clean_kwargs) as span: @@ -365,8 +447,11 @@ async def _run_async_traced_call( name: str, invoke: Callable[[], Awaitable[Any]], process_result: Callable[[Any, float], tuple[Any, dict[str, Any]]], + prepare_call: Callable[ + [Any, list[Any], dict[str, Any]], tuple[dict[str, Any], dict[str, Any]] + ] = _prepare_traced_call, ) -> Any: - input, clean_kwargs = _prepare_traced_call(api_client, args, kwargs) + input, clean_kwargs = prepare_call(api_client, args, kwargs) start = time.time() with start_span(name=name, type=SpanTypeAttribute.LLM, input=input, metadata=clean_kwargs) as span: @@ -468,6 +553,18 @@ def _embed_content_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any) ) +def _generate_images_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any) -> Any: + return _run_traced_call( + instance._api_client, + args, + kwargs, + name="generate_images", + invoke=lambda: wrapped(*args, **kwargs), + process_result=_generate_images_process_result, + prepare_call=_prepare_generate_images_traced_call, + ) + + async def _async_generate_content_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any) -> Any: return await _run_async_traced_call( instance._api_client, @@ -499,3 +596,15 @@ async def _async_embed_content_wrapper(wrapped: Any, instance: Any, args: Any, k invoke=lambda: wrapped(*args, **kwargs), process_result=_embed_process_result, ) + + +async def _async_generate_images_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any) -> Any: + return await _run_async_traced_call( + instance._api_client, + args, + kwargs, + name="generate_images", + invoke=lambda: wrapped(*args, **kwargs), + process_result=_generate_images_process_result, + prepare_call=_prepare_generate_images_traced_call, + )