xds: fix wrr stuck in rr mode (#10061)

This commit is contained in:
yifeizhuang 2023-04-18 16:39:51 -07:00 committed by GitHub
parent 35852130d9
commit 111ff60e1c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 72 additions and 18 deletions

View File

@ -46,6 +46,8 @@ import java.util.PriorityQueue;
import java.util.Random;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.logging.Level;
import java.util.logging.Logger;
/**
* A {@link LoadBalancer} that provides weighted-round-robin load-balancing over
@ -54,6 +56,8 @@ import java.util.concurrent.TimeUnit;
*/
@ExperimentalApi("https://github.com/grpc/grpc-java/issues/9885")
final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
private static final Logger log = Logger.getLogger(
WeightedRoundRobinLoadBalancer.class.getName());
private volatile WeightedRoundRobinLoadBalancerConfig config;
private final SynchronizationContext syncContext;
private final ScheduledExecutorService timeService;
@ -76,6 +80,7 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
this.timeService = checkNotNull(helper.getScheduledExecutorService(), "timeService");
this.updateWeightTask = new UpdateWeightTask();
this.random = random;
log.log(Level.FINE, "weighted_round_robin LB created");
}
@VisibleForTesting
@ -230,7 +235,6 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
final class WeightedRoundRobinPicker extends ReadyPicker {
private final List<Subchannel> list;
private volatile EdfScheduler scheduler;
private volatile boolean rrMode;
WeightedRoundRobinPicker(List<Subchannel> list) {
super(checkNotNull(list, "list"), random.nextInt(list.size()));
@ -241,16 +245,11 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
@Override
public PickResult pickSubchannel(PickSubchannelArgs args) {
if (rrMode) {
return super.pickSubchannel(args);
}
int pickIndex = scheduler.pick();
WrrSubchannel subchannel = (WrrSubchannel) list.get(pickIndex);
Subchannel subchannel = list.get(scheduler.pick());
if (!config.enableOobLoadReport) {
return PickResult.withSubchannel(
subchannel,
OrcaPerRequestUtil.getInstance().newOrcaClientStreamTracerFactory(
subchannel.perRpcListener));
return PickResult.withSubchannel(subchannel,
OrcaPerRequestUtil.getInstance().newOrcaClientStreamTracerFactory(
((WrrSubchannel)subchannel).perRpcListener));
} else {
return PickResult.withSubchannel(subchannel);
}
@ -266,25 +265,24 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
weightedChannelCount++;
}
}
if (weightedChannelCount < 2) {
rrMode = true;
return;
}
EdfScheduler scheduler = new EdfScheduler(list.size(), random);
avgWeight /= 1.0 * weightedChannelCount;
if (weightedChannelCount >= 1) {
avgWeight /= 1.0 * weightedChannelCount;
} else {
avgWeight = 1;
}
for (int i = 0; i < list.size(); i++) {
WrrSubchannel subchannel = (WrrSubchannel) list.get(i);
double newWeight = subchannel.getWeight();
scheduler.add(i, newWeight > 0 ? newWeight : avgWeight);
}
this.scheduler = scheduler;
rrMode = false;
}
@Override
public String toString() {
return MoreObjects.toStringHelper(WeightedRoundRobinPicker.class)
.add("list", list).add("rrMode", rrMode).toString();
.add("list", list).toString();
}
@VisibleForTesting

View File

@ -29,6 +29,7 @@ import static org.mockito.Mockito.when;
import com.github.xds.data.orca.v3.OrcaLoadReport;
import com.github.xds.service.orca.v3.OrcaLoadReportRequest;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.protobuf.Duration;
@ -514,6 +515,62 @@ public class WeightedRoundRobinLoadBalancerTest {
.isAtMost(0.001);
}
@Test
public void rrFallback() {
syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder()
.setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig)
.setAttributes(affinity).build()));
verify(helper, times(3)).createSubchannel(
any(CreateSubchannelArgs.class));
assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1);
Iterator<Subchannel> it = subchannels.values().iterator();
Subchannel readySubchannel1 = it.next();
subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
Subchannel readySubchannel2 = it.next();
subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
verify(helper, times(2)).updateBalancingState(
eq(ConnectivityState.READY), pickerCaptor.capture());
WeightedRoundRobinPicker weightedPicker = pickerCaptor.getAllValues().get(1);
assertThat(fakeClock.forwardTime(10, TimeUnit.SECONDS)).isEqualTo(1);
WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0);
WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1);
Map<WrrSubchannel, Integer> qpsByChannel = ImmutableMap.of(weightedSubchannel1, 2,
weightedSubchannel2, 1);
Map<Subchannel, Integer> pickCount = new HashMap<>();
for (int i = 0; i < 1000; i++) {
PickResult pickResult = weightedPicker.pickSubchannel(mockArgs);
pickCount.put(pickResult.getSubchannel(),
pickCount.getOrDefault(pickResult.getSubchannel(), 0) + 1);
assertThat(pickResult.getStreamTracerFactory()).isNotNull();
WrrSubchannel subchannel = (WrrSubchannel)pickResult.getSubchannel();
subchannel.onLoadReport(InternalCallMetricRecorder.createMetricReport(
0.1, 0.1, qpsByChannel.get(subchannel), new HashMap<>(), new HashMap<>()));
}
assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 1.0 / 2))
.isAtMost(0.1);
assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 1.0 / 2))
.isAtMost(0.1);
pickCount.clear();
for (int i = 0; i < 1000; i++) {
PickResult pickResult = weightedPicker.pickSubchannel(mockArgs);
pickCount.put(pickResult.getSubchannel(),
pickCount.getOrDefault(pickResult.getSubchannel(), 0) + 1);
assertThat(pickResult.getStreamTracerFactory()).isNotNull();
WrrSubchannel subchannel = (WrrSubchannel)pickResult.getSubchannel();
subchannel.onLoadReport(InternalCallMetricRecorder.createMetricReport(
0.1, 0.1, qpsByChannel.get(subchannel), new HashMap<>(), new HashMap<>()));
fakeClock.forwardTime(50, TimeUnit.MILLISECONDS);
}
assertThat(pickCount.size()).isEqualTo(2);
assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 2.0 / 3))
.isAtMost(0.1);
assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 1.0 / 3))
.isAtMost(0.1);
}
@Test
public void unknownWeightIsAvgWeight() {
syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder()
@ -584,7 +641,6 @@ public class WeightedRoundRobinLoadBalancerTest {
0.1, 0.1, 1, new HashMap<>(), new HashMap<>()));
weightedSubchannel2.onLoadReport(InternalCallMetricRecorder.createMetricReport(
0.2, 0.1, 1, new HashMap<>(), new HashMap<>()));
assertThat(weightedPicker.toString()).contains("rrMode=true");
CyclicBarrier barrier = new CyclicBarrier(2);
Map<Subchannel, AtomicInteger> pickCount = new ConcurrentHashMap<>();
pickCount.put(weightedSubchannel1, new AtomicInteger(0));