mirror of https://github.com/grpc/grpc-java.git
xds:Allow big cluster total weight (#9864)
* xds: allow sum of cluster weights above MAX_INT up to max of unsigned int. * Define nextLong(long bound) method in FakeRandom for WeightedRandomPickerTest.
This commit is contained in:
parent
04afea0fbd
commit
54c1f37093
|
@ -25,6 +25,8 @@ interface ThreadSafeRandom {
|
||||||
|
|
||||||
long nextLong();
|
long nextLong();
|
||||||
|
|
||||||
|
long nextLong(long bound);
|
||||||
|
|
||||||
final class ThreadSafeRandomImpl implements ThreadSafeRandom {
|
final class ThreadSafeRandomImpl implements ThreadSafeRandom {
|
||||||
|
|
||||||
static final ThreadSafeRandom instance = new ThreadSafeRandomImpl();
|
static final ThreadSafeRandom instance = new ThreadSafeRandomImpl();
|
||||||
|
@ -40,5 +42,10 @@ interface ThreadSafeRandom {
|
||||||
public long nextLong() {
|
public long nextLong() {
|
||||||
return ThreadLocalRandom.current().nextLong();
|
return ThreadLocalRandom.current().nextLong();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long nextLong(long bound) {
|
||||||
|
return ThreadLocalRandom.current().nextLong(bound);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,6 +21,7 @@ import static com.google.common.base.Preconditions.checkNotNull;
|
||||||
|
|
||||||
import com.google.common.annotations.VisibleForTesting;
|
import com.google.common.annotations.VisibleForTesting;
|
||||||
import com.google.common.base.MoreObjects;
|
import com.google.common.base.MoreObjects;
|
||||||
|
import com.google.common.primitives.UnsignedInteger;
|
||||||
import io.grpc.LoadBalancer.PickResult;
|
import io.grpc.LoadBalancer.PickResult;
|
||||||
import io.grpc.LoadBalancer.PickSubchannelArgs;
|
import io.grpc.LoadBalancer.PickSubchannelArgs;
|
||||||
import io.grpc.LoadBalancer.SubchannelPicker;
|
import io.grpc.LoadBalancer.SubchannelPicker;
|
||||||
|
@ -34,21 +35,22 @@ final class WeightedRandomPicker extends SubchannelPicker {
|
||||||
final List<WeightedChildPicker> weightedChildPickers;
|
final List<WeightedChildPicker> weightedChildPickers;
|
||||||
|
|
||||||
private final ThreadSafeRandom random;
|
private final ThreadSafeRandom random;
|
||||||
private final int totalWeight;
|
private final long totalWeight;
|
||||||
|
|
||||||
static final class WeightedChildPicker {
|
static final class WeightedChildPicker {
|
||||||
private final int weight;
|
private final long weight;
|
||||||
private final SubchannelPicker childPicker;
|
private final SubchannelPicker childPicker;
|
||||||
|
|
||||||
WeightedChildPicker(int weight, SubchannelPicker childPicker) {
|
WeightedChildPicker(long weight, SubchannelPicker childPicker) {
|
||||||
checkArgument(weight >= 0, "weight is negative");
|
checkArgument(weight >= 0, "weight is negative");
|
||||||
|
checkArgument(weight <= UnsignedInteger.MAX_VALUE.longValue(), "weight is too large");
|
||||||
checkNotNull(childPicker, "childPicker is null");
|
checkNotNull(childPicker, "childPicker is null");
|
||||||
|
|
||||||
this.weight = weight;
|
this.weight = weight;
|
||||||
this.childPicker = childPicker;
|
this.childPicker = childPicker;
|
||||||
}
|
}
|
||||||
|
|
||||||
int getWeight() {
|
long getWeight() {
|
||||||
return weight;
|
return weight;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -93,12 +95,16 @@ final class WeightedRandomPicker extends SubchannelPicker {
|
||||||
|
|
||||||
this.weightedChildPickers = Collections.unmodifiableList(weightedChildPickers);
|
this.weightedChildPickers = Collections.unmodifiableList(weightedChildPickers);
|
||||||
|
|
||||||
int totalWeight = 0;
|
long totalWeight = 0;
|
||||||
for (WeightedChildPicker weightedChildPicker : weightedChildPickers) {
|
for (WeightedChildPicker weightedChildPicker : weightedChildPickers) {
|
||||||
int weight = weightedChildPicker.getWeight();
|
long weight = weightedChildPicker.getWeight();
|
||||||
|
checkArgument(weight >= 0, "weight is negative");
|
||||||
|
checkNotNull(weightedChildPicker.getPicker(), "childPicker is null");
|
||||||
totalWeight += weight;
|
totalWeight += weight;
|
||||||
}
|
}
|
||||||
this.totalWeight = totalWeight;
|
this.totalWeight = totalWeight;
|
||||||
|
checkArgument(totalWeight <= UnsignedInteger.MAX_VALUE.longValue(),
|
||||||
|
"total weight greater than unsigned int can hold");
|
||||||
|
|
||||||
this.random = random;
|
this.random = random;
|
||||||
}
|
}
|
||||||
|
@ -111,15 +117,15 @@ final class WeightedRandomPicker extends SubchannelPicker {
|
||||||
childPicker =
|
childPicker =
|
||||||
weightedChildPickers.get(random.nextInt(weightedChildPickers.size())).getPicker();
|
weightedChildPickers.get(random.nextInt(weightedChildPickers.size())).getPicker();
|
||||||
} else {
|
} else {
|
||||||
int rand = random.nextInt(totalWeight);
|
long rand = random.nextLong(totalWeight);
|
||||||
|
|
||||||
// Find the first idx such that rand < accumulatedWeights[idx]
|
// Find the first idx such that rand < accumulatedWeights[idx]
|
||||||
// Not using Arrays.binarySearch for better readability.
|
// Not using Arrays.binarySearch for better readability.
|
||||||
int accumulatedWeight = 0;
|
long accumulatedWeight = 0;
|
||||||
for (int idx = 0; idx < weightedChildPickers.size(); idx++) {
|
for (WeightedChildPicker weightedChildPicker : weightedChildPickers) {
|
||||||
accumulatedWeight += weightedChildPickers.get(idx).getWeight();
|
accumulatedWeight += weightedChildPicker.getWeight();
|
||||||
if (rand < accumulatedWeight) {
|
if (rand < accumulatedWeight) {
|
||||||
childPicker = weightedChildPickers.get(idx).getPicker();
|
childPicker = weightedChildPicker.getPicker();
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -437,12 +437,12 @@ final class XdsNameResolver extends NameResolver {
|
||||||
if (action.cluster() != null) {
|
if (action.cluster() != null) {
|
||||||
cluster = prefixedClusterName(action.cluster());
|
cluster = prefixedClusterName(action.cluster());
|
||||||
} else if (action.weightedClusters() != null) {
|
} else if (action.weightedClusters() != null) {
|
||||||
int totalWeight = 0;
|
long totalWeight = 0;
|
||||||
for (ClusterWeight weightedCluster : action.weightedClusters()) {
|
for (ClusterWeight weightedCluster : action.weightedClusters()) {
|
||||||
totalWeight += weightedCluster.weight();
|
totalWeight += weightedCluster.weight();
|
||||||
}
|
}
|
||||||
int select = random.nextInt(totalWeight);
|
long select = random.nextLong(totalWeight);
|
||||||
int accumulator = 0;
|
long accumulator = 0;
|
||||||
for (ClusterWeight weightedCluster : action.weightedClusters()) {
|
for (ClusterWeight weightedCluster : action.weightedClusters()) {
|
||||||
accumulator += weightedCluster.weight();
|
accumulator += weightedCluster.weight();
|
||||||
if (select < accumulator) {
|
if (select < accumulator) {
|
||||||
|
|
|
@ -24,6 +24,7 @@ import com.google.common.base.MoreObjects;
|
||||||
import com.google.common.base.Splitter;
|
import com.google.common.base.Splitter;
|
||||||
import com.google.common.collect.ImmutableList;
|
import com.google.common.collect.ImmutableList;
|
||||||
import com.google.common.collect.ImmutableSet;
|
import com.google.common.collect.ImmutableSet;
|
||||||
|
import com.google.common.primitives.UnsignedInteger;
|
||||||
import com.google.protobuf.Any;
|
import com.google.protobuf.Any;
|
||||||
import com.google.protobuf.Duration;
|
import com.google.protobuf.Duration;
|
||||||
import com.google.protobuf.InvalidProtocolBufferException;
|
import com.google.protobuf.InvalidProtocolBufferException;
|
||||||
|
@ -477,7 +478,7 @@ class XdsRouteConfigureResource extends XdsResourceType<RdsUpdate> {
|
||||||
return StructOrError.fromError("No cluster found in weighted cluster list");
|
return StructOrError.fromError("No cluster found in weighted cluster list");
|
||||||
}
|
}
|
||||||
List<ClusterWeight> weightedClusters = new ArrayList<>();
|
List<ClusterWeight> weightedClusters = new ArrayList<>();
|
||||||
int clusterWeightSum = 0;
|
long clusterWeightSum = 0;
|
||||||
for (io.envoyproxy.envoy.config.route.v3.WeightedCluster.ClusterWeight clusterWeight
|
for (io.envoyproxy.envoy.config.route.v3.WeightedCluster.ClusterWeight clusterWeight
|
||||||
: clusterWeights) {
|
: clusterWeights) {
|
||||||
StructOrError<ClusterWeight> clusterWeightOrError =
|
StructOrError<ClusterWeight> clusterWeightOrError =
|
||||||
|
@ -492,6 +493,12 @@ class XdsRouteConfigureResource extends XdsResourceType<RdsUpdate> {
|
||||||
if (clusterWeightSum <= 0) {
|
if (clusterWeightSum <= 0) {
|
||||||
return StructOrError.fromError("Sum of cluster weights should be above 0.");
|
return StructOrError.fromError("Sum of cluster weights should be above 0.");
|
||||||
}
|
}
|
||||||
|
if (clusterWeightSum > UnsignedInteger.MAX_VALUE.longValue()) {
|
||||||
|
return StructOrError.fromError(String.format(
|
||||||
|
"Sum of cluster weights should be less than the maximum unsigned integer (%d), but"
|
||||||
|
+ " was %d. ",
|
||||||
|
UnsignedInteger.MAX_VALUE.longValue(), clusterWeightSum));
|
||||||
|
}
|
||||||
return StructOrError.fromStruct(VirtualHost.Route.RouteAction.forWeightedClusters(
|
return StructOrError.fromStruct(VirtualHost.Route.RouteAction.forWeightedClusters(
|
||||||
weightedClusters, hashPolicies, timeoutNano, retryPolicy));
|
weightedClusters, hashPolicies, timeoutNano, retryPolicy));
|
||||||
case CLUSTER_SPECIFIER_PLUGIN:
|
case CLUSTER_SPECIFIER_PLUGIN:
|
||||||
|
@ -499,7 +506,7 @@ class XdsRouteConfigureResource extends XdsResourceType<RdsUpdate> {
|
||||||
String pluginName = proto.getClusterSpecifierPlugin();
|
String pluginName = proto.getClusterSpecifierPlugin();
|
||||||
PluginConfig pluginConfig = pluginConfigMap.get(pluginName);
|
PluginConfig pluginConfig = pluginConfigMap.get(pluginName);
|
||||||
if (pluginConfig == null) {
|
if (pluginConfig == null) {
|
||||||
// Skip route if the plugin is not registered, but it's optional.
|
// Skip route if the plugin is not registered, but it is optional.
|
||||||
if (optionalPlugins.contains(pluginName)) {
|
if (optionalPlugins.contains(pluginName)) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
|
@ -87,7 +87,8 @@ public class WeightedRandomPickerTest {
|
||||||
|
|
||||||
private static final class FakeRandom implements ThreadSafeRandom {
|
private static final class FakeRandom implements ThreadSafeRandom {
|
||||||
int nextInt;
|
int nextInt;
|
||||||
int bound;
|
long bound;
|
||||||
|
Long nextLong;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int nextInt(int bound) {
|
public int nextInt(int bound) {
|
||||||
|
@ -102,6 +103,23 @@ public class WeightedRandomPickerTest {
|
||||||
public long nextLong() {
|
public long nextLong() {
|
||||||
throw new UnsupportedOperationException("Should not be called");
|
throw new UnsupportedOperationException("Should not be called");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long nextLong(long bound) {
|
||||||
|
this.bound = bound;
|
||||||
|
|
||||||
|
if (nextLong == null) {
|
||||||
|
assertThat(nextInt).isAtLeast(0);
|
||||||
|
if (bound <= Integer.MAX_VALUE) {
|
||||||
|
assertThat(nextInt).isLessThan((int)bound);
|
||||||
|
}
|
||||||
|
return nextInt;
|
||||||
|
}
|
||||||
|
|
||||||
|
assertThat(nextLong).isAtLeast(0);
|
||||||
|
assertThat(nextLong).isLessThan(bound);
|
||||||
|
return nextLong;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private final FakeRandom fakeRandom = new FakeRandom();
|
private final FakeRandom fakeRandom = new FakeRandom();
|
||||||
|
@ -120,6 +138,24 @@ public class WeightedRandomPickerTest {
|
||||||
new WeightedChildPicker(-1, childPicker0);
|
new WeightedChildPicker(-1, childPicker0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void overWeightSingle() {
|
||||||
|
thrown.expect(IllegalArgumentException.class);
|
||||||
|
new WeightedChildPicker(Integer.MAX_VALUE * 3L, childPicker0);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void overWeightAggregate() {
|
||||||
|
|
||||||
|
List<WeightedChildPicker> weightedChildPickers = Arrays.asList(
|
||||||
|
new WeightedChildPicker(Integer.MAX_VALUE, childPicker0),
|
||||||
|
new WeightedChildPicker(Integer.MAX_VALUE, childPicker1),
|
||||||
|
new WeightedChildPicker(10, childPicker2));
|
||||||
|
|
||||||
|
thrown.expect(IllegalArgumentException.class);
|
||||||
|
new WeightedRandomPicker(weightedChildPickers, fakeRandom);
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void pickWithFakeRandom() {
|
public void pickWithFakeRandom() {
|
||||||
WeightedChildPicker weightedChildPicker0 = new WeightedChildPicker(0, childPicker0);
|
WeightedChildPicker weightedChildPicker0 = new WeightedChildPicker(0, childPicker0);
|
||||||
|
@ -156,6 +192,36 @@ public class WeightedRandomPickerTest {
|
||||||
assertThat(fakeRandom.bound).isEqualTo(25);
|
assertThat(fakeRandom.bound).isEqualTo(25);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void pickFromLargeTotal() {
|
||||||
|
|
||||||
|
List<WeightedChildPicker> weightedChildPickers = Arrays.asList(
|
||||||
|
new WeightedChildPicker(10, childPicker0),
|
||||||
|
new WeightedChildPicker(Integer.MAX_VALUE, childPicker1),
|
||||||
|
new WeightedChildPicker(10, childPicker2));
|
||||||
|
WeightedRandomPicker xdsPicker = new WeightedRandomPicker(weightedChildPickers,fakeRandom);
|
||||||
|
|
||||||
|
long totalWeight = weightedChildPickers.stream()
|
||||||
|
.mapToLong(WeightedChildPicker::getWeight)
|
||||||
|
.reduce(0, Long::sum);
|
||||||
|
|
||||||
|
fakeRandom.nextLong = 5L;
|
||||||
|
assertThat(xdsPicker.pickSubchannel(pickSubchannelArgs)).isSameInstanceAs(pickResult0);
|
||||||
|
assertThat(fakeRandom.bound).isEqualTo(totalWeight);
|
||||||
|
|
||||||
|
fakeRandom.nextLong = 16L;
|
||||||
|
assertThat(xdsPicker.pickSubchannel(pickSubchannelArgs)).isSameInstanceAs(pickResult1);
|
||||||
|
assertThat(fakeRandom.bound).isEqualTo(totalWeight);
|
||||||
|
|
||||||
|
fakeRandom.nextLong = Integer.MAX_VALUE + 10L;
|
||||||
|
assertThat(xdsPicker.pickSubchannel(pickSubchannelArgs)).isSameInstanceAs(pickResult2);
|
||||||
|
assertThat(fakeRandom.bound).isEqualTo(totalWeight);
|
||||||
|
|
||||||
|
fakeRandom.nextLong = Integer.MAX_VALUE + 15L;
|
||||||
|
assertThat(xdsPicker.pickSubchannel(pickSubchannelArgs)).isSameInstanceAs(pickResult2);
|
||||||
|
assertThat(fakeRandom.bound).isEqualTo(totalWeight);
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void allZeroWeights() {
|
public void allZeroWeights() {
|
||||||
WeightedChildPicker weightedChildPicker0 = new WeightedChildPicker(0, childPicker0);
|
WeightedChildPicker weightedChildPicker0 = new WeightedChildPicker(0, childPicker0);
|
||||||
|
|
|
@ -24,6 +24,7 @@ import static io.grpc.xds.FaultFilter.HEADER_DELAY_KEY;
|
||||||
import static io.grpc.xds.FaultFilter.HEADER_DELAY_PERCENTAGE_KEY;
|
import static io.grpc.xds.FaultFilter.HEADER_DELAY_PERCENTAGE_KEY;
|
||||||
import static org.mockito.ArgumentMatchers.any;
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
import static org.mockito.ArgumentMatchers.anyInt;
|
import static org.mockito.ArgumentMatchers.anyInt;
|
||||||
|
import static org.mockito.ArgumentMatchers.anyLong;
|
||||||
import static org.mockito.ArgumentMatchers.eq;
|
import static org.mockito.ArgumentMatchers.eq;
|
||||||
import static org.mockito.Mockito.mock;
|
import static org.mockito.Mockito.mock;
|
||||||
import static org.mockito.Mockito.never;
|
import static org.mockito.Mockito.never;
|
||||||
|
@ -994,6 +995,7 @@ public class XdsNameResolverTest {
|
||||||
@Test
|
@Test
|
||||||
public void resolved_simpleCallSucceeds_routeToWeightedCluster() {
|
public void resolved_simpleCallSucceeds_routeToWeightedCluster() {
|
||||||
when(mockRandom.nextInt(anyInt())).thenReturn(90, 10);
|
when(mockRandom.nextInt(anyInt())).thenReturn(90, 10);
|
||||||
|
when(mockRandom.nextLong(anyLong())).thenReturn(90L, 10L);
|
||||||
resolver.start(mockListener);
|
resolver.start(mockListener);
|
||||||
FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient();
|
FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient();
|
||||||
xdsClient.deliverLdsUpdate(
|
xdsClient.deliverLdsUpdate(
|
||||||
|
|
Loading…
Reference in New Issue