Allow Netty server override for ProtocolNegotiator

Server implementation had to be refactored to use a ProtocolNegotiator to make this work.
This commit is contained in:
nmittler 2015-11-02 12:14:39 -08:00
parent 7767138c96
commit 8c7e251f41
8 changed files with 128 additions and 63 deletions

View File

@ -39,6 +39,7 @@ import com.google.common.base.Preconditions;
import io.grpc.Attributes;
import io.grpc.ExperimentalApi;
import io.grpc.Internal;
import io.grpc.NameResolver;
import io.grpc.internal.AbstractManagedChannelImplBuilder;
import io.grpc.internal.AbstractReferenceCounted;
@ -140,6 +141,7 @@ public class NettyChannelBuilder extends AbstractManagedChannelImplBuilder<Netty
*
* <p>Default: {@code null}.
*/
@Internal
public final NettyChannelBuilder protocolNegotiator(
@Nullable ProtocolNegotiator protocolNegotiator) {
this.protocolNegotiator = protocolNegotiator;

View File

@ -38,7 +38,6 @@ import static io.netty.channel.ChannelOption.SO_KEEPALIVE;
import io.grpc.internal.Server;
import io.grpc.internal.ServerListener;
import io.grpc.internal.SharedResourceHolder;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
@ -47,7 +46,6 @@ import io.netty.channel.ChannelInitializer;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.ServerChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.ssl.SslContext;
import io.netty.util.AbstractReferenceCounted;
import io.netty.util.ReferenceCounted;
@ -66,7 +64,7 @@ public class NettyServer implements Server {
private final SocketAddress address;
private final Class<? extends ServerChannel> channelType;
private final SslContext sslContext;
private final ProtocolNegotiator protocolNegotiator;
private final int maxStreamsPerConnection;
private final boolean usingSharedBossGroup;
private final boolean usingSharedWorkerGroup;
@ -81,13 +79,13 @@ public class NettyServer implements Server {
NettyServer(SocketAddress address, Class<? extends ServerChannel> channelType,
@Nullable EventLoopGroup bossGroup, @Nullable EventLoopGroup workerGroup,
@Nullable SslContext sslContext, int maxStreamsPerConnection, int flowControlWindow,
int maxMessageSize, int maxHeaderListSize) {
ProtocolNegotiator protocolNegotiator, int maxStreamsPerConnection,
int flowControlWindow, int maxMessageSize, int maxHeaderListSize) {
this.address = address;
this.channelType = checkNotNull(channelType, "channelType");
this.bossGroup = bossGroup;
this.workerGroup = workerGroup;
this.sslContext = sslContext;
this.protocolNegotiator = checkNotNull(protocolNegotiator, "protocolNegotiator");
this.usingSharedBossGroup = bossGroup == null;
this.usingSharedWorkerGroup = workerGroup == null;
this.maxStreamsPerConnection = maxStreamsPerConnection;
@ -120,8 +118,8 @@ public class NettyServer implements Server {
}
});
NettyServerTransport transport
= new NettyServerTransport(ch, sslContext, maxStreamsPerConnection, flowControlWindow,
maxMessageSize, maxHeaderListSize);
= new NettyServerTransport(ch, protocolNegotiator, maxStreamsPerConnection,
flowControlWindow, maxMessageSize, maxHeaderListSize);
transport.start(listener.transportCreated(transport));
}
});

View File

@ -38,6 +38,7 @@ import com.google.common.base.Preconditions;
import io.grpc.ExperimentalApi;
import io.grpc.HandlerRegistry;
import io.grpc.Internal;
import io.grpc.internal.AbstractServerImplBuilder;
import io.grpc.internal.GrpcUtil;
import io.netty.channel.EventLoopGroup;
@ -66,6 +67,7 @@ public final class NettyServerBuilder extends AbstractServerImplBuilder<NettySer
@Nullable
private EventLoopGroup workerEventLoopGroup;
private SslContext sslContext;
private ProtocolNegotiator protocolNegotiator;
private int maxConcurrentCallsPerConnection = Integer.MAX_VALUE;
private int flowControlWindow = DEFAULT_FLOW_CONTROL_WINDOW;
private int maxMessageSize = DEFAULT_MAX_MESSAGE_SIZE;
@ -179,6 +181,19 @@ public final class NettyServerBuilder extends AbstractServerImplBuilder<NettySer
return this;
}
/**
* Sets the {@link ProtocolNegotiator} to be used. If non-{@code null}, overrides the value
* specified in {@link #sslContext(SslContext)}.
*
* <p>Default: {@code null}.
*/
@Internal
public final NettyServerBuilder protocolNegotiator(
@Nullable ProtocolNegotiator protocolNegotiator) {
this.protocolNegotiator = protocolNegotiator;
return this;
}
/**
* The maximum number of concurrent calls permitted for each incoming connection. Defaults to no
* limit.
@ -221,8 +236,13 @@ public final class NettyServerBuilder extends AbstractServerImplBuilder<NettySer
@Override
protected NettyServer buildTransportServer() {
ProtocolNegotiator negotiator = protocolNegotiator;
if (negotiator == null) {
negotiator = sslContext != null ? ProtocolNegotiators.serverTls(sslContext) :
ProtocolNegotiators.serverPlaintext();
}
return new NettyServer(address, channelType, bossEventLoopGroup,
workerEventLoopGroup, sslContext, maxConcurrentCallsPerConnection, flowControlWindow,
workerEventLoopGroup, negotiator, maxConcurrentCallsPerConnection, flowControlWindow,
maxMessageSize, maxHeaderListSize);
}

View File

@ -52,14 +52,10 @@ import io.netty.handler.codec.http2.Http2HeadersDecoder;
import io.netty.handler.codec.http2.Http2InboundFrameLogger;
import io.netty.handler.codec.http2.Http2OutboundFrameLogger;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.ssl.SslContext;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.annotation.Nullable;
import javax.net.ssl.SSLEngine;
/**
* The Netty-based server transport.
*/
@ -67,7 +63,7 @@ class NettyServerTransport implements ServerTransport {
private static final Logger log = Logger.getLogger(NettyServerTransport.class.getName());
private final Channel channel;
private final SslContext sslContext;
private final ProtocolNegotiator protocolNegotiator;
private final int maxStreams;
private ServerTransportListener listener;
private boolean terminated;
@ -75,10 +71,10 @@ class NettyServerTransport implements ServerTransport {
private final int maxMessageSize;
private final int maxHeaderListSize;
NettyServerTransport(Channel channel, @Nullable SslContext sslContext, int maxStreams,
NettyServerTransport(Channel channel, ProtocolNegotiator protocolNegotiator, int maxStreams,
int flowControlWindow, int maxMessageSize, int maxHeaderListSize) {
this.channel = Preconditions.checkNotNull(channel, "channel");
this.sslContext = sslContext;
this.protocolNegotiator = Preconditions.checkNotNull(protocolNegotiator, "protocolNegotiator");
this.maxStreams = maxStreams;
this.flowControlWindow = flowControlWindow;
this.maxMessageSize = maxMessageSize;
@ -100,12 +96,8 @@ class NettyServerTransport implements ServerTransport {
}
});
ChannelHandler handler = grpcHandler;
if (sslContext != null) {
SSLEngine sslEngine = sslContext.newEngine(channel.alloc());
handler = ProtocolNegotiators.serverTls(sslEngine, grpcHandler);
}
channel.pipeline().addLast(handler);
ChannelHandler negotiationHandler = protocolNegotiator.newHandler(grpcHandler);
channel.pipeline().addLast(negotiationHandler);
}
@Override

View File

@ -31,6 +31,7 @@
package io.grpc.netty;
import io.grpc.Internal;
import io.netty.channel.ChannelHandler;
import io.netty.handler.codec.http2.Http2ConnectionHandler;
import io.netty.util.ByteString;
@ -38,6 +39,7 @@ import io.netty.util.ByteString;
/**
* A class that provides a Netty handler to control protocol negotiation.
*/
@Internal
public interface ProtocolNegotiator {
/**

View File

@ -37,6 +37,7 @@ import static io.grpc.netty.GrpcSslContexts.HTTP2_VERSIONS;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import io.grpc.Internal;
import io.grpc.Status;
import io.grpc.internal.GrpcUtil;
import io.netty.channel.ChannelDuplexHandler;
@ -44,6 +45,7 @@ import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerAdapter;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.DefaultHttpRequest;
import io.netty.handler.codec.http.HttpClientCodec;
@ -73,6 +75,7 @@ import javax.net.ssl.SSLParameters;
/**
* Common {@link ProtocolNegotiator}s used by gRPC.
*/
@Internal
public final class ProtocolNegotiators {
private static final Logger log = Logger.getLogger(ProtocolNegotiators.class.getName());
@ -80,28 +83,69 @@ public final class ProtocolNegotiators {
}
/**
* Create a TLS handler for HTTP/2 capable of using ALPN/NPN.
* Create a server plaintext handler for gRPC.
*/
public static ChannelHandler serverTls(SSLEngine sslEngine, ChannelHandler grpcHandler) {
Preconditions.checkNotNull(sslEngine, "sslEngine");
public static ProtocolNegotiator serverPlaintext() {
return new ProtocolNegotiator() {
@Override
public Handler newHandler(final Http2ConnectionHandler handler) {
return new Handler() {
@Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
// Just replace this handler with the gRPC handler.
ctx.pipeline().replace(this, null, handler);
}
return new TlsChannelInboundHandlerAdapter(new SslHandler(sslEngine, false), grpcHandler);
@Override
public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
// Don't care.
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
// Should never happen.
ctx.fireExceptionCaught(cause);
}
@Override
public ByteString scheme() {
return Utils.HTTP;
}
};
}
};
}
/**
* Create a server TLS handler for HTTP/2 capable of using ALPN/NPN.
*/
public static ProtocolNegotiator serverTls(final SslContext sslContext) {
Preconditions.checkNotNull(sslContext, "sslContext");
return new ProtocolNegotiator() {
@Override
public Handler newHandler(Http2ConnectionHandler handler) {
return new ServerTlsHandler(sslContext, handler);
}
};
}
@VisibleForTesting
static final class TlsChannelInboundHandlerAdapter extends ChannelInboundHandlerAdapter {
static final class ServerTlsHandler extends ChannelInboundHandlerAdapter
implements ProtocolNegotiator.Handler {
private final ChannelHandler grpcHandler;
private final SslHandler sslHandler;
private final SslContext sslContext;
TlsChannelInboundHandlerAdapter(SslHandler sslHandler, ChannelHandler grpcHandler) {
this.sslHandler = sslHandler;
ServerTlsHandler(SslContext sslContext, ChannelHandler grpcHandler) {
this.sslContext = sslContext;
this.grpcHandler = grpcHandler;
}
@Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
super.handlerAdded(ctx);
ctx.pipeline().addFirst(sslHandler);
SSLEngine sslEngine = sslContext.newEngine(ctx.alloc());
ctx.pipeline().addFirst(new SslHandler(sslEngine, false));
}
@Override
@ -114,7 +158,7 @@ public final class ProtocolNegotiators {
if (evt instanceof SslHandshakeCompletionEvent) {
SslHandshakeCompletionEvent handshakeEvent = (SslHandshakeCompletionEvent) evt;
if (handshakeEvent.isSuccess()) {
if (HTTP2_VERSIONS.contains(sslHandler(ctx).applicationProtocol())) {
if (HTTP2_VERSIONS.contains(sslHandler(ctx.pipeline()).applicationProtocol())) {
// Successfully negotiated the protocol. Replace this handler with
// the GRPC handler.
ctx.pipeline().replace(this, null, grpcHandler);
@ -129,13 +173,18 @@ public final class ProtocolNegotiators {
super.userEventTriggered(ctx, evt);
}
private SslHandler sslHandler(ChannelPipeline pipeline) {
return pipeline.get(SslHandler.class);
}
private void fail(ChannelHandlerContext ctx, Throwable exception) {
logSslEngineDetails(Level.FINE, ctx, "TLS negotiation failed for new client.", exception);
ctx.close();
}
private SslHandler sslHandler(ChannelHandlerContext ctx) {
return ctx.pipeline().get(SslHandler.class);
@Override
public ByteString scheme() {
return Utils.HTTPS;
}
}

View File

@ -310,8 +310,9 @@ public class NettyClientTransportTest {
File key = TestUtils.loadCert("server1.key");
SslContext serverContext = GrpcSslContexts.forServer(serverCert, key)
.ciphers(TestUtils.preferredTestCiphers(), SupportedCipherSuiteFilter.INSTANCE).build();
ProtocolNegotiator negotiator = ProtocolNegotiators.serverTls(serverContext);
server = new NettyServer(address, NioServerSocketChannel.class,
group, group, serverContext, maxStreamsPerConnection,
group, group, negotiator, maxStreamsPerConnection,
DEFAULT_WINDOW_SIZE, DEFAULT_MAX_MESSAGE_SIZE, maxHeaderListSize);
server.start(serverListener);
}

View File

@ -38,18 +38,17 @@ import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.mock;
import com.google.common.collect.Iterables;
import io.grpc.netty.ProtocolNegotiators.TlsChannelInboundHandlerAdapter;
import io.grpc.netty.ProtocolNegotiators.ServerTlsHandler;
import io.grpc.netty.ProtocolNegotiators.TlsNegotiator;
import io.grpc.testing.TestUtils;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.ssl.SslHandshakeCompletionEvent;
import io.netty.handler.ssl.SupportedCipherSuiteFilter;
import org.junit.Before;
import org.junit.Rule;
@ -58,6 +57,7 @@ import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import java.io.File;
import java.util.logging.Filter;
import java.util.logging.Level;
import java.util.logging.LogRecord;
@ -75,20 +75,17 @@ public class ProtocolNegotiatorsTest {
private EmbeddedChannel channel = new EmbeddedChannel();
private ChannelPipeline pipeline = channel.pipeline();
private SslHandler sslHandler;
private SslContext sslContext;
private SSLEngine engine;
private ChannelHandlerContext channelHandlerCtx;
@Before
public void setUp() throws Exception {
File serverCert = TestUtils.loadCert("server1.pem");
File key = TestUtils.loadCert("server1.key");
sslContext = GrpcSslContexts.forServer(serverCert, key)
.ciphers(TestUtils.preferredTestCiphers(), SupportedCipherSuiteFilter.INSTANCE).build();
engine = SSLContext.getDefault().createSSLEngine();
sslHandler = new SslHandler(engine, false) {
@Override
public String applicationProtocol() {
// Just get any of them.
return Iterables.getFirst(GrpcSslContexts.HTTP2_VERSIONS, "");
}
};
}
@Test
@ -96,13 +93,12 @@ public class ProtocolNegotiatorsTest {
thrown.expect(NullPointerException.class);
thrown.expectMessage("ssl");
ProtocolNegotiators.serverTls(null, null);
ProtocolNegotiators.serverTls(null);
}
@Test
public void tlsAdapter_exceptionClosesChannel() throws Exception {
ChannelInboundHandlerAdapter handler =
new TlsChannelInboundHandlerAdapter(sslHandler, grpcHandler);
ChannelHandler handler = new ServerTlsHandler(sslContext, grpcHandler);
// Use addFirst due to the funny error handling in EmbeddedChannel.
pipeline.addFirst(handler);
@ -114,18 +110,16 @@ public class ProtocolNegotiatorsTest {
@Test
public void tlsHandler_handlerAddedAddsSslHandler() throws Exception {
ChannelInboundHandlerAdapter handler =
new TlsChannelInboundHandlerAdapter(sslHandler, grpcHandler);
ChannelHandler handler = new ServerTlsHandler(sslContext, grpcHandler);
pipeline.addLast(handler);
assertEquals(sslHandler, pipeline.first());
assertTrue(pipeline.first() instanceof SslHandler);
}
@Test
public void tlsHandler_userEventTriggeredNonSslEvent() throws Exception {
ChannelInboundHandlerAdapter handler =
new TlsChannelInboundHandlerAdapter(sslHandler, grpcHandler);
ChannelHandler handler = new ServerTlsHandler(sslContext, grpcHandler);
pipeline.addLast(handler);
channelHandlerCtx = pipeline.context(handler);
Object nonSslEvent = new Object();
@ -146,9 +140,10 @@ public class ProtocolNegotiatorsTest {
}
};
ChannelInboundHandlerAdapter handler =
new TlsChannelInboundHandlerAdapter(badSslHandler, grpcHandler);
ChannelHandler handler = new ServerTlsHandler(sslContext, grpcHandler);
pipeline.addLast(handler);
pipeline.replace(SslHandler.class, null, badSslHandler);
channelHandlerCtx = pipeline.context(handler);
Object sslEvent = SslHandshakeCompletionEvent.SUCCESS;
@ -162,8 +157,7 @@ public class ProtocolNegotiatorsTest {
@Test
public void tlsHandler_userEventTriggeredSslEvent_handshakeFailure() throws Exception {
ChannelInboundHandlerAdapter handler =
new TlsChannelInboundHandlerAdapter(sslHandler, grpcHandler);
ChannelHandler handler = new ServerTlsHandler(sslContext, grpcHandler);
pipeline.addLast(handler);
channelHandlerCtx = pipeline.context(handler);
Object sslEvent = new SslHandshakeCompletionEvent(new RuntimeException("bad"));
@ -178,9 +172,17 @@ public class ProtocolNegotiatorsTest {
@Test
public void tlsHandler_userEventTriggeredSslEvent_supportedProtocol() throws Exception {
ChannelInboundHandlerAdapter handler =
new TlsChannelInboundHandlerAdapter(sslHandler, grpcHandler);
SslHandler goodSslHandler = new SslHandler(engine, false) {
@Override
public String applicationProtocol() {
return "h2";
}
};
ChannelHandler handler = new ServerTlsHandler(sslContext, grpcHandler);
pipeline.addLast(handler);
pipeline.replace(SslHandler.class, null, goodSslHandler);
channelHandlerCtx = pipeline.context(handler);
Object sslEvent = SslHandshakeCompletionEvent.SUCCESS;
@ -193,8 +195,7 @@ public class ProtocolNegotiatorsTest {
@Test
public void engineLog() {
ChannelInboundHandlerAdapter handler =
new TlsChannelInboundHandlerAdapter(sslHandler, grpcHandler);
ChannelHandler handler = new ServerTlsHandler(sslContext, grpcHandler);
pipeline.addLast(handler);
channelHandlerCtx = pipeline.context(handler);