From 11006a4fb9abf13fc8a38ae34bd54f5b72b1c270 Mon Sep 17 00:00:00 2001 From: Eric Anderson Date: Fri, 23 Jan 2026 12:50:09 -0800 Subject: [PATCH] xds: Normalize weights before combining endpoint and locality weights Previously, the number of endpoints in a locality would skew how much traffic was sent to that locality. Also, if endpoints in localities had wildly different weights, that would impact cross-locality weighting. For example, consider: LocalityA weight=1 endpointWeights=[100, 100, 100, 100] LocalityB weight=1 endpointWeights=[1] The endpoint in LocalityB should have an endpoint weight that is half the total sum of endpoint weights, in order to receive half the traffic. But the multiple endpoints in LocalityA would cause it to get 4x the traffic and the endpoint weights in LocalityA causes them to get 100x the traffic. See gRFC A113 --- .../java/io/grpc/xds/CdsLoadBalancer2.java | 75 ++++++++++++++++++- xds/src/main/java/io/grpc/xds/Endpoints.java | 5 +- .../xds/ClusterResolverLoadBalancerTest.java | 18 ++++- 3 files changed, 91 insertions(+), 7 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java b/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java index 67c5626aff5..5a59b47c529 100644 --- a/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java +++ b/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java @@ -22,6 +22,7 @@ import static io.grpc.xds.XdsLbPolicies.PRIORITY_POLICY_NAME; import com.google.common.collect.ImmutableMap; +import com.google.common.primitives.UnsignedInts; import com.google.errorprone.annotations.CheckReturnValue; import io.grpc.Attributes; import io.grpc.EquivalentAddressGroup; @@ -33,6 +34,7 @@ import io.grpc.NameResolver; import io.grpc.Status; import io.grpc.StatusOr; +import io.grpc.internal.GrpcUtil; import io.grpc.util.GracefulSwitchLoadBalancer; import io.grpc.util.OutlierDetectionLoadBalancer.OutlierDetectionLoadBalancerConfig; import io.grpc.xds.CdsLoadBalancerProvider.CdsConfig; @@ -74,6 +76,9 @@ * by a group of sub-clusters in a tree hierarchy. */ final class CdsLoadBalancer2 extends LoadBalancer { + static boolean pickFirstWeightedShuffling = + GrpcUtil.getFlag("GRPC_EXPERIMENTAL_PF_WEIGHTED_SHUFFLING", true); + private final XdsLogger logger; private final Helper helper; private final LoadBalancerRegistry lbRegistry; @@ -222,6 +227,26 @@ private String errorPrefix() { return "CdsLb for " + clusterName + ": "; } + /** + * The number of bits assigned to the fractional part of fixed-point values. We normalize weights + * to a fixed-point number between 0 and 1, representing that item's proportion of traffic (1 == + * 100% of traffic). We reserve at least one bit for the whole number so that we don't need to + * special case a single item, and so that we can round up very low values without risking uint32 + * overflow of the sum of weights. + */ + private static final int FIXED_POINT_FRACTIONAL_BITS = 31; + + /** Divide two uint32s and produce a fixed-point uint32 result. */ + private static long fractionToFixedPoint(long numerator, long denominator) { + long one = 1L << FIXED_POINT_FRACTIONAL_BITS; + return numerator * one / denominator; + } + + /** Multiply two uint32 fixed-point numbers, returning a uint32 fixed-point. */ + private static long fixedPointMultiply(long a, long b) { + return (a * b) >> FIXED_POINT_FRACTIONAL_BITS; + } + private static StatusOr getEdsUpdate(XdsConfig xdsConfig, String cluster) { StatusOr clusterConfig = xdsConfig.getClusters().get(cluster); if (clusterConfig == null) { @@ -286,17 +311,61 @@ StatusOr edsUpdateToResult( Map> prioritizedLocalityWeights = new HashMap<>(); List sortedPriorityNames = generatePriorityNames(clusterName, localityLbEndpoints); + Map priorityLocalityWeightSums; + if (pickFirstWeightedShuffling) { + priorityLocalityWeightSums = new HashMap<>(sortedPriorityNames.size() * 2); + for (Locality locality : localityLbEndpoints.keySet()) { + LocalityLbEndpoints localityLbInfo = localityLbEndpoints.get(locality); + String priorityName = localityPriorityNames.get(locality); + Long sum = priorityLocalityWeightSums.get(priorityName); + if (sum == null) { + sum = 0L; + } + long weight = UnsignedInts.toLong(localityLbInfo.localityWeight()); + priorityLocalityWeightSums.put(priorityName, sum + weight); + } + } else { + priorityLocalityWeightSums = null; + } + for (Locality locality : localityLbEndpoints.keySet()) { LocalityLbEndpoints localityLbInfo = localityLbEndpoints.get(locality); String priorityName = localityPriorityNames.get(locality); boolean discard = true; + // These sums _should_ fit in uint32, but XdsEndpointResource isn't actually verifying that + // is true today. Since we are using long to avoid signedness trouble, the math happens to + // still work if it turns out the sums exceed uint32. + long localityWeightSum = 0; + long endpointWeightSum = 0; + if (pickFirstWeightedShuffling) { + localityWeightSum = priorityLocalityWeightSums.get(priorityName); + for (LbEndpoint endpoint : localityLbInfo.endpoints()) { + if (endpoint.isHealthy()) { + endpointWeightSum += UnsignedInts.toLong(endpoint.loadBalancingWeight()); + } + } + } for (LbEndpoint endpoint : localityLbInfo.endpoints()) { if (endpoint.isHealthy()) { discard = false; - long weight = localityLbInfo.localityWeight(); - if (endpoint.loadBalancingWeight() != 0) { - weight *= endpoint.loadBalancingWeight(); + long weight; + if (pickFirstWeightedShuffling) { + // Combine locality and endpoint weights as defined by gRFC A113 + long localityWeight = fractionToFixedPoint( + UnsignedInts.toLong(localityLbInfo.localityWeight()), localityWeightSum); + long endpointWeight = fractionToFixedPoint( + UnsignedInts.toLong(endpoint.loadBalancingWeight()), endpointWeightSum); + weight = fixedPointMultiply(localityWeight, endpointWeight); + if (weight == 0) { + weight = 1; + } + } else { + weight = localityLbInfo.localityWeight(); + if (endpoint.loadBalancingWeight() != 0) { + weight *= endpoint.loadBalancingWeight(); + } } + String localityName = localityName(locality); Attributes attr = endpoint.eag().getAttributes().toBuilder() diff --git a/xds/src/main/java/io/grpc/xds/Endpoints.java b/xds/src/main/java/io/grpc/xds/Endpoints.java index dcb72f3e90d..558e3932ddc 100644 --- a/xds/src/main/java/io/grpc/xds/Endpoints.java +++ b/xds/src/main/java/io/grpc/xds/Endpoints.java @@ -59,7 +59,7 @@ abstract static class LbEndpoint { // The endpoint address to be connected to. abstract EquivalentAddressGroup eag(); - // Endpoint's weight for load balancing. If unspecified, value of 0 is returned. + // Endpoint's weight for load balancing. Guaranteed not to be 0. abstract int loadBalancingWeight(); // Whether the endpoint is healthy. @@ -71,6 +71,9 @@ abstract static class LbEndpoint { static LbEndpoint create(EquivalentAddressGroup eag, int loadBalancingWeight, boolean isHealthy, String hostname, ImmutableMap endpointMetadata) { + if (loadBalancingWeight == 0) { + loadBalancingWeight = 1; + } return new AutoValue_Endpoints_LbEndpoint( eag, loadBalancingWeight, isHealthy, hostname, endpointMetadata); } diff --git a/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java index 5e524f79596..c6e5db08526 100644 --- a/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java @@ -260,6 +260,18 @@ public void tearDown() throws Exception { assertThat(fakeClock.getPendingTasks()).isEmpty(); } + @Test + public void edsClustersWithRingHashEndpointLbPolicy_oppositePickFirstWeightedShuffling() + throws Exception { + boolean original = CdsLoadBalancer2.pickFirstWeightedShuffling; + CdsLoadBalancer2.pickFirstWeightedShuffling = !CdsLoadBalancer2.pickFirstWeightedShuffling; + try { + edsClustersWithRingHashEndpointLbPolicy(); + } finally { + CdsLoadBalancer2.pickFirstWeightedShuffling = original; + } + } + @Test public void edsClustersWithRingHashEndpointLbPolicy() throws Exception { boolean originalVal = LoadStatsManager2.isEnabledOrcaLrsPropagation; @@ -306,15 +318,15 @@ public void edsClustersWithRingHashEndpointLbPolicy() throws Exception { assertThat(addr1.getAddresses()) .isEqualTo(Arrays.asList(newInetSocketAddress("127.0.0.1", 8080))); assertThat(addr1.getAttributes().get(io.grpc.xds.XdsAttributes.ATTR_SERVER_WEIGHT)) - .isEqualTo(10); + .isEqualTo(CdsLoadBalancer2.pickFirstWeightedShuffling ? 0x0AAAAAAA /* 1/12 */ : 10); assertThat(addr2.getAddresses()) .isEqualTo(Arrays.asList(newInetSocketAddress("127.0.0.2", 8080))); assertThat(addr2.getAttributes().get(io.grpc.xds.XdsAttributes.ATTR_SERVER_WEIGHT)) - .isEqualTo(10); + .isEqualTo(CdsLoadBalancer2.pickFirstWeightedShuffling ? 0x0AAAAAAA /* 1/12 */ : 10); assertThat(addr3.getAddresses()) .isEqualTo(Arrays.asList(newInetSocketAddress("127.0.1.1", 8080))); assertThat(addr3.getAttributes().get(io.grpc.xds.XdsAttributes.ATTR_SERVER_WEIGHT)) - .isEqualTo(50 * 60); + .isEqualTo(CdsLoadBalancer2.pickFirstWeightedShuffling ? 0x6AAAAAAA /* 5/6 */ : 50 * 60); assertThat(childBalancer.name).isEqualTo(PRIORITY_POLICY_NAME); PriorityLbConfig priorityLbConfig = (PriorityLbConfig) childBalancer.config; assertThat(priorityLbConfig.priorities).containsExactly(CLUSTER + "[child1]");