From 8669fe0cee887a95b9fda256e43c9e576b165e62 Mon Sep 17 00:00:00 2001 From: Adam Seering Date: Mon, 12 Jan 2026 10:44:07 +0000 Subject: [PATCH 1/2] feat: Add ClientContext to Options and propagate to RPCs This change adds support for ClientContext in Options and ensures it is propagated to ExecuteSql, Read, Commit, and BeginTransaction requests. It aligns with go/spanner-client-scoped-session-state design. - Added RequestOptions.ClientContext to Options. - Refactored request option building to Options.toRequestOptionsProto. - Updated AbstractReadContext, TransactionRunnerImpl, and SessionImpl to use the shared logic. - Added tests. --- .../cloud/spanner/AbstractReadContext.java | 24 ++++--- .../com/google/cloud/spanner/Options.java | 62 +++++++++++++++++++ .../com/google/cloud/spanner/SessionImpl.java | 22 +++++-- .../google/cloud/spanner/SpannerOptions.java | 16 +++++ .../cloud/spanner/TransactionRunnerImpl.java | 12 +--- .../spanner/AbstractReadContextTest.java | 20 ++++++ .../com/google/cloud/spanner/OptionsTest.java | 34 +++++++--- .../google/cloud/spanner/SessionImplTest.java | 37 +++++++++++ .../spanner/TransactionContextImplTest.java | 8 +++ .../spanner/TransactionRunnerImplTest.java | 37 +++++++++++ 10 files changed, 239 insertions(+), 33 deletions(-) diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java index 289acb1a745..a89dbb2bb60 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java @@ -684,22 +684,20 @@ QueryOptions buildQueryOptions(QueryOptions requestOptions) { } RequestOptions buildRequestOptions(Options options) { - // Shortcut for the most common return value. - if (!(options.hasPriority() || options.hasTag() || getTransactionTag() != null)) { - return RequestOptions.getDefaultInstance(); - } - - RequestOptions.Builder builder = RequestOptions.newBuilder(); - if (options.hasPriority()) { - builder.setPriority(options.priority()); - } - if (options.hasTag()) { - builder.setRequestTag(options.tag()); + RequestOptions requestOptions = options.toRequestOptionsProto(false); + RequestOptions.ClientContext defaultClientContext = + session.getSpanner().getOptions().getClientContext(); + if (defaultClientContext != null) { + RequestOptions.ClientContext.Builder builder = defaultClientContext.toBuilder(); + if (requestOptions.hasClientContext()) { + builder.mergeFrom(requestOptions.getClientContext()); + } + requestOptions = requestOptions.toBuilder().setClientContext(builder.build()).build(); } if (getTransactionTag() != null) { - builder.setTransactionTag(getTransactionTag()); + return requestOptions.toBuilder().setTransactionTag(getTransactionTag()).build(); } - return builder.build(); + return requestOptions; } ExecuteSqlRequest.Builder getExecuteSqlRequestBuilder( diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Options.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Options.java index 1e6ce34d672..116e1aa4fc5 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Options.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Options.java @@ -20,6 +20,7 @@ import com.google.spanner.v1.DirectedReadOptions; import com.google.spanner.v1.ReadRequest.LockHint; import com.google.spanner.v1.ReadRequest.OrderBy; +import com.google.spanner.v1.RequestOptions; import com.google.spanner.v1.RequestOptions.Priority; import com.google.spanner.v1.TransactionOptions.IsolationLevel; import com.google.spanner.v1.TransactionOptions.ReadWrite.ReadLockMode; @@ -265,6 +266,37 @@ public static ReadQueryUpdateTransactionOption priority(RpcPriority priority) { return new PriorityOption(priority); } + /** + * Specifying this will add the given client context to the request. The client context is used to + * pass side-channel or configuration information to the backend, such as a user ID for a + * parameterized secure view. + */ + public static ReadQueryUpdateTransactionOption clientContext( + RequestOptions.ClientContext clientContext) { + return new ClientContextOption(clientContext); + } + + RequestOptions toRequestOptionsProto(boolean isTransactionOption) { + if (!hasPriority() && !hasTag() && !hasClientContext()) { + return RequestOptions.getDefaultInstance(); + } + RequestOptions.Builder builder = RequestOptions.newBuilder(); + if (hasPriority()) { + builder.setPriority(priority()); + } + if (hasTag()) { + if (isTransactionOption) { + builder.setTransactionTag(tag()); + } else { + builder.setRequestTag(tag()); + } + } + if (hasClientContext()) { + builder.setClientContext(clientContext()); + } + return builder.build(); + } + public static TransactionOption maxCommitDelay(Duration maxCommitDelay) { Preconditions.checkArgument(!maxCommitDelay.isNegative(), "maxCommitDelay should be positive"); return new MaxCommitDelayOption(maxCommitDelay); @@ -462,6 +494,20 @@ void appendToOptions(Options options) { } } + static final class ClientContextOption extends InternalOption + implements ReadQueryUpdateTransactionOption { + private final RequestOptions.ClientContext clientContext; + + ClientContextOption(RequestOptions.ClientContext clientContext) { + this.clientContext = clientContext; + } + + @Override + void appendToOptions(Options options) { + options.clientContext = clientContext; + } + } + static final class TagOption extends InternalOption implements ReadQueryUpdateTransactionOption { private final String tag; @@ -574,6 +620,7 @@ void appendToOptions(Options options) { private String filter; private RpcPriority priority; private String tag; + private RequestOptions.ClientContext clientContext; private String etag; private Boolean validateOnly; private Boolean withExcludeTxnFromChangeStreams; @@ -666,6 +713,14 @@ Priority priority() { return priority == null ? null : priority.proto; } + boolean hasClientContext() { + return clientContext != null; + } + + RequestOptions.ClientContext clientContext() { + return clientContext; + } + boolean hasTag() { return tag != null; } @@ -777,6 +832,9 @@ public String toString() { if (priority != null) { b.append("priority: ").append(priority).append(' '); } + if (clientContext != null) { + b.append("clientContext: ").append(clientContext).append(' '); + } if (tag != null) { b.append("tag: ").append(tag).append(' '); } @@ -850,6 +908,7 @@ public boolean equals(Object o) { && Objects.equals(pageToken(), that.pageToken()) && Objects.equals(filter(), that.filter()) && Objects.equals(priority(), that.priority()) + && Objects.equals(clientContext(), that.clientContext()) && Objects.equals(tag(), that.tag()) && Objects.equals(etag(), that.etag()) && Objects.equals(validateOnly(), that.validateOnly()) @@ -894,6 +953,9 @@ public int hashCode() { if (priority != null) { result = 31 * result + priority.hashCode(); } + if (clientContext != null) { + result = 31 * result + clientContext.hashCode(); + } if (tag != null) { result = 31 * result + tag.hashCode(); } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java index 94881bd5e08..a59314a8a3a 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java @@ -32,7 +32,6 @@ import com.google.cloud.spanner.SessionClient.SessionOption; import com.google.cloud.spanner.TransactionRunnerImpl.TransactionContextImpl; import com.google.cloud.spanner.spi.v1.SpannerRpc; -import com.google.common.base.Strings; import com.google.common.base.Ticker; import com.google.common.collect.Lists; import com.google.common.util.concurrent.MoreExecutors; @@ -182,6 +181,10 @@ ErrorHandler getErrorHandler() { return this.errorHandler; } + SpannerImpl getSpanner() { + return spanner; + } + void setCurrentSpan(ISpan span) { currentSpan = span; } @@ -486,9 +489,20 @@ ApiFuture beginTransactionAsync( if (sessionReference.getIsMultiplexed() && mutation != null) { requestBuilder.setMutationKey(mutation); } - if (sessionReference.getIsMultiplexed() && !Strings.isNullOrEmpty(transactionOptions.tag())) { - requestBuilder.setRequestOptions( - RequestOptions.newBuilder().setTransactionTag(transactionOptions.tag()).build()); + RequestOptions requestOptions = transactionOptions.toRequestOptionsProto(true); + RequestOptions.ClientContext defaultClientContext = spanner.getOptions().getClientContext(); + if (defaultClientContext != null) { + RequestOptions.ClientContext.Builder builder = defaultClientContext.toBuilder(); + if (requestOptions.hasClientContext()) { + builder.mergeFrom(requestOptions.getClientContext()); + } + requestOptions = requestOptions.toBuilder().setClientContext(builder.build()).build(); + } + if (!sessionReference.getIsMultiplexed()) { + requestOptions = requestOptions.toBuilder().clearTransactionTag().build(); + } + if (!requestOptions.equals(RequestOptions.getDefaultInstance())) { + requestBuilder.setRequestOptions(requestOptions); } final BeginTransactionRequest request = requestBuilder.build(); final ApiFuture requestFuture; diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerOptions.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerOptions.java index 2e01e3d4ca3..bd8e4a9fec3 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerOptions.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerOptions.java @@ -67,6 +67,7 @@ import com.google.spanner.v1.DirectedReadOptions; import com.google.spanner.v1.ExecuteSqlRequest; import com.google.spanner.v1.ExecuteSqlRequest.QueryOptions; +import com.google.spanner.v1.RequestOptions; import com.google.spanner.v1.SpannerGrpc; import com.google.spanner.v1.TransactionOptions; import com.google.spanner.v1.TransactionOptions.IsolationLevel; @@ -257,6 +258,7 @@ public static GcpChannelPoolOptions createDefaultDynamicChannelPoolOptions() { private final boolean enableEndToEndTracing; private final String monitoringHost; private final TransactionOptions defaultTransactionOptions; + private final RequestOptions.ClientContext clientContext; enum TracingFramework { OPEN_CENSUS, @@ -922,6 +924,7 @@ protected SpannerOptions(Builder builder) { enableEndToEndTracing = builder.enableEndToEndTracing; monitoringHost = builder.monitoringHost; defaultTransactionOptions = builder.defaultTransactionOptions; + clientContext = builder.clientContext; } private String getResolvedUniverseDomain() { @@ -929,6 +932,11 @@ private String getResolvedUniverseDomain() { return Strings.isNullOrEmpty(universeDomain) ? GOOGLE_DEFAULT_UNIVERSE : universeDomain; } + /** Returns the default {@link RequestOptions.ClientContext} for this {@link SpannerOptions}. */ + public RequestOptions.ClientContext getClientContext() { + return clientContext; + } + /** * The environment to read configuration values from. The default implementation uses environment * variables. @@ -1142,6 +1150,7 @@ public static class Builder private String experimentalHost = null; private boolean usePlainText = false; private TransactionOptions defaultTransactionOptions = TransactionOptions.getDefaultInstance(); + private RequestOptions.ClientContext clientContext; private static String createCustomClientLibToken(String token) { return token + " " + ServiceOptions.getGoogApiClientLibName(); @@ -1243,6 +1252,7 @@ protected Builder() { this.enableEndToEndTracing = options.enableEndToEndTracing; this.monitoringHost = options.monitoringHost; this.defaultTransactionOptions = options.defaultTransactionOptions; + this.clientContext = options.clientContext; } @Override @@ -1977,6 +1987,12 @@ public Builder setDefaultTransactionOptions( return this; } + /** Sets the default {@link RequestOptions.ClientContext} for all requests. */ + public Builder setClientContext(RequestOptions.ClientContext clientContext) { + this.clientContext = clientContext; + return this; + } + @SuppressWarnings("rawtypes") @Override public SpannerOptions build() { diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionRunnerImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionRunnerImpl.java index 7afccce194c..d28566cef89 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionRunnerImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionRunnerImpl.java @@ -464,15 +464,9 @@ public void run() { waitForTransactionTimeoutMillis, TimeUnit.MILLISECONDS) : transactionId); } - if (options.hasPriority() || getTransactionTag() != null) { - RequestOptions.Builder requestOptionsBuilder = RequestOptions.newBuilder(); - if (options.hasPriority()) { - requestOptionsBuilder.setPriority(options.priority()); - } - if (getTransactionTag() != null) { - requestOptionsBuilder.setTransactionTag(getTransactionTag()); - } - requestBuilder.setRequestOptions(requestOptionsBuilder.build()); + RequestOptions requestOptions = options.toRequestOptionsProto(true); + if (!requestOptions.equals(RequestOptions.getDefaultInstance())) { + requestBuilder.setRequestOptions(requestOptions); } if (session.getIsMultiplexed() && getLatestPrecommitToken() != null) { // Set the precommit token in the CommitRequest for multiplexed sessions. diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AbstractReadContextTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AbstractReadContextTest.java index eea6658d26d..b4bc7bf7bb6 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AbstractReadContextTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AbstractReadContextTest.java @@ -138,6 +138,10 @@ String getTransactionTag() { public void setup() { SessionImpl session = mock(SessionImpl.class); when(session.getName()).thenReturn("session-1"); + SpannerImpl spanner = mock(SpannerImpl.class); + SpannerOptions spannerOptions = mock(SpannerOptions.class); + when(spanner.getOptions()).thenReturn(spannerOptions); + when(session.getSpanner()).thenReturn(spanner); TestReadContextBuilder builder = new TestReadContextBuilder(); context = builder @@ -322,6 +326,10 @@ public void executeSqlRequestBuilderWithRequestOptions() { public void executeSqlRequestBuilderWithRequestOptionsWithTxnTag() { SessionImpl session = mock(SessionImpl.class); when(session.getName()).thenReturn("session-1"); + SpannerImpl spanner = mock(SpannerImpl.class); + SpannerOptions spannerOptions = mock(SpannerOptions.class); + when(spanner.getOptions()).thenReturn(spannerOptions); + when(session.getSpanner()).thenReturn(spanner); TestReadContextWithTagBuilder builder = new TestReadContextWithTagBuilder(); TestReadContextWithTag contextWithTag = builder @@ -345,6 +353,18 @@ public void executeSqlRequestBuilderWithRequestOptionsWithTxnTag() { assertThat(request.getRequestOptions().getTransactionTag()).isEqualTo("app=spanner,env=test"); } + @Test + public void testBuildRequestOptionsWithClientContext() { + RequestOptions.ClientContext clientContext = + RequestOptions.ClientContext.newBuilder() + .putSecureContext( + "key", com.google.protobuf.Value.newBuilder().setStringValue("value").build()) + .build(); + RequestOptions requestOptions = + context.buildRequestOptions(Options.fromQueryOptions(Options.clientContext(clientContext))); + assertEquals(clientContext, requestOptions.getClientContext()); + } + @Test public void testGetExecuteSqlRequestBuilderWithDirectedReadOptions() { ExecuteSqlRequest.Builder request = diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/OptionsTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/OptionsTest.java index 8571c42b3dd..3edf9a61e17 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/OptionsTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/OptionsTest.java @@ -34,17 +34,37 @@ import com.google.spanner.v1.DirectedReadOptions.ReplicaSelection; import com.google.spanner.v1.ReadRequest.LockHint; import com.google.spanner.v1.ReadRequest.OrderBy; -import com.google.spanner.v1.RequestOptions.Priority; -import com.google.spanner.v1.TransactionOptions.IsolationLevel; -import com.google.spanner.v1.TransactionOptions.ReadWrite; -import com.google.spanner.v1.TransactionOptions.ReadWrite.ReadLockMode; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; +import com.google.spanner.v1.RequestOptions; /** Unit tests for {@link Options}. */ @RunWith(JUnit4.class) public class OptionsTest { + @Test + public void testToRequestOptionsProto() { + RequestOptions.ClientContext clientContext = + RequestOptions.ClientContext.newBuilder() + .putSecureContext( + "key", com.google.protobuf.Value.newBuilder().setStringValue("value").build()) + .build(); + Options options = + Options.fromQueryOptions( + Options.priority(RpcPriority.HIGH), + Options.tag("tag"), + Options.clientContext(clientContext)); + + RequestOptions protoForStatement = options.toRequestOptionsProto(false); + assertEquals(RequestOptions.Priority.PRIORITY_HIGH, protoForStatement.getPriority()); + assertEquals("tag", protoForStatement.getRequestTag()); + assertEquals("", protoForStatement.getTransactionTag()); + assertEquals(clientContext, protoForStatement.getClientContext()); + + RequestOptions protoForTransaction = options.toRequestOptionsProto(true); + assertEquals(RequestOptions.Priority.PRIORITY_HIGH, protoForTransaction.getPriority()); + assertEquals("", protoForTransaction.getRequestTag()); + assertEquals("tag", protoForTransaction.getTransactionTag()); + assertEquals(clientContext, protoForTransaction.getClientContext()); + } + private static final DirectedReadOptions DIRECTED_READ_OPTIONS = DirectedReadOptions.newBuilder() .setIncludeReplicas( diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionImplTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionImplTest.java index 1ac3b7beaf7..ea7b6d5306d 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionImplTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionImplTest.java @@ -47,6 +47,7 @@ import com.google.spanner.v1.CommitResponse; import com.google.spanner.v1.Mutation.Write; import com.google.spanner.v1.PartialResultSet; +import com.google.spanner.v1.RequestOptions; import com.google.spanner.v1.ResultSetMetadata; import com.google.spanner.v1.RollbackRequest; import com.google.spanner.v1.Session; @@ -77,6 +78,42 @@ /** Unit tests for {@link com.google.cloud.spanner.SessionImpl}. */ @RunWith(JUnit4.class) public class SessionImplTest { + @Test + public void testBeginTransactionWithClientContext() { + RequestOptions.ClientContext clientContext = + RequestOptions.ClientContext.newBuilder() + .putSecureContext( + "key", com.google.protobuf.Value.newBuilder().setStringValue("value").build()) + .build(); + Mockito.when( + rpc.beginTransactionAsync( + Mockito.any(BeginTransactionRequest.class), anyMap(), eq(true))) + .thenReturn( + ApiFutures.immediateFuture( + Transaction.newBuilder().setId(ByteString.copyFromUtf8("tx")).build())); + + ((SessionImpl) session) + .beginTransactionAsync( + Options.fromTransactionOptions( + Options.priority(Options.RpcPriority.HIGH), + Options.tag("tag"), + Options.clientContext(clientContext)), + true, + Collections.emptyMap(), + null, + null); + + ArgumentCaptor requestCaptor = + ArgumentCaptor.forClass(BeginTransactionRequest.class); + Mockito.verify(rpc).beginTransactionAsync(requestCaptor.capture(), anyMap(), eq(true)); + BeginTransactionRequest request = requestCaptor.getValue(); + RequestOptions requestOptions = request.getRequestOptions(); + assertEquals(RequestOptions.Priority.PRIORITY_HIGH, requestOptions.getPriority()); + // TransactionTag should NOT be set because session is not multiplexed. + assertEquals("", requestOptions.getTransactionTag()); + assertEquals(clientContext, requestOptions.getClientContext()); + } + @Mock private SpannerRpc rpc; @Mock private SpannerOptions spannerOptions; private com.google.cloud.spanner.Session session; diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionContextImplTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionContextImplTest.java index bbfa7cbd0d5..49a47364a58 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionContextImplTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionContextImplTest.java @@ -70,6 +70,10 @@ public void setup() { when(rpc.getRequestIdCreator()).thenReturn(NoopRequestIdCreator.INSTANCE); when(session.getName()).thenReturn("test"); when(session.getRequestIdCreator()).thenReturn(NoopRequestIdCreator.INSTANCE); + SpannerImpl spanner = mock(SpannerImpl.class); + SpannerOptions spannerOptions = mock(SpannerOptions.class); + when(spanner.getOptions()).thenReturn(spannerOptions); + when(session.getSpanner()).thenReturn(spanner); doNothing().when(span).setStatus(any(Throwable.class)); doNothing().when(span).end(); doNothing().when(span).addAnnotation("Starting Commit"); @@ -214,6 +218,10 @@ private void batchDml(int status) { SessionImpl session = mock(SessionImpl.class); when(session.getName()).thenReturn("test"); when(session.getRequestIdCreator()).thenReturn(NoopRequestIdCreator.INSTANCE); + SpannerImpl spanner = mock(SpannerImpl.class); + SpannerOptions spannerOptions = mock(SpannerOptions.class); + when(spanner.getOptions()).thenReturn(spannerOptions); + when(session.getSpanner()).thenReturn(spanner); SpannerRpc rpc = mock(SpannerRpc.class); ExecuteBatchDmlResponse response = ExecuteBatchDmlResponse.newBuilder() diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionRunnerImplTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionRunnerImplTest.java index 907f628329e..b99be489b61 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionRunnerImplTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionRunnerImplTest.java @@ -53,6 +53,7 @@ import com.google.spanner.v1.ExecuteBatchDmlResponse; import com.google.spanner.v1.ExecuteSqlRequest; import com.google.spanner.v1.ExecuteSqlRequest.QueryOptions; +import com.google.spanner.v1.RequestOptions; import com.google.spanner.v1.ResultSet; import com.google.spanner.v1.ResultSetMetadata; import com.google.spanner.v1.ResultSetStats; @@ -79,6 +80,7 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.Mockito; import org.mockito.MockitoAnnotations; @@ -99,6 +101,37 @@ public void release(ScheduledExecutorService exec) { } } + @Test + public void testCommitWithClientContext() { + RequestOptions.ClientContext clientContext = + RequestOptions.ClientContext.newBuilder() + .putSecureContext( + "key", com.google.protobuf.Value.newBuilder().setStringValue("value").build()) + .build(); + Options options = + Options.fromTransactionOptions( + Options.priority(Options.RpcPriority.HIGH), + Options.tag("tag"), + Options.clientContext(clientContext)); + transactionRunner = new TransactionRunnerImpl(session, options); + when(session.getName()).thenReturn("projects/p/instances/i/databases/d/sessions/s"); + when(session.newTransaction(any(Options.class), any())).thenReturn(txn); + + transactionRunner.run( + transaction -> { + return null; + }); + + ArgumentCaptor commitRequestCaptor = + ArgumentCaptor.forClass(CommitRequest.class); + verify(rpc).commitAsync(commitRequestCaptor.capture(), anyMap()); + CommitRequest request = commitRequestCaptor.getValue(); + RequestOptions requestOptions = request.getRequestOptions(); + assertEquals(RequestOptions.Priority.PRIORITY_HIGH, requestOptions.getPriority()); + assertEquals("tag", requestOptions.getTransactionTag()); + assertEquals(clientContext, requestOptions.getClientContext()); + } + @Mock private SpannerRpc rpc; @Mock private SessionImpl session; @Mock private TransactionRunnerImpl.TransactionContextImpl txn; @@ -124,6 +157,10 @@ public void setUp() { when(session.getTracer()).thenReturn(tracer); when(session.getRequestIdCreator()).thenReturn(NoopRequestIdCreator.INSTANCE); when(rpc.getRequestIdCreator()).thenReturn(NoopRequestIdCreator.INSTANCE); + SpannerImpl spanner = mock(SpannerImpl.class); + SpannerOptions spannerOptions = mock(SpannerOptions.class); + when(spanner.getOptions()).thenReturn(spannerOptions); + when(session.getSpanner()).thenReturn(spanner); when(rpc.executeQuery(Mockito.any(ExecuteSqlRequest.class), Mockito.anyMap(), eq(true))) .thenAnswer( invocation -> { From 4556e819e442f72a2a37e06537a5107a30384695 Mon Sep 17 00:00:00 2001 From: Adam Seering Date: Wed, 14 Jan 2026 12:56:54 +0000 Subject: [PATCH 2/2] feat: Add ClientContext support to Connection API This change adds support for setting and propagating ClientContext in the Spanner Connection API. ClientContext allows propagating client-scoped session state (e.g., secure parameters) to Spanner RPCs. - Added setClientContext/getClientContext to Connection interface and implementation. - Implemented state propagation from Connection to UnitOfWork and its implementations (ReadWriteTransaction, SingleUseTransaction). - Fixed accidental import removal in OptionsTest.java. - Fixed TransactionRunnerImplTest to correctly verify ClientContext propagation. - Added ClientContextMockServerTest for end-to-end verification. --- .../connection/AbstractBaseUnitOfWork.java | 8 + .../cloud/spanner/connection/Connection.java | 19 + .../spanner/connection/ConnectionImpl.java | 31 ++ .../connection/ReadWriteTransaction.java | 6 + .../connection/SingleUseTransaction.java | 6 + .../com/google/cloud/spanner/OptionsTest.java | 7 + .../spanner/TransactionRunnerImplTest.java | 26 +- .../ClientContextMockServerTest.java | 330 ++++++++++++++++++ .../connection/ConnectionImplTest.java | 39 +++ 9 files changed, 459 insertions(+), 13 deletions(-) create mode 100644 google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ClientContextMockServerTest.java diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/AbstractBaseUnitOfWork.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/AbstractBaseUnitOfWork.java index 75a207043c2..1d71e062cbb 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/AbstractBaseUnitOfWork.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/AbstractBaseUnitOfWork.java @@ -80,6 +80,7 @@ abstract class AbstractBaseUnitOfWork implements UnitOfWork { protected final List transactionRetryListeners; protected final boolean excludeTxnFromChangeStreams; protected final RpcPriority rpcPriority; + protected final com.google.spanner.v1.RequestOptions.ClientContext clientContext; protected final Span span; /** Class for keeping track of the stacktrace of the caller of an async statement. */ @@ -117,6 +118,7 @@ abstract static class Builder, T extends AbstractBaseUni private boolean excludeTxnFromChangeStreams; private RpcPriority rpcPriority; + private com.google.spanner.v1.RequestOptions.ClientContext clientContext; private Span span; Builder() {} @@ -163,6 +165,11 @@ B setRpcPriority(@Nullable RpcPriority rpcPriority) { return self(); } + B setClientContext(@Nullable com.google.spanner.v1.RequestOptions.ClientContext clientContext) { + this.clientContext = clientContext; + return self(); + } + B setSpan(@Nullable Span span) { this.span = span; return self(); @@ -179,6 +186,7 @@ B setSpan(@Nullable Span span) { this.transactionRetryListeners = builder.transactionRetryListeners; this.excludeTxnFromChangeStreams = builder.excludeTxnFromChangeStreams; this.rpcPriority = builder.rpcPriority; + this.clientContext = builder.clientContext; this.span = Preconditions.checkNotNull(builder.span); } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/Connection.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/Connection.java index 533be8a047f..60d739a3c85 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/Connection.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/Connection.java @@ -449,6 +449,25 @@ default String getStatementTag() { throw new UnsupportedOperationException(); } + /** + * Sets the client context to use for the statements that are executed. The client context + * persists until it is changed or cleared. + * + * @param clientContext The client context to use with the statements that will be executed on + * this connection. + */ + default void setClientContext(com.google.spanner.v1.RequestOptions.ClientContext clientContext) { + throw new UnsupportedOperationException(); + } + + /** + * @return The client context that will be used with the statements that are executed on this + * connection. + */ + default com.google.spanner.v1.RequestOptions.ClientContext getClientContext() { + throw new UnsupportedOperationException(); + } + /** * Sets whether the next transaction should be excluded from all change streams with the DDL * option `allow_txn_exclusion=true` diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ConnectionImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ConnectionImpl.java index cfd63c89d49..cadd6375739 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ConnectionImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ConnectionImpl.java @@ -94,6 +94,7 @@ import com.google.common.util.concurrent.MoreExecutors; import com.google.spanner.v1.DirectedReadOptions; import com.google.spanner.v1.ExecuteSqlRequest.QueryOptions; +import com.google.spanner.v1.RequestOptions; import com.google.spanner.v1.ResultSetStats; import com.google.spanner.v1.TransactionOptions.IsolationLevel; import com.google.spanner.v1.TransactionOptions.ReadWrite.ReadLockMode; @@ -299,6 +300,7 @@ static UnitOfWorkType of(TransactionMode transactionMode) { private IsolationLevel transactionIsolationLevel; private String transactionTag; private String statementTag; + private RequestOptions.ClientContext clientContext; private boolean excludeTxnFromChangeStreams; private byte[] protoDescriptors; private String protoDescriptorsFilePath; @@ -536,6 +538,7 @@ private void reset(Context context, boolean inTransaction) { this.connectionState.resetValue(SAVEPOINT_SUPPORT, context, inTransaction); this.protoDescriptors = null; this.protoDescriptorsFilePath = null; + this.clientContext = null; if (!isTransactionStarted()) { setDefaultTransactionOptions(getDefaultIsolationLevel()); @@ -955,6 +958,18 @@ public String getTransactionTag() { return transactionTag; } + @Override + public void setClientContext(RequestOptions.ClientContext clientContext) { + ConnectionPreconditions.checkState(!isClosed(), CLOSED_ERROR_MSG); + this.clientContext = clientContext; + } + + @Override + public RequestOptions.ClientContext getClientContext() { + ConnectionPreconditions.checkState(!isClosed(), CLOSED_ERROR_MSG); + return clientContext; + } + @Override public void setTransactionTag(String tag) { ConnectionPreconditions.checkState(!isClosed(), CLOSED_ERROR_MSG); @@ -2026,6 +2041,9 @@ private QueryOption[] mergeQueryRequestOptions( options = appendQueryOption(options, Options.priority(getConnectionPropertyValue(RPC_PRIORITY))); } + if (clientContext != null) { + options = appendQueryOption(options, Options.clientContext(clientContext)); + } if (currentUnitOfWork != null && currentUnitOfWork.supportsDirectedReads(parsedStatement) && getConnectionPropertyValue(DIRECTED_READ) != null) { @@ -2070,6 +2088,14 @@ private UpdateOption[] mergeUpdateRequestOptions(UpdateOption... options) { options[options.length - 1] = Options.priority(getConnectionPropertyValue(RPC_PRIORITY)); } } + if (clientContext != null) { + if (options == null || options.length == 0) { + options = new UpdateOption[] {Options.clientContext(clientContext)}; + } else { + options = Arrays.copyOf(options, options.length + 1); + options[options.length - 1] = Options.clientContext(clientContext); + } + } return options; } @@ -2299,6 +2325,7 @@ UnitOfWork createNewUnitOfWork( createSpanForUnitOfWork( statementType == StatementType.DDL ? DDL_STATEMENT : SINGLE_USE_TRANSACTION)) .setProtoDescriptors(getProtoDescriptors()) + .setClientContext(clientContext) .build(); if (!isInternalMetadataQuery && !forceSingleUse) { // Reset the transaction options after starting a single-use transaction. @@ -2317,6 +2344,7 @@ UnitOfWork createNewUnitOfWork( .setTransactionTag(transactionTag) .setRpcPriority(getConnectionPropertyValue(RPC_PRIORITY)) .setSpan(createSpanForUnitOfWork(READ_ONLY_TRANSACTION)) + .setClientContext(clientContext) .build(); case READ_WRITE_TRANSACTION: return ReadWriteTransaction.newBuilder() @@ -2340,6 +2368,7 @@ UnitOfWork createNewUnitOfWork( .setExcludeTxnFromChangeStreams(excludeTxnFromChangeStreams) .setRpcPriority(getConnectionPropertyValue(RPC_PRIORITY)) .setSpan(createSpanForUnitOfWork(READ_WRITE_TRANSACTION)) + .setClientContext(clientContext) .build(); case DML_BATCH: // A DML batch can run inside the current transaction. It should therefore only @@ -2359,6 +2388,7 @@ UnitOfWork createNewUnitOfWork( .setRpcPriority(getConnectionPropertyValue(RPC_PRIORITY)) // Use the transaction Span for the DML batch. .setSpan(transactionStack.peek().getSpan()) + .setClientContext(clientContext) .build(); case DDL_BATCH: return DdlBatch.newBuilder() @@ -2369,6 +2399,7 @@ UnitOfWork createNewUnitOfWork( .setSpan(createSpanForUnitOfWork(DDL_BATCH)) .setProtoDescriptors(getProtoDescriptors()) .setConnectionState(connectionState) + .setClientContext(clientContext) .build(); default: } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReadWriteTransaction.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReadWriteTransaction.java index c0e464ee5e6..ccb592e3f84 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReadWriteTransaction.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReadWriteTransaction.java @@ -350,6 +350,9 @@ private TransactionOption[] extractOptions(Builder builder) { if (this.readLockMode != ReadLockMode.READ_LOCK_MODE_UNSPECIFIED) { numOptions++; } + if (this.clientContext != null) { + numOptions++; + } TransactionOption[] options = new TransactionOption[numOptions]; int index = 0; if (builder.returnCommitStats) { @@ -373,6 +376,9 @@ private TransactionOption[] extractOptions(Builder builder) { if (this.readLockMode != ReadLockMode.READ_LOCK_MODE_UNSPECIFIED) { options[index++] = Options.readLockMode(this.readLockMode); } + if (this.clientContext != null) { + options[index++] = Options.clientContext(this.clientContext); + } return options; } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SingleUseTransaction.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SingleUseTransaction.java index 370b579e6e2..cfb13cef966 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SingleUseTransaction.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SingleUseTransaction.java @@ -520,6 +520,9 @@ private TransactionRunner createWriteTransaction() { != ReadLockMode.READ_LOCK_MODE_UNSPECIFIED) { numOptions++; } + if (this.clientContext != null) { + numOptions++; + } if (numOptions == 0) { return dbClient.readWriteTransaction(); } @@ -547,6 +550,9 @@ private TransactionRunner createWriteTransaction() { != ReadLockMode.READ_LOCK_MODE_UNSPECIFIED) { options[index++] = Options.readLockMode(connectionState.getValue(READ_LOCK_MODE).getValue()); } + if (this.clientContext != null) { + options[index++] = Options.clientContext(this.clientContext); + } return dbClient.readWriteTransaction(options); } diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/OptionsTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/OptionsTest.java index 3edf9a61e17..52cd2db7798 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/OptionsTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/OptionsTest.java @@ -35,6 +35,13 @@ import com.google.spanner.v1.ReadRequest.LockHint; import com.google.spanner.v1.ReadRequest.OrderBy; import com.google.spanner.v1.RequestOptions; +import com.google.spanner.v1.RequestOptions.Priority; +import com.google.spanner.v1.TransactionOptions.IsolationLevel; +import com.google.spanner.v1.TransactionOptions.ReadWrite; +import com.google.spanner.v1.TransactionOptions.ReadWrite.ReadLockMode; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; /** Unit tests for {@link Options}. */ @RunWith(JUnit4.class) diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionRunnerImplTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionRunnerImplTest.java index b99be489b61..208225fcfb9 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionRunnerImplTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionRunnerImplTest.java @@ -108,28 +108,28 @@ public void testCommitWithClientContext() { .putSecureContext( "key", com.google.protobuf.Value.newBuilder().setStringValue("value").build()) .build(); - Options options = - Options.fromTransactionOptions( + when(session.getName()).thenReturn("projects/p/instances/i/databases/d/sessions/s"); + when(session.newTransaction(any(Options.class), any())).thenReturn(txn); + Mockito.clearInvocations(session); + transactionRunner = + new TransactionRunnerImpl( + session, Options.priority(Options.RpcPriority.HIGH), Options.tag("tag"), Options.clientContext(clientContext)); - transactionRunner = new TransactionRunnerImpl(session, options); - when(session.getName()).thenReturn("projects/p/instances/i/databases/d/sessions/s"); - when(session.newTransaction(any(Options.class), any())).thenReturn(txn); + transactionRunner.setSpan(span); transactionRunner.run( transaction -> { return null; }); - ArgumentCaptor commitRequestCaptor = - ArgumentCaptor.forClass(CommitRequest.class); - verify(rpc).commitAsync(commitRequestCaptor.capture(), anyMap()); - CommitRequest request = commitRequestCaptor.getValue(); - RequestOptions requestOptions = request.getRequestOptions(); - assertEquals(RequestOptions.Priority.PRIORITY_HIGH, requestOptions.getPriority()); - assertEquals("tag", requestOptions.getTransactionTag()); - assertEquals(clientContext, requestOptions.getClientContext()); + ArgumentCaptor optionsCaptor = ArgumentCaptor.forClass(Options.class); + verify(session).newTransaction(optionsCaptor.capture(), any()); + Options capturedOptions = optionsCaptor.getValue(); + assertEquals(RequestOptions.Priority.PRIORITY_HIGH, capturedOptions.priority()); + assertEquals("tag", capturedOptions.tag()); + assertEquals(clientContext, capturedOptions.clientContext()); } @Mock private SpannerRpc rpc; diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ClientContextMockServerTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ClientContextMockServerTest.java new file mode 100644 index 00000000000..aecaf7136ea --- /dev/null +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ClientContextMockServerTest.java @@ -0,0 +1,330 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner.connection; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; + +import com.google.cloud.spanner.DatabaseClient; +import com.google.cloud.spanner.DatabaseId; +import com.google.cloud.spanner.Dialect; +import com.google.cloud.spanner.MockSpannerServiceImpl; +import com.google.cloud.spanner.ResultSet; +import com.google.cloud.spanner.Spanner; +import com.google.cloud.spanner.SpannerOptions; +import com.google.protobuf.Value; +import com.google.spanner.v1.BeginTransactionRequest; +import com.google.spanner.v1.CommitRequest; +import com.google.spanner.v1.ExecuteBatchDmlRequest; +import com.google.spanner.v1.ExecuteSqlRequest; +import com.google.spanner.v1.RequestOptions; +import java.util.Collections; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; + +@RunWith(Parameterized.class) +public class ClientContextMockServerTest extends AbstractMockServerTest { + + @Parameters(name = "dialect = {0}") + public static Object[] data() { + return Dialect.values(); + } + + @Parameter public Dialect dialect; + + private Dialect currentDialect; + + private static final RequestOptions.ClientContext CLIENT_CONTEXT = + RequestOptions.ClientContext.newBuilder() + .putSecureContext("test-key", Value.newBuilder().setStringValue("test-value").build()) + .build(); + + @Before + public void setupDialect() { + if (currentDialect != dialect) { + mockSpanner.putStatementResult( + MockSpannerServiceImpl.StatementResult.detectDialectResult(dialect)); + SpannerPool.closeSpannerPool(); + currentDialect = dialect; + } + } + + @After + public void clearRequests() { + mockSpanner.clearRequests(); + } + + @Test + public void testQuery_PropagatesClientContext() { + try (Connection connection = createConnection()) { + connection.setClientContext(CLIENT_CONTEXT); + try (ResultSet ignore = connection.executeQuery(SELECT_COUNT_STATEMENT)) {} + + assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + assertEquals( + CLIENT_CONTEXT, + mockSpanner + .getRequestsOfType(ExecuteSqlRequest.class) + .get(0) + .getRequestOptions() + .getClientContext()); + } + } + + @Test + public void testUpdate_PropagatesClientContext() { + try (Connection connection = createConnection()) { + connection.setClientContext(CLIENT_CONTEXT); + connection.executeUpdate(INSERT_STATEMENT); + + assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + assertEquals( + CLIENT_CONTEXT, + mockSpanner + .getRequestsOfType(ExecuteSqlRequest.class) + .get(0) + .getRequestOptions() + .getClientContext()); + } + } + + @Test + public void testBatchUpdate_PropagatesClientContext() { + try (Connection connection = createConnection()) { + connection.setClientContext(CLIENT_CONTEXT); + connection.executeBatchUpdate(Collections.singletonList(INSERT_STATEMENT)); + + assertEquals(1, mockSpanner.countRequestsOfType(ExecuteBatchDmlRequest.class)); + assertEquals( + CLIENT_CONTEXT, + mockSpanner + .getRequestsOfType(ExecuteBatchDmlRequest.class) + .get(0) + .getRequestOptions() + .getClientContext()); + } + } + + @Test + public void testCommit_PropagatesClientContext() { + try (Connection connection = createConnection()) { + connection.setClientContext(CLIENT_CONTEXT); + connection.executeUpdate(INSERT_STATEMENT); + connection.commit(); + + assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); + assertEquals( + CLIENT_CONTEXT, + mockSpanner + .getRequestsOfType(CommitRequest.class) + .get(0) + .getRequestOptions() + .getClientContext()); + } + } + + @Test + public void testBeginTransaction_PropagatesClientContext() { + // 1. Test lazy transaction start (default). + // The BeginTransaction option is inlined with the first statement. + try (Connection connection = createConnection()) { + connection.setClientContext(CLIENT_CONTEXT); + connection.beginTransaction(); + connection.executeUpdate(INSERT_STATEMENT); + + assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + ExecuteSqlRequest request = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(0); + assertEquals(CLIENT_CONTEXT, request.getRequestOptions().getClientContext()); + assertEquals(0, mockSpanner.countRequestsOfType(BeginTransactionRequest.class)); + } + + // 2. Test eager transaction start. + // We can force an explicit BeginTransaction RPC by failing the first statement with an ABORTED + // error. If the statement fails before returning a transaction ID, the retry will use an + // explicit BeginTransaction RPC. + // Note: This relies on triggering a retry logic which is the only way to force explicit + // BeginTransaction in the standard Connection API flow without additional configuration (like + // setting delayTransactionStartUntilFirstWrite=false which is not exposed publicly here). + try (Connection connection = createConnection()) { + // Abort the next statement. This will cause the ExecuteSql request (which carries the + // BeginTransaction option) to fail with an ABORTED error. + // Since the request fails, the client does not receive the transaction ID. + // The retry logic in TransactionRunnerImpl/ReadWriteTransaction will then force an + // explicit BeginTransaction RPC to ensure a transaction is started before retrying the + // statement. + mockSpanner.abortNextStatement(); + + connection.setClientContext(CLIENT_CONTEXT); + connection.beginTransaction(); + connection.executeUpdate(INSERT_STATEMENT); + + // We expect multiple ExecuteSqlRequests. + // 1. The first one fails with ABORTED. This request includes the BeginTransaction option. + // 2. The retry. + // Note: precise count depends on Gax retry logic vs Spanner retry logic interaction. + int executeSqlCount = mockSpanner.countRequestsOfType(ExecuteSqlRequest.class); + assertFalse(executeSqlCount < 2); + + for (ExecuteSqlRequest req : mockSpanner.getRequestsOfType(ExecuteSqlRequest.class)) { + assertEquals(CLIENT_CONTEXT, req.getRequestOptions().getClientContext()); + } + + // We also expect 1 BeginTransactionRequest because the retry used explicit BeginTransaction. + assertEquals(1, mockSpanner.countRequestsOfType(BeginTransactionRequest.class)); + BeginTransactionRequest beginRequest = + mockSpanner.getRequestsOfType(BeginTransactionRequest.class).get(0); + assertEquals(CLIENT_CONTEXT, beginRequest.getRequestOptions().getClientContext()); + } + } + + @Test + public void testDatabaseClient_ClientContextMerging() { + String projectId = "test-project"; + String instanceId = "test-instance"; + String databaseId = "test-database"; + + // 1. Define the default ClientContext in SpannerOptions. + RequestOptions.ClientContext defaultContext = + RequestOptions.ClientContext.newBuilder() + .putSecureContext("key1", Value.newBuilder().setStringValue("default_value1").build()) + .putSecureContext("key2", Value.newBuilder().setStringValue("default_value2").build()) + .build(); + + SpannerOptions options = + SpannerOptions.newBuilder() + .setProjectId(projectId) + .setHost("http://localhost:" + getPort()) + .usePlainText() + .setClientContext(defaultContext) + .build(); + + try (Spanner spanner = options.getService()) { + DatabaseClient client = + spanner.getDatabaseClient(DatabaseId.of(projectId, instanceId, databaseId)); + + // 2. Define the request-specific ClientContext that overrides one key and adds a new one. + RequestOptions.ClientContext requestContext = + RequestOptions.ClientContext.newBuilder() + .putSecureContext("key2", Value.newBuilder().setStringValue("request_value2").build()) + .putSecureContext("key3", Value.newBuilder().setStringValue("request_value3").build()) + .build(); + + // 3. Define the expected merged ClientContext (Union + Overwrite). + RequestOptions.ClientContext expectedContext = + RequestOptions.ClientContext.newBuilder() + .putSecureContext("key1", Value.newBuilder().setStringValue("default_value1").build()) + .putSecureContext("key2", Value.newBuilder().setStringValue("request_value2").build()) + .putSecureContext("key3", Value.newBuilder().setStringValue("request_value3").build()) + .build(); + + // Execute a query with the request context. + try (ResultSet rs = + client + .singleUse() + .executeQuery( + SELECT_COUNT_STATEMENT, + com.google.cloud.spanner.Options.clientContext(requestContext))) { + rs.next(); + } + + assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + RequestOptions.ClientContext actualContext = + mockSpanner + .getRequestsOfType(ExecuteSqlRequest.class) + .get(0) + .getRequestOptions() + .getClientContext(); + + assertEquals(expectedContext, actualContext); + + // Verify specifically that key2 was overwritten and key1 was preserved. + assertEquals( + "request_value2", actualContext.getSecureContextOrThrow("key2").getStringValue()); + assertEquals( + "default_value1", actualContext.getSecureContextOrThrow("key1").getStringValue()); + } + } + + @Test + public void testPersistence() { + try (Connection connection = createConnection()) { + connection.setClientContext(CLIENT_CONTEXT); + try (ResultSet ignore = connection.executeQuery(SELECT_COUNT_STATEMENT)) {} + assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + assertEquals( + CLIENT_CONTEXT, + mockSpanner + .getRequestsOfType(ExecuteSqlRequest.class) + .get(0) + .getRequestOptions() + .getClientContext()); + + connection.executeUpdate(INSERT_STATEMENT); + assertEquals(2, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + assertEquals( + CLIENT_CONTEXT, + mockSpanner + .getRequestsOfType(ExecuteSqlRequest.class) + .get(1) + .getRequestOptions() + .getClientContext()); + + connection.commit(); + assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); + assertEquals( + CLIENT_CONTEXT, + mockSpanner + .getRequestsOfType(CommitRequest.class) + .get(0) + .getRequestOptions() + .getClientContext()); + } + } + + @Test + public void testClearClientContext() { + try (Connection connection = createConnection()) { + connection.setClientContext(CLIENT_CONTEXT); + try (ResultSet ignore = connection.executeQuery(SELECT_COUNT_STATEMENT)) {} + + assertEquals( + CLIENT_CONTEXT, + mockSpanner + .getRequestsOfType(ExecuteSqlRequest.class) + .get(0) + .getRequestOptions() + .getClientContext()); + + connection.setClientContext(null); + try (ResultSet ignore = connection.executeQuery(SELECT_COUNT_STATEMENT)) {} + + assertEquals(2, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + assertFalse( + mockSpanner + .getRequestsOfType(ExecuteSqlRequest.class) + .get(1) + .getRequestOptions() + .hasClientContext()); + } + } +} diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ConnectionImplTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ConnectionImplTest.java index ead2fd0f655..c1a5e2873de 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ConnectionImplTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ConnectionImplTest.java @@ -1948,6 +1948,45 @@ private void assertThrowResultNotAllowed( "Only statements that return a result of one of the following types are allowed")); } + @Test + public void testSetAndGetClientContext() { + try (Connection connection = + createConnection( + ConnectionOptions.newBuilder() + .setUri(URI) + .setCredentials(NoCredentials.getInstance()) + .build())) { + com.google.spanner.v1.RequestOptions.ClientContext context = + com.google.spanner.v1.RequestOptions.ClientContext.newBuilder() + .putSecureContext( + "key", com.google.protobuf.Value.newBuilder().setStringValue("test").build()) + .build(); + connection.setClientContext(context); + assertEquals(context, connection.getClientContext()); + } + } + + @Test + public void testResetClearsClientContext() { + try (Connection connection = + createConnection( + ConnectionOptions.newBuilder() + .setUri(URI) + .setCredentials(NoCredentials.getInstance()) + .build())) { + com.google.spanner.v1.RequestOptions.ClientContext context = + com.google.spanner.v1.RequestOptions.ClientContext.newBuilder() + .putSecureContext( + "key", com.google.protobuf.Value.newBuilder().setStringValue("test").build()) + .build(); + connection.setClientContext(context); + assertEquals(context, connection.getClientContext()); + + connection.reset(); + assertNull(connection.getClientContext()); + } + } + @Test public void testProtoDescriptorsAlwaysAllowed() { ConnectionOptions connectionOptions = mock(ConnectionOptions.class);