xds:Allow big cluster total weight (#9864)

* xds:  allow sum of cluster weights above MAX_INT up to max of unsigned int.

* Define nextLong(long bound) method in FakeRandom for WeightedRandomPickerTest.
This commit is contained in:
Larry Safran 2023-02-03 18:53:50 +00:00 committed by GitHub
parent 04afea0fbd
commit 54c1f37093
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 105 additions and 17 deletions

View File

@ -25,6 +25,8 @@ interface ThreadSafeRandom {
long nextLong();
long nextLong(long bound);
final class ThreadSafeRandomImpl implements ThreadSafeRandom {
static final ThreadSafeRandom instance = new ThreadSafeRandomImpl();
@ -40,5 +42,10 @@ interface ThreadSafeRandom {
public long nextLong() {
return ThreadLocalRandom.current().nextLong();
}
@Override
public long nextLong(long bound) {
return ThreadLocalRandom.current().nextLong(bound);
}
}
}

View File

@ -21,6 +21,7 @@ import static com.google.common.base.Preconditions.checkNotNull;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.MoreObjects;
import com.google.common.primitives.UnsignedInteger;
import io.grpc.LoadBalancer.PickResult;
import io.grpc.LoadBalancer.PickSubchannelArgs;
import io.grpc.LoadBalancer.SubchannelPicker;
@ -34,21 +35,22 @@ final class WeightedRandomPicker extends SubchannelPicker {
final List<WeightedChildPicker> weightedChildPickers;
private final ThreadSafeRandom random;
private final int totalWeight;
private final long totalWeight;
static final class WeightedChildPicker {
private final int weight;
private final long weight;
private final SubchannelPicker childPicker;
WeightedChildPicker(int weight, SubchannelPicker childPicker) {
WeightedChildPicker(long weight, SubchannelPicker childPicker) {
checkArgument(weight >= 0, "weight is negative");
checkArgument(weight <= UnsignedInteger.MAX_VALUE.longValue(), "weight is too large");
checkNotNull(childPicker, "childPicker is null");
this.weight = weight;
this.childPicker = childPicker;
}
int getWeight() {
long getWeight() {
return weight;
}
@ -93,12 +95,16 @@ final class WeightedRandomPicker extends SubchannelPicker {
this.weightedChildPickers = Collections.unmodifiableList(weightedChildPickers);
int totalWeight = 0;
long totalWeight = 0;
for (WeightedChildPicker weightedChildPicker : weightedChildPickers) {
int weight = weightedChildPicker.getWeight();
long weight = weightedChildPicker.getWeight();
checkArgument(weight >= 0, "weight is negative");
checkNotNull(weightedChildPicker.getPicker(), "childPicker is null");
totalWeight += weight;
}
this.totalWeight = totalWeight;
checkArgument(totalWeight <= UnsignedInteger.MAX_VALUE.longValue(),
"total weight greater than unsigned int can hold");
this.random = random;
}
@ -111,15 +117,15 @@ final class WeightedRandomPicker extends SubchannelPicker {
childPicker =
weightedChildPickers.get(random.nextInt(weightedChildPickers.size())).getPicker();
} else {
int rand = random.nextInt(totalWeight);
long rand = random.nextLong(totalWeight);
// Find the first idx such that rand < accumulatedWeights[idx]
// Not using Arrays.binarySearch for better readability.
int accumulatedWeight = 0;
for (int idx = 0; idx < weightedChildPickers.size(); idx++) {
accumulatedWeight += weightedChildPickers.get(idx).getWeight();
long accumulatedWeight = 0;
for (WeightedChildPicker weightedChildPicker : weightedChildPickers) {
accumulatedWeight += weightedChildPicker.getWeight();
if (rand < accumulatedWeight) {
childPicker = weightedChildPickers.get(idx).getPicker();
childPicker = weightedChildPicker.getPicker();
break;
}
}

View File

@ -437,12 +437,12 @@ final class XdsNameResolver extends NameResolver {
if (action.cluster() != null) {
cluster = prefixedClusterName(action.cluster());
} else if (action.weightedClusters() != null) {
int totalWeight = 0;
long totalWeight = 0;
for (ClusterWeight weightedCluster : action.weightedClusters()) {
totalWeight += weightedCluster.weight();
}
int select = random.nextInt(totalWeight);
int accumulator = 0;
long select = random.nextLong(totalWeight);
long accumulator = 0;
for (ClusterWeight weightedCluster : action.weightedClusters()) {
accumulator += weightedCluster.weight();
if (select < accumulator) {

View File

@ -24,6 +24,7 @@ import com.google.common.base.MoreObjects;
import com.google.common.base.Splitter;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.primitives.UnsignedInteger;
import com.google.protobuf.Any;
import com.google.protobuf.Duration;
import com.google.protobuf.InvalidProtocolBufferException;
@ -477,7 +478,7 @@ class XdsRouteConfigureResource extends XdsResourceType<RdsUpdate> {
return StructOrError.fromError("No cluster found in weighted cluster list");
}
List<ClusterWeight> weightedClusters = new ArrayList<>();
int clusterWeightSum = 0;
long clusterWeightSum = 0;
for (io.envoyproxy.envoy.config.route.v3.WeightedCluster.ClusterWeight clusterWeight
: clusterWeights) {
StructOrError<ClusterWeight> clusterWeightOrError =
@ -492,6 +493,12 @@ class XdsRouteConfigureResource extends XdsResourceType<RdsUpdate> {
if (clusterWeightSum <= 0) {
return StructOrError.fromError("Sum of cluster weights should be above 0.");
}
if (clusterWeightSum > UnsignedInteger.MAX_VALUE.longValue()) {
return StructOrError.fromError(String.format(
"Sum of cluster weights should be less than the maximum unsigned integer (%d), but"
+ " was %d. ",
UnsignedInteger.MAX_VALUE.longValue(), clusterWeightSum));
}
return StructOrError.fromStruct(VirtualHost.Route.RouteAction.forWeightedClusters(
weightedClusters, hashPolicies, timeoutNano, retryPolicy));
case CLUSTER_SPECIFIER_PLUGIN:
@ -499,7 +506,7 @@ class XdsRouteConfigureResource extends XdsResourceType<RdsUpdate> {
String pluginName = proto.getClusterSpecifierPlugin();
PluginConfig pluginConfig = pluginConfigMap.get(pluginName);
if (pluginConfig == null) {
// Skip route if the plugin is not registered, but it's optional.
// Skip route if the plugin is not registered, but it is optional.
if (optionalPlugins.contains(pluginName)) {
return null;
}

View File

@ -87,7 +87,8 @@ public class WeightedRandomPickerTest {
private static final class FakeRandom implements ThreadSafeRandom {
int nextInt;
int bound;
long bound;
Long nextLong;
@Override
public int nextInt(int bound) {
@ -102,6 +103,23 @@ public class WeightedRandomPickerTest {
public long nextLong() {
throw new UnsupportedOperationException("Should not be called");
}
@Override
public long nextLong(long bound) {
this.bound = bound;
if (nextLong == null) {
assertThat(nextInt).isAtLeast(0);
if (bound <= Integer.MAX_VALUE) {
assertThat(nextInt).isLessThan((int)bound);
}
return nextInt;
}
assertThat(nextLong).isAtLeast(0);
assertThat(nextLong).isLessThan(bound);
return nextLong;
}
}
private final FakeRandom fakeRandom = new FakeRandom();
@ -120,6 +138,24 @@ public class WeightedRandomPickerTest {
new WeightedChildPicker(-1, childPicker0);
}
@Test
public void overWeightSingle() {
thrown.expect(IllegalArgumentException.class);
new WeightedChildPicker(Integer.MAX_VALUE * 3L, childPicker0);
}
@Test
public void overWeightAggregate() {
List<WeightedChildPicker> weightedChildPickers = Arrays.asList(
new WeightedChildPicker(Integer.MAX_VALUE, childPicker0),
new WeightedChildPicker(Integer.MAX_VALUE, childPicker1),
new WeightedChildPicker(10, childPicker2));
thrown.expect(IllegalArgumentException.class);
new WeightedRandomPicker(weightedChildPickers, fakeRandom);
}
@Test
public void pickWithFakeRandom() {
WeightedChildPicker weightedChildPicker0 = new WeightedChildPicker(0, childPicker0);
@ -156,6 +192,36 @@ public class WeightedRandomPickerTest {
assertThat(fakeRandom.bound).isEqualTo(25);
}
@Test
public void pickFromLargeTotal() {
List<WeightedChildPicker> weightedChildPickers = Arrays.asList(
new WeightedChildPicker(10, childPicker0),
new WeightedChildPicker(Integer.MAX_VALUE, childPicker1),
new WeightedChildPicker(10, childPicker2));
WeightedRandomPicker xdsPicker = new WeightedRandomPicker(weightedChildPickers,fakeRandom);
long totalWeight = weightedChildPickers.stream()
.mapToLong(WeightedChildPicker::getWeight)
.reduce(0, Long::sum);
fakeRandom.nextLong = 5L;
assertThat(xdsPicker.pickSubchannel(pickSubchannelArgs)).isSameInstanceAs(pickResult0);
assertThat(fakeRandom.bound).isEqualTo(totalWeight);
fakeRandom.nextLong = 16L;
assertThat(xdsPicker.pickSubchannel(pickSubchannelArgs)).isSameInstanceAs(pickResult1);
assertThat(fakeRandom.bound).isEqualTo(totalWeight);
fakeRandom.nextLong = Integer.MAX_VALUE + 10L;
assertThat(xdsPicker.pickSubchannel(pickSubchannelArgs)).isSameInstanceAs(pickResult2);
assertThat(fakeRandom.bound).isEqualTo(totalWeight);
fakeRandom.nextLong = Integer.MAX_VALUE + 15L;
assertThat(xdsPicker.pickSubchannel(pickSubchannelArgs)).isSameInstanceAs(pickResult2);
assertThat(fakeRandom.bound).isEqualTo(totalWeight);
}
@Test
public void allZeroWeights() {
WeightedChildPicker weightedChildPicker0 = new WeightedChildPicker(0, childPicker0);

View File

@ -24,6 +24,7 @@ import static io.grpc.xds.FaultFilter.HEADER_DELAY_KEY;
import static io.grpc.xds.FaultFilter.HEADER_DELAY_PERCENTAGE_KEY;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
@ -994,6 +995,7 @@ public class XdsNameResolverTest {
@Test
public void resolved_simpleCallSucceeds_routeToWeightedCluster() {
when(mockRandom.nextInt(anyInt())).thenReturn(90, 10);
when(mockRandom.nextLong(anyLong())).thenReturn(90L, 10L);
resolver.start(mockListener);
FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient();
xdsClient.deliverLdsUpdate(