From c4e4496a2ae3cedb01034dd25c051833ef0a5d4c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Fri, 10 Apr 2026 15:53:51 +0200 Subject: [PATCH] chore(spanner): track latency when using KeyAwareChannel Track latency for streaming SQL and streaming reads that return only one PartialResultSet. The choice for which RPCs are eligible for tracking is done by the LatencyTracker. This allows us to include/exclude more RPCs in the future. --- .../cloud/spanner/spi/v1/ChannelEndpoint.java | 9 ++ .../spanner/spi/v1/EwmaLatencyTracker.java | 22 ++++- .../spi/v1/GrpcChannelEndpointCache.java | 20 +++- .../cloud/spanner/spi/v1/KeyAwareChannel.java | 21 ++++- .../cloud/spanner/spi/v1/LatencyTracker.java | 16 +++- .../spanner/spi/v1/KeyAwareChannelTest.java | 93 +++++++++++++++++++ 6 files changed, 168 insertions(+), 13 deletions(-) diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/ChannelEndpoint.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/ChannelEndpoint.java index fc82c530fc6f..6624303407a4 100644 --- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/ChannelEndpoint.java +++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/ChannelEndpoint.java @@ -71,4 +71,13 @@ public interface ChannelEndpoint { * @return the managed channel for this server */ ManagedChannel getChannel(); + + /** + * Returns the latency tracker for this endpoint, or null if not supported. + * + * @return the latency tracker or null + */ + default LatencyTracker getLatencyTracker() { + return null; + } } diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/EwmaLatencyTracker.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/EwmaLatencyTracker.java index 0cb2331660f9..213ed500d76e 100644 --- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/EwmaLatencyTracker.java +++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/EwmaLatencyTracker.java @@ -19,6 +19,8 @@ import com.google.api.core.BetaApi; import com.google.api.core.InternalApi; import com.google.common.base.Preconditions; +import com.google.spanner.v1.PartialResultSet; +import io.grpc.MethodDescriptor; import java.time.Duration; import java.util.concurrent.TimeUnit; import javax.annotation.concurrent.GuardedBy; @@ -67,8 +69,7 @@ public double getScore() { } } - @Override - public void update(Duration latency) { + void update(Duration latency) { long latencyMicros; try { latencyMicros = TimeUnit.MICROSECONDS.convert(latency.toNanos(), TimeUnit.NANOSECONDS); @@ -92,4 +93,21 @@ public void recordError(Duration penalty) { // Treat the error as a sample with high latency (penalty) update(penalty); } + + @Override + public boolean isEligible(MethodDescriptor methodDescriptor) { + String methodName = methodDescriptor.getFullMethodName(); + return KeyAwareChannel.STREAMING_READ_METHOD.equals(methodName) + || KeyAwareChannel.STREAMING_SQL_METHOD.equals(methodName); + } + + @Override + public void maybeUpdate(Object message, Duration latency) { + if (message instanceof PartialResultSet) { + PartialResultSet response = (PartialResultSet) message; + if (response.getLast()) { + update(latency); + } + } + } } diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GrpcChannelEndpointCache.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GrpcChannelEndpointCache.java index 98e7f83b094f..c3a1bf0b8518 100644 --- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GrpcChannelEndpointCache.java +++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GrpcChannelEndpointCache.java @@ -67,7 +67,8 @@ public GrpcChannelEndpointCache(InstantiatingGrpcChannelProvider channelProvider throws IOException { this.baseProvider = channelProvider; String defaultEndpoint = channelProvider.getEndpoint(); - this.defaultEndpoint = new GrpcChannelEndpoint(defaultEndpoint, channelProvider); + this.defaultEndpoint = + new GrpcChannelEndpoint(defaultEndpoint, channelProvider, new EwmaLatencyTracker()); this.defaultAuthority = this.defaultEndpoint.getChannel().authority(); this.servers.put(defaultEndpoint, this.defaultEndpoint); } @@ -92,7 +93,8 @@ public ChannelEndpoint get(String address) { // This is thread-safe as withEndpoint() returns a new provider instance. InstantiatingGrpcChannelProvider newProvider = createProviderWithAuthorityOverride(addr); - GrpcChannelEndpoint endpoint = new GrpcChannelEndpoint(addr, newProvider); + GrpcChannelEndpoint endpoint = + new GrpcChannelEndpoint(addr, newProvider, new EwmaLatencyTracker()); logger.log(Level.FINE, "Location-aware endpoint created for address: {0}", addr); return endpoint; } catch (IOException e) { @@ -178,10 +180,10 @@ private void shutdownChannel(GrpcChannelEndpoint server, boolean awaitTerminatio } } - /** gRPC implementation of {@link ChannelEndpoint}. */ static class GrpcChannelEndpoint implements ChannelEndpoint { private final String address; private final ManagedChannel channel; + private final LatencyTracker latencyTracker; /** * Creates a server from a channel provider. @@ -190,7 +192,8 @@ static class GrpcChannelEndpoint implements ChannelEndpoint { * @param provider the channel provider (must be a gRPC provider) * @throws IOException if the channel cannot be created */ - GrpcChannelEndpoint(String address, InstantiatingGrpcChannelProvider provider) + GrpcChannelEndpoint( + String address, InstantiatingGrpcChannelProvider provider, LatencyTracker latencyTracker) throws IOException { this.address = address; // Build a raw ManagedChannel directly instead of going through getTransportChannel(), @@ -203,6 +206,7 @@ static class GrpcChannelEndpoint implements ChannelEndpoint { provider.withHeaders(java.util.Collections.emptyMap()); } this.channel = readyProvider.createDecoratedChannelBuilder().build(); + this.latencyTracker = latencyTracker; } /** @@ -212,9 +216,10 @@ static class GrpcChannelEndpoint implements ChannelEndpoint { * @param channel the managed channel */ @VisibleForTesting - GrpcChannelEndpoint(String address, ManagedChannel channel) { + GrpcChannelEndpoint(String address, ManagedChannel channel, LatencyTracker latencyTracker) { this.address = address; this.channel = channel; + this.latencyTracker = latencyTracker; } @Override @@ -267,5 +272,10 @@ public boolean isTransientFailure() { public ManagedChannel getChannel() { return channel; } + + @Override + public LatencyTracker getLatencyTracker() { + return latencyTracker; + } } } diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareChannel.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareChannel.java index d7b32f72bcd6..7201bf477a28 100644 --- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareChannel.java +++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareChannel.java @@ -46,6 +46,7 @@ import java.io.IOException; import java.lang.ref.ReferenceQueue; import java.lang.ref.SoftReference; +import java.time.Duration; import java.util.HashSet; import java.util.Map; import java.util.Set; @@ -72,9 +73,8 @@ final class KeyAwareChannel extends ManagedChannel { private static final long MAX_TRACKED_READ_ONLY_TRANSACTIONS = 100_000L; private static final long MAX_TRACKED_EXCLUDED_LOGICAL_REQUESTS = 100_000L; private static final long EXCLUDED_LOGICAL_REQUEST_TTL_MINUTES = 10L; - private static final String STREAMING_READ_METHOD = "google.spanner.v1.Spanner/StreamingRead"; - private static final String STREAMING_SQL_METHOD = - "google.spanner.v1.Spanner/ExecuteStreamingSql"; + static final String STREAMING_READ_METHOD = "google.spanner.v1.Spanner/StreamingRead"; + static final String STREAMING_SQL_METHOD = "google.spanner.v1.Spanner/ExecuteStreamingSql"; private static final String UNARY_SQL_METHOD = "google.spanner.v1.Spanner/ExecuteSql"; private static final String BEGIN_TRANSACTION_METHOD = "google.spanner.v1.Spanner/BeginTransaction"; @@ -462,6 +462,7 @@ static final class KeyAwareClientCall private boolean isReadOnlyBegin; private boolean readOnlyIsStrong; private final Object lock = new Object(); + volatile long startTimeNanos; KeyAwareClientCall( KeyAwareChannel parentChannel, @@ -610,6 +611,7 @@ public void sendMessage(RequestT message) { } delegate.start(responseListener, headers); drainPendingRequests(); + startTimeNanos = System.nanoTime(); delegate.sendMessage(message); if (pendingHalfClose) { delegate.halfClose(); @@ -810,6 +812,7 @@ private RoutingDecision(@Nullable ChannelFinder finder, @Nullable ChannelEndpoin static final class KeyAwareClientCallListener extends SimpleForwardingClientCallListener { private final KeyAwareClientCall call; + private boolean firstMessageReceived = false; KeyAwareClientCallListener( ClientCall.Listener responseListener, KeyAwareClientCall call) { @@ -819,6 +822,18 @@ static final class KeyAwareClientCallListener @Override public void onMessage(ResponseT message) { + if (!firstMessageReceived) { + firstMessageReceived = true; + // call.selectedEndpoint will in real usage never be null when we reach this + // point. + if (call.selectedEndpoint != null) { + LatencyTracker tracker = call.selectedEndpoint.getLatencyTracker(); + if (tracker != null && tracker.isEligible(call.methodDescriptor)) { + Duration latency = Duration.ofNanos(System.nanoTime() - call.startTimeNanos); + tracker.maybeUpdate(message, latency); + } + } + } ByteString transactionId = null; if (message instanceof PartialResultSet) { PartialResultSet response = (PartialResultSet) message; diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/LatencyTracker.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/LatencyTracker.java index d7467853492d..c70bcc144eec 100644 --- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/LatencyTracker.java +++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/LatencyTracker.java @@ -18,6 +18,7 @@ import com.google.api.core.BetaApi; import com.google.api.core.InternalApi; +import io.grpc.MethodDescriptor; import java.time.Duration; /** @@ -38,11 +39,12 @@ public interface LatencyTracker { double getScore(); /** - * Updates the latency score with a new observation. + * Potentially updates the latency score based on the response message. * - * @param latency the observed latency. + * @param message the response message. + * @param latency the measured latency. */ - void update(Duration latency); + void maybeUpdate(Object message, Duration latency); /** * Records an error and applies a latency penalty. @@ -50,4 +52,12 @@ public interface LatencyTracker { * @param penalty the penalty to apply. */ void recordError(Duration penalty); + + /** + * Returns whether a call with the given method descriptor is eligible for latency measurement. + * + * @param methodDescriptor the method descriptor of the call. + * @return true if eligible, false otherwise. + */ + boolean isEligible(MethodDescriptor methodDescriptor); } diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyAwareChannelTest.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyAwareChannelTest.java index 1ad3888b4f9d..cad6b70b5673 100644 --- a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyAwareChannelTest.java +++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyAwareChannelTest.java @@ -286,6 +286,93 @@ public void resultSetCacheUpdateRoutesSubsequentRequest() throws Exception { assertThat(harness.endpointCache.callCountForAddress("routed:1234")).isEqualTo(1); } + @Test + public void callTracksLatencyOnMessage() throws Exception { + TestHarness harness = createHarness(); + ExecuteSqlRequest request = ExecuteSqlRequest.newBuilder().setSession(SESSION).build(); + + ClientCall call = + harness.channel.newCall(SpannerGrpc.getExecuteStreamingSqlMethod(), CallOptions.DEFAULT); + CapturingListener listener = new CapturingListener<>(); + call.start(listener, new Metadata()); + call.sendMessage(request); + + @SuppressWarnings("unchecked") + RecordingClientCall delegate = + (RecordingClientCall) + harness.defaultManagedChannel.latestCall(); + + FakeEndpoint defaultEndpoint = harness.endpointCache.defaultEndpoint; + LatencyTracker tracker = defaultEndpoint.getLatencyTracker(); + + double initialScore = tracker.getScore(); + + // Emit a message with last=true to trigger onMessage and latency update. + delegate.emitOnMessage(PartialResultSet.newBuilder().setLast(true).build()); + + // Verify that the score has been updated (it should not be equal to the initial score). + double newScore = tracker.getScore(); + assertThat(newScore).isNotEqualTo(initialScore); + } + + @Test + public void callDoesNotTrackLatencyForNonEligibleRpc() throws Exception { + TestHarness harness = createHarness(); + ExecuteSqlRequest request = ExecuteSqlRequest.newBuilder().setSession(SESSION).build(); + + ClientCall call = + harness.channel.newCall(SpannerGrpc.getExecuteSqlMethod(), CallOptions.DEFAULT); + CapturingListener listener = new CapturingListener<>(); + call.start(listener, new Metadata()); + call.sendMessage(request); + + @SuppressWarnings("unchecked") + RecordingClientCall delegate = + (RecordingClientCall) + harness.defaultManagedChannel.latestCall(); + + FakeEndpoint defaultEndpoint = harness.endpointCache.defaultEndpoint; + LatencyTracker tracker = defaultEndpoint.getLatencyTracker(); + + double initialScore = tracker.getScore(); + + // Emit a message. + delegate.emitOnMessage(ResultSet.newBuilder().build()); + + // Verify that the score has not been updated. + double newScore = tracker.getScore(); + assertThat(newScore).isEqualTo(initialScore); + } + + @Test + public void callDoesNotTrackLatencyForNonLastPartialResultSet() throws Exception { + TestHarness harness = createHarness(); + ExecuteSqlRequest request = ExecuteSqlRequest.newBuilder().setSession(SESSION).build(); + + ClientCall call = + harness.channel.newCall(SpannerGrpc.getExecuteStreamingSqlMethod(), CallOptions.DEFAULT); + CapturingListener listener = new CapturingListener<>(); + call.start(listener, new Metadata()); + call.sendMessage(request); + + @SuppressWarnings("unchecked") + RecordingClientCall delegate = + (RecordingClientCall) + harness.defaultManagedChannel.latestCall(); + + FakeEndpoint defaultEndpoint = harness.endpointCache.defaultEndpoint; + LatencyTracker tracker = defaultEndpoint.getLatencyTracker(); + + double initialScore = tracker.getScore(); + + // Emit a message with last=false. + delegate.emitOnMessage(PartialResultSet.newBuilder().setLast(false).build()); + + // Verify that the score has not been updated. + double newScore = tracker.getScore(); + assertThat(newScore).isEqualTo(initialScore); + } + @Test public void beginTransactionWithMutationKeyAddsRoutingHint() throws Exception { TestHarness harness = createHarness(); @@ -1350,12 +1437,18 @@ int callCountForAddress(String address) { private static final class FakeEndpoint implements ChannelEndpoint { private final String address; private final FakeManagedChannel channel; + private final LatencyTracker latencyTracker = new EwmaLatencyTracker(); private FakeEndpoint(String address) { this.address = address; this.channel = new FakeManagedChannel(address); } + @Override + public LatencyTracker getLatencyTracker() { + return latencyTracker; + } + @Override public String getAddress() { return address;