xds: WRRPicker must not access unsynchronized data in ChildLbState

There was no point to using subchannels as keys to
subchannelToReportListenerMap, as the listener is per-child. That meant
the keys would be guaranteed to be known ahead-of-time and the
unsynchronized getOrCreateOrcaListener() during picking was unnecessary.

The picker still stores ChildLbStates to make sure that updating weights
uses the correct children, but the picker itself no longer references
ChildLbStates except in the constructor. That means weight calculation
is moved into the LB policy, as child.getWeight() is unsynchronized, and
the picker no longer needs a reference to helper.
This commit is contained in:
Eric Anderson 2024-08-12 11:23:37 -07:00 committed by GitHub
parent 0d2ad89016
commit 0d47f5bd1b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 67 additions and 67 deletions

View File

@ -44,11 +44,10 @@ import io.grpc.xds.orca.OrcaOobUtil;
import io.grpc.xds.orca.OrcaOobUtil.OrcaOobReportListener; import io.grpc.xds.orca.OrcaOobUtil.OrcaOobReportListener;
import io.grpc.xds.orca.OrcaPerRequestUtil; import io.grpc.xds.orca.OrcaPerRequestUtil;
import io.grpc.xds.orca.OrcaPerRequestUtil.OrcaPerRequestReportListener; import io.grpc.xds.orca.OrcaPerRequestUtil.OrcaPerRequestReportListener;
import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet; import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Random; import java.util.Random;
import java.util.Set; import java.util.Set;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
@ -233,9 +232,44 @@ final class WeightedRoundRobinLoadBalancer extends MultiChildLoadBalancer {
} }
private SubchannelPicker createReadyPicker(Collection<ChildLbState> activeList) { private SubchannelPicker createReadyPicker(Collection<ChildLbState> activeList) {
return new WeightedRoundRobinPicker(ImmutableList.copyOf(activeList), WeightedRoundRobinPicker picker = new WeightedRoundRobinPicker(ImmutableList.copyOf(activeList),
config.enableOobLoadReport, config.errorUtilizationPenalty, sequence, getHelper(), config.enableOobLoadReport, config.errorUtilizationPenalty, sequence);
locality); updateWeight(picker);
return picker;
}
private void updateWeight(WeightedRoundRobinPicker picker) {
Helper helper = getHelper();
float[] newWeights = new float[picker.children.size()];
AtomicInteger staleEndpoints = new AtomicInteger();
AtomicInteger notYetUsableEndpoints = new AtomicInteger();
for (int i = 0; i < picker.children.size(); i++) {
double newWeight = ((WeightedChildLbState) picker.children.get(i)).getWeight(staleEndpoints,
notYetUsableEndpoints);
helper.getMetricRecorder()
.recordDoubleHistogram(ENDPOINT_WEIGHTS_HISTOGRAM, newWeight,
ImmutableList.of(helper.getChannelTarget()),
ImmutableList.of(locality));
newWeights[i] = newWeight > 0 ? (float) newWeight : 0.0f;
}
if (staleEndpoints.get() > 0) {
helper.getMetricRecorder()
.addLongCounter(ENDPOINT_WEIGHT_STALE_COUNTER, staleEndpoints.get(),
ImmutableList.of(helper.getChannelTarget()),
ImmutableList.of(locality));
}
if (notYetUsableEndpoints.get() > 0) {
helper.getMetricRecorder()
.addLongCounter(ENDPOINT_WEIGHT_NOT_YET_USEABLE_COUNTER, notYetUsableEndpoints.get(),
ImmutableList.of(helper.getChannelTarget()), ImmutableList.of(locality));
}
boolean weightsEffective = picker.updateWeight(newWeights);
if (!weightsEffective) {
helper.getMetricRecorder()
.addLongCounter(RR_FALLBACK_COUNTER, 1, ImmutableList.of(helper.getChannelTarget()),
ImmutableList.of(locality));
}
} }
private void updateBalancingState(ConnectivityState state, SubchannelPicker picker) { private void updateBalancingState(ConnectivityState state, SubchannelPicker picker) {
@ -345,7 +379,7 @@ final class WeightedRoundRobinLoadBalancer extends MultiChildLoadBalancer {
@Override @Override
public void run() { public void run() {
if (currentPicker != null && currentPicker instanceof WeightedRoundRobinPicker) { if (currentPicker != null && currentPicker instanceof WeightedRoundRobinPicker) {
((WeightedRoundRobinPicker) currentPicker).updateWeight(); updateWeight((WeightedRoundRobinPicker) currentPicker);
} }
weightUpdateTimer = syncContext.schedule(this, config.weightUpdatePeriodNanos, weightUpdateTimer = syncContext.schedule(this, config.weightUpdatePeriodNanos,
TimeUnit.NANOSECONDS, timeService); TimeUnit.NANOSECONDS, timeService);
@ -415,53 +449,50 @@ final class WeightedRoundRobinLoadBalancer extends MultiChildLoadBalancer {
@VisibleForTesting @VisibleForTesting
static final class WeightedRoundRobinPicker extends SubchannelPicker { static final class WeightedRoundRobinPicker extends SubchannelPicker {
private final List<ChildLbState> children; // Parallel lists (column-based storage instead of normal row-based storage of List<Struct>).
private final Map<Subchannel, OrcaPerRequestReportListener> subchannelToReportListenerMap = // The ith element of children corresponds to the ith element of pickers, listeners, and even
new HashMap<>(); // updateWeight(float[]).
private final List<ChildLbState> children; // May only be accessed from sync context
private final List<SubchannelPicker> pickers;
private final List<OrcaPerRequestReportListener> reportListeners;
private final boolean enableOobLoadReport; private final boolean enableOobLoadReport;
private final float errorUtilizationPenalty; private final float errorUtilizationPenalty;
private final AtomicInteger sequence; private final AtomicInteger sequence;
private final int hashCode; private final int hashCode;
private final LoadBalancer.Helper helper;
private final String locality;
private volatile StaticStrideScheduler scheduler; private volatile StaticStrideScheduler scheduler;
WeightedRoundRobinPicker(List<ChildLbState> children, boolean enableOobLoadReport, WeightedRoundRobinPicker(List<ChildLbState> children, boolean enableOobLoadReport,
float errorUtilizationPenalty, AtomicInteger sequence, LoadBalancer.Helper helper, float errorUtilizationPenalty, AtomicInteger sequence) {
String locality) {
checkNotNull(children, "children"); checkNotNull(children, "children");
Preconditions.checkArgument(!children.isEmpty(), "empty child list"); Preconditions.checkArgument(!children.isEmpty(), "empty child list");
this.children = children; this.children = children;
List<SubchannelPicker> pickers = new ArrayList<>(children.size());
List<OrcaPerRequestReportListener> reportListeners = new ArrayList<>(children.size());
for (ChildLbState child : children) { for (ChildLbState child : children) {
WeightedChildLbState wChild = (WeightedChildLbState) child; WeightedChildLbState wChild = (WeightedChildLbState) child;
for (WrrSubchannel subchannel : wChild.subchannels) { pickers.add(wChild.getCurrentPicker());
this.subchannelToReportListenerMap reportListeners.add(wChild.getOrCreateOrcaListener(errorUtilizationPenalty));
.put(subchannel, wChild.getOrCreateOrcaListener(errorUtilizationPenalty));
}
} }
this.pickers = pickers;
this.reportListeners = reportListeners;
this.enableOobLoadReport = enableOobLoadReport; this.enableOobLoadReport = enableOobLoadReport;
this.errorUtilizationPenalty = errorUtilizationPenalty; this.errorUtilizationPenalty = errorUtilizationPenalty;
this.sequence = checkNotNull(sequence, "sequence"); this.sequence = checkNotNull(sequence, "sequence");
this.helper = helper;
this.locality = checkNotNull(locality, "locality");
// For equality we treat children as a set; use hash code as defined by Set // For equality we treat pickers as a set; use hash code as defined by Set
int sum = 0; int sum = 0;
for (ChildLbState child : children) { for (SubchannelPicker picker : pickers) {
sum += child.hashCode(); sum += picker.hashCode();
} }
this.hashCode = sum this.hashCode = sum
^ Boolean.hashCode(enableOobLoadReport) ^ Boolean.hashCode(enableOobLoadReport)
^ Float.hashCode(errorUtilizationPenalty); ^ Float.hashCode(errorUtilizationPenalty);
updateWeight();
} }
@Override @Override
public PickResult pickSubchannel(PickSubchannelArgs args) { public PickResult pickSubchannel(PickSubchannelArgs args) {
ChildLbState childLbState = children.get(scheduler.pick()); int pick = scheduler.pick();
WeightedChildLbState wChild = (WeightedChildLbState) childLbState; PickResult pickResult = pickers.get(pick).pickSubchannel(args);
PickResult pickResult = childLbState.getCurrentPicker().pickSubchannel(args);
Subchannel subchannel = pickResult.getSubchannel(); Subchannel subchannel = pickResult.getSubchannel();
if (subchannel == null) { if (subchannel == null) {
return pickResult; return pickResult;
@ -469,48 +500,16 @@ final class WeightedRoundRobinLoadBalancer extends MultiChildLoadBalancer {
if (!enableOobLoadReport) { if (!enableOobLoadReport) {
return PickResult.withSubchannel(subchannel, return PickResult.withSubchannel(subchannel,
OrcaPerRequestUtil.getInstance().newOrcaClientStreamTracerFactory( OrcaPerRequestUtil.getInstance().newOrcaClientStreamTracerFactory(
subchannelToReportListenerMap.getOrDefault(subchannel, reportListeners.get(pick)));
wChild.getOrCreateOrcaListener(errorUtilizationPenalty))));
} else { } else {
return PickResult.withSubchannel(subchannel); return PickResult.withSubchannel(subchannel);
} }
} }
private void updateWeight() { /** Returns {@code true} if weights are different than round_robin. */
float[] newWeights = new float[children.size()]; private boolean updateWeight(float[] newWeights) {
AtomicInteger staleEndpoints = new AtomicInteger();
AtomicInteger notYetUsableEndpoints = new AtomicInteger();
for (int i = 0; i < children.size(); i++) {
double newWeight = ((WeightedChildLbState) children.get(i)).getWeight(staleEndpoints,
notYetUsableEndpoints);
// TODO: add locality label once available
helper.getMetricRecorder()
.recordDoubleHistogram(ENDPOINT_WEIGHTS_HISTOGRAM, newWeight,
ImmutableList.of(helper.getChannelTarget()),
ImmutableList.of(locality));
newWeights[i] = newWeight > 0 ? (float) newWeight : 0.0f;
}
if (staleEndpoints.get() > 0) {
// TODO: add locality label once available
helper.getMetricRecorder()
.addLongCounter(ENDPOINT_WEIGHT_STALE_COUNTER, staleEndpoints.get(),
ImmutableList.of(helper.getChannelTarget()),
ImmutableList.of(locality));
}
if (notYetUsableEndpoints.get() > 0) {
// TODO: add locality label once available
helper.getMetricRecorder()
.addLongCounter(ENDPOINT_WEIGHT_NOT_YET_USEABLE_COUNTER, notYetUsableEndpoints.get(),
ImmutableList.of(helper.getChannelTarget()), ImmutableList.of(locality));
}
this.scheduler = new StaticStrideScheduler(newWeights, sequence); this.scheduler = new StaticStrideScheduler(newWeights, sequence);
if (this.scheduler.usesRoundRobin()) { return !this.scheduler.usesRoundRobin();
// TODO: locality label once available
helper.getMetricRecorder()
.addLongCounter(RR_FALLBACK_COUNTER, 1, ImmutableList.of(helper.getChannelTarget()),
ImmutableList.of(locality));
}
} }
@Override @Override
@ -518,7 +517,8 @@ final class WeightedRoundRobinLoadBalancer extends MultiChildLoadBalancer {
return MoreObjects.toStringHelper(WeightedRoundRobinPicker.class) return MoreObjects.toStringHelper(WeightedRoundRobinPicker.class)
.add("enableOobLoadReport", enableOobLoadReport) .add("enableOobLoadReport", enableOobLoadReport)
.add("errorUtilizationPenalty", errorUtilizationPenalty) .add("errorUtilizationPenalty", errorUtilizationPenalty)
.add("list", children).toString(); .add("pickers", pickers)
.toString();
} }
@VisibleForTesting @VisibleForTesting
@ -545,8 +545,8 @@ final class WeightedRoundRobinLoadBalancer extends MultiChildLoadBalancer {
&& sequence == other.sequence && sequence == other.sequence
&& enableOobLoadReport == other.enableOobLoadReport && enableOobLoadReport == other.enableOobLoadReport
&& Float.compare(errorUtilizationPenalty, other.errorUtilizationPenalty) == 0 && Float.compare(errorUtilizationPenalty, other.errorUtilizationPenalty) == 0
&& children.size() == other.children.size() && pickers.size() == other.pickers.size()
&& new HashSet<>(children).containsAll(other.children); && new HashSet<>(pickers).containsAll(other.pickers);
} }
} }

View File

@ -244,7 +244,7 @@ public class WeightedRoundRobinLoadBalancerTest {
String weightedPickerStr = weightedPicker.toString(); String weightedPickerStr = weightedPicker.toString();
assertThat(weightedPickerStr).contains("enableOobLoadReport=false"); assertThat(weightedPickerStr).contains("enableOobLoadReport=false");
assertThat(weightedPickerStr).contains("errorUtilizationPenalty=1.0"); assertThat(weightedPickerStr).contains("errorUtilizationPenalty=1.0");
assertThat(weightedPickerStr).contains("list="); assertThat(weightedPickerStr).contains("pickers=");
WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0);
WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1);