xds:Allow big cluster total weight (#9864)

* xds:  allow sum of cluster weights above MAX_INT up to max of unsigned int.

* Define nextLong(long bound) method in FakeRandom for WeightedRandomPickerTest.
This commit is contained in:
Larry Safran 2023-02-03 18:53:50 +00:00 committed by GitHub
parent 04afea0fbd
commit 54c1f37093
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 105 additions and 17 deletions

View File

@ -25,6 +25,8 @@ interface ThreadSafeRandom {
long nextLong(); long nextLong();
long nextLong(long bound);
final class ThreadSafeRandomImpl implements ThreadSafeRandom { final class ThreadSafeRandomImpl implements ThreadSafeRandom {
static final ThreadSafeRandom instance = new ThreadSafeRandomImpl(); static final ThreadSafeRandom instance = new ThreadSafeRandomImpl();
@ -40,5 +42,10 @@ interface ThreadSafeRandom {
public long nextLong() { public long nextLong() {
return ThreadLocalRandom.current().nextLong(); return ThreadLocalRandom.current().nextLong();
} }
@Override
public long nextLong(long bound) {
return ThreadLocalRandom.current().nextLong(bound);
}
} }
} }

View File

@ -21,6 +21,7 @@ import static com.google.common.base.Preconditions.checkNotNull;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.MoreObjects; import com.google.common.base.MoreObjects;
import com.google.common.primitives.UnsignedInteger;
import io.grpc.LoadBalancer.PickResult; import io.grpc.LoadBalancer.PickResult;
import io.grpc.LoadBalancer.PickSubchannelArgs; import io.grpc.LoadBalancer.PickSubchannelArgs;
import io.grpc.LoadBalancer.SubchannelPicker; import io.grpc.LoadBalancer.SubchannelPicker;
@ -34,21 +35,22 @@ final class WeightedRandomPicker extends SubchannelPicker {
final List<WeightedChildPicker> weightedChildPickers; final List<WeightedChildPicker> weightedChildPickers;
private final ThreadSafeRandom random; private final ThreadSafeRandom random;
private final int totalWeight; private final long totalWeight;
static final class WeightedChildPicker { static final class WeightedChildPicker {
private final int weight; private final long weight;
private final SubchannelPicker childPicker; private final SubchannelPicker childPicker;
WeightedChildPicker(int weight, SubchannelPicker childPicker) { WeightedChildPicker(long weight, SubchannelPicker childPicker) {
checkArgument(weight >= 0, "weight is negative"); checkArgument(weight >= 0, "weight is negative");
checkArgument(weight <= UnsignedInteger.MAX_VALUE.longValue(), "weight is too large");
checkNotNull(childPicker, "childPicker is null"); checkNotNull(childPicker, "childPicker is null");
this.weight = weight; this.weight = weight;
this.childPicker = childPicker; this.childPicker = childPicker;
} }
int getWeight() { long getWeight() {
return weight; return weight;
} }
@ -93,12 +95,16 @@ final class WeightedRandomPicker extends SubchannelPicker {
this.weightedChildPickers = Collections.unmodifiableList(weightedChildPickers); this.weightedChildPickers = Collections.unmodifiableList(weightedChildPickers);
int totalWeight = 0; long totalWeight = 0;
for (WeightedChildPicker weightedChildPicker : weightedChildPickers) { for (WeightedChildPicker weightedChildPicker : weightedChildPickers) {
int weight = weightedChildPicker.getWeight(); long weight = weightedChildPicker.getWeight();
checkArgument(weight >= 0, "weight is negative");
checkNotNull(weightedChildPicker.getPicker(), "childPicker is null");
totalWeight += weight; totalWeight += weight;
} }
this.totalWeight = totalWeight; this.totalWeight = totalWeight;
checkArgument(totalWeight <= UnsignedInteger.MAX_VALUE.longValue(),
"total weight greater than unsigned int can hold");
this.random = random; this.random = random;
} }
@ -111,15 +117,15 @@ final class WeightedRandomPicker extends SubchannelPicker {
childPicker = childPicker =
weightedChildPickers.get(random.nextInt(weightedChildPickers.size())).getPicker(); weightedChildPickers.get(random.nextInt(weightedChildPickers.size())).getPicker();
} else { } else {
int rand = random.nextInt(totalWeight); long rand = random.nextLong(totalWeight);
// Find the first idx such that rand < accumulatedWeights[idx] // Find the first idx such that rand < accumulatedWeights[idx]
// Not using Arrays.binarySearch for better readability. // Not using Arrays.binarySearch for better readability.
int accumulatedWeight = 0; long accumulatedWeight = 0;
for (int idx = 0; idx < weightedChildPickers.size(); idx++) { for (WeightedChildPicker weightedChildPicker : weightedChildPickers) {
accumulatedWeight += weightedChildPickers.get(idx).getWeight(); accumulatedWeight += weightedChildPicker.getWeight();
if (rand < accumulatedWeight) { if (rand < accumulatedWeight) {
childPicker = weightedChildPickers.get(idx).getPicker(); childPicker = weightedChildPicker.getPicker();
break; break;
} }
} }

View File

@ -437,12 +437,12 @@ final class XdsNameResolver extends NameResolver {
if (action.cluster() != null) { if (action.cluster() != null) {
cluster = prefixedClusterName(action.cluster()); cluster = prefixedClusterName(action.cluster());
} else if (action.weightedClusters() != null) { } else if (action.weightedClusters() != null) {
int totalWeight = 0; long totalWeight = 0;
for (ClusterWeight weightedCluster : action.weightedClusters()) { for (ClusterWeight weightedCluster : action.weightedClusters()) {
totalWeight += weightedCluster.weight(); totalWeight += weightedCluster.weight();
} }
int select = random.nextInt(totalWeight); long select = random.nextLong(totalWeight);
int accumulator = 0; long accumulator = 0;
for (ClusterWeight weightedCluster : action.weightedClusters()) { for (ClusterWeight weightedCluster : action.weightedClusters()) {
accumulator += weightedCluster.weight(); accumulator += weightedCluster.weight();
if (select < accumulator) { if (select < accumulator) {

View File

@ -24,6 +24,7 @@ import com.google.common.base.MoreObjects;
import com.google.common.base.Splitter; import com.google.common.base.Splitter;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet; import com.google.common.collect.ImmutableSet;
import com.google.common.primitives.UnsignedInteger;
import com.google.protobuf.Any; import com.google.protobuf.Any;
import com.google.protobuf.Duration; import com.google.protobuf.Duration;
import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.InvalidProtocolBufferException;
@ -477,7 +478,7 @@ class XdsRouteConfigureResource extends XdsResourceType<RdsUpdate> {
return StructOrError.fromError("No cluster found in weighted cluster list"); return StructOrError.fromError("No cluster found in weighted cluster list");
} }
List<ClusterWeight> weightedClusters = new ArrayList<>(); List<ClusterWeight> weightedClusters = new ArrayList<>();
int clusterWeightSum = 0; long clusterWeightSum = 0;
for (io.envoyproxy.envoy.config.route.v3.WeightedCluster.ClusterWeight clusterWeight for (io.envoyproxy.envoy.config.route.v3.WeightedCluster.ClusterWeight clusterWeight
: clusterWeights) { : clusterWeights) {
StructOrError<ClusterWeight> clusterWeightOrError = StructOrError<ClusterWeight> clusterWeightOrError =
@ -492,6 +493,12 @@ class XdsRouteConfigureResource extends XdsResourceType<RdsUpdate> {
if (clusterWeightSum <= 0) { if (clusterWeightSum <= 0) {
return StructOrError.fromError("Sum of cluster weights should be above 0."); return StructOrError.fromError("Sum of cluster weights should be above 0.");
} }
if (clusterWeightSum > UnsignedInteger.MAX_VALUE.longValue()) {
return StructOrError.fromError(String.format(
"Sum of cluster weights should be less than the maximum unsigned integer (%d), but"
+ " was %d. ",
UnsignedInteger.MAX_VALUE.longValue(), clusterWeightSum));
}
return StructOrError.fromStruct(VirtualHost.Route.RouteAction.forWeightedClusters( return StructOrError.fromStruct(VirtualHost.Route.RouteAction.forWeightedClusters(
weightedClusters, hashPolicies, timeoutNano, retryPolicy)); weightedClusters, hashPolicies, timeoutNano, retryPolicy));
case CLUSTER_SPECIFIER_PLUGIN: case CLUSTER_SPECIFIER_PLUGIN:
@ -499,7 +506,7 @@ class XdsRouteConfigureResource extends XdsResourceType<RdsUpdate> {
String pluginName = proto.getClusterSpecifierPlugin(); String pluginName = proto.getClusterSpecifierPlugin();
PluginConfig pluginConfig = pluginConfigMap.get(pluginName); PluginConfig pluginConfig = pluginConfigMap.get(pluginName);
if (pluginConfig == null) { if (pluginConfig == null) {
// Skip route if the plugin is not registered, but it's optional. // Skip route if the plugin is not registered, but it is optional.
if (optionalPlugins.contains(pluginName)) { if (optionalPlugins.contains(pluginName)) {
return null; return null;
} }

View File

@ -87,7 +87,8 @@ public class WeightedRandomPickerTest {
private static final class FakeRandom implements ThreadSafeRandom { private static final class FakeRandom implements ThreadSafeRandom {
int nextInt; int nextInt;
int bound; long bound;
Long nextLong;
@Override @Override
public int nextInt(int bound) { public int nextInt(int bound) {
@ -102,6 +103,23 @@ public class WeightedRandomPickerTest {
public long nextLong() { public long nextLong() {
throw new UnsupportedOperationException("Should not be called"); throw new UnsupportedOperationException("Should not be called");
} }
@Override
public long nextLong(long bound) {
this.bound = bound;
if (nextLong == null) {
assertThat(nextInt).isAtLeast(0);
if (bound <= Integer.MAX_VALUE) {
assertThat(nextInt).isLessThan((int)bound);
}
return nextInt;
}
assertThat(nextLong).isAtLeast(0);
assertThat(nextLong).isLessThan(bound);
return nextLong;
}
} }
private final FakeRandom fakeRandom = new FakeRandom(); private final FakeRandom fakeRandom = new FakeRandom();
@ -120,6 +138,24 @@ public class WeightedRandomPickerTest {
new WeightedChildPicker(-1, childPicker0); new WeightedChildPicker(-1, childPicker0);
} }
@Test
public void overWeightSingle() {
thrown.expect(IllegalArgumentException.class);
new WeightedChildPicker(Integer.MAX_VALUE * 3L, childPicker0);
}
@Test
public void overWeightAggregate() {
List<WeightedChildPicker> weightedChildPickers = Arrays.asList(
new WeightedChildPicker(Integer.MAX_VALUE, childPicker0),
new WeightedChildPicker(Integer.MAX_VALUE, childPicker1),
new WeightedChildPicker(10, childPicker2));
thrown.expect(IllegalArgumentException.class);
new WeightedRandomPicker(weightedChildPickers, fakeRandom);
}
@Test @Test
public void pickWithFakeRandom() { public void pickWithFakeRandom() {
WeightedChildPicker weightedChildPicker0 = new WeightedChildPicker(0, childPicker0); WeightedChildPicker weightedChildPicker0 = new WeightedChildPicker(0, childPicker0);
@ -156,6 +192,36 @@ public class WeightedRandomPickerTest {
assertThat(fakeRandom.bound).isEqualTo(25); assertThat(fakeRandom.bound).isEqualTo(25);
} }
@Test
public void pickFromLargeTotal() {
List<WeightedChildPicker> weightedChildPickers = Arrays.asList(
new WeightedChildPicker(10, childPicker0),
new WeightedChildPicker(Integer.MAX_VALUE, childPicker1),
new WeightedChildPicker(10, childPicker2));
WeightedRandomPicker xdsPicker = new WeightedRandomPicker(weightedChildPickers,fakeRandom);
long totalWeight = weightedChildPickers.stream()
.mapToLong(WeightedChildPicker::getWeight)
.reduce(0, Long::sum);
fakeRandom.nextLong = 5L;
assertThat(xdsPicker.pickSubchannel(pickSubchannelArgs)).isSameInstanceAs(pickResult0);
assertThat(fakeRandom.bound).isEqualTo(totalWeight);
fakeRandom.nextLong = 16L;
assertThat(xdsPicker.pickSubchannel(pickSubchannelArgs)).isSameInstanceAs(pickResult1);
assertThat(fakeRandom.bound).isEqualTo(totalWeight);
fakeRandom.nextLong = Integer.MAX_VALUE + 10L;
assertThat(xdsPicker.pickSubchannel(pickSubchannelArgs)).isSameInstanceAs(pickResult2);
assertThat(fakeRandom.bound).isEqualTo(totalWeight);
fakeRandom.nextLong = Integer.MAX_VALUE + 15L;
assertThat(xdsPicker.pickSubchannel(pickSubchannelArgs)).isSameInstanceAs(pickResult2);
assertThat(fakeRandom.bound).isEqualTo(totalWeight);
}
@Test @Test
public void allZeroWeights() { public void allZeroWeights() {
WeightedChildPicker weightedChildPicker0 = new WeightedChildPicker(0, childPicker0); WeightedChildPicker weightedChildPicker0 = new WeightedChildPicker(0, childPicker0);

View File

@ -24,6 +24,7 @@ import static io.grpc.xds.FaultFilter.HEADER_DELAY_KEY;
import static io.grpc.xds.FaultFilter.HEADER_DELAY_PERCENTAGE_KEY; import static io.grpc.xds.FaultFilter.HEADER_DELAY_PERCENTAGE_KEY;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
@ -994,6 +995,7 @@ public class XdsNameResolverTest {
@Test @Test
public void resolved_simpleCallSucceeds_routeToWeightedCluster() { public void resolved_simpleCallSucceeds_routeToWeightedCluster() {
when(mockRandom.nextInt(anyInt())).thenReturn(90, 10); when(mockRandom.nextInt(anyInt())).thenReturn(90, 10);
when(mockRandom.nextLong(anyLong())).thenReturn(90L, 10L);
resolver.start(mockListener); resolver.start(mockListener);
FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient();
xdsClient.deliverLdsUpdate( xdsClient.deliverLdsUpdate(