Skip to content
129 changes: 99 additions & 30 deletions a2a/src/main/java/com/google/adk/a2a/executor/AgentExecutor.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,19 @@
import io.a2a.server.agentexecution.RequestContext;
import io.a2a.server.events.EventQueue;
import io.a2a.server.tasks.TaskUpdater;
import io.a2a.spec.Artifact;
import io.a2a.spec.InvalidAgentResponseError;
import io.a2a.spec.Message;
import io.a2a.spec.Part;
import io.a2a.spec.TaskArtifactUpdateEvent;
import io.a2a.spec.TaskState;
import io.a2a.spec.TaskStatus;
import io.a2a.spec.TaskStatusUpdateEvent;
import io.a2a.spec.TextPart;
import io.reactivex.rxjava3.core.Completable;
import io.reactivex.rxjava3.core.Flowable;
import io.reactivex.rxjava3.core.Maybe;
import io.reactivex.rxjava3.core.Single;
import io.reactivex.rxjava3.disposables.CompositeDisposable;
import io.reactivex.rxjava3.disposables.Disposable;
import java.util.HashMap;
Expand All @@ -43,10 +51,8 @@
* use in production code.
*/
public class AgentExecutor implements io.a2a.server.agentexecution.AgentExecutor {

private static final Logger logger = LoggerFactory.getLogger(AgentExecutor.class);
private static final String USER_ID_PREFIX = "A2A_USER_";

private final Map<String, Disposable> activeTasks = new ConcurrentHashMap<>();
private final Runner.Builder runnerBuilder;
private final AgentExecutorConfig agentExecutorConfig;
Expand Down Expand Up @@ -137,7 +143,6 @@ public Builder plugins(List<? extends Plugin> plugins) {
return this;
}

@CanIgnoreReturnValue
public AgentExecutor build() {
return new AgentExecutor(
app,
Expand Down Expand Up @@ -165,46 +170,88 @@ public void execute(RequestContext ctx, EventQueue eventQueue) {
if (message == null) {
throw new IllegalArgumentException("Message cannot be null");
}

// Submits a new task if there is no active task.
if (ctx.getTask() == null) {
updater.submit();
}

// Group all reactive work for this task into one container
CompositeDisposable taskDisposables = new CompositeDisposable();
// Check if the task with the task id is already running, put if absent.
if (activeTasks.putIfAbsent(ctx.getTaskId(), taskDisposables) != null) {
throw new IllegalStateException(String.format("Task %s already running", ctx.getTaskId()));
}

EventProcessor p = new EventProcessor(agentExecutorConfig.outputMode());
Content content = PartConverter.messageToContent(message);
Runner runner = runnerBuilder.build();
Single<Boolean> skipExecution =
agentExecutorConfig.beforeExecuteCallback() != null
? agentExecutorConfig.beforeExecuteCallback().call(ctx)
: Single.just(false);

Runner runner = runnerBuilder.build();
taskDisposables.add(
prepareSession(ctx, runner.appName(), runner.sessionService())
skipExecution
.flatMapPublisher(
session -> {
updater.startWork();
return runner.runAsync(
getUserId(ctx), session.id(), content, agentExecutorConfig.runConfig());
skip -> {
if (skip) {
cancel(ctx, eventQueue);
return Flowable.empty();
}
return Maybe.defer(
() -> {
return prepareSession(ctx, runner.appName(), runner.sessionService());
})
.flatMapPublisher(
session -> {
updater.startWork();
return runner.runAsync(
getUserId(ctx),
session.id(),
content,
agentExecutorConfig.runConfig());
});
})
.subscribe(
.concatMap(
event -> {
p.process(event, updater);
},
return p.process(event, ctx, agentExecutorConfig.afterEventCallback(), eventQueue)
.toFlowable();
})
// Ignore all events from the runner, since they are already processed.
.ignoreElements()
.materialize()
.flatMapCompletable(
notification -> {
Throwable error = notification.getError();
if (error != null) {
logger.error("Runner failed to execute", error);
}
return handleExecutionEnd(ctx, error, eventQueue);
})
.doFinally(() -> cleanupTask(ctx.getTaskId()))
.subscribe(
() -> {},
error -> {
logger.error("Runner failed with {}", error);
updater.fail(failedMessage(ctx, error));
cleanupTask(ctx.getTaskId());
},
() -> {
updater.complete();
cleanupTask(ctx.getTaskId());
logger.error("Failed to handle execution end", error);
}));
}

private Completable handleExecutionEnd(
RequestContext ctx, Throwable error, EventQueue eventQueue) {
TaskState state = error != null ? TaskState.FAILED : TaskState.COMPLETED;
Message message = error != null ? failedMessage(ctx, error) : null;
TaskStatusUpdateEvent initialEvent =
new TaskStatusUpdateEvent.Builder()
.taskId(ctx.getTaskId())
.contextId(ctx.getContextId())
.isFinal(true)
.status(new TaskStatus(state, message, null))
.build();
Maybe<TaskStatusUpdateEvent> afterExecute =
agentExecutorConfig.afterExecuteCallback() != null
? agentExecutorConfig.afterExecuteCallback().call(ctx, initialEvent)
: Maybe.just(initialEvent);
return afterExecute.doOnSuccess(event -> eventQueue.enqueueEvent(event)).ignoreElement();
}

private void cleanupTask(String taskId) {
Disposable d = activeTasks.remove(taskId);
if (d != null) {
Expand Down Expand Up @@ -249,16 +296,19 @@ private EventProcessor(AgentExecutorConfig.OutputMode outputMode) {
this.outputMode = outputMode;
}

private void process(Event event, TaskUpdater updater) {
private Maybe<TaskArtifactUpdateEvent> process(
Event event,
RequestContext ctx,
Callbacks.AfterEventCallback callback,
EventQueue eventQueue) {
if (event.errorCode().isPresent()) {
throw new InvalidAgentResponseError(
null, // Uses default code -32006
"Agent returned an error: " + event.errorCode().get(),
null);
return Maybe.error(
new InvalidAgentResponseError(
null, // Uses default code -32006
"Agent returned an error: " + event.errorCode().get(),
null));
}

ImmutableList<Part<?>> parts = EventConverter.contentToParts(event.content());

// Mark all parts as partial if the event is partial.
if (event.partial().orElse(false)) {
parts.forEach(
Expand Down Expand Up @@ -302,7 +352,26 @@ private void process(Event event, TaskUpdater updater) {
}
}

updater.addArtifact(parts, artifactId, null, metadata, append, lastChunk);
TaskArtifactUpdateEvent initialEvent =
new TaskArtifactUpdateEvent.Builder()
.taskId(ctx.getTaskId())
.contextId(ctx.getContextId())
.lastChunk(lastChunk)
.append(append)
.artifact(
new Artifact.Builder()
.artifactId(artifactId)
.parts(parts)
.metadata(metadata)
.build())
.build();

Maybe<TaskArtifactUpdateEvent> afterEvent =
callback != null ? callback.call(ctx, initialEvent, event) : Maybe.just(initialEvent);
return afterEvent.doOnSuccess(
finalEvent -> {
eventQueue.enqueueEvent(finalEvent);
});
}
}
}
Loading
Loading