From 5ad15a08a9c7f9645f646a5187b322de953de926 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Thu, 5 Mar 2026 18:41:27 -0800 Subject: [PATCH] feat: Fixing the spans produced by agent calls to have the right parent spans This change fixes an issue where OpenTelemetry spans generated during agent executions were not correctly associated with their parent spans, leading to fragmented traces. ### Core Changes * **Explicit Context Propagation**: Instead of relying on `Context.current()`, which can be unreliable in asynchronous RxJava streams, the OTel `Context` is now explicitly captured at entry points (like `Runner.runAsync`) and passed down through method signatures in `BaseAgent`, `BaseLlmFlow`, and `Functions`. * **RxJava Context Support**: Introduced `Tracing.withContext(Context)`, a new set of transformers (`FlowableTransformer`, etc.) that re-activates a captured OTel context during the subscription of reactive streams. This ensures that any work done inside `flatMap` or `concatMap` remains within the correct trace. * **Synchronous Scope Management**: Wrapped direct calls to plugins and tool callbacks in `try-with-resources` blocks using `context.makeCurrent()` to ensure the tracing context is active during synchronous execution. * **Tracing Enhancements**: * Updated `TracerProvider` to propagate the agent name via a new `AGENT_NAME_CONTEXT_KEY`. * Improved agent name retrieval from spans using reflection (supporting `ReadableSpan`) when it's not explicitly available in the context. * Modified span lifecycle management to start the span immediately upon subscription setup (via `Flowable.defer`) rather than waiting for `doOnSubscribe`. ### Impacted Areas * **`BaseAgent` & `BaseLlmFlow`**: Now strictly pass the parent context to all internal stages (preprocessing, model calls, postprocessing, and callbacks). * **`Runner`**: Entry points for `runAsync` and `runLive` are now consistently wrapped in an `"invocation"` span that serves as the root for the agent's work. * **`PluginManager`**: Ensures that plugin-provided callbacks are executed within the trace context captured when the callback was triggered. * **`Functions`**: Tool execution, including before/after callbacks and response event building, is now correctly parented. PiperOrigin-RevId: 879355958 --- .../java/com/google/adk/agents/BaseAgent.java | 133 +++++----- .../adk/flows/llmflows/BaseLlmFlow.java | 225 ++++++++++------- .../google/adk/flows/llmflows/Functions.java | 227 ++++++++++-------- .../com/google/adk/plugins/PluginManager.java | 17 +- .../java/com/google/adk/runner/Runner.java | 146 ++++++----- .../com/google/adk/telemetry/Tracing.java | 152 +++++++++++- 6 files changed, 572 insertions(+), 328 deletions(-) diff --git a/core/src/main/java/com/google/adk/agents/BaseAgent.java b/core/src/main/java/com/google/adk/agents/BaseAgent.java index 226e61abe..2f46b0bd4 100644 --- a/core/src/main/java/com/google/adk/agents/BaseAgent.java +++ b/core/src/main/java/com/google/adk/agents/BaseAgent.java @@ -29,6 +29,7 @@ import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.errorprone.annotations.DoNotCall; import com.google.genai.types.Content; +import io.opentelemetry.context.Context; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; @@ -312,38 +313,57 @@ public Flowable runAsync(InvocationContext parentContext) { private Flowable run( InvocationContext parentContext, Function> runImplementation) { + Context otelParentContext = Context.current(); + InvocationContext invocationContext = createInvocationContext(parentContext); + return Flowable.defer( - () -> { - InvocationContext invocationContext = createInvocationContext(parentContext); - - return callCallback( - beforeCallbacksToFunctions( - invocationContext.pluginManager(), beforeAgentCallback), - invocationContext) - .flatMapPublisher( - beforeEventOpt -> { - if (invocationContext.endInvocation()) { - return Flowable.fromOptional(beforeEventOpt); - } - - Flowable beforeEvents = Flowable.fromOptional(beforeEventOpt); - Flowable mainEvents = - Flowable.defer(() -> runImplementation.apply(invocationContext)); - Flowable afterEvents = - Flowable.defer( - () -> - callCallback( - afterCallbacksToFunctions( - invocationContext.pluginManager(), afterAgentCallback), - invocationContext) - .flatMapPublisher(Flowable::fromOptional)); - - return Flowable.concat(beforeEvents, mainEvents, afterEvents); - }) - .compose( - Tracing.traceAgent( - "invoke_agent " + name(), name(), description(), invocationContext)); - }); + () -> { + return callCallback( + beforeCallbacksToFunctions( + invocationContext.pluginManager(), beforeAgentCallback), + invocationContext, + otelParentContext) + .flatMapPublisher( + beforeEvent -> { + if (invocationContext.endInvocation()) { + return Flowable.just(beforeEvent); + } + + return Flowable.just(beforeEvent) + .concatWith( + runMainAndAfter( + invocationContext, otelParentContext, runImplementation)); + }) + .switchIfEmpty( + Flowable.defer( + () -> + runMainAndAfter( + invocationContext, otelParentContext, runImplementation))); + }) + .compose( + Tracing.traceAgent( + otelParentContext, + "invoke_agent " + name(), + name(), + description(), + invocationContext)); + } + + private Flowable runMainAndAfter( + InvocationContext invocationContext, + Context otelParentContext, + Function> runImplementation) { + Flowable mainEvents = + runImplementation.apply(invocationContext).compose(Tracing.withContext(otelParentContext)); + Flowable afterEvents = + callCallback( + afterCallbacksToFunctions(invocationContext.pluginManager(), afterAgentCallback), + invocationContext, + otelParentContext) + .flatMapPublisher(Flowable::just) + .compose(Tracing.withContext(otelParentContext)); + + return Flowable.concat(mainEvents, afterEvents); } /** @@ -383,13 +403,14 @@ private ImmutableList>> callbacksTo * * @param agentCallbacks Callback functions. * @param invocationContext Current invocation context. - * @return single emitting first event, or empty if none. + * @return Maybe emitting first event, or empty if none. */ - private Single> callCallback( + private Maybe callCallback( List>> agentCallbacks, - InvocationContext invocationContext) { + InvocationContext invocationContext, + Context otelParentContext) { if (agentCallbacks.isEmpty()) { - return Single.just(Optional.empty()); + return Maybe.empty(); } CallbackContext callbackContext = @@ -397,28 +418,26 @@ private Single> callCallback( return Flowable.fromIterable(agentCallbacks) .concatMap( - callback -> { - Maybe maybeContent = callback.apply(callbackContext); - - return maybeContent - .map( - content -> { - invocationContext.setEndInvocation(true); - return Optional.of( - Event.builder() - .id(Event.generateEventId()) - .invocationId(invocationContext.invocationId()) - .author(name()) - .branch(invocationContext.branch()) - .actions(callbackContext.eventActions()) - .content(content) - .build()); - }) - .toFlowable(); - }) + callback -> + callback + .apply(callbackContext) + .map( + content -> { + invocationContext.setEndInvocation(true); + return Event.builder() + .id(Event.generateEventId()) + .invocationId(invocationContext.invocationId()) + .author(name()) + .branch(invocationContext.branch()) + .actions(callbackContext.eventActions()) + .content(content) + .build(); + }) + .toFlowable() + .compose(Tracing.withContext(otelParentContext))) .firstElement() .switchIfEmpty( - Single.defer( + Maybe.defer( () -> { if (callbackContext.state().hasDelta()) { Event.Builder eventBuilder = @@ -429,9 +448,9 @@ private Single> callCallback( .branch(invocationContext.branch()) .actions(callbackContext.eventActions()); - return Single.just(Optional.of(eventBuilder.build())); + return Maybe.just(eventBuilder.build()); } else { - return Single.just(Optional.empty()); + return Maybe.empty(); } })); } diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index 1249728d8..746016c1a 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -92,7 +92,9 @@ public BaseLlmFlow( * events generated by them. */ protected Flowable preprocess( - InvocationContext context, AtomicReference llmRequestRef) { + InvocationContext context, + AtomicReference llmRequestRef, + Context otelParentContext) { LlmAgent agent = (LlmAgent) context.agent(); RequestProcessor toolsProcessor = @@ -104,7 +106,8 @@ protected Flowable preprocess( tool -> tool.processLlmRequest(builder, ToolContext.builder(ctx).build())) .andThen( Single.fromCallable( - () -> RequestProcessingResult.create(builder.build(), ImmutableList.of()))); + () -> RequestProcessingResult.create(builder.build(), ImmutableList.of()))) + .compose(Tracing.withContext(otelParentContext)); }; Iterable allProcessors = @@ -113,10 +116,12 @@ protected Flowable preprocess( return Flowable.fromIterable(allProcessors) .concatMap( processor -> - Single.defer(() -> processor.processRequest(context, llmRequestRef.get())) + processor + .processRequest(context, llmRequestRef.get()) .doOnSuccess(result -> llmRequestRef.set(result.updatedRequest())) .flattenAsFlowable( - result -> result.events() != null ? result.events() : ImmutableList.of())); + result -> result.events() != null ? result.events() : ImmutableList.of()) + .compose(Tracing.withContext(otelParentContext))); } /** @@ -147,12 +152,10 @@ protected Flowable postprocess( Context parentContext = Context.current(); return currentLlmResponse.flatMapPublisher( - updatedResponse -> { - try (Scope scope = parentContext.makeCurrent()) { - return buildPostprocessingEvents( - updatedResponse, eventIterables, context, baseEventForLlmResponse, llmRequest); - } - }); + updatedResponse -> + buildPostprocessingEvents( + updatedResponse, eventIterables, context, baseEventForLlmResponse, llmRequest) + .compose(Tracing.withContext(parentContext))); } /** @@ -164,84 +167,99 @@ protected Flowable postprocess( * callbacks. Callbacks should not rely on its ID if they create their own separate events. */ private Flowable callLlm( - InvocationContext context, LlmRequest llmRequest, Event eventForCallbackUsage) { + InvocationContext context, + LlmRequest llmRequest, + Event eventForCallbackUsage, + Context otelParentContext) { LlmAgent agent = (LlmAgent) context.agent(); LlmRequest.Builder llmRequestBuilder = llmRequest.toBuilder(); - return handleBeforeModelCallback(context, llmRequestBuilder, eventForCallbackUsage) - .flatMapPublisher( - beforeResponse -> { - if (beforeResponse.isPresent()) { - return Flowable.just(beforeResponse.get()); - } - BaseLlm llm = - agent.resolvedModel().model().isPresent() - ? agent.resolvedModel().model().get() - : LlmRegistry.getLlm(agent.resolvedModel().modelName().get()); - return llm.generateContent( - llmRequestBuilder.build(), - context.runConfig().streamingMode() == StreamingMode.SSE) - .onErrorResumeNext( - exception -> - handleOnModelErrorCallback( - context, llmRequestBuilder, eventForCallbackUsage, exception) - .switchIfEmpty(Single.error(exception)) - .toFlowable()) - .doOnNext( - llmResp -> - Tracing.traceCallLlm( - context, - eventForCallbackUsage.id(), - llmRequestBuilder.build(), - llmResp)) - .doOnError( - error -> { - Span span = Span.current(); - span.setStatus(StatusCode.ERROR, error.getMessage()); - span.recordException(error); - }) - .compose(Tracing.trace("call_llm")) - .concatMap( - llmResp -> - handleAfterModelCallback(context, llmResp, eventForCallbackUsage) - .toFlowable()); - }); + return handleBeforeModelCallback( + context, llmRequestBuilder, eventForCallbackUsage, otelParentContext) + .flatMapPublisher(Flowable::just) + .switchIfEmpty( + Flowable.defer( + () -> { + BaseLlm llm = + agent.resolvedModel().model().isPresent() + ? agent.resolvedModel().model().get() + : LlmRegistry.getLlm(agent.resolvedModel().modelName().get()); + return llm.generateContent( + llmRequestBuilder.build(), + context.runConfig().streamingMode() == StreamingMode.SSE) + .onErrorResumeNext( + exception -> + handleOnModelErrorCallback( + context, + llmRequestBuilder, + eventForCallbackUsage, + exception, + otelParentContext) + .switchIfEmpty(Single.error(exception)) + .toFlowable()) + .doOnNext( + llmResp -> + Tracing.traceCallLlm( + context, + eventForCallbackUsage.id(), + llmRequestBuilder.build(), + llmResp)) + .doOnError( + error -> { + Span span = Span.current(); + span.setStatus(StatusCode.ERROR, error.getMessage()); + span.recordException(error); + }) + .compose(Tracing.trace("call_llm", otelParentContext)) + .concatMap( + llmResp -> + handleAfterModelCallback( + context, llmResp, eventForCallbackUsage, otelParentContext) + .toFlowable()); + })); } /** * Invokes {@link BeforeModelCallback}s. If any returns a response, it's used instead of calling * the LLM. * - * @return A {@link Single} with the callback result or {@link Optional#empty()}. + * @return A {@link Maybe} with the callback result. */ - private Single> handleBeforeModelCallback( - InvocationContext context, LlmRequest.Builder llmRequestBuilder, Event modelResponseEvent) { + private Maybe handleBeforeModelCallback( + InvocationContext context, + LlmRequest.Builder llmRequestBuilder, + Event modelResponseEvent, + Context otelParentContext) { Event callbackEvent = modelResponseEvent.toBuilder().build(); CallbackContext callbackContext = new CallbackContext(context, callbackEvent.actions(), callbackEvent.id()); - Maybe pluginResult = - context.pluginManager().beforeModelCallback(callbackContext, llmRequestBuilder); + Maybe pluginResult; + try (Scope scope = otelParentContext.makeCurrent()) { + pluginResult = + context.pluginManager().beforeModelCallback(callbackContext, llmRequestBuilder); + } LlmAgent agent = (LlmAgent) context.agent(); List callbacks = agent.canonicalBeforeModelCallbacks(); if (callbacks.isEmpty()) { - return pluginResult.map(Optional::of).defaultIfEmpty(Optional.empty()); + return pluginResult; } Maybe callbackResult = Maybe.defer( () -> Flowable.fromIterable(callbacks) - .concatMapMaybe(callback -> callback.call(callbackContext, llmRequestBuilder)) + .concatMapMaybe( + callback -> + callback + .call(callbackContext, llmRequestBuilder) + .compose(Tracing.withContext(otelParentContext))) .firstElement()); - return pluginResult - .switchIfEmpty(callbackResult) - .map(Optional::of) - .defaultIfEmpty(Optional.empty()); + return pluginResult.switchIfEmpty(callbackResult); } /** @@ -254,14 +272,20 @@ private Maybe handleOnModelErrorCallback( InvocationContext context, LlmRequest.Builder llmRequestBuilder, Event modelResponseEvent, - Throwable throwable) { + Throwable throwable, + Context otelParentContext) { Event callbackEvent = modelResponseEvent.toBuilder().build(); CallbackContext callbackContext = new CallbackContext(context, callbackEvent.actions(), callbackEvent.id()); Exception ex = throwable instanceof Exception e ? e : new Exception(throwable); - Maybe pluginResult = - context.pluginManager().onModelErrorCallback(callbackContext, llmRequestBuilder, throwable); + Maybe pluginResult; + try (Scope scope = otelParentContext.makeCurrent()) { + pluginResult = + context + .pluginManager() + .onModelErrorCallback(callbackContext, llmRequestBuilder, throwable); + } LlmAgent agent = (LlmAgent) context.agent(); List callbacks = agent.canonicalOnModelErrorCallbacks(); @@ -275,7 +299,11 @@ private Maybe handleOnModelErrorCallback( () -> { LlmRequest llmRequest = llmRequestBuilder.build(); return Flowable.fromIterable(callbacks) - .concatMapMaybe(callback -> callback.call(callbackContext, llmRequest, ex)) + .concatMapMaybe( + callback -> + callback + .call(callbackContext, llmRequest, ex) + .compose(Tracing.withContext(otelParentContext))) .firstElement(); }); @@ -289,13 +317,18 @@ private Maybe handleOnModelErrorCallback( * @return A {@link Single} with the final {@link LlmResponse}. */ private Single handleAfterModelCallback( - InvocationContext context, LlmResponse llmResponse, Event modelResponseEvent) { + InvocationContext context, + LlmResponse llmResponse, + Event modelResponseEvent, + Context otelParentContext) { Event callbackEvent = modelResponseEvent.toBuilder().build(); CallbackContext callbackContext = new CallbackContext(context, callbackEvent.actions(), callbackEvent.id()); - Maybe pluginResult = - context.pluginManager().afterModelCallback(callbackContext, llmResponse); + Maybe pluginResult; + try (Scope scope = otelParentContext.makeCurrent()) { + pluginResult = context.pluginManager().afterModelCallback(callbackContext, llmResponse); + } LlmAgent agent = (LlmAgent) context.agent(); List callbacks = agent.canonicalAfterModelCallbacks(); @@ -308,7 +341,11 @@ private Single handleAfterModelCallback( Maybe.defer( () -> Flowable.fromIterable(callbacks) - .concatMapMaybe(callback -> callback.call(callbackContext, llmResponse)) + .concatMapMaybe( + callback -> + callback + .call(callbackContext, llmResponse) + .compose(Tracing.withContext(otelParentContext))) .firstElement()); return pluginResult.switchIfEmpty(callbackResult).defaultIfEmpty(llmResponse); @@ -323,13 +360,12 @@ private Single handleAfterModelCallback( * @throws LlmCallsLimitExceededException if the agent exceeds allowed LLM invocations. * @throws IllegalStateException if a transfer agent is specified but not found. */ - private Flowable runOneStep(InvocationContext context) { + private Flowable runOneStep(InvocationContext context, Context otelParentContext) { AtomicReference llmRequestRef = new AtomicReference<>(LlmRequest.builder().build()); return Flowable.defer( () -> { - Context currentContext = Context.current(); - return preprocess(context, llmRequestRef) + return preprocess(context, llmRequestRef, otelParentContext) .concatWith( Flowable.defer( () -> { @@ -355,11 +391,14 @@ private Flowable runOneStep(InvocationContext context) { .build(); mutableEventTemplate.setTimestamp(0L); - return callLlm(context, llmRequestAfterPreprocess, mutableEventTemplate) + return callLlm( + context, + llmRequestAfterPreprocess, + mutableEventTemplate, + otelParentContext) .concatMap( - llmResponse -> { - try (Scope postScope = currentContext.makeCurrent()) { - return postprocess( + llmResponse -> + postprocess( context, mutableEventTemplate, llmRequestAfterPreprocess, @@ -371,9 +410,8 @@ private Flowable runOneStep(InvocationContext context) { logger.debug( "Resetting event ID from {} to {}", oldId, newId); mutableEventTemplate.setId(newId); - }); - } - }) + }) + .compose(Tracing.withContext(otelParentContext))) .concatMap( event -> { Flowable postProcessedEvents = Flowable.just(event); @@ -407,11 +445,12 @@ private Flowable runOneStep(InvocationContext context) { */ @Override public Flowable run(InvocationContext invocationContext) { - return run(invocationContext, 0); + return run(invocationContext, Context.current(), 0); } - private Flowable run(InvocationContext invocationContext, int stepsCompleted) { - Flowable currentStepEvents = runOneStep(invocationContext).cache(); + private Flowable run( + InvocationContext invocationContext, Context otelParentContext, int stepsCompleted) { + Flowable currentStepEvents = runOneStep(invocationContext, otelParentContext).cache(); if (stepsCompleted + 1 >= maxSteps) { logger.debug("Ending flow execution because max steps reached."); return currentStepEvents; @@ -431,7 +470,7 @@ private Flowable run(InvocationContext invocationContext, int stepsComple return Flowable.empty(); } else { logger.debug("Continuing to next step of the flow."); - return run(invocationContext, stepsCompleted + 1); + return run(invocationContext, otelParentContext, stepsCompleted + 1); } })); } @@ -446,8 +485,10 @@ private Flowable run(InvocationContext invocationContext, int stepsComple */ @Override public Flowable runLive(InvocationContext invocationContext) { + Context otelParentContext = Context.current(); AtomicReference llmRequestRef = new AtomicReference<>(LlmRequest.builder().build()); - Flowable preprocessEvents = preprocess(invocationContext, llmRequestRef); + Flowable preprocessEvents = + preprocess(invocationContext, llmRequestRef, otelParentContext); return preprocessEvents.concatWith( Flowable.defer( @@ -485,7 +526,7 @@ public Flowable runLive(InvocationContext invocationContext) { eventIdForSendData, llmRequestAfterPreprocess.contents()); }) - .compose(Tracing.trace("send_data")); + .compose(Tracing.trace("send_data", otelParentContext)); Flowable liveRequests = invocationContext @@ -542,13 +583,15 @@ public void onError(Throwable e) { .receive() .flatMap( llmResponse -> { - Event baseEventForThisLlmResponse = - liveEventBuilderTemplate.id(Event.generateEventId()).build(); - return postprocess( - invocationContext, - baseEventForThisLlmResponse, - llmRequestAfterPreprocess, - llmResponse); + try (Scope scope = otelParentContext.makeCurrent()) { + Event baseEventForThisLlmResponse = + liveEventBuilderTemplate.id(Event.generateEventId()).build(); + return postprocess( + invocationContext, + baseEventForThisLlmResponse, + llmRequestAfterPreprocess, + llmResponse); + } }) .flatMap( event -> { diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java index 269764046..056987fc2 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java @@ -257,7 +257,8 @@ private static Function> getFunctionCallMapper( functionCall.args().map(HashMap::new).orElse(new HashMap<>()); Maybe> maybeFunctionResult = - maybeInvokeBeforeToolCall(invocationContext, tool, functionArgs, toolContext) + maybeInvokeBeforeToolCall( + invocationContext, tool, functionArgs, toolContext, parentContext) .switchIfEmpty( Maybe.defer( () -> { @@ -395,48 +396,49 @@ private static Maybe postProcessFunctionResult( .defaultIfEmpty(Optional.empty()) .onErrorResumeNext( t -> { - Maybe> errorCallbackResult = - handleOnToolErrorCallback(invocationContext, tool, functionArgs, toolContext, t); - Maybe>> mappedResult; - if (isLive) { - // In live mode, handle null results from the error callback gracefully. - mappedResult = errorCallbackResult.map(Optional::ofNullable); - } else { - // In non-live mode, a null result from the error callback will cause an NPE - // when wrapped with Optional.of(), potentially matching prior behavior. - mappedResult = errorCallbackResult.map(Optional::of); + try (Scope scope = parentContext.makeCurrent()) { + Maybe> errorCallbackResult = + handleOnToolErrorCallback( + invocationContext, tool, functionArgs, toolContext, t, parentContext); + Maybe>> mappedResult; + if (isLive) { + // In live mode, handle null results from the error callback gracefully. + mappedResult = errorCallbackResult.map(Optional::ofNullable); + } else { + // In non-live mode, a null result from the error callback will cause an NPE + // when wrapped with Optional.of(), potentially matching prior behavior. + mappedResult = errorCallbackResult.map(Optional::of); + } + return mappedResult.switchIfEmpty(Single.error(t)); } - return mappedResult.switchIfEmpty(Single.error(t)); }) .flatMapMaybe( optionalInitialResult -> { - try (Scope scope = parentContext.makeCurrent()) { - Map initialFunctionResult = optionalInitialResult.orElse(null); - - return maybeInvokeAfterToolCall( - invocationContext, tool, functionArgs, toolContext, initialFunctionResult) - .map(Optional::of) - .defaultIfEmpty(Optional.ofNullable(initialFunctionResult)) - .flatMapMaybe( - finalOptionalResult -> { - Map finalFunctionResult = - finalOptionalResult.orElse(null); - if (tool.longRunning() && finalFunctionResult == null) { - return Maybe.empty(); - } - return Maybe.fromCallable( - () -> - buildResponseEvent( - tool, - finalFunctionResult, - toolContext, - invocationContext)) - .compose( - Tracing.trace( - "tool_response [" + tool.name() + "]", parentContext)) - .doOnSuccess(event -> Tracing.traceToolResponse(event.id(), event)); - }); - } + Map initialFunctionResult = optionalInitialResult.orElse(null); + + return maybeInvokeAfterToolCall( + invocationContext, + tool, + functionArgs, + toolContext, + initialFunctionResult, + parentContext) + .map(Optional::of) + .defaultIfEmpty(Optional.ofNullable(initialFunctionResult)) + .flatMapMaybe( + finalOptionalResult -> { + Map finalFunctionResult = finalOptionalResult.orElse(null); + if (tool.longRunning() && finalFunctionResult == null) { + return Maybe.empty(); + } + return Maybe.fromCallable( + () -> + buildResponseEvent( + tool, finalFunctionResult, toolContext, invocationContext)) + .compose( + Tracing.trace("tool_response [" + tool.name() + "]", parentContext)) + .doOnSuccess(event -> Tracing.traceToolResponse(event.id(), event)); + }); }); } @@ -479,28 +481,32 @@ private static Maybe> maybeInvokeBeforeToolCall( InvocationContext invocationContext, BaseTool tool, Map functionArgs, - ToolContext toolContext) { - if (invocationContext.agent() instanceof LlmAgent) { - LlmAgent agent = (LlmAgent) invocationContext.agent(); + ToolContext toolContext, + Context parentContext) { + if (invocationContext.agent() instanceof LlmAgent agent) { + try (Scope scope = parentContext.makeCurrent()) { - Maybe> pluginResult = - invocationContext.pluginManager().beforeToolCallback(tool, functionArgs, toolContext); + Maybe> pluginResult = + invocationContext.pluginManager().beforeToolCallback(tool, functionArgs, toolContext); - List callbacks = agent.canonicalBeforeToolCallbacks(); - if (callbacks.isEmpty()) { - return pluginResult; - } - - Maybe> callbackResult = - Maybe.defer( - () -> - Flowable.fromIterable(callbacks) - .concatMapMaybe( - callback -> - callback.call(invocationContext, tool, functionArgs, toolContext)) - .firstElement()); + List callbacks = agent.canonicalBeforeToolCallbacks(); + if (callbacks.isEmpty()) { + return pluginResult; + } - return pluginResult.switchIfEmpty(callbackResult); + Maybe> callbackResult = + Maybe.defer( + () -> + Flowable.fromIterable(callbacks) + .concatMapMaybe( + callback -> + callback + .call(invocationContext, tool, functionArgs, toolContext) + .compose(Tracing.withContext(parentContext))) + .firstElement()); + + return pluginResult.switchIfEmpty(callbackResult); + } } return Maybe.empty(); } @@ -516,34 +522,39 @@ private static Maybe> handleOnToolErrorCallback( BaseTool tool, Map functionArgs, ToolContext toolContext, - Throwable throwable) { + Throwable throwable, + Context parentContext) { Exception ex = throwable instanceof Exception exception ? exception : new Exception(throwable); - Maybe> pluginResult = - invocationContext - .pluginManager() - .onToolErrorCallback(tool, functionArgs, toolContext, throwable); - - if (invocationContext.agent() instanceof LlmAgent) { - LlmAgent agent = (LlmAgent) invocationContext.agent(); + try (Scope scope = parentContext.makeCurrent()) { + Maybe> pluginResult = + invocationContext + .pluginManager() + .onToolErrorCallback(tool, functionArgs, toolContext, throwable); - List callbacks = agent.canonicalOnToolErrorCallbacks(); - if (callbacks.isEmpty()) { - return pluginResult; - } + if (invocationContext.agent() instanceof LlmAgent) { + LlmAgent agent = (LlmAgent) invocationContext.agent(); - Maybe> callbackResult = - Maybe.defer( - () -> - Flowable.fromIterable(callbacks) - .concatMapMaybe( - callback -> - callback.call(invocationContext, tool, functionArgs, toolContext, ex)) - .firstElement()); + List callbacks = agent.canonicalOnToolErrorCallbacks(); + if (callbacks.isEmpty()) { + return pluginResult; + } - return pluginResult.switchIfEmpty(callbackResult); + Maybe> callbackResult = + Maybe.defer( + () -> + Flowable.fromIterable(callbacks) + .concatMapMaybe( + callback -> + callback + .call(invocationContext, tool, functionArgs, toolContext, ex) + .compose(Tracing.withContext(parentContext))) + .firstElement()); + + return pluginResult.switchIfEmpty(callbackResult); + } + return pluginResult; } - return pluginResult; } private static Maybe> maybeInvokeAfterToolCall( @@ -551,35 +562,39 @@ private static Maybe> maybeInvokeAfterToolCall( BaseTool tool, Map functionArgs, ToolContext toolContext, - Map functionResult) { - if (invocationContext.agent() instanceof LlmAgent) { - LlmAgent agent = (LlmAgent) invocationContext.agent(); + Map functionResult, + Context parentContext) { + if (invocationContext.agent() instanceof LlmAgent agent) { - Maybe> pluginResult = - invocationContext - .pluginManager() - .afterToolCallback(tool, functionArgs, toolContext, functionResult); + try (Scope scope = parentContext.makeCurrent()) { + Maybe> pluginResult = + invocationContext + .pluginManager() + .afterToolCallback(tool, functionArgs, toolContext, functionResult); - List callbacks = agent.canonicalAfterToolCallbacks(); - if (callbacks.isEmpty()) { - return pluginResult; - } + List callbacks = agent.canonicalAfterToolCallbacks(); + if (callbacks.isEmpty()) { + return pluginResult; + } - Maybe> callbackResult = - Maybe.defer( - () -> - Flowable.fromIterable(callbacks) - .concatMapMaybe( - callback -> - callback.call( - invocationContext, - tool, - functionArgs, - toolContext, - functionResult)) - .firstElement()); - - return pluginResult.switchIfEmpty(callbackResult); + Maybe> callbackResult = + Maybe.defer( + () -> + Flowable.fromIterable(callbacks) + .concatMapMaybe( + callback -> + callback + .call( + invocationContext, + tool, + functionArgs, + toolContext, + functionResult) + .compose(Tracing.withContext(parentContext))) + .firstElement()); + + return pluginResult.switchIfEmpty(callbackResult); + } } return Maybe.empty(); } diff --git a/core/src/main/java/com/google/adk/plugins/PluginManager.java b/core/src/main/java/com/google/adk/plugins/PluginManager.java index 56dea936a..74300c9fa 100644 --- a/core/src/main/java/com/google/adk/plugins/PluginManager.java +++ b/core/src/main/java/com/google/adk/plugins/PluginManager.java @@ -21,11 +21,14 @@ import com.google.adk.events.Event; import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; +import com.google.adk.telemetry.Tracing; import com.google.adk.tools.BaseTool; import com.google.adk.tools.ToolContext; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.genai.types.Content; +import io.opentelemetry.context.Context; +import io.opentelemetry.context.Scope; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; @@ -129,6 +132,7 @@ public Completable runAfterRunCallback(InvocationContext invocationContext) { @Override public Completable afterRunCallback(InvocationContext invocationContext) { + Context capturedContext = Context.current(); return Flowable.fromIterable(plugins) .concatMapCompletable( plugin -> @@ -139,11 +143,13 @@ public Completable afterRunCallback(InvocationContext invocationContext) { logger.error( "[{}] Error during callback 'afterRunCallback'", plugin.getName(), - e))); + e)) + .compose(Tracing.withContext(capturedContext))); } @Override public Completable close() { + Context capturedContext = Context.current(); return Flowable.fromIterable(plugins) .concatMapCompletableDelayError( plugin -> @@ -151,8 +157,8 @@ public Completable close() { .close() .doOnError( e -> - logger.error( - "[{}] Error during callback 'close'", plugin.getName(), e))); + logger.error("[{}] Error during callback 'close'", plugin.getName(), e)) + .compose(Tracing.withContext(capturedContext))); } public Maybe runOnEventCallback(InvocationContext invocationContext, Event event) { @@ -275,7 +281,7 @@ public Maybe> onToolErrorCallback( */ private Maybe runMaybeCallbacks( Function> callbackExecutor, String callbackName) { - + Context capturedContext = Context.current(); return Flowable.fromIterable(this.plugins) .concatMapMaybe( plugin -> @@ -294,7 +300,8 @@ private Maybe runMaybeCallbacks( "[{}] Error during callback '{}'", plugin.getName(), callbackName, - e))) + e)) + .compose(Tracing.withContext(capturedContext))) .firstElement(); } } diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index f2cb5b9d5..95ebf5c07 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -375,20 +375,25 @@ public Flowable runAsync( Content newMessage, RunConfig runConfig, @Nullable Map stateDelta) { - Maybe maybeSession = - this.sessionService.getSession(appName, userId, sessionId, Optional.empty()); - return maybeSession - .switchIfEmpty( - Single.defer( - () -> { - if (runConfig.autoCreateSession()) { - return this.sessionService.createSession(appName, userId, null, sessionId); - } - return Single.error( - new IllegalArgumentException( - String.format("Session not found: %s for user %s", sessionId, userId))); - })) - .flatMapPublisher(session -> this.runAsyncImpl(session, newMessage, runConfig, stateDelta)); + return Flowable.defer( + () -> + this.sessionService + .getSession(appName, userId, sessionId, Optional.empty()) + .switchIfEmpty( + Single.defer( + () -> { + if (runConfig.autoCreateSession()) { + return this.sessionService.createSession( + appName, userId, (Map) null, sessionId); + } + return Single.error( + new IllegalArgumentException( + String.format( + "Session not found: %s for user %s", sessionId, userId))); + })) + .flatMapPublisher( + session -> this.runAsyncImpl(session, newMessage, runConfig, stateDelta))) + .compose(Tracing.trace("invocation")); } /** See {@link #runAsync(String, String, Content, RunConfig, Map)}. */ @@ -441,7 +446,8 @@ public Flowable runAsync( Content newMessage, RunConfig runConfig, @Nullable Map stateDelta) { - return runAsyncImpl(session, newMessage, runConfig, stateDelta); + return runAsyncImpl(session, newMessage, runConfig, stateDelta) + .compose(Tracing.trace("invocation")); } /** @@ -510,8 +516,7 @@ protected Flowable runAsyncImpl( Span span = Span.current(); span.setStatus(StatusCode.ERROR, "Error in runAsync Flowable execution"); span.recordException(throwable); - }) - .compose(Tracing.trace("invocation")); + }); } private Flowable runAgentWithFreshSession( @@ -568,7 +573,7 @@ private Flowable runAgentWithFreshSession( .toFlowable() .switchIfEmpty(agentEvents) .concatWith( - Completable.defer(() -> pluginManager.runAfterRunCallback(contextWithUpdatedSession))) + Completable.defer(() -> pluginManager.afterRunCallback(contextWithUpdatedSession))) .concatWith(Completable.defer(() -> compactEvents(updatedSession))); } @@ -641,39 +646,48 @@ private InvocationContext.Builder newInvocationContextBuilder(Session session) { */ public Flowable runLive( Session session, LiveRequestQueue liveRequestQueue, RunConfig runConfig) { + return runLiveImpl(session, liveRequestQueue, runConfig).compose(Tracing.trace("invocation")); + } + + /** + * Runs the agent in live mode, appending generated events to the session. + * + * @return stream of events from the agent. + */ + protected Flowable runLiveImpl( + Session session, @Nullable LiveRequestQueue liveRequestQueue, RunConfig runConfig) { return Flowable.defer( - () -> { - InvocationContext invocationContext = - newInvocationContextForLive(session, liveRequestQueue, runConfig); - - Single invocationContextSingle; - if (invocationContext.agent() instanceof LlmAgent agent) { - invocationContextSingle = - agent - .tools() - .map( - tools -> { - this.addActiveStreamingTools(invocationContext, tools); - return invocationContext; - }); - } else { - invocationContextSingle = Single.just(invocationContext); - } - return invocationContextSingle - .flatMapPublisher( - updatedInvocationContext -> - updatedInvocationContext - .agent() - .runLive(updatedInvocationContext) - .doOnNext(event -> this.sessionService.appendEvent(session, event))) - .doOnError( - throwable -> { - Span span = Span.current(); - span.setStatus(StatusCode.ERROR, "Error in runLive Flowable execution"); - span.recordException(throwable); - }); - }) - .compose(Tracing.trace("invocation")); + () -> { + InvocationContext invocationContext = + newInvocationContextForLive(session, liveRequestQueue, runConfig); + + Single invocationContextSingle; + if (invocationContext.agent() instanceof LlmAgent agent) { + invocationContextSingle = + agent + .tools() + .map( + tools -> { + this.addActiveStreamingTools(invocationContext, tools); + return invocationContext; + }); + } else { + invocationContextSingle = Single.just(invocationContext); + } + return invocationContextSingle + .flatMapPublisher( + updatedInvocationContext -> + updatedInvocationContext + .agent() + .runLive(updatedInvocationContext) + .doOnNext(event -> this.sessionService.appendEvent(session, event))) + .doOnError( + throwable -> { + Span span = Span.current(); + span.setStatus(StatusCode.ERROR, "Error in runLive Flowable execution"); + span.recordException(throwable); + }); + }); } /** @@ -684,19 +698,25 @@ public Flowable runLive( */ public Flowable runLive( String userId, String sessionId, LiveRequestQueue liveRequestQueue, RunConfig runConfig) { - return this.sessionService - .getSession(appName, userId, sessionId, Optional.empty()) - .switchIfEmpty( - Single.defer( - () -> { - if (runConfig.autoCreateSession()) { - return this.sessionService.createSession(appName, userId, null, sessionId); - } - return Single.error( - new IllegalArgumentException( - String.format("Session not found: %s for user %s", sessionId, userId))); - })) - .flatMapPublisher(session -> this.runLive(session, liveRequestQueue, runConfig)); + return Flowable.defer( + () -> + this.sessionService + .getSession(appName, userId, sessionId, Optional.empty()) + .switchIfEmpty( + Single.defer( + () -> { + if (runConfig.autoCreateSession()) { + return this.sessionService.createSession( + appName, userId, (Map) null, sessionId); + } + return Single.error( + new IllegalArgumentException( + String.format( + "Session not found: %s for user %s", sessionId, userId))); + })) + .flatMapPublisher( + session -> this.runLiveImpl(session, liveRequestQueue, runConfig))) + .compose(Tracing.trace("invocation")); } /** diff --git a/core/src/main/java/com/google/adk/telemetry/Tracing.java b/core/src/main/java/com/google/adk/telemetry/Tracing.java index 07a640c37..a7808c8cd 100644 --- a/core/src/main/java/com/google/adk/telemetry/Tracing.java +++ b/core/src/main/java/com/google/adk/telemetry/Tracing.java @@ -35,6 +35,7 @@ import io.opentelemetry.api.trace.Span; import io.opentelemetry.api.trace.Tracer; import io.opentelemetry.context.Context; +import io.opentelemetry.context.ContextKey; import io.opentelemetry.context.Scope; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.CompletableSource; @@ -47,6 +48,7 @@ import io.reactivex.rxjava3.core.Single; import io.reactivex.rxjava3.core.SingleSource; import io.reactivex.rxjava3.core.SingleTransformer; +import java.lang.reflect.Method; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -56,6 +58,7 @@ import java.util.Optional; import java.util.function.Consumer; import java.util.function.Supplier; +import javax.annotation.Nullable; import org.reactivestreams.Publisher; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -101,6 +104,9 @@ public class Tracing { private static final AttributeKey GEN_AI_USAGE_OUTPUT_TOKENS = AttributeKey.longKey("gen_ai.usage.output_tokens"); + private static final ContextKey AGENT_NAME_CONTEXT_KEY = + ContextKey.named("gen_ai.agent.name"); + private static final AttributeKey ADK_TOOL_CALL_ARGS = AttributeKey.stringKey("gcp.vertex.agent.tool_call_args"); private static final AttributeKey ADK_LLM_REQUEST = @@ -372,6 +378,33 @@ public static void traceSendData( }); } + private static String getAgentName(Context context) { + String agentName = context.get(AGENT_NAME_CONTEXT_KEY); + if (agentName != null) { + return agentName; + } + + return getAgentNameFromSpan(Span.fromContext(context)); + } + + @Nullable + private static String getAgentNameFromSpan(Span span) { + if (span == null || !span.getSpanContext().isValid()) { + return null; + } + try { + // Use reflection to try to get the attribute from a ReadableSpan (SDK implementation) + Method getAttributeMethod = span.getClass().getMethod("getAttribute", AttributeKey.class); + Object value = getAttributeMethod.invoke(span, GEN_AI_AGENT_NAME); + if (value instanceof String string) { + return string; + } + } catch (Exception ignored) { + // Not a ReadableSpan or other reflection issue + } + return null; + } + /** * Gets the tracer. * @@ -450,15 +483,87 @@ public static TracerProvider trace(String spanName, Context parentContext * @return A TracerProvider configured for agent invocation. */ public static TracerProvider traceAgent( + Context parent, String spanName, String agentName, String agentDescription, InvocationContext invocationContext) { return new TracerProvider(spanName) + .setParent(parent) + .setAgentName(agentName) .configure( span -> traceAgentInvocation(span, agentName, agentDescription, invocationContext)); } + /** + * Returns a transformer that re-activates a given context for the duration of the stream's + * subscription. + * + * @param context The context to re-activate. + * @param The type of the stream. + * @return A transformer that re-activates the context. + */ + public static ContextTransformer withContext(Context context) { + return new ContextTransformer<>(context); + } + + /** + * A transformer that re-activates a given context for the duration of the stream's subscription. + * + * @param The type of the stream. + */ + public static final class ContextTransformer + implements FlowableTransformer, + SingleTransformer, + MaybeTransformer, + CompletableTransformer { + private final Context context; + + private ContextTransformer(Context context) { + this.context = context; + } + + @Override + @SuppressWarnings("MustBeClosedChecker") + public Publisher apply(Flowable upstream) { + return Flowable.defer( + () -> { + Scope scope = context.makeCurrent(); + return upstream.doFinally(scope::close); + }); + } + + @Override + @SuppressWarnings("MustBeClosedChecker") + public SingleSource apply(Single upstream) { + return Single.defer( + () -> { + Scope scope = context.makeCurrent(); + return upstream.doFinally(scope::close); + }); + } + + @Override + @SuppressWarnings("MustBeClosedChecker") + public MaybeSource apply(Maybe upstream) { + return Maybe.defer( + () -> { + Scope scope = context.makeCurrent(); + return upstream.doFinally(scope::close); + }); + } + + @Override + @SuppressWarnings("MustBeClosedChecker") + public CompletableSource apply(Completable upstream) { + return Completable.defer( + () -> { + Scope scope = context.makeCurrent(); + return upstream.doFinally(scope::close); + }); + } + } + /** * A transformer that manages an OpenTelemetry span and scope for RxJava streams. * @@ -471,6 +576,7 @@ public static final class TracerProvider CompletableTransformer { private final String spanName; private Context explicitParentContext; + private String agentName; private final List> spanConfigurers = new ArrayList<>(); private TracerProvider(String spanName) { @@ -491,6 +597,12 @@ public TracerProvider setParent(Context parentContext) { return this; } + @CanIgnoreReturnValue + private TracerProvider setAgentName(String agentName) { + this.agentName = agentName; + return this; + } + private Context getParentContext() { return explicitParentContext != null ? explicitParentContext : Context.current(); } @@ -501,9 +613,33 @@ private final class TracingLifecycle { @SuppressWarnings("MustBeClosedChecker") void start() { - span = tracer.spanBuilder(spanName).setParent(getParentContext()).startSpan(); + Context parentContext = getParentContext(); + + // Propagate agent name from parent if possible + String agentNameToPropagate = TracerProvider.this.agentName; + if (agentNameToPropagate == null) { + agentNameToPropagate = getAgentName(parentContext); + } + + span = tracer.spanBuilder(spanName).setParent(parentContext).startSpan(); + + if (agentNameToPropagate != null) { + span.setAttribute(GEN_AI_AGENT_NAME, agentNameToPropagate); + } + spanConfigurers.forEach(c -> c.accept(span)); - scope = span.makeCurrent(); + + // Try to capture the final agent name (might have been set by a configurer) + String finalAgentName = getAgentNameFromSpan(span); + if (finalAgentName == null) { + finalAgentName = agentNameToPropagate; + } + + Context nextContext = parentContext.with(span); + if (finalAgentName != null) { + nextContext = nextContext.with(AGENT_NAME_CONTEXT_KEY, finalAgentName); + } + scope = nextContext.makeCurrent(); } void end() { @@ -521,7 +657,8 @@ public Publisher apply(Flowable upstream) { return Flowable.defer( () -> { TracingLifecycle lifecycle = new TracingLifecycle(); - return upstream.doOnSubscribe(s -> lifecycle.start()).doFinally(lifecycle::end); + lifecycle.start(); + return upstream.doFinally(lifecycle::end); }); } @@ -530,7 +667,8 @@ public SingleSource apply(Single upstream) { return Single.defer( () -> { TracingLifecycle lifecycle = new TracingLifecycle(); - return upstream.doOnSubscribe(s -> lifecycle.start()).doFinally(lifecycle::end); + lifecycle.start(); + return upstream.doFinally(lifecycle::end); }); } @@ -539,7 +677,8 @@ public MaybeSource apply(Maybe upstream) { return Maybe.defer( () -> { TracingLifecycle lifecycle = new TracingLifecycle(); - return upstream.doOnSubscribe(s -> lifecycle.start()).doFinally(lifecycle::end); + lifecycle.start(); + return upstream.doFinally(lifecycle::end); }); } @@ -548,7 +687,8 @@ public CompletableSource apply(Completable upstream) { return Completable.defer( () -> { TracingLifecycle lifecycle = new TracingLifecycle(); - return upstream.doOnSubscribe(s -> lifecycle.start()).doFinally(lifecycle::end); + lifecycle.start(); + return upstream.doFinally(lifecycle::end); }); } }