diff --git a/api/src/main/java/io/grpc/EquivalentAddressGroup.java b/api/src/main/java/io/grpc/EquivalentAddressGroup.java index bf8a864902c..d40e47c16f6 100644 --- a/api/src/main/java/io/grpc/EquivalentAddressGroup.java +++ b/api/src/main/java/io/grpc/EquivalentAddressGroup.java @@ -55,6 +55,15 @@ public final class EquivalentAddressGroup { */ public static final Attributes.Key ATTR_LOCALITY_NAME = Attributes.Key.create("io.grpc.EquivalentAddressGroup.LOCALITY"); + /** + * Endpoint weight for load balancing purposes. While the type is Long, it must be a valid uint32. + * Must not be zero. The weight is proportional to the other endpoints; if an endpoint's weight is + * twice that of another endpoint, it is intended to receive twice the load. + */ + @Attr + static final Attributes.Key ATTR_WEIGHT = + Attributes.Key.create("io.grpc.EquivalentAddressGroup.ATTR_WEIGHT"); + private final List addrs; private final Attributes attrs; diff --git a/api/src/main/java/io/grpc/InternalEquivalentAddressGroup.java b/api/src/main/java/io/grpc/InternalEquivalentAddressGroup.java new file mode 100644 index 00000000000..d4bed4d81bc --- /dev/null +++ b/api/src/main/java/io/grpc/InternalEquivalentAddressGroup.java @@ -0,0 +1,29 @@ +/* + * Copyright 2026 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +@Internal +public final class InternalEquivalentAddressGroup { + private InternalEquivalentAddressGroup() {} + + /** + * Endpoint weight for load balancing purposes. While the type is Long, it must be a valid uint32. + * Must not be zero. The weight is proportional to the other endpoints; if an endpoint's weight is + * twice that of another endpoint, it is intended to receive twice the load. + */ + public static final Attributes.Key ATTR_WEIGHT = EquivalentAddressGroup.ATTR_WEIGHT; +} diff --git a/core/src/main/java/io/grpc/internal/PickFirstLeafLoadBalancer.java b/core/src/main/java/io/grpc/internal/PickFirstLeafLoadBalancer.java index 935214a94fd..cd9ff9ab58c 100644 --- a/core/src/main/java/io/grpc/internal/PickFirstLeafLoadBalancer.java +++ b/core/src/main/java/io/grpc/internal/PickFirstLeafLoadBalancer.java @@ -26,10 +26,12 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; +import com.google.errorprone.annotations.CheckReturnValue; import io.grpc.Attributes; import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; import io.grpc.EquivalentAddressGroup; +import io.grpc.InternalEquivalentAddressGroup; import io.grpc.LoadBalancer; import io.grpc.Status; import io.grpc.SynchronizationContext.ScheduledHandle; @@ -61,6 +63,8 @@ final class PickFirstLeafLoadBalancer extends LoadBalancer { static final int CONNECTION_DELAY_INTERVAL_MS = 250; private final boolean enableHappyEyeballs = !isSerializingRetries() && PickFirstLoadBalancerProvider.isEnabledHappyEyeballs(); + static boolean weightedShuffling = + GrpcUtil.getFlag("GRPC_EXPERIMENTAL_PF_WEIGHTED_SHUFFLING", true); private final Helper helper; private final Map subchannels = new HashMap<>(); private final Index addressIndex = new Index(ImmutableList.of(), this.enableHappyEyeballs); @@ -128,13 +132,13 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { PickFirstLeafLoadBalancerConfig config = (PickFirstLeafLoadBalancerConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); if (config.shuffleAddressList != null && config.shuffleAddressList) { - Collections.shuffle(cleanServers, - config.randomSeed != null ? new Random(config.randomSeed) : new Random()); + cleanServers = shuffle( + cleanServers, config.randomSeed != null ? new Random(config.randomSeed) : new Random()); } } final ImmutableList newImmutableAddressGroups = - ImmutableList.builder().addAll(cleanServers).build(); + ImmutableList.copyOf(cleanServers); if (rawConnectivityState == READY || (rawConnectivityState == CONNECTING @@ -224,6 +228,46 @@ private static List deDupAddresses(List shuffle(List eags, Random random) { + if (weightedShuffling) { + List weightedEntries = new ArrayList<>(eags.size()); + for (EquivalentAddressGroup eag : eags) { + weightedEntries.add(new WeightEntry(eag, eagToWeight(eag, random))); + } + Collections.sort(weightedEntries, Collections.reverseOrder() /* descending */); + return Lists.transform(weightedEntries, entry -> entry.eag); + } else { + List eagsCopy = new ArrayList<>(eags); + Collections.shuffle(eagsCopy, random); + return eagsCopy; + } + } + + private static double eagToWeight(EquivalentAddressGroup eag, Random random) { + Long weight = eag.getAttributes().get(InternalEquivalentAddressGroup.ATTR_WEIGHT); + if (weight == null) { + weight = 1L; + } + return Math.pow(random.nextDouble(), 1.0 / weight); + } + + private static final class WeightEntry implements Comparable { + final EquivalentAddressGroup eag; + final double weight; + + public WeightEntry(EquivalentAddressGroup eag, double weight) { + this.eag = eag; + this.weight = weight; + } + + @Override + public int compareTo(WeightEntry entry) { + return Double.compare(this.weight, entry.weight); + } + } + @Override public void handleNameResolutionError(Status error) { if (rawConnectivityState == SHUTDOWN) { diff --git a/core/src/main/java/io/grpc/internal/PickFirstLoadBalancer.java b/core/src/main/java/io/grpc/internal/PickFirstLoadBalancer.java index aa8b5a7e9a9..cf4b4c94e04 100644 --- a/core/src/main/java/io/grpc/internal/PickFirstLoadBalancer.java +++ b/core/src/main/java/io/grpc/internal/PickFirstLoadBalancer.java @@ -27,8 +27,6 @@ import io.grpc.EquivalentAddressGroup; import io.grpc.LoadBalancer; import io.grpc.Status; -import java.util.ArrayList; -import java.util.Collections; import java.util.List; import java.util.Random; import java.util.concurrent.atomic.AtomicBoolean; @@ -65,9 +63,8 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { PickFirstLoadBalancerConfig config = (PickFirstLoadBalancerConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); if (config.shuffleAddressList != null && config.shuffleAddressList) { - servers = new ArrayList(servers); - Collections.shuffle(servers, - config.randomSeed != null ? new Random(config.randomSeed) : new Random()); + servers = PickFirstLeafLoadBalancer.shuffle( + servers, config.randomSeed != null ? new Random(config.randomSeed) : new Random()); } } diff --git a/core/src/test/java/io/grpc/internal/PickFirstLeafLoadBalancerTest.java b/core/src/test/java/io/grpc/internal/PickFirstLeafLoadBalancerTest.java index 8b09fce2aa2..cb73d17d682 100644 --- a/core/src/test/java/io/grpc/internal/PickFirstLeafLoadBalancerTest.java +++ b/core/src/test/java/io/grpc/internal/PickFirstLeafLoadBalancerTest.java @@ -23,6 +23,7 @@ import static io.grpc.ConnectivityState.READY; import static io.grpc.ConnectivityState.SHUTDOWN; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; +import static io.grpc.InternalEquivalentAddressGroup.ATTR_WEIGHT; import static io.grpc.LoadBalancer.HAS_HEALTH_PRODUCER_LISTENER_KEY; import static io.grpc.LoadBalancer.HEALTH_CONSUMER_LISTENER_ARG_KEY; import static io.grpc.LoadBalancer.IS_PETIOLE_POLICY; @@ -70,10 +71,13 @@ import io.grpc.internal.PickFirstLeafLoadBalancer.PickFirstLeafLoadBalancerConfig; import java.net.InetSocketAddress; import java.net.SocketAddress; +import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Queue; +import java.util.Random; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import org.junit.After; @@ -149,6 +153,7 @@ public void uncaughtException(Thread t, Throwable e) { private String originalHappyEyeballsEnabledValue; private String originalSerializeRetriesValue; + private boolean originalWeightedShuffling; private long backoffMillis; @@ -165,6 +170,8 @@ public void setUp() { System.setProperty(PickFirstLoadBalancerProvider.GRPC_PF_USE_HAPPY_EYEBALLS, Boolean.toString(enableHappyEyeballs)); + originalWeightedShuffling = PickFirstLeafLoadBalancer.weightedShuffling; + for (int i = 1; i <= 5; i++) { SocketAddress addr = new FakeSocketAddress("server" + i); servers.add(new EquivalentAddressGroup(addr)); @@ -207,6 +214,7 @@ public void tearDown() { System.setProperty(PickFirstLoadBalancerProvider.GRPC_PF_USE_HAPPY_EYEBALLS, originalHappyEyeballsEnabledValue); } + PickFirstLeafLoadBalancer.weightedShuffling = originalWeightedShuffling; loadBalancer.shutdown(); verifyNoMoreInteractions(mockArgs); @@ -242,6 +250,12 @@ public void pickAfterResolved() { verifyNoMoreInteractions(mockHelper); } + @Test + public void pickAfterResolved_shuffle_oppositeWeightedShuffling() { + PickFirstLeafLoadBalancer.weightedShuffling = !PickFirstLeafLoadBalancer.weightedShuffling; + pickAfterResolved_shuffle(); + } + @Test public void pickAfterResolved_shuffle() { servers.remove(4); @@ -305,6 +319,103 @@ public void pickAfterResolved_noShuffle() { assertNotNull(pickerCaptor.getValue().pickSubchannel(mockArgs)); } + @Test + public void pickAfterResolved_shuffleImplicitUniform_oppositeWeightedShuffling() { + PickFirstLeafLoadBalancer.weightedShuffling = !PickFirstLeafLoadBalancer.weightedShuffling; + pickAfterResolved_shuffleImplicitUniform(); + } + + @Test + public void pickAfterResolved_shuffleImplicitUniform() { + EquivalentAddressGroup eag1 = new EquivalentAddressGroup(new FakeSocketAddress("server1")); + EquivalentAddressGroup eag2 = new EquivalentAddressGroup(new FakeSocketAddress("server2")); + EquivalentAddressGroup eag3 = new EquivalentAddressGroup(new FakeSocketAddress("server3")); + + int[] counts = countAddressSelections(99, Arrays.asList(eag1, eag2, eag3)); + assertThat(counts[0]).isWithin(7).of(33); + assertThat(counts[1]).isWithin(7).of(33); + assertThat(counts[2]).isWithin(7).of(33); + } + + @Test + public void pickAfterResolved_shuffleExplicitUniform_oppositeWeightedShuffling() { + PickFirstLeafLoadBalancer.weightedShuffling = !PickFirstLeafLoadBalancer.weightedShuffling; + pickAfterResolved_shuffleExplicitUniform(); + } + + @Test + public void pickAfterResolved_shuffleExplicitUniform() { + EquivalentAddressGroup eag1 = new EquivalentAddressGroup( + new FakeSocketAddress("server1"), Attributes.newBuilder().set(ATTR_WEIGHT, 111L).build()); + EquivalentAddressGroup eag2 = new EquivalentAddressGroup( + new FakeSocketAddress("server2"), Attributes.newBuilder().set(ATTR_WEIGHT, 111L).build()); + EquivalentAddressGroup eag3 = new EquivalentAddressGroup( + new FakeSocketAddress("server3"), Attributes.newBuilder().set(ATTR_WEIGHT, 111L).build()); + + int[] counts = countAddressSelections(99, Arrays.asList(eag1, eag2, eag3)); + assertThat(counts[0]).isWithin(7).of(33); + assertThat(counts[1]).isWithin(7).of(33); + assertThat(counts[2]).isWithin(7).of(33); + } + + @Test + public void pickAfterResolved_shuffleWeighted_noWeightedShuffling() { + PickFirstLeafLoadBalancer.weightedShuffling = false; + EquivalentAddressGroup eag1 = new EquivalentAddressGroup( + new FakeSocketAddress("server1"), Attributes.newBuilder().set(ATTR_WEIGHT, 12L).build()); + EquivalentAddressGroup eag2 = new EquivalentAddressGroup( + new FakeSocketAddress("server2"), Attributes.newBuilder().set(ATTR_WEIGHT, 3L).build()); + EquivalentAddressGroup eag3 = new EquivalentAddressGroup( + new FakeSocketAddress("server3"), Attributes.newBuilder().set(ATTR_WEIGHT, 1L).build()); + + int[] counts = countAddressSelections(100, Arrays.asList(eag1, eag2, eag3)); + assertThat(counts[0]).isWithin(7).of(33); + assertThat(counts[1]).isWithin(7).of(33); + assertThat(counts[2]).isWithin(7).of(33); + } + + @Test + public void pickAfterResolved_shuffleWeighted_weightedShuffling() { + PickFirstLeafLoadBalancer.weightedShuffling = true; + EquivalentAddressGroup eag1 = new EquivalentAddressGroup( + new FakeSocketAddress("server1"), Attributes.newBuilder().set(ATTR_WEIGHT, 12L).build()); + EquivalentAddressGroup eag2 = new EquivalentAddressGroup( + new FakeSocketAddress("server2"), Attributes.newBuilder().set(ATTR_WEIGHT, 3L).build()); + EquivalentAddressGroup eag3 = new EquivalentAddressGroup( + new FakeSocketAddress("server3"), Attributes.newBuilder().set(ATTR_WEIGHT, 1L).build()); + + int[] counts = countAddressSelections(100, Arrays.asList(eag1, eag2, eag3)); + assertThat(counts[0]).isWithin(7).of(75); // 100*12/16 + assertThat(counts[1]).isWithin(7).of(19); // 100*3/16 + assertThat(counts[2]).isWithin(7).of(6); // 100*1/16 + } + + /** Returns int[index_of_eag] array with number of times each eag was selected. */ + private int[] countAddressSelections(int trials, List eags) { + int[] counts = new int[eags.size()]; + Random random = new Random(1); + for (int i = 0; i < trials; i++) { + RecordingHelper helper = new RecordingHelper(); + LoadBalancer lb = new PickFirstLeafLoadBalancer(helper); + assertThat(lb.acceptResolvedAddresses(ResolvedAddresses.newBuilder() + .setAddresses(eags) + .setAttributes(affinity) + .setLoadBalancingPolicyConfig( + new PickFirstLeafLoadBalancerConfig(true, random.nextLong())) + .build())) + .isSameInstanceAs(Status.OK); + helper.subchannels.remove().listener.onSubchannelState( + ConnectivityStateInfo.forNonError(READY)); + + assertThat(helper.state).isEqualTo(READY); + Subchannel subchannel = helper.picker.pickSubchannel(mockArgs).getSubchannel(); + counts[eags.indexOf(subchannel.getAddresses())]++; + + lb.shutdown(); + } + return counts; + } + @Test public void requestConnectionPicker() { // Set up @@ -2945,13 +3056,7 @@ public String toString() { } } - private class MockHelperImpl extends LoadBalancer.Helper { - private final List subchannels; - - public MockHelperImpl(List subchannels) { - this.subchannels = new ArrayList(subchannels); - } - + private class BaseHelper extends LoadBalancer.Helper { @Override public ManagedChannel createOobChannel(EquivalentAddressGroup eag, String authority) { return null; @@ -2981,6 +3086,14 @@ public ScheduledExecutorService getScheduledExecutorService() { public void refreshNameResolution() { // noop } + } + + private class MockHelperImpl extends BaseHelper { + private final List subchannels; + + public MockHelperImpl(List subchannels) { + this.subchannels = new ArrayList(subchannels); + } @Override public Subchannel createSubchannel(CreateSubchannelArgs args) { @@ -2997,4 +3110,23 @@ public Subchannel createSubchannel(CreateSubchannelArgs args) { throw new IllegalArgumentException("Unexpected addresses: " + args.getAddresses()); } } + + class RecordingHelper extends BaseHelper { + ConnectivityState state; + SubchannelPicker picker; + final Queue subchannels = new ArrayDeque<>(); + + @Override + public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) { + this.state = newState; + this.picker = newPicker; + } + + @Override + public Subchannel createSubchannel(CreateSubchannelArgs args) { + FakeSubchannel subchannel = new FakeSubchannel(args.getAddresses(), args.getAttributes()); + subchannels.add(subchannel); + return subchannel; + } + } } diff --git a/core/src/test/java/io/grpc/internal/PickFirstLoadBalancerTest.java b/core/src/test/java/io/grpc/internal/PickFirstLoadBalancerTest.java index 819293e070b..1e130423a45 100644 --- a/core/src/test/java/io/grpc/internal/PickFirstLoadBalancerTest.java +++ b/core/src/test/java/io/grpc/internal/PickFirstLoadBalancerTest.java @@ -21,6 +21,7 @@ import static io.grpc.ConnectivityState.IDLE; import static io.grpc.ConnectivityState.READY; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; +import static io.grpc.InternalEquivalentAddressGroup.ATTR_WEIGHT; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; @@ -49,12 +50,18 @@ import io.grpc.LoadBalancer.Subchannel; import io.grpc.LoadBalancer.SubchannelPicker; import io.grpc.LoadBalancer.SubchannelStateListener; +import io.grpc.ManagedChannel; import io.grpc.Status; import io.grpc.Status.Code; import io.grpc.SynchronizationContext; import io.grpc.internal.PickFirstLoadBalancer.PickFirstLoadBalancerConfig; import java.net.SocketAddress; +import java.util.ArrayDeque; +import java.util.Arrays; +import java.util.Collections; import java.util.List; +import java.util.Queue; +import java.util.Random; import org.junit.After; import org.junit.Before; import org.junit.Rule; @@ -103,8 +110,12 @@ public void uncaughtException(Thread t, Throwable e) { @Mock // This LoadBalancer doesn't use any of the arg fields, as verified in tearDown(). private PickSubchannelArgs mockArgs; + private boolean originalWeightedShuffling; + @Before public void setUp() { + originalWeightedShuffling = PickFirstLeafLoadBalancer.weightedShuffling; + for (int i = 0; i < 3; i++) { SocketAddress addr = new FakeSocketAddress("server" + i); servers.add(new EquivalentAddressGroup(addr)); @@ -120,6 +131,7 @@ public void setUp() { @After public void tearDown() throws Exception { + PickFirstLeafLoadBalancer.weightedShuffling = originalWeightedShuffling; verifyNoMoreInteractions(mockArgs); } @@ -141,6 +153,12 @@ public void pickAfterResolved() throws Exception { verifyNoMoreInteractions(mockHelper); } + @Test + public void pickAfterResolved_shuffle_oppositeWeightedShuffling() throws Exception { + PickFirstLeafLoadBalancer.weightedShuffling = !PickFirstLeafLoadBalancer.weightedShuffling; + pickAfterResolved_shuffle(); + } + @Test public void pickAfterResolved_shuffle() throws Exception { loadBalancer.acceptResolvedAddresses( @@ -184,6 +202,103 @@ public void pickAfterResolved_noShuffle() throws Exception { verifyNoMoreInteractions(mockHelper); } + @Test + public void pickAfterResolved_shuffleImplicitUniform_oppositeWeightedShuffling() { + PickFirstLeafLoadBalancer.weightedShuffling = !PickFirstLeafLoadBalancer.weightedShuffling; + pickAfterResolved_shuffleImplicitUniform(); + } + + @Test + public void pickAfterResolved_shuffleImplicitUniform() { + EquivalentAddressGroup eag1 = new EquivalentAddressGroup(new FakeSocketAddress("server1")); + EquivalentAddressGroup eag2 = new EquivalentAddressGroup(new FakeSocketAddress("server2")); + EquivalentAddressGroup eag3 = new EquivalentAddressGroup(new FakeSocketAddress("server3")); + + int[] counts = countAddressSelections(99, Arrays.asList(eag1, eag2, eag3)); + assertThat(counts[0]).isWithin(7).of(33); + assertThat(counts[1]).isWithin(7).of(33); + assertThat(counts[2]).isWithin(7).of(33); + } + + @Test + public void pickAfterResolved_shuffleExplicitUniform_oppositeWeightedShuffling() { + PickFirstLeafLoadBalancer.weightedShuffling = !PickFirstLeafLoadBalancer.weightedShuffling; + pickAfterResolved_shuffleExplicitUniform(); + } + + @Test + public void pickAfterResolved_shuffleExplicitUniform() { + EquivalentAddressGroup eag1 = new EquivalentAddressGroup( + new FakeSocketAddress("server1"), Attributes.newBuilder().set(ATTR_WEIGHT, 111L).build()); + EquivalentAddressGroup eag2 = new EquivalentAddressGroup( + new FakeSocketAddress("server2"), Attributes.newBuilder().set(ATTR_WEIGHT, 111L).build()); + EquivalentAddressGroup eag3 = new EquivalentAddressGroup( + new FakeSocketAddress("server3"), Attributes.newBuilder().set(ATTR_WEIGHT, 111L).build()); + + int[] counts = countAddressSelections(99, Arrays.asList(eag1, eag2, eag3)); + assertThat(counts[0]).isWithin(7).of(33); + assertThat(counts[1]).isWithin(7).of(33); + assertThat(counts[2]).isWithin(7).of(33); + } + + @Test + public void pickAfterResolved_shuffleWeighted_noWeightedShuffling() { + PickFirstLeafLoadBalancer.weightedShuffling = false; + EquivalentAddressGroup eag1 = new EquivalentAddressGroup( + new FakeSocketAddress("server1"), Attributes.newBuilder().set(ATTR_WEIGHT, 12L).build()); + EquivalentAddressGroup eag2 = new EquivalentAddressGroup( + new FakeSocketAddress("server2"), Attributes.newBuilder().set(ATTR_WEIGHT, 3L).build()); + EquivalentAddressGroup eag3 = new EquivalentAddressGroup( + new FakeSocketAddress("server3"), Attributes.newBuilder().set(ATTR_WEIGHT, 1L).build()); + + int[] counts = countAddressSelections(100, Arrays.asList(eag1, eag2, eag3)); + assertThat(counts[0]).isWithin(7).of(33); + assertThat(counts[1]).isWithin(7).of(33); + assertThat(counts[2]).isWithin(7).of(33); + } + + @Test + public void pickAfterResolved_shuffleWeighted_weightedShuffling() { + PickFirstLeafLoadBalancer.weightedShuffling = true; + EquivalentAddressGroup eag1 = new EquivalentAddressGroup( + new FakeSocketAddress("server1"), Attributes.newBuilder().set(ATTR_WEIGHT, 12L).build()); + EquivalentAddressGroup eag2 = new EquivalentAddressGroup( + new FakeSocketAddress("server2"), Attributes.newBuilder().set(ATTR_WEIGHT, 3L).build()); + EquivalentAddressGroup eag3 = new EquivalentAddressGroup( + new FakeSocketAddress("server3"), Attributes.newBuilder().set(ATTR_WEIGHT, 1L).build()); + + int[] counts = countAddressSelections(100, Arrays.asList(eag1, eag2, eag3)); + assertThat(counts[0]).isWithin(7).of(75); // 100*12/16 + assertThat(counts[1]).isWithin(7).of(19); // 100*3/16 + assertThat(counts[2]).isWithin(7).of(6); // 100*1/16 + } + + /** Returns int[index_of_eag] array with number of times each eag was selected. */ + private int[] countAddressSelections(int trials, List eags) { + int[] counts = new int[eags.size()]; + Random random = new Random(1); + for (int i = 0; i < trials; i++) { + RecordingHelper helper = new RecordingHelper(); + PickFirstLoadBalancer lb = new PickFirstLoadBalancer(helper); + assertThat(lb.acceptResolvedAddresses(ResolvedAddresses.newBuilder() + .setAddresses(eags) + .setAttributes(affinity) + .setLoadBalancingPolicyConfig( + new PickFirstLoadBalancerConfig(true, random.nextLong())) + .build())) + .isSameInstanceAs(Status.OK); + helper.subchannels.remove().listener.onSubchannelState( + ConnectivityStateInfo.forNonError(READY)); + + assertThat(helper.state).isEqualTo(READY); + Subchannel subchannel = helper.picker.pickSubchannel(mockArgs).getSubchannel(); + counts[eags.indexOf(subchannel.getAllAddresses().get(0))]++; + + lb.shutdown(); + } + return counts; + } + @Test public void requestConnectionPicker() throws Exception { loadBalancer.acceptResolvedAddresses( @@ -486,4 +601,96 @@ public String toString() { return "FakeSocketAddress-" + name; } } + + private static class FakeSubchannel extends Subchannel { + private final Attributes attributes; + private List eags; + private SubchannelStateListener listener; + + public FakeSubchannel(List eags, Attributes attributes) { + this.eags = Collections.unmodifiableList(eags); + this.attributes = attributes; + } + + @Override + public List getAllAddresses() { + return eags; + } + + @Override + public Attributes getAttributes() { + return attributes; + } + + @Override + public void start(SubchannelStateListener listener) { + this.listener = listener; + } + + @Override + public void updateAddresses(List addrs) { + this.eags = Collections.unmodifiableList(addrs); + } + + @Override + public void shutdown() { + listener.onSubchannelState(ConnectivityStateInfo.forNonError(ConnectivityState.SHUTDOWN)); + } + + @Override + public void requestConnection() { + } + + @Override + public String toString() { + return "FakeSubchannel@" + hashCode() + "(" + eags + ")"; + } + } + + private class BaseHelper extends Helper { + @Override + public ManagedChannel createOobChannel(EquivalentAddressGroup eag, String authority) { + return null; + } + + @Override + public String getAuthority() { + return null; + } + + @Override + public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) { + // ignore + } + + @Override + public SynchronizationContext getSynchronizationContext() { + return syncContext; + } + + @Override + public void refreshNameResolution() { + // noop + } + } + + class RecordingHelper extends BaseHelper { + ConnectivityState state; + SubchannelPicker picker; + final Queue subchannels = new ArrayDeque<>(); + + @Override + public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) { + this.state = newState; + this.picker = newPicker; + } + + @Override + public Subchannel createSubchannel(CreateSubchannelArgs args) { + FakeSubchannel subchannel = new FakeSubchannel(args.getAddresses(), args.getAttributes()); + subchannels.add(subchannel); + return subchannel; + } + } + } diff --git a/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java b/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java index 67c5626aff5..5a59b47c529 100644 --- a/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java +++ b/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java @@ -22,6 +22,7 @@ import static io.grpc.xds.XdsLbPolicies.PRIORITY_POLICY_NAME; import com.google.common.collect.ImmutableMap; +import com.google.common.primitives.UnsignedInts; import com.google.errorprone.annotations.CheckReturnValue; import io.grpc.Attributes; import io.grpc.EquivalentAddressGroup; @@ -33,6 +34,7 @@ import io.grpc.NameResolver; import io.grpc.Status; import io.grpc.StatusOr; +import io.grpc.internal.GrpcUtil; import io.grpc.util.GracefulSwitchLoadBalancer; import io.grpc.util.OutlierDetectionLoadBalancer.OutlierDetectionLoadBalancerConfig; import io.grpc.xds.CdsLoadBalancerProvider.CdsConfig; @@ -74,6 +76,9 @@ * by a group of sub-clusters in a tree hierarchy. */ final class CdsLoadBalancer2 extends LoadBalancer { + static boolean pickFirstWeightedShuffling = + GrpcUtil.getFlag("GRPC_EXPERIMENTAL_PF_WEIGHTED_SHUFFLING", true); + private final XdsLogger logger; private final Helper helper; private final LoadBalancerRegistry lbRegistry; @@ -222,6 +227,26 @@ private String errorPrefix() { return "CdsLb for " + clusterName + ": "; } + /** + * The number of bits assigned to the fractional part of fixed-point values. We normalize weights + * to a fixed-point number between 0 and 1, representing that item's proportion of traffic (1 == + * 100% of traffic). We reserve at least one bit for the whole number so that we don't need to + * special case a single item, and so that we can round up very low values without risking uint32 + * overflow of the sum of weights. + */ + private static final int FIXED_POINT_FRACTIONAL_BITS = 31; + + /** Divide two uint32s and produce a fixed-point uint32 result. */ + private static long fractionToFixedPoint(long numerator, long denominator) { + long one = 1L << FIXED_POINT_FRACTIONAL_BITS; + return numerator * one / denominator; + } + + /** Multiply two uint32 fixed-point numbers, returning a uint32 fixed-point. */ + private static long fixedPointMultiply(long a, long b) { + return (a * b) >> FIXED_POINT_FRACTIONAL_BITS; + } + private static StatusOr getEdsUpdate(XdsConfig xdsConfig, String cluster) { StatusOr clusterConfig = xdsConfig.getClusters().get(cluster); if (clusterConfig == null) { @@ -286,17 +311,61 @@ StatusOr edsUpdateToResult( Map> prioritizedLocalityWeights = new HashMap<>(); List sortedPriorityNames = generatePriorityNames(clusterName, localityLbEndpoints); + Map priorityLocalityWeightSums; + if (pickFirstWeightedShuffling) { + priorityLocalityWeightSums = new HashMap<>(sortedPriorityNames.size() * 2); + for (Locality locality : localityLbEndpoints.keySet()) { + LocalityLbEndpoints localityLbInfo = localityLbEndpoints.get(locality); + String priorityName = localityPriorityNames.get(locality); + Long sum = priorityLocalityWeightSums.get(priorityName); + if (sum == null) { + sum = 0L; + } + long weight = UnsignedInts.toLong(localityLbInfo.localityWeight()); + priorityLocalityWeightSums.put(priorityName, sum + weight); + } + } else { + priorityLocalityWeightSums = null; + } + for (Locality locality : localityLbEndpoints.keySet()) { LocalityLbEndpoints localityLbInfo = localityLbEndpoints.get(locality); String priorityName = localityPriorityNames.get(locality); boolean discard = true; + // These sums _should_ fit in uint32, but XdsEndpointResource isn't actually verifying that + // is true today. Since we are using long to avoid signedness trouble, the math happens to + // still work if it turns out the sums exceed uint32. + long localityWeightSum = 0; + long endpointWeightSum = 0; + if (pickFirstWeightedShuffling) { + localityWeightSum = priorityLocalityWeightSums.get(priorityName); + for (LbEndpoint endpoint : localityLbInfo.endpoints()) { + if (endpoint.isHealthy()) { + endpointWeightSum += UnsignedInts.toLong(endpoint.loadBalancingWeight()); + } + } + } for (LbEndpoint endpoint : localityLbInfo.endpoints()) { if (endpoint.isHealthy()) { discard = false; - long weight = localityLbInfo.localityWeight(); - if (endpoint.loadBalancingWeight() != 0) { - weight *= endpoint.loadBalancingWeight(); + long weight; + if (pickFirstWeightedShuffling) { + // Combine locality and endpoint weights as defined by gRFC A113 + long localityWeight = fractionToFixedPoint( + UnsignedInts.toLong(localityLbInfo.localityWeight()), localityWeightSum); + long endpointWeight = fractionToFixedPoint( + UnsignedInts.toLong(endpoint.loadBalancingWeight()), endpointWeightSum); + weight = fixedPointMultiply(localityWeight, endpointWeight); + if (weight == 0) { + weight = 1; + } + } else { + weight = localityLbInfo.localityWeight(); + if (endpoint.loadBalancingWeight() != 0) { + weight *= endpoint.loadBalancingWeight(); + } } + String localityName = localityName(locality); Attributes attr = endpoint.eag().getAttributes().toBuilder() diff --git a/xds/src/main/java/io/grpc/xds/Endpoints.java b/xds/src/main/java/io/grpc/xds/Endpoints.java index dcb72f3e90d..558e3932ddc 100644 --- a/xds/src/main/java/io/grpc/xds/Endpoints.java +++ b/xds/src/main/java/io/grpc/xds/Endpoints.java @@ -59,7 +59,7 @@ abstract static class LbEndpoint { // The endpoint address to be connected to. abstract EquivalentAddressGroup eag(); - // Endpoint's weight for load balancing. If unspecified, value of 0 is returned. + // Endpoint's weight for load balancing. Guaranteed not to be 0. abstract int loadBalancingWeight(); // Whether the endpoint is healthy. @@ -71,6 +71,9 @@ abstract static class LbEndpoint { static LbEndpoint create(EquivalentAddressGroup eag, int loadBalancingWeight, boolean isHealthy, String hostname, ImmutableMap endpointMetadata) { + if (loadBalancingWeight == 0) { + loadBalancingWeight = 1; + } return new AutoValue_Endpoints_LbEndpoint( eag, loadBalancingWeight, isHealthy, hostname, endpointMetadata); } diff --git a/xds/src/main/java/io/grpc/xds/XdsAttributes.java b/xds/src/main/java/io/grpc/xds/XdsAttributes.java index 5647b25e418..d3fe8d4619c 100644 --- a/xds/src/main/java/io/grpc/xds/XdsAttributes.java +++ b/xds/src/main/java/io/grpc/xds/XdsAttributes.java @@ -19,6 +19,7 @@ import io.grpc.Attributes; import io.grpc.EquivalentAddressGroup; import io.grpc.Grpc; +import io.grpc.InternalEquivalentAddressGroup; import io.grpc.NameResolver; import io.grpc.xds.XdsNameResolverProvider.CallCounterProvider; import io.grpc.xds.client.Locality; @@ -84,8 +85,7 @@ final class XdsAttributes { * Endpoint weight for load balancing purposes. */ @EquivalentAddressGroup.Attr - static final Attributes.Key ATTR_SERVER_WEIGHT = - Attributes.Key.create("io.grpc.xds.XdsAttributes.serverWeight"); + static final Attributes.Key ATTR_SERVER_WEIGHT = InternalEquivalentAddressGroup.ATTR_WEIGHT; /** * Filter chain match for network filters. diff --git a/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java index 5e524f79596..c6e5db08526 100644 --- a/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java @@ -260,6 +260,18 @@ public void tearDown() throws Exception { assertThat(fakeClock.getPendingTasks()).isEmpty(); } + @Test + public void edsClustersWithRingHashEndpointLbPolicy_oppositePickFirstWeightedShuffling() + throws Exception { + boolean original = CdsLoadBalancer2.pickFirstWeightedShuffling; + CdsLoadBalancer2.pickFirstWeightedShuffling = !CdsLoadBalancer2.pickFirstWeightedShuffling; + try { + edsClustersWithRingHashEndpointLbPolicy(); + } finally { + CdsLoadBalancer2.pickFirstWeightedShuffling = original; + } + } + @Test public void edsClustersWithRingHashEndpointLbPolicy() throws Exception { boolean originalVal = LoadStatsManager2.isEnabledOrcaLrsPropagation; @@ -306,15 +318,15 @@ public void edsClustersWithRingHashEndpointLbPolicy() throws Exception { assertThat(addr1.getAddresses()) .isEqualTo(Arrays.asList(newInetSocketAddress("127.0.0.1", 8080))); assertThat(addr1.getAttributes().get(io.grpc.xds.XdsAttributes.ATTR_SERVER_WEIGHT)) - .isEqualTo(10); + .isEqualTo(CdsLoadBalancer2.pickFirstWeightedShuffling ? 0x0AAAAAAA /* 1/12 */ : 10); assertThat(addr2.getAddresses()) .isEqualTo(Arrays.asList(newInetSocketAddress("127.0.0.2", 8080))); assertThat(addr2.getAttributes().get(io.grpc.xds.XdsAttributes.ATTR_SERVER_WEIGHT)) - .isEqualTo(10); + .isEqualTo(CdsLoadBalancer2.pickFirstWeightedShuffling ? 0x0AAAAAAA /* 1/12 */ : 10); assertThat(addr3.getAddresses()) .isEqualTo(Arrays.asList(newInetSocketAddress("127.0.1.1", 8080))); assertThat(addr3.getAttributes().get(io.grpc.xds.XdsAttributes.ATTR_SERVER_WEIGHT)) - .isEqualTo(50 * 60); + .isEqualTo(CdsLoadBalancer2.pickFirstWeightedShuffling ? 0x6AAAAAAA /* 5/6 */ : 50 * 60); assertThat(childBalancer.name).isEqualTo(PRIORITY_POLICY_NAME); PriorityLbConfig priorityLbConfig = (PriorityLbConfig) childBalancer.config; assertThat(priorityLbConfig.priorities).containsExactly(CLUSTER + "[child1]");