diff --git a/a2a/src/main/java/com/google/adk/a2a/RemoteA2AAgent.java b/a2a/src/main/java/com/google/adk/a2a/RemoteA2AAgent.java index b8ff39808..6c89d0979 100644 --- a/a2a/src/main/java/com/google/adk/a2a/RemoteA2AAgent.java +++ b/a2a/src/main/java/com/google/adk/a2a/RemoteA2AAgent.java @@ -171,6 +171,25 @@ public RemoteA2AAgent build() { } } + private Message.Builder newA2AMessage(Message.Role role, List> parts) { + return new Message.Builder().messageId(UUID.randomUUID().toString()).role(role).parts(parts); + } + + private Message prepareMessage(InvocationContext invocationContext) { + Event userCall = EventConverter.findUserFunctionCall(invocationContext.session().events()); + if (userCall != null) { + ImmutableList> parts = + EventConverter.contentToParts(userCall.content(), userCall.partial().orElse(false)); + return newA2AMessage(Message.Role.USER, parts) + .taskId(EventConverter.taskId(userCall)) + .contextId(EventConverter.contextId(userCall)) + .build(); + } + return newA2AMessage( + Message.Role.USER, EventConverter.messagePartsFromContext(invocationContext)) + .build(); + } + @Override protected Flowable runAsyncImpl(InvocationContext invocationContext) { // Construct A2A Message from the last ADK event @@ -181,14 +200,7 @@ protected Flowable runAsyncImpl(InvocationContext invocationContext) { return Flowable.empty(); } - Optional a2aMessageOpt = EventConverter.convertEventsToA2AMessage(invocationContext); - - if (a2aMessageOpt.isEmpty()) { - logger.warn("Failed to convert event to A2A message."); - return Flowable.empty(); - } - - Message originalMessage = a2aMessageOpt.get(); + Message originalMessage = prepareMessage(invocationContext); String requestJson = serializeMessageToJson(originalMessage); return Flowable.create( diff --git a/a2a/src/main/java/com/google/adk/a2a/converters/EventConverter.java b/a2a/src/main/java/com/google/adk/a2a/converters/EventConverter.java index 1a49b0070..406426046 100644 --- a/a2a/src/main/java/com/google/adk/a2a/converters/EventConverter.java +++ b/a2a/src/main/java/com/google/adk/a2a/converters/EventConverter.java @@ -2,16 +2,19 @@ import static com.google.common.collect.ImmutableList.toImmutableList; +import com.google.adk.a2a.common.GenAiFieldMissingException; import com.google.adk.agents.InvocationContext; +import com.google.adk.events.Event; import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; import com.google.genai.types.Content; -import io.a2a.spec.Message; +import com.google.genai.types.FunctionResponse; import io.a2a.spec.Part; import java.util.Collection; +import java.util.List; import java.util.Optional; import java.util.UUID; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; +import org.jspecify.annotations.Nullable; /** * Converter for ADK Events to A2A Messages. @@ -20,54 +23,106 @@ * use in production code. */ public final class EventConverter { - private static final Logger logger = LoggerFactory.getLogger(EventConverter.class); + public static final String ADK_TASK_ID_KEY = "adk_task_id"; + public static final String ADK_CONTEXT_ID_KEY = "adk_context_id"; private EventConverter() {} + private static String metadataValue(Event event, String key) { + if (event.customMetadata().isEmpty()) { + return ""; + } + return event.customMetadata().get().stream() + .filter(m -> m.key().orElse("").equals(key)) + .findFirst() + .map(m -> m.stringValue().orElse("")) + .orElse(""); + } + /** - * Converts an ADK InvocationContext to an A2A Message. + * Returns the task ID from the event. * - *

It combines all the events in the session, plus the user content, converted into A2A Parts, - * into a single A2A Message. + *

Task ID is stored in the event's custom metadata with the key {@link #ADK_TASK_ID_KEY}. * - *

If the context has no events, or no suitable content to build the message, an empty optional - * is returned. - * - * @param context The ADK InvocationContext to convert. - * @return The converted A2A Message. + * @param event The event to get the task ID from. + * @return The task ID, or an empty string if not found. */ - public static Optional convertEventsToA2AMessage(InvocationContext context) { - if (context.session().events().isEmpty()) { - logger.warn("No events in session, cannot convert to A2A message."); - return Optional.empty(); - } - - ImmutableList.Builder> partsBuilder = ImmutableList.builder(); + public static String taskId(Event event) { + return metadataValue(event, ADK_TASK_ID_KEY); + } - context - .session() - .events() - .forEach( - event -> - partsBuilder.addAll( - contentToParts(event.content(), event.partial().orElse(false)))); - partsBuilder.addAll(contentToParts(context.userContent(), false)); + /** + * Returns the context ID from the event. + * + *

Context ID is stored in the event's custom metadata with the key {@link + * #ADK_CONTEXT_ID_KEY}. + * + * @param event The event to get the context ID from. + * @return The context ID, or an empty string if not found. + */ + public static String contextId(Event event) { + return metadataValue(event, ADK_CONTEXT_ID_KEY); + } - ImmutableList> parts = partsBuilder.build(); + /** + * Returns the last user function call event from the list of events. + * + * @param events The list of events to find the user function call event from. + * @return The user function call event, or null if not found. + */ + public static @Nullable Event findUserFunctionCall(List events) { + Event candidate = Iterables.getLast(events); + if (!candidate.author().equals("user")) { + return null; + } + FunctionResponse functionResponse; + try { + functionResponse = findUserFunctionResponse(candidate); + } catch (GenAiFieldMissingException e) { + return null; + } + if (functionResponse == null || functionResponse.id().isEmpty()) { + return null; + } + for (int i = events.size() - 2; i >= 0; i--) { + Event event = events.get(i); + if (isUserFunctionCall(event, functionResponse.id().get())) { + return event; + } + } + return null; + } - if (parts.isEmpty()) { - logger.warn("No suitable content found to build A2A request message."); - return Optional.empty(); + private static FunctionResponse findUserFunctionResponse(Event candidate) { + if (candidate.content().isEmpty() || candidate.content().get().parts().isEmpty()) { + throw new GenAiFieldMissingException("Event has no content or parts."); } + return candidate.content().get().parts().get().stream() + .filter(part -> part.functionResponse().isPresent()) + .findFirst() + .orElseThrow(() -> new GenAiFieldMissingException("Event has no function response.")) + .functionResponse() + .get(); + } - return Optional.of( - new Message.Builder() - .messageId(UUID.randomUUID().toString()) - .parts(parts) - .role(Message.Role.USER) - .build()); + private static boolean isUserFunctionCall(Event event, String functionResponseId) { + if (event.content().isEmpty()) { + return false; + } + return event.content().get().parts().get().stream() + .anyMatch( + part -> + part.functionCall().isPresent() + && part.functionCall().get().id().orElse("").equals(functionResponseId)); } + /** + * Converts a GenAI Content object to a list of A2A Parts. + * + * @param content The GenAI Content object to convert. + * @param isPartial Whether the content is partial. + * @return A list of A2A Parts. + */ public static ImmutableList> contentToParts( Optional content, boolean isPartial) { return content.flatMap(Content::parts).stream() @@ -75,4 +130,69 @@ public static ImmutableList> contentToParts( .map(part -> PartConverter.fromGenaiPart(part, isPartial)) .collect(toImmutableList()); } + + /** + * Returns the parts from the context events that should be sent to the agent. + * + *

All session events from the previous remote agent response (or the beginning of the session + * in case of the first agent invocation) are included into the A2A message. Events from other + * agents are presented as user messages and rephased as if a user was telling what happened in + * the session up to the point. + * + * @param context The invocation context to get the parts from. + * @return A list of A2A Parts. + */ + public static ImmutableList> messagePartsFromContext(InvocationContext context) { + if (context.session().events().isEmpty()) { + return ImmutableList.of(); + } + List events = context.session().events(); + int lastResponseIndex = -1; + String contextId = ""; + for (int i = events.size() - 1; i >= 0; i--) { + Event event = events.get(i); + if (event.author().equals(context.agent().name())) { + lastResponseIndex = i; + contextId = contextId(event); + break; + } + } + ImmutableList.Builder> partsBuilder = ImmutableList.builder(); + for (int i = lastResponseIndex + 1; i < events.size(); i++) { + Event event = events.get(i); + if (!event.author().equals("user") && !event.author().equals(context.agent().name())) { + event = presentAsUserMessage(event, contextId); + } + contentToParts(event.content(), event.partial().orElse(false)).forEach(partsBuilder::add); + } + return partsBuilder.build(); + } + + private static Event presentAsUserMessage(Event event, String contextId) { + Event.Builder userEvent = + new Event.Builder().id(UUID.randomUUID().toString()).invocationId(contextId).author("user"); + ImmutableList parts = + event.content().flatMap(Content::parts).stream() + .flatMap(Collection::stream) + // convert only non-thought parts to user message parts, skip thought parts as they are + // not meant to be shown to the user + .filter(part -> !part.thought().orElse(false)) + .map(part -> PartConverter.remoteCallAsUserPart(event.author(), part)) + .collect(toImmutableList()); + if (parts.isEmpty()) { + return userEvent.build(); + } + com.google.genai.types.Part forContext = + com.google.genai.types.Part.builder().text("For context:").build(); + return userEvent + .content( + Content.builder() + .parts( + ImmutableList.builder() + .add(forContext) + .addAll(parts) + .build()) + .build()) + .build(); + } } diff --git a/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java b/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java index 05125d170..92069a772 100644 --- a/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java +++ b/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java @@ -374,6 +374,50 @@ private static FilePart filePartToA2A(Part part, ImmutableMap.BuilderEvents are rephrased as if a user was telling what happened in the session up to the point. + * E.g. + * + *

{@code
+   * For context:
+   * User said: Now help me with Z
+   * Agent A said: Agent B can help you with it!
+   * Agent B said: Agent C might know better.*
+   * }
+ * + * @param author The author of the part. + * @param part The part to convert. + * @return The converted part. + */ + public static Part remoteCallAsUserPart(String author, Part part) { + if (part.text().isPresent()) { + String partText = String.format("[%s] said: %s", author, part.text().get()); + return Part.builder().text(partText).build(); + } else if (part.functionCall().isPresent()) { + FunctionCall functionCall = part.functionCall().get(); + String partText = + String.format( + "[%s] called tool %s with parameters: %s", + author, + functionCall.name().orElse(""), + functionCall.args().orElse(ImmutableMap.of())); + return Part.builder().text(partText).build(); + } else if (part.functionResponse().isPresent()) { + FunctionResponse functionResponse = part.functionResponse().get(); + String partText = + String.format( + "[%s] %s tool returned result: %s", + author, + functionResponse.name().orElse(""), + functionResponse.response().orElse(ImmutableMap.of())); + return Part.builder().text(partText).build(); + } else { + return part; + } + } + @SuppressWarnings("unchecked") // safe conversion from objectMapper.readValue private static Map coerceToMap(Object value) { if (value == null) { diff --git a/a2a/src/test/java/com/google/adk/a2a/RemoteA2AAgentTest.java b/a2a/src/test/java/com/google/adk/a2a/RemoteA2AAgentTest.java index 87eaa2321..0a0ef24ac 100644 --- a/a2a/src/test/java/com/google/adk/a2a/RemoteA2AAgentTest.java +++ b/a2a/src/test/java/com/google/adk/a2a/RemoteA2AAgentTest.java @@ -412,10 +412,11 @@ public void runAsync_constructsRequestWithHistory() { .sendMessage(messageCaptor.capture(), any(List.class), any(Consumer.class), any()); Message message = messageCaptor.getValue(); assertThat(message.getRole()).isEqualTo(Message.Role.USER); - assertThat(message.getParts()).hasSize(3); + assertThat(message.getParts()).hasSize(4); assertThat(((TextPart) message.getParts().get(0)).getText()).isEqualTo("hello"); - assertThat(((TextPart) message.getParts().get(1)).getText()).isEqualTo("hi"); - assertThat(((TextPart) message.getParts().get(2)).getText()).isEqualTo("how are you?"); + assertThat(((TextPart) message.getParts().get(1)).getText()).isEqualTo("For context:"); + assertThat(((TextPart) message.getParts().get(2)).getText()).isEqualTo("[model] said: hi"); + assertThat(((TextPart) message.getParts().get(3)).getText()).isEqualTo("how are you?"); } @Test diff --git a/a2a/src/test/java/com/google/adk/a2a/converters/EventConverterTest.java b/a2a/src/test/java/com/google/adk/a2a/converters/EventConverterTest.java index 8d460c457..76a527622 100644 --- a/a2a/src/test/java/com/google/adk/a2a/converters/EventConverterTest.java +++ b/a2a/src/test/java/com/google/adk/a2a/converters/EventConverterTest.java @@ -4,23 +4,17 @@ import com.google.adk.agents.BaseAgent; import com.google.adk.agents.InvocationContext; -import com.google.adk.artifacts.InMemoryArtifactService; import com.google.adk.events.Event; -import com.google.adk.plugins.PluginManager; import com.google.adk.sessions.InMemorySessionService; import com.google.adk.sessions.Session; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; import com.google.genai.types.Content; +import com.google.genai.types.CustomMetadata; import com.google.genai.types.FunctionCall; import com.google.genai.types.FunctionResponse; import com.google.genai.types.Part; -import io.a2a.spec.DataPart; -import io.a2a.spec.Message; import io.a2a.spec.TextPart; import io.reactivex.rxjava3.core.Flowable; -import java.util.ArrayList; -import java.util.List; import java.util.Optional; import org.junit.Test; import org.junit.runner.RunWith; @@ -29,117 +23,174 @@ @RunWith(JUnit4.class) public final class EventConverterTest { + private static final class TestAgent extends BaseAgent { + TestAgent() { + super("test_agent", "test", ImmutableList.of(), null, null); + } + + @Override + protected Flowable runAsyncImpl(InvocationContext invocationContext) { + return Flowable.empty(); + } + + @Override + protected Flowable runLiveImpl(InvocationContext invocationContext) { + return Flowable.empty(); + } + } + @Test - public void convertEventsToA2AMessage_preservesFunctionCallAndResponseParts() { - // Arrange session events: user text, function call, function response. - Part userTextPart = Part.builder().text("Roll a die").build(); - Event userEvent = + public void testTaskId() { + Event e = Event.builder() - .id("event-user") - .author("user") - .content(Content.builder().role("user").parts(ImmutableList.of(userTextPart)).build()) + .customMetadata( + ImmutableList.of( + CustomMetadata.builder() + .key(EventConverter.ADK_TASK_ID_KEY) + .stringValue("task-123") + .build())) .build(); + assertThat(EventConverter.taskId(e)).isEqualTo("task-123"); + } - Part functionCallPart = - Part.builder() - .functionCall( - FunctionCall.builder() - .name("roll_die") - .id("adk-call-1") - .args(ImmutableMap.of("sides", 6)) - .build()) + @Test + public void testTaskId_empty() { + Event e = Event.builder().build(); + assertThat(EventConverter.taskId(e)).isEmpty(); + } + + @Test + public void testContextId() { + Event e = + Event.builder() + .customMetadata( + ImmutableList.of( + CustomMetadata.builder() + .key(EventConverter.ADK_CONTEXT_ID_KEY) + .stringValue("context-456") + .build())) .build(); - Event callEvent = + assertThat(EventConverter.contextId(e)).isEqualTo("context-456"); + } + + @Test + public void testContextId_empty() { + Event e = Event.builder().build(); + assertThat(EventConverter.contextId(e)).isEmpty(); + } + + @Test + public void testFindUserFunctionCall_success() { + Event agentEvent = Event.builder().author("agent").build(); + FunctionCall fc = FunctionCall.builder().name("my-func").id("fc-id").build(); + Event userEventWithCall = Event.builder() - .id("event-call") - .author("root_agent") + .author("user") .content( Content.builder() - .role("assistant") - .parts(ImmutableList.of(functionCallPart)) + .parts(ImmutableList.of(Part.builder().functionCall(fc).build())) .build()) .build(); - Part functionResponsePart = - Part.builder() - .functionResponse( - FunctionResponse.builder() - .name("roll_die") - .id("adk-call-1") - .response(ImmutableMap.of("result", 3)) + FunctionResponse fr = FunctionResponse.builder().name("my-func").id("fc-id").build(); + Event userEventWithResponse = + Event.builder() + .author("user") + .content( + Content.builder() + .parts(ImmutableList.of(Part.builder().functionResponse(fr).build())) .build()) .build(); - Event responseEvent = + + ImmutableList events = + ImmutableList.of(userEventWithCall, agentEvent, userEventWithResponse); + assertThat(EventConverter.findUserFunctionCall(events)).isEqualTo(userEventWithCall); + } + + @Test + public void testFindUserFunctionCall_noMatchingCall() { + Event agentEvent = Event.builder().author("agent").build(); + FunctionCall fc = FunctionCall.builder().name("my-func").id("other-id").build(); + Event userEventWithCall = Event.builder() - .id("event-response") - .author("roll_agent") + .author("user") .content( Content.builder() - .role("tool") - .parts(ImmutableList.of(functionResponsePart)) + .parts(ImmutableList.of(Part.builder().functionCall(fc).build())) .build()) .build(); - List events = new ArrayList<>(ImmutableList.of(userEvent, callEvent, responseEvent)); - Session session = - Session.builder("session-1").appName("demo").userId("user").events(events).build(); - - InvocationContext context = - InvocationContext.builder() - .sessionService(new InMemorySessionService()) - .artifactService(new InMemoryArtifactService()) - .pluginManager(new PluginManager()) - .invocationId("invocation-1") - .agent(new TestAgent()) - .session(session) - .userContent( - Content.builder().role("user").parts(ImmutableList.of(userTextPart)).build()) - .endInvocation(false) + FunctionResponse fr = FunctionResponse.builder().name("my-func").id("fc-id").build(); + Event userEventWithResponse = + Event.builder() + .author("user") + .content( + Content.builder() + .parts(ImmutableList.of(Part.builder().functionResponse(fr).build())) + .build()) .build(); - // Act - Optional maybeMessage = EventConverter.convertEventsToA2AMessage(context); - - // Assert - assertThat(maybeMessage).isPresent(); - Message message = maybeMessage.get(); - assertThat(message.getParts()).hasSize(4); - assertThat(message.getParts().get(0)).isInstanceOf(TextPart.class); - assertThat(message.getParts().get(1)).isInstanceOf(DataPart.class); - assertThat(message.getParts().get(2)).isInstanceOf(DataPart.class); - assertThat(message.getParts().get(3)).isInstanceOf(TextPart.class); - - DataPart callDataPart = (DataPart) message.getParts().get(1); - assertThat(callDataPart.getMetadata().get(PartConverter.A2A_DATA_PART_METADATA_TYPE_KEY)) - .isEqualTo(A2ADataPartMetadataType.FUNCTION_CALL.getType()); - assertThat(callDataPart.getData()).containsEntry("name", "roll_die"); - assertThat(callDataPart.getData()).containsEntry("id", "adk-call-1"); - assertThat(callDataPart.getData()).containsEntry("args", ImmutableMap.of("sides", 6)); - - DataPart responseDataPart = (DataPart) message.getParts().get(2); - assertThat(responseDataPart.getMetadata().get(PartConverter.A2A_DATA_PART_METADATA_TYPE_KEY)) - .isEqualTo(A2ADataPartMetadataType.FUNCTION_RESPONSE.getType()); - assertThat(responseDataPart.getData()).containsEntry("name", "roll_die"); - assertThat(responseDataPart.getData()).containsEntry("id", "adk-call-1"); - assertThat(responseDataPart.getData()).containsEntry("response", ImmutableMap.of("result", 3)); - - TextPart lastTextPart = (TextPart) message.getParts().get(3); - assertThat(lastTextPart.getText()).isEqualTo("Roll a die"); + ImmutableList events = + ImmutableList.of(userEventWithCall, agentEvent, userEventWithResponse); + assertThat(EventConverter.findUserFunctionCall(events)).isNull(); } - private static final class TestAgent extends BaseAgent { - TestAgent() { - super("test_agent", "test", ImmutableList.of(), null, null); - } + @Test + public void testFindUserFunctionCall_lastEventNotUser() { + Event agentEvent = Event.builder().author("agent").build(); + ImmutableList events = ImmutableList.of(agentEvent); + assertThat(EventConverter.findUserFunctionCall(events)).isNull(); + } - @Override - protected Flowable runAsyncImpl(InvocationContext invocationContext) { - return Flowable.empty(); - } + @Test + public void testContentToParts() { + Part textPart = Part.builder().text("hello").build(); + Content content = Content.builder().parts(ImmutableList.of(textPart)).build(); + ImmutableList> list = + EventConverter.contentToParts(Optional.of(content), false); + assertThat(list).hasSize(1); + assertThat(((TextPart) list.get(0)).getText()).isEqualTo("hello"); + } - @Override - protected Flowable runLiveImpl(InvocationContext invocationContext) { - return Flowable.empty(); - } + @Test + public void testMessagePartsFromContext() { + Session session = + Session.builder("session1") + .events( + ImmutableList.of( + Event.builder() + .author("user") + .content( + Content.builder() + .parts(ImmutableList.of(Part.builder().text("hello").build())) + .build()) + .build(), + Event.builder() + .author("test_agent") + .content( + Content.builder() + .parts(ImmutableList.of(Part.builder().text("hi").build())) + .build()) + .build(), + Event.builder() + .author("other_agent") + .content( + Content.builder() + .parts(ImmutableList.of(Part.builder().text("hey").build())) + .build()) + .build())) + .build(); + BaseAgent agent = new TestAgent(); + InvocationContext ctx = + InvocationContext.builder() + .session(session) + .sessionService(new InMemorySessionService()) + .agent(agent) + .build(); + ImmutableList> parts = EventConverter.messagePartsFromContext(ctx); + + assertThat(parts).hasSize(2); + assertThat(((TextPart) parts.get(0)).getText()).isEqualTo("For context:"); + assertThat(((TextPart) parts.get(1)).getText()).isEqualTo("[other_agent] said: hey"); } }