xds: WRR scheduler clips weights (#10480)

This commit is contained in:
Tony An 2023-08-17 17:11:55 -07:00 committed by GitHub
parent 26be0c7665
commit 97f4f8687c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 103 additions and 108 deletions

View File

@ -351,41 +351,69 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
private final AtomicInteger sequence;
private static final int K_MAX_WEIGHT = 0xFFFF;
// Assuming the mean of all known weights is M, StaticStrideScheduler will clamp
// weights bigger than M*kMaxRatio and weights smaller than M*kMinRatio.
//
// This is done as a performance optimization by limiting the number of rounds for picks
// for edge cases where channels have large differences in subchannel weights.
// In this case, without these clips, it would potentially require the scheduler to
// frequently traverse through the entire subchannel list within the pick method.
//
// The current values of 10 and 0.1 were chosen without any experimenting. It should
// decrease the amount of sequences that the scheduler must traverse through in order
// to pick a high weight subchannel in such corner cases.
// But, it also makes WeightedRoundRobin to send slightly more requests to
// potentially very bad tasks (that would have near-zero weights) than zero.
// This is not necessarily a downside, though. Perhaps this is not a problem at
// all, and we can increase this value if needed to save CPU cycles.
private static final double K_MAX_RATIO = 10;
private static final double K_MIN_RATIO = 0.1;
StaticStrideScheduler(float[] weights, AtomicInteger sequence) {
checkArgument(weights.length >= 1, "Couldn't build scheduler: requires at least one weight");
int numChannels = weights.length;
int numWeightedChannels = 0;
double sumWeight = 0;
float maxWeight = 0;
short meanWeight = 0;
double unscaledMeanWeight;
float unscaledMaxWeight = 0;
for (float weight : weights) {
if (weight > 0) {
sumWeight += weight;
maxWeight = Math.max(weight, maxWeight);
unscaledMaxWeight = Math.max(weight, unscaledMaxWeight);
numWeightedChannels++;
}
}
double scalingFactor = K_MAX_WEIGHT / maxWeight;
// Adjust max value s.t. ratio does not exceed K_MAX_RATIO. This should
// ensure that we on average do at most K_MAX_RATIO rounds for picks.
if (numWeightedChannels > 0) {
meanWeight = (short) Math.round(scalingFactor * sumWeight / numWeightedChannels);
unscaledMeanWeight = sumWeight / numWeightedChannels;
unscaledMaxWeight = Math.min(unscaledMaxWeight, (float) (K_MAX_RATIO * unscaledMeanWeight));
} else {
meanWeight = (short) Math.round(scalingFactor);
// Fall back to round robin if all values are non-positives
unscaledMeanWeight = 1;
unscaledMaxWeight = 1;
}
// scales weights s.t. max(weights) == K_MAX_WEIGHT, meanWeight is scaled accordingly
// Scales weights s.t. max(weights) == K_MAX_WEIGHT, meanWeight is scaled accordingly.
// Note that, since we cap the weights to stay within K_MAX_RATIO, meanWeight might not
// match the actual mean of the values that end up in the scheduler.
double scalingFactor = K_MAX_WEIGHT / unscaledMaxWeight;
// We compute weightLowerBound and clamp it to 1 from below so that in the
// worst case, we represent tiny weights as 1.
int weightLowerBound = (int) Math.ceil(scalingFactor * unscaledMeanWeight * K_MIN_RATIO);
short[] scaledWeights = new short[numChannels];
for (int i = 0; i < numChannels; i++) {
if (weights[i] <= 0) {
scaledWeights[i] = meanWeight;
scaledWeights[i] = (short) Math.round(scalingFactor * unscaledMeanWeight);
} else {
scaledWeights[i] = (short) Math.round(weights[i] * scalingFactor);
int weight = (int) Math.round(scalingFactor * Math.min(weights[i], unscaledMaxWeight));
scaledWeights[i] = (short) Math.max(weight, weightLowerBound);
}
}
this.scaledWeights = scaledWeights;
this.sequence = sequence;
}
/** Returns the next sequence number and atomically increases sequence with wraparound. */

View File

@ -330,11 +330,11 @@ public class WeightedRoundRobinLoadBalancerTest {
}
assertThat(pickCount.size()).isEqualTo(3);
assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 10000.0 - subchannel1PickRatio))
.isAtMost(0.0001);
.isLessThan(0.0002);
assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 10000.0 - subchannel2PickRatio ))
.isAtMost(0.0001);
.isLessThan(0.0002);
assertThat(Math.abs(pickCount.get(weightedSubchannel3) / 10000.0 - subchannel3PickRatio ))
.isAtMost(0.0001);
.isLessThan(0.0002);
}
@Test
@ -345,10 +345,11 @@ public class WeightedRoundRobinLoadBalancerTest {
0.9, 0, 0.1, 2, 0, new HashMap<>(), new HashMap<>(), new HashMap<>());
MetricReport report3 = InternalCallMetricRecorder.createMetricReport(
0.86, 0, 0.1, 100, 0, new HashMap<>(), new HashMap<>(), new HashMap<>());
double totalWeight = 999 / 0.1 + 2 / 0.9 + 100 / 0.86;
pickByWeight(report1, report2, report3, 999 / 0.1 / totalWeight, 2 / 0.9 / totalWeight,
100 / 0.86 / totalWeight);
double meanWeight = (999 / 0.1 + 2 / 0.9 + 100 / 0.86) / 3;
double cappedMin = meanWeight * 0.1; // min capped at minRatio * meanWeight
double totalWeight = 999 / 0.1 + cappedMin + cappedMin;
pickByWeight(report1, report2, report3, 999 / 0.1 / totalWeight, cappedMin / totalWeight,
cappedMin / totalWeight);
}
@Test
@ -359,10 +360,11 @@ public class WeightedRoundRobinLoadBalancerTest {
0.12, 0.9, 0.1, 2, 0, new HashMap<>(), new HashMap<>(), new HashMap<>());
MetricReport report3 = InternalCallMetricRecorder.createMetricReport(
0.33, 0.86, 0.1, 100, 0, new HashMap<>(), new HashMap<>(), new HashMap<>());
double totalWeight = 999 / 0.1 + 2 / 0.9 + 100 / 0.86;
pickByWeight(report1, report2, report3, 999 / 0.1 / totalWeight, 2 / 0.9 / totalWeight,
100 / 0.86 / totalWeight);
double meanWeight = (999 / 0.1 + 2 / 0.9 + 100 / 0.86) / 3;
double cappedMin = meanWeight * 0.1;
double totalWeight = 999 / 0.1 + cappedMin + cappedMin; // min capped at minRatio * meanWeight
pickByWeight(report1, report2, report3, 999 / 0.1 / totalWeight, cappedMin / totalWeight,
cappedMin / totalWeight);
}
@Test
@ -373,13 +375,14 @@ public class WeightedRoundRobinLoadBalancerTest {
0.9, 0, 0.1, 2, 1.8, new HashMap<>(), new HashMap<>(), new HashMap<>());
MetricReport report3 = InternalCallMetricRecorder.createMetricReport(
0.86, 0, 0.1, 100, 3, new HashMap<>(), new HashMap<>(), new HashMap<>());
double weight1 = 999 / (0.1 + 13 / 999F * weightedConfig.errorUtilizationPenalty);
double weight2 = 2 / (0.9 + 1.8 / 2F * weightedConfig.errorUtilizationPenalty);
double weight3 = 100 / (0.86 + 3 / 100F * weightedConfig.errorUtilizationPenalty);
double totalWeight = weight1 + weight2 + weight3;
pickByWeight(report1, report2, report3, weight1 / totalWeight, weight2 / totalWeight,
weight3 / totalWeight);
double weight1 = 999 / (0.1 + 13 / 999F * weightedConfig.errorUtilizationPenalty); // ~5609.899
double weight2 = 2 / (0.9 + 1.8 / 2F * weightedConfig.errorUtilizationPenalty); // ~0.317
double weight3 = 100 / (0.86 + 3 / 100F * weightedConfig.errorUtilizationPenalty); // ~96.154
double meanWeight = (weight1 + weight2 + weight3) / 3;
double cappedMin = meanWeight * 0.1; // min capped at minRatio * meanWeight
double totalWeight = weight1 + cappedMin + cappedMin;
pickByWeight(report1, report2, report3, weight1 / totalWeight, cappedMin / totalWeight,
cappedMin / totalWeight);
}
@Test
@ -835,7 +838,7 @@ public class WeightedRoundRobinLoadBalancerTest {
@Test(expected = IllegalArgumentException.class)
public void emptyWeights() {
float[] weights = {};
Random random = new Random();
Random random = new Random(0);
StaticStrideScheduler sss = new StaticStrideScheduler(weights,
new AtomicInteger(random.nextInt()));
sss.pick();
@ -844,7 +847,7 @@ public class WeightedRoundRobinLoadBalancerTest {
@Test
public void testPicksEqualsWeights() {
float[] weights = {1.0f, 2.0f, 3.0f};
Random random = new Random();
Random random = new Random(0);
StaticStrideScheduler sss = new StaticStrideScheduler(weights,
new AtomicInteger(random.nextInt()));
int[] expectedPicks = new int[] {1, 2, 3};
@ -858,7 +861,7 @@ public class WeightedRoundRobinLoadBalancerTest {
@Test
public void testContainsZeroWeightUseMean() {
float[] weights = {3.0f, 0.0f, 1.0f};
Random random = new Random();
Random random = new Random(0);
StaticStrideScheduler sss = new StaticStrideScheduler(weights,
new AtomicInteger(random.nextInt()));
int[] expectedPicks = new int[] {3, 2, 1};
@ -872,7 +875,7 @@ public class WeightedRoundRobinLoadBalancerTest {
@Test
public void testContainsNegativeWeightUseMean() {
float[] weights = {3.0f, -1.0f, 1.0f};
Random random = new Random();
Random random = new Random(0);
StaticStrideScheduler sss = new StaticStrideScheduler(weights,
new AtomicInteger(random.nextInt()));
int[] expectedPicks = new int[] {3, 2, 1};
@ -886,7 +889,7 @@ public class WeightedRoundRobinLoadBalancerTest {
@Test
public void testAllSameWeights() {
float[] weights = {1.0f, 1.0f, 1.0f};
Random random = new Random();
Random random = new Random(0);
StaticStrideScheduler sss = new StaticStrideScheduler(weights,
new AtomicInteger(random.nextInt()));
int[] expectedPicks = new int[] {2, 2, 2};
@ -898,9 +901,9 @@ public class WeightedRoundRobinLoadBalancerTest {
}
@Test
public void testAllZeroWeightsUseOne() {
public void testAllZeroWeightsIsRoundRobin() {
float[] weights = {0.0f, 0.0f, 0.0f};
Random random = new Random();
Random random = new Random(0);
StaticStrideScheduler sss = new StaticStrideScheduler(weights,
new AtomicInteger(random.nextInt()));
int[] expectedPicks = new int[] {2, 2, 2};
@ -912,9 +915,9 @@ public class WeightedRoundRobinLoadBalancerTest {
}
@Test
public void testAllInvalidWeightsUseOne() {
public void testAllInvalidWeightsIsRoundRobin() {
float[] weights = {-3.1f, -0.0f, 0.0f};
Random random = new Random();
Random random = new Random(0);
StaticStrideScheduler sss = new StaticStrideScheduler(weights,
new AtomicInteger(random.nextInt()));
int[] expectedPicks = new int[] {2, 2, 2};
@ -925,66 +928,13 @@ public class WeightedRoundRobinLoadBalancerTest {
assertThat(picks).isEqualTo(expectedPicks);
}
@Test
public void testLargestWeightIndexPickedEveryGeneration() {
float[] weights = {1.0f, 2.0f, 3.0f};
int largestWeightIndex = 2;
Random random = new Random();
StaticStrideScheduler sss = new StaticStrideScheduler(weights,
new AtomicInteger(random.nextInt()));
int largestWeightPickCount = 0;
int kMaxWeight = 65535;
for (int i = 0; i < largestWeightIndex * kMaxWeight; i++) {
if (sss.pick() == largestWeightIndex) {
largestWeightPickCount += 1;
}
}
assertThat(largestWeightPickCount).isEqualTo(kMaxWeight);
}
@Test
public void testStaticStrideSchedulerNonIntegers1() {
float[] weights = {2.0f, (float) (10.0 / 3.0), 1.0f};
Random random = new Random();
StaticStrideScheduler sss = new StaticStrideScheduler(weights,
new AtomicInteger(random.nextInt()));
double totalWeight = 2 + 10.0 / 3.0 + 1.0;
Map<Integer, Integer> pickCount = new HashMap<>();
for (int i = 0; i < 1000; i++) {
int result = sss.pick();
pickCount.merge(result, 1, (o, v) -> o + v);
}
for (int i = 0; i < 3; i++) {
assertThat(Math.abs(pickCount.getOrDefault(i, 0) / 1000.0 - weights[i] / totalWeight))
.isLessThan(0.002);
}
}
@Test
public void testStaticStrideSchedulerNonIntegers2() {
float[] weights = {0.5f, 0.3f, 1.0f};
Random random = new Random();
StaticStrideScheduler sss = new StaticStrideScheduler(weights,
new AtomicInteger(random.nextInt()));
double totalWeight = 1.8;
Map<Integer, Integer> pickCount = new HashMap<>();
for (int i = 0; i < 1000; i++) {
int result = sss.pick();
pickCount.put(result, pickCount.getOrDefault(result, 0) + 1);
}
for (int i = 0; i < 3; i++) {
assertThat(Math.abs(pickCount.getOrDefault(i, 0) / 1000.0 - weights[i] / totalWeight))
.isLessThan(0.002);
}
}
@Test
public void testTwoWeights() {
float[] weights = {1.0f, 2.0f};
Random random = new Random();
float[] weights = {1.43f, 2.119f};
Random random = new Random(0);
StaticStrideScheduler sss = new StaticStrideScheduler(weights,
new AtomicInteger(random.nextInt()));
double totalWeight = 3;
double totalWeight = 1.43 + 2.119;
Map<Integer, Integer> pickCount = new HashMap<>();
for (int i = 0; i < 1000; i++) {
int result = sss.pick();
@ -998,11 +948,11 @@ public class WeightedRoundRobinLoadBalancerTest {
@Test
public void testManyWeights() {
float[] weights = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
Random random = new Random();
float[] weights = {1.3f, 2.5f, 3.23f, 4.11f, 7.001f};
Random random = new Random(0);
StaticStrideScheduler sss = new StaticStrideScheduler(weights,
new AtomicInteger(random.nextInt()));
double totalWeight = 15;
double totalWeight = 1.3 + 2.5 + 3.23 + 4.11 + 7.001;
Map<Integer, Integer> pickCount = new HashMap<>();
for (int i = 0; i < 1000; i++) {
int result = sss.pick();
@ -1015,21 +965,38 @@ public class WeightedRoundRobinLoadBalancerTest {
}
@Test
public void testManyComplexWeights() {
float[] weights = {1.2f, 2.4f, 222.56f, 1.1f, 15.0f, 226342.0f, 5123.0f, 532.2f};
Random random = new Random();
public void testMaxClamped() {
float[] weights = {81f, 1f, 1f, 1f, 1f, 1f, 1f, 1f,
1f, 1f, 1f, 1f, 1f, 1f, 1f, 1f, 1f, 1f, 1f, 1f};
StaticStrideScheduler sss = new StaticStrideScheduler(weights,
new AtomicInteger(random.nextInt()));
double totalWeight = 1.2 + 2.4 + 222.56 + 15.0 + 226342.0 + 5123.0 + 0.0001;
Map<Integer, Integer> pickCount = new HashMap<>();
for (int i = 0; i < 1000; i++) {
int result = sss.pick();
pickCount.put(result, pickCount.getOrDefault(result, 0) + 1);
new AtomicInteger(0));
int[] picks = new int[weights.length];
// max gets clamped to mean*maxRatio = 50 for this set of weights. So if we
// pick 50 + 19 times we should get all possible picks.
for (int i = 1; i < 70; i++) {
picks[sss.pick()] += 1;
}
for (int i = 0; i < 8; i++) {
assertThat(Math.abs(pickCount.getOrDefault(i, 0) / 1000.0 - weights[i] / totalWeight))
.isAtMost(0.004);
int[] expectedPicks = new int[] {50, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
assertThat(picks).isEqualTo(expectedPicks);
}
@Test
public void testMinClamped() {
float[] weights = {100f, 1e-10f};
StaticStrideScheduler sss = new StaticStrideScheduler(weights,
new AtomicInteger(0));
int[] picks = new int[weights.length];
// We pick 201 elements and ensure that the second channel (with epsilon
// weight) also gets picked. The math is: mean value of elements is ~50, so
// the first channel keeps its weight of 100, but the second element's weight
// gets capped from below to 50*0.1 = 5.
for (int i = 0; i < 105; i++) {
picks[sss.pick()] += 1;
}
int[] expectedPicks = new int[] {100, 5};
assertThat(picks).isEqualTo(expectedPicks);
}
@Test