diff --git a/a2a/src/main/java/com/google/adk/a2a/executor/AgentExecutor.java b/a2a/src/main/java/com/google/adk/a2a/executor/AgentExecutor.java index 94a54aa67..3d66a4e07 100644 --- a/a2a/src/main/java/com/google/adk/a2a/executor/AgentExecutor.java +++ b/a2a/src/main/java/com/google/adk/a2a/executor/AgentExecutor.java @@ -20,11 +20,19 @@ import io.a2a.server.agentexecution.RequestContext; import io.a2a.server.events.EventQueue; import io.a2a.server.tasks.TaskUpdater; +import io.a2a.spec.Artifact; import io.a2a.spec.InvalidAgentResponseError; import io.a2a.spec.Message; import io.a2a.spec.Part; +import io.a2a.spec.TaskArtifactUpdateEvent; +import io.a2a.spec.TaskState; +import io.a2a.spec.TaskStatus; +import io.a2a.spec.TaskStatusUpdateEvent; import io.a2a.spec.TextPart; +import io.reactivex.rxjava3.core.Completable; +import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; +import io.reactivex.rxjava3.core.Single; import io.reactivex.rxjava3.disposables.CompositeDisposable; import io.reactivex.rxjava3.disposables.Disposable; import java.util.HashMap; @@ -43,10 +51,8 @@ * use in production code. */ public class AgentExecutor implements io.a2a.server.agentexecution.AgentExecutor { - private static final Logger logger = LoggerFactory.getLogger(AgentExecutor.class); private static final String USER_ID_PREFIX = "A2A_USER_"; - private final Map activeTasks = new ConcurrentHashMap<>(); private final Runner.Builder runnerBuilder; private final AgentExecutorConfig agentExecutorConfig; @@ -137,7 +143,6 @@ public Builder plugins(List plugins) { return this; } - @CanIgnoreReturnValue public AgentExecutor build() { return new AgentExecutor( app, @@ -165,46 +170,88 @@ public void execute(RequestContext ctx, EventQueue eventQueue) { if (message == null) { throw new IllegalArgumentException("Message cannot be null"); } - // Submits a new task if there is no active task. if (ctx.getTask() == null) { updater.submit(); } - // Group all reactive work for this task into one container CompositeDisposable taskDisposables = new CompositeDisposable(); // Check if the task with the task id is already running, put if absent. if (activeTasks.putIfAbsent(ctx.getTaskId(), taskDisposables) != null) { throw new IllegalStateException(String.format("Task %s already running", ctx.getTaskId())); } - EventProcessor p = new EventProcessor(agentExecutorConfig.outputMode()); Content content = PartConverter.messageToContent(message); - Runner runner = runnerBuilder.build(); + Single skipExecution = + agentExecutorConfig.beforeExecuteCallback() != null + ? agentExecutorConfig.beforeExecuteCallback().call(ctx) + : Single.just(false); + Runner runner = runnerBuilder.build(); taskDisposables.add( - prepareSession(ctx, runner.appName(), runner.sessionService()) + skipExecution .flatMapPublisher( - session -> { - updater.startWork(); - return runner.runAsync( - getUserId(ctx), session.id(), content, agentExecutorConfig.runConfig()); + skip -> { + if (skip) { + cancel(ctx, eventQueue); + return Flowable.empty(); + } + return Maybe.defer( + () -> { + return prepareSession(ctx, runner.appName(), runner.sessionService()); + }) + .flatMapPublisher( + session -> { + updater.startWork(); + return runner.runAsync( + getUserId(ctx), + session.id(), + content, + agentExecutorConfig.runConfig()); + }); }) - .subscribe( + .concatMap( event -> { - p.process(event, updater); - }, + return p.process(event, ctx, agentExecutorConfig.afterEventCallback(), eventQueue) + .toFlowable(); + }) + // Ignore all events from the runner, since they are already processed. + .ignoreElements() + .materialize() + .flatMapCompletable( + notification -> { + Throwable error = notification.getError(); + if (error != null) { + logger.error("Runner failed to execute", error); + } + return handleExecutionEnd(ctx, error, eventQueue); + }) + .doFinally(() -> cleanupTask(ctx.getTaskId())) + .subscribe( + () -> {}, error -> { - logger.error("Runner failed with {}", error); - updater.fail(failedMessage(ctx, error)); - cleanupTask(ctx.getTaskId()); - }, - () -> { - updater.complete(); - cleanupTask(ctx.getTaskId()); + logger.error("Failed to handle execution end", error); })); } + private Completable handleExecutionEnd( + RequestContext ctx, Throwable error, EventQueue eventQueue) { + TaskState state = error != null ? TaskState.FAILED : TaskState.COMPLETED; + Message message = error != null ? failedMessage(ctx, error) : null; + TaskStatusUpdateEvent initialEvent = + new TaskStatusUpdateEvent.Builder() + .taskId(ctx.getTaskId()) + .contextId(ctx.getContextId()) + .isFinal(true) + .status(new TaskStatus(state, message, null)) + .build(); + Maybe afterExecute = + agentExecutorConfig.afterExecuteCallback() != null + ? agentExecutorConfig.afterExecuteCallback().call(ctx, initialEvent) + : Maybe.just(initialEvent); + return afterExecute.doOnSuccess(event -> eventQueue.enqueueEvent(event)).ignoreElement(); + } + private void cleanupTask(String taskId) { Disposable d = activeTasks.remove(taskId); if (d != null) { @@ -249,16 +296,19 @@ private EventProcessor(AgentExecutorConfig.OutputMode outputMode) { this.outputMode = outputMode; } - private void process(Event event, TaskUpdater updater) { + private Maybe process( + Event event, + RequestContext ctx, + Callbacks.AfterEventCallback callback, + EventQueue eventQueue) { if (event.errorCode().isPresent()) { - throw new InvalidAgentResponseError( - null, // Uses default code -32006 - "Agent returned an error: " + event.errorCode().get(), - null); + return Maybe.error( + new InvalidAgentResponseError( + null, // Uses default code -32006 + "Agent returned an error: " + event.errorCode().get(), + null)); } - ImmutableList> parts = EventConverter.contentToParts(event.content()); - // Mark all parts as partial if the event is partial. if (event.partial().orElse(false)) { parts.forEach( @@ -302,7 +352,26 @@ private void process(Event event, TaskUpdater updater) { } } - updater.addArtifact(parts, artifactId, null, metadata, append, lastChunk); + TaskArtifactUpdateEvent initialEvent = + new TaskArtifactUpdateEvent.Builder() + .taskId(ctx.getTaskId()) + .contextId(ctx.getContextId()) + .lastChunk(lastChunk) + .append(append) + .artifact( + new Artifact.Builder() + .artifactId(artifactId) + .parts(parts) + .metadata(metadata) + .build()) + .build(); + + Maybe afterEvent = + callback != null ? callback.call(ctx, initialEvent, event) : Maybe.just(initialEvent); + return afterEvent.doOnSuccess( + finalEvent -> { + eventQueue.enqueueEvent(finalEvent); + }); } } } diff --git a/a2a/src/test/java/com/google/adk/a2a/executor/AgentExecutorTest.java b/a2a/src/test/java/com/google/adk/a2a/executor/AgentExecutorTest.java index d9c7c25ab..5570f40d0 100644 --- a/a2a/src/test/java/com/google/adk/a2a/executor/AgentExecutorTest.java +++ b/a2a/src/test/java/com/google/adk/a2a/executor/AgentExecutorTest.java @@ -3,7 +3,9 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertThrows; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -15,15 +17,25 @@ import com.google.adk.events.Event; import com.google.adk.sessions.InMemorySessionService; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterables; import com.google.genai.types.Content; import com.google.genai.types.Part; import io.a2a.server.agentexecution.RequestContext; import io.a2a.server.events.EventQueue; import io.a2a.spec.Message; import io.a2a.spec.TaskArtifactUpdateEvent; +import io.a2a.spec.TaskState; +import io.a2a.spec.TaskStatus; +import io.a2a.spec.TaskStatusUpdateEvent; import io.a2a.spec.TextPart; import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.Maybe; +import io.reactivex.rxjava3.core.Single; +import java.util.ArrayList; +import java.util.List; import java.util.Optional; +import java.util.UUID; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -33,10 +45,21 @@ @RunWith(JUnit4.class) public final class AgentExecutorTest { + private EventQueue eventQueue; + private List enqueuedEvents; private TestAgent testAgent; @Before public void setUp() { + enqueuedEvents = new ArrayList<>(); + eventQueue = mock(EventQueue.class); + doAnswer( + invocation -> { + enqueuedEvents.add(invocation.getArgument(0)); + return null; + }) + .when(eventQueue) + .enqueueEvent(any()); testAgent = new TestAgent(); } @@ -92,6 +115,248 @@ public void createAgentExecutor_noAgentExecutorConfig_throwsException() { }); } + @Test + public void execute_withBeforeExecuteCallback_cancelsExecutionOnError() { + // If callback returns error, execution should stop/fail. + Callbacks.BeforeExecuteCallback callback = + ctx -> Single.error(new RuntimeException("Cancelled")); + + AgentExecutorConfig config = + AgentExecutorConfig.builder().beforeExecuteCallback(callback).build(); + + AgentExecutor executor = + new AgentExecutor.Builder() + .agentExecutorConfig(config) + .app(App.builder().name("test_app").rootAgent(testAgent).build()) + .sessionService(new InMemorySessionService()) + .artifactService(new InMemoryArtifactService()) + .build(); + + RequestContext ctx = createRequestContext(); + executor.execute(ctx, eventQueue); + + // Verify error handling triggered cleanup and fail event + // The executor catches the error and emits failed event. + assertThat(enqueuedEvents).isNotEmpty(); + Object lastEvent = Iterables.getLast(enqueuedEvents); + assertThat(lastEvent).isInstanceOf(TaskStatusUpdateEvent.class); + TaskStatusUpdateEvent statusEvent = (TaskStatusUpdateEvent) lastEvent; + assertThat(statusEvent.getStatus().state().toString()).isEqualTo("FAILED"); + assertThat(statusEvent.getStatus().message().getParts().get(0)).isInstanceOf(TextPart.class); + TextPart textPart = (TextPart) statusEvent.getStatus().message().getParts().get(0); + assertThat(textPart.getText()).contains("Cancelled"); + } + + @Test + public void execute_withBeforeExecuteCallback_skipsExecutionIfTrue() { + Callbacks.BeforeExecuteCallback callback = ctx -> Single.just(true); + + AgentExecutorConfig config = + AgentExecutorConfig.builder().beforeExecuteCallback(callback).build(); + + AgentExecutor executor = + new AgentExecutor.Builder() + .agentExecutorConfig(config) + .app(App.builder().name("test_app").rootAgent(testAgent).build()) + .sessionService(new InMemorySessionService()) + .artifactService(new InMemoryArtifactService()) + .build(); + + RequestContext ctx = createRequestContext(); + executor.execute(ctx, eventQueue); + + // Filter for artifact events + Optional artifactEvent = + enqueuedEvents.stream() + .filter(e -> e instanceof TaskArtifactUpdateEvent) + .map(e -> (TaskArtifactUpdateEvent) e) + .findFirst(); + + assertThat(artifactEvent).isEmpty(); + } + + @Test + public void execute_withAfterEventCallback_modifiesEvent() { + // Agent emits an event. Callback intercepts and modifies it. + Part textPart = Part.builder().text("Hello world").build(); + Event agentEvent = + Event.builder() + .id("event-1") + .author("agent") + .content(Content.builder().role("model").parts(ImmutableList.of(textPart)).build()) + .build(); + testAgent.setEventsToEmit(Flowable.just(agentEvent)); + + Callbacks.AfterEventCallback callback = + (ctx, event, sourceEvent) -> { + // Modify event by adding metadata + return Maybe.just( + new TaskArtifactUpdateEvent.Builder(event) + .metadata(ImmutableMap.of("modified", true)) + .build()); + }; + + AgentExecutorConfig config = AgentExecutorConfig.builder().afterEventCallback(callback).build(); + + AgentExecutor executor = + new AgentExecutor.Builder() + .agentExecutorConfig(config) + .app(App.builder().name("test_app").rootAgent(testAgent).build()) + .sessionService(new InMemorySessionService()) + .artifactService(new InMemoryArtifactService()) + .build(); + + RequestContext ctx = createRequestContext(); + executor.execute(ctx, eventQueue); + + // Filter for artifact events + Optional artifactEvent = + enqueuedEvents.stream() + .filter(e -> e instanceof TaskArtifactUpdateEvent) + .map(e -> (TaskArtifactUpdateEvent) e) + .findFirst(); + + assertThat(artifactEvent).isPresent(); + assertThat(artifactEvent.get().getMetadata()).containsEntry("modified", true); + } + + @Test + public void execute_withAfterExecuteCallback_modifiesStatus() { + testAgent.setEventsToEmit(Flowable.empty()); // Just complete + + Callbacks.AfterExecuteCallback callback = + (ctx, event) -> { + // Modify status to have different message + Message newMessage = + new Message.Builder() + .messageId(UUID.randomUUID().toString()) + .role(Message.Role.AGENT) + .parts(ImmutableList.of(new TextPart("Modified completion"))) + .build(); + + return Maybe.just( + new TaskStatusUpdateEvent.Builder(event) + .status(new TaskStatus(event.getStatus().state(), newMessage, null)) + .build()); + }; + + AgentExecutorConfig config = + AgentExecutorConfig.builder().afterExecuteCallback(callback).build(); + + AgentExecutor executor = + new AgentExecutor.Builder() + .agentExecutorConfig(config) + .app(App.builder().name("test_app").rootAgent(testAgent).build()) + .sessionService(new InMemorySessionService()) + .artifactService(new InMemoryArtifactService()) + .build(); + + RequestContext ctx = createRequestContext(); + executor.execute(ctx, eventQueue); + + // Verify status event + Optional statusEvent = + enqueuedEvents.stream() + .filter(e -> e instanceof TaskStatusUpdateEvent) + .map(e -> (TaskStatusUpdateEvent) e) + .filter(TaskStatusUpdateEvent::isFinal) + .findFirst(); + + assertThat(statusEvent).isPresent(); + assertThat(statusEvent.get().getStatus().message().getParts().get(0)) + .isInstanceOf(TextPart.class); + TextPart textPart = (TextPart) statusEvent.get().getStatus().message().getParts().get(0); + assertThat(textPart.getText()).isEqualTo("Modified completion"); + } + + @Test + public void execute_runnerFails_registersFailedEvent() { + testAgent.setEventsToEmit(Flowable.error(new RuntimeException("Runner error"))); + AgentExecutor executor = + new AgentExecutor.Builder() + .agentExecutorConfig(AgentExecutorConfig.builder().build()) + .app(App.builder().name("test_app").rootAgent(testAgent).build()) + .sessionService(new InMemorySessionService()) + .artifactService(new InMemoryArtifactService()) + .build(); + + RequestContext ctx = createRequestContext(); + executor.execute(ctx, eventQueue); + + ImmutableList finalEvents = + enqueuedEvents.stream() + .filter(e -> e instanceof TaskStatusUpdateEvent) + .map(e -> (TaskStatusUpdateEvent) e) + // final events could be COMPLETED, FAILED, CANCELED, REJECTED or UNKNOWN + // as per io.a2a.spec.TaskState + .filter(TaskStatusUpdateEvent::isFinal) + .collect(toImmutableList()); + + assertThat(finalEvents).hasSize(1); + + TaskStatusUpdateEvent statusEvent = finalEvents.get(0); + assertThat(statusEvent.getStatus().state()).isEqualTo(TaskState.FAILED); + assertThat(statusEvent.getStatus().message().getParts().get(0)).isInstanceOf(TextPart.class); + TextPart textPart = (TextPart) statusEvent.getStatus().message().getParts().get(0); + assertThat(textPart.getText()).isEqualTo("Runner error"); + } + + @Test + public void execute_runnerSucceeds_registerCompletedTaskFails_noFailedTaskRegistered() { + testAgent.setEventsToEmit(Flowable.empty()); + + // Configure eventQueue to throw exception when TaskStatusUpdateEvent is enqueued + doAnswer( + invocation -> { + Object event = invocation.getArgument(0); + if (event instanceof TaskStatusUpdateEvent statusUpdate) { + if (statusUpdate.getStatus().state() == TaskState.COMPLETED) { + throw new RuntimeException("Enqueue failed"); + } + } + return null; + }) + .when(eventQueue) + .enqueueEvent(any()); + + AgentExecutor executor = + new AgentExecutor.Builder() + .agentExecutorConfig(AgentExecutorConfig.builder().build()) + .app(App.builder().name("test_app").rootAgent(testAgent).build()) + .sessionService(new InMemorySessionService()) + .artifactService(new InMemoryArtifactService()) + .build(); + + RequestContext ctx = createRequestContext(); + executor.execute(ctx, eventQueue); + + // Verify status events in the tracked enqueuedEvents + ImmutableList statusEvents = + enqueuedEvents.stream() + .filter(e -> e instanceof TaskStatusUpdateEvent) + .map(e -> (TaskStatusUpdateEvent) e) + .filter(TaskStatusUpdateEvent::isFinal) + .collect(toImmutableList()); + + // There should be no final status events. + assertThat(statusEvents).isEmpty(); + } + + private RequestContext createRequestContext() { + Message message = + new Message.Builder() + .messageId("msg-1") + .role(Message.Role.USER) + .parts(ImmutableList.of(new TextPart("trigger"))) + .build(); + + RequestContext ctx = mock(RequestContext.class); + when(ctx.getMessage()).thenReturn(message); + when(ctx.getTaskId()).thenReturn("task-" + UUID.randomUUID()); + when(ctx.getContextId()).thenReturn("ctx-" + UUID.randomUUID()); + return ctx; + } + @Test public void process_statefulAggregation_tracksArtifactIdAndAppendForAuthor() { Event partial1 = @@ -175,7 +440,7 @@ public void process_statefulAggregation_tracksArtifactIdAndAppendForAuthor() { } private static final class TestAgent extends BaseAgent { - private final Flowable eventsToEmit; + private Flowable eventsToEmit; TestAgent() { this(Flowable.empty()); @@ -187,6 +452,10 @@ private static final class TestAgent extends BaseAgent { this.eventsToEmit = eventsToEmit; } + void setEventsToEmit(Flowable events) { + this.eventsToEmit = events; + } + @Override protected Flowable runAsyncImpl(InvocationContext invocationContext) { return eventsToEmit; diff --git a/core/pom.xml b/core/pom.xml index eefbcda79..93c72e745 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -74,6 +74,10 @@ com.squareup.okhttp3 okhttp + + com.squareup.okhttp3 + okhttp-jvm + com.google.auto.value auto-value-annotations diff --git a/core/src/main/java/com/google/adk/agents/LlmAgent.java b/core/src/main/java/com/google/adk/agents/LlmAgent.java index ee4e6ab4c..bbed217f4 100644 --- a/core/src/main/java/com/google/adk/agents/LlmAgent.java +++ b/core/src/main/java/com/google/adk/agents/LlmAgent.java @@ -318,7 +318,7 @@ public Builder beforeModelCallback(BeforeModelCallback beforeModelCallback) { @CanIgnoreReturnValue public Builder beforeModelCallback( - @Nullable List beforeModelCallbacks) { + @Nullable List beforeModelCallbacks) { this.beforeModelCallback = convertCallbacks( beforeModelCallbacks, @@ -355,7 +355,8 @@ public Builder afterModelCallback(AfterModelCallback afterModelCallback) { } @CanIgnoreReturnValue - public Builder afterModelCallback(@Nullable List afterModelCallbacks) { + public Builder afterModelCallback( + @Nullable List afterModelCallbacks) { this.afterModelCallback = convertCallbacks( afterModelCallbacks, @@ -392,7 +393,7 @@ public Builder onModelErrorCallback(OnModelErrorCallback onModelErrorCallback) { @CanIgnoreReturnValue public Builder onModelErrorCallback( - @Nullable List onModelErrorCallbacks) { + @Nullable List onModelErrorCallbacks) { this.onModelErrorCallback = convertCallbacks( onModelErrorCallbacks, @@ -488,7 +489,8 @@ public Builder afterToolCallback(AfterToolCallback afterToolCallback) { } @CanIgnoreReturnValue - public Builder afterToolCallback(@Nullable List afterToolCallbacks) { + public Builder afterToolCallback( + @Nullable List afterToolCallbacks) { this.afterToolCallback = convertCallbacks( afterToolCallbacks, @@ -528,7 +530,7 @@ public Builder onToolErrorCallback(OnToolErrorCallback onToolErrorCallback) { @CanIgnoreReturnValue public Builder onToolErrorCallback( - @Nullable List onToolErrorCallbacks) { + @Nullable List onToolErrorCallbacks) { this.onToolErrorCallback = convertCallbacks( onToolErrorCallbacks, diff --git a/core/src/main/java/com/google/adk/flows/llmflows/AgentTransfer.java b/core/src/main/java/com/google/adk/flows/llmflows/AgentTransfer.java index 402f71c9d..0a0da8761 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/AgentTransfer.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/AgentTransfer.java @@ -55,22 +55,17 @@ public Single processRequest( .appendInstructions( ImmutableList.of(buildTargetAgentsInstructions(agent, transferTargets))); - // Note: this tool is not exposed to the LLM in GenerateContent request. It is there only to - // serve as a backwards-compatible instance for users who depend on the exact name of - // "transferToAgent". - builder.appendTools(ImmutableList.of(createTransferToAgentTool("legacyTransferToAgent"))); - - FunctionTool agentTransferTool = createTransferToAgentTool("transferToAgent"); + FunctionTool agentTransferTool = createTransferToAgentTool(); agentTransferTool.processLlmRequest(builder, ToolContext.builder(context).build()); return Single.just( RequestProcessor.RequestProcessingResult.create(builder.build(), ImmutableList.of())); } - private FunctionTool createTransferToAgentTool(String methodName) { + private FunctionTool createTransferToAgentTool() { Method transferToAgentMethod; try { transferToAgentMethod = - AgentTransfer.class.getMethod(methodName, String.class, ToolContext.class); + AgentTransfer.class.getMethod("transferToAgent", String.class, ToolContext.class); } catch (NoSuchMethodException e) { throw new IllegalStateException(e); } @@ -169,18 +164,4 @@ public static void transferToAgent( EventActions eventActions = toolContext.eventActions(); toolContext.setActions(eventActions.toBuilder().transferToAgent(agentName).build()); } - - /** - * Backwards compatible transferToAgent that uses camel-case naming instead of the ADK's - * snake_case convention. - * - *

It exists only to support users who already use literal "transferToAgent" function call to - * instruct ADK to transfer the question to another agent. - */ - @Schema(name = "transferToAgent") - public static void legacyTransferToAgent( - @Schema(name = "agentName") String agentName, - @Schema(optional = true) ToolContext toolContext) { - transferToAgent(agentName, toolContext); - } } diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index 549652e86..1249728d8 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -577,13 +577,7 @@ public void onError(Throwable e) { .get() .content(event.content().get()); } - if (functionResponses.stream() - .anyMatch( - functionResponse -> - functionResponse - .name() - .orElse("") - .equals("transferToAgent")) + if (event.actions().transferToAgent().isPresent() || event.actions().endInvocation().orElse(false)) { sendTask.dispose(); connection.close(); diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Contents.java b/core/src/main/java/com/google/adk/flows/llmflows/Contents.java index 040b14c05..f98a35f0b 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Contents.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Contents.java @@ -169,7 +169,36 @@ private boolean isEmptyContent(Event event) { || content.role().get().isEmpty() || content.parts().isEmpty() || content.parts().get().isEmpty() - || content.parts().get().get(0).text().map(String::isEmpty).orElse(false)); + || content.parts().get().stream().allMatch(this::isPartInvisible)); + } + + /** + * Returns whether a part is invisible for LLM context. + * + *

A part is invisible if: + * + *

    + *
  • It has no meaningful content (text, inline_data, file_data, function_call, + * function_response, executable_code, or code_execution_result), OR + *
  • It is marked as a thought AND does not contain function_call or function_response + *
+ * + *

Function calls and responses are never invisible, even if marked as thought, because they + * represent actions that need to be executed or results that need to be processed. + * + * @param part the part to check. + * @return {@code true} if the part is invisible, {@code false} otherwise. + */ + private boolean isPartInvisible(Part part) { + if (part.functionCall().isPresent() || part.functionResponse().isPresent()) { + return false; + } + return part.thought().orElse(false) + || !(part.text().isPresent() + || part.inlineData().isPresent() + || part.fileData().isPresent() + || part.codeExecutionResult().isPresent() + || part.executableCode().isPresent()); } /** diff --git a/core/src/main/java/com/google/adk/summarizer/EventsCompactionConfig.java b/core/src/main/java/com/google/adk/summarizer/EventsCompactionConfig.java index b61cd2008..39698c3db 100644 --- a/core/src/main/java/com/google/adk/summarizer/EventsCompactionConfig.java +++ b/core/src/main/java/com/google/adk/summarizer/EventsCompactionConfig.java @@ -16,6 +16,8 @@ package com.google.adk.summarizer; +import com.google.auto.value.AutoBuilder; +import com.google.errorprone.annotations.CanIgnoreReturnValue; import javax.annotation.Nullable; /** @@ -39,6 +41,35 @@ public record EventsCompactionConfig( @Nullable Integer tokenThreshold, @Nullable Integer eventRetentionSize) { + public static Builder builder() { + return new AutoBuilder_EventsCompactionConfig_Builder(); + } + + public Builder toBuilder() { + return new AutoBuilder_EventsCompactionConfig_Builder(this); + } + + /** Builder for {@link EventsCompactionConfig}. */ + @AutoBuilder + public abstract static class Builder { + @CanIgnoreReturnValue + public abstract Builder compactionInterval(@Nullable Integer compactionInterval); + + @CanIgnoreReturnValue + public abstract Builder overlapSize(@Nullable Integer overlapSize); + + @CanIgnoreReturnValue + public abstract Builder summarizer(@Nullable BaseEventSummarizer summarizer); + + @CanIgnoreReturnValue + public abstract Builder tokenThreshold(@Nullable Integer tokenThreshold); + + @CanIgnoreReturnValue + public abstract Builder eventRetentionSize(@Nullable Integer eventRetentionSize); + + public abstract EventsCompactionConfig build(); + } + public EventsCompactionConfig(int compactionInterval, int overlapSize) { this(compactionInterval, overlapSize, null, null, null); } diff --git a/core/src/test/java/com/google/adk/flows/llmflows/AgentTransferTest.java b/core/src/test/java/com/google/adk/flows/llmflows/AgentTransferTest.java index 6e6e99640..79552520b 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/AgentTransferTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/AgentTransferTest.java @@ -24,12 +24,13 @@ import static com.google.common.truth.Truth.assertThat; import com.google.adk.agents.InvocationContext; +import com.google.adk.agents.LiveRequest; +import com.google.adk.agents.LiveRequestQueue; import com.google.adk.agents.LlmAgent; import com.google.adk.agents.LoopAgent; import com.google.adk.agents.RunConfig; import com.google.adk.agents.SequentialAgent; import com.google.adk.events.Event; -import com.google.adk.models.LlmRequest; import com.google.adk.runner.InMemoryRunner; import com.google.adk.runner.Runner; import com.google.adk.sessions.Session; @@ -44,6 +45,7 @@ import com.google.genai.types.Schema; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Single; +import io.reactivex.rxjava3.subscribers.TestSubscriber; import java.util.List; import java.util.Map; import java.util.Optional; @@ -97,6 +99,50 @@ public void exitLoopTool_exitsLoop() { // TODO: b/413488103 - complete when LoopAgent is implemented. } + @Test + public void runLive_transferToAgent_closesConnection() throws Exception { + // Arrange + Content transferCallContent = Content.fromParts(createTransferCallPart("sub_agent_1")); + Content response1 = Content.fromParts(Part.fromText("response1")); + + TestLlm testLlm = + createTestLlm( + Flowable.just(createLlmResponse(transferCallContent)), + Flowable.just(createLlmResponse(response1))); + + LlmAgent subAgent1 = createTestAgentBuilder(testLlm).name("sub_agent_1").build(); + LlmAgent rootAgent = + createTestAgentBuilder(testLlm) + .name("root_agent") + .subAgents(ImmutableList.of(subAgent1)) + .build(); + InvocationContext invocationContext = createInvocationContext(rootAgent); + + Runner runner = getRunnerAndCreateSession(rootAgent, invocationContext.session()); + LiveRequestQueue liveRequestQueue = new LiveRequestQueue(); + + // Act + TestSubscriber testSubscriber = + runner + .runLive(invocationContext.session(), liveRequestQueue, RunConfig.builder().build()) + .test(); + liveRequestQueue.content(Content.fromParts(Part.fromText("hi"))); + testSubscriber.await(); + + // Assert + testSubscriber.assertComplete(); + assertThat(simplifyEvents(testSubscriber.values())) + .containsExactly( + "root_agent: FunctionCall(name=transfer_to_agent, args={agent_name=sub_agent_1})", + "root_agent: FunctionResponse(name=transfer_to_agent, response={})", + "sub_agent_1: response1") + .inOrder(); + + long closedConnectionsCount = + testLlm.getLiveRequestHistory().stream().filter(LiveRequest::shouldClose).count(); + assertThat(closedConnectionsCount).isEqualTo(1); + } + @Test public void testAutoToAuto() { Content transferCallContent = Content.fromParts(createTransferCallPart("sub_agent_1")); @@ -412,85 +458,6 @@ public void testAutoToLoop() { assertThat(simplifyEvents(actualEvents)).containsExactly("root_agent: response5"); } - @Test - public void testLegacyTransferToAgent() { - Content transferCallContent = - Content.fromParts( - Part.fromFunctionCall("transferToAgent", ImmutableMap.of("agentName", "sub_agent_1"))); - Content response1 = Content.fromParts(Part.fromText("response1")); - Content response2 = Content.fromParts(Part.fromText("response2")); - - TestLlm testLlm = - createTestLlm( - Flowable.just(createLlmResponse(transferCallContent)), - Flowable.just(createLlmResponse(response1)), - Flowable.just(createLlmResponse(response2))); - - LlmAgent subAgent1 = createTestAgentBuilder(testLlm).name("sub_agent_1").build(); - LlmAgent rootAgent = - createTestAgentBuilder(testLlm) - .name("root_agent") - .subAgents(ImmutableList.of(subAgent1)) - .build(); - InvocationContext invocationContext = createInvocationContext(rootAgent); - - Runner runner = getRunnerAndCreateSession(rootAgent, invocationContext.session()); - List actualEvents = runRunner(runner, invocationContext); - - assertThat(simplifyEvents(actualEvents)) - .containsExactly( - "root_agent: FunctionCall(name=transferToAgent, args={agentName=sub_agent_1})", - "root_agent: FunctionResponse(name=transferToAgent, response={})", - "sub_agent_1: response1") - .inOrder(); - - actualEvents = runRunner(runner, invocationContext); - - assertThat(simplifyEvents(actualEvents)).containsExactly("sub_agent_1: response2"); - } - - @Test - public void testAgentTransferDoesNotExposeLegacyTransferToAgent() { - Content transferCallContent = - Content.fromParts( - Part.fromFunctionCall("transferToAgent", ImmutableMap.of("agentName", "sub_agent_1"))); - Content response1 = Content.fromParts(Part.fromText("response1")); - Content response2 = Content.fromParts(Part.fromText("response2")); - TestLlm testLlm = - createTestLlm( - Flowable.just(createLlmResponse(transferCallContent)), - Flowable.just(createLlmResponse(response1)), - Flowable.just(createLlmResponse(response2))); - LlmAgent subAgent1 = createTestAgentBuilder(testLlm).name("sub_agent_1").build(); - LlmAgent rootAgent = - createTestAgentBuilder(testLlm) - .name("root_agent") - .subAgents(ImmutableList.of(subAgent1)) - .build(); - InvocationContext invocationContext = createInvocationContext(rootAgent); - AgentTransfer processor = new AgentTransfer(); - LlmRequest request = LlmRequest.builder().build(); - - var processed = processor.processRequest(invocationContext, request); - - assertThat(processed.blockingGet().updatedRequest().config().get().tools()).isPresent(); - assertThat(processed.blockingGet().updatedRequest().config().get().tools().get()).hasSize(1); - assertThat( - processed - .blockingGet() - .updatedRequest() - .config() - .get() - .tools() - .get() - .get(0) - .functionDeclarations() - .get() - .get(0) - .name()) - .hasValue("transfer_to_agent"); - } - private Runner getRunnerAndCreateSession(LlmAgent agent, Session session) { Runner runner = new InMemoryRunner(agent, session.appName()); // Ensure the session exists before running the agent. diff --git a/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java b/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java index d555525f4..3041a855b 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java @@ -754,6 +754,32 @@ public void processRequest_slidingWindow_preservesOverlappingCompactions() { .containsExactly("C1", "C2", "E4", "E5"); } + @Test + public void processRequest_notEmptyContent() { + Event e = + Event.builder() + .id("e1") + .author(AGENT) + .content( + Content.builder() + .role("model") + .parts( + ImmutableList.of( + Part.builder().text("").thought(true).build(), + Part.builder() + .functionCall( + FunctionCall.builder() + .name("test-tool") + .id("test-call-id") + .build()) + .thought(false) + .build())) + .build()) + .build(); + List contents = runContentsProcessor(ImmutableList.of(e)); + assertThat(contents).containsExactly(e.content().get()); + } + private static Event createUserEvent(String id, String text) { return Event.builder() .id(id) diff --git a/core/src/test/java/com/google/adk/summarizer/EventsCompactionConfigTest.java b/core/src/test/java/com/google/adk/summarizer/EventsCompactionConfigTest.java new file mode 100644 index 000000000..01f59d37a --- /dev/null +++ b/core/src/test/java/com/google/adk/summarizer/EventsCompactionConfigTest.java @@ -0,0 +1,55 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.summarizer; + +import static com.google.common.truth.Truth.assertThat; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class EventsCompactionConfigTest { + + @Test + public void builder_buildsConfig() { + EventsCompactionConfig config = + EventsCompactionConfig.builder() + .compactionInterval(10) + .overlapSize(2) + .tokenThreshold(100) + .eventRetentionSize(5) + .build(); + + assertThat(config.compactionInterval()).isEqualTo(10); + assertThat(config.overlapSize()).isEqualTo(2); + assertThat(config.tokenThreshold()).isEqualTo(100); + assertThat(config.eventRetentionSize()).isEqualTo(5); + assertThat(config.summarizer()).isNull(); + } + + @Test + public void toBuilder_rebuildsConfig() { + EventsCompactionConfig config = + EventsCompactionConfig.builder().compactionInterval(10).overlapSize(2).build(); + + EventsCompactionConfig rebuilt = config.toBuilder().compactionInterval(20).build(); + + assertThat(rebuilt.compactionInterval()).isEqualTo(20); + assertThat(rebuilt.overlapSize()).isEqualTo(2); + } +} diff --git a/pom.xml b/pom.xml index d3f2ba432..af8a1f2b1 100644 --- a/pom.xml +++ b/pom.xml @@ -41,6 +41,7 @@ 17 ${java.version} UTF-8 + 3.6.0 1.11.1 3.4.1 @@ -56,8 +57,8 @@ 5.20.0 1.6.0 2.19.0 - 4.12.0 - 3.3.6 + 5.3.2 + 3.7.0 0.18.1 3.41.0 3.9.0 @@ -69,7 +70,7 @@ 3.7.0 2.35.1 3.27.7 - 1.4.0 + 2.15.0 3.9.0 5.6 @@ -112,6 +113,13 @@ pom import + + com.squareup.okhttp3 + okhttp-bom + ${okhttp.version} + pom + import + @@ -144,11 +152,6 @@ google-genai ${google.genai.version} - - com.squareup.okhttp3 - okhttp - ${okhttp.version} - com.google.auto.value auto-value-annotations @@ -287,6 +290,11 @@ + + org.apache.maven.plugins + maven-checkstyle-plugin + ${maven.checkstyle.plugin.version} + maven-clean-plugin 3.1.0 @@ -462,6 +470,40 @@ + + illegal-optional-check + + + + org.apache.maven.plugins + maven-checkstyle-plugin + + + illegal-optional-check + + check + + compile + + + + + + + + + + + + + + + + + + + + release-sonatype