netty,alts: fire initial protocol negotiation event in WBAEH

This change is needed after trying to use the new style protocol negotiators internally.  The problem is that some handlers fire the event in handlerAdded, which is too early.  The followup PNE is fired after handlerAdded, which breaks the composibility of the negotiators.

To fix this, this change modifies the negotiation flow.  Specifically:

* Negotiators should NEVER fire a negotiation from handlerAdded, instead they should wait until userEventTriggered
* Negotiators now do state checking on the PNE.  If it is set twice, it fails.  If it has not been received when doing the next stage of negotiation, it fails.
* WBAEH now fires the initial, default event.  This is the only handler that can fire it from handlerAdded

The tests updated are ones not using WBAEH (which they probably should).  This change ensures attributes aren't lost when doing negotiation.
This commit is contained in:
Carl Mastrangelo 2019-06-18 09:33:40 -07:00 committed by GitHub
parent 40854dc9e1
commit 9c9ca659d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 35 additions and 14 deletions

View File

@ -17,6 +17,7 @@
package io.grpc.alts.internal;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;
import static io.grpc.alts.internal.AltsProtocolNegotiator.AUTH_CONTEXT_KEY;
import static io.grpc.alts.internal.AltsProtocolNegotiator.TSI_PEER_KEY;
@ -84,7 +85,7 @@ public final class TsiHandshakeHandler extends ByteToMessageDecoder {
private final HandshakeValidator handshakeValidator;
private final ChannelHandler next;
private ProtocolNegotiationEvent pne = InternalProtocolNegotiationEvent.getDefault();
private ProtocolNegotiationEvent pne;
/**
* Constructs a TsiHandshakeHandler.
@ -148,6 +149,7 @@ public final class TsiHandshakeHandler extends ByteToMessageDecoder {
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof ProtocolNegotiationEvent) {
checkState(pne == null, "negotiation already started");
pne = (ProtocolNegotiationEvent) evt;
} else {
super.userEventTriggered(ctx, evt);
@ -156,6 +158,7 @@ public final class TsiHandshakeHandler extends ByteToMessageDecoder {
private void fireProtocolNegotiationEvent(
ChannelHandlerContext ctx, TsiPeer peer, Object authContext, SecurityDetails details) {
checkState(pne != null, "negotiation not yet complete");
InternalProtocolNegotiators.negotiationLogger(ctx)
.log(ChannelLogLevel.INFO, "TsiHandshake finished");
ProtocolNegotiationEvent localPne = pne;

View File

@ -36,6 +36,7 @@ import io.grpc.internal.FixedObjectPool;
import io.grpc.internal.GrpcAttributes;
import io.grpc.internal.ObjectPool;
import io.grpc.netty.GrpcHttp2ConnectionHandler;
import io.grpc.netty.InternalProtocolNegotiationEvent;
import io.grpc.netty.NettyChannelBuilder;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
@ -149,6 +150,7 @@ public class AltsProtocolNegotiatorTest {
new AltsProtocolNegotiator.ServerAltsProtocolNegotiator(handshakerFactory, lazyFakeChannel)
.newHandler(grpcHandler);
channel = new EmbeddedChannel(uncaughtExceptionHandler, handler);
channel.pipeline().fireUserEventTriggered(InternalProtocolNegotiationEvent.getDefault());
}
@After

View File

@ -29,6 +29,7 @@ import io.grpc.internal.GrpcAttributes;
import io.grpc.internal.ObjectPool;
import io.grpc.netty.GrpcHttp2ConnectionHandler;
import io.grpc.netty.GrpcSslContexts;
import io.grpc.netty.InternalProtocolNegotiationEvent;
import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
@ -96,6 +97,7 @@ public final class GoogleDefaultProtocolNegotiatorTest {
// Add the negotiator handler last, but to the front. Putting this in ctor above would make it
// throw early.
chan.pipeline().addFirst(h);
chan.pipeline().fireUserEventTriggered(InternalProtocolNegotiationEvent.getDefault());
// Check that the message complained about the ALTS code, rather than SSL. ALTS throws on
// being added, so it's hard to catch it at the right time to make this assertion.
@ -111,6 +113,7 @@ public final class GoogleDefaultProtocolNegotiatorTest {
ChannelHandler h = googleProtocolNegotiator.newHandler(mockHandler);
EmbeddedChannel chan = new EmbeddedChannel(h);
chan.pipeline().fireUserEventTriggered(InternalProtocolNegotiationEvent.getDefault());
assertThat(chan.pipeline().first().getClass().getSimpleName()).isEqualTo("SslHandler");
}

View File

@ -17,6 +17,7 @@
package io.grpc.netty;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;
import static io.grpc.netty.GrpcSslContexts.NEXT_PROTOCOL_VERSIONS;
import com.google.common.annotations.VisibleForTesting;
@ -327,7 +328,7 @@ final class ProtocolNegotiators {
private final String host;
private final int port;
private ProtocolNegotiationEvent pne = ProtocolNegotiationEvent.DEFAULT;
private ProtocolNegotiationEvent pne;
ClientTlsHandler(ChannelHandler next, SslContext sslContext, String authority) {
this.next = checkNotNull(next, "next");
@ -351,6 +352,7 @@ final class ProtocolNegotiators {
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof ProtocolNegotiationEvent) {
checkState(pne == null, "negotiation already started");
pne = (ProtocolNegotiationEvent) evt;
} else if (evt instanceof SslHandshakeCompletionEvent) {
SslHandshakeCompletionEvent handshakeEvent = (SslHandshakeCompletionEvent) evt;
@ -376,6 +378,7 @@ final class ProtocolNegotiators {
}
private void fireProtocolNegotiationEvent(ChannelHandlerContext ctx, SSLSession session) {
checkState(pne != null, "negotiation not yet complete");
negotiationLogger(ctx).log(ChannelLogLevel.INFO, "ClientTls finished");
Security security = new Security(new Tls(session));
Attributes attrs = pne.getAttributes().toBuilder()
@ -466,7 +469,7 @@ final class ProtocolNegotiators {
private final String authority;
private final GrpcHttp2ConnectionHandler next;
private ProtocolNegotiationEvent pne = ProtocolNegotiationEvent.DEFAULT;
private ProtocolNegotiationEvent pne;
Http2UpgradeAndGrpcHandler(String authority, GrpcHttp2ConnectionHandler next) {
this.authority = checkNotNull(authority, "authority");
@ -497,8 +500,10 @@ final class ProtocolNegotiators {
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof ProtocolNegotiationEvent) {
checkState(pne == null, "negotiation already started");
pne = (ProtocolNegotiationEvent) evt;
} else if (evt == HttpClientUpgradeHandler.UpgradeEvent.UPGRADE_SUCCESSFUL) {
checkState(pne != null, "negotiation not yet complete");
negotiationLogger(ctx).log(ChannelLogLevel.INFO, "Http2Upgrade finished");
ctx.pipeline().remove(ctx.name());
next.handleProtocolNegotiationCompleted(pne.getAttributes(), pne.getSecurity());
@ -848,7 +853,7 @@ final class ProtocolNegotiators {
*/
static final class WaitUntilActiveHandler extends ChannelInboundHandlerAdapter {
private final ChannelHandler next;
private ProtocolNegotiationEvent pne = ProtocolNegotiationEvent.DEFAULT;
private ProtocolNegotiationEvent pne;
public WaitUntilActiveHandler(ChannelHandler next) {
this.next = checkNotNull(next, "next");
@ -859,31 +864,34 @@ final class ProtocolNegotiators {
negotiationLogger(ctx).log(ChannelLogLevel.INFO, "WaitUntilActive started");
// This should be a noop, but just in case...
super.handlerAdded(ctx);
if (ctx.channel().isActive()) {
ctx.pipeline().replace(ctx.name(), null, next);
}
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
// Still propagate channelActive to the new handler.
super.channelActive(ctx);
if (pne != null) {
fireProtocolNegotiationEvent(ctx);
}
}
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
ctx.pipeline().replace(ctx.name(), null, next);
// Still propagate channelActive to the new handler.
super.channelActive(ctx);
fireProtocolNegotiationEvent(ctx);
}
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof ProtocolNegotiationEvent) {
checkState(pne == null, "negotiation already started");
pne = (ProtocolNegotiationEvent) evt;
if (ctx.channel().isActive()) {
fireProtocolNegotiationEvent(ctx);
}
} else {
super.userEventTriggered(ctx, evt);
}
}
private void fireProtocolNegotiationEvent(ChannelHandlerContext ctx) {
checkState(pne != null, "negotiation not yet complete");
negotiationLogger(ctx).log(ChannelLogLevel.INFO, "WaitUntilActive finished");
ctx.pipeline().replace(ctx.name(), /* newName= */ null, next);
Attributes attrs = pne.getAttributes().toBuilder()
.set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, ctx.channel().localAddress())
.set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, ctx.channel().remoteAddress())

View File

@ -56,6 +56,8 @@ final class WriteBufferingAndExceptionHandler extends ChannelDuplexHandler {
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
ctx.pipeline().addBefore(ctx.name(), null, next);
super.handlerAdded(ctx);
// kick off protocol negotiation.
ctx.pipeline().fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT);
}
@Override

View File

@ -169,6 +169,7 @@ public class ProtocolNegotiatorsTest {
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
ctx.pipeline().addLast(handler);
ctx.pipeline().fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT);
// do not propagate channelActive().
}
};
@ -226,6 +227,7 @@ public class ProtocolNegotiatorsTest {
assertEquals(1, latch.getCount());
chan.connect(addr).sync();
chan.pipeline().fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT);
assertTrue(latch.await(TIMEOUT_SECONDS, TimeUnit.SECONDS));
assertNull(chan.pipeline().context(WaitUntilActiveHandler.class));
}
@ -571,6 +573,7 @@ public class ProtocolNegotiatorsTest {
.connect(addr)
.sync()
.channel();
c.pipeline().fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT);
SocketAddress localAddr = c.localAddress();
ProtocolNegotiationEvent expectedEvent = ProtocolNegotiationEvent.DEFAULT
.withAttributes(