s2a,netty: S2AHandshakerServiceChannel doesn't use custom event loop. (#11539)

* S2AHandshakerServiceChannel doesn't use custom event loop.

* use executorPool.

* log when channel not shutdown.

* use a cached threadpool.

* update non-executor version.
This commit is contained in:
Riya Mehta 2024-09-20 12:32:54 -07:00 committed by GitHub
parent 782a44ad62
commit e75a044107
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 56 additions and 65 deletions

View File

@ -17,12 +17,14 @@
package io.grpc.netty;
import io.grpc.ChannelLogger;
import io.grpc.internal.ObjectPool;
import io.grpc.netty.ProtocolNegotiators.ClientTlsHandler;
import io.grpc.netty.ProtocolNegotiators.GrpcNegotiationHandler;
import io.grpc.netty.ProtocolNegotiators.WaitUntilActiveHandler;
import io.netty.channel.ChannelHandler;
import io.netty.handler.ssl.SslContext;
import io.netty.util.AsciiString;
import java.util.concurrent.Executor;
/**
* Internal accessor for {@link ProtocolNegotiators}.
@ -35,9 +37,12 @@ public final class InternalProtocolNegotiators {
* Returns a {@link ProtocolNegotiator} that ensures the pipeline is set up so that TLS will
* be negotiated, the {@code handler} is added and writes to the {@link io.netty.channel.Channel}
* may happen immediately, even before the TLS Handshake is complete.
* @param executorPool a dedicated {@link Executor} pool for time-consuming TLS tasks
*/
public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext) {
final io.grpc.netty.ProtocolNegotiator negotiator = ProtocolNegotiators.tls(sslContext);
public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext,
ObjectPool<? extends Executor> executorPool) {
final io.grpc.netty.ProtocolNegotiator negotiator = ProtocolNegotiators.tls(sslContext,
executorPool);
final class TlsNegotiator implements InternalProtocolNegotiator.ProtocolNegotiator {
@Override
@ -58,6 +63,15 @@ public final class InternalProtocolNegotiators {
return new TlsNegotiator();
}
/**
* Returns a {@link ProtocolNegotiator} that ensures the pipeline is set up so that TLS will
* be negotiated, the {@code handler} is added and writes to the {@link io.netty.channel.Channel}
* may happen immediately, even before the TLS Handshake is complete.
*/
public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext) {
return tls(sslContext, null);
}
/**
* Returns a {@link ProtocolNegotiator} that ensures the pipeline is set up so that TLS will be

View File

@ -29,13 +29,11 @@ import io.grpc.ManagedChannel;
import io.grpc.MethodDescriptor;
import io.grpc.internal.SharedResourceHolder.Resource;
import io.grpc.netty.NettyChannelBuilder;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.util.concurrent.DefaultThreadFactory;
import java.time.Duration;
import java.util.Optional;
import java.util.concurrent.ConcurrentMap;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.annotation.concurrent.ThreadSafe;
/**
@ -61,7 +59,6 @@ import javax.annotation.concurrent.ThreadSafe;
public final class S2AHandshakerServiceChannel {
private static final ConcurrentMap<String, Resource<Channel>> SHARED_RESOURCE_CHANNELS =
Maps.newConcurrentMap();
private static final Duration DELEGATE_TERMINATION_TIMEOUT = Duration.ofSeconds(2);
private static final Duration CHANNEL_SHUTDOWN_TIMEOUT = Duration.ofSeconds(10);
/**
@ -95,41 +92,34 @@ public final class S2AHandshakerServiceChannel {
}
/**
* Creates a {@code EventLoopHoldingChannel} instance to the service running at {@code
* targetAddress}. This channel uses a dedicated thread pool for its {@code EventLoopGroup}
* instance to avoid blocking.
* Creates a {@code HandshakerServiceChannel} instance to the service running at {@code
* targetAddress}.
*/
@Override
public Channel create() {
EventLoopGroup eventLoopGroup =
new NioEventLoopGroup(1, new DefaultThreadFactory("S2A channel pool", true));
ManagedChannel channel = null;
if (channelCredentials.isPresent()) {
// Create a secure channel.
channel =
NettyChannelBuilder.forTarget(targetAddress, channelCredentials.get())
.channelType(NioSocketChannel.class)
.directExecutor()
.eventLoopGroup(eventLoopGroup)
.build();
} else {
// Create a plaintext channel.
channel =
NettyChannelBuilder.forTarget(targetAddress)
.channelType(NioSocketChannel.class)
.directExecutor()
.eventLoopGroup(eventLoopGroup)
.usePlaintext()
.build();
}
return EventLoopHoldingChannel.create(channel, eventLoopGroup);
return HandshakerServiceChannel.create(channel);
}
/** Destroys a {@code EventLoopHoldingChannel} instance. */
/** Destroys a {@code HandshakerServiceChannel} instance. */
@Override
public void close(Channel instanceChannel) {
checkNotNull(instanceChannel);
EventLoopHoldingChannel channel = (EventLoopHoldingChannel) instanceChannel;
HandshakerServiceChannel channel = (HandshakerServiceChannel) instanceChannel;
channel.close();
}
@ -140,23 +130,21 @@ public final class S2AHandshakerServiceChannel {
}
/**
* Manages a channel using a {@link ManagedChannel} instance that belong to the {@code
* EventLoopGroup} thread pool.
* Manages a channel using a {@link ManagedChannel} instance.
*/
@VisibleForTesting
static class EventLoopHoldingChannel extends Channel {
static class HandshakerServiceChannel extends Channel {
private static final Logger logger =
Logger.getLogger(S2AHandshakerServiceChannel.class.getName());
private final ManagedChannel delegate;
private final EventLoopGroup eventLoopGroup;
static EventLoopHoldingChannel create(ManagedChannel delegate, EventLoopGroup eventLoopGroup) {
static HandshakerServiceChannel create(ManagedChannel delegate) {
checkNotNull(delegate);
checkNotNull(eventLoopGroup);
return new EventLoopHoldingChannel(delegate, eventLoopGroup);
return new HandshakerServiceChannel(delegate);
}
private EventLoopHoldingChannel(ManagedChannel delegate, EventLoopGroup eventLoopGroup) {
private HandshakerServiceChannel(ManagedChannel delegate) {
this.delegate = delegate;
this.eventLoopGroup = eventLoopGroup;
}
/**
@ -178,16 +166,12 @@ public final class S2AHandshakerServiceChannel {
@SuppressWarnings("FutureReturnValueIgnored")
public void close() {
delegate.shutdownNow();
boolean isDelegateTerminated;
try {
isDelegateTerminated =
delegate.awaitTermination(DELEGATE_TERMINATION_TIMEOUT.getSeconds(), SECONDS);
delegate.awaitTermination(CHANNEL_SHUTDOWN_TIMEOUT.getSeconds(), SECONDS);
} catch (InterruptedException e) {
isDelegateTerminated = false;
Thread.currentThread().interrupt();
logger.log(Level.WARNING, "Channel to S2A was not shutdown.");
}
long quietPeriodSeconds = isDelegateTerminated ? 0 : 1;
eventLoopGroup.shutdownGracefully(
quietPeriodSeconds, CHANNEL_SHUTDOWN_TIMEOUT.getSeconds(), SECONDS);
}
}

View File

@ -29,7 +29,9 @@ import com.google.common.util.concurrent.ListeningExecutorService;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.errorprone.annotations.ThreadSafe;
import io.grpc.Channel;
import io.grpc.internal.GrpcUtil;
import io.grpc.internal.ObjectPool;
import io.grpc.internal.SharedResourcePool;
import io.grpc.netty.GrpcHttp2ConnectionHandler;
import io.grpc.netty.InternalProtocolNegotiator;
import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator;
@ -227,7 +229,10 @@ public final class S2AProtocolNegotiatorFactory {
@Override
public void onSuccess(SslContext sslContext) {
ChannelHandler handler =
InternalProtocolNegotiators.tls(sslContext).newHandler(grpcHandler);
InternalProtocolNegotiators.tls(
sslContext,
SharedResourcePool.forResource(GrpcUtil.SHARED_CHANNEL_EXECUTOR))
.newHandler(grpcHandler);
// Remove the bufferReads handler and delegate the rest of the handshake to the TLS
// handler.

View File

@ -18,11 +18,7 @@ package io.grpc.s2a.channel;
import static com.google.common.truth.Truth.assertThat;
import static com.google.common.truth.extensions.proto.ProtoTruth.assertThat;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.junit.Assert.assertThrows;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import io.grpc.CallOptions;
import io.grpc.Channel;
@ -39,15 +35,13 @@ import io.grpc.TlsServerCredentials;
import io.grpc.benchmarks.Utils;
import io.grpc.internal.SharedResourceHolder.Resource;
import io.grpc.netty.NettyServerBuilder;
import io.grpc.s2a.channel.S2AHandshakerServiceChannel.EventLoopHoldingChannel;
import io.grpc.s2a.channel.S2AHandshakerServiceChannel.HandshakerServiceChannel;
import io.grpc.stub.StreamObserver;
import io.grpc.testing.GrpcCleanupRule;
import io.grpc.testing.protobuf.SimpleRequest;
import io.grpc.testing.protobuf.SimpleResponse;
import io.grpc.testing.protobuf.SimpleServiceGrpc;
import io.netty.channel.EventLoopGroup;
import java.io.File;
import java.time.Duration;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import org.junit.Before;
@ -60,8 +54,6 @@ import org.junit.runners.JUnit4;
@RunWith(JUnit4.class)
public final class S2AHandshakerServiceChannelTest {
@ClassRule public static final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
private static final Duration CHANNEL_SHUTDOWN_TIMEOUT = Duration.ofSeconds(10);
private final EventLoopGroup mockEventLoopGroup = mock(EventLoopGroup.class);
private Server mtlsServer;
private Server plaintextServer;
@ -191,7 +183,7 @@ public final class S2AHandshakerServiceChannelTest {
}
/**
* Verifies that an {@code EventLoopHoldingChannel}'s {@code newCall} method can be used to
* Verifies that an {@code HandshakerServiceChannel}'s {@code newCall} method can be used to
* perform a simple RPC.
*/
@Test
@ -201,7 +193,7 @@ public final class S2AHandshakerServiceChannelTest {
"localhost:" + plaintextServer.getPort(),
/* s2aChannelCredentials= */ Optional.empty());
Channel channel = resource.create();
assertThat(channel).isInstanceOf(EventLoopHoldingChannel.class);
assertThat(channel).isInstanceOf(HandshakerServiceChannel.class);
assertThat(
SimpleServiceGrpc.newBlockingStub(channel).unaryRpc(SimpleRequest.getDefaultInstance()))
.isEqualToDefaultInstance();
@ -214,53 +206,49 @@ public final class S2AHandshakerServiceChannelTest {
S2AHandshakerServiceChannel.getChannelResource(
"localhost:" + mtlsServer.getPort(), getTlsChannelCredentials());
Channel channel = resource.create();
assertThat(channel).isInstanceOf(EventLoopHoldingChannel.class);
assertThat(channel).isInstanceOf(HandshakerServiceChannel.class);
assertThat(
SimpleServiceGrpc.newBlockingStub(channel).unaryRpc(SimpleRequest.getDefaultInstance()))
.isEqualToDefaultInstance();
}
/** Creates a {@code EventLoopHoldingChannel} instance and verifies its authority. */
/** Creates a {@code HandshakerServiceChannel} instance and verifies its authority. */
@Test
public void authority_success() throws Exception {
ManagedChannel channel = new FakeManagedChannel(true);
EventLoopHoldingChannel eventLoopHoldingChannel =
EventLoopHoldingChannel.create(channel, mockEventLoopGroup);
HandshakerServiceChannel eventLoopHoldingChannel =
HandshakerServiceChannel.create(channel);
assertThat(eventLoopHoldingChannel.authority()).isEqualTo("FakeManagedChannel");
}
/**
* Creates and closes a {@code EventLoopHoldingChannel} when its {@code ManagedChannel} terminates
* successfully.
* Creates and closes a {@code HandshakerServiceChannel} when its {@code ManagedChannel}
* terminates successfully.
*/
@Test
public void close_withDelegateTerminatedSuccess() throws Exception {
ManagedChannel channel = new FakeManagedChannel(true);
EventLoopHoldingChannel eventLoopHoldingChannel =
EventLoopHoldingChannel.create(channel, mockEventLoopGroup);
HandshakerServiceChannel eventLoopHoldingChannel =
HandshakerServiceChannel.create(channel);
eventLoopHoldingChannel.close();
assertThat(channel.isShutdown()).isTrue();
verify(mockEventLoopGroup, times(1))
.shutdownGracefully(0, CHANNEL_SHUTDOWN_TIMEOUT.getSeconds(), SECONDS);
}
/**
* Creates and closes a {@code EventLoopHoldingChannel} when its {@code ManagedChannel} does not
* Creates and closes a {@code HandshakerServiceChannel} when its {@code ManagedChannel} does not
* terminate successfully.
*/
@Test
public void close_withDelegateTerminatedFailure() throws Exception {
ManagedChannel channel = new FakeManagedChannel(false);
EventLoopHoldingChannel eventLoopHoldingChannel =
EventLoopHoldingChannel.create(channel, mockEventLoopGroup);
HandshakerServiceChannel eventLoopHoldingChannel =
HandshakerServiceChannel.create(channel);
eventLoopHoldingChannel.close();
assertThat(channel.isShutdown()).isTrue();
verify(mockEventLoopGroup, times(1))
.shutdownGracefully(1, CHANNEL_SHUTDOWN_TIMEOUT.getSeconds(), SECONDS);
}
/**
* Creates and closes a {@code EventLoopHoldingChannel}, creates a new channel from the same
* Creates and closes a {@code HandshakerServiceChannel}, creates a new channel from the same
* resource, and verifies that this second channel is useable.
*/
@Test
@ -273,7 +261,7 @@ public final class S2AHandshakerServiceChannelTest {
resource.close(channelOne);
Channel channelTwo = resource.create();
assertThat(channelTwo).isInstanceOf(EventLoopHoldingChannel.class);
assertThat(channelTwo).isInstanceOf(HandshakerServiceChannel.class);
assertThat(
SimpleServiceGrpc.newBlockingStub(channelTwo)
.unaryRpc(SimpleRequest.getDefaultInstance()))
@ -291,7 +279,7 @@ public final class S2AHandshakerServiceChannelTest {
resource.close(channelOne);
Channel channelTwo = resource.create();
assertThat(channelTwo).isInstanceOf(EventLoopHoldingChannel.class);
assertThat(channelTwo).isInstanceOf(HandshakerServiceChannel.class);
assertThat(
SimpleServiceGrpc.newBlockingStub(channelTwo)
.unaryRpc(SimpleRequest.getDefaultInstance()))