xds: create XdsServer wrapper for correct start and shutdown semantics (#6978)

This commit is contained in:
sanjaypujare 2020-05-04 10:02:42 -07:00 committed by GitHub
parent ce9d217920
commit 1b792d1ccf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 214 additions and 96 deletions

View File

@ -23,6 +23,7 @@ import com.google.common.annotations.VisibleForTesting;
import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext;
import io.envoyproxy.envoy.api.v2.core.Node;
import io.grpc.Internal;
import io.grpc.InternalLogId;
import io.grpc.Status;
import io.grpc.SynchronizationContext;
import io.grpc.internal.ExponentialBackoffPolicy;
@ -64,12 +65,11 @@ public final class XdsClientWrapperForServerSds {
new TimeServiceResource("GrpcServerXdsClient");
private EnvoyServerProtoData.Listener curListener;
// TODO(sanjaypujare): implement shutting down XdsServer which will need xdsClient reference
@SuppressWarnings("unused")
@Nullable private final XdsClient xdsClient;
@Nullable private XdsClient xdsClient;
private final int port;
private final ScheduledExecutorService timeService;
private final XdsClient.ListenerWatcher listenerWatcher;
private ScheduledExecutorService timeService;
private XdsClient.ListenerWatcher listenerWatcher;
/**
* Thrown when no suitable management server was found in the bootstrap file.
@ -84,41 +84,83 @@ public final class XdsClientWrapperForServerSds {
}
/**
* Factory method for creating a {@link XdsClientWrapperForServerSds}.
* Creates a {@link XdsClientWrapperForServerSds}.
*
* @param port server's port for which listener config is needed.
* @param bootstrapper {@link Bootstrapper} instance to load bootstrap config.
* @param syncContext {@link SynchronizationContext} needed by {@link XdsClient}.
*/
public static XdsClientWrapperForServerSds newInstance(
int port, Bootstrapper bootstrapper, SynchronizationContext syncContext)
throws IOException, ManagementServerNotFoundException {
Bootstrapper.BootstrapInfo bootstrapInfo = bootstrapper.readBootstrap();
final List<Bootstrapper.ServerInfo> serverList = bootstrapInfo.getServers();
if (serverList.isEmpty()) {
throw new ManagementServerNotFoundException("No management server provided by bootstrap");
public XdsClientWrapperForServerSds(int port) {
this.port = port;
}
private SynchronizationContext createSynchronizationContext() {
final InternalLogId logId =
InternalLogId.allocate("XdsClientWrapperForServerSds", Integer.toString(port));
return new SynchronizationContext(
new Thread.UncaughtExceptionHandler() {
// needed by syncContext
private boolean panicMode;
@Override
public void uncaughtException(Thread t, Throwable e) {
logger.log(
Level.SEVERE,
"[" + logId + "] Uncaught exception in the SynchronizationContext. Panic!",
e);
panic(e);
}
void panic(final Throwable t) {
if (panicMode) {
// Preserve the first panic information
return;
}
panicMode = true;
shutdown();
}
});
}
public boolean hasXdsClient() {
return xdsClient != null;
}
/** Creates an XdsClient and starts a watch. */
public void createXdsClientAndStart() {
checkState(xdsClient == null, "start() called more than once");
Bootstrapper.BootstrapInfo bootstrapInfo;
List<Bootstrapper.ServerInfo> serverList;
try {
bootstrapInfo = Bootstrapper.getInstance().readBootstrap();
serverList = bootstrapInfo.getServers();
if (serverList.isEmpty()) {
throw new ManagementServerNotFoundException("No management server provided by bootstrap");
}
} catch (IOException | ManagementServerNotFoundException e) {
logger.log(Level.FINE, "Exception reading bootstrap", e);
logger.log(Level.INFO, "Fallback to plaintext for server at port {0}", port);
return;
}
final Node node = bootstrapInfo.getNode();
ScheduledExecutorService timeService = SharedResourceHolder.get(timeServiceResource);
Node node = bootstrapInfo.getNode();
timeService = SharedResourceHolder.get(timeServiceResource);
XdsClientImpl xdsClientImpl =
new XdsClientImpl(
"",
serverList,
XdsClient.XdsChannelFactory.getInstance(),
node,
syncContext,
createSynchronizationContext(),
timeService,
new ExponentialBackoffPolicy.Provider(),
GrpcUtil.STOPWATCH_SUPPLIER);
return new XdsClientWrapperForServerSds(port, xdsClientImpl, timeService);
start(xdsClientImpl);
}
/** Accepts an XdsClient and starts a watch. */
@VisibleForTesting
XdsClientWrapperForServerSds(int port, XdsClient xdsClient,
ScheduledExecutorService timeService) {
this.port = port;
public void start(XdsClient xdsClient) {
checkState(this.xdsClient == null, "start() called more than once");
checkNotNull(xdsClient, "xdsClient");
this.xdsClient = xdsClient;
this.timeService = timeService;
this.listenerWatcher =
new XdsClient.ListenerWatcher() {
@Override
@ -271,9 +313,10 @@ public final class XdsClientWrapperForServerSds {
logger.log(Level.FINER, "Shutdown");
if (xdsClient != null) {
xdsClient.shutdown();
xdsClient = null;
}
if (timeService != null) {
timeServiceResource.close(timeService);
timeService = SharedResourceHolder.release(timeServiceResource, timeService);
}
}

View File

@ -21,7 +21,6 @@ import static com.google.common.base.Preconditions.checkNotNull;
import com.google.common.annotations.VisibleForTesting;
import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext;
import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext;
import io.grpc.SynchronizationContext;
import io.grpc.netty.GrpcHttp2ConnectionHandler;
import io.grpc.netty.InternalNettyChannelBuilder;
import io.grpc.netty.InternalNettyChannelBuilder.ProtocolNegotiatorFactory;
@ -31,17 +30,14 @@ import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator;
import io.grpc.netty.InternalProtocolNegotiators;
import io.grpc.netty.NettyChannelBuilder;
import io.grpc.netty.ProtocolNegotiationEvent;
import io.grpc.xds.Bootstrapper;
import io.grpc.xds.XdsAttributes;
import io.grpc.xds.XdsClientWrapperForServerSds;
import io.grpc.xds.XdsClientWrapperForServerSds.ManagementServerNotFoundException;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerAdapter;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.ssl.SslContext;
import io.netty.util.AsciiString;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.logging.Level;
@ -74,17 +70,8 @@ public final class SdsProtocolNegotiators {
*
* @param port the listening port passed to {@link XdsServerBuilder#forPort(int)}.
*/
public static ProtocolNegotiator serverProtocolNegotiator(
int port, SynchronizationContext syncContext) {
XdsClientWrapperForServerSds xdsClientWrapperForServerSds;
try {
xdsClientWrapperForServerSds = XdsClientWrapperForServerSds.newInstance(
port, Bootstrapper.getInstance(), syncContext);
return new ServerSdsProtocolNegotiator(xdsClientWrapperForServerSds);
} catch (IOException | ManagementServerNotFoundException e) {
logger.log(Level.INFO, "Fallback to plaintext for server at port {0}", port);
return InternalProtocolNegotiators.serverPlaintext();
}
public static ServerSdsProtocolNegotiator serverProtocolNegotiator(int port) {
return new ServerSdsProtocolNegotiator(new XdsClientWrapperForServerSds(port));
}
private static final class ClientSdsProtocolNegotiatorFactory
@ -253,6 +240,10 @@ public final class SdsProtocolNegotiators {
checkNotNull(xdsClientWrapperForServerSds, "xdsClientWrapperForServerSds");
}
XdsClientWrapperForServerSds getXdsClientWrapperForServerSds() {
return xdsClientWrapperForServerSds;
}
@Override
public AsciiString scheme() {
return SCHEME;
@ -264,11 +255,7 @@ public final class SdsProtocolNegotiators {
}
@Override
public void close() {
if (xdsClientWrapperForServerSds != null) {
xdsClientWrapperForServerSds.shutdown();
}
}
public void close() {}
}
@VisibleForTesting

View File

@ -0,0 +1,113 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.internal.sds;
import static com.google.common.base.Preconditions.checkNotNull;
import com.google.common.annotations.VisibleForTesting;
import io.grpc.Server;
import io.grpc.ServerServiceDefinition;
import io.grpc.xds.XdsClientWrapperForServerSds;
import java.io.IOException;
import java.net.SocketAddress;
import java.util.List;
import java.util.concurrent.TimeUnit;
/**
* Wraps a {@link Server} delegate and {@link XdsClientWrapperForServerSds} and intercepts {@link
* Server#shutdown()} and {@link Server#start()} to shut down and start the
* {@link XdsClientWrapperForServerSds} object.
*/
@VisibleForTesting
public final class ServerWrapperForXds extends Server {
private final Server delegate;
private final XdsClientWrapperForServerSds xdsClientWrapperForServerSds;
ServerWrapperForXds(Server delegate, XdsClientWrapperForServerSds xdsClientWrapperForServerSds) {
this.delegate = checkNotNull(delegate, "delegate");
this.xdsClientWrapperForServerSds =
checkNotNull(xdsClientWrapperForServerSds, "xdsClientWrapperForServerSds");
}
@Override
public Server start() throws IOException {
delegate.start();
if (!xdsClientWrapperForServerSds.hasXdsClient()) {
xdsClientWrapperForServerSds.createXdsClientAndStart();
}
return this;
}
@Override
public Server shutdown() {
xdsClientWrapperForServerSds.shutdown();
delegate.shutdown();
return this;
}
@Override
public Server shutdownNow() {
xdsClientWrapperForServerSds.shutdown();
delegate.shutdownNow();
return this;
}
@Override
public boolean isShutdown() {
return delegate.isShutdown();
}
@Override
public boolean isTerminated() {
return delegate.isTerminated();
}
@Override
public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException {
return delegate.awaitTermination(timeout, unit);
}
@Override
public void awaitTermination() throws InterruptedException {
delegate.awaitTermination();
}
@Override
public int getPort() {
return delegate.getPort();
}
@Override
public List<? extends SocketAddress> getListenSockets() {
return delegate.getListenSockets();
}
@Override
public List<ServerServiceDefinition> getServices() {
return delegate.getServices();
}
@Override
public List<ServerServiceDefinition> getImmutableServices() {
return delegate.getImmutableServices();
}
@Override
public List<ServerServiceDefinition> getMutableServices() {
return delegate.getMutableServices();
}
}

View File

@ -21,22 +21,18 @@ import io.grpc.BindableService;
import io.grpc.CompressorRegistry;
import io.grpc.DecompressorRegistry;
import io.grpc.HandlerRegistry;
import io.grpc.InternalLogId;
import io.grpc.Server;
import io.grpc.ServerBuilder;
import io.grpc.ServerInterceptor;
import io.grpc.ServerServiceDefinition;
import io.grpc.ServerStreamTracer;
import io.grpc.ServerTransportFilter;
import io.grpc.SynchronizationContext;
import io.grpc.netty.InternalProtocolNegotiator;
import io.grpc.netty.NettyServerBuilder;
import io.grpc.xds.internal.sds.SdsProtocolNegotiators.ServerSdsProtocolNegotiator;
import java.io.File;
import java.net.InetSocketAddress;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.annotation.Nullable;
/**
@ -44,8 +40,6 @@ import javax.annotation.Nullable;
* with peers. Note, this is not ready to use yet.
*/
public final class XdsServerBuilder extends ServerBuilder<XdsServerBuilder> {
private static final Logger logger =
Logger.getLogger(XdsServerBuilder.class.getName());
private final NettyServerBuilder delegate;
private final int port;
@ -135,33 +129,8 @@ public final class XdsServerBuilder extends ServerBuilder<XdsServerBuilder> {
@Override
public Server build() {
// note: doing it in build() will overwrite any previously set ProtocolNegotiator
final InternalLogId logId = InternalLogId.allocate("XdsServerBuilder", Integer.toString(port));
SynchronizationContext syncContext =
new SynchronizationContext(
new Thread.UncaughtExceptionHandler() {
// needed by syncContext
private boolean panicMode;
@Override
public void uncaughtException(Thread t, Throwable e) {
logger.log(
Level.SEVERE,
"[" + logId + "] Uncaught exception in the SynchronizationContext. Panic!",
e);
panic(e);
}
void panic(final Throwable t) {
if (panicMode) {
// Preserve the first panic information
return;
}
panicMode = true;
}
});
// TODO(sanjaypujare): move this to start() after creating an XdsServer wrapper
InternalProtocolNegotiator.ProtocolNegotiator serverProtocolNegotiator =
SdsProtocolNegotiators.serverProtocolNegotiator(port, syncContext);
ServerSdsProtocolNegotiator serverProtocolNegotiator =
SdsProtocolNegotiators.serverProtocolNegotiator(port);
return buildServer(serverProtocolNegotiator);
}
@ -170,9 +139,9 @@ public final class XdsServerBuilder extends ServerBuilder<XdsServerBuilder> {
* getXdsClientWrapperForServerSds from the serverSdsProtocolNegotiator.
*/
@VisibleForTesting
public Server buildServer(
InternalProtocolNegotiator.ProtocolNegotiator serverProtocolNegotiator) {
public ServerWrapperForXds buildServer(ServerSdsProtocolNegotiator serverProtocolNegotiator) {
delegate.protocolNegotiator(serverProtocolNegotiator);
return delegate.build();
return new ServerWrapperForXds(
delegate.build(), serverProtocolNegotiator.getXdsClientWrapperForServerSds());
}
}

View File

@ -33,6 +33,7 @@ import java.net.InetSocketAddress;
import java.net.UnknownHostException;
import java.util.ArrayList;
import java.util.Arrays;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
@ -145,7 +146,8 @@ public class XdsClientWrapperForServerSdsTest {
int port, DownstreamTlsContext downstreamTlsContext) {
XdsClient mockXdsClient = mock(XdsClient.class);
XdsClientWrapperForServerSds xdsClientWrapperForServerSds =
new XdsClientWrapperForServerSds(port, mockXdsClient, null);
new XdsClientWrapperForServerSds(port);
xdsClientWrapperForServerSds.start(mockXdsClient);
generateListenerUpdateToWatcher(
port, downstreamTlsContext, xdsClientWrapperForServerSds.getListenerWatcher());
return xdsClientWrapperForServerSds;
@ -163,12 +165,18 @@ public class XdsClientWrapperForServerSdsTest {
@Before
public void setUp() throws IOException {
MockitoAnnotations.initMocks(this);
xdsClientWrapperForServerSds = new XdsClientWrapperForServerSds(PORT, xdsClient, null);
xdsClientWrapperForServerSds = new XdsClientWrapperForServerSds(PORT);
xdsClientWrapperForServerSds.start(xdsClient);
tlsContexts[0] = null;
tlsContexts[1] = CommonTlsContextTestsUtil.buildTestDownstreamTlsContext("CERT1", "VA1");
tlsContexts[2] = CommonTlsContextTestsUtil.buildTestDownstreamTlsContext("CERT2", "VA2");
}
@After
public void tearDown() {
xdsClientWrapperForServerSds.shutdown();
}
/**
* Common method called by most tests. Creates 2 filterChains each with 2 addresses. First
* filterChain's destPort is always PORT.

View File

@ -32,7 +32,7 @@ import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.net.UnknownHostException;
import java.util.Collections;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
@ -55,7 +55,13 @@ public class XdsClientWrapperForServerSdsTestMisc {
@Before
public void setUp() throws IOException {
MockitoAnnotations.initMocks(this);
xdsClientWrapperForServerSds = new XdsClientWrapperForServerSds(PORT, xdsClient, null);
xdsClientWrapperForServerSds = new XdsClientWrapperForServerSds(PORT);
xdsClientWrapperForServerSds.start(xdsClient);
}
@After
public void tearDown() {
xdsClientWrapperForServerSds.shutdown();
}
@Test

View File

@ -20,8 +20,8 @@ import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import io.grpc.Server;
import io.grpc.xds.internal.sds.SdsProtocolNegotiators.ServerSdsProtocolNegotiator;
import io.grpc.xds.internal.sds.ServerWrapperForXds;
import io.grpc.xds.internal.sds.XdsServerBuilder;
import java.io.IOException;
import java.net.ServerSocket;
@ -42,10 +42,11 @@ public class XdsServerBuilderTest {
XdsServerBuilder builder = XdsServerBuilder.forPort(port);
XdsClient mockXdsClient = mock(XdsClient.class);
XdsClientWrapperForServerSds xdsClientWrapperForServerSds =
new XdsClientWrapperForServerSds(port, mockXdsClient, null);
new XdsClientWrapperForServerSds(port);
xdsClientWrapperForServerSds.start(mockXdsClient);
ServerSdsProtocolNegotiator serverSdsProtocolNegotiator =
new ServerSdsProtocolNegotiator(xdsClientWrapperForServerSds);
Server xdsServer = builder.buildServer(serverSdsProtocolNegotiator);
ServerWrapperForXds xdsServer = builder.buildServer(serverSdsProtocolNegotiator);
xdsServer.start();
xdsServer.shutdown();
xdsServer.awaitTermination(500L, TimeUnit.MILLISECONDS);

View File

@ -38,7 +38,6 @@ import io.grpc.Attributes;
import io.grpc.internal.testing.TestUtils;
import io.grpc.netty.GrpcHttp2ConnectionHandler;
import io.grpc.netty.InternalProtocolNegotiationEvent;
import io.grpc.netty.InternalProtocolNegotiator;
import io.grpc.xds.XdsAttributes;
import io.grpc.xds.XdsClientWrapperForServerSds;
import io.grpc.xds.XdsClientWrapperForServerSdsTest;
@ -284,14 +283,6 @@ public class SdsProtocolNegotiatorsTest {
assertTrue(channel.isOpen());
}
@Test
public void serverSdsProtocolNegotiator_nullSyncContext_expectPlaintext() {
InternalProtocolNegotiator.ProtocolNegotiator protocolNegotiator =
SdsProtocolNegotiators.serverProtocolNegotiator(/* port= */ 7000, /* syncContext= */ null);
assertThat(protocolNegotiator.getClass().getSimpleName())
.isEqualTo("ServerPlaintextNegotiator");
}
private static final class FakeGrpcHttp2ConnectionHandler extends GrpcHttp2ConnectionHandler {
FakeGrpcHttp2ConnectionHandler(