From 0b53dd7304365e20e4523ab0cc3cd113cfae8133 Mon Sep 17 00:00:00 2001 From: Tony An <40644135+tonyjongyoonan@users.noreply.github.com> Date: Thu, 6 Jul 2023 10:03:08 -0700 Subject: [PATCH] implemented and tested static stride scheduler for weighted round robin load balancing policy (#10272) --- .../xds/WeightedRoundRobinLoadBalancer.java | 220 ++++++------ .../WeightedRoundRobinLoadBalancerTest.java | 317 +++++++++++++++--- 2 files changed, 388 insertions(+), 149 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java index 48442a84b2..d5d8c4d9e4 100644 --- a/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java @@ -44,10 +44,10 @@ import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.PriorityQueue; import java.util.Random; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import java.util.logging.Level; import java.util.logging.Logger; @@ -120,7 +120,7 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer { @Override public void run() { if (currentPicker != null && currentPicker instanceof WeightedRoundRobinPicker) { - ((WeightedRoundRobinPicker)currentPicker).updateWeight(); + ((WeightedRoundRobinPicker) currentPicker).updateWeight(); } weightUpdateTimer = syncContext.schedule(this, config.weightUpdatePeriodNanos, TimeUnit.NANOSECONDS, timeService); @@ -258,7 +258,7 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer { new HashMap<>(); private final boolean enableOobLoadReport; private final float errorUtilizationPenalty; - private volatile EdfScheduler scheduler; + private volatile StaticStrideScheduler scheduler; WeightedRoundRobinPicker(List list, boolean enableOobLoadReport, float errorUtilizationPenalty) { @@ -279,7 +279,7 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer { Subchannel subchannel = list.get(scheduler.pick()); if (!enableOobLoadReport) { return PickResult.withSubchannel(subchannel, - OrcaPerRequestUtil.getInstance().newOrcaClientStreamTracerFactory( + OrcaPerRequestUtil.getInstance().newOrcaClientStreamTracerFactory( subchannelToReportListenerMap.getOrDefault(subchannel, ((WrrSubchannel) subchannel).new OrcaReportListener(errorUtilizationPenalty)))); } else { @@ -288,26 +288,14 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer { } private void updateWeight() { - int weightedChannelCount = 0; - double avgWeight = 0; - for (Subchannel value : list) { - double newWeight = ((WrrSubchannel) value).getWeight(); - if (newWeight > 0) { - avgWeight += newWeight; - weightedChannelCount++; - } - } - EdfScheduler scheduler = new EdfScheduler(list.size(), random); - if (weightedChannelCount >= 1) { - avgWeight /= 1.0 * weightedChannelCount; - } else { - avgWeight = 1; - } + float[] newWeights = new float[list.size()]; for (int i = 0; i < list.size(); i++) { WrrSubchannel subchannel = (WrrSubchannel) list.get(i); double newWeight = subchannel.getWeight(); - scheduler.add(i, newWeight > 0 ? newWeight : avgWeight); + newWeights[i] = newWeight > 0 ? (float) newWeight : 0.0f; } + + StaticStrideScheduler scheduler = new StaticStrideScheduler(newWeights, random); this.scheduler = scheduler; } @@ -340,111 +328,125 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer { } } - /** - * The earliest deadline first implementation in which each object is - * chosen deterministically and periodically with frequency proportional to its weight. - * - *

Specifically, each object added to chooser is given a deadline equal to the multiplicative - * inverse of its weight. The place of each object in its deadline is tracked, and each call to - * choose returns the object with the least remaining time in its deadline. - * (Ties are broken by the order in which the children were added to the chooser.) The deadline - * advances by the multiplicative inverse of the object's weight. - * For example, if items A and B are added with weights 0.5 and 0.2, successive chooses return: + /* + * The Static Stride Scheduler is an implementation of an earliest deadline first (EDF) scheduler + * in which each object's deadline is the multiplicative inverse of the object's weight. + *

+ * The way in which this is implemented is through a static stride scheduler. + * The Static Stride Scheduler works by iterating through the list of subchannel weights + * and using modular arithmetic to proportionally distribute picks, favoring entries + * with higher weights. It is based on the observation that the intended sequence generated + * from an EDF scheduler is a periodic one that can be achieved through modular arithmetic. + * The Static Stride Scheduler is more performant than other implementations of the EDF + * Scheduler, as it removes the need for a priority queue (and thus mutex locks). + *

+ * go/static-stride-scheduler + *

* *

- * - *

In short: the entry with the highest weight is preferred. - * - *

- * + *
  • nextSequence() - O(1) + *
  • pick() - O(n) */ @VisibleForTesting - static final class EdfScheduler { - private final PriorityQueue prioQueue; + static final class StaticStrideScheduler { + private final short[] scaledWeights; + private final int sizeDivisor; + private final AtomicInteger sequence; + private static final int K_MAX_WEIGHT = 0xFFFF; - /** - * Weights below this value will be upped to this minimum weight. - */ - private static final double MINIMUM_WEIGHT = 0.0001; - - private final Object lock = new Object(); - - private final Random random; - - /** - * Use the item's deadline as the order in the priority queue. If the deadlines are the same, - * use the index. Index should be unique. - */ - EdfScheduler(int initialCapacity, Random random) { - this.prioQueue = new PriorityQueue(initialCapacity, (o1, o2) -> { - if (o1.deadline == o2.deadline) { - return Integer.compare(o1.index, o2.index); - } else { - return Double.compare(o1.deadline, o2.deadline); + StaticStrideScheduler(float[] weights, Random random) { + checkArgument(weights.length >= 1, "Couldn't build scheduler: requires at least one weight"); + int numChannels = weights.length; + int numWeightedChannels = 0; + double sumWeight = 0; + float maxWeight = 0; + short meanWeight = 0; + for (float weight : weights) { + if (weight > 0) { + sumWeight += weight; + maxWeight = Math.max(weight, maxWeight); + numWeightedChannels++; } - }); - this.random = random; + } + + double scalingFactor = K_MAX_WEIGHT / maxWeight; + if (numWeightedChannels > 0) { + meanWeight = (short) Math.round(scalingFactor * sumWeight / numWeightedChannels); + } else { + meanWeight = 1; + } + + // scales weights s.t. max(weights) == K_MAX_WEIGHT, meanWeight is scaled accordingly + short[] scaledWeights = new short[numChannels]; + for (int i = 0; i < numChannels; i++) { + if (weights[i] <= 0) { + scaledWeights[i] = meanWeight; + } else { + scaledWeights[i] = (short) Math.round(weights[i] * scalingFactor); + } + } + + this.scaledWeights = scaledWeights; + this.sizeDivisor = numChannels; + this.sequence = new AtomicInteger(random.nextInt()); + } - /** - * Adds the item in the scheduler. This is not thread safe. - * - * @param index The field {@link ObjectState#index} to be added - * @param weight positive weight for the added object - */ - void add(int index, double weight) { - checkArgument(weight > 0.0, "Weights need to be positive."); - ObjectState state = new ObjectState(Math.max(weight, MINIMUM_WEIGHT), index); - // Randomize the initial deadline. - state.deadline = random.nextDouble() * (1 / state.weight); - prioQueue.add(state); + /** Returns the next sequence number and atomically increases sequence with wraparound. */ + private long nextSequence() { + return Integer.toUnsignedLong(sequence.getAndIncrement()); } - /** - * Picks the next WRR object. + @VisibleForTesting + long getSequence() { + return Integer.toUnsignedLong(sequence.get()); + } + + /* + * Selects index of next backend server. + *

    + * A 2D array is compactly represented as a function of W(backend), where the row + * represents the generation and the column represents the backend index: + * X(backend,generation) | generation ∈ [0,kMaxWeight). + * Each element in the conceptual array is a boolean indicating whether the backend at + * this index should be picked now. If false, the counter is incremented again, + * and the new element is checked. An atomically incremented counter keeps track of our + * backend and generation through modular arithmetic within the pick() method. + *

    + * Modular arithmetic allows us to evenly distribute picks and skips between + * generations based on W(backend). + * X(backend,generation) = (W(backend) * generation) % kMaxWeight >= kMaxWeight - W(backend) + * If we have the same three backends with weights: + * W(backend) = {2,3,6} scaled to max(W(backend)) = 6, then X(backend,generation) is: + *

    + * B0 B1 B2 + * T T T + * F F T + * F T T + * T F T + * F T T + * F F T + * The sequence of picked backend indices is given by + * walking across and down: {0,1,2,2,1,2,0,2,1,2,2}. + *

    + * To reduce the variance and spread the wasted work among different picks, + * an offset that varies per backend index is also included to the calculation. */ int pick() { - synchronized (lock) { - ObjectState minObject = prioQueue.remove(); - minObject.deadline += 1.0 / minObject.weight; - prioQueue.add(minObject); - return minObject.index; + while (true) { + long sequence = this.nextSequence(); + int backendIndex = (int) (sequence % this.sizeDivisor); + long generation = sequence / this.sizeDivisor; + int weight = Short.toUnsignedInt(this.scaledWeights[backendIndex]); + long offset = (long) K_MAX_WEIGHT / 2 * backendIndex; + if ((weight * generation + offset) % K_MAX_WEIGHT < K_MAX_WEIGHT - weight) { + continue; + } + return backendIndex; } } } - /** Holds the state of the object. */ - @VisibleForTesting - static class ObjectState { - private final double weight; - private final int index; - private volatile double deadline; - - ObjectState(double weight, int index) { - this.weight = weight; - this.index = index; - } - } - static final class WeightedRoundRobinLoadBalancerConfig { final long blackoutPeriodNanos; final long weightExpirationPeriodNanos; diff --git a/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java index daf58a174d..58a19af96a 100644 --- a/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java @@ -52,7 +52,7 @@ import io.grpc.SynchronizationContext; import io.grpc.internal.FakeClock; import io.grpc.services.InternalCallMetricRecorder; import io.grpc.services.MetricReport; -import io.grpc.xds.WeightedRoundRobinLoadBalancer.EdfScheduler; +import io.grpc.xds.WeightedRoundRobinLoadBalancer.StaticStrideScheduler; import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedRoundRobinLoadBalancerConfig; import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedRoundRobinPicker; import io.grpc.xds.WeightedRoundRobinLoadBalancer.WrrSubchannel; @@ -175,7 +175,7 @@ public class WeightedRoundRobinLoadBalancerTest { } }); wrr = new WeightedRoundRobinLoadBalancer(helper, fakeClock.getDeadlineTicker(), - new FakeRandom()); + new FakeRandom(0)); } @Test @@ -220,7 +220,7 @@ public class WeightedRoundRobinLoadBalancerTest { 0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>())); assertThat(fakeClock.forwardTime(11, TimeUnit.SECONDS)).isEqualTo(1); assertThat(weightedPicker.pickSubchannel(mockArgs) - .getSubchannel()).isEqualTo(weightedSubchannel1); + .getSubchannel()).isEqualTo(weightedSubchannel1); assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); weightedConfig = WeightedRoundRobinLoadBalancerConfig.newBuilder() .setWeightUpdatePeriodNanos(500_000_000L) //.5s @@ -338,7 +338,7 @@ public class WeightedRoundRobinLoadBalancerTest { } @Test - public void pickByWeight_LargeWeight() { + public void pickByWeight_largeWeight() { MetricReport report1 = InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 999, 0, new HashMap<>(), new HashMap<>()); MetricReport report2 = InternalCallMetricRecorder.createMetricReport( @@ -593,6 +593,7 @@ public class WeightedRoundRobinLoadBalancerTest { assertThat(fakeClock.forwardTime(500, TimeUnit.MILLISECONDS)).isEqualTo(1); assertThat(weightedPicker.pickSubchannel(mockArgs) .getSubchannel()).isEqualTo(weightedSubchannel2); + } @Test @@ -750,12 +751,12 @@ public class WeightedRoundRobinLoadBalancerTest { } assertThat(pickCount.size()).isEqualTo(3); assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 4.0 / 9)) - .isAtMost(0.001); + .isAtMost(0.002); assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 2.0 / 9)) - .isAtMost(0.001); + .isAtMost(0.002); // subchannel3's weight is average of subchannel1 and subchannel2 assertThat(Math.abs(pickCount.get(weightedSubchannel3) / 1000.0 - 3.0 / 9)) - .isAtMost(0.001); + .isAtMost(0.002); } @Test @@ -821,37 +822,6 @@ public class WeightedRoundRobinLoadBalancerTest { .isAtMost(0.001); } - @Test - public void edfScheduler() { - Random random = new Random(); - double totalWeight = 0; - int capacity = random.nextInt(10) + 1; - double[] weights = new double[capacity]; - EdfScheduler scheduler = new EdfScheduler(capacity, random); - for (int i = 0; i < capacity; i++) { - weights[i] = random.nextDouble(); - scheduler.add(i, weights[i]); - totalWeight += weights[i]; - } - Map pickCount = new HashMap<>(); - for (int i = 0; i < 1000; i++) { - int result = scheduler.pick(); - pickCount.put(result, pickCount.getOrDefault(result, 0) + 1); - } - for (int i = 0; i < capacity; i++) { - assertThat(Math.abs(pickCount.getOrDefault(i, 0) / 1000.0 - weights[i] / totalWeight) ) - .isAtMost(0.01); - } - } - - @Test - public void edsScheduler_sameWeight() { - EdfScheduler scheduler = new EdfScheduler(2, new FakeRandom()); - scheduler.add(0, 0.5); - scheduler.add(1, 0.5); - assertThat(scheduler.pick()).isEqualTo(0); - } - @Test(expected = NullPointerException.class) public void wrrConfig_TimeValueNonNull() { WeightedRoundRobinLoadBalancerConfig.newBuilder().setBlackoutPeriodNanos((Long) null); @@ -862,6 +832,267 @@ public class WeightedRoundRobinLoadBalancerTest { WeightedRoundRobinLoadBalancerConfig.newBuilder().setEnableOobLoadReport((Boolean) null); } + @Test(expected = IllegalArgumentException.class) + public void emptyWeights() { + float[] weights = {}; + Random random = new Random(); + StaticStrideScheduler sss = new StaticStrideScheduler(weights, random); + sss.pick(); + } + + @Test + public void testPicksEqualsWeights() { + float[] weights = {1.0f, 2.0f, 3.0f}; + Random random = new Random(); + StaticStrideScheduler sss = new StaticStrideScheduler(weights, random); + int[] expectedPicks = new int[] {1, 2, 3}; + int[] picks = new int[3]; + for (int i = 0; i < 6; i++) { + picks[sss.pick()] += 1; + } + assertThat(picks).isEqualTo(expectedPicks); + } + + @Test + public void testContainsZeroWeightUseMean() { + float[] weights = {3.0f, 0.0f, 1.0f}; + Random random = new Random(); + StaticStrideScheduler sss = new StaticStrideScheduler(weights, random); + int[] expectedPicks = new int[] {3, 2, 1}; + int[] picks = new int[3]; + for (int i = 0; i < 6; i++) { + picks[sss.pick()] += 1; + } + assertThat(picks).isEqualTo(expectedPicks); + } + + @Test + public void testContainsNegativeWeightUseMean() { + float[] weights = {3.0f, -1.0f, 1.0f}; + Random random = new Random(); + StaticStrideScheduler sss = new StaticStrideScheduler(weights, random); + int[] expectedPicks = new int[] {3, 2, 1}; + int[] picks = new int[3]; + for (int i = 0; i < 6; i++) { + picks[sss.pick()] += 1; + } + assertThat(picks).isEqualTo(expectedPicks); + } + + @Test + public void testAllSameWeights() { + float[] weights = {1.0f, 1.0f, 1.0f}; + Random random = new Random(); + StaticStrideScheduler sss = new StaticStrideScheduler(weights, random); + int[] expectedPicks = new int[] {2, 2, 2}; + int[] picks = new int[3]; + for (int i = 0; i < 6; i++) { + picks[sss.pick()] += 1; + } + assertThat(picks).isEqualTo(expectedPicks); + } + + @Test + public void testAllZeroWeightsUseOne() { + float[] weights = {0.0f, 0.0f, 0.0f}; + Random random = new Random(); + StaticStrideScheduler sss = new StaticStrideScheduler(weights, random); + int[] expectedPicks = new int[] {2, 2, 2}; + int[] picks = new int[3]; + for (int i = 0; i < 6; i++) { + picks[sss.pick()] += 1; + } + assertThat(picks).isEqualTo(expectedPicks); + } + + @Test + public void testAllInvalidWeightsUseOne() { + float[] weights = {-3.1f, -0.0f, 0.0f}; + Random random = new Random(); + StaticStrideScheduler sss = new StaticStrideScheduler(weights, random); + int[] expectedPicks = new int[] {2, 2, 2}; + int[] picks = new int[3]; + for (int i = 0; i < 6; i++) { + picks[sss.pick()] += 1; + } + assertThat(picks).isEqualTo(expectedPicks); + } + + @Test + public void testLargestWeightIndexPickedEveryGeneration() { + float[] weights = {1.0f, 2.0f, 3.0f}; + int largestWeightIndex = 2; + Random random = new Random(); + StaticStrideScheduler sss = new StaticStrideScheduler(weights, random); + int largestWeightPickCount = 0; + int kMaxWeight = 65535; + for (int i = 0; i < largestWeightIndex * kMaxWeight; i++) { + if (sss.pick() == largestWeightIndex) { + largestWeightPickCount += 1; + } + } + assertThat(largestWeightPickCount).isEqualTo(kMaxWeight); + } + + @Test + public void testStaticStrideSchedulerNonIntegers1() { + float[] weights = {2.0f, (float) (10.0 / 3.0), 1.0f}; + Random random = new Random(); + StaticStrideScheduler sss = new StaticStrideScheduler(weights, random); + double totalWeight = 2 + 10.0 / 3.0 + 1.0; + Map pickCount = new HashMap<>(); + for (int i = 0; i < 1000; i++) { + int result = sss.pick(); + pickCount.put(result, pickCount.getOrDefault(result, 0) + 1); + } + for (int i = 0; i < 3; i++) { + assertThat(Math.abs(pickCount.getOrDefault(i, 0) / 1000.0 - weights[i] / totalWeight)) + .isAtMost(0.01); + } + } + + @Test + public void testStaticStrideSchedulerNonIntegers2() { + float[] weights = {0.5f, 0.3f, 1.0f}; + Random random = new Random(); + StaticStrideScheduler sss = new StaticStrideScheduler(weights, random); + double totalWeight = 1.8; + Map pickCount = new HashMap<>(); + for (int i = 0; i < 1000; i++) { + int result = sss.pick(); + pickCount.put(result, pickCount.getOrDefault(result, 0) + 1); + } + for (int i = 0; i < 3; i++) { + assertThat(Math.abs(pickCount.getOrDefault(i, 0) / 1000.0 - weights[i] / totalWeight)) + .isAtMost(0.01); + } + } + + @Test + public void testTwoWeights() { + float[] weights = {1.0f, 2.0f}; + Random random = new Random(); + StaticStrideScheduler sss = new StaticStrideScheduler(weights, random); + double totalWeight = 3; + Map pickCount = new HashMap<>(); + for (int i = 0; i < 1000; i++) { + int result = sss.pick(); + pickCount.put(result, pickCount.getOrDefault(result, 0) + 1); + } + for (int i = 0; i < 2; i++) { + assertThat(Math.abs(pickCount.getOrDefault(i, 0) / 1000.0 - weights[i] / totalWeight)) + .isAtMost(0.01); + } + } + + @Test + public void testManyWeights() { + float[] weights = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}; + Random random = new Random(); + StaticStrideScheduler sss = new StaticStrideScheduler(weights, random); + double totalWeight = 15; + Map pickCount = new HashMap<>(); + for (int i = 0; i < 1000; i++) { + int result = sss.pick(); + pickCount.put(result, pickCount.getOrDefault(result, 0) + 1); + } + for (int i = 0; i < 5; i++) { + assertThat(Math.abs(pickCount.getOrDefault(i, 0) / 1000.0 - weights[i] / totalWeight)) + .isAtMost(0.0011); + } + } + + @Test + public void testManyComplexWeights() { + float[] weights = {1.2f, 2.4f, 222.56f, 1.1f, 15.0f, 226342.0f, 5123.0f, 532.2f}; + Random random = new Random(); + StaticStrideScheduler sss = new StaticStrideScheduler(weights, random); + double totalWeight = 1.2 + 2.4 + 222.56 + 15.0 + 226342.0 + 5123.0 + 0.0001; + Map pickCount = new HashMap<>(); + for (int i = 0; i < 1000; i++) { + int result = sss.pick(); + pickCount.put(result, pickCount.getOrDefault(result, 0) + 1); + } + for (int i = 0; i < 8; i++) { + assertThat(Math.abs(pickCount.getOrDefault(i, 0) / 1000.0 - weights[i] / totalWeight)) + .isAtMost(0.01); + } + } + + @Test + public void testDeterministicPicks() { + float[] weights = {2.0f, 3.0f, 6.0f}; + Random random = new FakeRandom(0); + StaticStrideScheduler sss = new StaticStrideScheduler(weights, random); + assertThat(sss.getSequence()).isEqualTo(0); + assertThat(sss.pick()).isEqualTo(1); + assertThat(sss.getSequence()).isEqualTo(2); + assertThat(sss.pick()).isEqualTo(2); + assertThat(sss.getSequence()).isEqualTo(3); + assertThat(sss.pick()).isEqualTo(2); + assertThat(sss.getSequence()).isEqualTo(6); + assertThat(sss.pick()).isEqualTo(0); + assertThat(sss.getSequence()).isEqualTo(7); + assertThat(sss.pick()).isEqualTo(1); + assertThat(sss.getSequence()).isEqualTo(8); + assertThat(sss.pick()).isEqualTo(2); + assertThat(sss.getSequence()).isEqualTo(9); + } + + @Test + public void testImmediateWraparound() { + float[] weights = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}; + Random random = new FakeRandom(-1); + StaticStrideScheduler sss = new StaticStrideScheduler(weights, random); + double totalWeight = 15; + Map pickCount = new HashMap<>(); + for (int i = 0; i < 1000; i++) { + int result = sss.pick(); + pickCount.put(result, pickCount.getOrDefault(result, 0) + 1); + } + for (int i = 0; i < 5; i++) { + assertThat(Math.abs(pickCount.getOrDefault(i, 0) / 1000.0 - weights[i] / totalWeight)) + .isAtMost(0.001); + } + } + + @Test + public void testWraparound() { + float[] weights = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}; + Random random = new FakeRandom(-500); + StaticStrideScheduler sss = new StaticStrideScheduler(weights, random); + double totalWeight = 15; + Map pickCount = new HashMap<>(); + for (int i = 0; i < 1000; i++) { + int result = sss.pick(); + pickCount.put(result, pickCount.getOrDefault(result, 0) + 1); + } + for (int i = 0; i < 5; i++) { + assertThat(Math.abs(pickCount.getOrDefault(i, 0) / 1000.0 - weights[i] / totalWeight)) + .isAtMost(0.0011); + } + } + + @Test + public void testDeterministicWraparound() { + float[] weights = {2.0f, 3.0f, 6.0f}; + Random random = new FakeRandom(-1); + StaticStrideScheduler sss = new StaticStrideScheduler(weights, random); + assertThat(sss.getSequence()).isEqualTo(0xFFFF_FFFFL); + assertThat(sss.pick()).isEqualTo(1); + assertThat(sss.getSequence()).isEqualTo(2); + assertThat(sss.pick()).isEqualTo(2); + assertThat(sss.getSequence()).isEqualTo(3); + assertThat(sss.pick()).isEqualTo(2); + assertThat(sss.getSequence()).isEqualTo(6); + assertThat(sss.pick()).isEqualTo(0); + assertThat(sss.getSequence()).isEqualTo(7); + assertThat(sss.pick()).isEqualTo(1); + assertThat(sss.getSequence()).isEqualTo(8); + assertThat(sss.pick()).isEqualTo(2); + assertThat(sss.getSequence()).isEqualTo(9); + } + private static class FakeSocketAddress extends SocketAddress { final String name; @@ -875,10 +1106,16 @@ public class WeightedRoundRobinLoadBalancerTest { } private static class FakeRandom extends Random { + private int nextInt; + + public FakeRandom(int nextInt) { + this.nextInt = nextInt; + } + @Override - public double nextDouble() { + public int nextInt() { // return constant value to disable init deadline randomization in the scheduler - return 0.322023; + return nextInt; } } }