xds: WRRPicker must not access unsynchronized data in ChildLbState

There was no point to using subchannels as keys to
subchannelToReportListenerMap, as the listener is per-child. That meant
the keys would be guaranteed to be known ahead-of-time and the
unsynchronized getOrCreateOrcaListener() during picking was unnecessary.

The picker still stores ChildLbStates to make sure that updating weights
uses the correct children, but the picker itself no longer references
ChildLbStates except in the constructor. That means weight calculation
is moved into the LB policy, as child.getWeight() is unsynchronized, and
the picker no longer needs a reference to helper.
This commit is contained in:
Eric Anderson 2024-08-12 11:23:37 -07:00 committed by GitHub
parent 0d2ad89016
commit 0d47f5bd1b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 67 additions and 67 deletions

View File

@ -44,11 +44,10 @@ import io.grpc.xds.orca.OrcaOobUtil;
import io.grpc.xds.orca.OrcaOobUtil.OrcaOobReportListener;
import io.grpc.xds.orca.OrcaPerRequestUtil;
import io.grpc.xds.orca.OrcaPerRequestUtil.OrcaPerRequestReportListener;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.ScheduledExecutorService;
@ -233,9 +232,44 @@ final class WeightedRoundRobinLoadBalancer extends MultiChildLoadBalancer {
}
private SubchannelPicker createReadyPicker(Collection<ChildLbState> activeList) {
return new WeightedRoundRobinPicker(ImmutableList.copyOf(activeList),
config.enableOobLoadReport, config.errorUtilizationPenalty, sequence, getHelper(),
locality);
WeightedRoundRobinPicker picker = new WeightedRoundRobinPicker(ImmutableList.copyOf(activeList),
config.enableOobLoadReport, config.errorUtilizationPenalty, sequence);
updateWeight(picker);
return picker;
}
private void updateWeight(WeightedRoundRobinPicker picker) {
Helper helper = getHelper();
float[] newWeights = new float[picker.children.size()];
AtomicInteger staleEndpoints = new AtomicInteger();
AtomicInteger notYetUsableEndpoints = new AtomicInteger();
for (int i = 0; i < picker.children.size(); i++) {
double newWeight = ((WeightedChildLbState) picker.children.get(i)).getWeight(staleEndpoints,
notYetUsableEndpoints);
helper.getMetricRecorder()
.recordDoubleHistogram(ENDPOINT_WEIGHTS_HISTOGRAM, newWeight,
ImmutableList.of(helper.getChannelTarget()),
ImmutableList.of(locality));
newWeights[i] = newWeight > 0 ? (float) newWeight : 0.0f;
}
if (staleEndpoints.get() > 0) {
helper.getMetricRecorder()
.addLongCounter(ENDPOINT_WEIGHT_STALE_COUNTER, staleEndpoints.get(),
ImmutableList.of(helper.getChannelTarget()),
ImmutableList.of(locality));
}
if (notYetUsableEndpoints.get() > 0) {
helper.getMetricRecorder()
.addLongCounter(ENDPOINT_WEIGHT_NOT_YET_USEABLE_COUNTER, notYetUsableEndpoints.get(),
ImmutableList.of(helper.getChannelTarget()), ImmutableList.of(locality));
}
boolean weightsEffective = picker.updateWeight(newWeights);
if (!weightsEffective) {
helper.getMetricRecorder()
.addLongCounter(RR_FALLBACK_COUNTER, 1, ImmutableList.of(helper.getChannelTarget()),
ImmutableList.of(locality));
}
}
private void updateBalancingState(ConnectivityState state, SubchannelPicker picker) {
@ -345,7 +379,7 @@ final class WeightedRoundRobinLoadBalancer extends MultiChildLoadBalancer {
@Override
public void run() {
if (currentPicker != null && currentPicker instanceof WeightedRoundRobinPicker) {
((WeightedRoundRobinPicker) currentPicker).updateWeight();
updateWeight((WeightedRoundRobinPicker) currentPicker);
}
weightUpdateTimer = syncContext.schedule(this, config.weightUpdatePeriodNanos,
TimeUnit.NANOSECONDS, timeService);
@ -415,53 +449,50 @@ final class WeightedRoundRobinLoadBalancer extends MultiChildLoadBalancer {
@VisibleForTesting
static final class WeightedRoundRobinPicker extends SubchannelPicker {
private final List<ChildLbState> children;
private final Map<Subchannel, OrcaPerRequestReportListener> subchannelToReportListenerMap =
new HashMap<>();
// Parallel lists (column-based storage instead of normal row-based storage of List<Struct>).
// The ith element of children corresponds to the ith element of pickers, listeners, and even
// updateWeight(float[]).
private final List<ChildLbState> children; // May only be accessed from sync context
private final List<SubchannelPicker> pickers;
private final List<OrcaPerRequestReportListener> reportListeners;
private final boolean enableOobLoadReport;
private final float errorUtilizationPenalty;
private final AtomicInteger sequence;
private final int hashCode;
private final LoadBalancer.Helper helper;
private final String locality;
private volatile StaticStrideScheduler scheduler;
WeightedRoundRobinPicker(List<ChildLbState> children, boolean enableOobLoadReport,
float errorUtilizationPenalty, AtomicInteger sequence, LoadBalancer.Helper helper,
String locality) {
float errorUtilizationPenalty, AtomicInteger sequence) {
checkNotNull(children, "children");
Preconditions.checkArgument(!children.isEmpty(), "empty child list");
this.children = children;
List<SubchannelPicker> pickers = new ArrayList<>(children.size());
List<OrcaPerRequestReportListener> reportListeners = new ArrayList<>(children.size());
for (ChildLbState child : children) {
WeightedChildLbState wChild = (WeightedChildLbState) child;
for (WrrSubchannel subchannel : wChild.subchannels) {
this.subchannelToReportListenerMap
.put(subchannel, wChild.getOrCreateOrcaListener(errorUtilizationPenalty));
}
pickers.add(wChild.getCurrentPicker());
reportListeners.add(wChild.getOrCreateOrcaListener(errorUtilizationPenalty));
}
this.pickers = pickers;
this.reportListeners = reportListeners;
this.enableOobLoadReport = enableOobLoadReport;
this.errorUtilizationPenalty = errorUtilizationPenalty;
this.sequence = checkNotNull(sequence, "sequence");
this.helper = helper;
this.locality = checkNotNull(locality, "locality");
// For equality we treat children as a set; use hash code as defined by Set
// For equality we treat pickers as a set; use hash code as defined by Set
int sum = 0;
for (ChildLbState child : children) {
sum += child.hashCode();
for (SubchannelPicker picker : pickers) {
sum += picker.hashCode();
}
this.hashCode = sum
^ Boolean.hashCode(enableOobLoadReport)
^ Float.hashCode(errorUtilizationPenalty);
updateWeight();
}
@Override
public PickResult pickSubchannel(PickSubchannelArgs args) {
ChildLbState childLbState = children.get(scheduler.pick());
WeightedChildLbState wChild = (WeightedChildLbState) childLbState;
PickResult pickResult = childLbState.getCurrentPicker().pickSubchannel(args);
int pick = scheduler.pick();
PickResult pickResult = pickers.get(pick).pickSubchannel(args);
Subchannel subchannel = pickResult.getSubchannel();
if (subchannel == null) {
return pickResult;
@ -469,48 +500,16 @@ final class WeightedRoundRobinLoadBalancer extends MultiChildLoadBalancer {
if (!enableOobLoadReport) {
return PickResult.withSubchannel(subchannel,
OrcaPerRequestUtil.getInstance().newOrcaClientStreamTracerFactory(
subchannelToReportListenerMap.getOrDefault(subchannel,
wChild.getOrCreateOrcaListener(errorUtilizationPenalty))));
reportListeners.get(pick)));
} else {
return PickResult.withSubchannel(subchannel);
}
}
private void updateWeight() {
float[] newWeights = new float[children.size()];
AtomicInteger staleEndpoints = new AtomicInteger();
AtomicInteger notYetUsableEndpoints = new AtomicInteger();
for (int i = 0; i < children.size(); i++) {
double newWeight = ((WeightedChildLbState) children.get(i)).getWeight(staleEndpoints,
notYetUsableEndpoints);
// TODO: add locality label once available
helper.getMetricRecorder()
.recordDoubleHistogram(ENDPOINT_WEIGHTS_HISTOGRAM, newWeight,
ImmutableList.of(helper.getChannelTarget()),
ImmutableList.of(locality));
newWeights[i] = newWeight > 0 ? (float) newWeight : 0.0f;
}
if (staleEndpoints.get() > 0) {
// TODO: add locality label once available
helper.getMetricRecorder()
.addLongCounter(ENDPOINT_WEIGHT_STALE_COUNTER, staleEndpoints.get(),
ImmutableList.of(helper.getChannelTarget()),
ImmutableList.of(locality));
}
if (notYetUsableEndpoints.get() > 0) {
// TODO: add locality label once available
helper.getMetricRecorder()
.addLongCounter(ENDPOINT_WEIGHT_NOT_YET_USEABLE_COUNTER, notYetUsableEndpoints.get(),
ImmutableList.of(helper.getChannelTarget()), ImmutableList.of(locality));
}
/** Returns {@code true} if weights are different than round_robin. */
private boolean updateWeight(float[] newWeights) {
this.scheduler = new StaticStrideScheduler(newWeights, sequence);
if (this.scheduler.usesRoundRobin()) {
// TODO: locality label once available
helper.getMetricRecorder()
.addLongCounter(RR_FALLBACK_COUNTER, 1, ImmutableList.of(helper.getChannelTarget()),
ImmutableList.of(locality));
}
return !this.scheduler.usesRoundRobin();
}
@Override
@ -518,7 +517,8 @@ final class WeightedRoundRobinLoadBalancer extends MultiChildLoadBalancer {
return MoreObjects.toStringHelper(WeightedRoundRobinPicker.class)
.add("enableOobLoadReport", enableOobLoadReport)
.add("errorUtilizationPenalty", errorUtilizationPenalty)
.add("list", children).toString();
.add("pickers", pickers)
.toString();
}
@VisibleForTesting
@ -545,8 +545,8 @@ final class WeightedRoundRobinLoadBalancer extends MultiChildLoadBalancer {
&& sequence == other.sequence
&& enableOobLoadReport == other.enableOobLoadReport
&& Float.compare(errorUtilizationPenalty, other.errorUtilizationPenalty) == 0
&& children.size() == other.children.size()
&& new HashSet<>(children).containsAll(other.children);
&& pickers.size() == other.pickers.size()
&& new HashSet<>(pickers).containsAll(other.pickers);
}
}

View File

@ -244,7 +244,7 @@ public class WeightedRoundRobinLoadBalancerTest {
String weightedPickerStr = weightedPicker.toString();
assertThat(weightedPickerStr).contains("enableOobLoadReport=false");
assertThat(weightedPickerStr).contains("errorUtilizationPenalty=1.0");
assertThat(weightedPickerStr).contains("list=");
assertThat(weightedPickerStr).contains("pickers=");
WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0);
WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1);