diff --git a/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java b/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java index 0db0f59eaa..a640bbd78b 100644 --- a/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java +++ b/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java @@ -43,10 +43,14 @@ import io.grpc.xds.XdsSubchannelPickers.ErrorPicker; import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Queue; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import javax.annotation.Nullable; /** @@ -120,7 +124,9 @@ final class CdsLoadBalancer2 extends LoadBalancer { * receiving the CDS LB policy config with the top-level cluster name. */ private final class CdsLbState { + private final ClusterState root; + private final Map clusterStates = new ConcurrentHashMap<>(); private LoadBalancer childLb; private CdsLbState(String rootCluster) { @@ -140,6 +146,11 @@ final class CdsLoadBalancer2 extends LoadBalancer { private void handleClusterDiscovered() { List instances = new ArrayList<>(); + + // Used for loop detection to break the infinite recursion that loops would cause + Map> parentClusters = new HashMap<>(); + Status loopStatus = null; + // Level-order traversal. // Collect configurations for all non-aggregate (leaf) clusters. Queue queue = new ArrayDeque<>(); @@ -155,26 +166,56 @@ final class CdsLoadBalancer2 extends LoadBalancer { continue; } if (clusterState.isLeaf) { - DiscoveryMechanism instance; - if (clusterState.result.clusterType() == ClusterType.EDS) { - instance = DiscoveryMechanism.forEds( - clusterState.name, clusterState.result.edsServiceName(), - clusterState.result.lrsServerInfo(), clusterState.result.maxConcurrentRequests(), - clusterState.result.upstreamTlsContext(), clusterState.result.outlierDetection()); - } else { // logical DNS - instance = DiscoveryMechanism.forLogicalDns( - clusterState.name, clusterState.result.dnsHostName(), - clusterState.result.lrsServerInfo(), clusterState.result.maxConcurrentRequests(), - clusterState.result.upstreamTlsContext()); + if (instances.stream().map(inst -> inst.cluster).noneMatch(clusterState.name::equals)) { + DiscoveryMechanism instance; + if (clusterState.result.clusterType() == ClusterType.EDS) { + instance = DiscoveryMechanism.forEds( + clusterState.name, clusterState.result.edsServiceName(), + clusterState.result.lrsServerInfo(), + clusterState.result.maxConcurrentRequests(), + clusterState.result.upstreamTlsContext(), + clusterState.result.outlierDetection()); + } else { // logical DNS + instance = DiscoveryMechanism.forLogicalDns( + clusterState.name, clusterState.result.dnsHostName(), + clusterState.result.lrsServerInfo(), + clusterState.result.maxConcurrentRequests(), + clusterState.result.upstreamTlsContext()); + } + instances.add(instance); } - instances.add(instance); } else { - if (clusterState.childClusterStates != null) { + if (clusterState.childClusterStates == null) { + continue; + } + // Do loop detection and break recursion if detected + List namesCausingLoops = identifyLoops(clusterState, parentClusters); + if (namesCausingLoops.isEmpty()) { queue.addAll(clusterState.childClusterStates.values()); + } else { + // Do cleanup + if (childLb != null) { + childLb.shutdown(); + childLb = null; + } + if (loopStatus != null) { + logger.log(XdsLogLevel.WARNING, + "Multiple loops in CDS config. Old msg: " + loopStatus.getDescription()); + } + loopStatus = Status.UNAVAILABLE.withDescription(String.format( + "CDS error: circular aggregate clusters directly under %s for " + + "root cluster %s, named %s", + clusterState.name, root.name, namesCausingLoops)); } } } } + + if (loopStatus != null) { + helper.updateBalancingState(TRANSIENT_FAILURE, new ErrorPicker(loopStatus)); + return; + } + if (instances.isEmpty()) { // none of non-aggregate clusters exists if (childLb != null) { childLb.shutdown(); @@ -214,6 +255,43 @@ final class CdsLoadBalancer2 extends LoadBalancer { resolvedAddresses.toBuilder().setLoadBalancingPolicyConfig(config).build()); } + /** + * Returns children that would cause loops and builds up the parentClusters map. + **/ + + private List identifyLoops(ClusterState clusterState, + Map> parentClusters) { + Set ancestors = new HashSet<>(); + ancestors.add(clusterState.name); + addAncestors(ancestors, clusterState, parentClusters); + + List namesCausingLoops = new ArrayList<>(); + for (ClusterState state : clusterState.childClusterStates.values()) { + if (ancestors.contains(state.name)) { + namesCausingLoops.add(state.name); + } + } + + // Update parent map with entries from remaining children to clusterState + clusterState.childClusterStates.values().stream() + .filter(child -> !namesCausingLoops.contains(child.name)) + .forEach( + child -> parentClusters.computeIfAbsent(child, k -> new ArrayList<>()) + .add(clusterState)); + + return namesCausingLoops; + } + + /** Recursively add all parents to the ancestors list. **/ + private void addAncestors(Set ancestors, ClusterState clusterState, + Map> parentClusters) { + List directParents = parentClusters.get(clusterState); + if (directParents != null) { + directParents.stream().map(c -> c.name).forEach(ancestors::add); + directParents.forEach(p -> addAncestors(ancestors, p, parentClusters)); + } + } + private void handleClusterDiscoveryError(Status error) { if (childLb != null) { childLb.handleNameResolutionError(error); @@ -238,16 +316,18 @@ final class CdsLoadBalancer2 extends LoadBalancer { } private void start() { + shutdown = false; xdsClient.watchXdsResource(XdsClusterResource.getInstance(), name, this); } void shutdown() { shutdown = true; xdsClient.cancelXdsResourceWatch(XdsClusterResource.getInstance(), name, this); - if (childClusterStates != null) { // recursively shut down all descendants - for (ClusterState state : childClusterStates.values()) { - state.shutdown(); - } + if (childClusterStates != null) { + // recursively shut down all descendants + childClusterStates.values().stream() + .filter(state -> !state.shutdown) + .forEach(ClusterState::shutdown); } } @@ -311,9 +391,24 @@ final class CdsLoadBalancer2 extends LoadBalancer { update.clusterName(), update.prioritizedClusterNames()); Map newChildStates = new LinkedHashMap<>(); for (String cluster : update.prioritizedClusterNames()) { + if (newChildStates.containsKey(cluster)) { + logger.log(XdsLogLevel.WARNING, + String.format("duplicate cluster name %s in aggregate %s is being ignored", + cluster, update.clusterName())); + continue; + } if (childClusterStates == null || !childClusterStates.containsKey(cluster)) { - ClusterState childState = new ClusterState(cluster); - childState.start(); + ClusterState childState; + if (clusterStates.containsKey(cluster)) { + childState = clusterStates.get(cluster); + if (childState.shutdown) { + childState.start(); + } + } else { + childState = new ClusterState(cluster); + clusterStates.put(cluster, childState); + childState.start(); + } newChildStates.put(cluster, childState); } else { newChildStates.put(cluster, childClusterStates.remove(cluster)); diff --git a/xds/src/test/java/io/grpc/xds/CdsLoadBalancer2Test.java b/xds/src/test/java/io/grpc/xds/CdsLoadBalancer2Test.java index 57985b7a57..4f30ec3ec4 100644 --- a/xds/src/test/java/io/grpc/xds/CdsLoadBalancer2Test.java +++ b/xds/src/test/java/io/grpc/xds/CdsLoadBalancer2Test.java @@ -23,9 +23,11 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; +import static org.mockito.Mockito.reset; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; import io.grpc.Attributes; @@ -458,6 +460,136 @@ public class CdsLoadBalancer2Test { assertThat(childBalancers).isEmpty(); } + @Test + public void aggregateCluster_withLoops() { + String cluster1 = "cluster-01.googleapis.com"; + // CLUSTER (aggr.) -> [cluster1] + CdsUpdate update = + CdsUpdate.forAggregate(CLUSTER, Collections.singletonList(cluster1)) + .roundRobinLbPolicy().build(); + xdsClient.deliverCdsUpdate(CLUSTER, update); + assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1); + + // CLUSTER (aggr.) -> [cluster2 (aggr.)] + String cluster2 = "cluster-02.googleapis.com"; + update = + CdsUpdate.forAggregate(cluster1, Collections.singletonList(cluster2)) + .roundRobinLbPolicy().build(); + xdsClient.deliverCdsUpdate(cluster1, update); + assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1, cluster2); + + // cluster2 (aggr.) -> [cluster3 (EDS), cluster1 (parent), cluster2 (self), cluster3 (dup)] + String cluster3 = "cluster-03.googleapis.com"; + CdsUpdate update2 = + CdsUpdate.forAggregate(cluster2, Arrays.asList(cluster3, cluster1, cluster2, cluster3)) + .roundRobinLbPolicy().build(); + xdsClient.deliverCdsUpdate(cluster2, update2); + assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1, cluster2, cluster3); + + reset(helper); + CdsUpdate update3 = CdsUpdate.forEds(cluster3, EDS_SERVICE_NAME, LRS_SERVER_INFO, 100L, + upstreamTlsContext, outlierDetection).roundRobinLbPolicy().build(); + xdsClient.deliverCdsUpdate(cluster3, update3); + verify(helper).updateBalancingState( + eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); + Status unavailable = Status.UNAVAILABLE.withDescription( + "CDS error: circular aggregate clusters directly under cluster-02.googleapis.com for root" + + " cluster cluster-foo.googleapis.com, named [cluster-01.googleapis.com," + + " cluster-02.googleapis.com]"); + assertPicker(pickerCaptor.getValue(), unavailable, null); + } + + @Test + public void aggregateCluster_withLoops_afterEds() { + String cluster1 = "cluster-01.googleapis.com"; + // CLUSTER (aggr.) -> [cluster1] + CdsUpdate update = + CdsUpdate.forAggregate(CLUSTER, Collections.singletonList(cluster1)) + .roundRobinLbPolicy().build(); + xdsClient.deliverCdsUpdate(CLUSTER, update); + assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1); + + // CLUSTER (aggr.) -> [cluster2 (aggr.)] + String cluster2 = "cluster-02.googleapis.com"; + update = + CdsUpdate.forAggregate(cluster1, Collections.singletonList(cluster2)) + .roundRobinLbPolicy().build(); + xdsClient.deliverCdsUpdate(cluster1, update); + assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1, cluster2); + + String cluster3 = "cluster-03.googleapis.com"; + CdsUpdate update2 = + CdsUpdate.forAggregate(cluster2, Arrays.asList(cluster3)) + .roundRobinLbPolicy().build(); + xdsClient.deliverCdsUpdate(cluster2, update2); + CdsUpdate update3 = CdsUpdate.forEds(cluster3, EDS_SERVICE_NAME, LRS_SERVER_INFO, 100L, + upstreamTlsContext, outlierDetection).roundRobinLbPolicy().build(); + xdsClient.deliverCdsUpdate(cluster3, update3); + + // cluster2 (aggr.) -> [cluster3 (EDS)] + CdsUpdate update2a = + CdsUpdate.forAggregate(cluster2, Arrays.asList(cluster3, cluster1, cluster2, cluster3)) + .roundRobinLbPolicy().build(); + xdsClient.deliverCdsUpdate(cluster2, update2a); + assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1, cluster2, cluster3); + verify(helper).updateBalancingState( + eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); + Status unavailable = Status.UNAVAILABLE.withDescription( + "CDS error: circular aggregate clusters directly under cluster-02.googleapis.com for root" + + " cluster cluster-foo.googleapis.com, named [cluster-01.googleapis.com," + + " cluster-02.googleapis.com]"); + assertPicker(pickerCaptor.getValue(), unavailable, null); + } + + @Test + public void aggregateCluster_duplicateChildren() { + String cluster1 = "cluster-01.googleapis.com"; + String cluster2 = "cluster-02.googleapis.com"; + String cluster3 = "cluster-03.googleapis.com"; + String cluster4 = "cluster-04.googleapis.com"; + + // CLUSTER (aggr.) -> [cluster1] + CdsUpdate update = + CdsUpdate.forAggregate(CLUSTER, Collections.singletonList(cluster1)) + .roundRobinLbPolicy().build(); + xdsClient.deliverCdsUpdate(CLUSTER, update); + assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1); + + // cluster1 (aggr) -> [cluster3 (EDS), cluster2 (aggr), cluster4 (aggr)] + CdsUpdate update1 = + CdsUpdate.forAggregate(cluster1, Arrays.asList(cluster3, cluster2, cluster4, cluster3)) + .roundRobinLbPolicy().build(); + xdsClient.deliverCdsUpdate(cluster1, update1); + assertThat(xdsClient.watchers.keySet()).containsExactly( + cluster3, cluster4, cluster2, cluster1, CLUSTER); + xdsClient.watchers.values().forEach(list -> assertThat(list.size()).isEqualTo(1)); + + // cluster2 (agg) -> [cluster3 (EDS), cluster4 {agg}] with dups + CdsUpdate update2 = + CdsUpdate.forAggregate(cluster2, Arrays.asList(cluster3, cluster4, cluster3)) + .roundRobinLbPolicy().build(); + xdsClient.deliverCdsUpdate(cluster2, update2); + + // Define EDS cluster + CdsUpdate update3 = CdsUpdate.forEds(cluster3, EDS_SERVICE_NAME, LRS_SERVER_INFO, 100L, + upstreamTlsContext, outlierDetection).roundRobinLbPolicy().build(); + xdsClient.deliverCdsUpdate(cluster3, update3); + + // cluster4 (agg) -> [cluster3 (EDS)] with dups (3 copies) + CdsUpdate update4 = + CdsUpdate.forAggregate(cluster4, Arrays.asList(cluster3, cluster3, cluster3)) + .roundRobinLbPolicy().build(); + xdsClient.deliverCdsUpdate(cluster4, update4); + xdsClient.watchers.values().forEach(list -> assertThat(list.size()).isEqualTo(1)); + + FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); + ClusterResolverConfig childLbConfig = (ClusterResolverConfig) childBalancer.config; + assertThat(childLbConfig.discoveryMechanisms).hasSize(1); + DiscoveryMechanism instance = Iterables.getOnlyElement(childLbConfig.discoveryMechanisms); + assertDiscoveryMechanism(instance, cluster3, DiscoveryMechanism.Type.EDS, EDS_SERVICE_NAME, + null, LRS_SERVER_INFO, 100L, upstreamTlsContext, outlierDetection); + } + @Test public void aggregateCluster_discoveryErrorBeforeChildLbCreated_returnErrorPicker() { String cluster1 = "cluster-01.googleapis.com"; @@ -550,7 +682,7 @@ public class CdsLoadBalancer2Test { assertThat(e).hasCauseThat().hasMessageThat().contains("Unable to parse"); return; } - fail("Expected the invalid config to casue an exception"); + fail("Expected the invalid config to cause an exception"); } private static void assertPicker(SubchannelPicker picker, Status expectedStatus, @@ -651,15 +783,16 @@ public class CdsLoadBalancer2Test { } private final class FakeXdsClient extends XdsClient { - private final Map> watchers = new HashMap<>(); + // watchers needs to support any non-cyclic shaped graphs + private final Map>> watchers = new HashMap<>(); @Override @SuppressWarnings("unchecked") void watchXdsResource(XdsResourceType type, String resourceName, ResourceWatcher watcher) { assertThat(type.typeName()).isEqualTo("CDS"); - assertThat(watchers).doesNotContainKey(resourceName); - watchers.put(resourceName, (ResourceWatcher)watcher); + watchers.computeIfAbsent(resourceName, k -> new ArrayList<>()) + .add((ResourceWatcher)watcher); } @Override @@ -669,25 +802,32 @@ public class CdsLoadBalancer2Test { ResourceWatcher watcher) { assertThat(type.typeName()).isEqualTo("CDS"); assertThat(watchers).containsKey(resourceName); - watchers.remove(resourceName); + List> watcherList = watchers.get(resourceName); + assertThat(watcherList.remove(watcher)).isTrue(); + if (watcherList.isEmpty()) { + watchers.remove(resourceName); + } } private void deliverCdsUpdate(String clusterName, CdsUpdate update) { if (watchers.containsKey(clusterName)) { - watchers.get(clusterName).onChanged(update); + List> resourceWatchers = + ImmutableList.copyOf(watchers.get(clusterName)); + resourceWatchers.forEach(w -> w.onChanged(update)); } } private void deliverResourceNotExist(String clusterName) { if (watchers.containsKey(clusterName)) { - watchers.get(clusterName).onResourceDoesNotExist(clusterName); + ImmutableList.copyOf(watchers.get(clusterName)) + .forEach(w -> w.onResourceDoesNotExist(clusterName)); } } private void deliverError(Status error) { - for (ResourceWatcher watcher : watchers.values()) { - watcher.onError(error); - } + watchers.values().stream() + .flatMap(List::stream) + .forEach(w -> w.onError(error)); } } }