diff --git a/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java b/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java index 67c5626aff5..5a59b47c529 100644 --- a/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java +++ b/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java @@ -22,6 +22,7 @@ import static io.grpc.xds.XdsLbPolicies.PRIORITY_POLICY_NAME; import com.google.common.collect.ImmutableMap; +import com.google.common.primitives.UnsignedInts; import com.google.errorprone.annotations.CheckReturnValue; import io.grpc.Attributes; import io.grpc.EquivalentAddressGroup; @@ -33,6 +34,7 @@ import io.grpc.NameResolver; import io.grpc.Status; import io.grpc.StatusOr; +import io.grpc.internal.GrpcUtil; import io.grpc.util.GracefulSwitchLoadBalancer; import io.grpc.util.OutlierDetectionLoadBalancer.OutlierDetectionLoadBalancerConfig; import io.grpc.xds.CdsLoadBalancerProvider.CdsConfig; @@ -74,6 +76,9 @@ * by a group of sub-clusters in a tree hierarchy. */ final class CdsLoadBalancer2 extends LoadBalancer { + static boolean pickFirstWeightedShuffling = + GrpcUtil.getFlag("GRPC_EXPERIMENTAL_PF_WEIGHTED_SHUFFLING", true); + private final XdsLogger logger; private final Helper helper; private final LoadBalancerRegistry lbRegistry; @@ -222,6 +227,26 @@ private String errorPrefix() { return "CdsLb for " + clusterName + ": "; } + /** + * The number of bits assigned to the fractional part of fixed-point values. We normalize weights + * to a fixed-point number between 0 and 1, representing that item's proportion of traffic (1 == + * 100% of traffic). We reserve at least one bit for the whole number so that we don't need to + * special case a single item, and so that we can round up very low values without risking uint32 + * overflow of the sum of weights. + */ + private static final int FIXED_POINT_FRACTIONAL_BITS = 31; + + /** Divide two uint32s and produce a fixed-point uint32 result. */ + private static long fractionToFixedPoint(long numerator, long denominator) { + long one = 1L << FIXED_POINT_FRACTIONAL_BITS; + return numerator * one / denominator; + } + + /** Multiply two uint32 fixed-point numbers, returning a uint32 fixed-point. */ + private static long fixedPointMultiply(long a, long b) { + return (a * b) >> FIXED_POINT_FRACTIONAL_BITS; + } + private static StatusOr getEdsUpdate(XdsConfig xdsConfig, String cluster) { StatusOr clusterConfig = xdsConfig.getClusters().get(cluster); if (clusterConfig == null) { @@ -286,17 +311,61 @@ StatusOr edsUpdateToResult( Map> prioritizedLocalityWeights = new HashMap<>(); List sortedPriorityNames = generatePriorityNames(clusterName, localityLbEndpoints); + Map priorityLocalityWeightSums; + if (pickFirstWeightedShuffling) { + priorityLocalityWeightSums = new HashMap<>(sortedPriorityNames.size() * 2); + for (Locality locality : localityLbEndpoints.keySet()) { + LocalityLbEndpoints localityLbInfo = localityLbEndpoints.get(locality); + String priorityName = localityPriorityNames.get(locality); + Long sum = priorityLocalityWeightSums.get(priorityName); + if (sum == null) { + sum = 0L; + } + long weight = UnsignedInts.toLong(localityLbInfo.localityWeight()); + priorityLocalityWeightSums.put(priorityName, sum + weight); + } + } else { + priorityLocalityWeightSums = null; + } + for (Locality locality : localityLbEndpoints.keySet()) { LocalityLbEndpoints localityLbInfo = localityLbEndpoints.get(locality); String priorityName = localityPriorityNames.get(locality); boolean discard = true; + // These sums _should_ fit in uint32, but XdsEndpointResource isn't actually verifying that + // is true today. Since we are using long to avoid signedness trouble, the math happens to + // still work if it turns out the sums exceed uint32. + long localityWeightSum = 0; + long endpointWeightSum = 0; + if (pickFirstWeightedShuffling) { + localityWeightSum = priorityLocalityWeightSums.get(priorityName); + for (LbEndpoint endpoint : localityLbInfo.endpoints()) { + if (endpoint.isHealthy()) { + endpointWeightSum += UnsignedInts.toLong(endpoint.loadBalancingWeight()); + } + } + } for (LbEndpoint endpoint : localityLbInfo.endpoints()) { if (endpoint.isHealthy()) { discard = false; - long weight = localityLbInfo.localityWeight(); - if (endpoint.loadBalancingWeight() != 0) { - weight *= endpoint.loadBalancingWeight(); + long weight; + if (pickFirstWeightedShuffling) { + // Combine locality and endpoint weights as defined by gRFC A113 + long localityWeight = fractionToFixedPoint( + UnsignedInts.toLong(localityLbInfo.localityWeight()), localityWeightSum); + long endpointWeight = fractionToFixedPoint( + UnsignedInts.toLong(endpoint.loadBalancingWeight()), endpointWeightSum); + weight = fixedPointMultiply(localityWeight, endpointWeight); + if (weight == 0) { + weight = 1; + } + } else { + weight = localityLbInfo.localityWeight(); + if (endpoint.loadBalancingWeight() != 0) { + weight *= endpoint.loadBalancingWeight(); + } } + String localityName = localityName(locality); Attributes attr = endpoint.eag().getAttributes().toBuilder() diff --git a/xds/src/main/java/io/grpc/xds/Endpoints.java b/xds/src/main/java/io/grpc/xds/Endpoints.java index dcb72f3e90d..558e3932ddc 100644 --- a/xds/src/main/java/io/grpc/xds/Endpoints.java +++ b/xds/src/main/java/io/grpc/xds/Endpoints.java @@ -59,7 +59,7 @@ abstract static class LbEndpoint { // The endpoint address to be connected to. abstract EquivalentAddressGroup eag(); - // Endpoint's weight for load balancing. If unspecified, value of 0 is returned. + // Endpoint's weight for load balancing. Guaranteed not to be 0. abstract int loadBalancingWeight(); // Whether the endpoint is healthy. @@ -71,6 +71,9 @@ abstract static class LbEndpoint { static LbEndpoint create(EquivalentAddressGroup eag, int loadBalancingWeight, boolean isHealthy, String hostname, ImmutableMap endpointMetadata) { + if (loadBalancingWeight == 0) { + loadBalancingWeight = 1; + } return new AutoValue_Endpoints_LbEndpoint( eag, loadBalancingWeight, isHealthy, hostname, endpointMetadata); } diff --git a/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java index 5e524f79596..c6e5db08526 100644 --- a/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java @@ -260,6 +260,18 @@ public void tearDown() throws Exception { assertThat(fakeClock.getPendingTasks()).isEmpty(); } + @Test + public void edsClustersWithRingHashEndpointLbPolicy_oppositePickFirstWeightedShuffling() + throws Exception { + boolean original = CdsLoadBalancer2.pickFirstWeightedShuffling; + CdsLoadBalancer2.pickFirstWeightedShuffling = !CdsLoadBalancer2.pickFirstWeightedShuffling; + try { + edsClustersWithRingHashEndpointLbPolicy(); + } finally { + CdsLoadBalancer2.pickFirstWeightedShuffling = original; + } + } + @Test public void edsClustersWithRingHashEndpointLbPolicy() throws Exception { boolean originalVal = LoadStatsManager2.isEnabledOrcaLrsPropagation; @@ -306,15 +318,15 @@ public void edsClustersWithRingHashEndpointLbPolicy() throws Exception { assertThat(addr1.getAddresses()) .isEqualTo(Arrays.asList(newInetSocketAddress("127.0.0.1", 8080))); assertThat(addr1.getAttributes().get(io.grpc.xds.XdsAttributes.ATTR_SERVER_WEIGHT)) - .isEqualTo(10); + .isEqualTo(CdsLoadBalancer2.pickFirstWeightedShuffling ? 0x0AAAAAAA /* 1/12 */ : 10); assertThat(addr2.getAddresses()) .isEqualTo(Arrays.asList(newInetSocketAddress("127.0.0.2", 8080))); assertThat(addr2.getAttributes().get(io.grpc.xds.XdsAttributes.ATTR_SERVER_WEIGHT)) - .isEqualTo(10); + .isEqualTo(CdsLoadBalancer2.pickFirstWeightedShuffling ? 0x0AAAAAAA /* 1/12 */ : 10); assertThat(addr3.getAddresses()) .isEqualTo(Arrays.asList(newInetSocketAddress("127.0.1.1", 8080))); assertThat(addr3.getAttributes().get(io.grpc.xds.XdsAttributes.ATTR_SERVER_WEIGHT)) - .isEqualTo(50 * 60); + .isEqualTo(CdsLoadBalancer2.pickFirstWeightedShuffling ? 0x6AAAAAAA /* 5/6 */ : 50 * 60); assertThat(childBalancer.name).isEqualTo(PRIORITY_POLICY_NAME); PriorityLbConfig priorityLbConfig = (PriorityLbConfig) childBalancer.config; assertThat(priorityLbConfig.priorities).containsExactly(CLUSTER + "[child1]");