alts: Limit number of concurrent handshakes to 32

This commit is contained in:
Eric Anderson 2020-12-01 17:30:03 -08:00 committed by Eric Anderson
parent 7dc8ab1c6e
commit 814e36b541
1 changed files with 75 additions and 3 deletions

View File

@ -31,12 +31,16 @@ import io.grpc.netty.InternalProtocolNegotiationEvent;
import io.grpc.netty.InternalProtocolNegotiators;
import io.grpc.netty.ProtocolNegotiationEvent;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.ByteToMessageDecoder;
import java.security.GeneralSecurityException;
import java.util.LinkedList;
import java.util.List;
import java.util.Queue;
import javax.annotation.Nullable;
/**
@ -78,12 +82,17 @@ public final class TsiHandshakeHandler extends ByteToMessageDecoder {
}
private static final int HANDSHAKE_FRAME_SIZE = 1024;
// Avoid performing too many handshakes in parallel, as it may cause queuing in the handshake
// server and cause unbounded blocking on the event loop (b/168808426). This is a workaround until
// there is an async TSI handshaking API to avoid the blocking.
private static final AsyncSemaphore semaphore = new AsyncSemaphore(32);
private final NettyTsiHandshaker handshaker;
private final HandshakeValidator handshakeValidator;
private final ChannelHandler next;
private ProtocolNegotiationEvent pne;
private boolean semaphoreAcquired;
/**
* Constructs a TsiHandshakeHandler.
@ -137,13 +146,37 @@ public final class TsiHandshakeHandler extends ByteToMessageDecoder {
}
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
public void userEventTriggered(final ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof ProtocolNegotiationEvent) {
checkState(pne == null, "negotiation already started");
pne = (ProtocolNegotiationEvent) evt;
InternalProtocolNegotiators.negotiationLogger(ctx)
.log(ChannelLogLevel.INFO, "TsiHandshake started");
sendHandshake(ctx);
ChannelFuture acquire = semaphore.acquire(ctx);
if (acquire.isSuccess()) {
semaphoreAcquired = true;
sendHandshake(ctx);
} else {
acquire.addListener(new ChannelFutureListener() {
@Override public void operationComplete(ChannelFuture future) {
if (!future.isSuccess()) {
ctx.fireExceptionCaught(future.cause());
return;
}
if (ctx.isRemoved()) {
semaphore.release();
return;
}
semaphoreAcquired = true;
try {
sendHandshake(ctx);
} catch (Exception ex) {
ctx.fireExceptionCaught(ex);
}
ctx.flush();
}
});
}
} else {
super.userEventTriggered(ctx, evt);
}
@ -188,6 +221,45 @@ public final class TsiHandshakeHandler extends ByteToMessageDecoder {
@Override
protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
if (semaphoreAcquired) {
semaphore.release();
semaphoreAcquired = false;
}
handshaker.close();
}
}
private static class AsyncSemaphore {
private final Object lock = new Object();
@SuppressWarnings("JdkObsolete") // LinkedList avoids high watermark memory issues
private final Queue<ChannelPromise> queue = new LinkedList<>();
private int permits;
public AsyncSemaphore(int permits) {
this.permits = permits;
}
public ChannelFuture acquire(ChannelHandlerContext ctx) {
synchronized (lock) {
if (permits > 0) {
permits--;
return ctx.newSucceededFuture();
}
ChannelPromise promise = ctx.newPromise();
queue.add(promise);
return promise;
}
}
public void release() {
ChannelPromise next;
synchronized (lock) {
next = queue.poll();
if (next == null) {
permits++;
return;
}
}
next.setSuccess();
}
}
}