Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java
Original file line number Diff line number Diff line change
Expand Up @@ -188,20 +188,22 @@ private Flowable<LlmResponse> callLlm(
context, llmRequestBuilder, eventForCallbackUsage, exception)
.switchIfEmpty(Single.error(exception))
.toFlowable())
.doOnNext(
llmResp ->
Tracing.traceCallLlm(
context,
eventForCallbackUsage.id(),
llmRequestBuilder.build(),
llmResp))
.doOnError(
error -> {
Span span = Span.current();
span.setStatus(StatusCode.ERROR, error.getMessage());
span.recordException(error);
})
.compose(Tracing.trace("call_llm"))
.compose(
Tracing.<LlmResponse>trace("call_llm")
.onSuccess(
(span, llmResp) ->
Tracing.traceCallLlm(
span,
context,
eventForCallbackUsage.id(),
llmRequestBuilder.build(),
llmResp)))
.concatMap(
llmResp ->
handleAfterModelCallback(context, llmResp, eventForCallbackUsage)
Expand Down
109 changes: 60 additions & 49 deletions core/src/main/java/com/google/adk/telemetry/Tracing.java
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Supplier;
import org.reactivestreams.Publisher;
Expand Down Expand Up @@ -292,62 +293,49 @@ private static Map<String, Object> buildLlmRequestForTrace(LlmRequest llmRequest
* @param llmResponse The LLM response object.
*/
public static void traceCallLlm(
Span span,
InvocationContext invocationContext,
String eventId,
LlmRequest llmRequest,
LlmResponse llmResponse) {
getValidCurrentSpan("traceCallLlm")
.ifPresent(
span -> {
span.setAttribute(GEN_AI_SYSTEM, "gcp.vertex.agent");
llmRequest
.model()
.ifPresent(modelName -> span.setAttribute(GEN_AI_REQUEST_MODEL, modelName));
span.setAttribute(GEN_AI_SYSTEM, "gcp.vertex.agent");
llmRequest.model().ifPresent(modelName -> span.setAttribute(GEN_AI_REQUEST_MODEL, modelName));

setInvocationAttributes(span, invocationContext, eventId);
setInvocationAttributes(span, invocationContext, eventId);

setJsonAttribute(span, ADK_LLM_REQUEST, buildLlmRequestForTrace(llmRequest));
setJsonAttribute(span, ADK_LLM_RESPONSE, llmResponse);
setJsonAttribute(span, ADK_LLM_REQUEST, buildLlmRequestForTrace(llmRequest));
setJsonAttribute(span, ADK_LLM_RESPONSE, llmResponse);

llmRequest
.config()
.ifPresent(
config -> {
config
.topP()
.ifPresent(
topP ->
span.setAttribute(GEN_AI_REQUEST_TOP_P, topP.doubleValue()));
config
.maxOutputTokens()
.ifPresent(
maxTokens ->
span.setAttribute(
GEN_AI_REQUEST_MAX_TOKENS, maxTokens.longValue()));
});
llmResponse
.usageMetadata()
llmRequest
.config()
.ifPresent(
config -> {
config
.topP()
.ifPresent(topP -> span.setAttribute(GEN_AI_REQUEST_TOP_P, topP.doubleValue()));
config
.maxOutputTokens()
.ifPresent(
usage -> {
usage
.promptTokenCount()
.ifPresent(
tokens ->
span.setAttribute(GEN_AI_USAGE_INPUT_TOKENS, (long) tokens));
usage
.candidatesTokenCount()
.ifPresent(
tokens ->
span.setAttribute(GEN_AI_USAGE_OUTPUT_TOKENS, (long) tokens));
});
llmResponse
.finishReason()
.map(reason -> reason.knownEnum().name().toLowerCase(Locale.ROOT))
maxTokens ->
span.setAttribute(GEN_AI_REQUEST_MAX_TOKENS, maxTokens.longValue()));
});
llmResponse
.usageMetadata()
.ifPresent(
usage -> {
usage
.promptTokenCount()
.ifPresent(tokens -> span.setAttribute(GEN_AI_USAGE_INPUT_TOKENS, (long) tokens));
usage
.candidatesTokenCount()
.ifPresent(
reason ->
span.setAttribute(
GEN_AI_RESPONSE_FINISH_REASONS, ImmutableList.of(reason)));
tokens -> span.setAttribute(GEN_AI_USAGE_OUTPUT_TOKENS, (long) tokens));
});
llmResponse
.finishReason()
.map(reason -> reason.knownEnum().name().toLowerCase(Locale.ROOT))
.ifPresent(
reason -> span.setAttribute(GEN_AI_RESPONSE_FINISH_REASONS, ImmutableList.of(reason)));
}

/**
Expand Down Expand Up @@ -472,6 +460,7 @@ public static final class TracerProvider<T>
private final String spanName;
private Context explicitParentContext;
private final List<Consumer<Span>> spanConfigurers = new ArrayList<>();
private BiConsumer<Span, T> onSuccessConsumer;

private TracerProvider(String spanName) {
this.spanName = spanName;
Expand All @@ -491,6 +480,16 @@ public TracerProvider<T> setParent(Context parentContext) {
return this;
}

/**
* Registers a callback to be executed with the span and the result item when the stream emits a
* success value.
*/
@CanIgnoreReturnValue
public TracerProvider<T> onSuccess(BiConsumer<Span, T> consumer) {
this.onSuccessConsumer = consumer;
return this;
}

private Context getParentContext() {
return explicitParentContext != null ? explicitParentContext : Context.current();
}
Expand Down Expand Up @@ -521,7 +520,11 @@ public Publisher<T> apply(Flowable<T> upstream) {
return Flowable.defer(
() -> {
TracingLifecycle lifecycle = new TracingLifecycle();
return upstream.doOnSubscribe(s -> lifecycle.start()).doFinally(lifecycle::end);
Flowable<T> pipeline = upstream.doOnSubscribe(s -> lifecycle.start());
if (onSuccessConsumer != null) {
pipeline = pipeline.doOnNext(t -> onSuccessConsumer.accept(lifecycle.span, t));
}
return pipeline.doFinally(lifecycle::end);
});
}

Expand All @@ -530,7 +533,11 @@ public SingleSource<T> apply(Single<T> upstream) {
return Single.defer(
() -> {
TracingLifecycle lifecycle = new TracingLifecycle();
return upstream.doOnSubscribe(s -> lifecycle.start()).doFinally(lifecycle::end);
Single<T> pipeline = upstream.doOnSubscribe(s -> lifecycle.start());
if (onSuccessConsumer != null) {
pipeline = pipeline.doOnSuccess(t -> onSuccessConsumer.accept(lifecycle.span, t));
}
return pipeline.doFinally(lifecycle::end);
});
}

Expand All @@ -539,7 +546,11 @@ public MaybeSource<T> apply(Maybe<T> upstream) {
return Maybe.defer(
() -> {
TracingLifecycle lifecycle = new TracingLifecycle();
return upstream.doOnSubscribe(s -> lifecycle.start()).doFinally(lifecycle::end);
Maybe<T> pipeline = upstream.doOnSubscribe(s -> lifecycle.start());
if (onSuccessConsumer != null) {
pipeline = pipeline.doOnSuccess(t -> onSuccessConsumer.accept(lifecycle.span, t));
}
return pipeline.doFinally(lifecycle::end);
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ public void testTraceCallLlm() {
.totalTokenCount(30)
.build())
.build();
Tracing.traceCallLlm(buildInvocationContext(), "event-1", llmRequest, llmResponse);
Tracing.traceCallLlm(span, buildInvocationContext(), "event-1", llmRequest, llmResponse);
} finally {
span.end();
}
Expand Down