diff --git a/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java index 8a0d97c57b..aaee161995 100644 --- a/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java @@ -46,6 +46,8 @@ import java.util.PriorityQueue; import java.util.Random; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; +import java.util.logging.Level; +import java.util.logging.Logger; /** * A {@link LoadBalancer} that provides weighted-round-robin load-balancing over @@ -54,6 +56,8 @@ import java.util.concurrent.TimeUnit; */ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/9885") final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer { + private static final Logger log = Logger.getLogger( + WeightedRoundRobinLoadBalancer.class.getName()); private volatile WeightedRoundRobinLoadBalancerConfig config; private final SynchronizationContext syncContext; private final ScheduledExecutorService timeService; @@ -76,6 +80,7 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer { this.timeService = checkNotNull(helper.getScheduledExecutorService(), "timeService"); this.updateWeightTask = new UpdateWeightTask(); this.random = random; + log.log(Level.FINE, "weighted_round_robin LB created"); } @VisibleForTesting @@ -230,7 +235,6 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer { final class WeightedRoundRobinPicker extends ReadyPicker { private final List list; private volatile EdfScheduler scheduler; - private volatile boolean rrMode; WeightedRoundRobinPicker(List list) { super(checkNotNull(list, "list"), random.nextInt(list.size())); @@ -241,16 +245,11 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer { @Override public PickResult pickSubchannel(PickSubchannelArgs args) { - if (rrMode) { - return super.pickSubchannel(args); - } - int pickIndex = scheduler.pick(); - WrrSubchannel subchannel = (WrrSubchannel) list.get(pickIndex); + Subchannel subchannel = list.get(scheduler.pick()); if (!config.enableOobLoadReport) { - return PickResult.withSubchannel( - subchannel, - OrcaPerRequestUtil.getInstance().newOrcaClientStreamTracerFactory( - subchannel.perRpcListener)); + return PickResult.withSubchannel(subchannel, + OrcaPerRequestUtil.getInstance().newOrcaClientStreamTracerFactory( + ((WrrSubchannel)subchannel).perRpcListener)); } else { return PickResult.withSubchannel(subchannel); } @@ -266,25 +265,24 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer { weightedChannelCount++; } } - if (weightedChannelCount < 2) { - rrMode = true; - return; - } EdfScheduler scheduler = new EdfScheduler(list.size(), random); - avgWeight /= 1.0 * weightedChannelCount; + if (weightedChannelCount >= 1) { + avgWeight /= 1.0 * weightedChannelCount; + } else { + avgWeight = 1; + } for (int i = 0; i < list.size(); i++) { WrrSubchannel subchannel = (WrrSubchannel) list.get(i); double newWeight = subchannel.getWeight(); scheduler.add(i, newWeight > 0 ? newWeight : avgWeight); } this.scheduler = scheduler; - rrMode = false; } @Override public String toString() { return MoreObjects.toStringHelper(WeightedRoundRobinPicker.class) - .add("list", list).add("rrMode", rrMode).toString(); + .add("list", list).toString(); } @VisibleForTesting diff --git a/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java index 8ab45ef850..16b89cc4bf 100644 --- a/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java @@ -29,6 +29,7 @@ import static org.mockito.Mockito.when; import com.github.xds.data.orca.v3.OrcaLoadReport; import com.github.xds.service.orca.v3.OrcaLoadReportRequest; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.protobuf.Duration; @@ -514,6 +515,62 @@ public class WeightedRoundRobinLoadBalancerTest { .isAtMost(0.001); } + @Test + public void rrFallback() { + syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() + .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) + .setAttributes(affinity).build())); + verify(helper, times(3)).createSubchannel( + any(CreateSubchannelArgs.class)); + assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); + + Iterator it = subchannels.values().iterator(); + Subchannel readySubchannel1 = it.next(); + subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo + .forNonError(ConnectivityState.READY)); + Subchannel readySubchannel2 = it.next(); + subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo + .forNonError(ConnectivityState.READY)); + verify(helper, times(2)).updateBalancingState( + eq(ConnectivityState.READY), pickerCaptor.capture()); + WeightedRoundRobinPicker weightedPicker = pickerCaptor.getAllValues().get(1); + assertThat(fakeClock.forwardTime(10, TimeUnit.SECONDS)).isEqualTo(1); + WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0); + WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1); + Map qpsByChannel = ImmutableMap.of(weightedSubchannel1, 2, + weightedSubchannel2, 1); + Map pickCount = new HashMap<>(); + for (int i = 0; i < 1000; i++) { + PickResult pickResult = weightedPicker.pickSubchannel(mockArgs); + pickCount.put(pickResult.getSubchannel(), + pickCount.getOrDefault(pickResult.getSubchannel(), 0) + 1); + assertThat(pickResult.getStreamTracerFactory()).isNotNull(); + WrrSubchannel subchannel = (WrrSubchannel)pickResult.getSubchannel(); + subchannel.onLoadReport(InternalCallMetricRecorder.createMetricReport( + 0.1, 0.1, qpsByChannel.get(subchannel), new HashMap<>(), new HashMap<>())); + } + assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 1.0 / 2)) + .isAtMost(0.1); + assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 1.0 / 2)) + .isAtMost(0.1); + pickCount.clear(); + for (int i = 0; i < 1000; i++) { + PickResult pickResult = weightedPicker.pickSubchannel(mockArgs); + pickCount.put(pickResult.getSubchannel(), + pickCount.getOrDefault(pickResult.getSubchannel(), 0) + 1); + assertThat(pickResult.getStreamTracerFactory()).isNotNull(); + WrrSubchannel subchannel = (WrrSubchannel)pickResult.getSubchannel(); + subchannel.onLoadReport(InternalCallMetricRecorder.createMetricReport( + 0.1, 0.1, qpsByChannel.get(subchannel), new HashMap<>(), new HashMap<>())); + fakeClock.forwardTime(50, TimeUnit.MILLISECONDS); + } + assertThat(pickCount.size()).isEqualTo(2); + assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 2.0 / 3)) + .isAtMost(0.1); + assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 1.0 / 3)) + .isAtMost(0.1); + } + @Test public void unknownWeightIsAvgWeight() { syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() @@ -584,7 +641,6 @@ public class WeightedRoundRobinLoadBalancerTest { 0.1, 0.1, 1, new HashMap<>(), new HashMap<>())); weightedSubchannel2.onLoadReport(InternalCallMetricRecorder.createMetricReport( 0.2, 0.1, 1, new HashMap<>(), new HashMap<>())); - assertThat(weightedPicker.toString()).contains("rrMode=true"); CyclicBarrier barrier = new CyclicBarrier(2); Map pickCount = new ConcurrentHashMap<>(); pickCount.put(weightedSubchannel1, new AtomicInteger(0));