implemented and tested static stride scheduler for weighted round robin load balancing policy (#10272)

This commit is contained in:
Tony An 2023-07-06 10:03:08 -07:00 committed by GitHub
parent 361616ae7c
commit 0b53dd7304
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 388 additions and 149 deletions

View File

@ -44,10 +44,10 @@ import java.util.HashMap;
import java.util.HashSet; import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.PriorityQueue;
import java.util.Random; import java.util.Random;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.logging.Level; import java.util.logging.Level;
import java.util.logging.Logger; import java.util.logging.Logger;
@ -120,7 +120,7 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
@Override @Override
public void run() { public void run() {
if (currentPicker != null && currentPicker instanceof WeightedRoundRobinPicker) { if (currentPicker != null && currentPicker instanceof WeightedRoundRobinPicker) {
((WeightedRoundRobinPicker)currentPicker).updateWeight(); ((WeightedRoundRobinPicker) currentPicker).updateWeight();
} }
weightUpdateTimer = syncContext.schedule(this, config.weightUpdatePeriodNanos, weightUpdateTimer = syncContext.schedule(this, config.weightUpdatePeriodNanos,
TimeUnit.NANOSECONDS, timeService); TimeUnit.NANOSECONDS, timeService);
@ -258,7 +258,7 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
new HashMap<>(); new HashMap<>();
private final boolean enableOobLoadReport; private final boolean enableOobLoadReport;
private final float errorUtilizationPenalty; private final float errorUtilizationPenalty;
private volatile EdfScheduler scheduler; private volatile StaticStrideScheduler scheduler;
WeightedRoundRobinPicker(List<Subchannel> list, boolean enableOobLoadReport, WeightedRoundRobinPicker(List<Subchannel> list, boolean enableOobLoadReport,
float errorUtilizationPenalty) { float errorUtilizationPenalty) {
@ -279,7 +279,7 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
Subchannel subchannel = list.get(scheduler.pick()); Subchannel subchannel = list.get(scheduler.pick());
if (!enableOobLoadReport) { if (!enableOobLoadReport) {
return PickResult.withSubchannel(subchannel, return PickResult.withSubchannel(subchannel,
OrcaPerRequestUtil.getInstance().newOrcaClientStreamTracerFactory( OrcaPerRequestUtil.getInstance().newOrcaClientStreamTracerFactory(
subchannelToReportListenerMap.getOrDefault(subchannel, subchannelToReportListenerMap.getOrDefault(subchannel,
((WrrSubchannel) subchannel).new OrcaReportListener(errorUtilizationPenalty)))); ((WrrSubchannel) subchannel).new OrcaReportListener(errorUtilizationPenalty))));
} else { } else {
@ -288,26 +288,14 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
} }
private void updateWeight() { private void updateWeight() {
int weightedChannelCount = 0; float[] newWeights = new float[list.size()];
double avgWeight = 0;
for (Subchannel value : list) {
double newWeight = ((WrrSubchannel) value).getWeight();
if (newWeight > 0) {
avgWeight += newWeight;
weightedChannelCount++;
}
}
EdfScheduler scheduler = new EdfScheduler(list.size(), random);
if (weightedChannelCount >= 1) {
avgWeight /= 1.0 * weightedChannelCount;
} else {
avgWeight = 1;
}
for (int i = 0; i < list.size(); i++) { for (int i = 0; i < list.size(); i++) {
WrrSubchannel subchannel = (WrrSubchannel) list.get(i); WrrSubchannel subchannel = (WrrSubchannel) list.get(i);
double newWeight = subchannel.getWeight(); double newWeight = subchannel.getWeight();
scheduler.add(i, newWeight > 0 ? newWeight : avgWeight); newWeights[i] = newWeight > 0 ? (float) newWeight : 0.0f;
} }
StaticStrideScheduler scheduler = new StaticStrideScheduler(newWeights, random);
this.scheduler = scheduler; this.scheduler = scheduler;
} }
@ -340,111 +328,125 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
} }
} }
/** /*
* The earliest deadline first implementation in which each object is * The Static Stride Scheduler is an implementation of an earliest deadline first (EDF) scheduler
* chosen deterministically and periodically with frequency proportional to its weight. * in which each object's deadline is the multiplicative inverse of the object's weight.
* * <p>
* <p>Specifically, each object added to chooser is given a deadline equal to the multiplicative * The way in which this is implemented is through a static stride scheduler.
* inverse of its weight. The place of each object in its deadline is tracked, and each call to * The Static Stride Scheduler works by iterating through the list of subchannel weights
* choose returns the object with the least remaining time in its deadline. * and using modular arithmetic to proportionally distribute picks, favoring entries
* (Ties are broken by the order in which the children were added to the chooser.) The deadline * with higher weights. It is based on the observation that the intended sequence generated
* advances by the multiplicative inverse of the object's weight. * from an EDF scheduler is a periodic one that can be achieved through modular arithmetic.
* For example, if items A and B are added with weights 0.5 and 0.2, successive chooses return: * The Static Stride Scheduler is more performant than other implementations of the EDF
* Scheduler, as it removes the need for a priority queue (and thus mutex locks).
* <p>
* go/static-stride-scheduler
* <p>
* *
* <ul> * <ul>
* <li>In the first call, the deadlines are A=2 (1/0.5) and B=5 (1/0.2), so A is returned. * <li>nextSequence() - O(1)
* The deadline of A is updated to 4. * <li>pick() - O(n)
* <li>Next, the remaining deadlines are A=4 and B=5, so A is returned. The deadline of A (2) is
* updated to A=6.
* <li>Remaining deadlines are A=6 and B=5, so B is returned. The deadline of B is updated with
* with B=10.
* <li>Remaining deadlines are A=6 and B=10, so A is returned. The deadline of A is updated with
* A=8.
* <li>Remaining deadlines are A=8 and B=10, so A is returned. The deadline of A is updated with
* A=10.
* <li>Remaining deadlines are A=10 and B=10, so A is returned. The deadline of A is updated
* with A=12.
* <li>Remaining deadlines are A=12 and B=10, so B is returned. The deadline of B is updated
* with B=15.
* <li>etc.
* </ul>
*
* <p>In short: the entry with the highest weight is preferred.
*
* <ul>
* <li>add() - O(lg n)
* <li>pick() - O(lg n)
* </ul>
*
*/ */
@VisibleForTesting @VisibleForTesting
static final class EdfScheduler { static final class StaticStrideScheduler {
private final PriorityQueue<ObjectState> prioQueue; private final short[] scaledWeights;
private final int sizeDivisor;
private final AtomicInteger sequence;
private static final int K_MAX_WEIGHT = 0xFFFF;
/** StaticStrideScheduler(float[] weights, Random random) {
* Weights below this value will be upped to this minimum weight. checkArgument(weights.length >= 1, "Couldn't build scheduler: requires at least one weight");
*/ int numChannels = weights.length;
private static final double MINIMUM_WEIGHT = 0.0001; int numWeightedChannels = 0;
double sumWeight = 0;
private final Object lock = new Object(); float maxWeight = 0;
short meanWeight = 0;
private final Random random; for (float weight : weights) {
if (weight > 0) {
/** sumWeight += weight;
* Use the item's deadline as the order in the priority queue. If the deadlines are the same, maxWeight = Math.max(weight, maxWeight);
* use the index. Index should be unique. numWeightedChannels++;
*/
EdfScheduler(int initialCapacity, Random random) {
this.prioQueue = new PriorityQueue<ObjectState>(initialCapacity, (o1, o2) -> {
if (o1.deadline == o2.deadline) {
return Integer.compare(o1.index, o2.index);
} else {
return Double.compare(o1.deadline, o2.deadline);
} }
}); }
this.random = random;
double scalingFactor = K_MAX_WEIGHT / maxWeight;
if (numWeightedChannels > 0) {
meanWeight = (short) Math.round(scalingFactor * sumWeight / numWeightedChannels);
} else {
meanWeight = 1;
}
// scales weights s.t. max(weights) == K_MAX_WEIGHT, meanWeight is scaled accordingly
short[] scaledWeights = new short[numChannels];
for (int i = 0; i < numChannels; i++) {
if (weights[i] <= 0) {
scaledWeights[i] = meanWeight;
} else {
scaledWeights[i] = (short) Math.round(weights[i] * scalingFactor);
}
}
this.scaledWeights = scaledWeights;
this.sizeDivisor = numChannels;
this.sequence = new AtomicInteger(random.nextInt());
} }
/** /** Returns the next sequence number and atomically increases sequence with wraparound. */
* Adds the item in the scheduler. This is not thread safe. private long nextSequence() {
* return Integer.toUnsignedLong(sequence.getAndIncrement());
* @param index The field {@link ObjectState#index} to be added
* @param weight positive weight for the added object
*/
void add(int index, double weight) {
checkArgument(weight > 0.0, "Weights need to be positive.");
ObjectState state = new ObjectState(Math.max(weight, MINIMUM_WEIGHT), index);
// Randomize the initial deadline.
state.deadline = random.nextDouble() * (1 / state.weight);
prioQueue.add(state);
} }
/** @VisibleForTesting
* Picks the next WRR object. long getSequence() {
return Integer.toUnsignedLong(sequence.get());
}
/*
* Selects index of next backend server.
* <p>
* A 2D array is compactly represented as a function of W(backend), where the row
* represents the generation and the column represents the backend index:
* X(backend,generation) | generation [0,kMaxWeight).
* Each element in the conceptual array is a boolean indicating whether the backend at
* this index should be picked now. If false, the counter is incremented again,
* and the new element is checked. An atomically incremented counter keeps track of our
* backend and generation through modular arithmetic within the pick() method.
* <p>
* Modular arithmetic allows us to evenly distribute picks and skips between
* generations based on W(backend).
* X(backend,generation) = (W(backend) * generation) % kMaxWeight >= kMaxWeight - W(backend)
* If we have the same three backends with weights:
* W(backend) = {2,3,6} scaled to max(W(backend)) = 6, then X(backend,generation) is:
* <p>
* B0 B1 B2
* T T T
* F F T
* F T T
* T F T
* F T T
* F F T
* The sequence of picked backend indices is given by
* walking across and down: {0,1,2,2,1,2,0,2,1,2,2}.
* <p>
* To reduce the variance and spread the wasted work among different picks,
* an offset that varies per backend index is also included to the calculation.
*/ */
int pick() { int pick() {
synchronized (lock) { while (true) {
ObjectState minObject = prioQueue.remove(); long sequence = this.nextSequence();
minObject.deadline += 1.0 / minObject.weight; int backendIndex = (int) (sequence % this.sizeDivisor);
prioQueue.add(minObject); long generation = sequence / this.sizeDivisor;
return minObject.index; int weight = Short.toUnsignedInt(this.scaledWeights[backendIndex]);
long offset = (long) K_MAX_WEIGHT / 2 * backendIndex;
if ((weight * generation + offset) % K_MAX_WEIGHT < K_MAX_WEIGHT - weight) {
continue;
}
return backendIndex;
} }
} }
} }
/** Holds the state of the object. */
@VisibleForTesting
static class ObjectState {
private final double weight;
private final int index;
private volatile double deadline;
ObjectState(double weight, int index) {
this.weight = weight;
this.index = index;
}
}
static final class WeightedRoundRobinLoadBalancerConfig { static final class WeightedRoundRobinLoadBalancerConfig {
final long blackoutPeriodNanos; final long blackoutPeriodNanos;
final long weightExpirationPeriodNanos; final long weightExpirationPeriodNanos;

View File

@ -52,7 +52,7 @@ import io.grpc.SynchronizationContext;
import io.grpc.internal.FakeClock; import io.grpc.internal.FakeClock;
import io.grpc.services.InternalCallMetricRecorder; import io.grpc.services.InternalCallMetricRecorder;
import io.grpc.services.MetricReport; import io.grpc.services.MetricReport;
import io.grpc.xds.WeightedRoundRobinLoadBalancer.EdfScheduler; import io.grpc.xds.WeightedRoundRobinLoadBalancer.StaticStrideScheduler;
import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedRoundRobinLoadBalancerConfig; import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedRoundRobinLoadBalancerConfig;
import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedRoundRobinPicker; import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedRoundRobinPicker;
import io.grpc.xds.WeightedRoundRobinLoadBalancer.WrrSubchannel; import io.grpc.xds.WeightedRoundRobinLoadBalancer.WrrSubchannel;
@ -175,7 +175,7 @@ public class WeightedRoundRobinLoadBalancerTest {
} }
}); });
wrr = new WeightedRoundRobinLoadBalancer(helper, fakeClock.getDeadlineTicker(), wrr = new WeightedRoundRobinLoadBalancer(helper, fakeClock.getDeadlineTicker(),
new FakeRandom()); new FakeRandom(0));
} }
@Test @Test
@ -220,7 +220,7 @@ public class WeightedRoundRobinLoadBalancerTest {
0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>())); 0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>()));
assertThat(fakeClock.forwardTime(11, TimeUnit.SECONDS)).isEqualTo(1); assertThat(fakeClock.forwardTime(11, TimeUnit.SECONDS)).isEqualTo(1);
assertThat(weightedPicker.pickSubchannel(mockArgs) assertThat(weightedPicker.pickSubchannel(mockArgs)
.getSubchannel()).isEqualTo(weightedSubchannel1); .getSubchannel()).isEqualTo(weightedSubchannel1);
assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1);
weightedConfig = WeightedRoundRobinLoadBalancerConfig.newBuilder() weightedConfig = WeightedRoundRobinLoadBalancerConfig.newBuilder()
.setWeightUpdatePeriodNanos(500_000_000L) //.5s .setWeightUpdatePeriodNanos(500_000_000L) //.5s
@ -338,7 +338,7 @@ public class WeightedRoundRobinLoadBalancerTest {
} }
@Test @Test
public void pickByWeight_LargeWeight() { public void pickByWeight_largeWeight() {
MetricReport report1 = InternalCallMetricRecorder.createMetricReport( MetricReport report1 = InternalCallMetricRecorder.createMetricReport(
0.1, 0, 0.1, 999, 0, new HashMap<>(), new HashMap<>()); 0.1, 0, 0.1, 999, 0, new HashMap<>(), new HashMap<>());
MetricReport report2 = InternalCallMetricRecorder.createMetricReport( MetricReport report2 = InternalCallMetricRecorder.createMetricReport(
@ -593,6 +593,7 @@ public class WeightedRoundRobinLoadBalancerTest {
assertThat(fakeClock.forwardTime(500, TimeUnit.MILLISECONDS)).isEqualTo(1); assertThat(fakeClock.forwardTime(500, TimeUnit.MILLISECONDS)).isEqualTo(1);
assertThat(weightedPicker.pickSubchannel(mockArgs) assertThat(weightedPicker.pickSubchannel(mockArgs)
.getSubchannel()).isEqualTo(weightedSubchannel2); .getSubchannel()).isEqualTo(weightedSubchannel2);
} }
@Test @Test
@ -750,12 +751,12 @@ public class WeightedRoundRobinLoadBalancerTest {
} }
assertThat(pickCount.size()).isEqualTo(3); assertThat(pickCount.size()).isEqualTo(3);
assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 4.0 / 9)) assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 4.0 / 9))
.isAtMost(0.001); .isAtMost(0.002);
assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 2.0 / 9)) assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 2.0 / 9))
.isAtMost(0.001); .isAtMost(0.002);
// subchannel3's weight is average of subchannel1 and subchannel2 // subchannel3's weight is average of subchannel1 and subchannel2
assertThat(Math.abs(pickCount.get(weightedSubchannel3) / 1000.0 - 3.0 / 9)) assertThat(Math.abs(pickCount.get(weightedSubchannel3) / 1000.0 - 3.0 / 9))
.isAtMost(0.001); .isAtMost(0.002);
} }
@Test @Test
@ -821,37 +822,6 @@ public class WeightedRoundRobinLoadBalancerTest {
.isAtMost(0.001); .isAtMost(0.001);
} }
@Test
public void edfScheduler() {
Random random = new Random();
double totalWeight = 0;
int capacity = random.nextInt(10) + 1;
double[] weights = new double[capacity];
EdfScheduler scheduler = new EdfScheduler(capacity, random);
for (int i = 0; i < capacity; i++) {
weights[i] = random.nextDouble();
scheduler.add(i, weights[i]);
totalWeight += weights[i];
}
Map<Integer, Integer> pickCount = new HashMap<>();
for (int i = 0; i < 1000; i++) {
int result = scheduler.pick();
pickCount.put(result, pickCount.getOrDefault(result, 0) + 1);
}
for (int i = 0; i < capacity; i++) {
assertThat(Math.abs(pickCount.getOrDefault(i, 0) / 1000.0 - weights[i] / totalWeight) )
.isAtMost(0.01);
}
}
@Test
public void edsScheduler_sameWeight() {
EdfScheduler scheduler = new EdfScheduler(2, new FakeRandom());
scheduler.add(0, 0.5);
scheduler.add(1, 0.5);
assertThat(scheduler.pick()).isEqualTo(0);
}
@Test(expected = NullPointerException.class) @Test(expected = NullPointerException.class)
public void wrrConfig_TimeValueNonNull() { public void wrrConfig_TimeValueNonNull() {
WeightedRoundRobinLoadBalancerConfig.newBuilder().setBlackoutPeriodNanos((Long) null); WeightedRoundRobinLoadBalancerConfig.newBuilder().setBlackoutPeriodNanos((Long) null);
@ -862,6 +832,267 @@ public class WeightedRoundRobinLoadBalancerTest {
WeightedRoundRobinLoadBalancerConfig.newBuilder().setEnableOobLoadReport((Boolean) null); WeightedRoundRobinLoadBalancerConfig.newBuilder().setEnableOobLoadReport((Boolean) null);
} }
@Test(expected = IllegalArgumentException.class)
public void emptyWeights() {
float[] weights = {};
Random random = new Random();
StaticStrideScheduler sss = new StaticStrideScheduler(weights, random);
sss.pick();
}
@Test
public void testPicksEqualsWeights() {
float[] weights = {1.0f, 2.0f, 3.0f};
Random random = new Random();
StaticStrideScheduler sss = new StaticStrideScheduler(weights, random);
int[] expectedPicks = new int[] {1, 2, 3};
int[] picks = new int[3];
for (int i = 0; i < 6; i++) {
picks[sss.pick()] += 1;
}
assertThat(picks).isEqualTo(expectedPicks);
}
@Test
public void testContainsZeroWeightUseMean() {
float[] weights = {3.0f, 0.0f, 1.0f};
Random random = new Random();
StaticStrideScheduler sss = new StaticStrideScheduler(weights, random);
int[] expectedPicks = new int[] {3, 2, 1};
int[] picks = new int[3];
for (int i = 0; i < 6; i++) {
picks[sss.pick()] += 1;
}
assertThat(picks).isEqualTo(expectedPicks);
}
@Test
public void testContainsNegativeWeightUseMean() {
float[] weights = {3.0f, -1.0f, 1.0f};
Random random = new Random();
StaticStrideScheduler sss = new StaticStrideScheduler(weights, random);
int[] expectedPicks = new int[] {3, 2, 1};
int[] picks = new int[3];
for (int i = 0; i < 6; i++) {
picks[sss.pick()] += 1;
}
assertThat(picks).isEqualTo(expectedPicks);
}
@Test
public void testAllSameWeights() {
float[] weights = {1.0f, 1.0f, 1.0f};
Random random = new Random();
StaticStrideScheduler sss = new StaticStrideScheduler(weights, random);
int[] expectedPicks = new int[] {2, 2, 2};
int[] picks = new int[3];
for (int i = 0; i < 6; i++) {
picks[sss.pick()] += 1;
}
assertThat(picks).isEqualTo(expectedPicks);
}
@Test
public void testAllZeroWeightsUseOne() {
float[] weights = {0.0f, 0.0f, 0.0f};
Random random = new Random();
StaticStrideScheduler sss = new StaticStrideScheduler(weights, random);
int[] expectedPicks = new int[] {2, 2, 2};
int[] picks = new int[3];
for (int i = 0; i < 6; i++) {
picks[sss.pick()] += 1;
}
assertThat(picks).isEqualTo(expectedPicks);
}
@Test
public void testAllInvalidWeightsUseOne() {
float[] weights = {-3.1f, -0.0f, 0.0f};
Random random = new Random();
StaticStrideScheduler sss = new StaticStrideScheduler(weights, random);
int[] expectedPicks = new int[] {2, 2, 2};
int[] picks = new int[3];
for (int i = 0; i < 6; i++) {
picks[sss.pick()] += 1;
}
assertThat(picks).isEqualTo(expectedPicks);
}
@Test
public void testLargestWeightIndexPickedEveryGeneration() {
float[] weights = {1.0f, 2.0f, 3.0f};
int largestWeightIndex = 2;
Random random = new Random();
StaticStrideScheduler sss = new StaticStrideScheduler(weights, random);
int largestWeightPickCount = 0;
int kMaxWeight = 65535;
for (int i = 0; i < largestWeightIndex * kMaxWeight; i++) {
if (sss.pick() == largestWeightIndex) {
largestWeightPickCount += 1;
}
}
assertThat(largestWeightPickCount).isEqualTo(kMaxWeight);
}
@Test
public void testStaticStrideSchedulerNonIntegers1() {
float[] weights = {2.0f, (float) (10.0 / 3.0), 1.0f};
Random random = new Random();
StaticStrideScheduler sss = new StaticStrideScheduler(weights, random);
double totalWeight = 2 + 10.0 / 3.0 + 1.0;
Map<Integer, Integer> pickCount = new HashMap<>();
for (int i = 0; i < 1000; i++) {
int result = sss.pick();
pickCount.put(result, pickCount.getOrDefault(result, 0) + 1);
}
for (int i = 0; i < 3; i++) {
assertThat(Math.abs(pickCount.getOrDefault(i, 0) / 1000.0 - weights[i] / totalWeight))
.isAtMost(0.01);
}
}
@Test
public void testStaticStrideSchedulerNonIntegers2() {
float[] weights = {0.5f, 0.3f, 1.0f};
Random random = new Random();
StaticStrideScheduler sss = new StaticStrideScheduler(weights, random);
double totalWeight = 1.8;
Map<Integer, Integer> pickCount = new HashMap<>();
for (int i = 0; i < 1000; i++) {
int result = sss.pick();
pickCount.put(result, pickCount.getOrDefault(result, 0) + 1);
}
for (int i = 0; i < 3; i++) {
assertThat(Math.abs(pickCount.getOrDefault(i, 0) / 1000.0 - weights[i] / totalWeight))
.isAtMost(0.01);
}
}
@Test
public void testTwoWeights() {
float[] weights = {1.0f, 2.0f};
Random random = new Random();
StaticStrideScheduler sss = new StaticStrideScheduler(weights, random);
double totalWeight = 3;
Map<Integer, Integer> pickCount = new HashMap<>();
for (int i = 0; i < 1000; i++) {
int result = sss.pick();
pickCount.put(result, pickCount.getOrDefault(result, 0) + 1);
}
for (int i = 0; i < 2; i++) {
assertThat(Math.abs(pickCount.getOrDefault(i, 0) / 1000.0 - weights[i] / totalWeight))
.isAtMost(0.01);
}
}
@Test
public void testManyWeights() {
float[] weights = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
Random random = new Random();
StaticStrideScheduler sss = new StaticStrideScheduler(weights, random);
double totalWeight = 15;
Map<Integer, Integer> pickCount = new HashMap<>();
for (int i = 0; i < 1000; i++) {
int result = sss.pick();
pickCount.put(result, pickCount.getOrDefault(result, 0) + 1);
}
for (int i = 0; i < 5; i++) {
assertThat(Math.abs(pickCount.getOrDefault(i, 0) / 1000.0 - weights[i] / totalWeight))
.isAtMost(0.0011);
}
}
@Test
public void testManyComplexWeights() {
float[] weights = {1.2f, 2.4f, 222.56f, 1.1f, 15.0f, 226342.0f, 5123.0f, 532.2f};
Random random = new Random();
StaticStrideScheduler sss = new StaticStrideScheduler(weights, random);
double totalWeight = 1.2 + 2.4 + 222.56 + 15.0 + 226342.0 + 5123.0 + 0.0001;
Map<Integer, Integer> pickCount = new HashMap<>();
for (int i = 0; i < 1000; i++) {
int result = sss.pick();
pickCount.put(result, pickCount.getOrDefault(result, 0) + 1);
}
for (int i = 0; i < 8; i++) {
assertThat(Math.abs(pickCount.getOrDefault(i, 0) / 1000.0 - weights[i] / totalWeight))
.isAtMost(0.01);
}
}
@Test
public void testDeterministicPicks() {
float[] weights = {2.0f, 3.0f, 6.0f};
Random random = new FakeRandom(0);
StaticStrideScheduler sss = new StaticStrideScheduler(weights, random);
assertThat(sss.getSequence()).isEqualTo(0);
assertThat(sss.pick()).isEqualTo(1);
assertThat(sss.getSequence()).isEqualTo(2);
assertThat(sss.pick()).isEqualTo(2);
assertThat(sss.getSequence()).isEqualTo(3);
assertThat(sss.pick()).isEqualTo(2);
assertThat(sss.getSequence()).isEqualTo(6);
assertThat(sss.pick()).isEqualTo(0);
assertThat(sss.getSequence()).isEqualTo(7);
assertThat(sss.pick()).isEqualTo(1);
assertThat(sss.getSequence()).isEqualTo(8);
assertThat(sss.pick()).isEqualTo(2);
assertThat(sss.getSequence()).isEqualTo(9);
}
@Test
public void testImmediateWraparound() {
float[] weights = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
Random random = new FakeRandom(-1);
StaticStrideScheduler sss = new StaticStrideScheduler(weights, random);
double totalWeight = 15;
Map<Integer, Integer> pickCount = new HashMap<>();
for (int i = 0; i < 1000; i++) {
int result = sss.pick();
pickCount.put(result, pickCount.getOrDefault(result, 0) + 1);
}
for (int i = 0; i < 5; i++) {
assertThat(Math.abs(pickCount.getOrDefault(i, 0) / 1000.0 - weights[i] / totalWeight))
.isAtMost(0.001);
}
}
@Test
public void testWraparound() {
float[] weights = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
Random random = new FakeRandom(-500);
StaticStrideScheduler sss = new StaticStrideScheduler(weights, random);
double totalWeight = 15;
Map<Integer, Integer> pickCount = new HashMap<>();
for (int i = 0; i < 1000; i++) {
int result = sss.pick();
pickCount.put(result, pickCount.getOrDefault(result, 0) + 1);
}
for (int i = 0; i < 5; i++) {
assertThat(Math.abs(pickCount.getOrDefault(i, 0) / 1000.0 - weights[i] / totalWeight))
.isAtMost(0.0011);
}
}
@Test
public void testDeterministicWraparound() {
float[] weights = {2.0f, 3.0f, 6.0f};
Random random = new FakeRandom(-1);
StaticStrideScheduler sss = new StaticStrideScheduler(weights, random);
assertThat(sss.getSequence()).isEqualTo(0xFFFF_FFFFL);
assertThat(sss.pick()).isEqualTo(1);
assertThat(sss.getSequence()).isEqualTo(2);
assertThat(sss.pick()).isEqualTo(2);
assertThat(sss.getSequence()).isEqualTo(3);
assertThat(sss.pick()).isEqualTo(2);
assertThat(sss.getSequence()).isEqualTo(6);
assertThat(sss.pick()).isEqualTo(0);
assertThat(sss.getSequence()).isEqualTo(7);
assertThat(sss.pick()).isEqualTo(1);
assertThat(sss.getSequence()).isEqualTo(8);
assertThat(sss.pick()).isEqualTo(2);
assertThat(sss.getSequence()).isEqualTo(9);
}
private static class FakeSocketAddress extends SocketAddress { private static class FakeSocketAddress extends SocketAddress {
final String name; final String name;
@ -875,10 +1106,16 @@ public class WeightedRoundRobinLoadBalancerTest {
} }
private static class FakeRandom extends Random { private static class FakeRandom extends Random {
private int nextInt;
public FakeRandom(int nextInt) {
this.nextInt = nextInt;
}
@Override @Override
public double nextDouble() { public int nextInt() {
// return constant value to disable init deadline randomization in the scheduler // return constant value to disable init deadline randomization in the scheduler
return 0.322023; return nextInt;
} }
} }
} }