Skip to content

Commit fca43fb

Browse files
google-genai-botcopybara-github
authored andcommitted
fix: prevent ConcurrentModificationException when session events are modified by another thread during iteration
PiperOrigin-RevId: 884587639
1 parent d8ca834 commit fca43fb

File tree

2 files changed

+67
-5
lines changed

2 files changed

+67
-5
lines changed

core/src/main/java/com/google/adk/flows/llmflows/Contents.java

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,19 @@ public Single<RequestProcessor.RequestProcessingResult> processRequest(
6464
modelName = "";
6565
}
6666

67+
ImmutableList<Event> sessionEvents;
68+
synchronized (context.session().events()) {
69+
sessionEvents = ImmutableList.copyOf(context.session().events());
70+
}
71+
6772
if (llmAgent.includeContents() == LlmAgent.IncludeContents.NONE) {
6873
return Single.just(
6974
RequestProcessor.RequestProcessingResult.create(
7075
request.toBuilder()
7176
.contents(
7277
getCurrentTurnContents(
7378
context.branch().orElse(null),
74-
context.session().events(),
79+
sessionEvents,
7580
context.agent().name(),
7681
modelName))
7782
.build(),
@@ -80,10 +85,7 @@ public Single<RequestProcessor.RequestProcessingResult> processRequest(
8085

8186
ImmutableList<Content> contents =
8287
getContents(
83-
context.branch().orElse(null),
84-
context.session().events(),
85-
context.agent().name(),
86-
modelName);
88+
context.branch().orElse(null), sessionEvents, context.agent().name(), modelName);
8789

8890
return Single.just(
8991
RequestProcessor.RequestProcessingResult.create(

core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,13 @@
3636
import com.google.genai.types.FunctionCall;
3737
import com.google.genai.types.FunctionResponse;
3838
import com.google.genai.types.Part;
39+
import java.util.HashMap;
3940
import java.util.List;
4041
import java.util.Map;
4142
import java.util.Objects;
4243
import java.util.Optional;
44+
import java.util.concurrent.CountDownLatch;
45+
import java.util.concurrent.atomic.AtomicReference;
4346
import org.junit.Test;
4447
import org.junit.runner.RunWith;
4548
import org.junit.runners.JUnit4;
@@ -780,6 +783,63 @@ public void processRequest_notEmptyContent() {
780783
assertThat(contents).containsExactly(e.content().get());
781784
}
782785

786+
@Test
787+
public void processRequest_concurrentReadAndWrite_noException() throws Exception {
788+
LlmAgent agent =
789+
LlmAgent.builder().name(AGENT).includeContents(LlmAgent.IncludeContents.DEFAULT).build();
790+
Session session =
791+
sessionService
792+
.createSession("test-app", "test-user", new HashMap<>(), "test-session")
793+
.blockingGet();
794+
795+
// Seed with dummy events to widen the race capability
796+
for (int i = 0; i < 5000; i++) {
797+
session.events().add(createUserEvent("dummy" + i, "dummy"));
798+
}
799+
800+
InvocationContext context =
801+
InvocationContext.builder()
802+
.invocationId("test-invocation")
803+
.agent(agent)
804+
.session(session)
805+
.sessionService(sessionService)
806+
.build();
807+
808+
LlmRequest initialRequest = LlmRequest.builder().build();
809+
810+
AtomicReference<Throwable> writerError = new AtomicReference<>();
811+
CountDownLatch startLatch = new CountDownLatch(1);
812+
813+
Thread writerThread =
814+
new Thread(
815+
() -> {
816+
startLatch.countDown();
817+
try {
818+
for (int i = 0; i < 2000; i++) {
819+
session.events().add(createUserEvent("writer" + i, "new data"));
820+
}
821+
} catch (Throwable t) {
822+
writerError.set(t);
823+
}
824+
});
825+
826+
writerThread.start();
827+
startLatch.await(); // wait for writer to be ready
828+
829+
// Process (read) requests concurrently to trigger race conditions
830+
for (int i = 0; i < 200; i++) {
831+
var unused = contentsProcessor.processRequest(context, initialRequest).blockingGet();
832+
if (writerError.get() != null) {
833+
throw new RuntimeException("Writer failed", writerError.get());
834+
}
835+
}
836+
837+
writerThread.join();
838+
if (writerError.get() != null) {
839+
throw new RuntimeException("Writer failed", writerError.get());
840+
}
841+
}
842+
783843
private static Event createUserEvent(String id, String text) {
784844
return Event.builder()
785845
.id(id)

0 commit comments

Comments
 (0)