diff --git a/xds/src/main/java/io/grpc/xds/ThreadSafeRandom.java b/xds/src/main/java/io/grpc/xds/ThreadSafeRandom.java index 1e844cede3..533ccee237 100644 --- a/xds/src/main/java/io/grpc/xds/ThreadSafeRandom.java +++ b/xds/src/main/java/io/grpc/xds/ThreadSafeRandom.java @@ -25,6 +25,8 @@ interface ThreadSafeRandom { long nextLong(); + long nextLong(long bound); + final class ThreadSafeRandomImpl implements ThreadSafeRandom { static final ThreadSafeRandom instance = new ThreadSafeRandomImpl(); @@ -40,5 +42,10 @@ interface ThreadSafeRandom { public long nextLong() { return ThreadLocalRandom.current().nextLong(); } + + @Override + public long nextLong(long bound) { + return ThreadLocalRandom.current().nextLong(bound); + } } } diff --git a/xds/src/main/java/io/grpc/xds/WeightedRandomPicker.java b/xds/src/main/java/io/grpc/xds/WeightedRandomPicker.java index 1f5fc6d01d..904f3872b6 100644 --- a/xds/src/main/java/io/grpc/xds/WeightedRandomPicker.java +++ b/xds/src/main/java/io/grpc/xds/WeightedRandomPicker.java @@ -21,6 +21,7 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; +import com.google.common.primitives.UnsignedInteger; import io.grpc.LoadBalancer.PickResult; import io.grpc.LoadBalancer.PickSubchannelArgs; import io.grpc.LoadBalancer.SubchannelPicker; @@ -34,21 +35,22 @@ final class WeightedRandomPicker extends SubchannelPicker { final List weightedChildPickers; private final ThreadSafeRandom random; - private final int totalWeight; + private final long totalWeight; static final class WeightedChildPicker { - private final int weight; + private final long weight; private final SubchannelPicker childPicker; - WeightedChildPicker(int weight, SubchannelPicker childPicker) { + WeightedChildPicker(long weight, SubchannelPicker childPicker) { checkArgument(weight >= 0, "weight is negative"); + checkArgument(weight <= UnsignedInteger.MAX_VALUE.longValue(), "weight is too large"); checkNotNull(childPicker, "childPicker is null"); this.weight = weight; this.childPicker = childPicker; } - int getWeight() { + long getWeight() { return weight; } @@ -93,12 +95,16 @@ final class WeightedRandomPicker extends SubchannelPicker { this.weightedChildPickers = Collections.unmodifiableList(weightedChildPickers); - int totalWeight = 0; + long totalWeight = 0; for (WeightedChildPicker weightedChildPicker : weightedChildPickers) { - int weight = weightedChildPicker.getWeight(); + long weight = weightedChildPicker.getWeight(); + checkArgument(weight >= 0, "weight is negative"); + checkNotNull(weightedChildPicker.getPicker(), "childPicker is null"); totalWeight += weight; } this.totalWeight = totalWeight; + checkArgument(totalWeight <= UnsignedInteger.MAX_VALUE.longValue(), + "total weight greater than unsigned int can hold"); this.random = random; } @@ -111,15 +117,15 @@ final class WeightedRandomPicker extends SubchannelPicker { childPicker = weightedChildPickers.get(random.nextInt(weightedChildPickers.size())).getPicker(); } else { - int rand = random.nextInt(totalWeight); + long rand = random.nextLong(totalWeight); // Find the first idx such that rand < accumulatedWeights[idx] // Not using Arrays.binarySearch for better readability. - int accumulatedWeight = 0; - for (int idx = 0; idx < weightedChildPickers.size(); idx++) { - accumulatedWeight += weightedChildPickers.get(idx).getWeight(); + long accumulatedWeight = 0; + for (WeightedChildPicker weightedChildPicker : weightedChildPickers) { + accumulatedWeight += weightedChildPicker.getWeight(); if (rand < accumulatedWeight) { - childPicker = weightedChildPickers.get(idx).getPicker(); + childPicker = weightedChildPicker.getPicker(); break; } } diff --git a/xds/src/main/java/io/grpc/xds/XdsNameResolver.java b/xds/src/main/java/io/grpc/xds/XdsNameResolver.java index 094bb944d8..8a5992ab61 100644 --- a/xds/src/main/java/io/grpc/xds/XdsNameResolver.java +++ b/xds/src/main/java/io/grpc/xds/XdsNameResolver.java @@ -437,12 +437,12 @@ final class XdsNameResolver extends NameResolver { if (action.cluster() != null) { cluster = prefixedClusterName(action.cluster()); } else if (action.weightedClusters() != null) { - int totalWeight = 0; + long totalWeight = 0; for (ClusterWeight weightedCluster : action.weightedClusters()) { totalWeight += weightedCluster.weight(); } - int select = random.nextInt(totalWeight); - int accumulator = 0; + long select = random.nextLong(totalWeight); + long accumulator = 0; for (ClusterWeight weightedCluster : action.weightedClusters()) { accumulator += weightedCluster.weight(); if (select < accumulator) { diff --git a/xds/src/main/java/io/grpc/xds/XdsRouteConfigureResource.java b/xds/src/main/java/io/grpc/xds/XdsRouteConfigureResource.java index ed109fd694..6ae23406d6 100644 --- a/xds/src/main/java/io/grpc/xds/XdsRouteConfigureResource.java +++ b/xds/src/main/java/io/grpc/xds/XdsRouteConfigureResource.java @@ -24,6 +24,7 @@ import com.google.common.base.MoreObjects; import com.google.common.base.Splitter; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import com.google.common.primitives.UnsignedInteger; import com.google.protobuf.Any; import com.google.protobuf.Duration; import com.google.protobuf.InvalidProtocolBufferException; @@ -477,7 +478,7 @@ class XdsRouteConfigureResource extends XdsResourceType { return StructOrError.fromError("No cluster found in weighted cluster list"); } List weightedClusters = new ArrayList<>(); - int clusterWeightSum = 0; + long clusterWeightSum = 0; for (io.envoyproxy.envoy.config.route.v3.WeightedCluster.ClusterWeight clusterWeight : clusterWeights) { StructOrError clusterWeightOrError = @@ -492,6 +493,12 @@ class XdsRouteConfigureResource extends XdsResourceType { if (clusterWeightSum <= 0) { return StructOrError.fromError("Sum of cluster weights should be above 0."); } + if (clusterWeightSum > UnsignedInteger.MAX_VALUE.longValue()) { + return StructOrError.fromError(String.format( + "Sum of cluster weights should be less than the maximum unsigned integer (%d), but" + + " was %d. ", + UnsignedInteger.MAX_VALUE.longValue(), clusterWeightSum)); + } return StructOrError.fromStruct(VirtualHost.Route.RouteAction.forWeightedClusters( weightedClusters, hashPolicies, timeoutNano, retryPolicy)); case CLUSTER_SPECIFIER_PLUGIN: @@ -499,7 +506,7 @@ class XdsRouteConfigureResource extends XdsResourceType { String pluginName = proto.getClusterSpecifierPlugin(); PluginConfig pluginConfig = pluginConfigMap.get(pluginName); if (pluginConfig == null) { - // Skip route if the plugin is not registered, but it's optional. + // Skip route if the plugin is not registered, but it is optional. if (optionalPlugins.contains(pluginName)) { return null; } diff --git a/xds/src/test/java/io/grpc/xds/WeightedRandomPickerTest.java b/xds/src/test/java/io/grpc/xds/WeightedRandomPickerTest.java index ecdd96a734..d6240fb09b 100644 --- a/xds/src/test/java/io/grpc/xds/WeightedRandomPickerTest.java +++ b/xds/src/test/java/io/grpc/xds/WeightedRandomPickerTest.java @@ -87,7 +87,8 @@ public class WeightedRandomPickerTest { private static final class FakeRandom implements ThreadSafeRandom { int nextInt; - int bound; + long bound; + Long nextLong; @Override public int nextInt(int bound) { @@ -102,6 +103,23 @@ public class WeightedRandomPickerTest { public long nextLong() { throw new UnsupportedOperationException("Should not be called"); } + + @Override + public long nextLong(long bound) { + this.bound = bound; + + if (nextLong == null) { + assertThat(nextInt).isAtLeast(0); + if (bound <= Integer.MAX_VALUE) { + assertThat(nextInt).isLessThan((int)bound); + } + return nextInt; + } + + assertThat(nextLong).isAtLeast(0); + assertThat(nextLong).isLessThan(bound); + return nextLong; + } } private final FakeRandom fakeRandom = new FakeRandom(); @@ -120,6 +138,24 @@ public class WeightedRandomPickerTest { new WeightedChildPicker(-1, childPicker0); } + @Test + public void overWeightSingle() { + thrown.expect(IllegalArgumentException.class); + new WeightedChildPicker(Integer.MAX_VALUE * 3L, childPicker0); + } + + @Test + public void overWeightAggregate() { + + List weightedChildPickers = Arrays.asList( + new WeightedChildPicker(Integer.MAX_VALUE, childPicker0), + new WeightedChildPicker(Integer.MAX_VALUE, childPicker1), + new WeightedChildPicker(10, childPicker2)); + + thrown.expect(IllegalArgumentException.class); + new WeightedRandomPicker(weightedChildPickers, fakeRandom); + } + @Test public void pickWithFakeRandom() { WeightedChildPicker weightedChildPicker0 = new WeightedChildPicker(0, childPicker0); @@ -156,6 +192,36 @@ public class WeightedRandomPickerTest { assertThat(fakeRandom.bound).isEqualTo(25); } + @Test + public void pickFromLargeTotal() { + + List weightedChildPickers = Arrays.asList( + new WeightedChildPicker(10, childPicker0), + new WeightedChildPicker(Integer.MAX_VALUE, childPicker1), + new WeightedChildPicker(10, childPicker2)); + WeightedRandomPicker xdsPicker = new WeightedRandomPicker(weightedChildPickers,fakeRandom); + + long totalWeight = weightedChildPickers.stream() + .mapToLong(WeightedChildPicker::getWeight) + .reduce(0, Long::sum); + + fakeRandom.nextLong = 5L; + assertThat(xdsPicker.pickSubchannel(pickSubchannelArgs)).isSameInstanceAs(pickResult0); + assertThat(fakeRandom.bound).isEqualTo(totalWeight); + + fakeRandom.nextLong = 16L; + assertThat(xdsPicker.pickSubchannel(pickSubchannelArgs)).isSameInstanceAs(pickResult1); + assertThat(fakeRandom.bound).isEqualTo(totalWeight); + + fakeRandom.nextLong = Integer.MAX_VALUE + 10L; + assertThat(xdsPicker.pickSubchannel(pickSubchannelArgs)).isSameInstanceAs(pickResult2); + assertThat(fakeRandom.bound).isEqualTo(totalWeight); + + fakeRandom.nextLong = Integer.MAX_VALUE + 15L; + assertThat(xdsPicker.pickSubchannel(pickSubchannelArgs)).isSameInstanceAs(pickResult2); + assertThat(fakeRandom.bound).isEqualTo(totalWeight); + } + @Test public void allZeroWeights() { WeightedChildPicker weightedChildPicker0 = new WeightedChildPicker(0, childPicker0); diff --git a/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java b/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java index b6f8b3c366..3d934e16aa 100644 --- a/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java @@ -24,6 +24,7 @@ import static io.grpc.xds.FaultFilter.HEADER_DELAY_KEY; import static io.grpc.xds.FaultFilter.HEADER_DELAY_PERCENTAGE_KEY; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -994,6 +995,7 @@ public class XdsNameResolverTest { @Test public void resolved_simpleCallSucceeds_routeToWeightedCluster() { when(mockRandom.nextInt(anyInt())).thenReturn(90, 10); + when(mockRandom.nextLong(anyLong())).thenReturn(90L, 10L); resolver.start(mockListener); FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); xdsClient.deliverLdsUpdate(