Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions a2a/src/main/java/com/google/adk/a2a/RemoteA2AAgent.java
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,25 @@ public RemoteA2AAgent build() {
}
}

private Message.Builder newA2AMessage(Message.Role role, List<io.a2a.spec.Part<?>> parts) {
return new Message.Builder().messageId(UUID.randomUUID().toString()).role(role).parts(parts);
}

private Message prepareMessage(InvocationContext invocationContext) {
Event userCall = EventConverter.findUserFunctionCall(invocationContext.session().events());
if (userCall != null) {
ImmutableList<io.a2a.spec.Part<?>> parts =
EventConverter.contentToParts(userCall.content(), userCall.partial().orElse(false));
return newA2AMessage(Message.Role.USER, parts)
.taskId(EventConverter.taskId(userCall))
.contextId(EventConverter.contextId(userCall))
.build();
}
return newA2AMessage(
Message.Role.USER, EventConverter.messagePartsFromContext(invocationContext))
.build();
}

@Override
protected Flowable<Event> runAsyncImpl(InvocationContext invocationContext) {
// Construct A2A Message from the last ADK event
Expand All @@ -181,14 +200,7 @@ protected Flowable<Event> runAsyncImpl(InvocationContext invocationContext) {
return Flowable.empty();
}

Optional<Message> a2aMessageOpt = EventConverter.convertEventsToA2AMessage(invocationContext);

if (a2aMessageOpt.isEmpty()) {
logger.warn("Failed to convert event to A2A message.");
return Flowable.empty();
}

Message originalMessage = a2aMessageOpt.get();
Message originalMessage = prepareMessage(invocationContext);
String requestJson = serializeMessageToJson(originalMessage);

return Flowable.create(
Expand Down
194 changes: 157 additions & 37 deletions a2a/src/main/java/com/google/adk/a2a/converters/EventConverter.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,19 @@

import static com.google.common.collect.ImmutableList.toImmutableList;

import com.google.adk.a2a.common.GenAiFieldMissingException;
import com.google.adk.agents.InvocationContext;
import com.google.adk.events.Event;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import com.google.genai.types.Content;
import io.a2a.spec.Message;
import com.google.genai.types.FunctionResponse;
import io.a2a.spec.Part;
import java.util.Collection;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.jspecify.annotations.Nullable;

/**
* Converter for ADK Events to A2A Messages.
Expand All @@ -20,59 +23,176 @@
* use in production code.
*/
public final class EventConverter {
private static final Logger logger = LoggerFactory.getLogger(EventConverter.class);
public static final String ADK_TASK_ID_KEY = "adk_task_id";
public static final String ADK_CONTEXT_ID_KEY = "adk_context_id";

private EventConverter() {}

private static String metadataValue(Event event, String key) {
if (event.customMetadata().isEmpty()) {
return "";
}
return event.customMetadata().get().stream()
.filter(m -> m.key().orElse("").equals(key))
.findFirst()
.map(m -> m.stringValue().orElse(""))
.orElse("");
}

/**
* Converts an ADK InvocationContext to an A2A Message.
* Returns the task ID from the event.
*
* <p>It combines all the events in the session, plus the user content, converted into A2A Parts,
* into a single A2A Message.
* <p>Task ID is stored in the event's custom metadata with the key {@link #ADK_TASK_ID_KEY}.
*
* <p>If the context has no events, or no suitable content to build the message, an empty optional
* is returned.
*
* @param context The ADK InvocationContext to convert.
* @return The converted A2A Message.
* @param event The event to get the task ID from.
* @return The task ID, or an empty string if not found.
*/
public static Optional<Message> convertEventsToA2AMessage(InvocationContext context) {
if (context.session().events().isEmpty()) {
logger.warn("No events in session, cannot convert to A2A message.");
return Optional.empty();
}

ImmutableList.Builder<Part<?>> partsBuilder = ImmutableList.builder();
public static String taskId(Event event) {
return metadataValue(event, ADK_TASK_ID_KEY);
}

context
.session()
.events()
.forEach(
event ->
partsBuilder.addAll(
contentToParts(event.content(), event.partial().orElse(false))));
partsBuilder.addAll(contentToParts(context.userContent(), false));
/**
* Returns the context ID from the event.
*
* <p>Context ID is stored in the event's custom metadata with the key {@link
* #ADK_CONTEXT_ID_KEY}.
*
* @param event The event to get the context ID from.
* @return The context ID, or an empty string if not found.
*/
public static String contextId(Event event) {
return metadataValue(event, ADK_CONTEXT_ID_KEY);
}

ImmutableList<Part<?>> parts = partsBuilder.build();
/**
* Returns the last user function call event from the list of events.
*
* @param events The list of events to find the user function call event from.
* @return The user function call event, or null if not found.
*/
public static @Nullable Event findUserFunctionCall(List<Event> events) {
Event candidate = Iterables.getLast(events);
if (!candidate.author().equals("user")) {
return null;
}
FunctionResponse functionResponse;
try {
functionResponse = findUserFunctionResponse(candidate);
} catch (GenAiFieldMissingException e) {
return null;
}
if (functionResponse == null || functionResponse.id().isEmpty()) {
return null;
}
for (int i = events.size() - 2; i >= 0; i--) {
Event event = events.get(i);
if (isUserFunctionCall(event, functionResponse.id().get())) {
return event;
}
}
return null;
}

if (parts.isEmpty()) {
logger.warn("No suitable content found to build A2A request message.");
return Optional.empty();
private static FunctionResponse findUserFunctionResponse(Event candidate) {
if (candidate.content().isEmpty() || candidate.content().get().parts().isEmpty()) {
throw new GenAiFieldMissingException("Event has no content or parts.");
}
return candidate.content().get().parts().get().stream()
.filter(part -> part.functionResponse().isPresent())
.findFirst()
.orElseThrow(() -> new GenAiFieldMissingException("Event has no function response."))
.functionResponse()
.get();
}

return Optional.of(
new Message.Builder()
.messageId(UUID.randomUUID().toString())
.parts(parts)
.role(Message.Role.USER)
.build());
private static boolean isUserFunctionCall(Event event, String functionResponseId) {
if (event.content().isEmpty()) {
return false;
}
return event.content().get().parts().get().stream()
.anyMatch(
part ->
part.functionCall().isPresent()
&& part.functionCall().get().id().orElse("").equals(functionResponseId));
}

/**
* Converts a GenAI Content object to a list of A2A Parts.
*
* @param content The GenAI Content object to convert.
* @param isPartial Whether the content is partial.
* @return A list of A2A Parts.
*/
public static ImmutableList<Part<?>> contentToParts(
Optional<Content> content, boolean isPartial) {
return content.flatMap(Content::parts).stream()
.flatMap(Collection::stream)
.map(part -> PartConverter.fromGenaiPart(part, isPartial))
.collect(toImmutableList());
}

/**
* Returns the parts from the context events that should be sent to the agent.
*
* <p>All session events from the previous remote agent response (or the beginning of the session
* in case of the first agent invocation) are included into the A2A message. Events from other
* agents are presented as user messages and rephased as if a user was telling what happened in
* the session up to the point.
*
* @param context The invocation context to get the parts from.
* @return A list of A2A Parts.
*/
public static ImmutableList<Part<?>> messagePartsFromContext(InvocationContext context) {
if (context.session().events().isEmpty()) {
return ImmutableList.of();
}
List<Event> events = context.session().events();
int lastResponseIndex = -1;
String contextId = "";
for (int i = events.size() - 1; i >= 0; i--) {
Event event = events.get(i);
if (event.author().equals(context.agent().name())) {
lastResponseIndex = i;
contextId = contextId(event);
break;
}
}
ImmutableList.Builder<Part<?>> partsBuilder = ImmutableList.builder();
for (int i = lastResponseIndex + 1; i < events.size(); i++) {
Event event = events.get(i);
if (!event.author().equals("user") && !event.author().equals(context.agent().name())) {
event = presentAsUserMessage(event, contextId);
}
contentToParts(event.content(), event.partial().orElse(false)).forEach(partsBuilder::add);
}
return partsBuilder.build();
}

private static Event presentAsUserMessage(Event event, String contextId) {
Event.Builder userEvent =
new Event.Builder().id(UUID.randomUUID().toString()).invocationId(contextId).author("user");
ImmutableList<com.google.genai.types.Part> parts =
event.content().flatMap(Content::parts).stream()
.flatMap(Collection::stream)
// convert only non-thought parts to user message parts, skip thought parts as they are
// not meant to be shown to the user
.filter(part -> !part.thought().orElse(false))
.map(part -> PartConverter.remoteCallAsUserPart(event.author(), part))
.collect(toImmutableList());
if (parts.isEmpty()) {
return userEvent.build();
}
com.google.genai.types.Part forContext =
com.google.genai.types.Part.builder().text("For context:").build();
return userEvent
.content(
Content.builder()
.parts(
ImmutableList.<com.google.genai.types.Part>builder()
.add(forContext)
.addAll(parts)
.build())
.build())
.build();
}
}
44 changes: 44 additions & 0 deletions a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,50 @@ private static FilePart filePartToA2A(Part part, ImmutableMap.Builder<String, Ob
metadata.buildOrThrow());
}

/**
* Converts a remote call part to a user part.
*
* <p>Events are rephrased as if a user was telling what happened in the session up to the point.
* E.g.
*
* <pre>{@code
* For context:
* User said: Now help me with Z
* Agent A said: Agent B can help you with it!
* Agent B said: Agent C might know better.*
* }</pre>
*
* @param author The author of the part.
* @param part The part to convert.
* @return The converted part.
*/
public static Part remoteCallAsUserPart(String author, Part part) {
if (part.text().isPresent()) {
String partText = String.format("[%s] said: %s", author, part.text().get());
return Part.builder().text(partText).build();
} else if (part.functionCall().isPresent()) {
FunctionCall functionCall = part.functionCall().get();
String partText =
String.format(
"[%s] called tool %s with parameters: %s",
author,
functionCall.name().orElse("<unknown>"),
functionCall.args().orElse(ImmutableMap.of()));
return Part.builder().text(partText).build();
} else if (part.functionResponse().isPresent()) {
FunctionResponse functionResponse = part.functionResponse().get();
String partText =
String.format(
"[%s] %s tool returned result: %s",
author,
functionResponse.name().orElse("<unknown>"),
functionResponse.response().orElse(ImmutableMap.of()));
return Part.builder().text(partText).build();
} else {
return part;
}
}

@SuppressWarnings("unchecked") // safe conversion from objectMapper.readValue
private static Map<String, Object> coerceToMap(Object value) {
if (value == null) {
Expand Down
7 changes: 4 additions & 3 deletions a2a/src/test/java/com/google/adk/a2a/RemoteA2AAgentTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -412,10 +412,11 @@ public void runAsync_constructsRequestWithHistory() {
.sendMessage(messageCaptor.capture(), any(List.class), any(Consumer.class), any());
Message message = messageCaptor.getValue();
assertThat(message.getRole()).isEqualTo(Message.Role.USER);
assertThat(message.getParts()).hasSize(3);
assertThat(message.getParts()).hasSize(4);
assertThat(((TextPart) message.getParts().get(0)).getText()).isEqualTo("hello");
assertThat(((TextPart) message.getParts().get(1)).getText()).isEqualTo("hi");
assertThat(((TextPart) message.getParts().get(2)).getText()).isEqualTo("how are you?");
assertThat(((TextPart) message.getParts().get(1)).getText()).isEqualTo("For context:");
assertThat(((TextPart) message.getParts().get(2)).getText()).isEqualTo("[model] said: hi");
assertThat(((TextPart) message.getParts().get(3)).getText()).isEqualTo("how are you?");
}

@Test
Expand Down
Loading