implemented and tested static stride scheduler for weighted round robin load balancing policy (#10272)

This commit is contained in:
Tony An 2023-07-06 10:03:08 -07:00 committed by GitHub
parent 361616ae7c
commit 0b53dd7304
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 388 additions and 149 deletions

View File

@ -44,10 +44,10 @@ import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Random;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.logging.Level;
import java.util.logging.Logger;
@ -258,7 +258,7 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
new HashMap<>();
private final boolean enableOobLoadReport;
private final float errorUtilizationPenalty;
private volatile EdfScheduler scheduler;
private volatile StaticStrideScheduler scheduler;
WeightedRoundRobinPicker(List<Subchannel> list, boolean enableOobLoadReport,
float errorUtilizationPenalty) {
@ -288,26 +288,14 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
}
private void updateWeight() {
int weightedChannelCount = 0;
double avgWeight = 0;
for (Subchannel value : list) {
double newWeight = ((WrrSubchannel) value).getWeight();
if (newWeight > 0) {
avgWeight += newWeight;
weightedChannelCount++;
}
}
EdfScheduler scheduler = new EdfScheduler(list.size(), random);
if (weightedChannelCount >= 1) {
avgWeight /= 1.0 * weightedChannelCount;
} else {
avgWeight = 1;
}
float[] newWeights = new float[list.size()];
for (int i = 0; i < list.size(); i++) {
WrrSubchannel subchannel = (WrrSubchannel) list.get(i);
double newWeight = subchannel.getWeight();
scheduler.add(i, newWeight > 0 ? newWeight : avgWeight);
newWeights[i] = newWeight > 0 ? (float) newWeight : 0.0f;
}
StaticStrideScheduler scheduler = new StaticStrideScheduler(newWeights, random);
this.scheduler = scheduler;
}
@ -340,109 +328,123 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
}
}
/**
* The earliest deadline first implementation in which each object is
* chosen deterministically and periodically with frequency proportional to its weight.
*
* <p>Specifically, each object added to chooser is given a deadline equal to the multiplicative
* inverse of its weight. The place of each object in its deadline is tracked, and each call to
* choose returns the object with the least remaining time in its deadline.
* (Ties are broken by the order in which the children were added to the chooser.) The deadline
* advances by the multiplicative inverse of the object's weight.
* For example, if items A and B are added with weights 0.5 and 0.2, successive chooses return:
/*
* The Static Stride Scheduler is an implementation of an earliest deadline first (EDF) scheduler
* in which each object's deadline is the multiplicative inverse of the object's weight.
* <p>
* The way in which this is implemented is through a static stride scheduler.
* The Static Stride Scheduler works by iterating through the list of subchannel weights
* and using modular arithmetic to proportionally distribute picks, favoring entries
* with higher weights. It is based on the observation that the intended sequence generated
* from an EDF scheduler is a periodic one that can be achieved through modular arithmetic.
* The Static Stride Scheduler is more performant than other implementations of the EDF
* Scheduler, as it removes the need for a priority queue (and thus mutex locks).
* <p>
* go/static-stride-scheduler
* <p>
*
* <ul>
* <li>In the first call, the deadlines are A=2 (1/0.5) and B=5 (1/0.2), so A is returned.
* The deadline of A is updated to 4.
* <li>Next, the remaining deadlines are A=4 and B=5, so A is returned. The deadline of A (2) is
* updated to A=6.
* <li>Remaining deadlines are A=6 and B=5, so B is returned. The deadline of B is updated with
* with B=10.
* <li>Remaining deadlines are A=6 and B=10, so A is returned. The deadline of A is updated with
* A=8.
* <li>Remaining deadlines are A=8 and B=10, so A is returned. The deadline of A is updated with
* A=10.
* <li>Remaining deadlines are A=10 and B=10, so A is returned. The deadline of A is updated
* with A=12.
* <li>Remaining deadlines are A=12 and B=10, so B is returned. The deadline of B is updated
* with B=15.
* <li>etc.
* </ul>
*
* <p>In short: the entry with the highest weight is preferred.
*
* <ul>
* <li>add() - O(lg n)
* <li>pick() - O(lg n)
* </ul>
*
* <li>nextSequence() - O(1)
* <li>pick() - O(n)
*/
@VisibleForTesting
static final class EdfScheduler {
private final PriorityQueue<ObjectState> prioQueue;
static final class StaticStrideScheduler {
private final short[] scaledWeights;
private final int sizeDivisor;
private final AtomicInteger sequence;
private static final int K_MAX_WEIGHT = 0xFFFF;
/**
* Weights below this value will be upped to this minimum weight.
*/
private static final double MINIMUM_WEIGHT = 0.0001;
StaticStrideScheduler(float[] weights, Random random) {
checkArgument(weights.length >= 1, "Couldn't build scheduler: requires at least one weight");
int numChannels = weights.length;
int numWeightedChannels = 0;
double sumWeight = 0;
float maxWeight = 0;
short meanWeight = 0;
for (float weight : weights) {
if (weight > 0) {
sumWeight += weight;
maxWeight = Math.max(weight, maxWeight);
numWeightedChannels++;
}
}
private final Object lock = new Object();
private final Random random;
/**
* Use the item's deadline as the order in the priority queue. If the deadlines are the same,
* use the index. Index should be unique.
*/
EdfScheduler(int initialCapacity, Random random) {
this.prioQueue = new PriorityQueue<ObjectState>(initialCapacity, (o1, o2) -> {
if (o1.deadline == o2.deadline) {
return Integer.compare(o1.index, o2.index);
double scalingFactor = K_MAX_WEIGHT / maxWeight;
if (numWeightedChannels > 0) {
meanWeight = (short) Math.round(scalingFactor * sumWeight / numWeightedChannels);
} else {
return Double.compare(o1.deadline, o2.deadline);
}
});
this.random = random;
meanWeight = 1;
}
/**
* Adds the item in the scheduler. This is not thread safe.
*
* @param index The field {@link ObjectState#index} to be added
* @param weight positive weight for the added object
*/
void add(int index, double weight) {
checkArgument(weight > 0.0, "Weights need to be positive.");
ObjectState state = new ObjectState(Math.max(weight, MINIMUM_WEIGHT), index);
// Randomize the initial deadline.
state.deadline = random.nextDouble() * (1 / state.weight);
prioQueue.add(state);
// scales weights s.t. max(weights) == K_MAX_WEIGHT, meanWeight is scaled accordingly
short[] scaledWeights = new short[numChannels];
for (int i = 0; i < numChannels; i++) {
if (weights[i] <= 0) {
scaledWeights[i] = meanWeight;
} else {
scaledWeights[i] = (short) Math.round(weights[i] * scalingFactor);
}
}
/**
* Picks the next WRR object.
this.scaledWeights = scaledWeights;
this.sizeDivisor = numChannels;
this.sequence = new AtomicInteger(random.nextInt());
}
/** Returns the next sequence number and atomically increases sequence with wraparound. */
private long nextSequence() {
return Integer.toUnsignedLong(sequence.getAndIncrement());
}
@VisibleForTesting
long getSequence() {
return Integer.toUnsignedLong(sequence.get());
}
/*
* Selects index of next backend server.
* <p>
* A 2D array is compactly represented as a function of W(backend), where the row
* represents the generation and the column represents the backend index:
* X(backend,generation) | generation [0,kMaxWeight).
* Each element in the conceptual array is a boolean indicating whether the backend at
* this index should be picked now. If false, the counter is incremented again,
* and the new element is checked. An atomically incremented counter keeps track of our
* backend and generation through modular arithmetic within the pick() method.
* <p>
* Modular arithmetic allows us to evenly distribute picks and skips between
* generations based on W(backend).
* X(backend,generation) = (W(backend) * generation) % kMaxWeight >= kMaxWeight - W(backend)
* If we have the same three backends with weights:
* W(backend) = {2,3,6} scaled to max(W(backend)) = 6, then X(backend,generation) is:
* <p>
* B0 B1 B2
* T T T
* F F T
* F T T
* T F T
* F T T
* F F T
* The sequence of picked backend indices is given by
* walking across and down: {0,1,2,2,1,2,0,2,1,2,2}.
* <p>
* To reduce the variance and spread the wasted work among different picks,
* an offset that varies per backend index is also included to the calculation.
*/
int pick() {
synchronized (lock) {
ObjectState minObject = prioQueue.remove();
minObject.deadline += 1.0 / minObject.weight;
prioQueue.add(minObject);
return minObject.index;
while (true) {
long sequence = this.nextSequence();
int backendIndex = (int) (sequence % this.sizeDivisor);
long generation = sequence / this.sizeDivisor;
int weight = Short.toUnsignedInt(this.scaledWeights[backendIndex]);
long offset = (long) K_MAX_WEIGHT / 2 * backendIndex;
if ((weight * generation + offset) % K_MAX_WEIGHT < K_MAX_WEIGHT - weight) {
continue;
}
return backendIndex;
}
}
/** Holds the state of the object. */
@VisibleForTesting
static class ObjectState {
private final double weight;
private final int index;
private volatile double deadline;
ObjectState(double weight, int index) {
this.weight = weight;
this.index = index;
}
}
static final class WeightedRoundRobinLoadBalancerConfig {

View File

@ -52,7 +52,7 @@ import io.grpc.SynchronizationContext;
import io.grpc.internal.FakeClock;
import io.grpc.services.InternalCallMetricRecorder;
import io.grpc.services.MetricReport;
import io.grpc.xds.WeightedRoundRobinLoadBalancer.EdfScheduler;
import io.grpc.xds.WeightedRoundRobinLoadBalancer.StaticStrideScheduler;
import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedRoundRobinLoadBalancerConfig;
import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedRoundRobinPicker;
import io.grpc.xds.WeightedRoundRobinLoadBalancer.WrrSubchannel;
@ -175,7 +175,7 @@ public class WeightedRoundRobinLoadBalancerTest {
}
});
wrr = new WeightedRoundRobinLoadBalancer(helper, fakeClock.getDeadlineTicker(),
new FakeRandom());
new FakeRandom(0));
}
@Test
@ -338,7 +338,7 @@ public class WeightedRoundRobinLoadBalancerTest {
}
@Test
public void pickByWeight_LargeWeight() {
public void pickByWeight_largeWeight() {
MetricReport report1 = InternalCallMetricRecorder.createMetricReport(
0.1, 0, 0.1, 999, 0, new HashMap<>(), new HashMap<>());
MetricReport report2 = InternalCallMetricRecorder.createMetricReport(
@ -593,6 +593,7 @@ public class WeightedRoundRobinLoadBalancerTest {
assertThat(fakeClock.forwardTime(500, TimeUnit.MILLISECONDS)).isEqualTo(1);
assertThat(weightedPicker.pickSubchannel(mockArgs)
.getSubchannel()).isEqualTo(weightedSubchannel2);
}
@Test
@ -750,12 +751,12 @@ public class WeightedRoundRobinLoadBalancerTest {
}
assertThat(pickCount.size()).isEqualTo(3);
assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 4.0 / 9))
.isAtMost(0.001);
.isAtMost(0.002);
assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 2.0 / 9))
.isAtMost(0.001);
.isAtMost(0.002);
// subchannel3's weight is average of subchannel1 and subchannel2
assertThat(Math.abs(pickCount.get(weightedSubchannel3) / 1000.0 - 3.0 / 9))
.isAtMost(0.001);
.isAtMost(0.002);
}
@Test
@ -821,37 +822,6 @@ public class WeightedRoundRobinLoadBalancerTest {
.isAtMost(0.001);
}
@Test
public void edfScheduler() {
Random random = new Random();
double totalWeight = 0;
int capacity = random.nextInt(10) + 1;
double[] weights = new double[capacity];
EdfScheduler scheduler = new EdfScheduler(capacity, random);
for (int i = 0; i < capacity; i++) {
weights[i] = random.nextDouble();
scheduler.add(i, weights[i]);
totalWeight += weights[i];
}
Map<Integer, Integer> pickCount = new HashMap<>();
for (int i = 0; i < 1000; i++) {
int result = scheduler.pick();
pickCount.put(result, pickCount.getOrDefault(result, 0) + 1);
}
for (int i = 0; i < capacity; i++) {
assertThat(Math.abs(pickCount.getOrDefault(i, 0) / 1000.0 - weights[i] / totalWeight) )
.isAtMost(0.01);
}
}
@Test
public void edsScheduler_sameWeight() {
EdfScheduler scheduler = new EdfScheduler(2, new FakeRandom());
scheduler.add(0, 0.5);
scheduler.add(1, 0.5);
assertThat(scheduler.pick()).isEqualTo(0);
}
@Test(expected = NullPointerException.class)
public void wrrConfig_TimeValueNonNull() {
WeightedRoundRobinLoadBalancerConfig.newBuilder().setBlackoutPeriodNanos((Long) null);
@ -862,6 +832,267 @@ public class WeightedRoundRobinLoadBalancerTest {
WeightedRoundRobinLoadBalancerConfig.newBuilder().setEnableOobLoadReport((Boolean) null);
}
@Test(expected = IllegalArgumentException.class)
public void emptyWeights() {
float[] weights = {};
Random random = new Random();
StaticStrideScheduler sss = new StaticStrideScheduler(weights, random);
sss.pick();
}
@Test
public void testPicksEqualsWeights() {
float[] weights = {1.0f, 2.0f, 3.0f};
Random random = new Random();
StaticStrideScheduler sss = new StaticStrideScheduler(weights, random);
int[] expectedPicks = new int[] {1, 2, 3};
int[] picks = new int[3];
for (int i = 0; i < 6; i++) {
picks[sss.pick()] += 1;
}
assertThat(picks).isEqualTo(expectedPicks);
}
@Test
public void testContainsZeroWeightUseMean() {
float[] weights = {3.0f, 0.0f, 1.0f};
Random random = new Random();
StaticStrideScheduler sss = new StaticStrideScheduler(weights, random);
int[] expectedPicks = new int[] {3, 2, 1};
int[] picks = new int[3];
for (int i = 0; i < 6; i++) {
picks[sss.pick()] += 1;
}
assertThat(picks).isEqualTo(expectedPicks);
}
@Test
public void testContainsNegativeWeightUseMean() {
float[] weights = {3.0f, -1.0f, 1.0f};
Random random = new Random();
StaticStrideScheduler sss = new StaticStrideScheduler(weights, random);
int[] expectedPicks = new int[] {3, 2, 1};
int[] picks = new int[3];
for (int i = 0; i < 6; i++) {
picks[sss.pick()] += 1;
}
assertThat(picks).isEqualTo(expectedPicks);
}
@Test
public void testAllSameWeights() {
float[] weights = {1.0f, 1.0f, 1.0f};
Random random = new Random();
StaticStrideScheduler sss = new StaticStrideScheduler(weights, random);
int[] expectedPicks = new int[] {2, 2, 2};
int[] picks = new int[3];
for (int i = 0; i < 6; i++) {
picks[sss.pick()] += 1;
}
assertThat(picks).isEqualTo(expectedPicks);
}
@Test
public void testAllZeroWeightsUseOne() {
float[] weights = {0.0f, 0.0f, 0.0f};
Random random = new Random();
StaticStrideScheduler sss = new StaticStrideScheduler(weights, random);
int[] expectedPicks = new int[] {2, 2, 2};
int[] picks = new int[3];
for (int i = 0; i < 6; i++) {
picks[sss.pick()] += 1;
}
assertThat(picks).isEqualTo(expectedPicks);
}
@Test
public void testAllInvalidWeightsUseOne() {
float[] weights = {-3.1f, -0.0f, 0.0f};
Random random = new Random();
StaticStrideScheduler sss = new StaticStrideScheduler(weights, random);
int[] expectedPicks = new int[] {2, 2, 2};
int[] picks = new int[3];
for (int i = 0; i < 6; i++) {
picks[sss.pick()] += 1;
}
assertThat(picks).isEqualTo(expectedPicks);
}
@Test
public void testLargestWeightIndexPickedEveryGeneration() {
float[] weights = {1.0f, 2.0f, 3.0f};
int largestWeightIndex = 2;
Random random = new Random();
StaticStrideScheduler sss = new StaticStrideScheduler(weights, random);
int largestWeightPickCount = 0;
int kMaxWeight = 65535;
for (int i = 0; i < largestWeightIndex * kMaxWeight; i++) {
if (sss.pick() == largestWeightIndex) {
largestWeightPickCount += 1;
}
}
assertThat(largestWeightPickCount).isEqualTo(kMaxWeight);
}
@Test
public void testStaticStrideSchedulerNonIntegers1() {
float[] weights = {2.0f, (float) (10.0 / 3.0), 1.0f};
Random random = new Random();
StaticStrideScheduler sss = new StaticStrideScheduler(weights, random);
double totalWeight = 2 + 10.0 / 3.0 + 1.0;
Map<Integer, Integer> pickCount = new HashMap<>();
for (int i = 0; i < 1000; i++) {
int result = sss.pick();
pickCount.put(result, pickCount.getOrDefault(result, 0) + 1);
}
for (int i = 0; i < 3; i++) {
assertThat(Math.abs(pickCount.getOrDefault(i, 0) / 1000.0 - weights[i] / totalWeight))
.isAtMost(0.01);
}
}
@Test
public void testStaticStrideSchedulerNonIntegers2() {
float[] weights = {0.5f, 0.3f, 1.0f};
Random random = new Random();
StaticStrideScheduler sss = new StaticStrideScheduler(weights, random);
double totalWeight = 1.8;
Map<Integer, Integer> pickCount = new HashMap<>();
for (int i = 0; i < 1000; i++) {
int result = sss.pick();
pickCount.put(result, pickCount.getOrDefault(result, 0) + 1);
}
for (int i = 0; i < 3; i++) {
assertThat(Math.abs(pickCount.getOrDefault(i, 0) / 1000.0 - weights[i] / totalWeight))
.isAtMost(0.01);
}
}
@Test
public void testTwoWeights() {
float[] weights = {1.0f, 2.0f};
Random random = new Random();
StaticStrideScheduler sss = new StaticStrideScheduler(weights, random);
double totalWeight = 3;
Map<Integer, Integer> pickCount = new HashMap<>();
for (int i = 0; i < 1000; i++) {
int result = sss.pick();
pickCount.put(result, pickCount.getOrDefault(result, 0) + 1);
}
for (int i = 0; i < 2; i++) {
assertThat(Math.abs(pickCount.getOrDefault(i, 0) / 1000.0 - weights[i] / totalWeight))
.isAtMost(0.01);
}
}
@Test
public void testManyWeights() {
float[] weights = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
Random random = new Random();
StaticStrideScheduler sss = new StaticStrideScheduler(weights, random);
double totalWeight = 15;
Map<Integer, Integer> pickCount = new HashMap<>();
for (int i = 0; i < 1000; i++) {
int result = sss.pick();
pickCount.put(result, pickCount.getOrDefault(result, 0) + 1);
}
for (int i = 0; i < 5; i++) {
assertThat(Math.abs(pickCount.getOrDefault(i, 0) / 1000.0 - weights[i] / totalWeight))
.isAtMost(0.0011);
}
}
@Test
public void testManyComplexWeights() {
float[] weights = {1.2f, 2.4f, 222.56f, 1.1f, 15.0f, 226342.0f, 5123.0f, 532.2f};
Random random = new Random();
StaticStrideScheduler sss = new StaticStrideScheduler(weights, random);
double totalWeight = 1.2 + 2.4 + 222.56 + 15.0 + 226342.0 + 5123.0 + 0.0001;
Map<Integer, Integer> pickCount = new HashMap<>();
for (int i = 0; i < 1000; i++) {
int result = sss.pick();
pickCount.put(result, pickCount.getOrDefault(result, 0) + 1);
}
for (int i = 0; i < 8; i++) {
assertThat(Math.abs(pickCount.getOrDefault(i, 0) / 1000.0 - weights[i] / totalWeight))
.isAtMost(0.01);
}
}
@Test
public void testDeterministicPicks() {
float[] weights = {2.0f, 3.0f, 6.0f};
Random random = new FakeRandom(0);
StaticStrideScheduler sss = new StaticStrideScheduler(weights, random);
assertThat(sss.getSequence()).isEqualTo(0);
assertThat(sss.pick()).isEqualTo(1);
assertThat(sss.getSequence()).isEqualTo(2);
assertThat(sss.pick()).isEqualTo(2);
assertThat(sss.getSequence()).isEqualTo(3);
assertThat(sss.pick()).isEqualTo(2);
assertThat(sss.getSequence()).isEqualTo(6);
assertThat(sss.pick()).isEqualTo(0);
assertThat(sss.getSequence()).isEqualTo(7);
assertThat(sss.pick()).isEqualTo(1);
assertThat(sss.getSequence()).isEqualTo(8);
assertThat(sss.pick()).isEqualTo(2);
assertThat(sss.getSequence()).isEqualTo(9);
}
@Test
public void testImmediateWraparound() {
float[] weights = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
Random random = new FakeRandom(-1);
StaticStrideScheduler sss = new StaticStrideScheduler(weights, random);
double totalWeight = 15;
Map<Integer, Integer> pickCount = new HashMap<>();
for (int i = 0; i < 1000; i++) {
int result = sss.pick();
pickCount.put(result, pickCount.getOrDefault(result, 0) + 1);
}
for (int i = 0; i < 5; i++) {
assertThat(Math.abs(pickCount.getOrDefault(i, 0) / 1000.0 - weights[i] / totalWeight))
.isAtMost(0.001);
}
}
@Test
public void testWraparound() {
float[] weights = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
Random random = new FakeRandom(-500);
StaticStrideScheduler sss = new StaticStrideScheduler(weights, random);
double totalWeight = 15;
Map<Integer, Integer> pickCount = new HashMap<>();
for (int i = 0; i < 1000; i++) {
int result = sss.pick();
pickCount.put(result, pickCount.getOrDefault(result, 0) + 1);
}
for (int i = 0; i < 5; i++) {
assertThat(Math.abs(pickCount.getOrDefault(i, 0) / 1000.0 - weights[i] / totalWeight))
.isAtMost(0.0011);
}
}
@Test
public void testDeterministicWraparound() {
float[] weights = {2.0f, 3.0f, 6.0f};
Random random = new FakeRandom(-1);
StaticStrideScheduler sss = new StaticStrideScheduler(weights, random);
assertThat(sss.getSequence()).isEqualTo(0xFFFF_FFFFL);
assertThat(sss.pick()).isEqualTo(1);
assertThat(sss.getSequence()).isEqualTo(2);
assertThat(sss.pick()).isEqualTo(2);
assertThat(sss.getSequence()).isEqualTo(3);
assertThat(sss.pick()).isEqualTo(2);
assertThat(sss.getSequence()).isEqualTo(6);
assertThat(sss.pick()).isEqualTo(0);
assertThat(sss.getSequence()).isEqualTo(7);
assertThat(sss.pick()).isEqualTo(1);
assertThat(sss.getSequence()).isEqualTo(8);
assertThat(sss.pick()).isEqualTo(2);
assertThat(sss.getSequence()).isEqualTo(9);
}
private static class FakeSocketAddress extends SocketAddress {
final String name;
@ -875,10 +1106,16 @@ public class WeightedRoundRobinLoadBalancerTest {
}
private static class FakeRandom extends Random {
private int nextInt;
public FakeRandom(int nextInt) {
this.nextInt = nextInt;
}
@Override
public double nextDouble() {
public int nextInt() {
// return constant value to disable init deadline randomization in the scheduler
return 0.322023;
return nextInt;
}
}
}