From 00eb0232afc245b22f3468a0bc6b28c1ca16f93d Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Thu, 5 Mar 2026 22:35:14 -0800 Subject: [PATCH] feat: Trigger traceCallLlm to set call_llm attributes before span ends PiperOrigin-RevId: 879435432 --- .../adk/flows/llmflows/BaseLlmFlow.java | 18 +-- .../com/google/adk/telemetry/Tracing.java | 109 ++++++++++-------- .../adk/telemetry/ContextPropagationTest.java | 2 +- 3 files changed, 71 insertions(+), 58 deletions(-) diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index 1249728d8..52045a02e 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -188,20 +188,22 @@ private Flowable callLlm( context, llmRequestBuilder, eventForCallbackUsage, exception) .switchIfEmpty(Single.error(exception)) .toFlowable()) - .doOnNext( - llmResp -> - Tracing.traceCallLlm( - context, - eventForCallbackUsage.id(), - llmRequestBuilder.build(), - llmResp)) .doOnError( error -> { Span span = Span.current(); span.setStatus(StatusCode.ERROR, error.getMessage()); span.recordException(error); }) - .compose(Tracing.trace("call_llm")) + .compose( + Tracing.trace("call_llm") + .onSuccess( + (span, llmResp) -> + Tracing.traceCallLlm( + span, + context, + eventForCallbackUsage.id(), + llmRequestBuilder.build(), + llmResp))) .concatMap( llmResp -> handleAfterModelCallback(context, llmResp, eventForCallbackUsage) diff --git a/core/src/main/java/com/google/adk/telemetry/Tracing.java b/core/src/main/java/com/google/adk/telemetry/Tracing.java index 07a640c37..07c277467 100644 --- a/core/src/main/java/com/google/adk/telemetry/Tracing.java +++ b/core/src/main/java/com/google/adk/telemetry/Tracing.java @@ -54,6 +54,7 @@ import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.function.BiConsumer; import java.util.function.Consumer; import java.util.function.Supplier; import org.reactivestreams.Publisher; @@ -292,62 +293,49 @@ private static Map buildLlmRequestForTrace(LlmRequest llmRequest * @param llmResponse The LLM response object. */ public static void traceCallLlm( + Span span, InvocationContext invocationContext, String eventId, LlmRequest llmRequest, LlmResponse llmResponse) { - getValidCurrentSpan("traceCallLlm") - .ifPresent( - span -> { - span.setAttribute(GEN_AI_SYSTEM, "gcp.vertex.agent"); - llmRequest - .model() - .ifPresent(modelName -> span.setAttribute(GEN_AI_REQUEST_MODEL, modelName)); + span.setAttribute(GEN_AI_SYSTEM, "gcp.vertex.agent"); + llmRequest.model().ifPresent(modelName -> span.setAttribute(GEN_AI_REQUEST_MODEL, modelName)); - setInvocationAttributes(span, invocationContext, eventId); + setInvocationAttributes(span, invocationContext, eventId); - setJsonAttribute(span, ADK_LLM_REQUEST, buildLlmRequestForTrace(llmRequest)); - setJsonAttribute(span, ADK_LLM_RESPONSE, llmResponse); + setJsonAttribute(span, ADK_LLM_REQUEST, buildLlmRequestForTrace(llmRequest)); + setJsonAttribute(span, ADK_LLM_RESPONSE, llmResponse); - llmRequest - .config() - .ifPresent( - config -> { - config - .topP() - .ifPresent( - topP -> - span.setAttribute(GEN_AI_REQUEST_TOP_P, topP.doubleValue())); - config - .maxOutputTokens() - .ifPresent( - maxTokens -> - span.setAttribute( - GEN_AI_REQUEST_MAX_TOKENS, maxTokens.longValue())); - }); - llmResponse - .usageMetadata() + llmRequest + .config() + .ifPresent( + config -> { + config + .topP() + .ifPresent(topP -> span.setAttribute(GEN_AI_REQUEST_TOP_P, topP.doubleValue())); + config + .maxOutputTokens() .ifPresent( - usage -> { - usage - .promptTokenCount() - .ifPresent( - tokens -> - span.setAttribute(GEN_AI_USAGE_INPUT_TOKENS, (long) tokens)); - usage - .candidatesTokenCount() - .ifPresent( - tokens -> - span.setAttribute(GEN_AI_USAGE_OUTPUT_TOKENS, (long) tokens)); - }); - llmResponse - .finishReason() - .map(reason -> reason.knownEnum().name().toLowerCase(Locale.ROOT)) + maxTokens -> + span.setAttribute(GEN_AI_REQUEST_MAX_TOKENS, maxTokens.longValue())); + }); + llmResponse + .usageMetadata() + .ifPresent( + usage -> { + usage + .promptTokenCount() + .ifPresent(tokens -> span.setAttribute(GEN_AI_USAGE_INPUT_TOKENS, (long) tokens)); + usage + .candidatesTokenCount() .ifPresent( - reason -> - span.setAttribute( - GEN_AI_RESPONSE_FINISH_REASONS, ImmutableList.of(reason))); + tokens -> span.setAttribute(GEN_AI_USAGE_OUTPUT_TOKENS, (long) tokens)); }); + llmResponse + .finishReason() + .map(reason -> reason.knownEnum().name().toLowerCase(Locale.ROOT)) + .ifPresent( + reason -> span.setAttribute(GEN_AI_RESPONSE_FINISH_REASONS, ImmutableList.of(reason))); } /** @@ -472,6 +460,7 @@ public static final class TracerProvider private final String spanName; private Context explicitParentContext; private final List> spanConfigurers = new ArrayList<>(); + private BiConsumer onSuccessConsumer; private TracerProvider(String spanName) { this.spanName = spanName; @@ -491,6 +480,16 @@ public TracerProvider setParent(Context parentContext) { return this; } + /** + * Registers a callback to be executed with the span and the result item when the stream emits a + * success value. + */ + @CanIgnoreReturnValue + public TracerProvider onSuccess(BiConsumer consumer) { + this.onSuccessConsumer = consumer; + return this; + } + private Context getParentContext() { return explicitParentContext != null ? explicitParentContext : Context.current(); } @@ -521,7 +520,11 @@ public Publisher apply(Flowable upstream) { return Flowable.defer( () -> { TracingLifecycle lifecycle = new TracingLifecycle(); - return upstream.doOnSubscribe(s -> lifecycle.start()).doFinally(lifecycle::end); + Flowable pipeline = upstream.doOnSubscribe(s -> lifecycle.start()); + if (onSuccessConsumer != null) { + pipeline = pipeline.doOnNext(t -> onSuccessConsumer.accept(lifecycle.span, t)); + } + return pipeline.doFinally(lifecycle::end); }); } @@ -530,7 +533,11 @@ public SingleSource apply(Single upstream) { return Single.defer( () -> { TracingLifecycle lifecycle = new TracingLifecycle(); - return upstream.doOnSubscribe(s -> lifecycle.start()).doFinally(lifecycle::end); + Single pipeline = upstream.doOnSubscribe(s -> lifecycle.start()); + if (onSuccessConsumer != null) { + pipeline = pipeline.doOnSuccess(t -> onSuccessConsumer.accept(lifecycle.span, t)); + } + return pipeline.doFinally(lifecycle::end); }); } @@ -539,7 +546,11 @@ public MaybeSource apply(Maybe upstream) { return Maybe.defer( () -> { TracingLifecycle lifecycle = new TracingLifecycle(); - return upstream.doOnSubscribe(s -> lifecycle.start()).doFinally(lifecycle::end); + Maybe pipeline = upstream.doOnSubscribe(s -> lifecycle.start()); + if (onSuccessConsumer != null) { + pipeline = pipeline.doOnSuccess(t -> onSuccessConsumer.accept(lifecycle.span, t)); + } + return pipeline.doFinally(lifecycle::end); }); } diff --git a/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java b/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java index 9439fe718..f809193cf 100644 --- a/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java +++ b/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java @@ -503,7 +503,7 @@ public void testTraceCallLlm() { .totalTokenCount(30) .build()) .build(); - Tracing.traceCallLlm(buildInvocationContext(), "event-1", llmRequest, llmResponse); + Tracing.traceCallLlm(span, buildInvocationContext(), "event-1", llmRequest, llmResponse); } finally { span.end(); }