From 25c51e46a3482f5f273cd299890c98263544f8be Mon Sep 17 00:00:00 2001 From: Eric Anderson Date: Fri, 23 Jan 2026 12:46:50 -0800 Subject: [PATCH 1/2] core: Add pick_first weighted shuffle The prior uniform shuffle in pick_first will send uniform load across clients. When endpoints have weights, we'd desire for endpoints to be selected proportionally to their weight. The server weight attribute has to move out of xDS to be seen by pick-first, but it is kept as internal for now. Since xDS is the only thing that sets weights, the behavior change is only visible to xDS. See gRFC A113 --- .../java/io/grpc/EquivalentAddressGroup.java | 9 + .../grpc/InternalEquivalentAddressGroup.java | 29 +++ .../internal/PickFirstLeafLoadBalancer.java | 50 ++++- .../grpc/internal/PickFirstLoadBalancer.java | 7 +- .../PickFirstLeafLoadBalancerTest.java | 146 +++++++++++- .../internal/PickFirstLoadBalancerTest.java | 207 ++++++++++++++++++ .../main/java/io/grpc/xds/XdsAttributes.java | 4 +- 7 files changed, 435 insertions(+), 17 deletions(-) create mode 100644 api/src/main/java/io/grpc/InternalEquivalentAddressGroup.java diff --git a/api/src/main/java/io/grpc/EquivalentAddressGroup.java b/api/src/main/java/io/grpc/EquivalentAddressGroup.java index bf8a864902c..d40e47c16f6 100644 --- a/api/src/main/java/io/grpc/EquivalentAddressGroup.java +++ b/api/src/main/java/io/grpc/EquivalentAddressGroup.java @@ -55,6 +55,15 @@ public final class EquivalentAddressGroup { */ public static final Attributes.Key ATTR_LOCALITY_NAME = Attributes.Key.create("io.grpc.EquivalentAddressGroup.LOCALITY"); + /** + * Endpoint weight for load balancing purposes. While the type is Long, it must be a valid uint32. + * Must not be zero. The weight is proportional to the other endpoints; if an endpoint's weight is + * twice that of another endpoint, it is intended to receive twice the load. + */ + @Attr + static final Attributes.Key ATTR_WEIGHT = + Attributes.Key.create("io.grpc.EquivalentAddressGroup.ATTR_WEIGHT"); + private final List addrs; private final Attributes attrs; diff --git a/api/src/main/java/io/grpc/InternalEquivalentAddressGroup.java b/api/src/main/java/io/grpc/InternalEquivalentAddressGroup.java new file mode 100644 index 00000000000..d4bed4d81bc --- /dev/null +++ b/api/src/main/java/io/grpc/InternalEquivalentAddressGroup.java @@ -0,0 +1,29 @@ +/* + * Copyright 2026 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +@Internal +public final class InternalEquivalentAddressGroup { + private InternalEquivalentAddressGroup() {} + + /** + * Endpoint weight for load balancing purposes. While the type is Long, it must be a valid uint32. + * Must not be zero. The weight is proportional to the other endpoints; if an endpoint's weight is + * twice that of another endpoint, it is intended to receive twice the load. + */ + public static final Attributes.Key ATTR_WEIGHT = EquivalentAddressGroup.ATTR_WEIGHT; +} diff --git a/core/src/main/java/io/grpc/internal/PickFirstLeafLoadBalancer.java b/core/src/main/java/io/grpc/internal/PickFirstLeafLoadBalancer.java index 935214a94fd..cd9ff9ab58c 100644 --- a/core/src/main/java/io/grpc/internal/PickFirstLeafLoadBalancer.java +++ b/core/src/main/java/io/grpc/internal/PickFirstLeafLoadBalancer.java @@ -26,10 +26,12 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; +import com.google.errorprone.annotations.CheckReturnValue; import io.grpc.Attributes; import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; import io.grpc.EquivalentAddressGroup; +import io.grpc.InternalEquivalentAddressGroup; import io.grpc.LoadBalancer; import io.grpc.Status; import io.grpc.SynchronizationContext.ScheduledHandle; @@ -61,6 +63,8 @@ final class PickFirstLeafLoadBalancer extends LoadBalancer { static final int CONNECTION_DELAY_INTERVAL_MS = 250; private final boolean enableHappyEyeballs = !isSerializingRetries() && PickFirstLoadBalancerProvider.isEnabledHappyEyeballs(); + static boolean weightedShuffling = + GrpcUtil.getFlag("GRPC_EXPERIMENTAL_PF_WEIGHTED_SHUFFLING", true); private final Helper helper; private final Map subchannels = new HashMap<>(); private final Index addressIndex = new Index(ImmutableList.of(), this.enableHappyEyeballs); @@ -128,13 +132,13 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { PickFirstLeafLoadBalancerConfig config = (PickFirstLeafLoadBalancerConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); if (config.shuffleAddressList != null && config.shuffleAddressList) { - Collections.shuffle(cleanServers, - config.randomSeed != null ? new Random(config.randomSeed) : new Random()); + cleanServers = shuffle( + cleanServers, config.randomSeed != null ? new Random(config.randomSeed) : new Random()); } } final ImmutableList newImmutableAddressGroups = - ImmutableList.builder().addAll(cleanServers).build(); + ImmutableList.copyOf(cleanServers); if (rawConnectivityState == READY || (rawConnectivityState == CONNECTING @@ -224,6 +228,46 @@ private static List deDupAddresses(List shuffle(List eags, Random random) { + if (weightedShuffling) { + List weightedEntries = new ArrayList<>(eags.size()); + for (EquivalentAddressGroup eag : eags) { + weightedEntries.add(new WeightEntry(eag, eagToWeight(eag, random))); + } + Collections.sort(weightedEntries, Collections.reverseOrder() /* descending */); + return Lists.transform(weightedEntries, entry -> entry.eag); + } else { + List eagsCopy = new ArrayList<>(eags); + Collections.shuffle(eagsCopy, random); + return eagsCopy; + } + } + + private static double eagToWeight(EquivalentAddressGroup eag, Random random) { + Long weight = eag.getAttributes().get(InternalEquivalentAddressGroup.ATTR_WEIGHT); + if (weight == null) { + weight = 1L; + } + return Math.pow(random.nextDouble(), 1.0 / weight); + } + + private static final class WeightEntry implements Comparable { + final EquivalentAddressGroup eag; + final double weight; + + public WeightEntry(EquivalentAddressGroup eag, double weight) { + this.eag = eag; + this.weight = weight; + } + + @Override + public int compareTo(WeightEntry entry) { + return Double.compare(this.weight, entry.weight); + } + } + @Override public void handleNameResolutionError(Status error) { if (rawConnectivityState == SHUTDOWN) { diff --git a/core/src/main/java/io/grpc/internal/PickFirstLoadBalancer.java b/core/src/main/java/io/grpc/internal/PickFirstLoadBalancer.java index aa8b5a7e9a9..cf4b4c94e04 100644 --- a/core/src/main/java/io/grpc/internal/PickFirstLoadBalancer.java +++ b/core/src/main/java/io/grpc/internal/PickFirstLoadBalancer.java @@ -27,8 +27,6 @@ import io.grpc.EquivalentAddressGroup; import io.grpc.LoadBalancer; import io.grpc.Status; -import java.util.ArrayList; -import java.util.Collections; import java.util.List; import java.util.Random; import java.util.concurrent.atomic.AtomicBoolean; @@ -65,9 +63,8 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { PickFirstLoadBalancerConfig config = (PickFirstLoadBalancerConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); if (config.shuffleAddressList != null && config.shuffleAddressList) { - servers = new ArrayList(servers); - Collections.shuffle(servers, - config.randomSeed != null ? new Random(config.randomSeed) : new Random()); + servers = PickFirstLeafLoadBalancer.shuffle( + servers, config.randomSeed != null ? new Random(config.randomSeed) : new Random()); } } diff --git a/core/src/test/java/io/grpc/internal/PickFirstLeafLoadBalancerTest.java b/core/src/test/java/io/grpc/internal/PickFirstLeafLoadBalancerTest.java index 8b09fce2aa2..cb73d17d682 100644 --- a/core/src/test/java/io/grpc/internal/PickFirstLeafLoadBalancerTest.java +++ b/core/src/test/java/io/grpc/internal/PickFirstLeafLoadBalancerTest.java @@ -23,6 +23,7 @@ import static io.grpc.ConnectivityState.READY; import static io.grpc.ConnectivityState.SHUTDOWN; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; +import static io.grpc.InternalEquivalentAddressGroup.ATTR_WEIGHT; import static io.grpc.LoadBalancer.HAS_HEALTH_PRODUCER_LISTENER_KEY; import static io.grpc.LoadBalancer.HEALTH_CONSUMER_LISTENER_ARG_KEY; import static io.grpc.LoadBalancer.IS_PETIOLE_POLICY; @@ -70,10 +71,13 @@ import io.grpc.internal.PickFirstLeafLoadBalancer.PickFirstLeafLoadBalancerConfig; import java.net.InetSocketAddress; import java.net.SocketAddress; +import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Queue; +import java.util.Random; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import org.junit.After; @@ -149,6 +153,7 @@ public void uncaughtException(Thread t, Throwable e) { private String originalHappyEyeballsEnabledValue; private String originalSerializeRetriesValue; + private boolean originalWeightedShuffling; private long backoffMillis; @@ -165,6 +170,8 @@ public void setUp() { System.setProperty(PickFirstLoadBalancerProvider.GRPC_PF_USE_HAPPY_EYEBALLS, Boolean.toString(enableHappyEyeballs)); + originalWeightedShuffling = PickFirstLeafLoadBalancer.weightedShuffling; + for (int i = 1; i <= 5; i++) { SocketAddress addr = new FakeSocketAddress("server" + i); servers.add(new EquivalentAddressGroup(addr)); @@ -207,6 +214,7 @@ public void tearDown() { System.setProperty(PickFirstLoadBalancerProvider.GRPC_PF_USE_HAPPY_EYEBALLS, originalHappyEyeballsEnabledValue); } + PickFirstLeafLoadBalancer.weightedShuffling = originalWeightedShuffling; loadBalancer.shutdown(); verifyNoMoreInteractions(mockArgs); @@ -242,6 +250,12 @@ public void pickAfterResolved() { verifyNoMoreInteractions(mockHelper); } + @Test + public void pickAfterResolved_shuffle_oppositeWeightedShuffling() { + PickFirstLeafLoadBalancer.weightedShuffling = !PickFirstLeafLoadBalancer.weightedShuffling; + pickAfterResolved_shuffle(); + } + @Test public void pickAfterResolved_shuffle() { servers.remove(4); @@ -305,6 +319,103 @@ public void pickAfterResolved_noShuffle() { assertNotNull(pickerCaptor.getValue().pickSubchannel(mockArgs)); } + @Test + public void pickAfterResolved_shuffleImplicitUniform_oppositeWeightedShuffling() { + PickFirstLeafLoadBalancer.weightedShuffling = !PickFirstLeafLoadBalancer.weightedShuffling; + pickAfterResolved_shuffleImplicitUniform(); + } + + @Test + public void pickAfterResolved_shuffleImplicitUniform() { + EquivalentAddressGroup eag1 = new EquivalentAddressGroup(new FakeSocketAddress("server1")); + EquivalentAddressGroup eag2 = new EquivalentAddressGroup(new FakeSocketAddress("server2")); + EquivalentAddressGroup eag3 = new EquivalentAddressGroup(new FakeSocketAddress("server3")); + + int[] counts = countAddressSelections(99, Arrays.asList(eag1, eag2, eag3)); + assertThat(counts[0]).isWithin(7).of(33); + assertThat(counts[1]).isWithin(7).of(33); + assertThat(counts[2]).isWithin(7).of(33); + } + + @Test + public void pickAfterResolved_shuffleExplicitUniform_oppositeWeightedShuffling() { + PickFirstLeafLoadBalancer.weightedShuffling = !PickFirstLeafLoadBalancer.weightedShuffling; + pickAfterResolved_shuffleExplicitUniform(); + } + + @Test + public void pickAfterResolved_shuffleExplicitUniform() { + EquivalentAddressGroup eag1 = new EquivalentAddressGroup( + new FakeSocketAddress("server1"), Attributes.newBuilder().set(ATTR_WEIGHT, 111L).build()); + EquivalentAddressGroup eag2 = new EquivalentAddressGroup( + new FakeSocketAddress("server2"), Attributes.newBuilder().set(ATTR_WEIGHT, 111L).build()); + EquivalentAddressGroup eag3 = new EquivalentAddressGroup( + new FakeSocketAddress("server3"), Attributes.newBuilder().set(ATTR_WEIGHT, 111L).build()); + + int[] counts = countAddressSelections(99, Arrays.asList(eag1, eag2, eag3)); + assertThat(counts[0]).isWithin(7).of(33); + assertThat(counts[1]).isWithin(7).of(33); + assertThat(counts[2]).isWithin(7).of(33); + } + + @Test + public void pickAfterResolved_shuffleWeighted_noWeightedShuffling() { + PickFirstLeafLoadBalancer.weightedShuffling = false; + EquivalentAddressGroup eag1 = new EquivalentAddressGroup( + new FakeSocketAddress("server1"), Attributes.newBuilder().set(ATTR_WEIGHT, 12L).build()); + EquivalentAddressGroup eag2 = new EquivalentAddressGroup( + new FakeSocketAddress("server2"), Attributes.newBuilder().set(ATTR_WEIGHT, 3L).build()); + EquivalentAddressGroup eag3 = new EquivalentAddressGroup( + new FakeSocketAddress("server3"), Attributes.newBuilder().set(ATTR_WEIGHT, 1L).build()); + + int[] counts = countAddressSelections(100, Arrays.asList(eag1, eag2, eag3)); + assertThat(counts[0]).isWithin(7).of(33); + assertThat(counts[1]).isWithin(7).of(33); + assertThat(counts[2]).isWithin(7).of(33); + } + + @Test + public void pickAfterResolved_shuffleWeighted_weightedShuffling() { + PickFirstLeafLoadBalancer.weightedShuffling = true; + EquivalentAddressGroup eag1 = new EquivalentAddressGroup( + new FakeSocketAddress("server1"), Attributes.newBuilder().set(ATTR_WEIGHT, 12L).build()); + EquivalentAddressGroup eag2 = new EquivalentAddressGroup( + new FakeSocketAddress("server2"), Attributes.newBuilder().set(ATTR_WEIGHT, 3L).build()); + EquivalentAddressGroup eag3 = new EquivalentAddressGroup( + new FakeSocketAddress("server3"), Attributes.newBuilder().set(ATTR_WEIGHT, 1L).build()); + + int[] counts = countAddressSelections(100, Arrays.asList(eag1, eag2, eag3)); + assertThat(counts[0]).isWithin(7).of(75); // 100*12/16 + assertThat(counts[1]).isWithin(7).of(19); // 100*3/16 + assertThat(counts[2]).isWithin(7).of(6); // 100*1/16 + } + + /** Returns int[index_of_eag] array with number of times each eag was selected. */ + private int[] countAddressSelections(int trials, List eags) { + int[] counts = new int[eags.size()]; + Random random = new Random(1); + for (int i = 0; i < trials; i++) { + RecordingHelper helper = new RecordingHelper(); + LoadBalancer lb = new PickFirstLeafLoadBalancer(helper); + assertThat(lb.acceptResolvedAddresses(ResolvedAddresses.newBuilder() + .setAddresses(eags) + .setAttributes(affinity) + .setLoadBalancingPolicyConfig( + new PickFirstLeafLoadBalancerConfig(true, random.nextLong())) + .build())) + .isSameInstanceAs(Status.OK); + helper.subchannels.remove().listener.onSubchannelState( + ConnectivityStateInfo.forNonError(READY)); + + assertThat(helper.state).isEqualTo(READY); + Subchannel subchannel = helper.picker.pickSubchannel(mockArgs).getSubchannel(); + counts[eags.indexOf(subchannel.getAddresses())]++; + + lb.shutdown(); + } + return counts; + } + @Test public void requestConnectionPicker() { // Set up @@ -2945,13 +3056,7 @@ public String toString() { } } - private class MockHelperImpl extends LoadBalancer.Helper { - private final List subchannels; - - public MockHelperImpl(List subchannels) { - this.subchannels = new ArrayList(subchannels); - } - + private class BaseHelper extends LoadBalancer.Helper { @Override public ManagedChannel createOobChannel(EquivalentAddressGroup eag, String authority) { return null; @@ -2981,6 +3086,14 @@ public ScheduledExecutorService getScheduledExecutorService() { public void refreshNameResolution() { // noop } + } + + private class MockHelperImpl extends BaseHelper { + private final List subchannels; + + public MockHelperImpl(List subchannels) { + this.subchannels = new ArrayList(subchannels); + } @Override public Subchannel createSubchannel(CreateSubchannelArgs args) { @@ -2997,4 +3110,23 @@ public Subchannel createSubchannel(CreateSubchannelArgs args) { throw new IllegalArgumentException("Unexpected addresses: " + args.getAddresses()); } } + + class RecordingHelper extends BaseHelper { + ConnectivityState state; + SubchannelPicker picker; + final Queue subchannels = new ArrayDeque<>(); + + @Override + public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) { + this.state = newState; + this.picker = newPicker; + } + + @Override + public Subchannel createSubchannel(CreateSubchannelArgs args) { + FakeSubchannel subchannel = new FakeSubchannel(args.getAddresses(), args.getAttributes()); + subchannels.add(subchannel); + return subchannel; + } + } } diff --git a/core/src/test/java/io/grpc/internal/PickFirstLoadBalancerTest.java b/core/src/test/java/io/grpc/internal/PickFirstLoadBalancerTest.java index 819293e070b..1e130423a45 100644 --- a/core/src/test/java/io/grpc/internal/PickFirstLoadBalancerTest.java +++ b/core/src/test/java/io/grpc/internal/PickFirstLoadBalancerTest.java @@ -21,6 +21,7 @@ import static io.grpc.ConnectivityState.IDLE; import static io.grpc.ConnectivityState.READY; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; +import static io.grpc.InternalEquivalentAddressGroup.ATTR_WEIGHT; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; @@ -49,12 +50,18 @@ import io.grpc.LoadBalancer.Subchannel; import io.grpc.LoadBalancer.SubchannelPicker; import io.grpc.LoadBalancer.SubchannelStateListener; +import io.grpc.ManagedChannel; import io.grpc.Status; import io.grpc.Status.Code; import io.grpc.SynchronizationContext; import io.grpc.internal.PickFirstLoadBalancer.PickFirstLoadBalancerConfig; import java.net.SocketAddress; +import java.util.ArrayDeque; +import java.util.Arrays; +import java.util.Collections; import java.util.List; +import java.util.Queue; +import java.util.Random; import org.junit.After; import org.junit.Before; import org.junit.Rule; @@ -103,8 +110,12 @@ public void uncaughtException(Thread t, Throwable e) { @Mock // This LoadBalancer doesn't use any of the arg fields, as verified in tearDown(). private PickSubchannelArgs mockArgs; + private boolean originalWeightedShuffling; + @Before public void setUp() { + originalWeightedShuffling = PickFirstLeafLoadBalancer.weightedShuffling; + for (int i = 0; i < 3; i++) { SocketAddress addr = new FakeSocketAddress("server" + i); servers.add(new EquivalentAddressGroup(addr)); @@ -120,6 +131,7 @@ public void setUp() { @After public void tearDown() throws Exception { + PickFirstLeafLoadBalancer.weightedShuffling = originalWeightedShuffling; verifyNoMoreInteractions(mockArgs); } @@ -141,6 +153,12 @@ public void pickAfterResolved() throws Exception { verifyNoMoreInteractions(mockHelper); } + @Test + public void pickAfterResolved_shuffle_oppositeWeightedShuffling() throws Exception { + PickFirstLeafLoadBalancer.weightedShuffling = !PickFirstLeafLoadBalancer.weightedShuffling; + pickAfterResolved_shuffle(); + } + @Test public void pickAfterResolved_shuffle() throws Exception { loadBalancer.acceptResolvedAddresses( @@ -184,6 +202,103 @@ public void pickAfterResolved_noShuffle() throws Exception { verifyNoMoreInteractions(mockHelper); } + @Test + public void pickAfterResolved_shuffleImplicitUniform_oppositeWeightedShuffling() { + PickFirstLeafLoadBalancer.weightedShuffling = !PickFirstLeafLoadBalancer.weightedShuffling; + pickAfterResolved_shuffleImplicitUniform(); + } + + @Test + public void pickAfterResolved_shuffleImplicitUniform() { + EquivalentAddressGroup eag1 = new EquivalentAddressGroup(new FakeSocketAddress("server1")); + EquivalentAddressGroup eag2 = new EquivalentAddressGroup(new FakeSocketAddress("server2")); + EquivalentAddressGroup eag3 = new EquivalentAddressGroup(new FakeSocketAddress("server3")); + + int[] counts = countAddressSelections(99, Arrays.asList(eag1, eag2, eag3)); + assertThat(counts[0]).isWithin(7).of(33); + assertThat(counts[1]).isWithin(7).of(33); + assertThat(counts[2]).isWithin(7).of(33); + } + + @Test + public void pickAfterResolved_shuffleExplicitUniform_oppositeWeightedShuffling() { + PickFirstLeafLoadBalancer.weightedShuffling = !PickFirstLeafLoadBalancer.weightedShuffling; + pickAfterResolved_shuffleExplicitUniform(); + } + + @Test + public void pickAfterResolved_shuffleExplicitUniform() { + EquivalentAddressGroup eag1 = new EquivalentAddressGroup( + new FakeSocketAddress("server1"), Attributes.newBuilder().set(ATTR_WEIGHT, 111L).build()); + EquivalentAddressGroup eag2 = new EquivalentAddressGroup( + new FakeSocketAddress("server2"), Attributes.newBuilder().set(ATTR_WEIGHT, 111L).build()); + EquivalentAddressGroup eag3 = new EquivalentAddressGroup( + new FakeSocketAddress("server3"), Attributes.newBuilder().set(ATTR_WEIGHT, 111L).build()); + + int[] counts = countAddressSelections(99, Arrays.asList(eag1, eag2, eag3)); + assertThat(counts[0]).isWithin(7).of(33); + assertThat(counts[1]).isWithin(7).of(33); + assertThat(counts[2]).isWithin(7).of(33); + } + + @Test + public void pickAfterResolved_shuffleWeighted_noWeightedShuffling() { + PickFirstLeafLoadBalancer.weightedShuffling = false; + EquivalentAddressGroup eag1 = new EquivalentAddressGroup( + new FakeSocketAddress("server1"), Attributes.newBuilder().set(ATTR_WEIGHT, 12L).build()); + EquivalentAddressGroup eag2 = new EquivalentAddressGroup( + new FakeSocketAddress("server2"), Attributes.newBuilder().set(ATTR_WEIGHT, 3L).build()); + EquivalentAddressGroup eag3 = new EquivalentAddressGroup( + new FakeSocketAddress("server3"), Attributes.newBuilder().set(ATTR_WEIGHT, 1L).build()); + + int[] counts = countAddressSelections(100, Arrays.asList(eag1, eag2, eag3)); + assertThat(counts[0]).isWithin(7).of(33); + assertThat(counts[1]).isWithin(7).of(33); + assertThat(counts[2]).isWithin(7).of(33); + } + + @Test + public void pickAfterResolved_shuffleWeighted_weightedShuffling() { + PickFirstLeafLoadBalancer.weightedShuffling = true; + EquivalentAddressGroup eag1 = new EquivalentAddressGroup( + new FakeSocketAddress("server1"), Attributes.newBuilder().set(ATTR_WEIGHT, 12L).build()); + EquivalentAddressGroup eag2 = new EquivalentAddressGroup( + new FakeSocketAddress("server2"), Attributes.newBuilder().set(ATTR_WEIGHT, 3L).build()); + EquivalentAddressGroup eag3 = new EquivalentAddressGroup( + new FakeSocketAddress("server3"), Attributes.newBuilder().set(ATTR_WEIGHT, 1L).build()); + + int[] counts = countAddressSelections(100, Arrays.asList(eag1, eag2, eag3)); + assertThat(counts[0]).isWithin(7).of(75); // 100*12/16 + assertThat(counts[1]).isWithin(7).of(19); // 100*3/16 + assertThat(counts[2]).isWithin(7).of(6); // 100*1/16 + } + + /** Returns int[index_of_eag] array with number of times each eag was selected. */ + private int[] countAddressSelections(int trials, List eags) { + int[] counts = new int[eags.size()]; + Random random = new Random(1); + for (int i = 0; i < trials; i++) { + RecordingHelper helper = new RecordingHelper(); + PickFirstLoadBalancer lb = new PickFirstLoadBalancer(helper); + assertThat(lb.acceptResolvedAddresses(ResolvedAddresses.newBuilder() + .setAddresses(eags) + .setAttributes(affinity) + .setLoadBalancingPolicyConfig( + new PickFirstLoadBalancerConfig(true, random.nextLong())) + .build())) + .isSameInstanceAs(Status.OK); + helper.subchannels.remove().listener.onSubchannelState( + ConnectivityStateInfo.forNonError(READY)); + + assertThat(helper.state).isEqualTo(READY); + Subchannel subchannel = helper.picker.pickSubchannel(mockArgs).getSubchannel(); + counts[eags.indexOf(subchannel.getAllAddresses().get(0))]++; + + lb.shutdown(); + } + return counts; + } + @Test public void requestConnectionPicker() throws Exception { loadBalancer.acceptResolvedAddresses( @@ -486,4 +601,96 @@ public String toString() { return "FakeSocketAddress-" + name; } } + + private static class FakeSubchannel extends Subchannel { + private final Attributes attributes; + private List eags; + private SubchannelStateListener listener; + + public FakeSubchannel(List eags, Attributes attributes) { + this.eags = Collections.unmodifiableList(eags); + this.attributes = attributes; + } + + @Override + public List getAllAddresses() { + return eags; + } + + @Override + public Attributes getAttributes() { + return attributes; + } + + @Override + public void start(SubchannelStateListener listener) { + this.listener = listener; + } + + @Override + public void updateAddresses(List addrs) { + this.eags = Collections.unmodifiableList(addrs); + } + + @Override + public void shutdown() { + listener.onSubchannelState(ConnectivityStateInfo.forNonError(ConnectivityState.SHUTDOWN)); + } + + @Override + public void requestConnection() { + } + + @Override + public String toString() { + return "FakeSubchannel@" + hashCode() + "(" + eags + ")"; + } + } + + private class BaseHelper extends Helper { + @Override + public ManagedChannel createOobChannel(EquivalentAddressGroup eag, String authority) { + return null; + } + + @Override + public String getAuthority() { + return null; + } + + @Override + public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) { + // ignore + } + + @Override + public SynchronizationContext getSynchronizationContext() { + return syncContext; + } + + @Override + public void refreshNameResolution() { + // noop + } + } + + class RecordingHelper extends BaseHelper { + ConnectivityState state; + SubchannelPicker picker; + final Queue subchannels = new ArrayDeque<>(); + + @Override + public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) { + this.state = newState; + this.picker = newPicker; + } + + @Override + public Subchannel createSubchannel(CreateSubchannelArgs args) { + FakeSubchannel subchannel = new FakeSubchannel(args.getAddresses(), args.getAttributes()); + subchannels.add(subchannel); + return subchannel; + } + } + } diff --git a/xds/src/main/java/io/grpc/xds/XdsAttributes.java b/xds/src/main/java/io/grpc/xds/XdsAttributes.java index 5647b25e418..d3fe8d4619c 100644 --- a/xds/src/main/java/io/grpc/xds/XdsAttributes.java +++ b/xds/src/main/java/io/grpc/xds/XdsAttributes.java @@ -19,6 +19,7 @@ import io.grpc.Attributes; import io.grpc.EquivalentAddressGroup; import io.grpc.Grpc; +import io.grpc.InternalEquivalentAddressGroup; import io.grpc.NameResolver; import io.grpc.xds.XdsNameResolverProvider.CallCounterProvider; import io.grpc.xds.client.Locality; @@ -84,8 +85,7 @@ final class XdsAttributes { * Endpoint weight for load balancing purposes. */ @EquivalentAddressGroup.Attr - static final Attributes.Key ATTR_SERVER_WEIGHT = - Attributes.Key.create("io.grpc.xds.XdsAttributes.serverWeight"); + static final Attributes.Key ATTR_SERVER_WEIGHT = InternalEquivalentAddressGroup.ATTR_WEIGHT; /** * Filter chain match for network filters. From 43057868bfd553695fe0e759fa6f8cc07b56e5dd Mon Sep 17 00:00:00 2001 From: Eric Anderson Date: Fri, 23 Jan 2026 12:50:09 -0800 Subject: [PATCH 2/2] xds: Normalize weights before combining endpoint and locality weights Previously, the number of endpoints in a locality would skew how much traffic was sent to that locality. Also, if endpoints in localities had wildly different weights, that would impact cross-locality weighting. For example, consider: LocalityA weight=1 endpointWeights=[100, 100, 100, 100] LocalityB weight=1 endpointWeights=[1] The endpoint in LocalityB should have an endpoint weight that is half the total sum of endpoint weights, in order to receive half the traffic. But the multiple endpoints in LocalityA would cause it to get 4x the traffic and the endpoint weights in LocalityA causes them to get 100x the traffic. See gRFC A113 --- .../java/io/grpc/xds/CdsLoadBalancer2.java | 75 ++++++++++++++++++- xds/src/main/java/io/grpc/xds/Endpoints.java | 5 +- .../xds/ClusterResolverLoadBalancerTest.java | 18 ++++- 3 files changed, 91 insertions(+), 7 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java b/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java index 67c5626aff5..5a59b47c529 100644 --- a/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java +++ b/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java @@ -22,6 +22,7 @@ import static io.grpc.xds.XdsLbPolicies.PRIORITY_POLICY_NAME; import com.google.common.collect.ImmutableMap; +import com.google.common.primitives.UnsignedInts; import com.google.errorprone.annotations.CheckReturnValue; import io.grpc.Attributes; import io.grpc.EquivalentAddressGroup; @@ -33,6 +34,7 @@ import io.grpc.NameResolver; import io.grpc.Status; import io.grpc.StatusOr; +import io.grpc.internal.GrpcUtil; import io.grpc.util.GracefulSwitchLoadBalancer; import io.grpc.util.OutlierDetectionLoadBalancer.OutlierDetectionLoadBalancerConfig; import io.grpc.xds.CdsLoadBalancerProvider.CdsConfig; @@ -74,6 +76,9 @@ * by a group of sub-clusters in a tree hierarchy. */ final class CdsLoadBalancer2 extends LoadBalancer { + static boolean pickFirstWeightedShuffling = + GrpcUtil.getFlag("GRPC_EXPERIMENTAL_PF_WEIGHTED_SHUFFLING", true); + private final XdsLogger logger; private final Helper helper; private final LoadBalancerRegistry lbRegistry; @@ -222,6 +227,26 @@ private String errorPrefix() { return "CdsLb for " + clusterName + ": "; } + /** + * The number of bits assigned to the fractional part of fixed-point values. We normalize weights + * to a fixed-point number between 0 and 1, representing that item's proportion of traffic (1 == + * 100% of traffic). We reserve at least one bit for the whole number so that we don't need to + * special case a single item, and so that we can round up very low values without risking uint32 + * overflow of the sum of weights. + */ + private static final int FIXED_POINT_FRACTIONAL_BITS = 31; + + /** Divide two uint32s and produce a fixed-point uint32 result. */ + private static long fractionToFixedPoint(long numerator, long denominator) { + long one = 1L << FIXED_POINT_FRACTIONAL_BITS; + return numerator * one / denominator; + } + + /** Multiply two uint32 fixed-point numbers, returning a uint32 fixed-point. */ + private static long fixedPointMultiply(long a, long b) { + return (a * b) >> FIXED_POINT_FRACTIONAL_BITS; + } + private static StatusOr getEdsUpdate(XdsConfig xdsConfig, String cluster) { StatusOr clusterConfig = xdsConfig.getClusters().get(cluster); if (clusterConfig == null) { @@ -286,17 +311,61 @@ StatusOr edsUpdateToResult( Map> prioritizedLocalityWeights = new HashMap<>(); List sortedPriorityNames = generatePriorityNames(clusterName, localityLbEndpoints); + Map priorityLocalityWeightSums; + if (pickFirstWeightedShuffling) { + priorityLocalityWeightSums = new HashMap<>(sortedPriorityNames.size() * 2); + for (Locality locality : localityLbEndpoints.keySet()) { + LocalityLbEndpoints localityLbInfo = localityLbEndpoints.get(locality); + String priorityName = localityPriorityNames.get(locality); + Long sum = priorityLocalityWeightSums.get(priorityName); + if (sum == null) { + sum = 0L; + } + long weight = UnsignedInts.toLong(localityLbInfo.localityWeight()); + priorityLocalityWeightSums.put(priorityName, sum + weight); + } + } else { + priorityLocalityWeightSums = null; + } + for (Locality locality : localityLbEndpoints.keySet()) { LocalityLbEndpoints localityLbInfo = localityLbEndpoints.get(locality); String priorityName = localityPriorityNames.get(locality); boolean discard = true; + // These sums _should_ fit in uint32, but XdsEndpointResource isn't actually verifying that + // is true today. Since we are using long to avoid signedness trouble, the math happens to + // still work if it turns out the sums exceed uint32. + long localityWeightSum = 0; + long endpointWeightSum = 0; + if (pickFirstWeightedShuffling) { + localityWeightSum = priorityLocalityWeightSums.get(priorityName); + for (LbEndpoint endpoint : localityLbInfo.endpoints()) { + if (endpoint.isHealthy()) { + endpointWeightSum += UnsignedInts.toLong(endpoint.loadBalancingWeight()); + } + } + } for (LbEndpoint endpoint : localityLbInfo.endpoints()) { if (endpoint.isHealthy()) { discard = false; - long weight = localityLbInfo.localityWeight(); - if (endpoint.loadBalancingWeight() != 0) { - weight *= endpoint.loadBalancingWeight(); + long weight; + if (pickFirstWeightedShuffling) { + // Combine locality and endpoint weights as defined by gRFC A113 + long localityWeight = fractionToFixedPoint( + UnsignedInts.toLong(localityLbInfo.localityWeight()), localityWeightSum); + long endpointWeight = fractionToFixedPoint( + UnsignedInts.toLong(endpoint.loadBalancingWeight()), endpointWeightSum); + weight = fixedPointMultiply(localityWeight, endpointWeight); + if (weight == 0) { + weight = 1; + } + } else { + weight = localityLbInfo.localityWeight(); + if (endpoint.loadBalancingWeight() != 0) { + weight *= endpoint.loadBalancingWeight(); + } } + String localityName = localityName(locality); Attributes attr = endpoint.eag().getAttributes().toBuilder() diff --git a/xds/src/main/java/io/grpc/xds/Endpoints.java b/xds/src/main/java/io/grpc/xds/Endpoints.java index dcb72f3e90d..558e3932ddc 100644 --- a/xds/src/main/java/io/grpc/xds/Endpoints.java +++ b/xds/src/main/java/io/grpc/xds/Endpoints.java @@ -59,7 +59,7 @@ abstract static class LbEndpoint { // The endpoint address to be connected to. abstract EquivalentAddressGroup eag(); - // Endpoint's weight for load balancing. If unspecified, value of 0 is returned. + // Endpoint's weight for load balancing. Guaranteed not to be 0. abstract int loadBalancingWeight(); // Whether the endpoint is healthy. @@ -71,6 +71,9 @@ abstract static class LbEndpoint { static LbEndpoint create(EquivalentAddressGroup eag, int loadBalancingWeight, boolean isHealthy, String hostname, ImmutableMap endpointMetadata) { + if (loadBalancingWeight == 0) { + loadBalancingWeight = 1; + } return new AutoValue_Endpoints_LbEndpoint( eag, loadBalancingWeight, isHealthy, hostname, endpointMetadata); } diff --git a/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java index 5e524f79596..c6e5db08526 100644 --- a/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java @@ -260,6 +260,18 @@ public void tearDown() throws Exception { assertThat(fakeClock.getPendingTasks()).isEmpty(); } + @Test + public void edsClustersWithRingHashEndpointLbPolicy_oppositePickFirstWeightedShuffling() + throws Exception { + boolean original = CdsLoadBalancer2.pickFirstWeightedShuffling; + CdsLoadBalancer2.pickFirstWeightedShuffling = !CdsLoadBalancer2.pickFirstWeightedShuffling; + try { + edsClustersWithRingHashEndpointLbPolicy(); + } finally { + CdsLoadBalancer2.pickFirstWeightedShuffling = original; + } + } + @Test public void edsClustersWithRingHashEndpointLbPolicy() throws Exception { boolean originalVal = LoadStatsManager2.isEnabledOrcaLrsPropagation; @@ -306,15 +318,15 @@ public void edsClustersWithRingHashEndpointLbPolicy() throws Exception { assertThat(addr1.getAddresses()) .isEqualTo(Arrays.asList(newInetSocketAddress("127.0.0.1", 8080))); assertThat(addr1.getAttributes().get(io.grpc.xds.XdsAttributes.ATTR_SERVER_WEIGHT)) - .isEqualTo(10); + .isEqualTo(CdsLoadBalancer2.pickFirstWeightedShuffling ? 0x0AAAAAAA /* 1/12 */ : 10); assertThat(addr2.getAddresses()) .isEqualTo(Arrays.asList(newInetSocketAddress("127.0.0.2", 8080))); assertThat(addr2.getAttributes().get(io.grpc.xds.XdsAttributes.ATTR_SERVER_WEIGHT)) - .isEqualTo(10); + .isEqualTo(CdsLoadBalancer2.pickFirstWeightedShuffling ? 0x0AAAAAAA /* 1/12 */ : 10); assertThat(addr3.getAddresses()) .isEqualTo(Arrays.asList(newInetSocketAddress("127.0.1.1", 8080))); assertThat(addr3.getAttributes().get(io.grpc.xds.XdsAttributes.ATTR_SERVER_WEIGHT)) - .isEqualTo(50 * 60); + .isEqualTo(CdsLoadBalancer2.pickFirstWeightedShuffling ? 0x6AAAAAAA /* 5/6 */ : 50 * 60); assertThat(childBalancer.name).isEqualTo(PRIORITY_POLICY_NAME); PriorityLbConfig priorityLbConfig = (PriorityLbConfig) childBalancer.config; assertThat(priorityLbConfig.priorities).containsExactly(CLUSTER + "[child1]");