From 8d12baa4477d177bb4d27c07ee58b139b247820a Mon Sep 17 00:00:00 2001 From: yifeizhuang Date: Mon, 27 Feb 2023 10:34:51 -0800 Subject: [PATCH] xds: add weighted round robin LB policy support (#9873) --- .../io/grpc/util/RoundRobinLoadBalancer.java | 39 +- repositories.bzl | 6 +- xds/BUILD.bazel | 4 + .../grpc/xds/LoadBalancerConfigFactory.java | 75 +- .../xds/WeightedRoundRobinLoadBalancer.java | 471 ++++++++++++ ...eightedRoundRobinLoadBalancerProvider.java | 94 +++ .../java/io/grpc/xds/XdsClusterResource.java | 2 +- .../java/io/grpc/xds/XdsResourceType.java | 4 + .../services/io.grpc.LoadBalancerProvider | 1 + .../xds/LoadBalancerConfigFactoryTest.java | 95 ++- ...tedRoundRobinLoadBalancerProviderTest.java | 116 +++ .../WeightedRoundRobinLoadBalancerTest.java | 673 ++++++++++++++++++ .../io/grpc/xds/XdsClientImplDataTest.java | 68 ++ 13 files changed, 1599 insertions(+), 49 deletions(-) create mode 100644 xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java create mode 100644 xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancerProvider.java create mode 100644 xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerProviderTest.java create mode 100644 xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java diff --git a/core/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java b/core/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java index b715f75614..4649302af1 100644 --- a/core/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java +++ b/core/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java @@ -31,6 +31,7 @@ import io.grpc.Attributes; import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; import io.grpc.EquivalentAddressGroup; +import io.grpc.Internal; import io.grpc.LoadBalancer; import io.grpc.NameResolver; import io.grpc.Status; @@ -50,7 +51,8 @@ import javax.annotation.Nonnull; * A {@link LoadBalancer} that provides round-robin load-balancing over the {@link * EquivalentAddressGroup}s from the {@link NameResolver}. */ -final class RoundRobinLoadBalancer extends LoadBalancer { +@Internal +public class RoundRobinLoadBalancer extends LoadBalancer { @VisibleForTesting static final Attributes.Key> STATE_INFO = Attributes.Key.create("state-info"); @@ -59,11 +61,10 @@ final class RoundRobinLoadBalancer extends LoadBalancer { private final Map subchannels = new HashMap<>(); private final Random random; - private ConnectivityState currentState; - private RoundRobinPicker currentPicker = new EmptyPicker(EMPTY_OK); + protected RoundRobinPicker currentPicker = new EmptyPicker(EMPTY_OK); - RoundRobinLoadBalancer(Helper helper) { + public RoundRobinLoadBalancer(Helper helper) { this.helper = checkNotNull(helper, "helper"); this.random = new Random(); } @@ -207,10 +208,7 @@ final class RoundRobinLoadBalancer extends LoadBalancer { // an arbitrary subchannel, otherwise return OK. new EmptyPicker(aggStatus)); } else { - // initialize the Picker to a random start index to ensure that a high frequency of Picker - // churn does not skew subchannel selection. - int startIndex = random.nextInt(activeList.size()); - updateBalancingState(READY, new ReadyPicker(activeList, startIndex)); + updateBalancingState(READY, createReadyPicker(activeList)); } } @@ -222,6 +220,13 @@ final class RoundRobinLoadBalancer extends LoadBalancer { } } + protected RoundRobinPicker createReadyPicker(List activeList) { + // initialize the Picker to a random start index to ensure that a high frequency of Picker + // churn does not skew subchannel selection. + int startIndex = random.nextInt(activeList.size()); + return new ReadyPicker(activeList, startIndex); + } + /** * Filters out non-ready subchannels. */ @@ -254,7 +259,7 @@ final class RoundRobinLoadBalancer extends LoadBalancer { } @VisibleForTesting - Collection getSubchannels() { + protected Collection getSubchannels() { return subchannels.values(); } @@ -275,12 +280,11 @@ final class RoundRobinLoadBalancer extends LoadBalancer { } // Only subclasses are ReadyPicker or EmptyPicker - private abstract static class RoundRobinPicker extends SubchannelPicker { - abstract boolean isEquivalentTo(RoundRobinPicker picker); + public abstract static class RoundRobinPicker extends SubchannelPicker { + public abstract boolean isEquivalentTo(RoundRobinPicker picker); } - @VisibleForTesting - static final class ReadyPicker extends RoundRobinPicker { + public static class ReadyPicker extends RoundRobinPicker { private static final AtomicIntegerFieldUpdater indexUpdater = AtomicIntegerFieldUpdater.newUpdater(ReadyPicker.class, "index"); @@ -288,7 +292,7 @@ final class RoundRobinLoadBalancer extends LoadBalancer { @SuppressWarnings("unused") private volatile int index; - ReadyPicker(List list, int startIndex) { + public ReadyPicker(List list, int startIndex) { Preconditions.checkArgument(!list.isEmpty(), "empty list"); this.list = list; this.index = startIndex - 1; @@ -321,7 +325,7 @@ final class RoundRobinLoadBalancer extends LoadBalancer { } @Override - boolean isEquivalentTo(RoundRobinPicker picker) { + public boolean isEquivalentTo(RoundRobinPicker picker) { if (!(picker instanceof ReadyPicker)) { return false; } @@ -332,8 +336,7 @@ final class RoundRobinLoadBalancer extends LoadBalancer { } } - @VisibleForTesting - static final class EmptyPicker extends RoundRobinPicker { + public static final class EmptyPicker extends RoundRobinPicker { private final Status status; @@ -347,7 +350,7 @@ final class RoundRobinLoadBalancer extends LoadBalancer { } @Override - boolean isEquivalentTo(RoundRobinPicker picker) { + public boolean isEquivalentTo(RoundRobinPicker picker) { return picker instanceof EmptyPicker && (Objects.equal(status, ((EmptyPicker) picker).status) || (status.isOk() && ((EmptyPicker) picker).status.isOk())); } diff --git a/repositories.bzl b/repositories.bzl index 6c586a69d2..3245427737 100644 --- a/repositories.bzl +++ b/repositories.bzl @@ -137,10 +137,10 @@ def grpc_java_repositories(): if not native.existing_rule("envoy_api"): http_archive( name = "envoy_api", - sha256 = "a0c58442cc2038ccccad9616dd1bab5ff1e65da2bbc0ae41020ef6010119eb0e", - strip_prefix = "data-plane-api-869b00336913138cad96a653458aab650c4e70ea", + sha256 = "74156c0d8738d0469f23047f0fd0f8846fdd0d59d7b55c76cd8cb9ebf2fa3a01", + strip_prefix = "data-plane-api-b1d2e441133c00bfe8412dfd6e93ea85e66da9bb", urls = [ - "https://github.com/envoyproxy/data-plane-api/archive/869b00336913138cad96a653458aab650c4e70ea.tar.gz", + "https://github.com/envoyproxy/data-plane-api/archive/b1d2e441133c00bfe8412dfd6e93ea85e66da9bb.tar.gz", ], ) diff --git a/xds/BUILD.bazel b/xds/BUILD.bazel index e62b183f9e..2d7e18daf1 100644 --- a/xds/BUILD.bazel +++ b/xds/BUILD.bazel @@ -32,6 +32,7 @@ java_library( ":envoy_service_load_stats_v3_java_grpc", ":envoy_service_status_v3_java_grpc", ":xds_protos_java", + ":orca", "//:auto_value_annotations", "//alts", "//api", @@ -40,6 +41,8 @@ java_library( "//core:util", "//netty", "//stub", + "//services:metrics", + "//services:metrics_internal", "@com_google_code_findbugs_jsr305//jar", "@com_google_code_gson_gson//jar", "@com_google_errorprone_error_prone_annotations//jar", @@ -83,6 +86,7 @@ java_proto_library( "@envoy_api//envoy/extensions/filters/http/rbac/v3:pkg", "@envoy_api//envoy/extensions/filters/http/router/v3:pkg", "@envoy_api//envoy/extensions/filters/network/http_connection_manager/v3:pkg", + "@envoy_api//envoy/extensions/load_balancing_policies/client_side_weighted_round_robin/v3:pkg", "@envoy_api//envoy/extensions/load_balancing_policies/least_request/v3:pkg", "@envoy_api//envoy/extensions/load_balancing_policies/ring_hash/v3:pkg", "@envoy_api//envoy/extensions/load_balancing_policies/round_robin/v3:pkg", diff --git a/xds/src/main/java/io/grpc/xds/LoadBalancerConfigFactory.java b/xds/src/main/java/io/grpc/xds/LoadBalancerConfigFactory.java index ce3e95f03d..4b919a4e6f 100644 --- a/xds/src/main/java/io/grpc/xds/LoadBalancerConfigFactory.java +++ b/xds/src/main/java/io/grpc/xds/LoadBalancerConfigFactory.java @@ -22,12 +22,14 @@ import com.google.common.collect.Iterables; import com.google.protobuf.Any; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.Struct; +import com.google.protobuf.util.Durations; import com.google.protobuf.util.JsonFormat; import io.envoyproxy.envoy.config.cluster.v3.Cluster; import io.envoyproxy.envoy.config.cluster.v3.Cluster.LeastRequestLbConfig; import io.envoyproxy.envoy.config.cluster.v3.Cluster.RingHashLbConfig; import io.envoyproxy.envoy.config.cluster.v3.LoadBalancingPolicy; import io.envoyproxy.envoy.config.cluster.v3.LoadBalancingPolicy.Policy; +import io.envoyproxy.envoy.extensions.load_balancing_policies.client_side_weighted_round_robin.v3.ClientSideWeightedRoundRobin; import io.envoyproxy.envoy.extensions.load_balancing_policies.least_request.v3.LeastRequest; import io.envoyproxy.envoy.extensions.load_balancing_policies.ring_hash.v3.RingHash; import io.envoyproxy.envoy.extensions.load_balancing_policies.round_robin.v3.RoundRobin; @@ -73,6 +75,16 @@ class LoadBalancerConfigFactory { static final String WRR_LOCALITY_FIELD_NAME = "wrr_locality_experimental"; static final String CHILD_POLICY_FIELD = "childPolicy"; + static final String BLACK_OUT_PERIOD = "blackoutPeriod"; + + static final String WEIGHT_EXPIRATION_PERIOD = "weightExpirationPeriod"; + + static final String OOB_REPORTING_PERIOD = "oobReportingPeriod"; + + static final String ENABLE_OOB_LOAD_REPORT = "enableOobLoadReport"; + + static final String WEIGHT_UPDATE_PERIOD = "weightUpdatePeriod"; + /** * Factory method for creating a new {link LoadBalancerConfigConverter} for a given xDS {@link * Cluster}. @@ -80,14 +92,14 @@ class LoadBalancerConfigFactory { * @throws ResourceInvalidException If the {@link Cluster} has an invalid LB configuration. */ static ImmutableMap newConfig(Cluster cluster, boolean enableLeastRequest, - boolean enableCustomLbConfig) + boolean enableCustomLbConfig, boolean enableWrr) throws ResourceInvalidException { // The new load_balancing_policy will always be used if it is set, but for backward // compatibility we will fall back to using the old lb_policy field if the new field is not set. if (cluster.hasLoadBalancingPolicy() && enableCustomLbConfig) { try { return LoadBalancingPolicyConverter.convertToServiceConfig(cluster.getLoadBalancingPolicy(), - 0); + 0, enableWrr); } catch (MaxRecursionReachedException e) { throw new ResourceInvalidException("Maximum LB config recursion depth reached", e); } @@ -111,6 +123,35 @@ class LoadBalancerConfigFactory { return ImmutableMap.of(RING_HASH_FIELD_NAME, configBuilder.buildOrThrow()); } + /** + * Builds a service config JSON object for the weighted_round_robin load balancer config based on + * the given config values. + */ + private static ImmutableMap buildWrrConfig(String blackoutPeriod, + String weightExpirationPeriod, + String oobReportingPeriod, + Boolean enableOobLoadReport, + String weightUpdatePeriod) { + ImmutableMap.Builder configBuilder = ImmutableMap.builder(); + if (blackoutPeriod != null) { + configBuilder.put(BLACK_OUT_PERIOD, blackoutPeriod); + } + if (weightExpirationPeriod != null) { + configBuilder.put(WEIGHT_EXPIRATION_PERIOD, weightExpirationPeriod); + } + if (oobReportingPeriod != null) { + configBuilder.put(OOB_REPORTING_PERIOD, oobReportingPeriod); + } + if (enableOobLoadReport != null) { + configBuilder.put(ENABLE_OOB_LOAD_REPORT, enableOobLoadReport); + } + if (weightUpdatePeriod != null) { + configBuilder.put(WEIGHT_UPDATE_PERIOD, weightUpdatePeriod); + } + return ImmutableMap.of(WeightedRoundRobinLoadBalancerProvider.SCHEME, + configBuilder.buildOrThrow()); + } + /** * Builds a service config JSON object for the least_request load balancer config based on the * given config values.. @@ -151,7 +192,7 @@ class LoadBalancerConfigFactory { * Converts a {@link LoadBalancingPolicy} object to a service config JSON object. */ private static ImmutableMap convertToServiceConfig( - LoadBalancingPolicy loadBalancingPolicy, int recursionDepth) + LoadBalancingPolicy loadBalancingPolicy, int recursionDepth, boolean enableWrr) throws ResourceInvalidException, MaxRecursionReachedException { if (recursionDepth > MAX_RECURSION) { throw new MaxRecursionReachedException(); @@ -165,11 +206,16 @@ class LoadBalancerConfigFactory { serviceConfig = convertRingHashConfig(typedConfig.unpack(RingHash.class)); } else if (typedConfig.is(WrrLocality.class)) { serviceConfig = convertWrrLocalityConfig(typedConfig.unpack(WrrLocality.class), - recursionDepth); + recursionDepth, enableWrr); } else if (typedConfig.is(RoundRobin.class)) { serviceConfig = convertRoundRobinConfig(); } else if (typedConfig.is(LeastRequest.class)) { serviceConfig = convertLeastRequestConfig(typedConfig.unpack(LeastRequest.class)); + } else if (typedConfig.is(ClientSideWeightedRoundRobin.class)) { + if (enableWrr) { + serviceConfig = convertWeightedRoundRobinConfig( + typedConfig.unpack(ClientSideWeightedRoundRobin.class)); + } } else if (typedConfig.is(com.github.xds.type.v3.TypedStruct.class)) { serviceConfig = convertCustomConfig( typedConfig.unpack(com.github.xds.type.v3.TypedStruct.class)); @@ -217,14 +263,31 @@ class LoadBalancerConfigFactory { ringHash.hasMaximumRingSize() ? ringHash.getMaximumRingSize().getValue() : null); } + private static ImmutableMap convertWeightedRoundRobinConfig( + ClientSideWeightedRoundRobin wrr) throws ResourceInvalidException { + try { + return buildWrrConfig( + wrr.hasBlackoutPeriod() ? Durations.toString(wrr.getBlackoutPeriod()) : null, + wrr.hasWeightExpirationPeriod() + ? Durations.toString(wrr.getWeightExpirationPeriod()) : null, + wrr.hasOobReportingPeriod() ? Durations.toString(wrr.getOobReportingPeriod()) : null, + wrr.hasEnableOobLoadReport() ? wrr.getEnableOobLoadReport().getValue() : null, + wrr.hasWeightUpdatePeriod() ? Durations.toString(wrr.getWeightUpdatePeriod()) : null); + } catch (IllegalArgumentException ex) { + throw new ResourceInvalidException("Invalid duration in weighted round robin config: " + + ex.getMessage()); + } + } + /** * Converts a wrr_locality {@link Any} configuration to service config format. */ private static ImmutableMap convertWrrLocalityConfig(WrrLocality wrrLocality, - int recursionDepth) throws ResourceInvalidException, + int recursionDepth, boolean enableWrr) throws ResourceInvalidException, MaxRecursionReachedException { return buildWrrLocalityConfig( - convertToServiceConfig(wrrLocality.getEndpointPickingPolicy(), recursionDepth + 1)); + convertToServiceConfig(wrrLocality.getEndpointPickingPolicy(), + recursionDepth + 1, enableWrr)); } /** diff --git a/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java new file mode 100644 index 0000000000..60804fec7b --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java @@ -0,0 +1,471 @@ +/* + * Copyright 2023 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.MoreObjects; +import com.google.common.base.Preconditions; +import io.grpc.ConnectivityState; +import io.grpc.ConnectivityStateInfo; +import io.grpc.Deadline.Ticker; +import io.grpc.EquivalentAddressGroup; +import io.grpc.ExperimentalApi; +import io.grpc.LoadBalancer; +import io.grpc.NameResolver; +import io.grpc.Status; +import io.grpc.SynchronizationContext; +import io.grpc.SynchronizationContext.ScheduledHandle; +import io.grpc.services.MetricReport; +import io.grpc.util.ForwardingLoadBalancerHelper; +import io.grpc.util.ForwardingSubchannel; +import io.grpc.util.RoundRobinLoadBalancer; +import io.grpc.xds.orca.OrcaOobUtil; +import io.grpc.xds.orca.OrcaOobUtil.OrcaOobReportListener; +import io.grpc.xds.orca.OrcaPerRequestUtil; +import io.grpc.xds.orca.OrcaPerRequestUtil.OrcaPerRequestReportListener; +import java.util.HashSet; +import java.util.List; +import java.util.PriorityQueue; +import java.util.Random; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; + +/** + * A {@link LoadBalancer} that provides weighted-round-robin load-balancing over + * the {@link EquivalentAddressGroup}s from the {@link NameResolver}. The subchannel weights are + * determined by backend metrics using ORCA. + */ +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/9885") +final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer { + private volatile WeightedRoundRobinLoadBalancerConfig config; + private final SynchronizationContext syncContext; + private final ScheduledExecutorService timeService; + private ScheduledHandle weightUpdateTimer; + private final Runnable updateWeightTask; + private final Random random; + private final long infTime; + private final Ticker ticker; + + public WeightedRoundRobinLoadBalancer(Helper helper, Ticker ticker) { + this(new WrrHelper(OrcaOobUtil.newOrcaReportingHelper(helper)), ticker); + } + + public WeightedRoundRobinLoadBalancer(WrrHelper helper, Ticker ticker) { + super(helper); + helper.setLoadBalancer(this); + this.ticker = checkNotNull(ticker, "ticker"); + this.infTime = ticker.nanoTime() + Long.MAX_VALUE; + this.syncContext = checkNotNull(helper.getSynchronizationContext(), "syncContext"); + this.timeService = checkNotNull(helper.getScheduledExecutorService(), "timeService"); + this.updateWeightTask = new UpdateWeightTask(); + this.random = new Random(); + } + + @Override + public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { + if (resolvedAddresses.getLoadBalancingPolicyConfig() == null) { + handleNameResolutionError(Status.UNAVAILABLE.withDescription( + "NameResolver returned no WeightedRoundRobinLoadBalancerConfig. addrs=" + + resolvedAddresses.getAddresses() + + ", attrs=" + resolvedAddresses.getAttributes())); + return false; + } + config = + (WeightedRoundRobinLoadBalancerConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); + boolean accepted = super.acceptResolvedAddresses(resolvedAddresses); + if (weightUpdateTimer != null && weightUpdateTimer.isPending()) { + weightUpdateTimer.cancel(); + } + updateWeightTask.run(); + afterAcceptAddresses(); + return accepted; + } + + @Override + public RoundRobinPicker createReadyPicker(List activeList) { + int startIndex = random.nextInt(activeList.size()); + return new WeightedRoundRobinPicker(activeList, startIndex); + } + + private final class UpdateWeightTask implements Runnable { + @Override + public void run() { + if (currentPicker != null && currentPicker instanceof WeightedRoundRobinPicker) { + ((WeightedRoundRobinPicker)currentPicker).updateWeight(); + } + weightUpdateTimer = syncContext.schedule(this, config.weightUpdatePeriodNanos, + TimeUnit.NANOSECONDS, timeService); + } + } + + private void afterAcceptAddresses() { + for (Subchannel subchannel : getSubchannels()) { + WrrSubchannel weightedSubchannel = (WrrSubchannel) subchannel; + if (config.enableOobLoadReport) { + OrcaOobUtil.setListener(weightedSubchannel, weightedSubchannel.oobListener, + OrcaOobUtil.OrcaReportingConfig.newBuilder() + .setReportInterval(config.oobReportingPeriodNanos, TimeUnit.NANOSECONDS) + .build()); + } else { + OrcaOobUtil.setListener(weightedSubchannel, null, null); + } + } + } + + @Override + public void shutdown() { + if (weightUpdateTimer != null) { + weightUpdateTimer.cancel(); + } + super.shutdown(); + } + + private static final class WrrHelper extends ForwardingLoadBalancerHelper { + private final Helper delegate; + private WeightedRoundRobinLoadBalancer wrr; + + WrrHelper(Helper helper) { + this.delegate = helper; + } + + void setLoadBalancer(WeightedRoundRobinLoadBalancer lb) { + this.wrr = lb; + } + + @Override + protected Helper delegate() { + return delegate; + } + + @Override + public Subchannel createSubchannel(CreateSubchannelArgs args) { + return wrr.new WrrSubchannel(delegate().createSubchannel(args)); + } + } + + @VisibleForTesting + final class WrrSubchannel extends ForwardingSubchannel { + private final Subchannel delegate; + private final OrcaOobReportListener oobListener = this::onLoadReport; + private final OrcaPerRequestReportListener perRpcListener = this::onLoadReport; + private volatile long lastUpdated; + private volatile long nonEmptySince; + private volatile double weight; + + WrrSubchannel(Subchannel delegate) { + this.delegate = checkNotNull(delegate, "delegate"); + } + + @VisibleForTesting + void onLoadReport(MetricReport report) { + double newWeight = report.getCpuUtilization() == 0 ? 0 : + report.getQps() / report.getCpuUtilization(); + if (newWeight == 0) { + return; + } + if (nonEmptySince == infTime) { + nonEmptySince = ticker.nanoTime(); + } + lastUpdated = ticker.nanoTime(); + weight = newWeight; + } + + @Override + public void start(SubchannelStateListener listener) { + delegate().start(new SubchannelStateListener() { + @Override + public void onSubchannelState(ConnectivityStateInfo newState) { + if (newState.getState().equals(ConnectivityState.READY)) { + nonEmptySince = infTime; + } + listener.onSubchannelState(newState); + } + }); + } + + private double getWeight() { + if (config == null) { + return 0; + } + long now = ticker.nanoTime(); + if (now - lastUpdated >= config.weightExpirationPeriodNanos) { + nonEmptySince = infTime; + return 0; + } else if (now - nonEmptySince < config.blackoutPeriodNanos + && config.blackoutPeriodNanos > 0) { + return 0; + } else { + return weight; + } + } + + @Override + protected Subchannel delegate() { + return delegate; + } + } + + @VisibleForTesting + final class WeightedRoundRobinPicker extends ReadyPicker { + private final List list; + private volatile EdfScheduler scheduler; + private volatile boolean rrMode; + + WeightedRoundRobinPicker(List list, int startIndex) { + super(checkNotNull(list, "list"), startIndex); + Preconditions.checkArgument(!list.isEmpty(), "empty list"); + this.list = list; + updateWeight(); + } + + @Override + public PickResult pickSubchannel(PickSubchannelArgs args) { + if (rrMode) { + return super.pickSubchannel(args); + } + int pickIndex = scheduler.pick(); + WrrSubchannel subchannel = (WrrSubchannel) list.get(pickIndex); + if (!config.enableOobLoadReport) { + return PickResult.withSubchannel( + subchannel, + OrcaPerRequestUtil.getInstance().newOrcaClientStreamTracerFactory( + subchannel.perRpcListener)); + } else { + return PickResult.withSubchannel(subchannel); + } + } + + private void updateWeight() { + int weightedChannelCount = 0; + double avgWeight = 0; + for (Subchannel value : list) { + double newWeight = ((WrrSubchannel) value).getWeight(); + if (newWeight > 0) { + avgWeight += newWeight; + weightedChannelCount++; + } + } + if (weightedChannelCount < 2) { + rrMode = true; + return; + } + EdfScheduler scheduler = new EdfScheduler(list.size()); + avgWeight /= 1.0 * weightedChannelCount; + for (int i = 0; i < list.size(); i++) { + WrrSubchannel subchannel = (WrrSubchannel) list.get(i); + double newWeight = subchannel.getWeight(); + scheduler.add(i, newWeight > 0 ? newWeight : avgWeight); + } + this.scheduler = scheduler; + rrMode = false; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(WeightedRoundRobinPicker.class) + .add("list", list).add("rrMode", rrMode).toString(); + } + + @VisibleForTesting + List getList() { + return list; + } + + @Override + public boolean isEquivalentTo(RoundRobinPicker picker) { + if (!(picker instanceof WeightedRoundRobinPicker)) { + return false; + } + WeightedRoundRobinPicker other = (WeightedRoundRobinPicker) picker; + // the lists cannot contain duplicate subchannels + return other == this + || (list.size() == other.list.size() && new HashSet<>(list).containsAll(other.list)); + } + } + + /** + * The earliest deadline first implementation in which each object is + * chosen deterministically and periodically with frequency proportional to its weight. + * + *

Specifically, each object added to chooser is given a deadline equal to the multiplicative + * inverse of its weight. The place of each object in its deadline is tracked, and each call to + * choose returns the object with the least remaining time in its deadline. + * (Ties are broken by the order in which the children were added to the chooser.) The deadline + * advances by the multiplicative inverse of the object's weight. + * For example, if items A and B are added with weights 0.5 and 0.2, successive chooses return: + * + *

    + *
  • In the first call, the deadlines are A=2 (1/0.5) and B=5 (1/0.2), so A is returned. + * The deadline of A is updated to 4. + *
  • Next, the remaining deadlines are A=4 and B=5, so A is returned. The deadline of A (2) is + * updated to A=6. + *
  • Remaining deadlines are A=6 and B=5, so B is returned. The deadline of B is updated with + * with B=10. + *
  • Remaining deadlines are A=6 and B=10, so A is returned. The deadline of A is updated with + * A=8. + *
  • Remaining deadlines are A=8 and B=10, so A is returned. The deadline of A is updated with + * A=10. + *
  • Remaining deadlines are A=10 and B=10, so A is returned. The deadline of A is updated + * with A=12. + *
  • Remaining deadlines are A=12 and B=10, so B is returned. The deadline of B is updated + * with B=15. + *
  • etc. + *
+ * + *

In short: the entry with the highest weight is preferred. + * + *

    + *
  • add() - O(lg n) + *
  • pick() - O(lg n) + *
+ * + */ + @VisibleForTesting + static final class EdfScheduler { + private final PriorityQueue prioQueue; + + /** + * Weights below this value will be upped to this minimum weight. + */ + private static final double MINIMUM_WEIGHT = 0.0001; + + private final Object lock = new Object(); + + /** + * Use the item's deadline as the order in the priority queue. If the deadlines are the same, + * use the index. Index should be unique. + */ + EdfScheduler(int initialCapacity) { + this.prioQueue = new PriorityQueue(initialCapacity, (o1, o2) -> { + if (o1.deadline == o2.deadline) { + return Integer.compare(o1.index, o2.index); + } else { + return Double.compare(o1.deadline, o2.deadline); + } + }); + } + + /** + * Adds the item in the scheduler. This is not thread safe. + * + * @param index The field {@link ObjectState#index} to be added + * @param weight positive weight for the added object + */ + void add(int index, double weight) { + checkArgument(weight > 0.0, "Weights need to be positive."); + ObjectState state = new ObjectState(Math.max(weight, MINIMUM_WEIGHT), index); + state.deadline = 1 / state.weight; + // TODO(zivy): randomize the initial deadline. + prioQueue.add(state); + } + + /** + * Picks the next WRR object. + */ + int pick() { + synchronized (lock) { + ObjectState minObject = prioQueue.remove(); + minObject.deadline += 1.0 / minObject.weight; + prioQueue.add(minObject); + return minObject.index; + } + } + } + + /** Holds the state of the object. */ + @VisibleForTesting + static class ObjectState { + private final double weight; + private final int index; + private volatile double deadline; + + ObjectState(double weight, int index) { + this.weight = weight; + this.index = index; + } + } + + static final class WeightedRoundRobinLoadBalancerConfig { + final long blackoutPeriodNanos; + final long weightExpirationPeriodNanos; + final boolean enableOobLoadReport; + final long oobReportingPeriodNanos; + final long weightUpdatePeriodNanos; + + public static Builder newBuilder() { + return new Builder(); + } + + private WeightedRoundRobinLoadBalancerConfig(long blackoutPeriodNanos, + long weightExpirationPeriodNanos, + boolean enableOobLoadReport, + long oobReportingPeriodNanos, + long weightUpdatePeriodNanos) { + this.blackoutPeriodNanos = blackoutPeriodNanos; + this.weightExpirationPeriodNanos = weightExpirationPeriodNanos; + this.enableOobLoadReport = enableOobLoadReport; + this.oobReportingPeriodNanos = oobReportingPeriodNanos; + this.weightUpdatePeriodNanos = weightUpdatePeriodNanos; + } + + static final class Builder { + long blackoutPeriodNanos = 10_000_000_000L; // 10s + long weightExpirationPeriodNanos = 180_000_000_000L; //3min + boolean enableOobLoadReport = false; + long oobReportingPeriodNanos = 10_000_000_000L; // 10s + long weightUpdatePeriodNanos = 1_000_000_000L; // 1s + + private Builder() { + + } + + Builder setBlackoutPeriodNanos(long blackoutPeriodNanos) { + this.blackoutPeriodNanos = blackoutPeriodNanos; + return this; + } + + Builder setWeightExpirationPeriodNanos(long weightExpirationPeriodNanos) { + this.weightExpirationPeriodNanos = weightExpirationPeriodNanos; + return this; + } + + Builder setEnableOobLoadReport(boolean enableOobLoadReport) { + this.enableOobLoadReport = enableOobLoadReport; + return this; + } + + Builder setOobReportingPeriodNanos(long oobReportingPeriodNanos) { + this.oobReportingPeriodNanos = oobReportingPeriodNanos; + return this; + } + + Builder setWeightUpdatePeriodNanos(long weightUpdatePeriodNanos) { + this.weightUpdatePeriodNanos = weightUpdatePeriodNanos; + return this; + } + + WeightedRoundRobinLoadBalancerConfig build() { + return new WeightedRoundRobinLoadBalancerConfig(blackoutPeriodNanos, + weightExpirationPeriodNanos, enableOobLoadReport, oobReportingPeriodNanos, + weightUpdatePeriodNanos); + } + } + } +} diff --git a/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancerProvider.java b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancerProvider.java new file mode 100644 index 0000000000..b1d16d3904 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancerProvider.java @@ -0,0 +1,94 @@ +/* + * Copyright 2023 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import com.google.common.annotations.VisibleForTesting; +import io.grpc.Deadline; +import io.grpc.ExperimentalApi; +import io.grpc.Internal; +import io.grpc.LoadBalancer; +import io.grpc.LoadBalancer.Helper; +import io.grpc.LoadBalancerProvider; +import io.grpc.NameResolver.ConfigOrError; +import io.grpc.internal.JsonUtil; +import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedRoundRobinLoadBalancerConfig; +import java.util.Map; + +/** + * Provides a {@link WeightedRoundRobinLoadBalancer}. + * */ +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/9885") +@Internal +public final class WeightedRoundRobinLoadBalancerProvider extends LoadBalancerProvider { + + @VisibleForTesting + static final long MIN_WEIGHT_UPDATE_PERIOD_NANOS = 100_000_000L; // 100ms + + static final String SCHEME = "weighted_round_robin_experimental"; + + @Override + public LoadBalancer newLoadBalancer(Helper helper) { + return new WeightedRoundRobinLoadBalancer(helper, Deadline.getSystemTicker()); + } + + @Override + public boolean isAvailable() { + return true; + } + + @Override + public int getPriority() { + return 5; + } + + @Override + public String getPolicyName() { + return SCHEME; + } + + @Override + public ConfigOrError parseLoadBalancingPolicyConfig(Map rawConfig) { + Long blackoutPeriodNanos = JsonUtil.getStringAsDuration(rawConfig, "blackoutPeriod"); + Long weightExpirationPeriodNanos = + JsonUtil.getStringAsDuration(rawConfig, "weightExpirationPeriod"); + Long oobReportingPeriodNanos = JsonUtil.getStringAsDuration(rawConfig, "oobReportingPeriod"); + Boolean enableOobLoadReport = JsonUtil.getBoolean(rawConfig, "enableOobLoadReport"); + Long weightUpdatePeriodNanos = JsonUtil.getStringAsDuration(rawConfig, "weightUpdatePeriod"); + + WeightedRoundRobinLoadBalancerConfig.Builder configBuilder = + WeightedRoundRobinLoadBalancerConfig.newBuilder(); + if (blackoutPeriodNanos != null) { + configBuilder.setBlackoutPeriodNanos(blackoutPeriodNanos); + } + if (weightExpirationPeriodNanos != null) { + configBuilder.setWeightExpirationPeriodNanos(weightExpirationPeriodNanos); + } + if (enableOobLoadReport != null) { + configBuilder.setEnableOobLoadReport(enableOobLoadReport); + } + if (oobReportingPeriodNanos != null) { + configBuilder.setOobReportingPeriodNanos(oobReportingPeriodNanos); + } + if (weightUpdatePeriodNanos != null) { + configBuilder.setWeightUpdatePeriodNanos(weightUpdatePeriodNanos); + if (weightUpdatePeriodNanos < MIN_WEIGHT_UPDATE_PERIOD_NANOS) { + configBuilder.setWeightUpdatePeriodNanos(MIN_WEIGHT_UPDATE_PERIOD_NANOS); + } + } + return ConfigOrError.fromConfig(configBuilder.build()); + } +} diff --git a/xds/src/main/java/io/grpc/xds/XdsClusterResource.java b/xds/src/main/java/io/grpc/xds/XdsClusterResource.java index 33f6176474..1dc59feb8b 100644 --- a/xds/src/main/java/io/grpc/xds/XdsClusterResource.java +++ b/xds/src/main/java/io/grpc/xds/XdsClusterResource.java @@ -133,7 +133,7 @@ class XdsClusterResource extends XdsResourceType { CdsUpdate.Builder updateBuilder = structOrError.getStruct(); ImmutableMap lbPolicyConfig = LoadBalancerConfigFactory.newConfig(cluster, - enableLeastRequest, enableCustomLbConfig); + enableLeastRequest, enableCustomLbConfig, enableWrr); // Validate the LB config by trying to parse it with the corresponding LB provider. LbConfig lbConfig = ServiceConfigUtil.unwrapLoadBalancingConfig(lbPolicyConfig); diff --git a/xds/src/main/java/io/grpc/xds/XdsResourceType.java b/xds/src/main/java/io/grpc/xds/XdsResourceType.java index 1302f5a59e..4c19ebf776 100644 --- a/xds/src/main/java/io/grpc/xds/XdsResourceType.java +++ b/xds/src/main/java/io/grpc/xds/XdsResourceType.java @@ -59,6 +59,10 @@ abstract class XdsResourceType { !Strings.isNullOrEmpty(System.getenv("GRPC_EXPERIMENTAL_ENABLE_LEAST_REQUEST")) ? Boolean.parseBoolean(System.getenv("GRPC_EXPERIMENTAL_ENABLE_LEAST_REQUEST")) : Boolean.parseBoolean(System.getProperty("io.grpc.xds.experimentalEnableLeastRequest")); + + @VisibleForTesting + static boolean enableWrr = getFlag("GRPC_EXPERIMENTAL_XDS_WRR_LB", false); + @VisibleForTesting static boolean enableCustomLbConfig = getFlag("GRPC_EXPERIMENTAL_XDS_CUSTOM_LB_CONFIG", true); @VisibleForTesting diff --git a/xds/src/main/resources/META-INF/services/io.grpc.LoadBalancerProvider b/xds/src/main/resources/META-INF/services/io.grpc.LoadBalancerProvider index 6b6e3a392a..e1c4d4aa42 100644 --- a/xds/src/main/resources/META-INF/services/io.grpc.LoadBalancerProvider +++ b/xds/src/main/resources/META-INF/services/io.grpc.LoadBalancerProvider @@ -7,3 +7,4 @@ io.grpc.xds.ClusterImplLoadBalancerProvider io.grpc.xds.LeastRequestLoadBalancerProvider io.grpc.xds.RingHashLoadBalancerProvider io.grpc.xds.WrrLocalityLoadBalancerProvider +io.grpc.xds.WeightedRoundRobinLoadBalancerProvider diff --git a/xds/src/test/java/io/grpc/xds/LoadBalancerConfigFactoryTest.java b/xds/src/test/java/io/grpc/xds/LoadBalancerConfigFactoryTest.java index c7217cb45e..884f04b2f2 100644 --- a/xds/src/test/java/io/grpc/xds/LoadBalancerConfigFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/LoadBalancerConfigFactoryTest.java @@ -24,6 +24,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; import com.google.protobuf.Any; +import com.google.protobuf.BoolValue; +import com.google.protobuf.Duration; import com.google.protobuf.Struct; import com.google.protobuf.UInt32Value; import com.google.protobuf.UInt64Value; @@ -36,6 +38,7 @@ import io.envoyproxy.envoy.config.cluster.v3.Cluster.RingHashLbConfig.HashFuncti import io.envoyproxy.envoy.config.cluster.v3.LoadBalancingPolicy; import io.envoyproxy.envoy.config.cluster.v3.LoadBalancingPolicy.Policy; import io.envoyproxy.envoy.config.core.v3.TypedExtensionConfig; +import io.envoyproxy.envoy.extensions.load_balancing_policies.client_side_weighted_round_robin.v3.ClientSideWeightedRoundRobin; import io.envoyproxy.envoy.extensions.load_balancing_policies.least_request.v3.LeastRequest; import io.envoyproxy.envoy.extensions.load_balancing_policies.ring_hash.v3.RingHash; import io.envoyproxy.envoy.extensions.load_balancing_policies.round_robin.v3.RoundRobin; @@ -78,6 +81,17 @@ public class LoadBalancerConfigFactoryTest { LeastRequest.newBuilder().setChoiceCount(UInt32Value.of(LEAST_REQUEST_CHOICE_COUNT)) .build()))).build(); + private static final Policy WRR_POLICY = Policy.newBuilder() + .setTypedExtensionConfig(TypedExtensionConfig.newBuilder() + .setName("backend") + .setTypedConfig( + Any.pack(ClientSideWeightedRoundRobin.newBuilder() + .setBlackoutPeriod(Duration.newBuilder().setSeconds(287).build()) + .setEnableOobLoadReport( + BoolValue.newBuilder().setValue(true).build()) + .build())) + .build()) + .build(); private static final String CUSTOM_POLICY_NAME = "myorg.MyCustomLeastRequestPolicy"; private static final String CUSTOM_POLICY_FIELD_KEY = "choiceCount"; private static final double CUSTOM_POLICY_FIELD_VALUE = 2; @@ -101,6 +115,11 @@ public class LoadBalancerConfigFactoryTest { private static final LbConfig VALID_ROUND_ROBIN_CONFIG = new LbConfig("wrr_locality_experimental", ImmutableMap.of("childPolicy", ImmutableList.of(ImmutableMap.of("round_robin", ImmutableMap.of())))); + + private static final LbConfig VALID_WRR_CONFIG = new LbConfig("wrr_locality_experimental", + ImmutableMap.of("childPolicy", ImmutableList.of( + ImmutableMap.of("weighted_round_robin_experimental", + ImmutableMap.of("blackoutPeriod","287s", "enableOobLoadReport", true ))))); private static final LbConfig VALID_RING_HASH_CONFIG = new LbConfig("ring_hash_experimental", ImmutableMap.of("minRingSize", (double) RING_HASH_MIN_RING_SIZE, "maxRingSize", (double) RING_HASH_MAX_RING_SIZE)); @@ -123,14 +142,46 @@ public class LoadBalancerConfigFactoryTest { public void roundRobin() throws ResourceInvalidException { Cluster cluster = newCluster(buildWrrPolicy(ROUND_ROBIN_POLICY)); - assertThat(newLbConfig(cluster, true, true)).isEqualTo(VALID_ROUND_ROBIN_CONFIG); + assertThat(newLbConfig(cluster, true, true, true)).isEqualTo(VALID_ROUND_ROBIN_CONFIG); + } + + @Test + public void weightedRoundRobin() throws ResourceInvalidException { + Cluster cluster = newCluster(buildWrrPolicy(WRR_POLICY)); + + assertThat(newLbConfig(cluster, true, true, true)).isEqualTo(VALID_WRR_CONFIG); + } + + @Test + public void weightedRoundRobin_invalid() throws ResourceInvalidException { + Cluster cluster = newCluster(buildWrrPolicy(Policy.newBuilder() + .setTypedExtensionConfig(TypedExtensionConfig.newBuilder() + .setName("backend") + .setTypedConfig( + Any.pack(ClientSideWeightedRoundRobin.newBuilder() + .setBlackoutPeriod(Duration.newBuilder().setNanos(1000000000).build()) + .setEnableOobLoadReport( + BoolValue.newBuilder().setValue(true).build()) + .build())) + .build()) + .build())); + + assertResourceInvalidExceptionThrown(cluster, true, true, true, + "Invalid duration in weighted round robin config"); + } + + @Test + public void weightedRoundRobin_fallback_roundrobin() throws ResourceInvalidException { + Cluster cluster = newCluster(buildWrrPolicy(WRR_POLICY, ROUND_ROBIN_POLICY)); + + assertThat(newLbConfig(cluster, true, true, false)).isEqualTo(VALID_ROUND_ROBIN_CONFIG); } @Test public void roundRobin_legacy() throws ResourceInvalidException { Cluster cluster = Cluster.newBuilder().setLbPolicy(LbPolicy.ROUND_ROBIN).build(); - assertThat(newLbConfig(cluster, true, true)).isEqualTo(VALID_ROUND_ROBIN_CONFIG); + assertThat(newLbConfig(cluster, true, true, true)).isEqualTo(VALID_ROUND_ROBIN_CONFIG); } @Test @@ -139,7 +190,7 @@ public class LoadBalancerConfigFactoryTest { .setLoadBalancingPolicy(LoadBalancingPolicy.newBuilder().addPolicies(RING_HASH_POLICY)) .build(); - assertThat(newLbConfig(cluster, true, true)).isEqualTo(VALID_RING_HASH_CONFIG); + assertThat(newLbConfig(cluster, true, true, true)).isEqualTo(VALID_RING_HASH_CONFIG); } @Test @@ -149,7 +200,7 @@ public class LoadBalancerConfigFactoryTest { .setMaximumRingSize(UInt64Value.of(RING_HASH_MAX_RING_SIZE)) .setHashFunction(HashFunction.XX_HASH)).build(); - assertThat(newLbConfig(cluster, true, true)).isEqualTo(VALID_RING_HASH_CONFIG); + assertThat(newLbConfig(cluster, true, true, true)).isEqualTo(VALID_RING_HASH_CONFIG); } @Test @@ -161,7 +212,7 @@ public class LoadBalancerConfigFactoryTest { .setMaximumRingSize(UInt64Value.of(RING_HASH_MAX_RING_SIZE)) .setHashFunction(RingHash.HashFunction.MURMUR_HASH_2).build()))).build()); - assertResourceInvalidExceptionThrown(cluster, true, true, "Invalid ring hash function"); + assertResourceInvalidExceptionThrown(cluster, true, true, true, "Invalid ring hash function"); } @Test @@ -169,7 +220,7 @@ public class LoadBalancerConfigFactoryTest { Cluster cluster = Cluster.newBuilder().setLbPolicy(LbPolicy.RING_HASH).setRingHashLbConfig( RingHashLbConfig.newBuilder().setHashFunction(HashFunction.MURMUR_HASH_2)).build(); - assertResourceInvalidExceptionThrown(cluster, true, true, "invalid ring hash function"); + assertResourceInvalidExceptionThrown(cluster, true, true, true, "invalid ring hash function"); } @Test @@ -178,7 +229,7 @@ public class LoadBalancerConfigFactoryTest { .setLoadBalancingPolicy(LoadBalancingPolicy.newBuilder().addPolicies(LEAST_REQUEST_POLICY)) .build(); - assertThat(newLbConfig(cluster, true, true)).isEqualTo(VALID_LEAST_REQUEST_CONFIG); + assertThat(newLbConfig(cluster, true, true, true)).isEqualTo(VALID_LEAST_REQUEST_CONFIG); } @Test @@ -190,7 +241,7 @@ public class LoadBalancerConfigFactoryTest { LeastRequestLbConfig.newBuilder() .setChoiceCount(UInt32Value.of(LEAST_REQUEST_CHOICE_COUNT))).build(); - LbConfig lbConfig = newLbConfig(cluster, true, true); + LbConfig lbConfig = newLbConfig(cluster, true, true, true); assertThat(lbConfig.getPolicyName()).isEqualTo("wrr_locality_experimental"); List childConfigs = ServiceConfigUtil.unwrapLoadBalancingConfigList( @@ -207,14 +258,15 @@ public class LoadBalancerConfigFactoryTest { Cluster cluster = Cluster.newBuilder().setLbPolicy(LbPolicy.LEAST_REQUEST).build(); - assertResourceInvalidExceptionThrown(cluster, false, true, "unsupported lb policy"); + assertResourceInvalidExceptionThrown(cluster, false, true, true, "unsupported lb policy"); } @Test public void customRootLb_providerRegistered() throws ResourceInvalidException { LoadBalancerRegistry.getDefaultRegistry().register(CUSTOM_POLICY_PROVIDER); - assertThat(newLbConfig(newCluster(CUSTOM_POLICY), false, true)).isEqualTo(VALID_CUSTOM_CONFIG); + assertThat(newLbConfig(newCluster(CUSTOM_POLICY), false, true, + true)).isEqualTo(VALID_CUSTOM_CONFIG); } @Test @@ -223,7 +275,7 @@ public class LoadBalancerConfigFactoryTest { .setLoadBalancingPolicy(LoadBalancingPolicy.newBuilder().addPolicies(CUSTOM_POLICY)) .build(); - assertResourceInvalidExceptionThrown(cluster, false, true, "Invalid LoadBalancingPolicy"); + assertResourceInvalidExceptionThrown(cluster, false, true, true,"Invalid LoadBalancingPolicy"); } // When a provider for the endpoint picking custom policy is available, the configuration should @@ -235,7 +287,7 @@ public class LoadBalancerConfigFactoryTest { Cluster cluster = Cluster.newBuilder().setLoadBalancingPolicy(LoadBalancingPolicy.newBuilder() .addPolicies(buildWrrPolicy(CUSTOM_POLICY, ROUND_ROBIN_POLICY))).build(); - assertThat(newLbConfig(cluster, false, true)).isEqualTo(VALID_CUSTOM_CONFIG_IN_WRR); + assertThat(newLbConfig(cluster, false, true, true)).isEqualTo(VALID_CUSTOM_CONFIG_IN_WRR); } // When a provider for the endpoint picking custom policy is available, the configuration should @@ -247,7 +299,7 @@ public class LoadBalancerConfigFactoryTest { Cluster cluster = Cluster.newBuilder().setLoadBalancingPolicy(LoadBalancingPolicy.newBuilder() .addPolicies(buildWrrPolicy(CUSTOM_POLICY_UDPA, ROUND_ROBIN_POLICY))).build(); - assertThat(newLbConfig(cluster, false, true)).isEqualTo(VALID_CUSTOM_CONFIG_IN_WRR); + assertThat(newLbConfig(cluster, false, true, true)).isEqualTo(VALID_CUSTOM_CONFIG_IN_WRR); } // When a provider for the custom wrr_locality child policy is NOT available, we should fall back @@ -257,7 +309,7 @@ public class LoadBalancerConfigFactoryTest { Cluster cluster = Cluster.newBuilder().setLoadBalancingPolicy(LoadBalancingPolicy.newBuilder() .addPolicies(buildWrrPolicy(CUSTOM_POLICY, ROUND_ROBIN_POLICY))).build(); - assertThat(newLbConfig(cluster, false, true)).isEqualTo(VALID_ROUND_ROBIN_CONFIG); + assertThat(newLbConfig(cluster, false, true, true)).isEqualTo(VALID_ROUND_ROBIN_CONFIG); } // When a provider for the custom wrr_locality child policy is NOT available and no alternative @@ -267,7 +319,7 @@ public class LoadBalancerConfigFactoryTest { Cluster cluster = Cluster.newBuilder().setLoadBalancingPolicy( LoadBalancingPolicy.newBuilder().addPolicies(buildWrrPolicy(CUSTOM_POLICY))).build(); - assertResourceInvalidExceptionThrown(cluster, false, true, "Invalid LoadBalancingPolicy"); + assertResourceInvalidExceptionThrown(cluster, false, true, true, "Invalid LoadBalancingPolicy"); } @Test @@ -278,7 +330,7 @@ public class LoadBalancerConfigFactoryTest { .build(); // Custom LB flag not set, so we use old logic that will default to round_robin. - assertThat(newLbConfig(cluster, true, false)).isEqualTo(VALID_ROUND_ROBIN_CONFIG); + assertThat(newLbConfig(cluster, true, false, true)).isEqualTo(VALID_ROUND_ROBIN_CONFIG); } @Test @@ -305,7 +357,7 @@ public class LoadBalancerConfigFactoryTest { buildWrrPolicy( ROUND_ROBIN_POLICY))))))))))))))))))).build(); - assertResourceInvalidExceptionThrown(cluster, false, true, + assertResourceInvalidExceptionThrown(cluster, false, true, true, "Maximum LB config recursion depth reached"); } @@ -322,16 +374,17 @@ public class LoadBalancerConfigFactoryTest { } private LbConfig newLbConfig(Cluster cluster, boolean enableLeastRequest, - boolean enableCustomConfig) + boolean enableCustomConfig, boolean enableWrr) throws ResourceInvalidException { return ServiceConfigUtil.unwrapLoadBalancingConfig( - LoadBalancerConfigFactory.newConfig(cluster, enableLeastRequest, enableCustomConfig)); + LoadBalancerConfigFactory.newConfig(cluster, enableLeastRequest, enableCustomConfig, + enableWrr)); } private void assertResourceInvalidExceptionThrown(Cluster cluster, boolean enableLeastRequest, - boolean enableCustomConfig, String expectedMessage) { + boolean enableCustomConfig, boolean enableWrr, String expectedMessage) { try { - newLbConfig(cluster, enableLeastRequest, enableCustomConfig); + newLbConfig(cluster, enableLeastRequest, enableCustomConfig, enableWrr); } catch (ResourceInvalidException e) { assertThat(e).hasMessageThat().contains(expectedMessage); return; diff --git a/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerProviderTest.java b/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerProviderTest.java new file mode 100644 index 0000000000..db72d85525 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerProviderTest.java @@ -0,0 +1,116 @@ +/* + * Copyright 2023 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import io.grpc.InternalServiceProviders; +import io.grpc.LoadBalancer; +import io.grpc.LoadBalancerProvider; +import io.grpc.NameResolver.ConfigOrError; +import io.grpc.SynchronizationContext; +import io.grpc.internal.FakeClock; +import io.grpc.internal.JsonParser; +import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedRoundRobinLoadBalancerConfig; +import java.io.IOException; +import java.util.Map; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link WeightedRoundRobinLoadBalancerProvider}. */ +@RunWith(JUnit4.class) +public class WeightedRoundRobinLoadBalancerProviderTest { + + private final WeightedRoundRobinLoadBalancerProvider provider = + new WeightedRoundRobinLoadBalancerProvider(); + + private final SynchronizationContext syncContext = new SynchronizationContext( + new Thread.UncaughtExceptionHandler() { + @Override + public void uncaughtException(Thread t, Throwable e) { + throw new AssertionError(e); + } + }); + + @Test + public void provided() { + for (LoadBalancerProvider current : InternalServiceProviders.getCandidatesViaServiceLoader( + LoadBalancerProvider.class, getClass().getClassLoader())) { + if (current instanceof WeightedRoundRobinLoadBalancerProvider) { + return; + } + } + fail("WeightedRoundRobinLoadBalancerProvider not registered"); + } + + @Test + public void providesLoadBalancer() { + LoadBalancer.Helper helper = mock(LoadBalancer.Helper.class); + when(helper.getSynchronizationContext()).thenReturn(syncContext); + when(helper.getScheduledExecutorService()).thenReturn( + new FakeClock().getScheduledExecutorService()); + assertThat(provider.newLoadBalancer(helper)) + .isInstanceOf(WeightedRoundRobinLoadBalancer.class); + } + + @Test + public void parseLoadBalancingConfig() throws IOException { + String lbConfig = + "{\"blackoutPeriod\" : \"20s\"," + + " \"weightExpirationPeriod\" : \"300s\"," + + " \"oobReportingPeriod\" : \"100s\"," + + " \"enableOobLoadReport\" : true," + + " \"weightUpdatePeriod\" : \"2s\"" + + " }"; + + ConfigOrError configOrError = provider.parseLoadBalancingPolicyConfig( + parseJsonObject(lbConfig)); + assertThat(configOrError.getConfig()).isNotNull(); + WeightedRoundRobinLoadBalancerConfig config = + (WeightedRoundRobinLoadBalancerConfig) configOrError.getConfig(); + assertThat(config.blackoutPeriodNanos).isEqualTo(20_000_000_000L); + assertThat(config.weightExpirationPeriodNanos).isEqualTo(300_000_000_000L); + assertThat(config.oobReportingPeriodNanos).isEqualTo(100_000_000_000L); + assertThat(config.enableOobLoadReport).isEqualTo(true); + assertThat(config.weightUpdatePeriodNanos).isEqualTo(2_000_000_000L); + } + + @Test + public void parseLoadBalancingConfigDefaultValues() throws IOException { + String lbConfig = "{\"weightUpdatePeriod\" : \"0.02s\"}"; + + ConfigOrError configOrError = provider.parseLoadBalancingPolicyConfig( + parseJsonObject(lbConfig)); + assertThat(configOrError.getConfig()).isNotNull(); + WeightedRoundRobinLoadBalancerConfig config = + (WeightedRoundRobinLoadBalancerConfig) configOrError.getConfig(); + assertThat(config.blackoutPeriodNanos).isEqualTo(10_000_000_000L); + assertThat(config.weightExpirationPeriodNanos).isEqualTo(180_000_000_000L); + assertThat(config.enableOobLoadReport).isEqualTo(false); + assertThat(config.weightUpdatePeriodNanos).isEqualTo(100_000_000L); + } + + + @SuppressWarnings("unchecked") + private static Map parseJsonObject(String json) throws IOException { + return (Map) JsonParser.parse(json); + } +} diff --git a/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java new file mode 100644 index 0000000000..ed8540ff13 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java @@ -0,0 +1,673 @@ +/* + * Copyright 2023 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +import com.github.xds.data.orca.v3.OrcaLoadReport; +import com.github.xds.service.orca.v3.OrcaLoadReportRequest; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.google.protobuf.Duration; +import io.grpc.Attributes; +import io.grpc.Channel; +import io.grpc.ChannelLogger; +import io.grpc.ClientCall; +import io.grpc.ConnectivityState; +import io.grpc.ConnectivityStateInfo; +import io.grpc.EquivalentAddressGroup; +import io.grpc.LoadBalancer; +import io.grpc.LoadBalancer.CreateSubchannelArgs; +import io.grpc.LoadBalancer.Helper; +import io.grpc.LoadBalancer.PickResult; +import io.grpc.LoadBalancer.ResolvedAddresses; +import io.grpc.LoadBalancer.Subchannel; +import io.grpc.LoadBalancer.SubchannelStateListener; +import io.grpc.SynchronizationContext; +import io.grpc.internal.FakeClock; +import io.grpc.services.InternalCallMetricRecorder; +import io.grpc.services.MetricReport; +import io.grpc.util.RoundRobinLoadBalancer.EmptyPicker; +import io.grpc.xds.WeightedRoundRobinLoadBalancer.EdfScheduler; +import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedRoundRobinLoadBalancerConfig; +import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedRoundRobinPicker; +import io.grpc.xds.WeightedRoundRobinLoadBalancer.WrrSubchannel; +import java.net.SocketAddress; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Queue; +import java.util.Random; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.CyclicBarrier; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; +import org.mockito.stubbing.Answer; + +@RunWith(JUnit4.class) +public class WeightedRoundRobinLoadBalancerTest { + @Rule + public final MockitoRule mockito = MockitoJUnit.rule(); + + @Mock + Helper helper; + + @Mock + private LoadBalancer.PickSubchannelArgs mockArgs; + + @Captor + private ArgumentCaptor pickerCaptor; + + private final List servers = Lists.newArrayList(); + + private final Map, Subchannel> subchannels = Maps.newLinkedHashMap(); + + private final Map subchannelStateListeners = + Maps.newLinkedHashMap(); + + private final Queue> oobCalls = + new ConcurrentLinkedQueue<>(); + + private WeightedRoundRobinLoadBalancer wrr; + + private final FakeClock fakeClock = new FakeClock(); + + private WeightedRoundRobinLoadBalancerConfig weightedConfig = + WeightedRoundRobinLoadBalancerConfig.newBuilder().build(); + + private static final Attributes.Key MAJOR_KEY = Attributes.Key.create("major-key"); + + private final Attributes affinity = + Attributes.newBuilder().set(MAJOR_KEY, "I got the keys").build(); + + private final SynchronizationContext syncContext = new SynchronizationContext( + new Thread.UncaughtExceptionHandler() { + @Override + public void uncaughtException(Thread t, Throwable e) { + throw new AssertionError(e); + } + }); + + @Before + public void setup() { + for (int i = 0; i < 3; i++) { + SocketAddress addr = new FakeSocketAddress("server" + i); + EquivalentAddressGroup eag = new EquivalentAddressGroup(addr); + servers.add(eag); + Subchannel sc = mock(Subchannel.class); + Channel channel = mock(Channel.class); + when(channel.newCall(any(), any())).then( + new Answer>() { + @SuppressWarnings("unchecked") + @Override + public ClientCall answer( + InvocationOnMock invocation) throws Throwable { + ClientCall clientCall = mock(ClientCall.class); + oobCalls.add(clientCall); + return clientCall; + } + }); + when(sc.asChannel()).thenReturn(channel); + subchannels.put(Arrays.asList(eag), sc); + } + when(helper.getSynchronizationContext()).thenReturn(syncContext); + when(helper.getScheduledExecutorService()).thenReturn( + fakeClock.getScheduledExecutorService()); + when(helper.createSubchannel(any(CreateSubchannelArgs.class))) + .then(new Answer() { + @Override + public Subchannel answer(InvocationOnMock invocation) throws Throwable { + CreateSubchannelArgs args = (CreateSubchannelArgs) invocation.getArguments()[0]; + final Subchannel subchannel = subchannels.get(args.getAddresses()); + when(subchannel.getAllAddresses()).thenReturn(args.getAddresses()); + when(subchannel.getAttributes()).thenReturn(args.getAttributes()); + when(subchannel.getChannelLogger()).thenReturn(mock(ChannelLogger.class)); + doAnswer( + new Answer() { + @Override + public Void answer(InvocationOnMock invocation) throws Throwable { + subchannelStateListeners.put( + subchannel, (SubchannelStateListener) invocation.getArguments()[0]); + return null; + } + }).when(subchannel).start(any(SubchannelStateListener.class)); + return subchannel; + } + }); + wrr = new WeightedRoundRobinLoadBalancer(helper, fakeClock.getDeadlineTicker()); + } + + @Test + public void wrrLifeCycle() { + syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() + .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) + .setAttributes(affinity).build())); + verify(helper, times(3)).createSubchannel( + any(CreateSubchannelArgs.class)); + assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); + + Iterator it = subchannels.values().iterator(); + Subchannel readySubchannel1 = it.next(); + subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo + .forNonError(ConnectivityState.READY)); + Subchannel readySubchannel2 = it.next(); + subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo + .forNonError(ConnectivityState.READY)); + Subchannel connectingSubchannel = it.next(); + subchannelStateListeners.get(connectingSubchannel).onSubchannelState(ConnectivityStateInfo + .forNonError(ConnectivityState.CONNECTING)); + verify(helper, times(2)).updateBalancingState( + eq(ConnectivityState.READY), pickerCaptor.capture()); + assertThat(pickerCaptor.getAllValues().size()).isEqualTo(2); + assertThat(pickerCaptor.getAllValues().get(0).getList().size()).isEqualTo(1); + WeightedRoundRobinPicker weightedPicker = pickerCaptor.getAllValues().get(1); + assertThat(weightedPicker.getList().size()).isEqualTo(2); + WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0); + WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1); + weightedSubchannel1.onLoadReport(InternalCallMetricRecorder.createMetricReport( + 0.1, 0.1, 1, new HashMap<>(), new HashMap<>())); + weightedSubchannel2.onLoadReport(InternalCallMetricRecorder.createMetricReport( + 0.2, 0.1, 1, new HashMap<>(), new HashMap<>())); + assertThat(fakeClock.forwardTime(11, TimeUnit.SECONDS)).isEqualTo(1); + assertThat(weightedPicker.pickSubchannel(mockArgs) + .getSubchannel()).isEqualTo(weightedSubchannel1); + assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); + weightedConfig = WeightedRoundRobinLoadBalancerConfig.newBuilder() + .setWeightUpdatePeriodNanos(500_000_000L) //.5s + .build(); + syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() + .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) + .setAttributes(affinity).build())); + assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); + + syncContext.execute(() -> wrr.shutdown()); + for (Subchannel subchannel: subchannels.values()) { + verify(subchannel).shutdown(); + } + assertThat(fakeClock.getPendingTasks().size()).isEqualTo(0); + verifyNoMoreInteractions(mockArgs); + } + + @Test + public void enableOobLoadReportConfig() { + syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() + .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) + .setAttributes(affinity).build())); + verify(helper, times(3)).createSubchannel( + any(CreateSubchannelArgs.class)); + Iterator it = subchannels.values().iterator(); + Subchannel readySubchannel1 = it.next(); + subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo + .forNonError(ConnectivityState.READY)); + Subchannel readySubchannel2 = it.next(); + subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo + .forNonError(ConnectivityState.READY)); + verify(helper, times(2)).updateBalancingState( + eq(ConnectivityState.READY), pickerCaptor.capture()); + WeightedRoundRobinPicker weightedPicker = pickerCaptor.getAllValues().get(1); + WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0); + WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1); + weightedSubchannel1.onLoadReport(InternalCallMetricRecorder.createMetricReport( + 0.1, 0.1, 1, new HashMap<>(), new HashMap<>())); + weightedSubchannel2.onLoadReport(InternalCallMetricRecorder.createMetricReport( + 0.9, 0.1, 1, new HashMap<>(), new HashMap<>())); + assertThat(fakeClock.forwardTime(11, TimeUnit.SECONDS)).isEqualTo(1); + PickResult pickResult = weightedPicker.pickSubchannel(mockArgs); + assertThat(pickResult.getSubchannel()).isEqualTo(weightedSubchannel1); + assertThat(pickResult.getStreamTracerFactory()).isNotNull(); // verify per-request listener + assertThat(oobCalls.isEmpty()).isTrue(); + weightedConfig = WeightedRoundRobinLoadBalancerConfig.newBuilder().setEnableOobLoadReport(true) + .setOobReportingPeriodNanos(20_030_000_000L) + .build(); + syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() + .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) + .setAttributes(affinity).build())); + pickResult = weightedPicker.pickSubchannel(mockArgs); + assertThat(pickResult.getSubchannel()).isEqualTo(weightedSubchannel1); + assertThat(pickResult.getStreamTracerFactory()).isNull(); + OrcaLoadReportRequest golden = OrcaLoadReportRequest.newBuilder().setReportInterval( + Duration.newBuilder().setSeconds(20).setNanos(30000000).build()).build(); + assertThat(oobCalls.size()).isEqualTo(2); + verify(oobCalls.poll()).sendMessage(eq(golden)); + verify(oobCalls.poll()).sendMessage(eq(golden)); + } + + private void pickByWeight(MetricReport r1, MetricReport r2, MetricReport r3, + double subchannel1PickRatio, double subchannel2PickRatio, + double subchannel3PickRatio) { + syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() + .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) + .setAttributes(affinity).build())); + verify(helper, times(3)).createSubchannel( + any(CreateSubchannelArgs.class)); + assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); + + Iterator it = subchannels.values().iterator(); + Subchannel readySubchannel1 = it.next(); + subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo + .forNonError(ConnectivityState.READY)); + Subchannel readySubchannel2 = it.next(); + subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo + .forNonError(ConnectivityState.READY)); + Subchannel readySubchannel3 = it.next(); + subchannelStateListeners.get(readySubchannel3).onSubchannelState(ConnectivityStateInfo + .forNonError(ConnectivityState.READY)); + verify(helper, times(3)).updateBalancingState( + eq(ConnectivityState.READY), pickerCaptor.capture()); + WeightedRoundRobinPicker weightedPicker = pickerCaptor.getAllValues().get(2); + WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0); + WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1); + WrrSubchannel weightedSubchannel3 = (WrrSubchannel) weightedPicker.getList().get(2); + weightedSubchannel1.onLoadReport(r1); + weightedSubchannel2.onLoadReport(r2); + weightedSubchannel3.onLoadReport(r3); + assertThat(fakeClock.forwardTime(11, TimeUnit.SECONDS)).isEqualTo(1); + Map pickCount = new HashMap<>(); + for (int i = 0; i < 10000; i++) { + Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel(); + pickCount.put(result, pickCount.getOrDefault(result, 0) + 1); + } + assertThat(pickCount.size()).isEqualTo(3); + assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 10000.0 - subchannel1PickRatio)) + .isAtMost(0.001); + assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 10000.0 - subchannel2PickRatio )) + .isAtMost(0.001); + assertThat(Math.abs(pickCount.get(weightedSubchannel3) / 10000.0 - subchannel3PickRatio )) + .isAtMost(0.001); + } + + @Test + public void pickByWeight_LargeWeight() { + MetricReport report1 = InternalCallMetricRecorder.createMetricReport( + 0.1, 0.1, 999, new HashMap<>(), new HashMap<>()); + MetricReport report2 = InternalCallMetricRecorder.createMetricReport( + 0.9, 0.1, 2, new HashMap<>(), new HashMap<>()); + MetricReport report3 = InternalCallMetricRecorder.createMetricReport( + 0.86, 0.1, 100, new HashMap<>(), new HashMap<>()); + double totalWeight = 999 / 0.1 + 2 / 0.9 + 100 / 0.86; + + pickByWeight(report1, report2, report3, 999 / 0.1 / totalWeight, 2 / 0.9 / totalWeight, + 100 / 0.86 / totalWeight); + } + + @Test + public void pickByWeight_normalWeight() { + MetricReport report1 = InternalCallMetricRecorder.createMetricReport( + 0.12, 0.1, 22, new HashMap<>(), new HashMap<>()); + MetricReport report2 = InternalCallMetricRecorder.createMetricReport( + 0.28, 0.1, 40, new HashMap<>(), new HashMap<>()); + MetricReport report3 = InternalCallMetricRecorder.createMetricReport( + 0.86, 0.1, 100, new HashMap<>(), new HashMap<>()); + double totalWeight = 22 / 0.12 + 40 / 0.28 + 100 / 0.86; + pickByWeight(report1, report2, report3, 22 / 0.12 / totalWeight, + 40 / 0.28 / totalWeight, 100 / 0.86 / totalWeight + ); + } + + @Test + public void emptyConfig() { + assertThat(wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() + .setAddresses(servers).setLoadBalancingPolicyConfig(null) + .setAttributes(affinity).build())).isFalse(); + verify(helper, never()).createSubchannel(any(CreateSubchannelArgs.class)); + verify(helper).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); + assertThat(fakeClock.getPendingTasks()).isEmpty(); + + syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() + .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) + .setAttributes(affinity).build())); + verify(helper, times(3)).createSubchannel( + any(CreateSubchannelArgs.class)); + verify(helper).updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture()); + assertThat(pickerCaptor.getValue()).isInstanceOf(EmptyPicker.class); + assertThat(fakeClock.forwardTime(11, TimeUnit.SECONDS)).isEqualTo(1); + } + + @Test + public void blackoutPeriod() { + syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() + .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) + .setAttributes(affinity).build())); + verify(helper, times(3)).createSubchannel( + any(CreateSubchannelArgs.class)); + assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); + + Iterator it = subchannels.values().iterator(); + Subchannel readySubchannel1 = it.next(); + subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo + .forNonError(ConnectivityState.READY)); + Subchannel readySubchannel2 = it.next(); + subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo + .forNonError(ConnectivityState.READY)); + verify(helper, times(2)).updateBalancingState( + eq(ConnectivityState.READY), pickerCaptor.capture()); + WeightedRoundRobinPicker weightedPicker = pickerCaptor.getAllValues().get(1); + WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0); + WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1); + weightedSubchannel1.onLoadReport(InternalCallMetricRecorder.createMetricReport( + 0.1, 0.1, 1, new HashMap<>(), new HashMap<>())); + weightedSubchannel2.onLoadReport(InternalCallMetricRecorder.createMetricReport( + 0.2, 0.1, 1, new HashMap<>(), new HashMap<>())); + assertThat(fakeClock.forwardTime(5, TimeUnit.SECONDS)).isEqualTo(1); + Map pickCount = new HashMap<>(); + for (int i = 0; i < 1000; i++) { + Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel(); + pickCount.put(result, pickCount.getOrDefault(result, 0) + 1); + } + assertThat(pickCount.size()).isEqualTo(2); + // within blackout period, fallback to simple round robin + assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 0.5)).isAtMost(0.001); + assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 0.5)).isAtMost(0.001); + + assertThat(fakeClock.forwardTime(5, TimeUnit.SECONDS)).isEqualTo(1); + pickCount = new HashMap<>(); + for (int i = 0; i < 1000; i++) { + Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel(); + pickCount.put(result, pickCount.getOrDefault(result, 0) + 1); + } + assertThat(pickCount.size()).isEqualTo(2); + // after blackout period + assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 2.0 / 3)) + .isAtMost(0.001); + assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 1.0 / 3)) + .isAtMost(0.001); + } + + @Test + public void updateWeightTimer() { + syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() + .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) + .setAttributes(affinity).build())); + verify(helper, times(3)).createSubchannel( + any(CreateSubchannelArgs.class)); + assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); + + Iterator it = subchannels.values().iterator(); + Subchannel readySubchannel1 = it.next(); + subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo + .forNonError(ConnectivityState.READY)); + Subchannel readySubchannel2 = it.next(); + subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo + .forNonError(ConnectivityState.READY)); + Subchannel connectingSubchannel = it.next(); + subchannelStateListeners.get(connectingSubchannel).onSubchannelState(ConnectivityStateInfo + .forNonError(ConnectivityState.CONNECTING)); + verify(helper, times(2)).updateBalancingState( + eq(ConnectivityState.READY), pickerCaptor.capture()); + assertThat(pickerCaptor.getAllValues().size()).isEqualTo(2); + assertThat(pickerCaptor.getAllValues().get(0).getList().size()).isEqualTo(1); + WeightedRoundRobinPicker weightedPicker = pickerCaptor.getAllValues().get(1); + assertThat(weightedPicker.getList().size()).isEqualTo(2); + WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0); + WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1); + weightedSubchannel1.onLoadReport(InternalCallMetricRecorder.createMetricReport( + 0.1, 0.1, 1, new HashMap<>(), new HashMap<>())); + weightedSubchannel2.onLoadReport(InternalCallMetricRecorder.createMetricReport( + 0.2, 0.1, 1, new HashMap<>(), new HashMap<>())); + assertThat(fakeClock.forwardTime(11, TimeUnit.SECONDS)).isEqualTo(1); + assertThat(weightedPicker.pickSubchannel(mockArgs) + .getSubchannel()).isEqualTo(weightedSubchannel1); + assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); + weightedConfig = WeightedRoundRobinLoadBalancerConfig.newBuilder() + .setWeightUpdatePeriodNanos(500_000_000L) //.5s + .build(); + syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() + .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) + .setAttributes(affinity).build())); + assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); + weightedSubchannel1.onLoadReport(InternalCallMetricRecorder.createMetricReport( + 0.2, 0.1, 1, new HashMap<>(), new HashMap<>())); + weightedSubchannel2.onLoadReport(InternalCallMetricRecorder.createMetricReport( + 0.1, 0.1, 1, new HashMap<>(), new HashMap<>())); + //timer fires, new weight updated + assertThat(fakeClock.forwardTime(500, TimeUnit.MILLISECONDS)).isEqualTo(1); + assertThat(weightedPicker.pickSubchannel(mockArgs) + .getSubchannel()).isEqualTo(weightedSubchannel2); + } + + @Test + public void weightExpired() { + syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() + .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) + .setAttributes(affinity).build())); + verify(helper, times(3)).createSubchannel( + any(CreateSubchannelArgs.class)); + assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); + + Iterator it = subchannels.values().iterator(); + Subchannel readySubchannel1 = it.next(); + subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo + .forNonError(ConnectivityState.READY)); + Subchannel readySubchannel2 = it.next(); + subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo + .forNonError(ConnectivityState.READY)); + verify(helper, times(2)).updateBalancingState( + eq(ConnectivityState.READY), pickerCaptor.capture()); + WeightedRoundRobinPicker weightedPicker = pickerCaptor.getAllValues().get(1); + WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0); + WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1); + weightedSubchannel1.onLoadReport(InternalCallMetricRecorder.createMetricReport( + 0.1, 0.1, 1, new HashMap<>(), new HashMap<>())); + weightedSubchannel2.onLoadReport(InternalCallMetricRecorder.createMetricReport( + 0.2, 0.1, 1, new HashMap<>(), new HashMap<>())); + assertThat(fakeClock.forwardTime(10, TimeUnit.SECONDS)).isEqualTo(1); + Map pickCount = new HashMap<>(); + for (int i = 0; i < 1000; i++) { + Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel(); + pickCount.put(result, pickCount.getOrDefault(result, 0) + 1); + } + assertThat(pickCount.size()).isEqualTo(2); + assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 2.0 / 3)) + .isAtMost(0.001); + assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 1.0 / 3)) + .isAtMost(0.001); + + // weight expired, fallback to simple round robin + assertThat(fakeClock.forwardTime(300, TimeUnit.SECONDS)).isEqualTo(1); + pickCount = new HashMap<>(); + for (int i = 0; i < 1000; i++) { + Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel(); + pickCount.put(result, pickCount.getOrDefault(result, 0) + 1); + } + assertThat(pickCount.size()).isEqualTo(2); + assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 0.5)) + .isAtMost(0.001); + assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 0.5)) + .isAtMost(0.001); + } + + @Test + public void unknownWeightIsAvgWeight() { + syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() + .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) + .setAttributes(affinity).build())); + verify(helper, times(3)).createSubchannel( + any(CreateSubchannelArgs.class)); + assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); + + Iterator it = subchannels.values().iterator(); + Subchannel readySubchannel1 = it.next(); + subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo + .forNonError(ConnectivityState.READY)); + Subchannel readySubchannel2 = it.next(); + subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo + .forNonError(ConnectivityState.READY)); + Subchannel readySubchannel3 = it.next(); + subchannelStateListeners.get(readySubchannel3).onSubchannelState(ConnectivityStateInfo + .forNonError(ConnectivityState.READY)); + verify(helper, times(3)).updateBalancingState( + eq(ConnectivityState.READY), pickerCaptor.capture()); + WeightedRoundRobinPicker weightedPicker = pickerCaptor.getAllValues().get(2); + WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0); + WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1); + WrrSubchannel weightedSubchannel3 = (WrrSubchannel) weightedPicker.getList().get(2); + weightedSubchannel1.onLoadReport(InternalCallMetricRecorder.createMetricReport( + 0.1, 0.1, 1, new HashMap<>(), new HashMap<>())); + weightedSubchannel2.onLoadReport(InternalCallMetricRecorder.createMetricReport( + 0.2, 0.1, 1, new HashMap<>(), new HashMap<>())); + assertThat(fakeClock.forwardTime(10, TimeUnit.SECONDS)).isEqualTo(1); + Map pickCount = new HashMap<>(); + for (int i = 0; i < 1000; i++) { + Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel(); + pickCount.put(result, pickCount.getOrDefault(result, 0) + 1); + } + assertThat(pickCount.size()).isEqualTo(3); + assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 4.0 / 9)) + .isAtMost(0.001); + assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 2.0 / 9)) + .isAtMost(0.001); + // subchannel3's weight is average of subchannel1 and subchannel2 + assertThat(Math.abs(pickCount.get(weightedSubchannel3) / 1000.0 - 3.0 / 9)) + .isAtMost(0.001); + } + + @Test + public void pickFromOtherThread() throws Exception { + syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() + .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) + .setAttributes(affinity).build())); + verify(helper, times(3)).createSubchannel( + any(CreateSubchannelArgs.class)); + assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); + + Iterator it = subchannels.values().iterator(); + Subchannel readySubchannel1 = it.next(); + subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo + .forNonError(ConnectivityState.READY)); + Subchannel readySubchannel2 = it.next(); + subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo + .forNonError(ConnectivityState.READY)); + verify(helper, times(2)).updateBalancingState( + eq(ConnectivityState.READY), pickerCaptor.capture()); + WeightedRoundRobinPicker weightedPicker = pickerCaptor.getAllValues().get(1); + WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0); + WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1); + weightedSubchannel1.onLoadReport(InternalCallMetricRecorder.createMetricReport( + 0.1, 0.1, 1, new HashMap<>(), new HashMap<>())); + weightedSubchannel2.onLoadReport(InternalCallMetricRecorder.createMetricReport( + 0.2, 0.1, 1, new HashMap<>(), new HashMap<>())); + assertThat(weightedPicker.toString()).contains("rrMode=true"); + CyclicBarrier barrier = new CyclicBarrier(2); + Map pickCount = new ConcurrentHashMap<>(); + pickCount.put(weightedSubchannel1, new AtomicInteger(0)); + pickCount.put(weightedSubchannel2, new AtomicInteger(0)); + new Thread(new Runnable() { + @Override + public void run() { + try { + weightedPicker.pickSubchannel(mockArgs); + barrier.await(); + for (int i = 0; i < 1000; i++) { + Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel(); + pickCount.get(result).addAndGet(1); + } + barrier.await(); + } catch (Exception ex) { + throw new AssertionError(ex); + } + } + }).start(); + assertThat(fakeClock.forwardTime(10, TimeUnit.SECONDS)).isEqualTo(1); + barrier.await(); + for (int i = 0; i < 1000; i++) { + Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel(); + pickCount.get(result).addAndGet(1); + } + barrier.await(); + assertThat(pickCount.size()).isEqualTo(2); + // after blackout period + assertThat(Math.abs(pickCount.get(weightedSubchannel1).get() / 2000.0 - 2.0 / 3)) + .isAtMost(0.001); + assertThat(Math.abs(pickCount.get(weightedSubchannel2).get() / 2000.0 - 1.0 / 3)) + .isAtMost(0.001); + } + + @Test + public void edfScheduler() { + Random random = new Random(); + double totalWeight = 0; + int capacity = random.nextInt(10) + 1; + double[] weights = new double[capacity]; + EdfScheduler scheduler = new EdfScheduler(capacity); + for (int i = 0; i < capacity; i++) { + weights[i] = random.nextDouble(); + scheduler.add(i, weights[i]); + totalWeight += weights[i]; + } + Map pickCount = new HashMap<>(); + for (int i = 0; i < 1000; i++) { + int result = scheduler.pick(); + pickCount.put(result, pickCount.getOrDefault(result, 0) + 1); + } + for (int i = 0; i < capacity; i++) { + assertThat(Math.abs(pickCount.get(i) / 1000.0 - weights[i] / totalWeight) ).isAtMost(0.01); + } + } + + @Test + public void edsScheduler_sameWeight() { + EdfScheduler scheduler = new EdfScheduler(2); + scheduler.add(0, 0.5); + scheduler.add(1, 0.5); + assertThat(scheduler.pick()).isEqualTo(0); + } + + @Test(expected = NullPointerException.class) + public void wrrConfig_TimeValueNonNull() { + WeightedRoundRobinLoadBalancerConfig.newBuilder().setBlackoutPeriodNanos((Long) null); + } + + @Test(expected = NullPointerException.class) + public void wrrConfig_BooleanValueNonNull() { + WeightedRoundRobinLoadBalancerConfig.newBuilder().setEnableOobLoadReport((Boolean) null); + } + + private static class FakeSocketAddress extends SocketAddress { + final String name; + + FakeSocketAddress(String name) { + this.name = name; + } + + @Override public String toString() { + return "FakeSocketAddress-" + name; + } + } +} diff --git a/xds/src/test/java/io/grpc/xds/XdsClientImplDataTest.java b/xds/src/test/java/io/grpc/xds/XdsClientImplDataTest.java index e66dfd7b62..a39ecb46fa 100644 --- a/xds/src/test/java/io/grpc/xds/XdsClientImplDataTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsClientImplDataTest.java @@ -27,6 +27,7 @@ import com.google.common.collect.Iterables; import com.google.protobuf.Any; import com.google.protobuf.BoolValue; import com.google.protobuf.ByteString; +import com.google.protobuf.Duration; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.Message; import com.google.protobuf.StringValue; @@ -39,6 +40,7 @@ import io.envoyproxy.envoy.config.cluster.v3.Cluster; import io.envoyproxy.envoy.config.cluster.v3.Cluster.DiscoveryType; import io.envoyproxy.envoy.config.cluster.v3.Cluster.EdsClusterConfig; import io.envoyproxy.envoy.config.cluster.v3.Cluster.LbPolicy; +import io.envoyproxy.envoy.config.cluster.v3.LoadBalancingPolicy; import io.envoyproxy.envoy.config.core.v3.Address; import io.envoyproxy.envoy.config.core.v3.AggregatedConfigSource; import io.envoyproxy.envoy.config.core.v3.CidrRange; @@ -85,6 +87,8 @@ import io.envoyproxy.envoy.extensions.filters.http.router.v3.Router; import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager; import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.HttpFilter; import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.Rds; +import io.envoyproxy.envoy.extensions.load_balancing_policies.client_side_weighted_round_robin.v3.ClientSideWeightedRoundRobin; +import io.envoyproxy.envoy.extensions.load_balancing_policies.wrr_locality.v3.WrrLocality; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateProviderPluginInstance; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; @@ -128,6 +132,7 @@ import io.grpc.xds.VirtualHost.Route.RouteAction.ClusterWeight; import io.grpc.xds.VirtualHost.Route.RouteAction.HashPolicy; import io.grpc.xds.VirtualHost.Route.RouteMatch; import io.grpc.xds.VirtualHost.Route.RouteMatch.PathMatcher; +import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedRoundRobinLoadBalancerConfig; import io.grpc.xds.XdsClientImpl.ResourceInvalidException; import io.grpc.xds.XdsClusterResource.CdsUpdate; import io.grpc.xds.XdsResourceType.StructOrError; @@ -163,6 +168,7 @@ public class XdsClientImplDataTest { private boolean originalEnableRbac; private boolean originalEnableRouteLookup; private boolean originalEnableLeastRequest; + private boolean originalEnableWrr; @Before public void setUp() { @@ -174,6 +180,8 @@ public class XdsClientImplDataTest { assertThat(originalEnableRouteLookup).isFalse(); originalEnableLeastRequest = XdsResourceType.enableLeastRequest; assertThat(originalEnableLeastRequest).isFalse(); + originalEnableWrr = XdsResourceType.enableWrr; + assertThat(originalEnableWrr).isFalse(); } @After @@ -182,6 +190,7 @@ public class XdsClientImplDataTest { XdsResourceType.enableRbac = originalEnableRbac; XdsResourceType.enableRouteLookup = originalEnableRouteLookup; XdsResourceType.enableLeastRequest = originalEnableLeastRequest; + XdsResourceType.enableWrr = originalEnableWrr; } @Test @@ -1966,6 +1975,65 @@ public class XdsClientImplDataTest { assertThat(childConfigs.get(0).getPolicyName()).isEqualTo("least_request_experimental"); } + @Test + public void parseCluster_WrrLbPolicy_defaultLbConfig() throws ResourceInvalidException { + XdsResourceType.enableWrr = true; + + LoadBalancingPolicy wrrConfig = + LoadBalancingPolicy.newBuilder().addPolicies( + LoadBalancingPolicy.Policy.newBuilder() + .setTypedExtensionConfig(TypedExtensionConfig.newBuilder() + .setName("backend") + .setTypedConfig( + Any.pack(ClientSideWeightedRoundRobin.newBuilder() + .setBlackoutPeriod(Duration.newBuilder().setSeconds(17).build()) + .setEnableOobLoadReport( + BoolValue.newBuilder().setValue(true).build()) + .build())) + .build()) + .build()) + .build(); + + Cluster cluster = Cluster.newBuilder() + .setName("cluster-foo.googleapis.com") + .setType(DiscoveryType.EDS) + .setEdsClusterConfig( + EdsClusterConfig.newBuilder() + .setEdsConfig( + ConfigSource.newBuilder() + .setAds(AggregatedConfigSource.getDefaultInstance())) + .setServiceName("service-foo.googleapis.com")) + .setLoadBalancingPolicy( + LoadBalancingPolicy.newBuilder().addPolicies( + LoadBalancingPolicy.Policy.newBuilder() + .setTypedExtensionConfig( + TypedExtensionConfig.newBuilder() + .setTypedConfig( + Any.pack(WrrLocality.newBuilder() + .setEndpointPickingPolicy(wrrConfig) + .build())) + .build()) + .build()) + .build()) + .build(); + CdsUpdate update = XdsClusterResource.processCluster( + cluster, null, LRS_SERVER_INFO, + LoadBalancerRegistry.getDefaultRegistry()); + LbConfig lbConfig = ServiceConfigUtil.unwrapLoadBalancingConfig(update.lbPolicyConfig()); + assertThat(lbConfig.getPolicyName()).isEqualTo("wrr_locality_experimental"); + List childConfigs = ServiceConfigUtil.unwrapLoadBalancingConfigList( + JsonUtil.getListOfObjects(lbConfig.getRawConfigValue(), "childPolicy")); + assertThat(childConfigs.get(0).getPolicyName()).isEqualTo("weighted_round_robin_experimental"); + WeightedRoundRobinLoadBalancerConfig result = (WeightedRoundRobinLoadBalancerConfig) + new WeightedRoundRobinLoadBalancerProvider().parseLoadBalancingPolicyConfig( + childConfigs.get(0).getRawConfigValue()).getConfig(); + assertThat(result.blackoutPeriodNanos).isEqualTo(17_000_000_000L); + assertThat(result.enableOobLoadReport).isTrue(); + assertThat(result.oobReportingPeriodNanos).isEqualTo(10_000_000_000L); + assertThat(result.weightUpdatePeriodNanos).isEqualTo(1_000_000_000L); + assertThat(result.weightExpirationPeriodNanos).isEqualTo(180_000_000_000L); + } + @Test public void parseCluster_transportSocketMatches_exception() throws ResourceInvalidException { Cluster cluster = Cluster.newBuilder()