diff --git a/core/src/main/java/com/google/adk/agents/BaseAgent.java b/core/src/main/java/com/google/adk/agents/BaseAgent.java index 226e61abe..2f46b0bd4 100644 --- a/core/src/main/java/com/google/adk/agents/BaseAgent.java +++ b/core/src/main/java/com/google/adk/agents/BaseAgent.java @@ -29,6 +29,7 @@ import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.errorprone.annotations.DoNotCall; import com.google.genai.types.Content; +import io.opentelemetry.context.Context; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; @@ -312,38 +313,57 @@ public Flowable runAsync(InvocationContext parentContext) { private Flowable run( InvocationContext parentContext, Function> runImplementation) { + Context otelParentContext = Context.current(); + InvocationContext invocationContext = createInvocationContext(parentContext); + return Flowable.defer( - () -> { - InvocationContext invocationContext = createInvocationContext(parentContext); - - return callCallback( - beforeCallbacksToFunctions( - invocationContext.pluginManager(), beforeAgentCallback), - invocationContext) - .flatMapPublisher( - beforeEventOpt -> { - if (invocationContext.endInvocation()) { - return Flowable.fromOptional(beforeEventOpt); - } - - Flowable beforeEvents = Flowable.fromOptional(beforeEventOpt); - Flowable mainEvents = - Flowable.defer(() -> runImplementation.apply(invocationContext)); - Flowable afterEvents = - Flowable.defer( - () -> - callCallback( - afterCallbacksToFunctions( - invocationContext.pluginManager(), afterAgentCallback), - invocationContext) - .flatMapPublisher(Flowable::fromOptional)); - - return Flowable.concat(beforeEvents, mainEvents, afterEvents); - }) - .compose( - Tracing.traceAgent( - "invoke_agent " + name(), name(), description(), invocationContext)); - }); + () -> { + return callCallback( + beforeCallbacksToFunctions( + invocationContext.pluginManager(), beforeAgentCallback), + invocationContext, + otelParentContext) + .flatMapPublisher( + beforeEvent -> { + if (invocationContext.endInvocation()) { + return Flowable.just(beforeEvent); + } + + return Flowable.just(beforeEvent) + .concatWith( + runMainAndAfter( + invocationContext, otelParentContext, runImplementation)); + }) + .switchIfEmpty( + Flowable.defer( + () -> + runMainAndAfter( + invocationContext, otelParentContext, runImplementation))); + }) + .compose( + Tracing.traceAgent( + otelParentContext, + "invoke_agent " + name(), + name(), + description(), + invocationContext)); + } + + private Flowable runMainAndAfter( + InvocationContext invocationContext, + Context otelParentContext, + Function> runImplementation) { + Flowable mainEvents = + runImplementation.apply(invocationContext).compose(Tracing.withContext(otelParentContext)); + Flowable afterEvents = + callCallback( + afterCallbacksToFunctions(invocationContext.pluginManager(), afterAgentCallback), + invocationContext, + otelParentContext) + .flatMapPublisher(Flowable::just) + .compose(Tracing.withContext(otelParentContext)); + + return Flowable.concat(mainEvents, afterEvents); } /** @@ -383,13 +403,14 @@ private ImmutableList>> callbacksTo * * @param agentCallbacks Callback functions. * @param invocationContext Current invocation context. - * @return single emitting first event, or empty if none. + * @return Maybe emitting first event, or empty if none. */ - private Single> callCallback( + private Maybe callCallback( List>> agentCallbacks, - InvocationContext invocationContext) { + InvocationContext invocationContext, + Context otelParentContext) { if (agentCallbacks.isEmpty()) { - return Single.just(Optional.empty()); + return Maybe.empty(); } CallbackContext callbackContext = @@ -397,28 +418,26 @@ private Single> callCallback( return Flowable.fromIterable(agentCallbacks) .concatMap( - callback -> { - Maybe maybeContent = callback.apply(callbackContext); - - return maybeContent - .map( - content -> { - invocationContext.setEndInvocation(true); - return Optional.of( - Event.builder() - .id(Event.generateEventId()) - .invocationId(invocationContext.invocationId()) - .author(name()) - .branch(invocationContext.branch()) - .actions(callbackContext.eventActions()) - .content(content) - .build()); - }) - .toFlowable(); - }) + callback -> + callback + .apply(callbackContext) + .map( + content -> { + invocationContext.setEndInvocation(true); + return Event.builder() + .id(Event.generateEventId()) + .invocationId(invocationContext.invocationId()) + .author(name()) + .branch(invocationContext.branch()) + .actions(callbackContext.eventActions()) + .content(content) + .build(); + }) + .toFlowable() + .compose(Tracing.withContext(otelParentContext))) .firstElement() .switchIfEmpty( - Single.defer( + Maybe.defer( () -> { if (callbackContext.state().hasDelta()) { Event.Builder eventBuilder = @@ -429,9 +448,9 @@ private Single> callCallback( .branch(invocationContext.branch()) .actions(callbackContext.eventActions()); - return Single.just(Optional.of(eventBuilder.build())); + return Maybe.just(eventBuilder.build()); } else { - return Single.just(Optional.empty()); + return Maybe.empty(); } })); } diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index 1249728d8..746016c1a 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -92,7 +92,9 @@ public BaseLlmFlow( * events generated by them. */ protected Flowable preprocess( - InvocationContext context, AtomicReference llmRequestRef) { + InvocationContext context, + AtomicReference llmRequestRef, + Context otelParentContext) { LlmAgent agent = (LlmAgent) context.agent(); RequestProcessor toolsProcessor = @@ -104,7 +106,8 @@ protected Flowable preprocess( tool -> tool.processLlmRequest(builder, ToolContext.builder(ctx).build())) .andThen( Single.fromCallable( - () -> RequestProcessingResult.create(builder.build(), ImmutableList.of()))); + () -> RequestProcessingResult.create(builder.build(), ImmutableList.of()))) + .compose(Tracing.withContext(otelParentContext)); }; Iterable allProcessors = @@ -113,10 +116,12 @@ protected Flowable preprocess( return Flowable.fromIterable(allProcessors) .concatMap( processor -> - Single.defer(() -> processor.processRequest(context, llmRequestRef.get())) + processor + .processRequest(context, llmRequestRef.get()) .doOnSuccess(result -> llmRequestRef.set(result.updatedRequest())) .flattenAsFlowable( - result -> result.events() != null ? result.events() : ImmutableList.of())); + result -> result.events() != null ? result.events() : ImmutableList.of()) + .compose(Tracing.withContext(otelParentContext))); } /** @@ -147,12 +152,10 @@ protected Flowable postprocess( Context parentContext = Context.current(); return currentLlmResponse.flatMapPublisher( - updatedResponse -> { - try (Scope scope = parentContext.makeCurrent()) { - return buildPostprocessingEvents( - updatedResponse, eventIterables, context, baseEventForLlmResponse, llmRequest); - } - }); + updatedResponse -> + buildPostprocessingEvents( + updatedResponse, eventIterables, context, baseEventForLlmResponse, llmRequest) + .compose(Tracing.withContext(parentContext))); } /** @@ -164,84 +167,99 @@ protected Flowable postprocess( * callbacks. Callbacks should not rely on its ID if they create their own separate events. */ private Flowable callLlm( - InvocationContext context, LlmRequest llmRequest, Event eventForCallbackUsage) { + InvocationContext context, + LlmRequest llmRequest, + Event eventForCallbackUsage, + Context otelParentContext) { LlmAgent agent = (LlmAgent) context.agent(); LlmRequest.Builder llmRequestBuilder = llmRequest.toBuilder(); - return handleBeforeModelCallback(context, llmRequestBuilder, eventForCallbackUsage) - .flatMapPublisher( - beforeResponse -> { - if (beforeResponse.isPresent()) { - return Flowable.just(beforeResponse.get()); - } - BaseLlm llm = - agent.resolvedModel().model().isPresent() - ? agent.resolvedModel().model().get() - : LlmRegistry.getLlm(agent.resolvedModel().modelName().get()); - return llm.generateContent( - llmRequestBuilder.build(), - context.runConfig().streamingMode() == StreamingMode.SSE) - .onErrorResumeNext( - exception -> - handleOnModelErrorCallback( - context, llmRequestBuilder, eventForCallbackUsage, exception) - .switchIfEmpty(Single.error(exception)) - .toFlowable()) - .doOnNext( - llmResp -> - Tracing.traceCallLlm( - context, - eventForCallbackUsage.id(), - llmRequestBuilder.build(), - llmResp)) - .doOnError( - error -> { - Span span = Span.current(); - span.setStatus(StatusCode.ERROR, error.getMessage()); - span.recordException(error); - }) - .compose(Tracing.trace("call_llm")) - .concatMap( - llmResp -> - handleAfterModelCallback(context, llmResp, eventForCallbackUsage) - .toFlowable()); - }); + return handleBeforeModelCallback( + context, llmRequestBuilder, eventForCallbackUsage, otelParentContext) + .flatMapPublisher(Flowable::just) + .switchIfEmpty( + Flowable.defer( + () -> { + BaseLlm llm = + agent.resolvedModel().model().isPresent() + ? agent.resolvedModel().model().get() + : LlmRegistry.getLlm(agent.resolvedModel().modelName().get()); + return llm.generateContent( + llmRequestBuilder.build(), + context.runConfig().streamingMode() == StreamingMode.SSE) + .onErrorResumeNext( + exception -> + handleOnModelErrorCallback( + context, + llmRequestBuilder, + eventForCallbackUsage, + exception, + otelParentContext) + .switchIfEmpty(Single.error(exception)) + .toFlowable()) + .doOnNext( + llmResp -> + Tracing.traceCallLlm( + context, + eventForCallbackUsage.id(), + llmRequestBuilder.build(), + llmResp)) + .doOnError( + error -> { + Span span = Span.current(); + span.setStatus(StatusCode.ERROR, error.getMessage()); + span.recordException(error); + }) + .compose(Tracing.trace("call_llm", otelParentContext)) + .concatMap( + llmResp -> + handleAfterModelCallback( + context, llmResp, eventForCallbackUsage, otelParentContext) + .toFlowable()); + })); } /** * Invokes {@link BeforeModelCallback}s. If any returns a response, it's used instead of calling * the LLM. * - * @return A {@link Single} with the callback result or {@link Optional#empty()}. + * @return A {@link Maybe} with the callback result. */ - private Single> handleBeforeModelCallback( - InvocationContext context, LlmRequest.Builder llmRequestBuilder, Event modelResponseEvent) { + private Maybe handleBeforeModelCallback( + InvocationContext context, + LlmRequest.Builder llmRequestBuilder, + Event modelResponseEvent, + Context otelParentContext) { Event callbackEvent = modelResponseEvent.toBuilder().build(); CallbackContext callbackContext = new CallbackContext(context, callbackEvent.actions(), callbackEvent.id()); - Maybe pluginResult = - context.pluginManager().beforeModelCallback(callbackContext, llmRequestBuilder); + Maybe pluginResult; + try (Scope scope = otelParentContext.makeCurrent()) { + pluginResult = + context.pluginManager().beforeModelCallback(callbackContext, llmRequestBuilder); + } LlmAgent agent = (LlmAgent) context.agent(); List callbacks = agent.canonicalBeforeModelCallbacks(); if (callbacks.isEmpty()) { - return pluginResult.map(Optional::of).defaultIfEmpty(Optional.empty()); + return pluginResult; } Maybe callbackResult = Maybe.defer( () -> Flowable.fromIterable(callbacks) - .concatMapMaybe(callback -> callback.call(callbackContext, llmRequestBuilder)) + .concatMapMaybe( + callback -> + callback + .call(callbackContext, llmRequestBuilder) + .compose(Tracing.withContext(otelParentContext))) .firstElement()); - return pluginResult - .switchIfEmpty(callbackResult) - .map(Optional::of) - .defaultIfEmpty(Optional.empty()); + return pluginResult.switchIfEmpty(callbackResult); } /** @@ -254,14 +272,20 @@ private Maybe handleOnModelErrorCallback( InvocationContext context, LlmRequest.Builder llmRequestBuilder, Event modelResponseEvent, - Throwable throwable) { + Throwable throwable, + Context otelParentContext) { Event callbackEvent = modelResponseEvent.toBuilder().build(); CallbackContext callbackContext = new CallbackContext(context, callbackEvent.actions(), callbackEvent.id()); Exception ex = throwable instanceof Exception e ? e : new Exception(throwable); - Maybe pluginResult = - context.pluginManager().onModelErrorCallback(callbackContext, llmRequestBuilder, throwable); + Maybe pluginResult; + try (Scope scope = otelParentContext.makeCurrent()) { + pluginResult = + context + .pluginManager() + .onModelErrorCallback(callbackContext, llmRequestBuilder, throwable); + } LlmAgent agent = (LlmAgent) context.agent(); List callbacks = agent.canonicalOnModelErrorCallbacks(); @@ -275,7 +299,11 @@ private Maybe handleOnModelErrorCallback( () -> { LlmRequest llmRequest = llmRequestBuilder.build(); return Flowable.fromIterable(callbacks) - .concatMapMaybe(callback -> callback.call(callbackContext, llmRequest, ex)) + .concatMapMaybe( + callback -> + callback + .call(callbackContext, llmRequest, ex) + .compose(Tracing.withContext(otelParentContext))) .firstElement(); }); @@ -289,13 +317,18 @@ private Maybe handleOnModelErrorCallback( * @return A {@link Single} with the final {@link LlmResponse}. */ private Single handleAfterModelCallback( - InvocationContext context, LlmResponse llmResponse, Event modelResponseEvent) { + InvocationContext context, + LlmResponse llmResponse, + Event modelResponseEvent, + Context otelParentContext) { Event callbackEvent = modelResponseEvent.toBuilder().build(); CallbackContext callbackContext = new CallbackContext(context, callbackEvent.actions(), callbackEvent.id()); - Maybe pluginResult = - context.pluginManager().afterModelCallback(callbackContext, llmResponse); + Maybe pluginResult; + try (Scope scope = otelParentContext.makeCurrent()) { + pluginResult = context.pluginManager().afterModelCallback(callbackContext, llmResponse); + } LlmAgent agent = (LlmAgent) context.agent(); List callbacks = agent.canonicalAfterModelCallbacks(); @@ -308,7 +341,11 @@ private Single handleAfterModelCallback( Maybe.defer( () -> Flowable.fromIterable(callbacks) - .concatMapMaybe(callback -> callback.call(callbackContext, llmResponse)) + .concatMapMaybe( + callback -> + callback + .call(callbackContext, llmResponse) + .compose(Tracing.withContext(otelParentContext))) .firstElement()); return pluginResult.switchIfEmpty(callbackResult).defaultIfEmpty(llmResponse); @@ -323,13 +360,12 @@ private Single handleAfterModelCallback( * @throws LlmCallsLimitExceededException if the agent exceeds allowed LLM invocations. * @throws IllegalStateException if a transfer agent is specified but not found. */ - private Flowable runOneStep(InvocationContext context) { + private Flowable runOneStep(InvocationContext context, Context otelParentContext) { AtomicReference llmRequestRef = new AtomicReference<>(LlmRequest.builder().build()); return Flowable.defer( () -> { - Context currentContext = Context.current(); - return preprocess(context, llmRequestRef) + return preprocess(context, llmRequestRef, otelParentContext) .concatWith( Flowable.defer( () -> { @@ -355,11 +391,14 @@ private Flowable runOneStep(InvocationContext context) { .build(); mutableEventTemplate.setTimestamp(0L); - return callLlm(context, llmRequestAfterPreprocess, mutableEventTemplate) + return callLlm( + context, + llmRequestAfterPreprocess, + mutableEventTemplate, + otelParentContext) .concatMap( - llmResponse -> { - try (Scope postScope = currentContext.makeCurrent()) { - return postprocess( + llmResponse -> + postprocess( context, mutableEventTemplate, llmRequestAfterPreprocess, @@ -371,9 +410,8 @@ private Flowable runOneStep(InvocationContext context) { logger.debug( "Resetting event ID from {} to {}", oldId, newId); mutableEventTemplate.setId(newId); - }); - } - }) + }) + .compose(Tracing.withContext(otelParentContext))) .concatMap( event -> { Flowable postProcessedEvents = Flowable.just(event); @@ -407,11 +445,12 @@ private Flowable runOneStep(InvocationContext context) { */ @Override public Flowable run(InvocationContext invocationContext) { - return run(invocationContext, 0); + return run(invocationContext, Context.current(), 0); } - private Flowable run(InvocationContext invocationContext, int stepsCompleted) { - Flowable currentStepEvents = runOneStep(invocationContext).cache(); + private Flowable run( + InvocationContext invocationContext, Context otelParentContext, int stepsCompleted) { + Flowable currentStepEvents = runOneStep(invocationContext, otelParentContext).cache(); if (stepsCompleted + 1 >= maxSteps) { logger.debug("Ending flow execution because max steps reached."); return currentStepEvents; @@ -431,7 +470,7 @@ private Flowable run(InvocationContext invocationContext, int stepsComple return Flowable.empty(); } else { logger.debug("Continuing to next step of the flow."); - return run(invocationContext, stepsCompleted + 1); + return run(invocationContext, otelParentContext, stepsCompleted + 1); } })); } @@ -446,8 +485,10 @@ private Flowable run(InvocationContext invocationContext, int stepsComple */ @Override public Flowable runLive(InvocationContext invocationContext) { + Context otelParentContext = Context.current(); AtomicReference llmRequestRef = new AtomicReference<>(LlmRequest.builder().build()); - Flowable preprocessEvents = preprocess(invocationContext, llmRequestRef); + Flowable preprocessEvents = + preprocess(invocationContext, llmRequestRef, otelParentContext); return preprocessEvents.concatWith( Flowable.defer( @@ -485,7 +526,7 @@ public Flowable runLive(InvocationContext invocationContext) { eventIdForSendData, llmRequestAfterPreprocess.contents()); }) - .compose(Tracing.trace("send_data")); + .compose(Tracing.trace("send_data", otelParentContext)); Flowable liveRequests = invocationContext @@ -542,13 +583,15 @@ public void onError(Throwable e) { .receive() .flatMap( llmResponse -> { - Event baseEventForThisLlmResponse = - liveEventBuilderTemplate.id(Event.generateEventId()).build(); - return postprocess( - invocationContext, - baseEventForThisLlmResponse, - llmRequestAfterPreprocess, - llmResponse); + try (Scope scope = otelParentContext.makeCurrent()) { + Event baseEventForThisLlmResponse = + liveEventBuilderTemplate.id(Event.generateEventId()).build(); + return postprocess( + invocationContext, + baseEventForThisLlmResponse, + llmRequestAfterPreprocess, + llmResponse); + } }) .flatMap( event -> { diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java index 269764046..056987fc2 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java @@ -257,7 +257,8 @@ private static Function> getFunctionCallMapper( functionCall.args().map(HashMap::new).orElse(new HashMap<>()); Maybe> maybeFunctionResult = - maybeInvokeBeforeToolCall(invocationContext, tool, functionArgs, toolContext) + maybeInvokeBeforeToolCall( + invocationContext, tool, functionArgs, toolContext, parentContext) .switchIfEmpty( Maybe.defer( () -> { @@ -395,48 +396,49 @@ private static Maybe postProcessFunctionResult( .defaultIfEmpty(Optional.empty()) .onErrorResumeNext( t -> { - Maybe> errorCallbackResult = - handleOnToolErrorCallback(invocationContext, tool, functionArgs, toolContext, t); - Maybe>> mappedResult; - if (isLive) { - // In live mode, handle null results from the error callback gracefully. - mappedResult = errorCallbackResult.map(Optional::ofNullable); - } else { - // In non-live mode, a null result from the error callback will cause an NPE - // when wrapped with Optional.of(), potentially matching prior behavior. - mappedResult = errorCallbackResult.map(Optional::of); + try (Scope scope = parentContext.makeCurrent()) { + Maybe> errorCallbackResult = + handleOnToolErrorCallback( + invocationContext, tool, functionArgs, toolContext, t, parentContext); + Maybe>> mappedResult; + if (isLive) { + // In live mode, handle null results from the error callback gracefully. + mappedResult = errorCallbackResult.map(Optional::ofNullable); + } else { + // In non-live mode, a null result from the error callback will cause an NPE + // when wrapped with Optional.of(), potentially matching prior behavior. + mappedResult = errorCallbackResult.map(Optional::of); + } + return mappedResult.switchIfEmpty(Single.error(t)); } - return mappedResult.switchIfEmpty(Single.error(t)); }) .flatMapMaybe( optionalInitialResult -> { - try (Scope scope = parentContext.makeCurrent()) { - Map initialFunctionResult = optionalInitialResult.orElse(null); - - return maybeInvokeAfterToolCall( - invocationContext, tool, functionArgs, toolContext, initialFunctionResult) - .map(Optional::of) - .defaultIfEmpty(Optional.ofNullable(initialFunctionResult)) - .flatMapMaybe( - finalOptionalResult -> { - Map finalFunctionResult = - finalOptionalResult.orElse(null); - if (tool.longRunning() && finalFunctionResult == null) { - return Maybe.empty(); - } - return Maybe.fromCallable( - () -> - buildResponseEvent( - tool, - finalFunctionResult, - toolContext, - invocationContext)) - .compose( - Tracing.trace( - "tool_response [" + tool.name() + "]", parentContext)) - .doOnSuccess(event -> Tracing.traceToolResponse(event.id(), event)); - }); - } + Map initialFunctionResult = optionalInitialResult.orElse(null); + + return maybeInvokeAfterToolCall( + invocationContext, + tool, + functionArgs, + toolContext, + initialFunctionResult, + parentContext) + .map(Optional::of) + .defaultIfEmpty(Optional.ofNullable(initialFunctionResult)) + .flatMapMaybe( + finalOptionalResult -> { + Map finalFunctionResult = finalOptionalResult.orElse(null); + if (tool.longRunning() && finalFunctionResult == null) { + return Maybe.empty(); + } + return Maybe.fromCallable( + () -> + buildResponseEvent( + tool, finalFunctionResult, toolContext, invocationContext)) + .compose( + Tracing.trace("tool_response [" + tool.name() + "]", parentContext)) + .doOnSuccess(event -> Tracing.traceToolResponse(event.id(), event)); + }); }); } @@ -479,28 +481,32 @@ private static Maybe> maybeInvokeBeforeToolCall( InvocationContext invocationContext, BaseTool tool, Map functionArgs, - ToolContext toolContext) { - if (invocationContext.agent() instanceof LlmAgent) { - LlmAgent agent = (LlmAgent) invocationContext.agent(); + ToolContext toolContext, + Context parentContext) { + if (invocationContext.agent() instanceof LlmAgent agent) { + try (Scope scope = parentContext.makeCurrent()) { - Maybe> pluginResult = - invocationContext.pluginManager().beforeToolCallback(tool, functionArgs, toolContext); + Maybe> pluginResult = + invocationContext.pluginManager().beforeToolCallback(tool, functionArgs, toolContext); - List callbacks = agent.canonicalBeforeToolCallbacks(); - if (callbacks.isEmpty()) { - return pluginResult; - } - - Maybe> callbackResult = - Maybe.defer( - () -> - Flowable.fromIterable(callbacks) - .concatMapMaybe( - callback -> - callback.call(invocationContext, tool, functionArgs, toolContext)) - .firstElement()); + List callbacks = agent.canonicalBeforeToolCallbacks(); + if (callbacks.isEmpty()) { + return pluginResult; + } - return pluginResult.switchIfEmpty(callbackResult); + Maybe> callbackResult = + Maybe.defer( + () -> + Flowable.fromIterable(callbacks) + .concatMapMaybe( + callback -> + callback + .call(invocationContext, tool, functionArgs, toolContext) + .compose(Tracing.withContext(parentContext))) + .firstElement()); + + return pluginResult.switchIfEmpty(callbackResult); + } } return Maybe.empty(); } @@ -516,34 +522,39 @@ private static Maybe> handleOnToolErrorCallback( BaseTool tool, Map functionArgs, ToolContext toolContext, - Throwable throwable) { + Throwable throwable, + Context parentContext) { Exception ex = throwable instanceof Exception exception ? exception : new Exception(throwable); - Maybe> pluginResult = - invocationContext - .pluginManager() - .onToolErrorCallback(tool, functionArgs, toolContext, throwable); - - if (invocationContext.agent() instanceof LlmAgent) { - LlmAgent agent = (LlmAgent) invocationContext.agent(); + try (Scope scope = parentContext.makeCurrent()) { + Maybe> pluginResult = + invocationContext + .pluginManager() + .onToolErrorCallback(tool, functionArgs, toolContext, throwable); - List callbacks = agent.canonicalOnToolErrorCallbacks(); - if (callbacks.isEmpty()) { - return pluginResult; - } + if (invocationContext.agent() instanceof LlmAgent) { + LlmAgent agent = (LlmAgent) invocationContext.agent(); - Maybe> callbackResult = - Maybe.defer( - () -> - Flowable.fromIterable(callbacks) - .concatMapMaybe( - callback -> - callback.call(invocationContext, tool, functionArgs, toolContext, ex)) - .firstElement()); + List callbacks = agent.canonicalOnToolErrorCallbacks(); + if (callbacks.isEmpty()) { + return pluginResult; + } - return pluginResult.switchIfEmpty(callbackResult); + Maybe> callbackResult = + Maybe.defer( + () -> + Flowable.fromIterable(callbacks) + .concatMapMaybe( + callback -> + callback + .call(invocationContext, tool, functionArgs, toolContext, ex) + .compose(Tracing.withContext(parentContext))) + .firstElement()); + + return pluginResult.switchIfEmpty(callbackResult); + } + return pluginResult; } - return pluginResult; } private static Maybe> maybeInvokeAfterToolCall( @@ -551,35 +562,39 @@ private static Maybe> maybeInvokeAfterToolCall( BaseTool tool, Map functionArgs, ToolContext toolContext, - Map functionResult) { - if (invocationContext.agent() instanceof LlmAgent) { - LlmAgent agent = (LlmAgent) invocationContext.agent(); + Map functionResult, + Context parentContext) { + if (invocationContext.agent() instanceof LlmAgent agent) { - Maybe> pluginResult = - invocationContext - .pluginManager() - .afterToolCallback(tool, functionArgs, toolContext, functionResult); + try (Scope scope = parentContext.makeCurrent()) { + Maybe> pluginResult = + invocationContext + .pluginManager() + .afterToolCallback(tool, functionArgs, toolContext, functionResult); - List callbacks = agent.canonicalAfterToolCallbacks(); - if (callbacks.isEmpty()) { - return pluginResult; - } + List callbacks = agent.canonicalAfterToolCallbacks(); + if (callbacks.isEmpty()) { + return pluginResult; + } - Maybe> callbackResult = - Maybe.defer( - () -> - Flowable.fromIterable(callbacks) - .concatMapMaybe( - callback -> - callback.call( - invocationContext, - tool, - functionArgs, - toolContext, - functionResult)) - .firstElement()); - - return pluginResult.switchIfEmpty(callbackResult); + Maybe> callbackResult = + Maybe.defer( + () -> + Flowable.fromIterable(callbacks) + .concatMapMaybe( + callback -> + callback + .call( + invocationContext, + tool, + functionArgs, + toolContext, + functionResult) + .compose(Tracing.withContext(parentContext))) + .firstElement()); + + return pluginResult.switchIfEmpty(callbackResult); + } } return Maybe.empty(); } diff --git a/core/src/main/java/com/google/adk/plugins/PluginManager.java b/core/src/main/java/com/google/adk/plugins/PluginManager.java index 56dea936a..74300c9fa 100644 --- a/core/src/main/java/com/google/adk/plugins/PluginManager.java +++ b/core/src/main/java/com/google/adk/plugins/PluginManager.java @@ -21,11 +21,14 @@ import com.google.adk.events.Event; import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; +import com.google.adk.telemetry.Tracing; import com.google.adk.tools.BaseTool; import com.google.adk.tools.ToolContext; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.genai.types.Content; +import io.opentelemetry.context.Context; +import io.opentelemetry.context.Scope; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; @@ -129,6 +132,7 @@ public Completable runAfterRunCallback(InvocationContext invocationContext) { @Override public Completable afterRunCallback(InvocationContext invocationContext) { + Context capturedContext = Context.current(); return Flowable.fromIterable(plugins) .concatMapCompletable( plugin -> @@ -139,11 +143,13 @@ public Completable afterRunCallback(InvocationContext invocationContext) { logger.error( "[{}] Error during callback 'afterRunCallback'", plugin.getName(), - e))); + e)) + .compose(Tracing.withContext(capturedContext))); } @Override public Completable close() { + Context capturedContext = Context.current(); return Flowable.fromIterable(plugins) .concatMapCompletableDelayError( plugin -> @@ -151,8 +157,8 @@ public Completable close() { .close() .doOnError( e -> - logger.error( - "[{}] Error during callback 'close'", plugin.getName(), e))); + logger.error("[{}] Error during callback 'close'", plugin.getName(), e)) + .compose(Tracing.withContext(capturedContext))); } public Maybe runOnEventCallback(InvocationContext invocationContext, Event event) { @@ -275,7 +281,7 @@ public Maybe> onToolErrorCallback( */ private Maybe runMaybeCallbacks( Function> callbackExecutor, String callbackName) { - + Context capturedContext = Context.current(); return Flowable.fromIterable(this.plugins) .concatMapMaybe( plugin -> @@ -294,7 +300,8 @@ private Maybe runMaybeCallbacks( "[{}] Error during callback '{}'", plugin.getName(), callbackName, - e))) + e)) + .compose(Tracing.withContext(capturedContext))) .firstElement(); } } diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index f2cb5b9d5..95ebf5c07 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -375,20 +375,25 @@ public Flowable runAsync( Content newMessage, RunConfig runConfig, @Nullable Map stateDelta) { - Maybe maybeSession = - this.sessionService.getSession(appName, userId, sessionId, Optional.empty()); - return maybeSession - .switchIfEmpty( - Single.defer( - () -> { - if (runConfig.autoCreateSession()) { - return this.sessionService.createSession(appName, userId, null, sessionId); - } - return Single.error( - new IllegalArgumentException( - String.format("Session not found: %s for user %s", sessionId, userId))); - })) - .flatMapPublisher(session -> this.runAsyncImpl(session, newMessage, runConfig, stateDelta)); + return Flowable.defer( + () -> + this.sessionService + .getSession(appName, userId, sessionId, Optional.empty()) + .switchIfEmpty( + Single.defer( + () -> { + if (runConfig.autoCreateSession()) { + return this.sessionService.createSession( + appName, userId, (Map) null, sessionId); + } + return Single.error( + new IllegalArgumentException( + String.format( + "Session not found: %s for user %s", sessionId, userId))); + })) + .flatMapPublisher( + session -> this.runAsyncImpl(session, newMessage, runConfig, stateDelta))) + .compose(Tracing.trace("invocation")); } /** See {@link #runAsync(String, String, Content, RunConfig, Map)}. */ @@ -441,7 +446,8 @@ public Flowable runAsync( Content newMessage, RunConfig runConfig, @Nullable Map stateDelta) { - return runAsyncImpl(session, newMessage, runConfig, stateDelta); + return runAsyncImpl(session, newMessage, runConfig, stateDelta) + .compose(Tracing.trace("invocation")); } /** @@ -510,8 +516,7 @@ protected Flowable runAsyncImpl( Span span = Span.current(); span.setStatus(StatusCode.ERROR, "Error in runAsync Flowable execution"); span.recordException(throwable); - }) - .compose(Tracing.trace("invocation")); + }); } private Flowable runAgentWithFreshSession( @@ -568,7 +573,7 @@ private Flowable runAgentWithFreshSession( .toFlowable() .switchIfEmpty(agentEvents) .concatWith( - Completable.defer(() -> pluginManager.runAfterRunCallback(contextWithUpdatedSession))) + Completable.defer(() -> pluginManager.afterRunCallback(contextWithUpdatedSession))) .concatWith(Completable.defer(() -> compactEvents(updatedSession))); } @@ -641,39 +646,48 @@ private InvocationContext.Builder newInvocationContextBuilder(Session session) { */ public Flowable runLive( Session session, LiveRequestQueue liveRequestQueue, RunConfig runConfig) { + return runLiveImpl(session, liveRequestQueue, runConfig).compose(Tracing.trace("invocation")); + } + + /** + * Runs the agent in live mode, appending generated events to the session. + * + * @return stream of events from the agent. + */ + protected Flowable runLiveImpl( + Session session, @Nullable LiveRequestQueue liveRequestQueue, RunConfig runConfig) { return Flowable.defer( - () -> { - InvocationContext invocationContext = - newInvocationContextForLive(session, liveRequestQueue, runConfig); - - Single invocationContextSingle; - if (invocationContext.agent() instanceof LlmAgent agent) { - invocationContextSingle = - agent - .tools() - .map( - tools -> { - this.addActiveStreamingTools(invocationContext, tools); - return invocationContext; - }); - } else { - invocationContextSingle = Single.just(invocationContext); - } - return invocationContextSingle - .flatMapPublisher( - updatedInvocationContext -> - updatedInvocationContext - .agent() - .runLive(updatedInvocationContext) - .doOnNext(event -> this.sessionService.appendEvent(session, event))) - .doOnError( - throwable -> { - Span span = Span.current(); - span.setStatus(StatusCode.ERROR, "Error in runLive Flowable execution"); - span.recordException(throwable); - }); - }) - .compose(Tracing.trace("invocation")); + () -> { + InvocationContext invocationContext = + newInvocationContextForLive(session, liveRequestQueue, runConfig); + + Single invocationContextSingle; + if (invocationContext.agent() instanceof LlmAgent agent) { + invocationContextSingle = + agent + .tools() + .map( + tools -> { + this.addActiveStreamingTools(invocationContext, tools); + return invocationContext; + }); + } else { + invocationContextSingle = Single.just(invocationContext); + } + return invocationContextSingle + .flatMapPublisher( + updatedInvocationContext -> + updatedInvocationContext + .agent() + .runLive(updatedInvocationContext) + .doOnNext(event -> this.sessionService.appendEvent(session, event))) + .doOnError( + throwable -> { + Span span = Span.current(); + span.setStatus(StatusCode.ERROR, "Error in runLive Flowable execution"); + span.recordException(throwable); + }); + }); } /** @@ -684,19 +698,25 @@ public Flowable runLive( */ public Flowable runLive( String userId, String sessionId, LiveRequestQueue liveRequestQueue, RunConfig runConfig) { - return this.sessionService - .getSession(appName, userId, sessionId, Optional.empty()) - .switchIfEmpty( - Single.defer( - () -> { - if (runConfig.autoCreateSession()) { - return this.sessionService.createSession(appName, userId, null, sessionId); - } - return Single.error( - new IllegalArgumentException( - String.format("Session not found: %s for user %s", sessionId, userId))); - })) - .flatMapPublisher(session -> this.runLive(session, liveRequestQueue, runConfig)); + return Flowable.defer( + () -> + this.sessionService + .getSession(appName, userId, sessionId, Optional.empty()) + .switchIfEmpty( + Single.defer( + () -> { + if (runConfig.autoCreateSession()) { + return this.sessionService.createSession( + appName, userId, (Map) null, sessionId); + } + return Single.error( + new IllegalArgumentException( + String.format( + "Session not found: %s for user %s", sessionId, userId))); + })) + .flatMapPublisher( + session -> this.runLiveImpl(session, liveRequestQueue, runConfig))) + .compose(Tracing.trace("invocation")); } /** diff --git a/core/src/main/java/com/google/adk/telemetry/Tracing.java b/core/src/main/java/com/google/adk/telemetry/Tracing.java index 07a640c37..a7808c8cd 100644 --- a/core/src/main/java/com/google/adk/telemetry/Tracing.java +++ b/core/src/main/java/com/google/adk/telemetry/Tracing.java @@ -35,6 +35,7 @@ import io.opentelemetry.api.trace.Span; import io.opentelemetry.api.trace.Tracer; import io.opentelemetry.context.Context; +import io.opentelemetry.context.ContextKey; import io.opentelemetry.context.Scope; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.CompletableSource; @@ -47,6 +48,7 @@ import io.reactivex.rxjava3.core.Single; import io.reactivex.rxjava3.core.SingleSource; import io.reactivex.rxjava3.core.SingleTransformer; +import java.lang.reflect.Method; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -56,6 +58,7 @@ import java.util.Optional; import java.util.function.Consumer; import java.util.function.Supplier; +import javax.annotation.Nullable; import org.reactivestreams.Publisher; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -101,6 +104,9 @@ public class Tracing { private static final AttributeKey GEN_AI_USAGE_OUTPUT_TOKENS = AttributeKey.longKey("gen_ai.usage.output_tokens"); + private static final ContextKey AGENT_NAME_CONTEXT_KEY = + ContextKey.named("gen_ai.agent.name"); + private static final AttributeKey ADK_TOOL_CALL_ARGS = AttributeKey.stringKey("gcp.vertex.agent.tool_call_args"); private static final AttributeKey ADK_LLM_REQUEST = @@ -372,6 +378,33 @@ public static void traceSendData( }); } + private static String getAgentName(Context context) { + String agentName = context.get(AGENT_NAME_CONTEXT_KEY); + if (agentName != null) { + return agentName; + } + + return getAgentNameFromSpan(Span.fromContext(context)); + } + + @Nullable + private static String getAgentNameFromSpan(Span span) { + if (span == null || !span.getSpanContext().isValid()) { + return null; + } + try { + // Use reflection to try to get the attribute from a ReadableSpan (SDK implementation) + Method getAttributeMethod = span.getClass().getMethod("getAttribute", AttributeKey.class); + Object value = getAttributeMethod.invoke(span, GEN_AI_AGENT_NAME); + if (value instanceof String string) { + return string; + } + } catch (Exception ignored) { + // Not a ReadableSpan or other reflection issue + } + return null; + } + /** * Gets the tracer. * @@ -450,15 +483,87 @@ public static TracerProvider trace(String spanName, Context parentContext * @return A TracerProvider configured for agent invocation. */ public static TracerProvider traceAgent( + Context parent, String spanName, String agentName, String agentDescription, InvocationContext invocationContext) { return new TracerProvider(spanName) + .setParent(parent) + .setAgentName(agentName) .configure( span -> traceAgentInvocation(span, agentName, agentDescription, invocationContext)); } + /** + * Returns a transformer that re-activates a given context for the duration of the stream's + * subscription. + * + * @param context The context to re-activate. + * @param The type of the stream. + * @return A transformer that re-activates the context. + */ + public static ContextTransformer withContext(Context context) { + return new ContextTransformer<>(context); + } + + /** + * A transformer that re-activates a given context for the duration of the stream's subscription. + * + * @param The type of the stream. + */ + public static final class ContextTransformer + implements FlowableTransformer, + SingleTransformer, + MaybeTransformer, + CompletableTransformer { + private final Context context; + + private ContextTransformer(Context context) { + this.context = context; + } + + @Override + @SuppressWarnings("MustBeClosedChecker") + public Publisher apply(Flowable upstream) { + return Flowable.defer( + () -> { + Scope scope = context.makeCurrent(); + return upstream.doFinally(scope::close); + }); + } + + @Override + @SuppressWarnings("MustBeClosedChecker") + public SingleSource apply(Single upstream) { + return Single.defer( + () -> { + Scope scope = context.makeCurrent(); + return upstream.doFinally(scope::close); + }); + } + + @Override + @SuppressWarnings("MustBeClosedChecker") + public MaybeSource apply(Maybe upstream) { + return Maybe.defer( + () -> { + Scope scope = context.makeCurrent(); + return upstream.doFinally(scope::close); + }); + } + + @Override + @SuppressWarnings("MustBeClosedChecker") + public CompletableSource apply(Completable upstream) { + return Completable.defer( + () -> { + Scope scope = context.makeCurrent(); + return upstream.doFinally(scope::close); + }); + } + } + /** * A transformer that manages an OpenTelemetry span and scope for RxJava streams. * @@ -471,6 +576,7 @@ public static final class TracerProvider CompletableTransformer { private final String spanName; private Context explicitParentContext; + private String agentName; private final List> spanConfigurers = new ArrayList<>(); private TracerProvider(String spanName) { @@ -491,6 +597,12 @@ public TracerProvider setParent(Context parentContext) { return this; } + @CanIgnoreReturnValue + private TracerProvider setAgentName(String agentName) { + this.agentName = agentName; + return this; + } + private Context getParentContext() { return explicitParentContext != null ? explicitParentContext : Context.current(); } @@ -501,9 +613,33 @@ private final class TracingLifecycle { @SuppressWarnings("MustBeClosedChecker") void start() { - span = tracer.spanBuilder(spanName).setParent(getParentContext()).startSpan(); + Context parentContext = getParentContext(); + + // Propagate agent name from parent if possible + String agentNameToPropagate = TracerProvider.this.agentName; + if (agentNameToPropagate == null) { + agentNameToPropagate = getAgentName(parentContext); + } + + span = tracer.spanBuilder(spanName).setParent(parentContext).startSpan(); + + if (agentNameToPropagate != null) { + span.setAttribute(GEN_AI_AGENT_NAME, agentNameToPropagate); + } + spanConfigurers.forEach(c -> c.accept(span)); - scope = span.makeCurrent(); + + // Try to capture the final agent name (might have been set by a configurer) + String finalAgentName = getAgentNameFromSpan(span); + if (finalAgentName == null) { + finalAgentName = agentNameToPropagate; + } + + Context nextContext = parentContext.with(span); + if (finalAgentName != null) { + nextContext = nextContext.with(AGENT_NAME_CONTEXT_KEY, finalAgentName); + } + scope = nextContext.makeCurrent(); } void end() { @@ -521,7 +657,8 @@ public Publisher apply(Flowable upstream) { return Flowable.defer( () -> { TracingLifecycle lifecycle = new TracingLifecycle(); - return upstream.doOnSubscribe(s -> lifecycle.start()).doFinally(lifecycle::end); + lifecycle.start(); + return upstream.doFinally(lifecycle::end); }); } @@ -530,7 +667,8 @@ public SingleSource apply(Single upstream) { return Single.defer( () -> { TracingLifecycle lifecycle = new TracingLifecycle(); - return upstream.doOnSubscribe(s -> lifecycle.start()).doFinally(lifecycle::end); + lifecycle.start(); + return upstream.doFinally(lifecycle::end); }); } @@ -539,7 +677,8 @@ public MaybeSource apply(Maybe upstream) { return Maybe.defer( () -> { TracingLifecycle lifecycle = new TracingLifecycle(); - return upstream.doOnSubscribe(s -> lifecycle.start()).doFinally(lifecycle::end); + lifecycle.start(); + return upstream.doFinally(lifecycle::end); }); } @@ -548,7 +687,8 @@ public CompletableSource apply(Completable upstream) { return Completable.defer( () -> { TracingLifecycle lifecycle = new TracingLifecycle(); - return upstream.doOnSubscribe(s -> lifecycle.start()).doFinally(lifecycle::end); + lifecycle.start(); + return upstream.doFinally(lifecycle::end); }); } }