From a80da51130b37a0920fb5369743926072f9cd2d8 Mon Sep 17 00:00:00 2001 From: Roland Kakonyi Date: Mon, 4 May 2026 10:35:57 +0200 Subject: [PATCH] Fix Vertex session user IDs --- .../adk/sessions/VertexAiSessionService.java | 13 +++++- .../sessions/VertexAiSessionServiceTest.java | 42 ++++++++++++++++++- 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java b/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java index 99e7e3479..4cfea51dd 100644 --- a/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java +++ b/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java @@ -35,6 +35,7 @@ import java.util.Comparator; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; @@ -148,7 +149,7 @@ private ListSessionsResponse parseListSessionsResponse( Session session = Session.builder(sessionId) .appName(appName) - .userId(userId) + .userId((String) apiSession.get("userId")) .state( apiSession.get("sessionState") == null ? new ConcurrentHashMap<>() @@ -195,6 +196,16 @@ public Maybe getSession( .getSession(reasoningEngineId, sessionId) .flatMap( getSessionResponseMap -> { + String responseUserId = + Optional.ofNullable(getSessionResponseMap.get("userId")) + .map(JsonNode::asText) + .orElse(null); + if (!Objects.equals(responseUserId, userId)) { + return Maybe.error( + new IllegalArgumentException( + "Session " + sessionId + " does not belong to user " + userId + ".")); + } + String sessId = Optional.ofNullable(getSessionResponseMap.get("name")) .map(name -> Iterables.getLast(Splitter.on('/').splitToList(name.asText()))) diff --git a/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java b/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java index dd62263d7..3c9e8c3e2 100644 --- a/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java +++ b/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java @@ -274,6 +274,33 @@ public void listSessions_success() { assertThat(sessionsList).hasSize(2); ImmutableList ids = sessionsList.stream().map(Session::id).collect(toImmutableList()); assertThat(ids).containsExactly("1", "2"); + ImmutableList userIds = + sessionsList.stream().map(Session::userId).collect(toImmutableList()); + assertThat(userIds).containsExactly("user", "user"); + } + + @Test + public void listSessions_usesResponseUserId() throws Exception { + when(mockApiClient.request("GET", "reasoningEngines/123/sessions?filter=user_id=user1", "")) + .thenAnswer( + new MockApiAnswer( + """ + { + "sessions": [ + { + "name": "projects/test-project/locations/test-location/reasoningEngines/123/sessions/3", + "userId": "user2", + "updateTime": "2024-12-14T12:12:12.123456Z" + } + ] + }\ + """)); + + ListSessionsResponse sessions = + vertexAiSessionService.listSessions("123", "user1").blockingGet(); + + assertThat(sessions.sessions()).hasSize(1); + assertThat(sessions.sessions().get(0).userId()).isEqualTo("user2"); } @Test @@ -346,12 +373,25 @@ public void listEvents_empty() { public void listEmptySession_success() { assertThat( vertexAiSessionService - .getSession("789", "user1", "3", Optional.empty()) + .getSession("789", "user2", "3", Optional.empty()) .blockingGet() .events()) .isEmpty(); } + @Test + public void getSession_whenResponseUserIdDiffers_throws() { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + vertexAiSessionService + .getSession("789", "user1", "3", Optional.empty()) + .blockingGet()); + + assertThat(exception).hasMessageThat().contains("Session 3 does not belong to user user1."); + } + @Test public void appendEvent_withStateRemoved_updatesSessionState() { String userId = "userB";