xds: add weighted round robin LB policy support (#9873)

This commit is contained in:
yifeizhuang 2023-02-27 10:34:51 -08:00 committed by GitHub
parent cc28dfdb36
commit 8d12baa447
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 1599 additions and 49 deletions

View File

@ -31,6 +31,7 @@ import io.grpc.Attributes;
import io.grpc.ConnectivityState; import io.grpc.ConnectivityState;
import io.grpc.ConnectivityStateInfo; import io.grpc.ConnectivityStateInfo;
import io.grpc.EquivalentAddressGroup; import io.grpc.EquivalentAddressGroup;
import io.grpc.Internal;
import io.grpc.LoadBalancer; import io.grpc.LoadBalancer;
import io.grpc.NameResolver; import io.grpc.NameResolver;
import io.grpc.Status; import io.grpc.Status;
@ -50,7 +51,8 @@ import javax.annotation.Nonnull;
* A {@link LoadBalancer} that provides round-robin load-balancing over the {@link * A {@link LoadBalancer} that provides round-robin load-balancing over the {@link
* EquivalentAddressGroup}s from the {@link NameResolver}. * EquivalentAddressGroup}s from the {@link NameResolver}.
*/ */
final class RoundRobinLoadBalancer extends LoadBalancer { @Internal
public class RoundRobinLoadBalancer extends LoadBalancer {
@VisibleForTesting @VisibleForTesting
static final Attributes.Key<Ref<ConnectivityStateInfo>> STATE_INFO = static final Attributes.Key<Ref<ConnectivityStateInfo>> STATE_INFO =
Attributes.Key.create("state-info"); Attributes.Key.create("state-info");
@ -59,11 +61,10 @@ final class RoundRobinLoadBalancer extends LoadBalancer {
private final Map<EquivalentAddressGroup, Subchannel> subchannels = private final Map<EquivalentAddressGroup, Subchannel> subchannels =
new HashMap<>(); new HashMap<>();
private final Random random; private final Random random;
private ConnectivityState currentState; private ConnectivityState currentState;
private RoundRobinPicker currentPicker = new EmptyPicker(EMPTY_OK); protected RoundRobinPicker currentPicker = new EmptyPicker(EMPTY_OK);
RoundRobinLoadBalancer(Helper helper) { public RoundRobinLoadBalancer(Helper helper) {
this.helper = checkNotNull(helper, "helper"); this.helper = checkNotNull(helper, "helper");
this.random = new Random(); this.random = new Random();
} }
@ -207,10 +208,7 @@ final class RoundRobinLoadBalancer extends LoadBalancer {
// an arbitrary subchannel, otherwise return OK. // an arbitrary subchannel, otherwise return OK.
new EmptyPicker(aggStatus)); new EmptyPicker(aggStatus));
} else { } else {
// initialize the Picker to a random start index to ensure that a high frequency of Picker updateBalancingState(READY, createReadyPicker(activeList));
// churn does not skew subchannel selection.
int startIndex = random.nextInt(activeList.size());
updateBalancingState(READY, new ReadyPicker(activeList, startIndex));
} }
} }
@ -222,6 +220,13 @@ final class RoundRobinLoadBalancer extends LoadBalancer {
} }
} }
protected RoundRobinPicker createReadyPicker(List<Subchannel> activeList) {
// initialize the Picker to a random start index to ensure that a high frequency of Picker
// churn does not skew subchannel selection.
int startIndex = random.nextInt(activeList.size());
return new ReadyPicker(activeList, startIndex);
}
/** /**
* Filters out non-ready subchannels. * Filters out non-ready subchannels.
*/ */
@ -254,7 +259,7 @@ final class RoundRobinLoadBalancer extends LoadBalancer {
} }
@VisibleForTesting @VisibleForTesting
Collection<Subchannel> getSubchannels() { protected Collection<Subchannel> getSubchannels() {
return subchannels.values(); return subchannels.values();
} }
@ -275,12 +280,11 @@ final class RoundRobinLoadBalancer extends LoadBalancer {
} }
// Only subclasses are ReadyPicker or EmptyPicker // Only subclasses are ReadyPicker or EmptyPicker
private abstract static class RoundRobinPicker extends SubchannelPicker { public abstract static class RoundRobinPicker extends SubchannelPicker {
abstract boolean isEquivalentTo(RoundRobinPicker picker); public abstract boolean isEquivalentTo(RoundRobinPicker picker);
} }
@VisibleForTesting public static class ReadyPicker extends RoundRobinPicker {
static final class ReadyPicker extends RoundRobinPicker {
private static final AtomicIntegerFieldUpdater<ReadyPicker> indexUpdater = private static final AtomicIntegerFieldUpdater<ReadyPicker> indexUpdater =
AtomicIntegerFieldUpdater.newUpdater(ReadyPicker.class, "index"); AtomicIntegerFieldUpdater.newUpdater(ReadyPicker.class, "index");
@ -288,7 +292,7 @@ final class RoundRobinLoadBalancer extends LoadBalancer {
@SuppressWarnings("unused") @SuppressWarnings("unused")
private volatile int index; private volatile int index;
ReadyPicker(List<Subchannel> list, int startIndex) { public ReadyPicker(List<Subchannel> list, int startIndex) {
Preconditions.checkArgument(!list.isEmpty(), "empty list"); Preconditions.checkArgument(!list.isEmpty(), "empty list");
this.list = list; this.list = list;
this.index = startIndex - 1; this.index = startIndex - 1;
@ -321,7 +325,7 @@ final class RoundRobinLoadBalancer extends LoadBalancer {
} }
@Override @Override
boolean isEquivalentTo(RoundRobinPicker picker) { public boolean isEquivalentTo(RoundRobinPicker picker) {
if (!(picker instanceof ReadyPicker)) { if (!(picker instanceof ReadyPicker)) {
return false; return false;
} }
@ -332,8 +336,7 @@ final class RoundRobinLoadBalancer extends LoadBalancer {
} }
} }
@VisibleForTesting public static final class EmptyPicker extends RoundRobinPicker {
static final class EmptyPicker extends RoundRobinPicker {
private final Status status; private final Status status;
@ -347,7 +350,7 @@ final class RoundRobinLoadBalancer extends LoadBalancer {
} }
@Override @Override
boolean isEquivalentTo(RoundRobinPicker picker) { public boolean isEquivalentTo(RoundRobinPicker picker) {
return picker instanceof EmptyPicker && (Objects.equal(status, ((EmptyPicker) picker).status) return picker instanceof EmptyPicker && (Objects.equal(status, ((EmptyPicker) picker).status)
|| (status.isOk() && ((EmptyPicker) picker).status.isOk())); || (status.isOk() && ((EmptyPicker) picker).status.isOk()));
} }

View File

@ -137,10 +137,10 @@ def grpc_java_repositories():
if not native.existing_rule("envoy_api"): if not native.existing_rule("envoy_api"):
http_archive( http_archive(
name = "envoy_api", name = "envoy_api",
sha256 = "a0c58442cc2038ccccad9616dd1bab5ff1e65da2bbc0ae41020ef6010119eb0e", sha256 = "74156c0d8738d0469f23047f0fd0f8846fdd0d59d7b55c76cd8cb9ebf2fa3a01",
strip_prefix = "data-plane-api-869b00336913138cad96a653458aab650c4e70ea", strip_prefix = "data-plane-api-b1d2e441133c00bfe8412dfd6e93ea85e66da9bb",
urls = [ urls = [
"https://github.com/envoyproxy/data-plane-api/archive/869b00336913138cad96a653458aab650c4e70ea.tar.gz", "https://github.com/envoyproxy/data-plane-api/archive/b1d2e441133c00bfe8412dfd6e93ea85e66da9bb.tar.gz",
], ],
) )

View File

@ -32,6 +32,7 @@ java_library(
":envoy_service_load_stats_v3_java_grpc", ":envoy_service_load_stats_v3_java_grpc",
":envoy_service_status_v3_java_grpc", ":envoy_service_status_v3_java_grpc",
":xds_protos_java", ":xds_protos_java",
":orca",
"//:auto_value_annotations", "//:auto_value_annotations",
"//alts", "//alts",
"//api", "//api",
@ -40,6 +41,8 @@ java_library(
"//core:util", "//core:util",
"//netty", "//netty",
"//stub", "//stub",
"//services:metrics",
"//services:metrics_internal",
"@com_google_code_findbugs_jsr305//jar", "@com_google_code_findbugs_jsr305//jar",
"@com_google_code_gson_gson//jar", "@com_google_code_gson_gson//jar",
"@com_google_errorprone_error_prone_annotations//jar", "@com_google_errorprone_error_prone_annotations//jar",
@ -83,6 +86,7 @@ java_proto_library(
"@envoy_api//envoy/extensions/filters/http/rbac/v3:pkg", "@envoy_api//envoy/extensions/filters/http/rbac/v3:pkg",
"@envoy_api//envoy/extensions/filters/http/router/v3:pkg", "@envoy_api//envoy/extensions/filters/http/router/v3:pkg",
"@envoy_api//envoy/extensions/filters/network/http_connection_manager/v3:pkg", "@envoy_api//envoy/extensions/filters/network/http_connection_manager/v3:pkg",
"@envoy_api//envoy/extensions/load_balancing_policies/client_side_weighted_round_robin/v3:pkg",
"@envoy_api//envoy/extensions/load_balancing_policies/least_request/v3:pkg", "@envoy_api//envoy/extensions/load_balancing_policies/least_request/v3:pkg",
"@envoy_api//envoy/extensions/load_balancing_policies/ring_hash/v3:pkg", "@envoy_api//envoy/extensions/load_balancing_policies/ring_hash/v3:pkg",
"@envoy_api//envoy/extensions/load_balancing_policies/round_robin/v3:pkg", "@envoy_api//envoy/extensions/load_balancing_policies/round_robin/v3:pkg",

View File

@ -22,12 +22,14 @@ import com.google.common.collect.Iterables;
import com.google.protobuf.Any; import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Struct; import com.google.protobuf.Struct;
import com.google.protobuf.util.Durations;
import com.google.protobuf.util.JsonFormat; import com.google.protobuf.util.JsonFormat;
import io.envoyproxy.envoy.config.cluster.v3.Cluster; import io.envoyproxy.envoy.config.cluster.v3.Cluster;
import io.envoyproxy.envoy.config.cluster.v3.Cluster.LeastRequestLbConfig; import io.envoyproxy.envoy.config.cluster.v3.Cluster.LeastRequestLbConfig;
import io.envoyproxy.envoy.config.cluster.v3.Cluster.RingHashLbConfig; import io.envoyproxy.envoy.config.cluster.v3.Cluster.RingHashLbConfig;
import io.envoyproxy.envoy.config.cluster.v3.LoadBalancingPolicy; import io.envoyproxy.envoy.config.cluster.v3.LoadBalancingPolicy;
import io.envoyproxy.envoy.config.cluster.v3.LoadBalancingPolicy.Policy; import io.envoyproxy.envoy.config.cluster.v3.LoadBalancingPolicy.Policy;
import io.envoyproxy.envoy.extensions.load_balancing_policies.client_side_weighted_round_robin.v3.ClientSideWeightedRoundRobin;
import io.envoyproxy.envoy.extensions.load_balancing_policies.least_request.v3.LeastRequest; import io.envoyproxy.envoy.extensions.load_balancing_policies.least_request.v3.LeastRequest;
import io.envoyproxy.envoy.extensions.load_balancing_policies.ring_hash.v3.RingHash; import io.envoyproxy.envoy.extensions.load_balancing_policies.ring_hash.v3.RingHash;
import io.envoyproxy.envoy.extensions.load_balancing_policies.round_robin.v3.RoundRobin; import io.envoyproxy.envoy.extensions.load_balancing_policies.round_robin.v3.RoundRobin;
@ -73,6 +75,16 @@ class LoadBalancerConfigFactory {
static final String WRR_LOCALITY_FIELD_NAME = "wrr_locality_experimental"; static final String WRR_LOCALITY_FIELD_NAME = "wrr_locality_experimental";
static final String CHILD_POLICY_FIELD = "childPolicy"; static final String CHILD_POLICY_FIELD = "childPolicy";
static final String BLACK_OUT_PERIOD = "blackoutPeriod";
static final String WEIGHT_EXPIRATION_PERIOD = "weightExpirationPeriod";
static final String OOB_REPORTING_PERIOD = "oobReportingPeriod";
static final String ENABLE_OOB_LOAD_REPORT = "enableOobLoadReport";
static final String WEIGHT_UPDATE_PERIOD = "weightUpdatePeriod";
/** /**
* Factory method for creating a new {link LoadBalancerConfigConverter} for a given xDS {@link * Factory method for creating a new {link LoadBalancerConfigConverter} for a given xDS {@link
* Cluster}. * Cluster}.
@ -80,14 +92,14 @@ class LoadBalancerConfigFactory {
* @throws ResourceInvalidException If the {@link Cluster} has an invalid LB configuration. * @throws ResourceInvalidException If the {@link Cluster} has an invalid LB configuration.
*/ */
static ImmutableMap<String, ?> newConfig(Cluster cluster, boolean enableLeastRequest, static ImmutableMap<String, ?> newConfig(Cluster cluster, boolean enableLeastRequest,
boolean enableCustomLbConfig) boolean enableCustomLbConfig, boolean enableWrr)
throws ResourceInvalidException { throws ResourceInvalidException {
// The new load_balancing_policy will always be used if it is set, but for backward // The new load_balancing_policy will always be used if it is set, but for backward
// compatibility we will fall back to using the old lb_policy field if the new field is not set. // compatibility we will fall back to using the old lb_policy field if the new field is not set.
if (cluster.hasLoadBalancingPolicy() && enableCustomLbConfig) { if (cluster.hasLoadBalancingPolicy() && enableCustomLbConfig) {
try { try {
return LoadBalancingPolicyConverter.convertToServiceConfig(cluster.getLoadBalancingPolicy(), return LoadBalancingPolicyConverter.convertToServiceConfig(cluster.getLoadBalancingPolicy(),
0); 0, enableWrr);
} catch (MaxRecursionReachedException e) { } catch (MaxRecursionReachedException e) {
throw new ResourceInvalidException("Maximum LB config recursion depth reached", e); throw new ResourceInvalidException("Maximum LB config recursion depth reached", e);
} }
@ -111,6 +123,35 @@ class LoadBalancerConfigFactory {
return ImmutableMap.of(RING_HASH_FIELD_NAME, configBuilder.buildOrThrow()); return ImmutableMap.of(RING_HASH_FIELD_NAME, configBuilder.buildOrThrow());
} }
/**
* Builds a service config JSON object for the weighted_round_robin load balancer config based on
* the given config values.
*/
private static ImmutableMap<String, ?> buildWrrConfig(String blackoutPeriod,
String weightExpirationPeriod,
String oobReportingPeriod,
Boolean enableOobLoadReport,
String weightUpdatePeriod) {
ImmutableMap.Builder<String, Object> configBuilder = ImmutableMap.builder();
if (blackoutPeriod != null) {
configBuilder.put(BLACK_OUT_PERIOD, blackoutPeriod);
}
if (weightExpirationPeriod != null) {
configBuilder.put(WEIGHT_EXPIRATION_PERIOD, weightExpirationPeriod);
}
if (oobReportingPeriod != null) {
configBuilder.put(OOB_REPORTING_PERIOD, oobReportingPeriod);
}
if (enableOobLoadReport != null) {
configBuilder.put(ENABLE_OOB_LOAD_REPORT, enableOobLoadReport);
}
if (weightUpdatePeriod != null) {
configBuilder.put(WEIGHT_UPDATE_PERIOD, weightUpdatePeriod);
}
return ImmutableMap.of(WeightedRoundRobinLoadBalancerProvider.SCHEME,
configBuilder.buildOrThrow());
}
/** /**
* Builds a service config JSON object for the least_request load balancer config based on the * Builds a service config JSON object for the least_request load balancer config based on the
* given config values.. * given config values..
@ -151,7 +192,7 @@ class LoadBalancerConfigFactory {
* Converts a {@link LoadBalancingPolicy} object to a service config JSON object. * Converts a {@link LoadBalancingPolicy} object to a service config JSON object.
*/ */
private static ImmutableMap<String, ?> convertToServiceConfig( private static ImmutableMap<String, ?> convertToServiceConfig(
LoadBalancingPolicy loadBalancingPolicy, int recursionDepth) LoadBalancingPolicy loadBalancingPolicy, int recursionDepth, boolean enableWrr)
throws ResourceInvalidException, MaxRecursionReachedException { throws ResourceInvalidException, MaxRecursionReachedException {
if (recursionDepth > MAX_RECURSION) { if (recursionDepth > MAX_RECURSION) {
throw new MaxRecursionReachedException(); throw new MaxRecursionReachedException();
@ -165,11 +206,16 @@ class LoadBalancerConfigFactory {
serviceConfig = convertRingHashConfig(typedConfig.unpack(RingHash.class)); serviceConfig = convertRingHashConfig(typedConfig.unpack(RingHash.class));
} else if (typedConfig.is(WrrLocality.class)) { } else if (typedConfig.is(WrrLocality.class)) {
serviceConfig = convertWrrLocalityConfig(typedConfig.unpack(WrrLocality.class), serviceConfig = convertWrrLocalityConfig(typedConfig.unpack(WrrLocality.class),
recursionDepth); recursionDepth, enableWrr);
} else if (typedConfig.is(RoundRobin.class)) { } else if (typedConfig.is(RoundRobin.class)) {
serviceConfig = convertRoundRobinConfig(); serviceConfig = convertRoundRobinConfig();
} else if (typedConfig.is(LeastRequest.class)) { } else if (typedConfig.is(LeastRequest.class)) {
serviceConfig = convertLeastRequestConfig(typedConfig.unpack(LeastRequest.class)); serviceConfig = convertLeastRequestConfig(typedConfig.unpack(LeastRequest.class));
} else if (typedConfig.is(ClientSideWeightedRoundRobin.class)) {
if (enableWrr) {
serviceConfig = convertWeightedRoundRobinConfig(
typedConfig.unpack(ClientSideWeightedRoundRobin.class));
}
} else if (typedConfig.is(com.github.xds.type.v3.TypedStruct.class)) { } else if (typedConfig.is(com.github.xds.type.v3.TypedStruct.class)) {
serviceConfig = convertCustomConfig( serviceConfig = convertCustomConfig(
typedConfig.unpack(com.github.xds.type.v3.TypedStruct.class)); typedConfig.unpack(com.github.xds.type.v3.TypedStruct.class));
@ -217,14 +263,31 @@ class LoadBalancerConfigFactory {
ringHash.hasMaximumRingSize() ? ringHash.getMaximumRingSize().getValue() : null); ringHash.hasMaximumRingSize() ? ringHash.getMaximumRingSize().getValue() : null);
} }
private static ImmutableMap<String, ?> convertWeightedRoundRobinConfig(
ClientSideWeightedRoundRobin wrr) throws ResourceInvalidException {
try {
return buildWrrConfig(
wrr.hasBlackoutPeriod() ? Durations.toString(wrr.getBlackoutPeriod()) : null,
wrr.hasWeightExpirationPeriod()
? Durations.toString(wrr.getWeightExpirationPeriod()) : null,
wrr.hasOobReportingPeriod() ? Durations.toString(wrr.getOobReportingPeriod()) : null,
wrr.hasEnableOobLoadReport() ? wrr.getEnableOobLoadReport().getValue() : null,
wrr.hasWeightUpdatePeriod() ? Durations.toString(wrr.getWeightUpdatePeriod()) : null);
} catch (IllegalArgumentException ex) {
throw new ResourceInvalidException("Invalid duration in weighted round robin config: "
+ ex.getMessage());
}
}
/** /**
* Converts a wrr_locality {@link Any} configuration to service config format. * Converts a wrr_locality {@link Any} configuration to service config format.
*/ */
private static ImmutableMap<String, ?> convertWrrLocalityConfig(WrrLocality wrrLocality, private static ImmutableMap<String, ?> convertWrrLocalityConfig(WrrLocality wrrLocality,
int recursionDepth) throws ResourceInvalidException, int recursionDepth, boolean enableWrr) throws ResourceInvalidException,
MaxRecursionReachedException { MaxRecursionReachedException {
return buildWrrLocalityConfig( return buildWrrLocalityConfig(
convertToServiceConfig(wrrLocality.getEndpointPickingPolicy(), recursionDepth + 1)); convertToServiceConfig(wrrLocality.getEndpointPickingPolicy(),
recursionDepth + 1, enableWrr));
} }
/** /**

View File

@ -0,0 +1,471 @@
/*
* Copyright 2023 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.MoreObjects;
import com.google.common.base.Preconditions;
import io.grpc.ConnectivityState;
import io.grpc.ConnectivityStateInfo;
import io.grpc.Deadline.Ticker;
import io.grpc.EquivalentAddressGroup;
import io.grpc.ExperimentalApi;
import io.grpc.LoadBalancer;
import io.grpc.NameResolver;
import io.grpc.Status;
import io.grpc.SynchronizationContext;
import io.grpc.SynchronizationContext.ScheduledHandle;
import io.grpc.services.MetricReport;
import io.grpc.util.ForwardingLoadBalancerHelper;
import io.grpc.util.ForwardingSubchannel;
import io.grpc.util.RoundRobinLoadBalancer;
import io.grpc.xds.orca.OrcaOobUtil;
import io.grpc.xds.orca.OrcaOobUtil.OrcaOobReportListener;
import io.grpc.xds.orca.OrcaPerRequestUtil;
import io.grpc.xds.orca.OrcaPerRequestUtil.OrcaPerRequestReportListener;
import java.util.HashSet;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Random;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
/**
* A {@link LoadBalancer} that provides weighted-round-robin load-balancing over
* the {@link EquivalentAddressGroup}s from the {@link NameResolver}. The subchannel weights are
* determined by backend metrics using ORCA.
*/
@ExperimentalApi("https://github.com/grpc/grpc-java/issues/9885")
final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
private volatile WeightedRoundRobinLoadBalancerConfig config;
private final SynchronizationContext syncContext;
private final ScheduledExecutorService timeService;
private ScheduledHandle weightUpdateTimer;
private final Runnable updateWeightTask;
private final Random random;
private final long infTime;
private final Ticker ticker;
public WeightedRoundRobinLoadBalancer(Helper helper, Ticker ticker) {
this(new WrrHelper(OrcaOobUtil.newOrcaReportingHelper(helper)), ticker);
}
public WeightedRoundRobinLoadBalancer(WrrHelper helper, Ticker ticker) {
super(helper);
helper.setLoadBalancer(this);
this.ticker = checkNotNull(ticker, "ticker");
this.infTime = ticker.nanoTime() + Long.MAX_VALUE;
this.syncContext = checkNotNull(helper.getSynchronizationContext(), "syncContext");
this.timeService = checkNotNull(helper.getScheduledExecutorService(), "timeService");
this.updateWeightTask = new UpdateWeightTask();
this.random = new Random();
}
@Override
public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) {
if (resolvedAddresses.getLoadBalancingPolicyConfig() == null) {
handleNameResolutionError(Status.UNAVAILABLE.withDescription(
"NameResolver returned no WeightedRoundRobinLoadBalancerConfig. addrs="
+ resolvedAddresses.getAddresses()
+ ", attrs=" + resolvedAddresses.getAttributes()));
return false;
}
config =
(WeightedRoundRobinLoadBalancerConfig) resolvedAddresses.getLoadBalancingPolicyConfig();
boolean accepted = super.acceptResolvedAddresses(resolvedAddresses);
if (weightUpdateTimer != null && weightUpdateTimer.isPending()) {
weightUpdateTimer.cancel();
}
updateWeightTask.run();
afterAcceptAddresses();
return accepted;
}
@Override
public RoundRobinPicker createReadyPicker(List<Subchannel> activeList) {
int startIndex = random.nextInt(activeList.size());
return new WeightedRoundRobinPicker(activeList, startIndex);
}
private final class UpdateWeightTask implements Runnable {
@Override
public void run() {
if (currentPicker != null && currentPicker instanceof WeightedRoundRobinPicker) {
((WeightedRoundRobinPicker)currentPicker).updateWeight();
}
weightUpdateTimer = syncContext.schedule(this, config.weightUpdatePeriodNanos,
TimeUnit.NANOSECONDS, timeService);
}
}
private void afterAcceptAddresses() {
for (Subchannel subchannel : getSubchannels()) {
WrrSubchannel weightedSubchannel = (WrrSubchannel) subchannel;
if (config.enableOobLoadReport) {
OrcaOobUtil.setListener(weightedSubchannel, weightedSubchannel.oobListener,
OrcaOobUtil.OrcaReportingConfig.newBuilder()
.setReportInterval(config.oobReportingPeriodNanos, TimeUnit.NANOSECONDS)
.build());
} else {
OrcaOobUtil.setListener(weightedSubchannel, null, null);
}
}
}
@Override
public void shutdown() {
if (weightUpdateTimer != null) {
weightUpdateTimer.cancel();
}
super.shutdown();
}
private static final class WrrHelper extends ForwardingLoadBalancerHelper {
private final Helper delegate;
private WeightedRoundRobinLoadBalancer wrr;
WrrHelper(Helper helper) {
this.delegate = helper;
}
void setLoadBalancer(WeightedRoundRobinLoadBalancer lb) {
this.wrr = lb;
}
@Override
protected Helper delegate() {
return delegate;
}
@Override
public Subchannel createSubchannel(CreateSubchannelArgs args) {
return wrr.new WrrSubchannel(delegate().createSubchannel(args));
}
}
@VisibleForTesting
final class WrrSubchannel extends ForwardingSubchannel {
private final Subchannel delegate;
private final OrcaOobReportListener oobListener = this::onLoadReport;
private final OrcaPerRequestReportListener perRpcListener = this::onLoadReport;
private volatile long lastUpdated;
private volatile long nonEmptySince;
private volatile double weight;
WrrSubchannel(Subchannel delegate) {
this.delegate = checkNotNull(delegate, "delegate");
}
@VisibleForTesting
void onLoadReport(MetricReport report) {
double newWeight = report.getCpuUtilization() == 0 ? 0 :
report.getQps() / report.getCpuUtilization();
if (newWeight == 0) {
return;
}
if (nonEmptySince == infTime) {
nonEmptySince = ticker.nanoTime();
}
lastUpdated = ticker.nanoTime();
weight = newWeight;
}
@Override
public void start(SubchannelStateListener listener) {
delegate().start(new SubchannelStateListener() {
@Override
public void onSubchannelState(ConnectivityStateInfo newState) {
if (newState.getState().equals(ConnectivityState.READY)) {
nonEmptySince = infTime;
}
listener.onSubchannelState(newState);
}
});
}
private double getWeight() {
if (config == null) {
return 0;
}
long now = ticker.nanoTime();
if (now - lastUpdated >= config.weightExpirationPeriodNanos) {
nonEmptySince = infTime;
return 0;
} else if (now - nonEmptySince < config.blackoutPeriodNanos
&& config.blackoutPeriodNanos > 0) {
return 0;
} else {
return weight;
}
}
@Override
protected Subchannel delegate() {
return delegate;
}
}
@VisibleForTesting
final class WeightedRoundRobinPicker extends ReadyPicker {
private final List<Subchannel> list;
private volatile EdfScheduler scheduler;
private volatile boolean rrMode;
WeightedRoundRobinPicker(List<Subchannel> list, int startIndex) {
super(checkNotNull(list, "list"), startIndex);
Preconditions.checkArgument(!list.isEmpty(), "empty list");
this.list = list;
updateWeight();
}
@Override
public PickResult pickSubchannel(PickSubchannelArgs args) {
if (rrMode) {
return super.pickSubchannel(args);
}
int pickIndex = scheduler.pick();
WrrSubchannel subchannel = (WrrSubchannel) list.get(pickIndex);
if (!config.enableOobLoadReport) {
return PickResult.withSubchannel(
subchannel,
OrcaPerRequestUtil.getInstance().newOrcaClientStreamTracerFactory(
subchannel.perRpcListener));
} else {
return PickResult.withSubchannel(subchannel);
}
}
private void updateWeight() {
int weightedChannelCount = 0;
double avgWeight = 0;
for (Subchannel value : list) {
double newWeight = ((WrrSubchannel) value).getWeight();
if (newWeight > 0) {
avgWeight += newWeight;
weightedChannelCount++;
}
}
if (weightedChannelCount < 2) {
rrMode = true;
return;
}
EdfScheduler scheduler = new EdfScheduler(list.size());
avgWeight /= 1.0 * weightedChannelCount;
for (int i = 0; i < list.size(); i++) {
WrrSubchannel subchannel = (WrrSubchannel) list.get(i);
double newWeight = subchannel.getWeight();
scheduler.add(i, newWeight > 0 ? newWeight : avgWeight);
}
this.scheduler = scheduler;
rrMode = false;
}
@Override
public String toString() {
return MoreObjects.toStringHelper(WeightedRoundRobinPicker.class)
.add("list", list).add("rrMode", rrMode).toString();
}
@VisibleForTesting
List<Subchannel> getList() {
return list;
}
@Override
public boolean isEquivalentTo(RoundRobinPicker picker) {
if (!(picker instanceof WeightedRoundRobinPicker)) {
return false;
}
WeightedRoundRobinPicker other = (WeightedRoundRobinPicker) picker;
// the lists cannot contain duplicate subchannels
return other == this
|| (list.size() == other.list.size() && new HashSet<>(list).containsAll(other.list));
}
}
/**
* The earliest deadline first implementation in which each object is
* chosen deterministically and periodically with frequency proportional to its weight.
*
* <p>Specifically, each object added to chooser is given a deadline equal to the multiplicative
* inverse of its weight. The place of each object in its deadline is tracked, and each call to
* choose returns the object with the least remaining time in its deadline.
* (Ties are broken by the order in which the children were added to the chooser.) The deadline
* advances by the multiplicative inverse of the object's weight.
* For example, if items A and B are added with weights 0.5 and 0.2, successive chooses return:
*
* <ul>
* <li>In the first call, the deadlines are A=2 (1/0.5) and B=5 (1/0.2), so A is returned.
* The deadline of A is updated to 4.
* <li>Next, the remaining deadlines are A=4 and B=5, so A is returned. The deadline of A (2) is
* updated to A=6.
* <li>Remaining deadlines are A=6 and B=5, so B is returned. The deadline of B is updated with
* with B=10.
* <li>Remaining deadlines are A=6 and B=10, so A is returned. The deadline of A is updated with
* A=8.
* <li>Remaining deadlines are A=8 and B=10, so A is returned. The deadline of A is updated with
* A=10.
* <li>Remaining deadlines are A=10 and B=10, so A is returned. The deadline of A is updated
* with A=12.
* <li>Remaining deadlines are A=12 and B=10, so B is returned. The deadline of B is updated
* with B=15.
* <li>etc.
* </ul>
*
* <p>In short: the entry with the highest weight is preferred.
*
* <ul>
* <li>add() - O(lg n)
* <li>pick() - O(lg n)
* </ul>
*
*/
@VisibleForTesting
static final class EdfScheduler {
private final PriorityQueue<ObjectState> prioQueue;
/**
* Weights below this value will be upped to this minimum weight.
*/
private static final double MINIMUM_WEIGHT = 0.0001;
private final Object lock = new Object();
/**
* Use the item's deadline as the order in the priority queue. If the deadlines are the same,
* use the index. Index should be unique.
*/
EdfScheduler(int initialCapacity) {
this.prioQueue = new PriorityQueue<ObjectState>(initialCapacity, (o1, o2) -> {
if (o1.deadline == o2.deadline) {
return Integer.compare(o1.index, o2.index);
} else {
return Double.compare(o1.deadline, o2.deadline);
}
});
}
/**
* Adds the item in the scheduler. This is not thread safe.
*
* @param index The field {@link ObjectState#index} to be added
* @param weight positive weight for the added object
*/
void add(int index, double weight) {
checkArgument(weight > 0.0, "Weights need to be positive.");
ObjectState state = new ObjectState(Math.max(weight, MINIMUM_WEIGHT), index);
state.deadline = 1 / state.weight;
// TODO(zivy): randomize the initial deadline.
prioQueue.add(state);
}
/**
* Picks the next WRR object.
*/
int pick() {
synchronized (lock) {
ObjectState minObject = prioQueue.remove();
minObject.deadline += 1.0 / minObject.weight;
prioQueue.add(minObject);
return minObject.index;
}
}
}
/** Holds the state of the object. */
@VisibleForTesting
static class ObjectState {
private final double weight;
private final int index;
private volatile double deadline;
ObjectState(double weight, int index) {
this.weight = weight;
this.index = index;
}
}
static final class WeightedRoundRobinLoadBalancerConfig {
final long blackoutPeriodNanos;
final long weightExpirationPeriodNanos;
final boolean enableOobLoadReport;
final long oobReportingPeriodNanos;
final long weightUpdatePeriodNanos;
public static Builder newBuilder() {
return new Builder();
}
private WeightedRoundRobinLoadBalancerConfig(long blackoutPeriodNanos,
long weightExpirationPeriodNanos,
boolean enableOobLoadReport,
long oobReportingPeriodNanos,
long weightUpdatePeriodNanos) {
this.blackoutPeriodNanos = blackoutPeriodNanos;
this.weightExpirationPeriodNanos = weightExpirationPeriodNanos;
this.enableOobLoadReport = enableOobLoadReport;
this.oobReportingPeriodNanos = oobReportingPeriodNanos;
this.weightUpdatePeriodNanos = weightUpdatePeriodNanos;
}
static final class Builder {
long blackoutPeriodNanos = 10_000_000_000L; // 10s
long weightExpirationPeriodNanos = 180_000_000_000L; //3min
boolean enableOobLoadReport = false;
long oobReportingPeriodNanos = 10_000_000_000L; // 10s
long weightUpdatePeriodNanos = 1_000_000_000L; // 1s
private Builder() {
}
Builder setBlackoutPeriodNanos(long blackoutPeriodNanos) {
this.blackoutPeriodNanos = blackoutPeriodNanos;
return this;
}
Builder setWeightExpirationPeriodNanos(long weightExpirationPeriodNanos) {
this.weightExpirationPeriodNanos = weightExpirationPeriodNanos;
return this;
}
Builder setEnableOobLoadReport(boolean enableOobLoadReport) {
this.enableOobLoadReport = enableOobLoadReport;
return this;
}
Builder setOobReportingPeriodNanos(long oobReportingPeriodNanos) {
this.oobReportingPeriodNanos = oobReportingPeriodNanos;
return this;
}
Builder setWeightUpdatePeriodNanos(long weightUpdatePeriodNanos) {
this.weightUpdatePeriodNanos = weightUpdatePeriodNanos;
return this;
}
WeightedRoundRobinLoadBalancerConfig build() {
return new WeightedRoundRobinLoadBalancerConfig(blackoutPeriodNanos,
weightExpirationPeriodNanos, enableOobLoadReport, oobReportingPeriodNanos,
weightUpdatePeriodNanos);
}
}
}
}

View File

@ -0,0 +1,94 @@
/*
* Copyright 2023 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds;
import com.google.common.annotations.VisibleForTesting;
import io.grpc.Deadline;
import io.grpc.ExperimentalApi;
import io.grpc.Internal;
import io.grpc.LoadBalancer;
import io.grpc.LoadBalancer.Helper;
import io.grpc.LoadBalancerProvider;
import io.grpc.NameResolver.ConfigOrError;
import io.grpc.internal.JsonUtil;
import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedRoundRobinLoadBalancerConfig;
import java.util.Map;
/**
* Provides a {@link WeightedRoundRobinLoadBalancer}.
* */
@ExperimentalApi("https://github.com/grpc/grpc-java/issues/9885")
@Internal
public final class WeightedRoundRobinLoadBalancerProvider extends LoadBalancerProvider {
@VisibleForTesting
static final long MIN_WEIGHT_UPDATE_PERIOD_NANOS = 100_000_000L; // 100ms
static final String SCHEME = "weighted_round_robin_experimental";
@Override
public LoadBalancer newLoadBalancer(Helper helper) {
return new WeightedRoundRobinLoadBalancer(helper, Deadline.getSystemTicker());
}
@Override
public boolean isAvailable() {
return true;
}
@Override
public int getPriority() {
return 5;
}
@Override
public String getPolicyName() {
return SCHEME;
}
@Override
public ConfigOrError parseLoadBalancingPolicyConfig(Map<String, ?> rawConfig) {
Long blackoutPeriodNanos = JsonUtil.getStringAsDuration(rawConfig, "blackoutPeriod");
Long weightExpirationPeriodNanos =
JsonUtil.getStringAsDuration(rawConfig, "weightExpirationPeriod");
Long oobReportingPeriodNanos = JsonUtil.getStringAsDuration(rawConfig, "oobReportingPeriod");
Boolean enableOobLoadReport = JsonUtil.getBoolean(rawConfig, "enableOobLoadReport");
Long weightUpdatePeriodNanos = JsonUtil.getStringAsDuration(rawConfig, "weightUpdatePeriod");
WeightedRoundRobinLoadBalancerConfig.Builder configBuilder =
WeightedRoundRobinLoadBalancerConfig.newBuilder();
if (blackoutPeriodNanos != null) {
configBuilder.setBlackoutPeriodNanos(blackoutPeriodNanos);
}
if (weightExpirationPeriodNanos != null) {
configBuilder.setWeightExpirationPeriodNanos(weightExpirationPeriodNanos);
}
if (enableOobLoadReport != null) {
configBuilder.setEnableOobLoadReport(enableOobLoadReport);
}
if (oobReportingPeriodNanos != null) {
configBuilder.setOobReportingPeriodNanos(oobReportingPeriodNanos);
}
if (weightUpdatePeriodNanos != null) {
configBuilder.setWeightUpdatePeriodNanos(weightUpdatePeriodNanos);
if (weightUpdatePeriodNanos < MIN_WEIGHT_UPDATE_PERIOD_NANOS) {
configBuilder.setWeightUpdatePeriodNanos(MIN_WEIGHT_UPDATE_PERIOD_NANOS);
}
}
return ConfigOrError.fromConfig(configBuilder.build());
}
}

View File

@ -133,7 +133,7 @@ class XdsClusterResource extends XdsResourceType<CdsUpdate> {
CdsUpdate.Builder updateBuilder = structOrError.getStruct(); CdsUpdate.Builder updateBuilder = structOrError.getStruct();
ImmutableMap<String, ?> lbPolicyConfig = LoadBalancerConfigFactory.newConfig(cluster, ImmutableMap<String, ?> lbPolicyConfig = LoadBalancerConfigFactory.newConfig(cluster,
enableLeastRequest, enableCustomLbConfig); enableLeastRequest, enableCustomLbConfig, enableWrr);
// Validate the LB config by trying to parse it with the corresponding LB provider. // Validate the LB config by trying to parse it with the corresponding LB provider.
LbConfig lbConfig = ServiceConfigUtil.unwrapLoadBalancingConfig(lbPolicyConfig); LbConfig lbConfig = ServiceConfigUtil.unwrapLoadBalancingConfig(lbPolicyConfig);

View File

@ -59,6 +59,10 @@ abstract class XdsResourceType<T extends ResourceUpdate> {
!Strings.isNullOrEmpty(System.getenv("GRPC_EXPERIMENTAL_ENABLE_LEAST_REQUEST")) !Strings.isNullOrEmpty(System.getenv("GRPC_EXPERIMENTAL_ENABLE_LEAST_REQUEST"))
? Boolean.parseBoolean(System.getenv("GRPC_EXPERIMENTAL_ENABLE_LEAST_REQUEST")) ? Boolean.parseBoolean(System.getenv("GRPC_EXPERIMENTAL_ENABLE_LEAST_REQUEST"))
: Boolean.parseBoolean(System.getProperty("io.grpc.xds.experimentalEnableLeastRequest")); : Boolean.parseBoolean(System.getProperty("io.grpc.xds.experimentalEnableLeastRequest"));
@VisibleForTesting
static boolean enableWrr = getFlag("GRPC_EXPERIMENTAL_XDS_WRR_LB", false);
@VisibleForTesting @VisibleForTesting
static boolean enableCustomLbConfig = getFlag("GRPC_EXPERIMENTAL_XDS_CUSTOM_LB_CONFIG", true); static boolean enableCustomLbConfig = getFlag("GRPC_EXPERIMENTAL_XDS_CUSTOM_LB_CONFIG", true);
@VisibleForTesting @VisibleForTesting

View File

@ -7,3 +7,4 @@ io.grpc.xds.ClusterImplLoadBalancerProvider
io.grpc.xds.LeastRequestLoadBalancerProvider io.grpc.xds.LeastRequestLoadBalancerProvider
io.grpc.xds.RingHashLoadBalancerProvider io.grpc.xds.RingHashLoadBalancerProvider
io.grpc.xds.WrrLocalityLoadBalancerProvider io.grpc.xds.WrrLocalityLoadBalancerProvider
io.grpc.xds.WeightedRoundRobinLoadBalancerProvider

View File

@ -24,6 +24,8 @@ import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.google.protobuf.Any; import com.google.protobuf.Any;
import com.google.protobuf.BoolValue;
import com.google.protobuf.Duration;
import com.google.protobuf.Struct; import com.google.protobuf.Struct;
import com.google.protobuf.UInt32Value; import com.google.protobuf.UInt32Value;
import com.google.protobuf.UInt64Value; import com.google.protobuf.UInt64Value;
@ -36,6 +38,7 @@ import io.envoyproxy.envoy.config.cluster.v3.Cluster.RingHashLbConfig.HashFuncti
import io.envoyproxy.envoy.config.cluster.v3.LoadBalancingPolicy; import io.envoyproxy.envoy.config.cluster.v3.LoadBalancingPolicy;
import io.envoyproxy.envoy.config.cluster.v3.LoadBalancingPolicy.Policy; import io.envoyproxy.envoy.config.cluster.v3.LoadBalancingPolicy.Policy;
import io.envoyproxy.envoy.config.core.v3.TypedExtensionConfig; import io.envoyproxy.envoy.config.core.v3.TypedExtensionConfig;
import io.envoyproxy.envoy.extensions.load_balancing_policies.client_side_weighted_round_robin.v3.ClientSideWeightedRoundRobin;
import io.envoyproxy.envoy.extensions.load_balancing_policies.least_request.v3.LeastRequest; import io.envoyproxy.envoy.extensions.load_balancing_policies.least_request.v3.LeastRequest;
import io.envoyproxy.envoy.extensions.load_balancing_policies.ring_hash.v3.RingHash; import io.envoyproxy.envoy.extensions.load_balancing_policies.ring_hash.v3.RingHash;
import io.envoyproxy.envoy.extensions.load_balancing_policies.round_robin.v3.RoundRobin; import io.envoyproxy.envoy.extensions.load_balancing_policies.round_robin.v3.RoundRobin;
@ -78,6 +81,17 @@ public class LoadBalancerConfigFactoryTest {
LeastRequest.newBuilder().setChoiceCount(UInt32Value.of(LEAST_REQUEST_CHOICE_COUNT)) LeastRequest.newBuilder().setChoiceCount(UInt32Value.of(LEAST_REQUEST_CHOICE_COUNT))
.build()))).build(); .build()))).build();
private static final Policy WRR_POLICY = Policy.newBuilder()
.setTypedExtensionConfig(TypedExtensionConfig.newBuilder()
.setName("backend")
.setTypedConfig(
Any.pack(ClientSideWeightedRoundRobin.newBuilder()
.setBlackoutPeriod(Duration.newBuilder().setSeconds(287).build())
.setEnableOobLoadReport(
BoolValue.newBuilder().setValue(true).build())
.build()))
.build())
.build();
private static final String CUSTOM_POLICY_NAME = "myorg.MyCustomLeastRequestPolicy"; private static final String CUSTOM_POLICY_NAME = "myorg.MyCustomLeastRequestPolicy";
private static final String CUSTOM_POLICY_FIELD_KEY = "choiceCount"; private static final String CUSTOM_POLICY_FIELD_KEY = "choiceCount";
private static final double CUSTOM_POLICY_FIELD_VALUE = 2; private static final double CUSTOM_POLICY_FIELD_VALUE = 2;
@ -101,6 +115,11 @@ public class LoadBalancerConfigFactoryTest {
private static final LbConfig VALID_ROUND_ROBIN_CONFIG = new LbConfig("wrr_locality_experimental", private static final LbConfig VALID_ROUND_ROBIN_CONFIG = new LbConfig("wrr_locality_experimental",
ImmutableMap.of("childPolicy", ImmutableMap.of("childPolicy",
ImmutableList.of(ImmutableMap.of("round_robin", ImmutableMap.of())))); ImmutableList.of(ImmutableMap.of("round_robin", ImmutableMap.of()))));
private static final LbConfig VALID_WRR_CONFIG = new LbConfig("wrr_locality_experimental",
ImmutableMap.of("childPolicy", ImmutableList.of(
ImmutableMap.of("weighted_round_robin_experimental",
ImmutableMap.of("blackoutPeriod","287s", "enableOobLoadReport", true )))));
private static final LbConfig VALID_RING_HASH_CONFIG = new LbConfig("ring_hash_experimental", private static final LbConfig VALID_RING_HASH_CONFIG = new LbConfig("ring_hash_experimental",
ImmutableMap.of("minRingSize", (double) RING_HASH_MIN_RING_SIZE, "maxRingSize", ImmutableMap.of("minRingSize", (double) RING_HASH_MIN_RING_SIZE, "maxRingSize",
(double) RING_HASH_MAX_RING_SIZE)); (double) RING_HASH_MAX_RING_SIZE));
@ -123,14 +142,46 @@ public class LoadBalancerConfigFactoryTest {
public void roundRobin() throws ResourceInvalidException { public void roundRobin() throws ResourceInvalidException {
Cluster cluster = newCluster(buildWrrPolicy(ROUND_ROBIN_POLICY)); Cluster cluster = newCluster(buildWrrPolicy(ROUND_ROBIN_POLICY));
assertThat(newLbConfig(cluster, true, true)).isEqualTo(VALID_ROUND_ROBIN_CONFIG); assertThat(newLbConfig(cluster, true, true, true)).isEqualTo(VALID_ROUND_ROBIN_CONFIG);
}
@Test
public void weightedRoundRobin() throws ResourceInvalidException {
Cluster cluster = newCluster(buildWrrPolicy(WRR_POLICY));
assertThat(newLbConfig(cluster, true, true, true)).isEqualTo(VALID_WRR_CONFIG);
}
@Test
public void weightedRoundRobin_invalid() throws ResourceInvalidException {
Cluster cluster = newCluster(buildWrrPolicy(Policy.newBuilder()
.setTypedExtensionConfig(TypedExtensionConfig.newBuilder()
.setName("backend")
.setTypedConfig(
Any.pack(ClientSideWeightedRoundRobin.newBuilder()
.setBlackoutPeriod(Duration.newBuilder().setNanos(1000000000).build())
.setEnableOobLoadReport(
BoolValue.newBuilder().setValue(true).build())
.build()))
.build())
.build()));
assertResourceInvalidExceptionThrown(cluster, true, true, true,
"Invalid duration in weighted round robin config");
}
@Test
public void weightedRoundRobin_fallback_roundrobin() throws ResourceInvalidException {
Cluster cluster = newCluster(buildWrrPolicy(WRR_POLICY, ROUND_ROBIN_POLICY));
assertThat(newLbConfig(cluster, true, true, false)).isEqualTo(VALID_ROUND_ROBIN_CONFIG);
} }
@Test @Test
public void roundRobin_legacy() throws ResourceInvalidException { public void roundRobin_legacy() throws ResourceInvalidException {
Cluster cluster = Cluster.newBuilder().setLbPolicy(LbPolicy.ROUND_ROBIN).build(); Cluster cluster = Cluster.newBuilder().setLbPolicy(LbPolicy.ROUND_ROBIN).build();
assertThat(newLbConfig(cluster, true, true)).isEqualTo(VALID_ROUND_ROBIN_CONFIG); assertThat(newLbConfig(cluster, true, true, true)).isEqualTo(VALID_ROUND_ROBIN_CONFIG);
} }
@Test @Test
@ -139,7 +190,7 @@ public class LoadBalancerConfigFactoryTest {
.setLoadBalancingPolicy(LoadBalancingPolicy.newBuilder().addPolicies(RING_HASH_POLICY)) .setLoadBalancingPolicy(LoadBalancingPolicy.newBuilder().addPolicies(RING_HASH_POLICY))
.build(); .build();
assertThat(newLbConfig(cluster, true, true)).isEqualTo(VALID_RING_HASH_CONFIG); assertThat(newLbConfig(cluster, true, true, true)).isEqualTo(VALID_RING_HASH_CONFIG);
} }
@Test @Test
@ -149,7 +200,7 @@ public class LoadBalancerConfigFactoryTest {
.setMaximumRingSize(UInt64Value.of(RING_HASH_MAX_RING_SIZE)) .setMaximumRingSize(UInt64Value.of(RING_HASH_MAX_RING_SIZE))
.setHashFunction(HashFunction.XX_HASH)).build(); .setHashFunction(HashFunction.XX_HASH)).build();
assertThat(newLbConfig(cluster, true, true)).isEqualTo(VALID_RING_HASH_CONFIG); assertThat(newLbConfig(cluster, true, true, true)).isEqualTo(VALID_RING_HASH_CONFIG);
} }
@Test @Test
@ -161,7 +212,7 @@ public class LoadBalancerConfigFactoryTest {
.setMaximumRingSize(UInt64Value.of(RING_HASH_MAX_RING_SIZE)) .setMaximumRingSize(UInt64Value.of(RING_HASH_MAX_RING_SIZE))
.setHashFunction(RingHash.HashFunction.MURMUR_HASH_2).build()))).build()); .setHashFunction(RingHash.HashFunction.MURMUR_HASH_2).build()))).build());
assertResourceInvalidExceptionThrown(cluster, true, true, "Invalid ring hash function"); assertResourceInvalidExceptionThrown(cluster, true, true, true, "Invalid ring hash function");
} }
@Test @Test
@ -169,7 +220,7 @@ public class LoadBalancerConfigFactoryTest {
Cluster cluster = Cluster.newBuilder().setLbPolicy(LbPolicy.RING_HASH).setRingHashLbConfig( Cluster cluster = Cluster.newBuilder().setLbPolicy(LbPolicy.RING_HASH).setRingHashLbConfig(
RingHashLbConfig.newBuilder().setHashFunction(HashFunction.MURMUR_HASH_2)).build(); RingHashLbConfig.newBuilder().setHashFunction(HashFunction.MURMUR_HASH_2)).build();
assertResourceInvalidExceptionThrown(cluster, true, true, "invalid ring hash function"); assertResourceInvalidExceptionThrown(cluster, true, true, true, "invalid ring hash function");
} }
@Test @Test
@ -178,7 +229,7 @@ public class LoadBalancerConfigFactoryTest {
.setLoadBalancingPolicy(LoadBalancingPolicy.newBuilder().addPolicies(LEAST_REQUEST_POLICY)) .setLoadBalancingPolicy(LoadBalancingPolicy.newBuilder().addPolicies(LEAST_REQUEST_POLICY))
.build(); .build();
assertThat(newLbConfig(cluster, true, true)).isEqualTo(VALID_LEAST_REQUEST_CONFIG); assertThat(newLbConfig(cluster, true, true, true)).isEqualTo(VALID_LEAST_REQUEST_CONFIG);
} }
@Test @Test
@ -190,7 +241,7 @@ public class LoadBalancerConfigFactoryTest {
LeastRequestLbConfig.newBuilder() LeastRequestLbConfig.newBuilder()
.setChoiceCount(UInt32Value.of(LEAST_REQUEST_CHOICE_COUNT))).build(); .setChoiceCount(UInt32Value.of(LEAST_REQUEST_CHOICE_COUNT))).build();
LbConfig lbConfig = newLbConfig(cluster, true, true); LbConfig lbConfig = newLbConfig(cluster, true, true, true);
assertThat(lbConfig.getPolicyName()).isEqualTo("wrr_locality_experimental"); assertThat(lbConfig.getPolicyName()).isEqualTo("wrr_locality_experimental");
List<LbConfig> childConfigs = ServiceConfigUtil.unwrapLoadBalancingConfigList( List<LbConfig> childConfigs = ServiceConfigUtil.unwrapLoadBalancingConfigList(
@ -207,14 +258,15 @@ public class LoadBalancerConfigFactoryTest {
Cluster cluster = Cluster.newBuilder().setLbPolicy(LbPolicy.LEAST_REQUEST).build(); Cluster cluster = Cluster.newBuilder().setLbPolicy(LbPolicy.LEAST_REQUEST).build();
assertResourceInvalidExceptionThrown(cluster, false, true, "unsupported lb policy"); assertResourceInvalidExceptionThrown(cluster, false, true, true, "unsupported lb policy");
} }
@Test @Test
public void customRootLb_providerRegistered() throws ResourceInvalidException { public void customRootLb_providerRegistered() throws ResourceInvalidException {
LoadBalancerRegistry.getDefaultRegistry().register(CUSTOM_POLICY_PROVIDER); LoadBalancerRegistry.getDefaultRegistry().register(CUSTOM_POLICY_PROVIDER);
assertThat(newLbConfig(newCluster(CUSTOM_POLICY), false, true)).isEqualTo(VALID_CUSTOM_CONFIG); assertThat(newLbConfig(newCluster(CUSTOM_POLICY), false, true,
true)).isEqualTo(VALID_CUSTOM_CONFIG);
} }
@Test @Test
@ -223,7 +275,7 @@ public class LoadBalancerConfigFactoryTest {
.setLoadBalancingPolicy(LoadBalancingPolicy.newBuilder().addPolicies(CUSTOM_POLICY)) .setLoadBalancingPolicy(LoadBalancingPolicy.newBuilder().addPolicies(CUSTOM_POLICY))
.build(); .build();
assertResourceInvalidExceptionThrown(cluster, false, true, "Invalid LoadBalancingPolicy"); assertResourceInvalidExceptionThrown(cluster, false, true, true,"Invalid LoadBalancingPolicy");
} }
// When a provider for the endpoint picking custom policy is available, the configuration should // When a provider for the endpoint picking custom policy is available, the configuration should
@ -235,7 +287,7 @@ public class LoadBalancerConfigFactoryTest {
Cluster cluster = Cluster.newBuilder().setLoadBalancingPolicy(LoadBalancingPolicy.newBuilder() Cluster cluster = Cluster.newBuilder().setLoadBalancingPolicy(LoadBalancingPolicy.newBuilder()
.addPolicies(buildWrrPolicy(CUSTOM_POLICY, ROUND_ROBIN_POLICY))).build(); .addPolicies(buildWrrPolicy(CUSTOM_POLICY, ROUND_ROBIN_POLICY))).build();
assertThat(newLbConfig(cluster, false, true)).isEqualTo(VALID_CUSTOM_CONFIG_IN_WRR); assertThat(newLbConfig(cluster, false, true, true)).isEqualTo(VALID_CUSTOM_CONFIG_IN_WRR);
} }
// When a provider for the endpoint picking custom policy is available, the configuration should // When a provider for the endpoint picking custom policy is available, the configuration should
@ -247,7 +299,7 @@ public class LoadBalancerConfigFactoryTest {
Cluster cluster = Cluster.newBuilder().setLoadBalancingPolicy(LoadBalancingPolicy.newBuilder() Cluster cluster = Cluster.newBuilder().setLoadBalancingPolicy(LoadBalancingPolicy.newBuilder()
.addPolicies(buildWrrPolicy(CUSTOM_POLICY_UDPA, ROUND_ROBIN_POLICY))).build(); .addPolicies(buildWrrPolicy(CUSTOM_POLICY_UDPA, ROUND_ROBIN_POLICY))).build();
assertThat(newLbConfig(cluster, false, true)).isEqualTo(VALID_CUSTOM_CONFIG_IN_WRR); assertThat(newLbConfig(cluster, false, true, true)).isEqualTo(VALID_CUSTOM_CONFIG_IN_WRR);
} }
// When a provider for the custom wrr_locality child policy is NOT available, we should fall back // When a provider for the custom wrr_locality child policy is NOT available, we should fall back
@ -257,7 +309,7 @@ public class LoadBalancerConfigFactoryTest {
Cluster cluster = Cluster.newBuilder().setLoadBalancingPolicy(LoadBalancingPolicy.newBuilder() Cluster cluster = Cluster.newBuilder().setLoadBalancingPolicy(LoadBalancingPolicy.newBuilder()
.addPolicies(buildWrrPolicy(CUSTOM_POLICY, ROUND_ROBIN_POLICY))).build(); .addPolicies(buildWrrPolicy(CUSTOM_POLICY, ROUND_ROBIN_POLICY))).build();
assertThat(newLbConfig(cluster, false, true)).isEqualTo(VALID_ROUND_ROBIN_CONFIG); assertThat(newLbConfig(cluster, false, true, true)).isEqualTo(VALID_ROUND_ROBIN_CONFIG);
} }
// When a provider for the custom wrr_locality child policy is NOT available and no alternative // When a provider for the custom wrr_locality child policy is NOT available and no alternative
@ -267,7 +319,7 @@ public class LoadBalancerConfigFactoryTest {
Cluster cluster = Cluster.newBuilder().setLoadBalancingPolicy( Cluster cluster = Cluster.newBuilder().setLoadBalancingPolicy(
LoadBalancingPolicy.newBuilder().addPolicies(buildWrrPolicy(CUSTOM_POLICY))).build(); LoadBalancingPolicy.newBuilder().addPolicies(buildWrrPolicy(CUSTOM_POLICY))).build();
assertResourceInvalidExceptionThrown(cluster, false, true, "Invalid LoadBalancingPolicy"); assertResourceInvalidExceptionThrown(cluster, false, true, true, "Invalid LoadBalancingPolicy");
} }
@Test @Test
@ -278,7 +330,7 @@ public class LoadBalancerConfigFactoryTest {
.build(); .build();
// Custom LB flag not set, so we use old logic that will default to round_robin. // Custom LB flag not set, so we use old logic that will default to round_robin.
assertThat(newLbConfig(cluster, true, false)).isEqualTo(VALID_ROUND_ROBIN_CONFIG); assertThat(newLbConfig(cluster, true, false, true)).isEqualTo(VALID_ROUND_ROBIN_CONFIG);
} }
@Test @Test
@ -305,7 +357,7 @@ public class LoadBalancerConfigFactoryTest {
buildWrrPolicy( buildWrrPolicy(
ROUND_ROBIN_POLICY))))))))))))))))))).build(); ROUND_ROBIN_POLICY))))))))))))))))))).build();
assertResourceInvalidExceptionThrown(cluster, false, true, assertResourceInvalidExceptionThrown(cluster, false, true, true,
"Maximum LB config recursion depth reached"); "Maximum LB config recursion depth reached");
} }
@ -322,16 +374,17 @@ public class LoadBalancerConfigFactoryTest {
} }
private LbConfig newLbConfig(Cluster cluster, boolean enableLeastRequest, private LbConfig newLbConfig(Cluster cluster, boolean enableLeastRequest,
boolean enableCustomConfig) boolean enableCustomConfig, boolean enableWrr)
throws ResourceInvalidException { throws ResourceInvalidException {
return ServiceConfigUtil.unwrapLoadBalancingConfig( return ServiceConfigUtil.unwrapLoadBalancingConfig(
LoadBalancerConfigFactory.newConfig(cluster, enableLeastRequest, enableCustomConfig)); LoadBalancerConfigFactory.newConfig(cluster, enableLeastRequest, enableCustomConfig,
enableWrr));
} }
private void assertResourceInvalidExceptionThrown(Cluster cluster, boolean enableLeastRequest, private void assertResourceInvalidExceptionThrown(Cluster cluster, boolean enableLeastRequest,
boolean enableCustomConfig, String expectedMessage) { boolean enableCustomConfig, boolean enableWrr, String expectedMessage) {
try { try {
newLbConfig(cluster, enableLeastRequest, enableCustomConfig); newLbConfig(cluster, enableLeastRequest, enableCustomConfig, enableWrr);
} catch (ResourceInvalidException e) { } catch (ResourceInvalidException e) {
assertThat(e).hasMessageThat().contains(expectedMessage); assertThat(e).hasMessageThat().contains(expectedMessage);
return; return;

View File

@ -0,0 +1,116 @@
/*
* Copyright 2023 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import io.grpc.InternalServiceProviders;
import io.grpc.LoadBalancer;
import io.grpc.LoadBalancerProvider;
import io.grpc.NameResolver.ConfigOrError;
import io.grpc.SynchronizationContext;
import io.grpc.internal.FakeClock;
import io.grpc.internal.JsonParser;
import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedRoundRobinLoadBalancerConfig;
import java.io.IOException;
import java.util.Map;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Unit tests for {@link WeightedRoundRobinLoadBalancerProvider}. */
@RunWith(JUnit4.class)
public class WeightedRoundRobinLoadBalancerProviderTest {
private final WeightedRoundRobinLoadBalancerProvider provider =
new WeightedRoundRobinLoadBalancerProvider();
private final SynchronizationContext syncContext = new SynchronizationContext(
new Thread.UncaughtExceptionHandler() {
@Override
public void uncaughtException(Thread t, Throwable e) {
throw new AssertionError(e);
}
});
@Test
public void provided() {
for (LoadBalancerProvider current : InternalServiceProviders.getCandidatesViaServiceLoader(
LoadBalancerProvider.class, getClass().getClassLoader())) {
if (current instanceof WeightedRoundRobinLoadBalancerProvider) {
return;
}
}
fail("WeightedRoundRobinLoadBalancerProvider not registered");
}
@Test
public void providesLoadBalancer() {
LoadBalancer.Helper helper = mock(LoadBalancer.Helper.class);
when(helper.getSynchronizationContext()).thenReturn(syncContext);
when(helper.getScheduledExecutorService()).thenReturn(
new FakeClock().getScheduledExecutorService());
assertThat(provider.newLoadBalancer(helper))
.isInstanceOf(WeightedRoundRobinLoadBalancer.class);
}
@Test
public void parseLoadBalancingConfig() throws IOException {
String lbConfig =
"{\"blackoutPeriod\" : \"20s\","
+ " \"weightExpirationPeriod\" : \"300s\","
+ " \"oobReportingPeriod\" : \"100s\","
+ " \"enableOobLoadReport\" : true,"
+ " \"weightUpdatePeriod\" : \"2s\""
+ " }";
ConfigOrError configOrError = provider.parseLoadBalancingPolicyConfig(
parseJsonObject(lbConfig));
assertThat(configOrError.getConfig()).isNotNull();
WeightedRoundRobinLoadBalancerConfig config =
(WeightedRoundRobinLoadBalancerConfig) configOrError.getConfig();
assertThat(config.blackoutPeriodNanos).isEqualTo(20_000_000_000L);
assertThat(config.weightExpirationPeriodNanos).isEqualTo(300_000_000_000L);
assertThat(config.oobReportingPeriodNanos).isEqualTo(100_000_000_000L);
assertThat(config.enableOobLoadReport).isEqualTo(true);
assertThat(config.weightUpdatePeriodNanos).isEqualTo(2_000_000_000L);
}
@Test
public void parseLoadBalancingConfigDefaultValues() throws IOException {
String lbConfig = "{\"weightUpdatePeriod\" : \"0.02s\"}";
ConfigOrError configOrError = provider.parseLoadBalancingPolicyConfig(
parseJsonObject(lbConfig));
assertThat(configOrError.getConfig()).isNotNull();
WeightedRoundRobinLoadBalancerConfig config =
(WeightedRoundRobinLoadBalancerConfig) configOrError.getConfig();
assertThat(config.blackoutPeriodNanos).isEqualTo(10_000_000_000L);
assertThat(config.weightExpirationPeriodNanos).isEqualTo(180_000_000_000L);
assertThat(config.enableOobLoadReport).isEqualTo(false);
assertThat(config.weightUpdatePeriodNanos).isEqualTo(100_000_000L);
}
@SuppressWarnings("unchecked")
private static Map<String, ?> parseJsonObject(String json) throws IOException {
return (Map<String, ?>) JsonParser.parse(json);
}
}

View File

@ -0,0 +1,673 @@
/*
* Copyright 2023 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds;
import static com.google.common.truth.Truth.assertThat;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import com.github.xds.data.orca.v3.OrcaLoadReport;
import com.github.xds.service.orca.v3.OrcaLoadReportRequest;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.protobuf.Duration;
import io.grpc.Attributes;
import io.grpc.Channel;
import io.grpc.ChannelLogger;
import io.grpc.ClientCall;
import io.grpc.ConnectivityState;
import io.grpc.ConnectivityStateInfo;
import io.grpc.EquivalentAddressGroup;
import io.grpc.LoadBalancer;
import io.grpc.LoadBalancer.CreateSubchannelArgs;
import io.grpc.LoadBalancer.Helper;
import io.grpc.LoadBalancer.PickResult;
import io.grpc.LoadBalancer.ResolvedAddresses;
import io.grpc.LoadBalancer.Subchannel;
import io.grpc.LoadBalancer.SubchannelStateListener;
import io.grpc.SynchronizationContext;
import io.grpc.internal.FakeClock;
import io.grpc.services.InternalCallMetricRecorder;
import io.grpc.services.MetricReport;
import io.grpc.util.RoundRobinLoadBalancer.EmptyPicker;
import io.grpc.xds.WeightedRoundRobinLoadBalancer.EdfScheduler;
import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedRoundRobinLoadBalancerConfig;
import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedRoundRobinPicker;
import io.grpc.xds.WeightedRoundRobinLoadBalancer.WrrSubchannel;
import java.net.SocketAddress;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Random;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.junit.MockitoJUnit;
import org.mockito.junit.MockitoRule;
import org.mockito.stubbing.Answer;
@RunWith(JUnit4.class)
public class WeightedRoundRobinLoadBalancerTest {
@Rule
public final MockitoRule mockito = MockitoJUnit.rule();
@Mock
Helper helper;
@Mock
private LoadBalancer.PickSubchannelArgs mockArgs;
@Captor
private ArgumentCaptor<WeightedRoundRobinPicker> pickerCaptor;
private final List<EquivalentAddressGroup> servers = Lists.newArrayList();
private final Map<List<EquivalentAddressGroup>, Subchannel> subchannels = Maps.newLinkedHashMap();
private final Map<Subchannel, SubchannelStateListener> subchannelStateListeners =
Maps.newLinkedHashMap();
private final Queue<ClientCall<OrcaLoadReportRequest, OrcaLoadReport>> oobCalls =
new ConcurrentLinkedQueue<>();
private WeightedRoundRobinLoadBalancer wrr;
private final FakeClock fakeClock = new FakeClock();
private WeightedRoundRobinLoadBalancerConfig weightedConfig =
WeightedRoundRobinLoadBalancerConfig.newBuilder().build();
private static final Attributes.Key<String> MAJOR_KEY = Attributes.Key.create("major-key");
private final Attributes affinity =
Attributes.newBuilder().set(MAJOR_KEY, "I got the keys").build();
private final SynchronizationContext syncContext = new SynchronizationContext(
new Thread.UncaughtExceptionHandler() {
@Override
public void uncaughtException(Thread t, Throwable e) {
throw new AssertionError(e);
}
});
@Before
public void setup() {
for (int i = 0; i < 3; i++) {
SocketAddress addr = new FakeSocketAddress("server" + i);
EquivalentAddressGroup eag = new EquivalentAddressGroup(addr);
servers.add(eag);
Subchannel sc = mock(Subchannel.class);
Channel channel = mock(Channel.class);
when(channel.newCall(any(), any())).then(
new Answer<ClientCall<OrcaLoadReportRequest, OrcaLoadReport>>() {
@SuppressWarnings("unchecked")
@Override
public ClientCall<OrcaLoadReportRequest, OrcaLoadReport> answer(
InvocationOnMock invocation) throws Throwable {
ClientCall<OrcaLoadReportRequest, OrcaLoadReport> clientCall = mock(ClientCall.class);
oobCalls.add(clientCall);
return clientCall;
}
});
when(sc.asChannel()).thenReturn(channel);
subchannels.put(Arrays.asList(eag), sc);
}
when(helper.getSynchronizationContext()).thenReturn(syncContext);
when(helper.getScheduledExecutorService()).thenReturn(
fakeClock.getScheduledExecutorService());
when(helper.createSubchannel(any(CreateSubchannelArgs.class)))
.then(new Answer<Subchannel>() {
@Override
public Subchannel answer(InvocationOnMock invocation) throws Throwable {
CreateSubchannelArgs args = (CreateSubchannelArgs) invocation.getArguments()[0];
final Subchannel subchannel = subchannels.get(args.getAddresses());
when(subchannel.getAllAddresses()).thenReturn(args.getAddresses());
when(subchannel.getAttributes()).thenReturn(args.getAttributes());
when(subchannel.getChannelLogger()).thenReturn(mock(ChannelLogger.class));
doAnswer(
new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocation) throws Throwable {
subchannelStateListeners.put(
subchannel, (SubchannelStateListener) invocation.getArguments()[0]);
return null;
}
}).when(subchannel).start(any(SubchannelStateListener.class));
return subchannel;
}
});
wrr = new WeightedRoundRobinLoadBalancer(helper, fakeClock.getDeadlineTicker());
}
@Test
public void wrrLifeCycle() {
syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder()
.setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig)
.setAttributes(affinity).build()));
verify(helper, times(3)).createSubchannel(
any(CreateSubchannelArgs.class));
assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1);
Iterator<Subchannel> it = subchannels.values().iterator();
Subchannel readySubchannel1 = it.next();
subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
Subchannel readySubchannel2 = it.next();
subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
Subchannel connectingSubchannel = it.next();
subchannelStateListeners.get(connectingSubchannel).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.CONNECTING));
verify(helper, times(2)).updateBalancingState(
eq(ConnectivityState.READY), pickerCaptor.capture());
assertThat(pickerCaptor.getAllValues().size()).isEqualTo(2);
assertThat(pickerCaptor.getAllValues().get(0).getList().size()).isEqualTo(1);
WeightedRoundRobinPicker weightedPicker = pickerCaptor.getAllValues().get(1);
assertThat(weightedPicker.getList().size()).isEqualTo(2);
WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0);
WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1);
weightedSubchannel1.onLoadReport(InternalCallMetricRecorder.createMetricReport(
0.1, 0.1, 1, new HashMap<>(), new HashMap<>()));
weightedSubchannel2.onLoadReport(InternalCallMetricRecorder.createMetricReport(
0.2, 0.1, 1, new HashMap<>(), new HashMap<>()));
assertThat(fakeClock.forwardTime(11, TimeUnit.SECONDS)).isEqualTo(1);
assertThat(weightedPicker.pickSubchannel(mockArgs)
.getSubchannel()).isEqualTo(weightedSubchannel1);
assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1);
weightedConfig = WeightedRoundRobinLoadBalancerConfig.newBuilder()
.setWeightUpdatePeriodNanos(500_000_000L) //.5s
.build();
syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder()
.setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig)
.setAttributes(affinity).build()));
assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1);
syncContext.execute(() -> wrr.shutdown());
for (Subchannel subchannel: subchannels.values()) {
verify(subchannel).shutdown();
}
assertThat(fakeClock.getPendingTasks().size()).isEqualTo(0);
verifyNoMoreInteractions(mockArgs);
}
@Test
public void enableOobLoadReportConfig() {
syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder()
.setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig)
.setAttributes(affinity).build()));
verify(helper, times(3)).createSubchannel(
any(CreateSubchannelArgs.class));
Iterator<Subchannel> it = subchannels.values().iterator();
Subchannel readySubchannel1 = it.next();
subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
Subchannel readySubchannel2 = it.next();
subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
verify(helper, times(2)).updateBalancingState(
eq(ConnectivityState.READY), pickerCaptor.capture());
WeightedRoundRobinPicker weightedPicker = pickerCaptor.getAllValues().get(1);
WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0);
WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1);
weightedSubchannel1.onLoadReport(InternalCallMetricRecorder.createMetricReport(
0.1, 0.1, 1, new HashMap<>(), new HashMap<>()));
weightedSubchannel2.onLoadReport(InternalCallMetricRecorder.createMetricReport(
0.9, 0.1, 1, new HashMap<>(), new HashMap<>()));
assertThat(fakeClock.forwardTime(11, TimeUnit.SECONDS)).isEqualTo(1);
PickResult pickResult = weightedPicker.pickSubchannel(mockArgs);
assertThat(pickResult.getSubchannel()).isEqualTo(weightedSubchannel1);
assertThat(pickResult.getStreamTracerFactory()).isNotNull(); // verify per-request listener
assertThat(oobCalls.isEmpty()).isTrue();
weightedConfig = WeightedRoundRobinLoadBalancerConfig.newBuilder().setEnableOobLoadReport(true)
.setOobReportingPeriodNanos(20_030_000_000L)
.build();
syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder()
.setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig)
.setAttributes(affinity).build()));
pickResult = weightedPicker.pickSubchannel(mockArgs);
assertThat(pickResult.getSubchannel()).isEqualTo(weightedSubchannel1);
assertThat(pickResult.getStreamTracerFactory()).isNull();
OrcaLoadReportRequest golden = OrcaLoadReportRequest.newBuilder().setReportInterval(
Duration.newBuilder().setSeconds(20).setNanos(30000000).build()).build();
assertThat(oobCalls.size()).isEqualTo(2);
verify(oobCalls.poll()).sendMessage(eq(golden));
verify(oobCalls.poll()).sendMessage(eq(golden));
}
private void pickByWeight(MetricReport r1, MetricReport r2, MetricReport r3,
double subchannel1PickRatio, double subchannel2PickRatio,
double subchannel3PickRatio) {
syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder()
.setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig)
.setAttributes(affinity).build()));
verify(helper, times(3)).createSubchannel(
any(CreateSubchannelArgs.class));
assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1);
Iterator<Subchannel> it = subchannels.values().iterator();
Subchannel readySubchannel1 = it.next();
subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
Subchannel readySubchannel2 = it.next();
subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
Subchannel readySubchannel3 = it.next();
subchannelStateListeners.get(readySubchannel3).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
verify(helper, times(3)).updateBalancingState(
eq(ConnectivityState.READY), pickerCaptor.capture());
WeightedRoundRobinPicker weightedPicker = pickerCaptor.getAllValues().get(2);
WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0);
WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1);
WrrSubchannel weightedSubchannel3 = (WrrSubchannel) weightedPicker.getList().get(2);
weightedSubchannel1.onLoadReport(r1);
weightedSubchannel2.onLoadReport(r2);
weightedSubchannel3.onLoadReport(r3);
assertThat(fakeClock.forwardTime(11, TimeUnit.SECONDS)).isEqualTo(1);
Map<Subchannel, Integer> pickCount = new HashMap<>();
for (int i = 0; i < 10000; i++) {
Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel();
pickCount.put(result, pickCount.getOrDefault(result, 0) + 1);
}
assertThat(pickCount.size()).isEqualTo(3);
assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 10000.0 - subchannel1PickRatio))
.isAtMost(0.001);
assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 10000.0 - subchannel2PickRatio ))
.isAtMost(0.001);
assertThat(Math.abs(pickCount.get(weightedSubchannel3) / 10000.0 - subchannel3PickRatio ))
.isAtMost(0.001);
}
@Test
public void pickByWeight_LargeWeight() {
MetricReport report1 = InternalCallMetricRecorder.createMetricReport(
0.1, 0.1, 999, new HashMap<>(), new HashMap<>());
MetricReport report2 = InternalCallMetricRecorder.createMetricReport(
0.9, 0.1, 2, new HashMap<>(), new HashMap<>());
MetricReport report3 = InternalCallMetricRecorder.createMetricReport(
0.86, 0.1, 100, new HashMap<>(), new HashMap<>());
double totalWeight = 999 / 0.1 + 2 / 0.9 + 100 / 0.86;
pickByWeight(report1, report2, report3, 999 / 0.1 / totalWeight, 2 / 0.9 / totalWeight,
100 / 0.86 / totalWeight);
}
@Test
public void pickByWeight_normalWeight() {
MetricReport report1 = InternalCallMetricRecorder.createMetricReport(
0.12, 0.1, 22, new HashMap<>(), new HashMap<>());
MetricReport report2 = InternalCallMetricRecorder.createMetricReport(
0.28, 0.1, 40, new HashMap<>(), new HashMap<>());
MetricReport report3 = InternalCallMetricRecorder.createMetricReport(
0.86, 0.1, 100, new HashMap<>(), new HashMap<>());
double totalWeight = 22 / 0.12 + 40 / 0.28 + 100 / 0.86;
pickByWeight(report1, report2, report3, 22 / 0.12 / totalWeight,
40 / 0.28 / totalWeight, 100 / 0.86 / totalWeight
);
}
@Test
public void emptyConfig() {
assertThat(wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder()
.setAddresses(servers).setLoadBalancingPolicyConfig(null)
.setAttributes(affinity).build())).isFalse();
verify(helper, never()).createSubchannel(any(CreateSubchannelArgs.class));
verify(helper).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any());
assertThat(fakeClock.getPendingTasks()).isEmpty();
syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder()
.setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig)
.setAttributes(affinity).build()));
verify(helper, times(3)).createSubchannel(
any(CreateSubchannelArgs.class));
verify(helper).updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture());
assertThat(pickerCaptor.getValue()).isInstanceOf(EmptyPicker.class);
assertThat(fakeClock.forwardTime(11, TimeUnit.SECONDS)).isEqualTo(1);
}
@Test
public void blackoutPeriod() {
syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder()
.setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig)
.setAttributes(affinity).build()));
verify(helper, times(3)).createSubchannel(
any(CreateSubchannelArgs.class));
assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1);
Iterator<Subchannel> it = subchannels.values().iterator();
Subchannel readySubchannel1 = it.next();
subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
Subchannel readySubchannel2 = it.next();
subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
verify(helper, times(2)).updateBalancingState(
eq(ConnectivityState.READY), pickerCaptor.capture());
WeightedRoundRobinPicker weightedPicker = pickerCaptor.getAllValues().get(1);
WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0);
WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1);
weightedSubchannel1.onLoadReport(InternalCallMetricRecorder.createMetricReport(
0.1, 0.1, 1, new HashMap<>(), new HashMap<>()));
weightedSubchannel2.onLoadReport(InternalCallMetricRecorder.createMetricReport(
0.2, 0.1, 1, new HashMap<>(), new HashMap<>()));
assertThat(fakeClock.forwardTime(5, TimeUnit.SECONDS)).isEqualTo(1);
Map<Subchannel, Integer> pickCount = new HashMap<>();
for (int i = 0; i < 1000; i++) {
Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel();
pickCount.put(result, pickCount.getOrDefault(result, 0) + 1);
}
assertThat(pickCount.size()).isEqualTo(2);
// within blackout period, fallback to simple round robin
assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 0.5)).isAtMost(0.001);
assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 0.5)).isAtMost(0.001);
assertThat(fakeClock.forwardTime(5, TimeUnit.SECONDS)).isEqualTo(1);
pickCount = new HashMap<>();
for (int i = 0; i < 1000; i++) {
Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel();
pickCount.put(result, pickCount.getOrDefault(result, 0) + 1);
}
assertThat(pickCount.size()).isEqualTo(2);
// after blackout period
assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 2.0 / 3))
.isAtMost(0.001);
assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 1.0 / 3))
.isAtMost(0.001);
}
@Test
public void updateWeightTimer() {
syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder()
.setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig)
.setAttributes(affinity).build()));
verify(helper, times(3)).createSubchannel(
any(CreateSubchannelArgs.class));
assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1);
Iterator<Subchannel> it = subchannels.values().iterator();
Subchannel readySubchannel1 = it.next();
subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
Subchannel readySubchannel2 = it.next();
subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
Subchannel connectingSubchannel = it.next();
subchannelStateListeners.get(connectingSubchannel).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.CONNECTING));
verify(helper, times(2)).updateBalancingState(
eq(ConnectivityState.READY), pickerCaptor.capture());
assertThat(pickerCaptor.getAllValues().size()).isEqualTo(2);
assertThat(pickerCaptor.getAllValues().get(0).getList().size()).isEqualTo(1);
WeightedRoundRobinPicker weightedPicker = pickerCaptor.getAllValues().get(1);
assertThat(weightedPicker.getList().size()).isEqualTo(2);
WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0);
WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1);
weightedSubchannel1.onLoadReport(InternalCallMetricRecorder.createMetricReport(
0.1, 0.1, 1, new HashMap<>(), new HashMap<>()));
weightedSubchannel2.onLoadReport(InternalCallMetricRecorder.createMetricReport(
0.2, 0.1, 1, new HashMap<>(), new HashMap<>()));
assertThat(fakeClock.forwardTime(11, TimeUnit.SECONDS)).isEqualTo(1);
assertThat(weightedPicker.pickSubchannel(mockArgs)
.getSubchannel()).isEqualTo(weightedSubchannel1);
assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1);
weightedConfig = WeightedRoundRobinLoadBalancerConfig.newBuilder()
.setWeightUpdatePeriodNanos(500_000_000L) //.5s
.build();
syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder()
.setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig)
.setAttributes(affinity).build()));
assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1);
weightedSubchannel1.onLoadReport(InternalCallMetricRecorder.createMetricReport(
0.2, 0.1, 1, new HashMap<>(), new HashMap<>()));
weightedSubchannel2.onLoadReport(InternalCallMetricRecorder.createMetricReport(
0.1, 0.1, 1, new HashMap<>(), new HashMap<>()));
//timer fires, new weight updated
assertThat(fakeClock.forwardTime(500, TimeUnit.MILLISECONDS)).isEqualTo(1);
assertThat(weightedPicker.pickSubchannel(mockArgs)
.getSubchannel()).isEqualTo(weightedSubchannel2);
}
@Test
public void weightExpired() {
syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder()
.setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig)
.setAttributes(affinity).build()));
verify(helper, times(3)).createSubchannel(
any(CreateSubchannelArgs.class));
assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1);
Iterator<Subchannel> it = subchannels.values().iterator();
Subchannel readySubchannel1 = it.next();
subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
Subchannel readySubchannel2 = it.next();
subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
verify(helper, times(2)).updateBalancingState(
eq(ConnectivityState.READY), pickerCaptor.capture());
WeightedRoundRobinPicker weightedPicker = pickerCaptor.getAllValues().get(1);
WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0);
WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1);
weightedSubchannel1.onLoadReport(InternalCallMetricRecorder.createMetricReport(
0.1, 0.1, 1, new HashMap<>(), new HashMap<>()));
weightedSubchannel2.onLoadReport(InternalCallMetricRecorder.createMetricReport(
0.2, 0.1, 1, new HashMap<>(), new HashMap<>()));
assertThat(fakeClock.forwardTime(10, TimeUnit.SECONDS)).isEqualTo(1);
Map<Subchannel, Integer> pickCount = new HashMap<>();
for (int i = 0; i < 1000; i++) {
Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel();
pickCount.put(result, pickCount.getOrDefault(result, 0) + 1);
}
assertThat(pickCount.size()).isEqualTo(2);
assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 2.0 / 3))
.isAtMost(0.001);
assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 1.0 / 3))
.isAtMost(0.001);
// weight expired, fallback to simple round robin
assertThat(fakeClock.forwardTime(300, TimeUnit.SECONDS)).isEqualTo(1);
pickCount = new HashMap<>();
for (int i = 0; i < 1000; i++) {
Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel();
pickCount.put(result, pickCount.getOrDefault(result, 0) + 1);
}
assertThat(pickCount.size()).isEqualTo(2);
assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 0.5))
.isAtMost(0.001);
assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 0.5))
.isAtMost(0.001);
}
@Test
public void unknownWeightIsAvgWeight() {
syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder()
.setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig)
.setAttributes(affinity).build()));
verify(helper, times(3)).createSubchannel(
any(CreateSubchannelArgs.class));
assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1);
Iterator<Subchannel> it = subchannels.values().iterator();
Subchannel readySubchannel1 = it.next();
subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
Subchannel readySubchannel2 = it.next();
subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
Subchannel readySubchannel3 = it.next();
subchannelStateListeners.get(readySubchannel3).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
verify(helper, times(3)).updateBalancingState(
eq(ConnectivityState.READY), pickerCaptor.capture());
WeightedRoundRobinPicker weightedPicker = pickerCaptor.getAllValues().get(2);
WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0);
WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1);
WrrSubchannel weightedSubchannel3 = (WrrSubchannel) weightedPicker.getList().get(2);
weightedSubchannel1.onLoadReport(InternalCallMetricRecorder.createMetricReport(
0.1, 0.1, 1, new HashMap<>(), new HashMap<>()));
weightedSubchannel2.onLoadReport(InternalCallMetricRecorder.createMetricReport(
0.2, 0.1, 1, new HashMap<>(), new HashMap<>()));
assertThat(fakeClock.forwardTime(10, TimeUnit.SECONDS)).isEqualTo(1);
Map<Subchannel, Integer> pickCount = new HashMap<>();
for (int i = 0; i < 1000; i++) {
Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel();
pickCount.put(result, pickCount.getOrDefault(result, 0) + 1);
}
assertThat(pickCount.size()).isEqualTo(3);
assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 4.0 / 9))
.isAtMost(0.001);
assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 2.0 / 9))
.isAtMost(0.001);
// subchannel3's weight is average of subchannel1 and subchannel2
assertThat(Math.abs(pickCount.get(weightedSubchannel3) / 1000.0 - 3.0 / 9))
.isAtMost(0.001);
}
@Test
public void pickFromOtherThread() throws Exception {
syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder()
.setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig)
.setAttributes(affinity).build()));
verify(helper, times(3)).createSubchannel(
any(CreateSubchannelArgs.class));
assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1);
Iterator<Subchannel> it = subchannels.values().iterator();
Subchannel readySubchannel1 = it.next();
subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
Subchannel readySubchannel2 = it.next();
subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
verify(helper, times(2)).updateBalancingState(
eq(ConnectivityState.READY), pickerCaptor.capture());
WeightedRoundRobinPicker weightedPicker = pickerCaptor.getAllValues().get(1);
WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0);
WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1);
weightedSubchannel1.onLoadReport(InternalCallMetricRecorder.createMetricReport(
0.1, 0.1, 1, new HashMap<>(), new HashMap<>()));
weightedSubchannel2.onLoadReport(InternalCallMetricRecorder.createMetricReport(
0.2, 0.1, 1, new HashMap<>(), new HashMap<>()));
assertThat(weightedPicker.toString()).contains("rrMode=true");
CyclicBarrier barrier = new CyclicBarrier(2);
Map<Subchannel, AtomicInteger> pickCount = new ConcurrentHashMap<>();
pickCount.put(weightedSubchannel1, new AtomicInteger(0));
pickCount.put(weightedSubchannel2, new AtomicInteger(0));
new Thread(new Runnable() {
@Override
public void run() {
try {
weightedPicker.pickSubchannel(mockArgs);
barrier.await();
for (int i = 0; i < 1000; i++) {
Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel();
pickCount.get(result).addAndGet(1);
}
barrier.await();
} catch (Exception ex) {
throw new AssertionError(ex);
}
}
}).start();
assertThat(fakeClock.forwardTime(10, TimeUnit.SECONDS)).isEqualTo(1);
barrier.await();
for (int i = 0; i < 1000; i++) {
Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel();
pickCount.get(result).addAndGet(1);
}
barrier.await();
assertThat(pickCount.size()).isEqualTo(2);
// after blackout period
assertThat(Math.abs(pickCount.get(weightedSubchannel1).get() / 2000.0 - 2.0 / 3))
.isAtMost(0.001);
assertThat(Math.abs(pickCount.get(weightedSubchannel2).get() / 2000.0 - 1.0 / 3))
.isAtMost(0.001);
}
@Test
public void edfScheduler() {
Random random = new Random();
double totalWeight = 0;
int capacity = random.nextInt(10) + 1;
double[] weights = new double[capacity];
EdfScheduler scheduler = new EdfScheduler(capacity);
for (int i = 0; i < capacity; i++) {
weights[i] = random.nextDouble();
scheduler.add(i, weights[i]);
totalWeight += weights[i];
}
Map<Integer, Integer> pickCount = new HashMap<>();
for (int i = 0; i < 1000; i++) {
int result = scheduler.pick();
pickCount.put(result, pickCount.getOrDefault(result, 0) + 1);
}
for (int i = 0; i < capacity; i++) {
assertThat(Math.abs(pickCount.get(i) / 1000.0 - weights[i] / totalWeight) ).isAtMost(0.01);
}
}
@Test
public void edsScheduler_sameWeight() {
EdfScheduler scheduler = new EdfScheduler(2);
scheduler.add(0, 0.5);
scheduler.add(1, 0.5);
assertThat(scheduler.pick()).isEqualTo(0);
}
@Test(expected = NullPointerException.class)
public void wrrConfig_TimeValueNonNull() {
WeightedRoundRobinLoadBalancerConfig.newBuilder().setBlackoutPeriodNanos((Long) null);
}
@Test(expected = NullPointerException.class)
public void wrrConfig_BooleanValueNonNull() {
WeightedRoundRobinLoadBalancerConfig.newBuilder().setEnableOobLoadReport((Boolean) null);
}
private static class FakeSocketAddress extends SocketAddress {
final String name;
FakeSocketAddress(String name) {
this.name = name;
}
@Override public String toString() {
return "FakeSocketAddress-" + name;
}
}
}

View File

@ -27,6 +27,7 @@ import com.google.common.collect.Iterables;
import com.google.protobuf.Any; import com.google.protobuf.Any;
import com.google.protobuf.BoolValue; import com.google.protobuf.BoolValue;
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import com.google.protobuf.Duration;
import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Message; import com.google.protobuf.Message;
import com.google.protobuf.StringValue; import com.google.protobuf.StringValue;
@ -39,6 +40,7 @@ import io.envoyproxy.envoy.config.cluster.v3.Cluster;
import io.envoyproxy.envoy.config.cluster.v3.Cluster.DiscoveryType; import io.envoyproxy.envoy.config.cluster.v3.Cluster.DiscoveryType;
import io.envoyproxy.envoy.config.cluster.v3.Cluster.EdsClusterConfig; import io.envoyproxy.envoy.config.cluster.v3.Cluster.EdsClusterConfig;
import io.envoyproxy.envoy.config.cluster.v3.Cluster.LbPolicy; import io.envoyproxy.envoy.config.cluster.v3.Cluster.LbPolicy;
import io.envoyproxy.envoy.config.cluster.v3.LoadBalancingPolicy;
import io.envoyproxy.envoy.config.core.v3.Address; import io.envoyproxy.envoy.config.core.v3.Address;
import io.envoyproxy.envoy.config.core.v3.AggregatedConfigSource; import io.envoyproxy.envoy.config.core.v3.AggregatedConfigSource;
import io.envoyproxy.envoy.config.core.v3.CidrRange; import io.envoyproxy.envoy.config.core.v3.CidrRange;
@ -85,6 +87,8 @@ import io.envoyproxy.envoy.extensions.filters.http.router.v3.Router;
import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager; import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager;
import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.HttpFilter; import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.HttpFilter;
import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.Rds; import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.Rds;
import io.envoyproxy.envoy.extensions.load_balancing_policies.client_side_weighted_round_robin.v3.ClientSideWeightedRoundRobin;
import io.envoyproxy.envoy.extensions.load_balancing_policies.wrr_locality.v3.WrrLocality;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateProviderPluginInstance; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateProviderPluginInstance;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext;
@ -128,6 +132,7 @@ import io.grpc.xds.VirtualHost.Route.RouteAction.ClusterWeight;
import io.grpc.xds.VirtualHost.Route.RouteAction.HashPolicy; import io.grpc.xds.VirtualHost.Route.RouteAction.HashPolicy;
import io.grpc.xds.VirtualHost.Route.RouteMatch; import io.grpc.xds.VirtualHost.Route.RouteMatch;
import io.grpc.xds.VirtualHost.Route.RouteMatch.PathMatcher; import io.grpc.xds.VirtualHost.Route.RouteMatch.PathMatcher;
import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedRoundRobinLoadBalancerConfig;
import io.grpc.xds.XdsClientImpl.ResourceInvalidException; import io.grpc.xds.XdsClientImpl.ResourceInvalidException;
import io.grpc.xds.XdsClusterResource.CdsUpdate; import io.grpc.xds.XdsClusterResource.CdsUpdate;
import io.grpc.xds.XdsResourceType.StructOrError; import io.grpc.xds.XdsResourceType.StructOrError;
@ -163,6 +168,7 @@ public class XdsClientImplDataTest {
private boolean originalEnableRbac; private boolean originalEnableRbac;
private boolean originalEnableRouteLookup; private boolean originalEnableRouteLookup;
private boolean originalEnableLeastRequest; private boolean originalEnableLeastRequest;
private boolean originalEnableWrr;
@Before @Before
public void setUp() { public void setUp() {
@ -174,6 +180,8 @@ public class XdsClientImplDataTest {
assertThat(originalEnableRouteLookup).isFalse(); assertThat(originalEnableRouteLookup).isFalse();
originalEnableLeastRequest = XdsResourceType.enableLeastRequest; originalEnableLeastRequest = XdsResourceType.enableLeastRequest;
assertThat(originalEnableLeastRequest).isFalse(); assertThat(originalEnableLeastRequest).isFalse();
originalEnableWrr = XdsResourceType.enableWrr;
assertThat(originalEnableWrr).isFalse();
} }
@After @After
@ -182,6 +190,7 @@ public class XdsClientImplDataTest {
XdsResourceType.enableRbac = originalEnableRbac; XdsResourceType.enableRbac = originalEnableRbac;
XdsResourceType.enableRouteLookup = originalEnableRouteLookup; XdsResourceType.enableRouteLookup = originalEnableRouteLookup;
XdsResourceType.enableLeastRequest = originalEnableLeastRequest; XdsResourceType.enableLeastRequest = originalEnableLeastRequest;
XdsResourceType.enableWrr = originalEnableWrr;
} }
@Test @Test
@ -1966,6 +1975,65 @@ public class XdsClientImplDataTest {
assertThat(childConfigs.get(0).getPolicyName()).isEqualTo("least_request_experimental"); assertThat(childConfigs.get(0).getPolicyName()).isEqualTo("least_request_experimental");
} }
@Test
public void parseCluster_WrrLbPolicy_defaultLbConfig() throws ResourceInvalidException {
XdsResourceType.enableWrr = true;
LoadBalancingPolicy wrrConfig =
LoadBalancingPolicy.newBuilder().addPolicies(
LoadBalancingPolicy.Policy.newBuilder()
.setTypedExtensionConfig(TypedExtensionConfig.newBuilder()
.setName("backend")
.setTypedConfig(
Any.pack(ClientSideWeightedRoundRobin.newBuilder()
.setBlackoutPeriod(Duration.newBuilder().setSeconds(17).build())
.setEnableOobLoadReport(
BoolValue.newBuilder().setValue(true).build())
.build()))
.build())
.build())
.build();
Cluster cluster = Cluster.newBuilder()
.setName("cluster-foo.googleapis.com")
.setType(DiscoveryType.EDS)
.setEdsClusterConfig(
EdsClusterConfig.newBuilder()
.setEdsConfig(
ConfigSource.newBuilder()
.setAds(AggregatedConfigSource.getDefaultInstance()))
.setServiceName("service-foo.googleapis.com"))
.setLoadBalancingPolicy(
LoadBalancingPolicy.newBuilder().addPolicies(
LoadBalancingPolicy.Policy.newBuilder()
.setTypedExtensionConfig(
TypedExtensionConfig.newBuilder()
.setTypedConfig(
Any.pack(WrrLocality.newBuilder()
.setEndpointPickingPolicy(wrrConfig)
.build()))
.build())
.build())
.build())
.build();
CdsUpdate update = XdsClusterResource.processCluster(
cluster, null, LRS_SERVER_INFO,
LoadBalancerRegistry.getDefaultRegistry());
LbConfig lbConfig = ServiceConfigUtil.unwrapLoadBalancingConfig(update.lbPolicyConfig());
assertThat(lbConfig.getPolicyName()).isEqualTo("wrr_locality_experimental");
List<LbConfig> childConfigs = ServiceConfigUtil.unwrapLoadBalancingConfigList(
JsonUtil.getListOfObjects(lbConfig.getRawConfigValue(), "childPolicy"));
assertThat(childConfigs.get(0).getPolicyName()).isEqualTo("weighted_round_robin_experimental");
WeightedRoundRobinLoadBalancerConfig result = (WeightedRoundRobinLoadBalancerConfig)
new WeightedRoundRobinLoadBalancerProvider().parseLoadBalancingPolicyConfig(
childConfigs.get(0).getRawConfigValue()).getConfig();
assertThat(result.blackoutPeriodNanos).isEqualTo(17_000_000_000L);
assertThat(result.enableOobLoadReport).isTrue();
assertThat(result.oobReportingPeriodNanos).isEqualTo(10_000_000_000L);
assertThat(result.weightUpdatePeriodNanos).isEqualTo(1_000_000_000L);
assertThat(result.weightExpirationPeriodNanos).isEqualTo(180_000_000_000L);
}
@Test @Test
public void parseCluster_transportSocketMatches_exception() throws ResourceInvalidException { public void parseCluster_transportSocketMatches_exception() throws ResourceInvalidException {
Cluster cluster = Cluster.newBuilder() Cluster cluster = Cluster.newBuilder()