mirror of https://github.com/grpc/grpc-java.git
xds: WRR scheduler clips weights (#10480)
This commit is contained in:
parent
26be0c7665
commit
97f4f8687c
|
@ -351,41 +351,69 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
|
|||
private final AtomicInteger sequence;
|
||||
private static final int K_MAX_WEIGHT = 0xFFFF;
|
||||
|
||||
// Assuming the mean of all known weights is M, StaticStrideScheduler will clamp
|
||||
// weights bigger than M*kMaxRatio and weights smaller than M*kMinRatio.
|
||||
//
|
||||
// This is done as a performance optimization by limiting the number of rounds for picks
|
||||
// for edge cases where channels have large differences in subchannel weights.
|
||||
// In this case, without these clips, it would potentially require the scheduler to
|
||||
// frequently traverse through the entire subchannel list within the pick method.
|
||||
//
|
||||
// The current values of 10 and 0.1 were chosen without any experimenting. It should
|
||||
// decrease the amount of sequences that the scheduler must traverse through in order
|
||||
// to pick a high weight subchannel in such corner cases.
|
||||
// But, it also makes WeightedRoundRobin to send slightly more requests to
|
||||
// potentially very bad tasks (that would have near-zero weights) than zero.
|
||||
// This is not necessarily a downside, though. Perhaps this is not a problem at
|
||||
// all, and we can increase this value if needed to save CPU cycles.
|
||||
private static final double K_MAX_RATIO = 10;
|
||||
private static final double K_MIN_RATIO = 0.1;
|
||||
|
||||
StaticStrideScheduler(float[] weights, AtomicInteger sequence) {
|
||||
checkArgument(weights.length >= 1, "Couldn't build scheduler: requires at least one weight");
|
||||
int numChannels = weights.length;
|
||||
int numWeightedChannels = 0;
|
||||
double sumWeight = 0;
|
||||
float maxWeight = 0;
|
||||
short meanWeight = 0;
|
||||
double unscaledMeanWeight;
|
||||
float unscaledMaxWeight = 0;
|
||||
for (float weight : weights) {
|
||||
if (weight > 0) {
|
||||
sumWeight += weight;
|
||||
maxWeight = Math.max(weight, maxWeight);
|
||||
unscaledMaxWeight = Math.max(weight, unscaledMaxWeight);
|
||||
numWeightedChannels++;
|
||||
}
|
||||
}
|
||||
|
||||
double scalingFactor = K_MAX_WEIGHT / maxWeight;
|
||||
// Adjust max value s.t. ratio does not exceed K_MAX_RATIO. This should
|
||||
// ensure that we on average do at most K_MAX_RATIO rounds for picks.
|
||||
if (numWeightedChannels > 0) {
|
||||
meanWeight = (short) Math.round(scalingFactor * sumWeight / numWeightedChannels);
|
||||
unscaledMeanWeight = sumWeight / numWeightedChannels;
|
||||
unscaledMaxWeight = Math.min(unscaledMaxWeight, (float) (K_MAX_RATIO * unscaledMeanWeight));
|
||||
} else {
|
||||
meanWeight = (short) Math.round(scalingFactor);
|
||||
// Fall back to round robin if all values are non-positives
|
||||
unscaledMeanWeight = 1;
|
||||
unscaledMaxWeight = 1;
|
||||
}
|
||||
|
||||
// scales weights s.t. max(weights) == K_MAX_WEIGHT, meanWeight is scaled accordingly
|
||||
// Scales weights s.t. max(weights) == K_MAX_WEIGHT, meanWeight is scaled accordingly.
|
||||
// Note that, since we cap the weights to stay within K_MAX_RATIO, meanWeight might not
|
||||
// match the actual mean of the values that end up in the scheduler.
|
||||
double scalingFactor = K_MAX_WEIGHT / unscaledMaxWeight;
|
||||
// We compute weightLowerBound and clamp it to 1 from below so that in the
|
||||
// worst case, we represent tiny weights as 1.
|
||||
int weightLowerBound = (int) Math.ceil(scalingFactor * unscaledMeanWeight * K_MIN_RATIO);
|
||||
short[] scaledWeights = new short[numChannels];
|
||||
for (int i = 0; i < numChannels; i++) {
|
||||
if (weights[i] <= 0) {
|
||||
scaledWeights[i] = meanWeight;
|
||||
scaledWeights[i] = (short) Math.round(scalingFactor * unscaledMeanWeight);
|
||||
} else {
|
||||
scaledWeights[i] = (short) Math.round(weights[i] * scalingFactor);
|
||||
int weight = (int) Math.round(scalingFactor * Math.min(weights[i], unscaledMaxWeight));
|
||||
scaledWeights[i] = (short) Math.max(weight, weightLowerBound);
|
||||
}
|
||||
}
|
||||
|
||||
this.scaledWeights = scaledWeights;
|
||||
this.sequence = sequence;
|
||||
|
||||
}
|
||||
|
||||
/** Returns the next sequence number and atomically increases sequence with wraparound. */
|
||||
|
|
|
@ -330,11 +330,11 @@ public class WeightedRoundRobinLoadBalancerTest {
|
|||
}
|
||||
assertThat(pickCount.size()).isEqualTo(3);
|
||||
assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 10000.0 - subchannel1PickRatio))
|
||||
.isAtMost(0.0001);
|
||||
.isLessThan(0.0002);
|
||||
assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 10000.0 - subchannel2PickRatio ))
|
||||
.isAtMost(0.0001);
|
||||
.isLessThan(0.0002);
|
||||
assertThat(Math.abs(pickCount.get(weightedSubchannel3) / 10000.0 - subchannel3PickRatio ))
|
||||
.isAtMost(0.0001);
|
||||
.isLessThan(0.0002);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -345,10 +345,11 @@ public class WeightedRoundRobinLoadBalancerTest {
|
|||
0.9, 0, 0.1, 2, 0, new HashMap<>(), new HashMap<>(), new HashMap<>());
|
||||
MetricReport report3 = InternalCallMetricRecorder.createMetricReport(
|
||||
0.86, 0, 0.1, 100, 0, new HashMap<>(), new HashMap<>(), new HashMap<>());
|
||||
double totalWeight = 999 / 0.1 + 2 / 0.9 + 100 / 0.86;
|
||||
|
||||
pickByWeight(report1, report2, report3, 999 / 0.1 / totalWeight, 2 / 0.9 / totalWeight,
|
||||
100 / 0.86 / totalWeight);
|
||||
double meanWeight = (999 / 0.1 + 2 / 0.9 + 100 / 0.86) / 3;
|
||||
double cappedMin = meanWeight * 0.1; // min capped at minRatio * meanWeight
|
||||
double totalWeight = 999 / 0.1 + cappedMin + cappedMin;
|
||||
pickByWeight(report1, report2, report3, 999 / 0.1 / totalWeight, cappedMin / totalWeight,
|
||||
cappedMin / totalWeight);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -359,10 +360,11 @@ public class WeightedRoundRobinLoadBalancerTest {
|
|||
0.12, 0.9, 0.1, 2, 0, new HashMap<>(), new HashMap<>(), new HashMap<>());
|
||||
MetricReport report3 = InternalCallMetricRecorder.createMetricReport(
|
||||
0.33, 0.86, 0.1, 100, 0, new HashMap<>(), new HashMap<>(), new HashMap<>());
|
||||
double totalWeight = 999 / 0.1 + 2 / 0.9 + 100 / 0.86;
|
||||
|
||||
pickByWeight(report1, report2, report3, 999 / 0.1 / totalWeight, 2 / 0.9 / totalWeight,
|
||||
100 / 0.86 / totalWeight);
|
||||
double meanWeight = (999 / 0.1 + 2 / 0.9 + 100 / 0.86) / 3;
|
||||
double cappedMin = meanWeight * 0.1;
|
||||
double totalWeight = 999 / 0.1 + cappedMin + cappedMin; // min capped at minRatio * meanWeight
|
||||
pickByWeight(report1, report2, report3, 999 / 0.1 / totalWeight, cappedMin / totalWeight,
|
||||
cappedMin / totalWeight);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -373,13 +375,14 @@ public class WeightedRoundRobinLoadBalancerTest {
|
|||
0.9, 0, 0.1, 2, 1.8, new HashMap<>(), new HashMap<>(), new HashMap<>());
|
||||
MetricReport report3 = InternalCallMetricRecorder.createMetricReport(
|
||||
0.86, 0, 0.1, 100, 3, new HashMap<>(), new HashMap<>(), new HashMap<>());
|
||||
double weight1 = 999 / (0.1 + 13 / 999F * weightedConfig.errorUtilizationPenalty);
|
||||
double weight2 = 2 / (0.9 + 1.8 / 2F * weightedConfig.errorUtilizationPenalty);
|
||||
double weight3 = 100 / (0.86 + 3 / 100F * weightedConfig.errorUtilizationPenalty);
|
||||
double totalWeight = weight1 + weight2 + weight3;
|
||||
|
||||
pickByWeight(report1, report2, report3, weight1 / totalWeight, weight2 / totalWeight,
|
||||
weight3 / totalWeight);
|
||||
double weight1 = 999 / (0.1 + 13 / 999F * weightedConfig.errorUtilizationPenalty); // ~5609.899
|
||||
double weight2 = 2 / (0.9 + 1.8 / 2F * weightedConfig.errorUtilizationPenalty); // ~0.317
|
||||
double weight3 = 100 / (0.86 + 3 / 100F * weightedConfig.errorUtilizationPenalty); // ~96.154
|
||||
double meanWeight = (weight1 + weight2 + weight3) / 3;
|
||||
double cappedMin = meanWeight * 0.1; // min capped at minRatio * meanWeight
|
||||
double totalWeight = weight1 + cappedMin + cappedMin;
|
||||
pickByWeight(report1, report2, report3, weight1 / totalWeight, cappedMin / totalWeight,
|
||||
cappedMin / totalWeight);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -835,7 +838,7 @@ public class WeightedRoundRobinLoadBalancerTest {
|
|||
@Test(expected = IllegalArgumentException.class)
|
||||
public void emptyWeights() {
|
||||
float[] weights = {};
|
||||
Random random = new Random();
|
||||
Random random = new Random(0);
|
||||
StaticStrideScheduler sss = new StaticStrideScheduler(weights,
|
||||
new AtomicInteger(random.nextInt()));
|
||||
sss.pick();
|
||||
|
@ -844,7 +847,7 @@ public class WeightedRoundRobinLoadBalancerTest {
|
|||
@Test
|
||||
public void testPicksEqualsWeights() {
|
||||
float[] weights = {1.0f, 2.0f, 3.0f};
|
||||
Random random = new Random();
|
||||
Random random = new Random(0);
|
||||
StaticStrideScheduler sss = new StaticStrideScheduler(weights,
|
||||
new AtomicInteger(random.nextInt()));
|
||||
int[] expectedPicks = new int[] {1, 2, 3};
|
||||
|
@ -858,7 +861,7 @@ public class WeightedRoundRobinLoadBalancerTest {
|
|||
@Test
|
||||
public void testContainsZeroWeightUseMean() {
|
||||
float[] weights = {3.0f, 0.0f, 1.0f};
|
||||
Random random = new Random();
|
||||
Random random = new Random(0);
|
||||
StaticStrideScheduler sss = new StaticStrideScheduler(weights,
|
||||
new AtomicInteger(random.nextInt()));
|
||||
int[] expectedPicks = new int[] {3, 2, 1};
|
||||
|
@ -872,7 +875,7 @@ public class WeightedRoundRobinLoadBalancerTest {
|
|||
@Test
|
||||
public void testContainsNegativeWeightUseMean() {
|
||||
float[] weights = {3.0f, -1.0f, 1.0f};
|
||||
Random random = new Random();
|
||||
Random random = new Random(0);
|
||||
StaticStrideScheduler sss = new StaticStrideScheduler(weights,
|
||||
new AtomicInteger(random.nextInt()));
|
||||
int[] expectedPicks = new int[] {3, 2, 1};
|
||||
|
@ -886,7 +889,7 @@ public class WeightedRoundRobinLoadBalancerTest {
|
|||
@Test
|
||||
public void testAllSameWeights() {
|
||||
float[] weights = {1.0f, 1.0f, 1.0f};
|
||||
Random random = new Random();
|
||||
Random random = new Random(0);
|
||||
StaticStrideScheduler sss = new StaticStrideScheduler(weights,
|
||||
new AtomicInteger(random.nextInt()));
|
||||
int[] expectedPicks = new int[] {2, 2, 2};
|
||||
|
@ -898,9 +901,9 @@ public class WeightedRoundRobinLoadBalancerTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testAllZeroWeightsUseOne() {
|
||||
public void testAllZeroWeightsIsRoundRobin() {
|
||||
float[] weights = {0.0f, 0.0f, 0.0f};
|
||||
Random random = new Random();
|
||||
Random random = new Random(0);
|
||||
StaticStrideScheduler sss = new StaticStrideScheduler(weights,
|
||||
new AtomicInteger(random.nextInt()));
|
||||
int[] expectedPicks = new int[] {2, 2, 2};
|
||||
|
@ -912,9 +915,9 @@ public class WeightedRoundRobinLoadBalancerTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testAllInvalidWeightsUseOne() {
|
||||
public void testAllInvalidWeightsIsRoundRobin() {
|
||||
float[] weights = {-3.1f, -0.0f, 0.0f};
|
||||
Random random = new Random();
|
||||
Random random = new Random(0);
|
||||
StaticStrideScheduler sss = new StaticStrideScheduler(weights,
|
||||
new AtomicInteger(random.nextInt()));
|
||||
int[] expectedPicks = new int[] {2, 2, 2};
|
||||
|
@ -925,66 +928,13 @@ public class WeightedRoundRobinLoadBalancerTest {
|
|||
assertThat(picks).isEqualTo(expectedPicks);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLargestWeightIndexPickedEveryGeneration() {
|
||||
float[] weights = {1.0f, 2.0f, 3.0f};
|
||||
int largestWeightIndex = 2;
|
||||
Random random = new Random();
|
||||
StaticStrideScheduler sss = new StaticStrideScheduler(weights,
|
||||
new AtomicInteger(random.nextInt()));
|
||||
int largestWeightPickCount = 0;
|
||||
int kMaxWeight = 65535;
|
||||
for (int i = 0; i < largestWeightIndex * kMaxWeight; i++) {
|
||||
if (sss.pick() == largestWeightIndex) {
|
||||
largestWeightPickCount += 1;
|
||||
}
|
||||
}
|
||||
assertThat(largestWeightPickCount).isEqualTo(kMaxWeight);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testStaticStrideSchedulerNonIntegers1() {
|
||||
float[] weights = {2.0f, (float) (10.0 / 3.0), 1.0f};
|
||||
Random random = new Random();
|
||||
StaticStrideScheduler sss = new StaticStrideScheduler(weights,
|
||||
new AtomicInteger(random.nextInt()));
|
||||
double totalWeight = 2 + 10.0 / 3.0 + 1.0;
|
||||
Map<Integer, Integer> pickCount = new HashMap<>();
|
||||
for (int i = 0; i < 1000; i++) {
|
||||
int result = sss.pick();
|
||||
pickCount.merge(result, 1, (o, v) -> o + v);
|
||||
}
|
||||
for (int i = 0; i < 3; i++) {
|
||||
assertThat(Math.abs(pickCount.getOrDefault(i, 0) / 1000.0 - weights[i] / totalWeight))
|
||||
.isLessThan(0.002);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testStaticStrideSchedulerNonIntegers2() {
|
||||
float[] weights = {0.5f, 0.3f, 1.0f};
|
||||
Random random = new Random();
|
||||
StaticStrideScheduler sss = new StaticStrideScheduler(weights,
|
||||
new AtomicInteger(random.nextInt()));
|
||||
double totalWeight = 1.8;
|
||||
Map<Integer, Integer> pickCount = new HashMap<>();
|
||||
for (int i = 0; i < 1000; i++) {
|
||||
int result = sss.pick();
|
||||
pickCount.put(result, pickCount.getOrDefault(result, 0) + 1);
|
||||
}
|
||||
for (int i = 0; i < 3; i++) {
|
||||
assertThat(Math.abs(pickCount.getOrDefault(i, 0) / 1000.0 - weights[i] / totalWeight))
|
||||
.isLessThan(0.002);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTwoWeights() {
|
||||
float[] weights = {1.0f, 2.0f};
|
||||
Random random = new Random();
|
||||
float[] weights = {1.43f, 2.119f};
|
||||
Random random = new Random(0);
|
||||
StaticStrideScheduler sss = new StaticStrideScheduler(weights,
|
||||
new AtomicInteger(random.nextInt()));
|
||||
double totalWeight = 3;
|
||||
double totalWeight = 1.43 + 2.119;
|
||||
Map<Integer, Integer> pickCount = new HashMap<>();
|
||||
for (int i = 0; i < 1000; i++) {
|
||||
int result = sss.pick();
|
||||
|
@ -998,11 +948,11 @@ public class WeightedRoundRobinLoadBalancerTest {
|
|||
|
||||
@Test
|
||||
public void testManyWeights() {
|
||||
float[] weights = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
|
||||
Random random = new Random();
|
||||
float[] weights = {1.3f, 2.5f, 3.23f, 4.11f, 7.001f};
|
||||
Random random = new Random(0);
|
||||
StaticStrideScheduler sss = new StaticStrideScheduler(weights,
|
||||
new AtomicInteger(random.nextInt()));
|
||||
double totalWeight = 15;
|
||||
double totalWeight = 1.3 + 2.5 + 3.23 + 4.11 + 7.001;
|
||||
Map<Integer, Integer> pickCount = new HashMap<>();
|
||||
for (int i = 0; i < 1000; i++) {
|
||||
int result = sss.pick();
|
||||
|
@ -1015,21 +965,38 @@ public class WeightedRoundRobinLoadBalancerTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testManyComplexWeights() {
|
||||
float[] weights = {1.2f, 2.4f, 222.56f, 1.1f, 15.0f, 226342.0f, 5123.0f, 532.2f};
|
||||
Random random = new Random();
|
||||
public void testMaxClamped() {
|
||||
float[] weights = {81f, 1f, 1f, 1f, 1f, 1f, 1f, 1f,
|
||||
1f, 1f, 1f, 1f, 1f, 1f, 1f, 1f, 1f, 1f, 1f, 1f};
|
||||
StaticStrideScheduler sss = new StaticStrideScheduler(weights,
|
||||
new AtomicInteger(random.nextInt()));
|
||||
double totalWeight = 1.2 + 2.4 + 222.56 + 15.0 + 226342.0 + 5123.0 + 0.0001;
|
||||
Map<Integer, Integer> pickCount = new HashMap<>();
|
||||
for (int i = 0; i < 1000; i++) {
|
||||
int result = sss.pick();
|
||||
pickCount.put(result, pickCount.getOrDefault(result, 0) + 1);
|
||||
new AtomicInteger(0));
|
||||
int[] picks = new int[weights.length];
|
||||
|
||||
// max gets clamped to mean*maxRatio = 50 for this set of weights. So if we
|
||||
// pick 50 + 19 times we should get all possible picks.
|
||||
for (int i = 1; i < 70; i++) {
|
||||
picks[sss.pick()] += 1;
|
||||
}
|
||||
for (int i = 0; i < 8; i++) {
|
||||
assertThat(Math.abs(pickCount.getOrDefault(i, 0) / 1000.0 - weights[i] / totalWeight))
|
||||
.isAtMost(0.004);
|
||||
int[] expectedPicks = new int[] {50, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
|
||||
assertThat(picks).isEqualTo(expectedPicks);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMinClamped() {
|
||||
float[] weights = {100f, 1e-10f};
|
||||
StaticStrideScheduler sss = new StaticStrideScheduler(weights,
|
||||
new AtomicInteger(0));
|
||||
int[] picks = new int[weights.length];
|
||||
|
||||
// We pick 201 elements and ensure that the second channel (with epsilon
|
||||
// weight) also gets picked. The math is: mean value of elements is ~50, so
|
||||
// the first channel keeps its weight of 100, but the second element's weight
|
||||
// gets capped from below to 50*0.1 = 5.
|
||||
for (int i = 0; i < 105; i++) {
|
||||
picks[sss.pick()] += 1;
|
||||
}
|
||||
int[] expectedPicks = new int[] {100, 5};
|
||||
assertThat(picks).isEqualTo(expectedPicks);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
Loading…
Reference in New Issue