From 1b792d1ccf9af3f705d1d5e4e85d28cd25f9a155 Mon Sep 17 00:00:00 2001 From: sanjaypujare Date: Mon, 4 May 2020 10:02:42 -0700 Subject: [PATCH] xds: create XdsServer wrapper for correct start and shutdown semantics (#6978) --- .../xds/XdsClientWrapperForServerSds.java | 89 ++++++++++---- .../internal/sds/SdsProtocolNegotiators.java | 27 ++--- .../xds/internal/sds/ServerWrapperForXds.java | 113 ++++++++++++++++++ .../xds/internal/sds/XdsServerBuilder.java | 43 +------ .../xds/XdsClientWrapperForServerSdsTest.java | 12 +- .../XdsClientWrapperForServerSdsTestMisc.java | 10 +- .../io/grpc/xds/XdsServerBuilderTest.java | 7 +- .../sds/SdsProtocolNegotiatorsTest.java | 9 -- 8 files changed, 214 insertions(+), 96 deletions(-) create mode 100644 xds/src/main/java/io/grpc/xds/internal/sds/ServerWrapperForXds.java diff --git a/xds/src/main/java/io/grpc/xds/XdsClientWrapperForServerSds.java b/xds/src/main/java/io/grpc/xds/XdsClientWrapperForServerSds.java index d51265d993..f8295896c9 100644 --- a/xds/src/main/java/io/grpc/xds/XdsClientWrapperForServerSds.java +++ b/xds/src/main/java/io/grpc/xds/XdsClientWrapperForServerSds.java @@ -23,6 +23,7 @@ import com.google.common.annotations.VisibleForTesting; import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext; import io.envoyproxy.envoy.api.v2.core.Node; import io.grpc.Internal; +import io.grpc.InternalLogId; import io.grpc.Status; import io.grpc.SynchronizationContext; import io.grpc.internal.ExponentialBackoffPolicy; @@ -64,12 +65,11 @@ public final class XdsClientWrapperForServerSds { new TimeServiceResource("GrpcServerXdsClient"); private EnvoyServerProtoData.Listener curListener; - // TODO(sanjaypujare): implement shutting down XdsServer which will need xdsClient reference @SuppressWarnings("unused") - @Nullable private final XdsClient xdsClient; + @Nullable private XdsClient xdsClient; private final int port; - private final ScheduledExecutorService timeService; - private final XdsClient.ListenerWatcher listenerWatcher; + private ScheduledExecutorService timeService; + private XdsClient.ListenerWatcher listenerWatcher; /** * Thrown when no suitable management server was found in the bootstrap file. @@ -84,41 +84,83 @@ public final class XdsClientWrapperForServerSds { } /** - * Factory method for creating a {@link XdsClientWrapperForServerSds}. + * Creates a {@link XdsClientWrapperForServerSds}. * * @param port server's port for which listener config is needed. - * @param bootstrapper {@link Bootstrapper} instance to load bootstrap config. - * @param syncContext {@link SynchronizationContext} needed by {@link XdsClient}. */ - public static XdsClientWrapperForServerSds newInstance( - int port, Bootstrapper bootstrapper, SynchronizationContext syncContext) - throws IOException, ManagementServerNotFoundException { - Bootstrapper.BootstrapInfo bootstrapInfo = bootstrapper.readBootstrap(); - final List serverList = bootstrapInfo.getServers(); - if (serverList.isEmpty()) { - throw new ManagementServerNotFoundException("No management server provided by bootstrap"); + public XdsClientWrapperForServerSds(int port) { + this.port = port; + } + + private SynchronizationContext createSynchronizationContext() { + final InternalLogId logId = + InternalLogId.allocate("XdsClientWrapperForServerSds", Integer.toString(port)); + return new SynchronizationContext( + new Thread.UncaughtExceptionHandler() { + // needed by syncContext + private boolean panicMode; + + @Override + public void uncaughtException(Thread t, Throwable e) { + logger.log( + Level.SEVERE, + "[" + logId + "] Uncaught exception in the SynchronizationContext. Panic!", + e); + panic(e); + } + + void panic(final Throwable t) { + if (panicMode) { + // Preserve the first panic information + return; + } + panicMode = true; + shutdown(); + } + }); + } + + public boolean hasXdsClient() { + return xdsClient != null; + } + + /** Creates an XdsClient and starts a watch. */ + public void createXdsClientAndStart() { + checkState(xdsClient == null, "start() called more than once"); + Bootstrapper.BootstrapInfo bootstrapInfo; + List serverList; + try { + bootstrapInfo = Bootstrapper.getInstance().readBootstrap(); + serverList = bootstrapInfo.getServers(); + if (serverList.isEmpty()) { + throw new ManagementServerNotFoundException("No management server provided by bootstrap"); + } + } catch (IOException | ManagementServerNotFoundException e) { + logger.log(Level.FINE, "Exception reading bootstrap", e); + logger.log(Level.INFO, "Fallback to plaintext for server at port {0}", port); + return; } - final Node node = bootstrapInfo.getNode(); - ScheduledExecutorService timeService = SharedResourceHolder.get(timeServiceResource); + Node node = bootstrapInfo.getNode(); + timeService = SharedResourceHolder.get(timeServiceResource); XdsClientImpl xdsClientImpl = new XdsClientImpl( "", serverList, XdsClient.XdsChannelFactory.getInstance(), node, - syncContext, + createSynchronizationContext(), timeService, new ExponentialBackoffPolicy.Provider(), GrpcUtil.STOPWATCH_SUPPLIER); - return new XdsClientWrapperForServerSds(port, xdsClientImpl, timeService); + start(xdsClientImpl); } + /** Accepts an XdsClient and starts a watch. */ @VisibleForTesting - XdsClientWrapperForServerSds(int port, XdsClient xdsClient, - ScheduledExecutorService timeService) { - this.port = port; + public void start(XdsClient xdsClient) { + checkState(this.xdsClient == null, "start() called more than once"); + checkNotNull(xdsClient, "xdsClient"); this.xdsClient = xdsClient; - this.timeService = timeService; this.listenerWatcher = new XdsClient.ListenerWatcher() { @Override @@ -271,9 +313,10 @@ public final class XdsClientWrapperForServerSds { logger.log(Level.FINER, "Shutdown"); if (xdsClient != null) { xdsClient.shutdown(); + xdsClient = null; } if (timeService != null) { - timeServiceResource.close(timeService); + timeService = SharedResourceHolder.release(timeServiceResource, timeService); } } diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/SdsProtocolNegotiators.java b/xds/src/main/java/io/grpc/xds/internal/sds/SdsProtocolNegotiators.java index 255632d0ff..382a27ce35 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/SdsProtocolNegotiators.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/SdsProtocolNegotiators.java @@ -21,7 +21,6 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.annotations.VisibleForTesting; import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext; import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext; -import io.grpc.SynchronizationContext; import io.grpc.netty.GrpcHttp2ConnectionHandler; import io.grpc.netty.InternalNettyChannelBuilder; import io.grpc.netty.InternalNettyChannelBuilder.ProtocolNegotiatorFactory; @@ -31,17 +30,14 @@ import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator; import io.grpc.netty.InternalProtocolNegotiators; import io.grpc.netty.NettyChannelBuilder; import io.grpc.netty.ProtocolNegotiationEvent; -import io.grpc.xds.Bootstrapper; import io.grpc.xds.XdsAttributes; import io.grpc.xds.XdsClientWrapperForServerSds; -import io.grpc.xds.XdsClientWrapperForServerSds.ManagementServerNotFoundException; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerAdapter; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.handler.ssl.SslContext; import io.netty.util.AsciiString; -import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.logging.Level; @@ -74,17 +70,8 @@ public final class SdsProtocolNegotiators { * * @param port the listening port passed to {@link XdsServerBuilder#forPort(int)}. */ - public static ProtocolNegotiator serverProtocolNegotiator( - int port, SynchronizationContext syncContext) { - XdsClientWrapperForServerSds xdsClientWrapperForServerSds; - try { - xdsClientWrapperForServerSds = XdsClientWrapperForServerSds.newInstance( - port, Bootstrapper.getInstance(), syncContext); - return new ServerSdsProtocolNegotiator(xdsClientWrapperForServerSds); - } catch (IOException | ManagementServerNotFoundException e) { - logger.log(Level.INFO, "Fallback to plaintext for server at port {0}", port); - return InternalProtocolNegotiators.serverPlaintext(); - } + public static ServerSdsProtocolNegotiator serverProtocolNegotiator(int port) { + return new ServerSdsProtocolNegotiator(new XdsClientWrapperForServerSds(port)); } private static final class ClientSdsProtocolNegotiatorFactory @@ -253,6 +240,10 @@ public final class SdsProtocolNegotiators { checkNotNull(xdsClientWrapperForServerSds, "xdsClientWrapperForServerSds"); } + XdsClientWrapperForServerSds getXdsClientWrapperForServerSds() { + return xdsClientWrapperForServerSds; + } + @Override public AsciiString scheme() { return SCHEME; @@ -264,11 +255,7 @@ public final class SdsProtocolNegotiators { } @Override - public void close() { - if (xdsClientWrapperForServerSds != null) { - xdsClientWrapperForServerSds.shutdown(); - } - } + public void close() {} } @VisibleForTesting diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/ServerWrapperForXds.java b/xds/src/main/java/io/grpc/xds/internal/sds/ServerWrapperForXds.java new file mode 100644 index 0000000000..09bd879305 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/sds/ServerWrapperForXds.java @@ -0,0 +1,113 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.sds; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.annotations.VisibleForTesting; +import io.grpc.Server; +import io.grpc.ServerServiceDefinition; +import io.grpc.xds.XdsClientWrapperForServerSds; +import java.io.IOException; +import java.net.SocketAddress; +import java.util.List; +import java.util.concurrent.TimeUnit; + +/** + * Wraps a {@link Server} delegate and {@link XdsClientWrapperForServerSds} and intercepts {@link + * Server#shutdown()} and {@link Server#start()} to shut down and start the + * {@link XdsClientWrapperForServerSds} object. + */ +@VisibleForTesting +public final class ServerWrapperForXds extends Server { + private final Server delegate; + private final XdsClientWrapperForServerSds xdsClientWrapperForServerSds; + + ServerWrapperForXds(Server delegate, XdsClientWrapperForServerSds xdsClientWrapperForServerSds) { + this.delegate = checkNotNull(delegate, "delegate"); + this.xdsClientWrapperForServerSds = + checkNotNull(xdsClientWrapperForServerSds, "xdsClientWrapperForServerSds"); + } + + @Override + public Server start() throws IOException { + delegate.start(); + if (!xdsClientWrapperForServerSds.hasXdsClient()) { + xdsClientWrapperForServerSds.createXdsClientAndStart(); + } + return this; + } + + @Override + public Server shutdown() { + xdsClientWrapperForServerSds.shutdown(); + delegate.shutdown(); + return this; + } + + @Override + public Server shutdownNow() { + xdsClientWrapperForServerSds.shutdown(); + delegate.shutdownNow(); + return this; + } + + @Override + public boolean isShutdown() { + return delegate.isShutdown(); + } + + @Override + public boolean isTerminated() { + return delegate.isTerminated(); + } + + @Override + public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { + return delegate.awaitTermination(timeout, unit); + } + + @Override + public void awaitTermination() throws InterruptedException { + delegate.awaitTermination(); + } + + @Override + public int getPort() { + return delegate.getPort(); + } + + @Override + public List getListenSockets() { + return delegate.getListenSockets(); + } + + @Override + public List getServices() { + return delegate.getServices(); + } + + @Override + public List getImmutableServices() { + return delegate.getImmutableServices(); + } + + @Override + public List getMutableServices() { + return delegate.getMutableServices(); + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/XdsServerBuilder.java b/xds/src/main/java/io/grpc/xds/internal/sds/XdsServerBuilder.java index 5c33162cb6..d972027639 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/XdsServerBuilder.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/XdsServerBuilder.java @@ -21,22 +21,18 @@ import io.grpc.BindableService; import io.grpc.CompressorRegistry; import io.grpc.DecompressorRegistry; import io.grpc.HandlerRegistry; -import io.grpc.InternalLogId; import io.grpc.Server; import io.grpc.ServerBuilder; import io.grpc.ServerInterceptor; import io.grpc.ServerServiceDefinition; import io.grpc.ServerStreamTracer; import io.grpc.ServerTransportFilter; -import io.grpc.SynchronizationContext; -import io.grpc.netty.InternalProtocolNegotiator; import io.grpc.netty.NettyServerBuilder; +import io.grpc.xds.internal.sds.SdsProtocolNegotiators.ServerSdsProtocolNegotiator; import java.io.File; import java.net.InetSocketAddress; import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; -import java.util.logging.Level; -import java.util.logging.Logger; import javax.annotation.Nullable; /** @@ -44,8 +40,6 @@ import javax.annotation.Nullable; * with peers. Note, this is not ready to use yet. */ public final class XdsServerBuilder extends ServerBuilder { - private static final Logger logger = - Logger.getLogger(XdsServerBuilder.class.getName()); private final NettyServerBuilder delegate; private final int port; @@ -135,33 +129,8 @@ public final class XdsServerBuilder extends ServerBuilder { @Override public Server build() { // note: doing it in build() will overwrite any previously set ProtocolNegotiator - final InternalLogId logId = InternalLogId.allocate("XdsServerBuilder", Integer.toString(port)); - SynchronizationContext syncContext = - new SynchronizationContext( - new Thread.UncaughtExceptionHandler() { - // needed by syncContext - private boolean panicMode; - - @Override - public void uncaughtException(Thread t, Throwable e) { - logger.log( - Level.SEVERE, - "[" + logId + "] Uncaught exception in the SynchronizationContext. Panic!", - e); - panic(e); - } - - void panic(final Throwable t) { - if (panicMode) { - // Preserve the first panic information - return; - } - panicMode = true; - } - }); - // TODO(sanjaypujare): move this to start() after creating an XdsServer wrapper - InternalProtocolNegotiator.ProtocolNegotiator serverProtocolNegotiator = - SdsProtocolNegotiators.serverProtocolNegotiator(port, syncContext); + ServerSdsProtocolNegotiator serverProtocolNegotiator = + SdsProtocolNegotiators.serverProtocolNegotiator(port); return buildServer(serverProtocolNegotiator); } @@ -170,9 +139,9 @@ public final class XdsServerBuilder extends ServerBuilder { * getXdsClientWrapperForServerSds from the serverSdsProtocolNegotiator. */ @VisibleForTesting - public Server buildServer( - InternalProtocolNegotiator.ProtocolNegotiator serverProtocolNegotiator) { + public ServerWrapperForXds buildServer(ServerSdsProtocolNegotiator serverProtocolNegotiator) { delegate.protocolNegotiator(serverProtocolNegotiator); - return delegate.build(); + return new ServerWrapperForXds( + delegate.build(), serverProtocolNegotiator.getXdsClientWrapperForServerSds()); } } diff --git a/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTest.java b/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTest.java index e0882bf94c..860532fa26 100644 --- a/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTest.java @@ -33,6 +33,7 @@ import java.net.InetSocketAddress; import java.net.UnknownHostException; import java.util.ArrayList; import java.util.Arrays; +import org.junit.After; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -145,7 +146,8 @@ public class XdsClientWrapperForServerSdsTest { int port, DownstreamTlsContext downstreamTlsContext) { XdsClient mockXdsClient = mock(XdsClient.class); XdsClientWrapperForServerSds xdsClientWrapperForServerSds = - new XdsClientWrapperForServerSds(port, mockXdsClient, null); + new XdsClientWrapperForServerSds(port); + xdsClientWrapperForServerSds.start(mockXdsClient); generateListenerUpdateToWatcher( port, downstreamTlsContext, xdsClientWrapperForServerSds.getListenerWatcher()); return xdsClientWrapperForServerSds; @@ -163,12 +165,18 @@ public class XdsClientWrapperForServerSdsTest { @Before public void setUp() throws IOException { MockitoAnnotations.initMocks(this); - xdsClientWrapperForServerSds = new XdsClientWrapperForServerSds(PORT, xdsClient, null); + xdsClientWrapperForServerSds = new XdsClientWrapperForServerSds(PORT); + xdsClientWrapperForServerSds.start(xdsClient); tlsContexts[0] = null; tlsContexts[1] = CommonTlsContextTestsUtil.buildTestDownstreamTlsContext("CERT1", "VA1"); tlsContexts[2] = CommonTlsContextTestsUtil.buildTestDownstreamTlsContext("CERT2", "VA2"); } + @After + public void tearDown() { + xdsClientWrapperForServerSds.shutdown(); + } + /** * Common method called by most tests. Creates 2 filterChains each with 2 addresses. First * filterChain's destPort is always PORT. diff --git a/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTestMisc.java b/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTestMisc.java index 13c451d630..5c14267908 100644 --- a/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTestMisc.java +++ b/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTestMisc.java @@ -32,7 +32,7 @@ import java.net.InetSocketAddress; import java.net.SocketAddress; import java.net.UnknownHostException; import java.util.Collections; - +import org.junit.After; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -55,7 +55,13 @@ public class XdsClientWrapperForServerSdsTestMisc { @Before public void setUp() throws IOException { MockitoAnnotations.initMocks(this); - xdsClientWrapperForServerSds = new XdsClientWrapperForServerSds(PORT, xdsClient, null); + xdsClientWrapperForServerSds = new XdsClientWrapperForServerSds(PORT); + xdsClientWrapperForServerSds.start(xdsClient); + } + + @After + public void tearDown() { + xdsClientWrapperForServerSds.shutdown(); } @Test diff --git a/xds/src/test/java/io/grpc/xds/XdsServerBuilderTest.java b/xds/src/test/java/io/grpc/xds/XdsServerBuilderTest.java index d2546b1d23..90d8fda1a4 100644 --- a/xds/src/test/java/io/grpc/xds/XdsServerBuilderTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsServerBuilderTest.java @@ -20,8 +20,8 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; -import io.grpc.Server; import io.grpc.xds.internal.sds.SdsProtocolNegotiators.ServerSdsProtocolNegotiator; +import io.grpc.xds.internal.sds.ServerWrapperForXds; import io.grpc.xds.internal.sds.XdsServerBuilder; import java.io.IOException; import java.net.ServerSocket; @@ -42,10 +42,11 @@ public class XdsServerBuilderTest { XdsServerBuilder builder = XdsServerBuilder.forPort(port); XdsClient mockXdsClient = mock(XdsClient.class); XdsClientWrapperForServerSds xdsClientWrapperForServerSds = - new XdsClientWrapperForServerSds(port, mockXdsClient, null); + new XdsClientWrapperForServerSds(port); + xdsClientWrapperForServerSds.start(mockXdsClient); ServerSdsProtocolNegotiator serverSdsProtocolNegotiator = new ServerSdsProtocolNegotiator(xdsClientWrapperForServerSds); - Server xdsServer = builder.buildServer(serverSdsProtocolNegotiator); + ServerWrapperForXds xdsServer = builder.buildServer(serverSdsProtocolNegotiator); xdsServer.start(); xdsServer.shutdown(); xdsServer.awaitTermination(500L, TimeUnit.MILLISECONDS); diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/SdsProtocolNegotiatorsTest.java b/xds/src/test/java/io/grpc/xds/internal/sds/SdsProtocolNegotiatorsTest.java index 9d49528821..8c53cd287a 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/SdsProtocolNegotiatorsTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/SdsProtocolNegotiatorsTest.java @@ -38,7 +38,6 @@ import io.grpc.Attributes; import io.grpc.internal.testing.TestUtils; import io.grpc.netty.GrpcHttp2ConnectionHandler; import io.grpc.netty.InternalProtocolNegotiationEvent; -import io.grpc.netty.InternalProtocolNegotiator; import io.grpc.xds.XdsAttributes; import io.grpc.xds.XdsClientWrapperForServerSds; import io.grpc.xds.XdsClientWrapperForServerSdsTest; @@ -284,14 +283,6 @@ public class SdsProtocolNegotiatorsTest { assertTrue(channel.isOpen()); } - @Test - public void serverSdsProtocolNegotiator_nullSyncContext_expectPlaintext() { - InternalProtocolNegotiator.ProtocolNegotiator protocolNegotiator = - SdsProtocolNegotiators.serverProtocolNegotiator(/* port= */ 7000, /* syncContext= */ null); - assertThat(protocolNegotiator.getClass().getSimpleName()) - .isEqualTo("ServerPlaintextNegotiator"); - } - private static final class FakeGrpcHttp2ConnectionHandler extends GrpcHttp2ConnectionHandler { FakeGrpcHttp2ConnectionHandler(