mirror of https://github.com/grpc/grpc-java.git
xds:Allow big cluster total weight (#9864)
* xds: allow sum of cluster weights above MAX_INT up to max of unsigned int. * Define nextLong(long bound) method in FakeRandom for WeightedRandomPickerTest.
This commit is contained in:
parent
04afea0fbd
commit
54c1f37093
|
@ -25,6 +25,8 @@ interface ThreadSafeRandom {
|
|||
|
||||
long nextLong();
|
||||
|
||||
long nextLong(long bound);
|
||||
|
||||
final class ThreadSafeRandomImpl implements ThreadSafeRandom {
|
||||
|
||||
static final ThreadSafeRandom instance = new ThreadSafeRandomImpl();
|
||||
|
@ -40,5 +42,10 @@ interface ThreadSafeRandom {
|
|||
public long nextLong() {
|
||||
return ThreadLocalRandom.current().nextLong();
|
||||
}
|
||||
|
||||
@Override
|
||||
public long nextLong(long bound) {
|
||||
return ThreadLocalRandom.current().nextLong(bound);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,6 +21,7 @@ import static com.google.common.base.Preconditions.checkNotNull;
|
|||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import com.google.common.base.MoreObjects;
|
||||
import com.google.common.primitives.UnsignedInteger;
|
||||
import io.grpc.LoadBalancer.PickResult;
|
||||
import io.grpc.LoadBalancer.PickSubchannelArgs;
|
||||
import io.grpc.LoadBalancer.SubchannelPicker;
|
||||
|
@ -34,21 +35,22 @@ final class WeightedRandomPicker extends SubchannelPicker {
|
|||
final List<WeightedChildPicker> weightedChildPickers;
|
||||
|
||||
private final ThreadSafeRandom random;
|
||||
private final int totalWeight;
|
||||
private final long totalWeight;
|
||||
|
||||
static final class WeightedChildPicker {
|
||||
private final int weight;
|
||||
private final long weight;
|
||||
private final SubchannelPicker childPicker;
|
||||
|
||||
WeightedChildPicker(int weight, SubchannelPicker childPicker) {
|
||||
WeightedChildPicker(long weight, SubchannelPicker childPicker) {
|
||||
checkArgument(weight >= 0, "weight is negative");
|
||||
checkArgument(weight <= UnsignedInteger.MAX_VALUE.longValue(), "weight is too large");
|
||||
checkNotNull(childPicker, "childPicker is null");
|
||||
|
||||
this.weight = weight;
|
||||
this.childPicker = childPicker;
|
||||
}
|
||||
|
||||
int getWeight() {
|
||||
long getWeight() {
|
||||
return weight;
|
||||
}
|
||||
|
||||
|
@ -93,12 +95,16 @@ final class WeightedRandomPicker extends SubchannelPicker {
|
|||
|
||||
this.weightedChildPickers = Collections.unmodifiableList(weightedChildPickers);
|
||||
|
||||
int totalWeight = 0;
|
||||
long totalWeight = 0;
|
||||
for (WeightedChildPicker weightedChildPicker : weightedChildPickers) {
|
||||
int weight = weightedChildPicker.getWeight();
|
||||
long weight = weightedChildPicker.getWeight();
|
||||
checkArgument(weight >= 0, "weight is negative");
|
||||
checkNotNull(weightedChildPicker.getPicker(), "childPicker is null");
|
||||
totalWeight += weight;
|
||||
}
|
||||
this.totalWeight = totalWeight;
|
||||
checkArgument(totalWeight <= UnsignedInteger.MAX_VALUE.longValue(),
|
||||
"total weight greater than unsigned int can hold");
|
||||
|
||||
this.random = random;
|
||||
}
|
||||
|
@ -111,15 +117,15 @@ final class WeightedRandomPicker extends SubchannelPicker {
|
|||
childPicker =
|
||||
weightedChildPickers.get(random.nextInt(weightedChildPickers.size())).getPicker();
|
||||
} else {
|
||||
int rand = random.nextInt(totalWeight);
|
||||
long rand = random.nextLong(totalWeight);
|
||||
|
||||
// Find the first idx such that rand < accumulatedWeights[idx]
|
||||
// Not using Arrays.binarySearch for better readability.
|
||||
int accumulatedWeight = 0;
|
||||
for (int idx = 0; idx < weightedChildPickers.size(); idx++) {
|
||||
accumulatedWeight += weightedChildPickers.get(idx).getWeight();
|
||||
long accumulatedWeight = 0;
|
||||
for (WeightedChildPicker weightedChildPicker : weightedChildPickers) {
|
||||
accumulatedWeight += weightedChildPicker.getWeight();
|
||||
if (rand < accumulatedWeight) {
|
||||
childPicker = weightedChildPickers.get(idx).getPicker();
|
||||
childPicker = weightedChildPicker.getPicker();
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -437,12 +437,12 @@ final class XdsNameResolver extends NameResolver {
|
|||
if (action.cluster() != null) {
|
||||
cluster = prefixedClusterName(action.cluster());
|
||||
} else if (action.weightedClusters() != null) {
|
||||
int totalWeight = 0;
|
||||
long totalWeight = 0;
|
||||
for (ClusterWeight weightedCluster : action.weightedClusters()) {
|
||||
totalWeight += weightedCluster.weight();
|
||||
}
|
||||
int select = random.nextInt(totalWeight);
|
||||
int accumulator = 0;
|
||||
long select = random.nextLong(totalWeight);
|
||||
long accumulator = 0;
|
||||
for (ClusterWeight weightedCluster : action.weightedClusters()) {
|
||||
accumulator += weightedCluster.weight();
|
||||
if (select < accumulator) {
|
||||
|
|
|
@ -24,6 +24,7 @@ import com.google.common.base.MoreObjects;
|
|||
import com.google.common.base.Splitter;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
import com.google.common.primitives.UnsignedInteger;
|
||||
import com.google.protobuf.Any;
|
||||
import com.google.protobuf.Duration;
|
||||
import com.google.protobuf.InvalidProtocolBufferException;
|
||||
|
@ -477,7 +478,7 @@ class XdsRouteConfigureResource extends XdsResourceType<RdsUpdate> {
|
|||
return StructOrError.fromError("No cluster found in weighted cluster list");
|
||||
}
|
||||
List<ClusterWeight> weightedClusters = new ArrayList<>();
|
||||
int clusterWeightSum = 0;
|
||||
long clusterWeightSum = 0;
|
||||
for (io.envoyproxy.envoy.config.route.v3.WeightedCluster.ClusterWeight clusterWeight
|
||||
: clusterWeights) {
|
||||
StructOrError<ClusterWeight> clusterWeightOrError =
|
||||
|
@ -492,6 +493,12 @@ class XdsRouteConfigureResource extends XdsResourceType<RdsUpdate> {
|
|||
if (clusterWeightSum <= 0) {
|
||||
return StructOrError.fromError("Sum of cluster weights should be above 0.");
|
||||
}
|
||||
if (clusterWeightSum > UnsignedInteger.MAX_VALUE.longValue()) {
|
||||
return StructOrError.fromError(String.format(
|
||||
"Sum of cluster weights should be less than the maximum unsigned integer (%d), but"
|
||||
+ " was %d. ",
|
||||
UnsignedInteger.MAX_VALUE.longValue(), clusterWeightSum));
|
||||
}
|
||||
return StructOrError.fromStruct(VirtualHost.Route.RouteAction.forWeightedClusters(
|
||||
weightedClusters, hashPolicies, timeoutNano, retryPolicy));
|
||||
case CLUSTER_SPECIFIER_PLUGIN:
|
||||
|
@ -499,7 +506,7 @@ class XdsRouteConfigureResource extends XdsResourceType<RdsUpdate> {
|
|||
String pluginName = proto.getClusterSpecifierPlugin();
|
||||
PluginConfig pluginConfig = pluginConfigMap.get(pluginName);
|
||||
if (pluginConfig == null) {
|
||||
// Skip route if the plugin is not registered, but it's optional.
|
||||
// Skip route if the plugin is not registered, but it is optional.
|
||||
if (optionalPlugins.contains(pluginName)) {
|
||||
return null;
|
||||
}
|
||||
|
|
|
@ -87,7 +87,8 @@ public class WeightedRandomPickerTest {
|
|||
|
||||
private static final class FakeRandom implements ThreadSafeRandom {
|
||||
int nextInt;
|
||||
int bound;
|
||||
long bound;
|
||||
Long nextLong;
|
||||
|
||||
@Override
|
||||
public int nextInt(int bound) {
|
||||
|
@ -102,6 +103,23 @@ public class WeightedRandomPickerTest {
|
|||
public long nextLong() {
|
||||
throw new UnsupportedOperationException("Should not be called");
|
||||
}
|
||||
|
||||
@Override
|
||||
public long nextLong(long bound) {
|
||||
this.bound = bound;
|
||||
|
||||
if (nextLong == null) {
|
||||
assertThat(nextInt).isAtLeast(0);
|
||||
if (bound <= Integer.MAX_VALUE) {
|
||||
assertThat(nextInt).isLessThan((int)bound);
|
||||
}
|
||||
return nextInt;
|
||||
}
|
||||
|
||||
assertThat(nextLong).isAtLeast(0);
|
||||
assertThat(nextLong).isLessThan(bound);
|
||||
return nextLong;
|
||||
}
|
||||
}
|
||||
|
||||
private final FakeRandom fakeRandom = new FakeRandom();
|
||||
|
@ -120,6 +138,24 @@ public class WeightedRandomPickerTest {
|
|||
new WeightedChildPicker(-1, childPicker0);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void overWeightSingle() {
|
||||
thrown.expect(IllegalArgumentException.class);
|
||||
new WeightedChildPicker(Integer.MAX_VALUE * 3L, childPicker0);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void overWeightAggregate() {
|
||||
|
||||
List<WeightedChildPicker> weightedChildPickers = Arrays.asList(
|
||||
new WeightedChildPicker(Integer.MAX_VALUE, childPicker0),
|
||||
new WeightedChildPicker(Integer.MAX_VALUE, childPicker1),
|
||||
new WeightedChildPicker(10, childPicker2));
|
||||
|
||||
thrown.expect(IllegalArgumentException.class);
|
||||
new WeightedRandomPicker(weightedChildPickers, fakeRandom);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void pickWithFakeRandom() {
|
||||
WeightedChildPicker weightedChildPicker0 = new WeightedChildPicker(0, childPicker0);
|
||||
|
@ -156,6 +192,36 @@ public class WeightedRandomPickerTest {
|
|||
assertThat(fakeRandom.bound).isEqualTo(25);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void pickFromLargeTotal() {
|
||||
|
||||
List<WeightedChildPicker> weightedChildPickers = Arrays.asList(
|
||||
new WeightedChildPicker(10, childPicker0),
|
||||
new WeightedChildPicker(Integer.MAX_VALUE, childPicker1),
|
||||
new WeightedChildPicker(10, childPicker2));
|
||||
WeightedRandomPicker xdsPicker = new WeightedRandomPicker(weightedChildPickers,fakeRandom);
|
||||
|
||||
long totalWeight = weightedChildPickers.stream()
|
||||
.mapToLong(WeightedChildPicker::getWeight)
|
||||
.reduce(0, Long::sum);
|
||||
|
||||
fakeRandom.nextLong = 5L;
|
||||
assertThat(xdsPicker.pickSubchannel(pickSubchannelArgs)).isSameInstanceAs(pickResult0);
|
||||
assertThat(fakeRandom.bound).isEqualTo(totalWeight);
|
||||
|
||||
fakeRandom.nextLong = 16L;
|
||||
assertThat(xdsPicker.pickSubchannel(pickSubchannelArgs)).isSameInstanceAs(pickResult1);
|
||||
assertThat(fakeRandom.bound).isEqualTo(totalWeight);
|
||||
|
||||
fakeRandom.nextLong = Integer.MAX_VALUE + 10L;
|
||||
assertThat(xdsPicker.pickSubchannel(pickSubchannelArgs)).isSameInstanceAs(pickResult2);
|
||||
assertThat(fakeRandom.bound).isEqualTo(totalWeight);
|
||||
|
||||
fakeRandom.nextLong = Integer.MAX_VALUE + 15L;
|
||||
assertThat(xdsPicker.pickSubchannel(pickSubchannelArgs)).isSameInstanceAs(pickResult2);
|
||||
assertThat(fakeRandom.bound).isEqualTo(totalWeight);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void allZeroWeights() {
|
||||
WeightedChildPicker weightedChildPicker0 = new WeightedChildPicker(0, childPicker0);
|
||||
|
|
|
@ -24,6 +24,7 @@ import static io.grpc.xds.FaultFilter.HEADER_DELAY_KEY;
|
|||
import static io.grpc.xds.FaultFilter.HEADER_DELAY_PERCENTAGE_KEY;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.anyInt;
|
||||
import static org.mockito.ArgumentMatchers.anyLong;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.never;
|
||||
|
@ -994,6 +995,7 @@ public class XdsNameResolverTest {
|
|||
@Test
|
||||
public void resolved_simpleCallSucceeds_routeToWeightedCluster() {
|
||||
when(mockRandom.nextInt(anyInt())).thenReturn(90, 10);
|
||||
when(mockRandom.nextLong(anyLong())).thenReturn(90L, 10L);
|
||||
resolver.start(mockListener);
|
||||
FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient();
|
||||
xdsClient.deliverLdsUpdate(
|
||||
|
|
Loading…
Reference in New Issue