Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 18 additions & 13 deletions core/src/main/java/com/google/adk/runner/Runner.java
Original file line number Diff line number Diff line change
Expand Up @@ -312,27 +312,27 @@ private Single<Event> appendNewMessageToSession(
throw new IllegalArgumentException("No parts in the new_message.");
}

List<Completable> artifactSaves = new ArrayList<>();
if (this.artifactService != null && saveInputBlobsAsArtifacts) {
// The runner directly saves the artifacts (if applicable) in the user message and replaces
// the artifact data with a file name placeholder.
List<Part> newParts = new ArrayList<>();
for (int i = 0; i < newMessage.parts().get().size(); i++) {
Part part = newMessage.parts().get().get(i);
if (part.inlineData().isEmpty()) {
if (part.inlineData().isEmpty() && part.fileData().isEmpty()) {
newParts.add(part);
continue;
}
String fileName = "artifact_" + invocationContext.invocationId() + "_" + i;
var unused =
this.artifactService.saveArtifact(
this.appName, session.userId(), session.id(), fileName, part);

newMessage
.parts()
.get()
.set(
i,
Part.fromText(
"Uploaded file: " + fileName + ". It has been saved to the artifacts"));
artifactSaves.add(
this.artifactService
.saveArtifact(this.appName, session.userId(), session.id(), fileName, part)
.ignoreElement());

newParts.add(
Part.fromText("Uploaded file: " + fileName + ". It has been saved to the artifacts"));
}
newMessage = newMessage.toBuilder().parts(newParts).build();
}
// Appends only. We do not yield the event because it's not from the model.
Event.Builder eventBuilder =
Expand All @@ -348,7 +348,12 @@ private Single<Event> appendNewMessageToSession(
EventActions.builder().stateDelta(new ConcurrentHashMap<>(stateDelta)).build());
}

return this.sessionService.appendEvent(session, eventBuilder.build());
Single<Event> appendEventSingle =
this.sessionService.appendEvent(session, eventBuilder.build());
if (artifactSaves.isEmpty()) {
return appendEventSingle;
}
return Completable.merge(artifactSaves).andThen(appendEventSingle);
}

/** See {@link #runAsync(String, String, Content, RunConfig, Map)}. */
Expand Down
118 changes: 118 additions & 0 deletions core/src/test/java/com/google/adk/runner/RunnerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import com.google.adk.agents.LlmAgent;
import com.google.adk.agents.RunConfig;
import com.google.adk.apps.App;
import com.google.adk.artifacts.BaseArtifactService;
import com.google.adk.events.Event;
import com.google.adk.flows.llmflows.Functions;
import com.google.adk.models.LlmResponse;
Expand All @@ -52,6 +53,7 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import com.google.genai.types.Blob;
import com.google.genai.types.Content;
import com.google.genai.types.FunctionCall;
import com.google.genai.types.FunctionResponse;
Expand All @@ -62,6 +64,7 @@
import io.reactivex.rxjava3.core.Completable;
import io.reactivex.rxjava3.core.Flowable;
import io.reactivex.rxjava3.core.Maybe;
import io.reactivex.rxjava3.core.Single;
import io.reactivex.rxjava3.subscribers.TestSubscriber;
import java.util.List;
import java.util.Objects;
Expand Down Expand Up @@ -640,6 +643,121 @@ public void runAsync_withSessionKeyAndStateDelta_mergesStateIntoSession() {
assertThat(finalSession.state()).containsAtLeastEntriesIn(stateDelta);
}

@Test
public void runAsync_withSaveInputBlobsAsArtifactsTrue_savesBlobsAndReplacesContent() {
BaseArtifactService mockArtifactService = mock(BaseArtifactService.class);
when(mockArtifactService.saveArtifact(any(), any(), any(), any(), any()))
.thenReturn(Single.just(1));

Runner runnerWithMockService =
Runner.builder()
.app(
App.builder()
.name("test")
.rootAgent(agent)
.plugins(ImmutableList.of(plugin))
.build())
.artifactService(mockArtifactService)
.build();
Session localSession =
runnerWithMockService.sessionService().createSession("test", "user").blockingGet();

Content userContent =
Content.builder()
.role("user")
.parts(
Part.fromText("text part"),
Part.builder()
.inlineData(
Blob.builder().mimeType("image/png").data(new byte[] {1, 2, 3}).build())
.build())
.build();

var events =
runnerWithMockService
.runAsync(
"user",
localSession.id(),
userContent,
RunConfig.builder().setSaveInputBlobsAsArtifacts(true).build())
.toList()
.blockingGet();

assertThat(simplifyEvents(events)).containsExactly("test agent: from llm");

Session finalSession =
runnerWithMockService
.sessionService()
.getSession("test", "user", localSession.id(), Optional.empty())
.blockingGet();

Event savedEvent = finalSession.events().get(0);
assertThat(savedEvent.author()).isEqualTo("user");
List<Part> parts = savedEvent.content().get().parts().get();
assertThat(parts).hasSize(2);
assertThat(parts.get(0).text()).hasValue("text part");
assertThat(parts.get(1).text().get()).startsWith("Uploaded file: artifact_");
assertThat(parts.get(1).inlineData()).isEmpty();

verify(mockArtifactService).saveArtifact(any(), any(), any(), any(), any());
}

@Test
public void runAsync_withSaveInputBlobsAsArtifactsFalse_doesNotModifyContent() {
BaseArtifactService mockArtifactService = mock(BaseArtifactService.class);

Runner runnerWithMockService =
Runner.builder()
.app(
App.builder()
.name("test")
.rootAgent(agent)
.plugins(ImmutableList.of(plugin))
.build())
.artifactService(mockArtifactService)
.build();
Session localSession =
runnerWithMockService.sessionService().createSession("test", "user").blockingGet();

Content userContent =
Content.builder()
.role("user")
.parts(
Part.fromText("text part"),
Part.builder()
.inlineData(
Blob.builder().mimeType("image/png").data(new byte[] {1, 2, 3}).build())
.build())
.build();

var events =
runnerWithMockService
.runAsync(
"user",
localSession.id(),
userContent,
RunConfig.builder().setSaveInputBlobsAsArtifacts(false).build())
.toList()
.blockingGet();

assertThat(simplifyEvents(events)).containsExactly("test agent: from llm");

Session finalSession =
runnerWithMockService
.sessionService()
.getSession("test", "user", localSession.id(), Optional.empty())
.blockingGet();

Event savedEvent = finalSession.events().get(0);
assertThat(savedEvent.author()).isEqualTo("user");
List<Part> parts = savedEvent.content().get().parts().get();
assertThat(parts).hasSize(2);
assertThat(parts.get(0).text()).hasValue("text part");
assertThat(parts.get(1).inlineData().get().mimeType()).hasValue("image/png");

verify(mockArtifactService, never()).saveArtifact(any(), any(), any(), any(), any());
}

@Test
public void runAsync_withEmptyStateDelta_doesNotModifySession() {
ImmutableMap<String, Object> emptyStateDelta = ImmutableMap.of();
Expand Down