all: implement Helper.createResolvingOobChannelBuilder(target, creds)

- Add APIs to `ClientTransportFactory`:
```java
public interface ClientTransportFactory {
  /**
   * Swaps to a new ChannelCredentials with all other settings unchanged. Returns null if the
   * ChannelCredentials is not supported by the current ClientTransportFactory settings.
   */
  SwapChannelCredentialsResult swapChannelCredentials(ChannelCredentials channelCreds);

  final class SwapChannelCredentialsResult {
    final ClientTransportFactory transportFactory;
    @Nullable final CallCredentials callCredentials;
  }
}
```

- Add `ChannelCredentials` to constructor args of `ManagedChannelImplBuilder`:
 ```java
public ManagedChannelImplBuilder(
      String target, @Nullable ChannelCredentials channelCreds, @Nullable CallCredentials callCreds, ...)
  ```
This commit is contained in:
ZHANG Dapeng 2021-01-28 09:49:53 -08:00 committed by GitHub
parent a6df2b2ff4
commit 45a151810c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 328 additions and 49 deletions

View File

@ -19,6 +19,7 @@ package io.grpc.inprocess;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import io.grpc.ChannelCredentials;
import io.grpc.ChannelLogger;
import io.grpc.ExperimentalApi;
import io.grpc.Internal;
@ -246,6 +247,11 @@ public final class InProcessChannelBuilder extends
return timerService;
}
@Override
public SwapChannelCredentialsResult swapChannelCredentials(ChannelCredentials channelCreds) {
return null;
}
@Override
public void close() {
if (closed) {

View File

@ -20,9 +20,10 @@ import static com.google.common.base.MoreObjects.firstNonNull;
import static com.google.common.base.Preconditions.checkNotNull;
import io.grpc.Attributes;
import io.grpc.CallCredentials.RequestInfo;
import io.grpc.CallCredentials;
import io.grpc.CallCredentials.RequestInfo;
import io.grpc.CallOptions;
import io.grpc.ChannelCredentials;
import io.grpc.ChannelLogger;
import io.grpc.CompositeCallCredentials;
import io.grpc.Metadata;
@ -61,6 +62,11 @@ final class CallCredentialsApplyingTransportFactory implements ClientTransportFa
return delegate.getScheduledExecutorService();
}
@Override
public SwapChannelCredentialsResult swapChannelCredentials(ChannelCredentials channelCreds) {
throw new UnsupportedOperationException();
}
@Override
public void close() {
delegate.close();

View File

@ -19,11 +19,14 @@ package io.grpc.internal;
import com.google.common.base.Objects;
import com.google.common.base.Preconditions;
import io.grpc.Attributes;
import io.grpc.CallCredentials;
import io.grpc.ChannelCredentials;
import io.grpc.ChannelLogger;
import io.grpc.HttpConnectProxiedSocketAddress;
import java.io.Closeable;
import java.net.SocketAddress;
import java.util.concurrent.ScheduledExecutorService;
import javax.annotation.CheckReturnValue;
import javax.annotation.Nullable;
/** Pre-configured factory for creating {@link ConnectionClientTransport} instances. */
@ -53,6 +56,14 @@ public interface ClientTransportFactory extends Closeable {
*/
ScheduledExecutorService getScheduledExecutorService();
/**
* Swaps to a new ChannelCredentials with all other settings unchanged. Returns null if the
* ChannelCredentials is not supported by the current ClientTransportFactory settings.
*/
@CheckReturnValue
@Nullable
SwapChannelCredentialsResult swapChannelCredentials(ChannelCredentials channelCreds);
/**
* Releases any resources.
*
@ -143,4 +154,15 @@ public interface ClientTransportFactory extends Closeable {
&& Objects.equal(this.connectProxiedSocketAddr, that.connectProxiedSocketAddr);
}
}
final class SwapChannelCredentialsResult {
final ClientTransportFactory transportFactory;
@Nullable final CallCredentials callCredentials;
public SwapChannelCredentialsResult(
ClientTransportFactory transportFactory, @Nullable CallCredentials callCredentials) {
this.transportFactory = Preconditions.checkNotNull(transportFactory, "transportFactory");
this.callCredentials = callCredentials;
}
}
}

View File

@ -30,8 +30,10 @@ import com.google.common.base.Supplier;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.SettableFuture;
import io.grpc.Attributes;
import io.grpc.CallCredentials;
import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ChannelCredentials;
import io.grpc.ChannelLogger;
import io.grpc.ChannelLogger.ChannelLogLevel;
import io.grpc.ClientCall;
@ -46,6 +48,7 @@ import io.grpc.DecompressorRegistry;
import io.grpc.EquivalentAddressGroup;
import io.grpc.ForwardingChannelBuilder;
import io.grpc.ForwardingClientCall;
import io.grpc.Grpc;
import io.grpc.InternalChannelz;
import io.grpc.InternalChannelz.ChannelStats;
import io.grpc.InternalChannelz.ChannelTrace;
@ -74,8 +77,9 @@ import io.grpc.SynchronizationContext;
import io.grpc.SynchronizationContext.ScheduledHandle;
import io.grpc.internal.AutoConfiguredLoadBalancerFactory.AutoConfiguredLoadBalancer;
import io.grpc.internal.ClientCallImpl.ClientStreamProvider;
import io.grpc.internal.ClientTransportFactory.SwapChannelCredentialsResult;
import io.grpc.internal.ManagedChannelImplBuilder.ClientTransportFactoryBuilder;
import io.grpc.internal.ManagedChannelImplBuilder.FixedPortProvider;
import io.grpc.internal.ManagedChannelImplBuilder.UnsupportedClientTransportFactoryBuilder;
import io.grpc.internal.ManagedChannelServiceConfig.MethodInfo;
import io.grpc.internal.ManagedChannelServiceConfig.ServiceConfigConvertedSelector;
import io.grpc.internal.RetriableStream.ChannelBufferMeter;
@ -153,6 +157,8 @@ final class ManagedChannelImpl extends ManagedChannel implements
private final NameResolver.Args nameResolverArgs;
private final AutoConfiguredLoadBalancerFactory loadBalancerFactory;
private final ClientTransportFactory originalTransportFactory;
@Nullable
private final ChannelCredentials originalChannelCreds;
private final ClientTransportFactory transportFactory;
private final ClientTransportFactory oobTransportFactory;
private final RestrictedScheduledExecutor scheduledExecutor;
@ -593,6 +599,7 @@ final class ManagedChannelImpl extends ManagedChannel implements
this.timeProvider = checkNotNull(timeProvider, "timeProvider");
this.executorPool = checkNotNull(builder.executorPool, "executorPool");
this.executor = checkNotNull(executorPool.getObject(), "executor");
this.originalChannelCreds = builder.channelCredentials;
this.originalTransportFactory = clientTransportFactory;
this.transportFactory = new CallCredentialsApplyingTransportFactory(
clientTransportFactory, builder.callCredentials, this.executor);
@ -1516,50 +1523,82 @@ final class ManagedChannelImpl extends ManagedChannel implements
@Override
public ManagedChannelBuilder<?> createResolvingOobChannelBuilder(String target) {
return createResolvingOobChannelBuilder(target, new DefaultChannelCreds());
}
// TODO(creamsoup) prevent main channel to shutdown if oob channel is not terminated
// TODO(zdapeng) register the channel as a subchannel of the parent channel in channelz.
@Override
public ManagedChannelBuilder<?> createResolvingOobChannelBuilder(
final String target, final ChannelCredentials channelCreds) {
checkNotNull(channelCreds, "channelCreds");
final class ResolvingOobChannelBuilder
extends ForwardingChannelBuilder<ResolvingOobChannelBuilder> {
private final ManagedChannelImplBuilder managedChannelImplBuilder;
final ManagedChannelBuilder<?> delegate;
ResolvingOobChannelBuilder(String target) {
managedChannelImplBuilder = new ManagedChannelImplBuilder(target,
new UnsupportedClientTransportFactoryBuilder(),
ResolvingOobChannelBuilder() {
final ClientTransportFactory transportFactory;
CallCredentials callCredentials;
if (channelCreds instanceof DefaultChannelCreds) {
transportFactory = originalTransportFactory;
callCredentials = null;
} else {
SwapChannelCredentialsResult swapResult =
originalTransportFactory.swapChannelCredentials(channelCreds);
if (swapResult == null) {
delegate = Grpc.newChannelBuilder(target, channelCreds);
return;
} else {
transportFactory = swapResult.transportFactory;
callCredentials = swapResult.callCredentials;
}
}
ClientTransportFactoryBuilder transportFactoryBuilder =
new ClientTransportFactoryBuilder() {
@Override
public ClientTransportFactory buildClientTransportFactory() {
return transportFactory;
}
};
delegate = new ManagedChannelImplBuilder(
target,
channelCreds,
callCredentials,
transportFactoryBuilder,
new FixedPortProvider(nameResolverArgs.getDefaultPort()));
managedChannelImplBuilder.executorPool = executorPool;
managedChannelImplBuilder.offloadExecutorPool = offloadExecutorHolder.pool;
}
@Override
protected ManagedChannelBuilder<?> delegate() {
return managedChannelImplBuilder;
}
@Override
public ManagedChannel build() {
// TODO(creamsoup) prevent main channel to shutdown if oob channel is not terminated
return new ManagedChannelImpl(
managedChannelImplBuilder,
originalTransportFactory,
backoffPolicyProvider,
balancerRpcExecutorPool,
stopwatchSupplier,
Collections.<ClientInterceptor>emptyList(),
timeProvider);
return delegate;
}
}
checkState(!terminated, "Channel is terminated");
@SuppressWarnings("deprecation")
ResolvingOobChannelBuilder builder = new ResolvingOobChannelBuilder(target)
ResolvingOobChannelBuilder builder = new ResolvingOobChannelBuilder()
.nameResolverFactory(nameResolverFactory);
return builder
.overrideAuthority(getAuthority())
.overrideAuthority(ManagedChannelImpl.this.authority())
// TODO(zdapeng): executors should not outlive the parent channel.
.executor(executor)
.offloadExecutor(offloadExecutorHolder.getExecutor())
.maxTraceEvents(maxTraceEvents)
.proxyDetector(nameResolverArgs.getProxyDetector())
.userAgent(userAgent);
}
@Override
public ChannelCredentials getUnsafeChannelCredentials() {
if (originalChannelCreds == null) {
return new DefaultChannelCreds();
}
return originalChannelCreds;
}
@Override
public void updateOobChannelAddresses(ManagedChannel channel, EquivalentAddressGroup eag) {
checkArgument(channel instanceof OobChannel,
@ -1596,6 +1635,18 @@ final class ManagedChannelImpl extends ManagedChannel implements
public NameResolverRegistry getNameResolverRegistry() {
return nameResolverRegistry;
}
/**
* A placeholder for channel creds if user did not specify channel creds for the channel.
*/
// TODO(zdapeng): get rid of this class and let all ChannelBuilders always provide a non-null
// channel creds.
final class DefaultChannelCreds extends ChannelCredentials {
@Override
public ChannelCredentials withoutBearerTokens() {
return this;
}
}
}
private final class NameResolverListener extends NameResolver.Listener2 {

View File

@ -24,6 +24,7 @@ import com.google.common.util.concurrent.MoreExecutors;
import io.grpc.Attributes;
import io.grpc.BinaryLog;
import io.grpc.CallCredentials;
import io.grpc.ChannelCredentials;
import io.grpc.ClientInterceptor;
import io.grpc.CompressorRegistry;
import io.grpc.DecompressorRegistry;
@ -111,6 +112,8 @@ public final class ManagedChannelImplBuilder
final String target;
@Nullable
final ChannelCredentials channelCredentials;
@Nullable
final CallCredentials callCredentials;
@Nullable
@ -225,18 +228,23 @@ public final class ManagedChannelImplBuilder
public ManagedChannelImplBuilder(String target,
ClientTransportFactoryBuilder clientTransportFactoryBuilder,
@Nullable ChannelBuilderDefaultPortProvider channelBuilderDefaultPortProvider) {
this(target, null, clientTransportFactoryBuilder, channelBuilderDefaultPortProvider);
this(target, null, null, clientTransportFactoryBuilder, channelBuilderDefaultPortProvider);
}
/**
* Creates a new managed channel builder with a target string, which can be either a valid {@link
* io.grpc.NameResolver}-compliant URI, or an authority string. Transport implementors must
* provide client transport factory builder, and may set custom channel default port provider.
*
* @param channelCreds The ChannelCredentials provided by the user. These may be used when
* creating derivative channels.
*/
public ManagedChannelImplBuilder(String target, @Nullable CallCredentials callCreds,
public ManagedChannelImplBuilder(
String target, @Nullable ChannelCredentials channelCreds, @Nullable CallCredentials callCreds,
ClientTransportFactoryBuilder clientTransportFactoryBuilder,
@Nullable ChannelBuilderDefaultPortProvider channelBuilderDefaultPortProvider) {
this.target = Preconditions.checkNotNull(target, "target");
this.channelCredentials = channelCreds;
this.callCredentials = callCreds;
this.clientTransportFactoryBuilder = Preconditions
.checkNotNull(clientTransportFactoryBuilder, "clientTransportFactoryBuilder");
@ -273,6 +281,7 @@ public final class ManagedChannelImplBuilder
ClientTransportFactoryBuilder clientTransportFactoryBuilder,
@Nullable ChannelBuilderDefaultPortProvider channelBuilderDefaultPortProvider) {
this.target = makeTargetStringForDirectAddress(directServerAddress);
this.channelCredentials = null;
this.callCredentials = null;
this.clientTransportFactoryBuilder = Preconditions
.checkNotNull(clientTransportFactoryBuilder, "clientTransportFactoryBuilder");

View File

@ -16,9 +16,12 @@
package io.grpc.inprocess;
import static com.google.common.truth.Truth.assertThat;
import static io.grpc.internal.GrpcUtil.TIMER_SERVICE;
import static org.junit.Assert.assertSame;
import static org.mockito.Mockito.mock;
import io.grpc.ChannelCredentials;
import io.grpc.internal.ClientTransportFactory;
import io.grpc.internal.FakeClock;
import io.grpc.internal.SharedResourceHolder;
@ -60,4 +63,11 @@ public class InProcessChannelBuilderTest {
clientTransportFactory.close();
}
@Test
public void transportFactoryDoesNotSupportSwapChannelCreds() {
InProcessChannelBuilder builder = InProcessChannelBuilder.forName("foo");
ClientTransportFactory transportFactory = builder.buildTransportFactory();
assertThat(transportFactory.swapChannelCredentials(mock(ChannelCredentials.class))).isNull();
}
}

View File

@ -63,15 +63,18 @@ import io.grpc.CallCredentials;
import io.grpc.CallCredentials.RequestInfo;
import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ChannelCredentials;
import io.grpc.ChannelLogger;
import io.grpc.ClientCall;
import io.grpc.ClientInterceptor;
import io.grpc.ClientInterceptors;
import io.grpc.ClientStreamTracer;
import io.grpc.CompositeChannelCredentials;
import io.grpc.ConnectivityState;
import io.grpc.ConnectivityStateInfo;
import io.grpc.Context;
import io.grpc.EquivalentAddressGroup;
import io.grpc.InsecureChannelCredentials;
import io.grpc.IntegerMarshaller;
import io.grpc.InternalChannelz;
import io.grpc.InternalChannelz.ChannelStats;
@ -105,6 +108,7 @@ import io.grpc.Status;
import io.grpc.Status.Code;
import io.grpc.StringMarshaller;
import io.grpc.internal.ClientTransportFactory.ClientTransportOptions;
import io.grpc.internal.ClientTransportFactory.SwapChannelCredentialsResult;
import io.grpc.internal.InternalSubchannel.TransportLogger;
import io.grpc.internal.ManagedChannelImpl.ScParser;
import io.grpc.internal.ManagedChannelImplBuilder.ClientTransportFactoryBuilder;
@ -1662,7 +1666,8 @@ public class ManagedChannelImplTest {
Metadata.Key<String> metadataKey =
Metadata.Key.of("token", Metadata.ASCII_STRING_MARSHALLER);
String channelCredValue = "channel-provided call cred";
channelBuilder = new ManagedChannelImplBuilder(TARGET,
channelBuilder = new ManagedChannelImplBuilder(
TARGET, InsecureChannelCredentials.create(),
new FakeCallCredentials(metadataKey, channelCredValue),
new UnsupportedClientTransportFactoryBuilder(), new FixedPortProvider(DEFAULT_PORT));
configureBuilder(channelBuilder);
@ -1733,11 +1738,91 @@ public class ManagedChannelImplTest {
call = oob.newCall(method, callOptions);
call.start(mockCallListener2, headers);
verify(transportInfo.transport).newStream(same(method), same(headers), same(callOptions));
// CallOptions may contain StreamTracerFactory for census that is added by default.
verify(transportInfo.transport).newStream(same(method), same(headers), any(CallOptions.class));
assertThat(headers.getAll(metadataKey)).containsExactly(callCredValue);
oob.shutdownNow();
}
@Test
public void oobChannelWithOobChannelCredsHasChannelCallCredentials() {
Metadata.Key<String> metadataKey =
Metadata.Key.of("token", Metadata.ASCII_STRING_MARSHALLER);
String channelCredValue = "channel-provided call cred";
when(mockTransportFactory.swapChannelCredentials(any(CompositeChannelCredentials.class)))
.thenAnswer(new Answer<SwapChannelCredentialsResult>() {
@Override
public SwapChannelCredentialsResult answer(InvocationOnMock invocation) {
CompositeChannelCredentials c =
invocation.getArgument(0, CompositeChannelCredentials.class);
return new SwapChannelCredentialsResult(mockTransportFactory, c.getCallCredentials());
}
});
channelBuilder = new ManagedChannelImplBuilder(
TARGET, InsecureChannelCredentials.create(),
new FakeCallCredentials(metadataKey, channelCredValue),
new UnsupportedClientTransportFactoryBuilder(), new FixedPortProvider(DEFAULT_PORT));
configureBuilder(channelBuilder);
createChannel();
// Verify that the normal channel has call creds, to validate configuration
Subchannel subchannel =
createSubchannelSafely(helper, addressGroup, Attributes.EMPTY, subchannelStateListener);
requestConnectionSafely(helper, subchannel);
MockClientTransportInfo transportInfo = transports.poll();
transportInfo.listener.transportReady();
when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))).thenReturn(
PickResult.withSubchannel(subchannel));
updateBalancingStateSafely(helper, READY, mockPicker);
String callCredValue = "per-RPC call cred";
CallOptions callOptions = CallOptions.DEFAULT
.withCallCredentials(new FakeCallCredentials(metadataKey, callCredValue));
Metadata headers = new Metadata();
ClientCall<String, Integer> call = channel.newCall(method, callOptions);
call.start(mockCallListener, headers);
verify(transportInfo.transport).newStream(same(method), same(headers), same(callOptions));
assertThat(headers.getAll(metadataKey))
.containsExactly(channelCredValue, callCredValue).inOrder();
// Verify that resolving oob channel with oob channel creds provides call creds
String oobChannelCredValue = "oob-channel-provided call cred";
ChannelCredentials oobChannelCreds = CompositeChannelCredentials.create(
InsecureChannelCredentials.create(),
new FakeCallCredentials(metadataKey, oobChannelCredValue));
ManagedChannel oob = helper.createResolvingOobChannelBuilder("oobauthority", oobChannelCreds)
.nameResolverFactory(
new FakeNameResolverFactory.Builder(URI.create("oobauthority")).build())
.defaultLoadBalancingPolicy(MOCK_POLICY_NAME)
.idleTimeout(ManagedChannelImplBuilder.IDLE_MODE_MAX_TIMEOUT_DAYS, TimeUnit.DAYS)
.build();
oob.getState(true);
ArgumentCaptor<Helper> helperCaptor = ArgumentCaptor.forClass(Helper.class);
verify(mockLoadBalancerProvider, times(2)).newLoadBalancer(helperCaptor.capture());
Helper oobHelper = helperCaptor.getValue();
subchannel =
createSubchannelSafely(oobHelper, addressGroup, Attributes.EMPTY, subchannelStateListener);
requestConnectionSafely(oobHelper, subchannel);
transportInfo = transports.poll();
transportInfo.listener.transportReady();
SubchannelPicker mockPicker2 = mock(SubchannelPicker.class);
when(mockPicker2.pickSubchannel(any(PickSubchannelArgs.class))).thenReturn(
PickResult.withSubchannel(subchannel));
updateBalancingStateSafely(oobHelper, READY, mockPicker2);
headers = new Metadata();
call = oob.newCall(method, callOptions);
call.start(mockCallListener2, headers);
// CallOptions may contain StreamTracerFactory for census that is added by default.
verify(transportInfo.transport).newStream(same(method), same(headers), any(CallOptions.class));
assertThat(headers.getAll(metadataKey))
.containsExactly(oobChannelCredValue, callCredValue).inOrder();
oob.shutdownNow();
}
@Test
public void oobChannelsWhenChannelShutdownNow() {
createChannel();

View File

@ -24,6 +24,7 @@ import android.util.Log;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.util.concurrent.MoreExecutors;
import io.grpc.ChannelCredentials;
import io.grpc.ChannelLogger;
import io.grpc.ExperimentalApi;
import io.grpc.Internal;
@ -269,6 +270,11 @@ public final class CronetChannelBuilder
return timeoutService;
}
@Override
public SwapChannelCredentialsResult swapChannelCredentials(ChannelCredentials channelCreds) {
return null;
}
@Override
public void close() {
if (usingSharedScheduler) {

View File

@ -46,6 +46,7 @@ import io.grpc.internal.ManagedChannelImplBuilder.ClientTransportFactoryBuilder;
import io.grpc.internal.ObjectPool;
import io.grpc.internal.SharedResourcePool;
import io.grpc.internal.TransportTracer;
import io.grpc.netty.ProtocolNegotiators.FromChannelCredentialsResult;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFactory;
import io.netty.channel.ChannelOption;
@ -157,12 +158,11 @@ public final class NettyChannelBuilder extends
*/
@CheckReturnValue
public static NettyChannelBuilder forTarget(String target, ChannelCredentials creds) {
ProtocolNegotiators.FromChannelCredentialsResult result = ProtocolNegotiators.from(creds);
FromChannelCredentialsResult result = ProtocolNegotiators.from(creds);
if (result.error != null) {
throw new IllegalArgumentException(result.error);
}
return new NettyChannelBuilder(
target, result.negotiator, result.callCredentials);
return new NettyChannelBuilder(target, creds, result.callCredentials, result.negotiator);
}
private final class NettyChannelTransportFactoryBuilder implements ClientTransportFactoryBuilder {
@ -187,11 +187,11 @@ public final class NettyChannelBuilder extends
this.freezeProtocolNegotiatorFactory = false;
}
@CheckReturnValue
NettyChannelBuilder(
String target, ProtocolNegotiator.ClientFactory negotiator,
@Nullable CallCredentials callCreds) {
managedChannelImplBuilder = new ManagedChannelImplBuilder(target, callCreds,
String target, ChannelCredentials channelCreds, CallCredentials callCreds,
ProtocolNegotiator.ClientFactory negotiator) {
managedChannelImplBuilder = new ManagedChannelImplBuilder(
target, channelCreds, callCreds,
new NettyChannelTransportFactoryBuilder(),
new NettyChannelDefaultPortProvider());
this.protocolNegotiatorFactory = checkNotNull(negotiator, "negotiator");
@ -628,7 +628,8 @@ public final class NettyChannelBuilder extends
private final int flowControlWindow;
private final int maxMessageSize;
private final int maxHeaderListSize;
private final AtomicBackoff keepAliveTimeNanos;
private final long keepAliveTimeNanos;
private final AtomicBackoff keepAliveBackoff;
private final long keepAliveTimeoutNanos;
private final boolean keepAliveWithoutCalls;
private final TransportTracer.Factory transportTracerFactory;
@ -637,7 +638,8 @@ public final class NettyChannelBuilder extends
private boolean closed;
NettyTransportFactory(ProtocolNegotiator protocolNegotiator,
NettyTransportFactory(
ProtocolNegotiator protocolNegotiator,
ChannelFactory<? extends Channel> channelFactory,
Map<ChannelOption<?>, ?> channelOptions, ObjectPool<? extends EventLoopGroup> groupPool,
boolean autoFlowControl, int flowControlWindow, int maxMessageSize, int maxHeaderListSize,
@ -653,7 +655,8 @@ public final class NettyChannelBuilder extends
this.flowControlWindow = flowControlWindow;
this.maxMessageSize = maxMessageSize;
this.maxHeaderListSize = maxHeaderListSize;
this.keepAliveTimeNanos = new AtomicBackoff("keepalive time nanos", keepAliveTimeNanos);
this.keepAliveTimeNanos = keepAliveTimeNanos;
this.keepAliveBackoff = new AtomicBackoff("keepalive time nanos", keepAliveTimeNanos);
this.keepAliveTimeoutNanos = keepAliveTimeoutNanos;
this.keepAliveWithoutCalls = keepAliveWithoutCalls;
this.transportTracerFactory = transportTracerFactory;
@ -678,7 +681,7 @@ public final class NettyChannelBuilder extends
protocolNegotiator);
}
final AtomicBackoff.State keepAliveTimeNanosState = keepAliveTimeNanos.getState();
final AtomicBackoff.State keepAliveTimeNanosState = keepAliveBackoff.getState();
Runnable tooManyPingsRunnable = new Runnable() {
@Override
public void run() {
@ -702,6 +705,21 @@ public final class NettyChannelBuilder extends
return group;
}
@Override
public SwapChannelCredentialsResult swapChannelCredentials(ChannelCredentials channelCreds) {
checkNotNull(channelCreds, "channelCreds");
FromChannelCredentialsResult result = ProtocolNegotiators.from(channelCreds);
if (result.error != null) {
return null;
}
ClientTransportFactory factory = new NettyTransportFactory(
result.negotiator.newNegotiator(), channelFactory, channelOptions, groupPool,
autoFlowControl, flowControlWindow, maxMessageSize, maxHeaderListSize, keepAliveTimeNanos,
keepAliveTimeoutNanos, keepAliveWithoutCalls, transportTracerFactory, localSocketPicker,
useGetForSafeMethods);
return new SwapChannelCredentialsResult(factory, result.callCredentials);
}
@Override
public void close() {
if (closed) {

View File

@ -49,7 +49,7 @@ public final class NettyChannelProvider extends ManagedChannelProvider {
if (result.error != null) {
return NewChannelBuilderResult.error(result.error);
}
return NewChannelBuilderResult.channelBuilder(new NettyChannelBuilder(
target, result.negotiator, result.callCredentials));
return NewChannelBuilderResult.channelBuilder(
new NettyChannelBuilder(target, creds, result.callCredentials, result.negotiator));
}
}

View File

@ -16,13 +16,18 @@
package io.grpc.netty;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.mock;
import io.grpc.ChannelCredentials;
import io.grpc.ManagedChannel;
import io.grpc.internal.ClientTransportFactory;
import io.grpc.internal.ClientTransportFactory.SwapChannelCredentialsResult;
import io.grpc.netty.NettyTestUtil.TrackingObjectPoolForTest;
import io.grpc.netty.ProtocolNegotiators.PlaintextProtocolNegotiatorClientFactory;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFactory;
import io.netty.channel.EventLoopGroup;
@ -282,4 +287,18 @@ public class NettyChannelBuilderTest {
builder.assertEventLoopAndChannelType();
}
@Test
public void transportFactorySupportsNettyChannelCreds() {
NettyChannelBuilder builder = NettyChannelBuilder.forTarget("foo");
ClientTransportFactory transportFactory = builder.buildTransportFactory();
SwapChannelCredentialsResult result = transportFactory.swapChannelCredentials(
mock(ChannelCredentials.class));
assertThat(result).isNull();
result = transportFactory.swapChannelCredentials(
NettyChannelCredentials.create(new PlaintextProtocolNegotiatorClientFactory()));
assertThat(result).isNotNull();
}
}

View File

@ -59,6 +59,7 @@ import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import javax.annotation.CheckReturnValue;
import javax.annotation.Nullable;
import javax.net.SocketFactory;
import javax.net.ssl.HostnameVerifier;
@ -145,7 +146,7 @@ public final class OkHttpChannelBuilder extends
if (result.error != null) {
throw new IllegalArgumentException(result.error);
}
return new OkHttpChannelBuilder(target, result.factory, result.callCredentials);
return new OkHttpChannelBuilder(target, creds, result.callCredentials, result.factory);
}
private Executor transportExecutor;
@ -181,9 +182,11 @@ public final class OkHttpChannelBuilder extends
this.freezeSecurityConfiguration = false;
}
OkHttpChannelBuilder(String target, @Nullable SSLSocketFactory factory,
@Nullable CallCredentials callCredentials) {
managedChannelImplBuilder = new ManagedChannelImplBuilder(target, callCredentials,
OkHttpChannelBuilder(
String target, ChannelCredentials channelCreds, CallCredentials callCreds,
SSLSocketFactory factory) {
managedChannelImplBuilder = new ManagedChannelImplBuilder(
target, channelCreds, callCreds,
new OkHttpChannelTransportFactoryBuilder(),
new OkHttpChannelDefaultPortProvider());
this.sslSocketFactory = factory;
@ -631,7 +634,8 @@ public final class OkHttpChannelBuilder extends
private final ConnectionSpec connectionSpec;
private final int maxMessageSize;
private final boolean enableKeepAlive;
private final AtomicBackoff keepAliveTimeNanos;
private final long keepAliveTimeNanos;
private final AtomicBackoff keepAliveBackoff;
private final long keepAliveTimeoutNanos;
private final int flowControlWindow;
private final boolean keepAliveWithoutCalls;
@ -665,7 +669,8 @@ public final class OkHttpChannelBuilder extends
this.connectionSpec = connectionSpec;
this.maxMessageSize = maxMessageSize;
this.enableKeepAlive = enableKeepAlive;
this.keepAliveTimeNanos = new AtomicBackoff("keepalive time nanos", keepAliveTimeNanos);
this.keepAliveTimeNanos = keepAliveTimeNanos;
this.keepAliveBackoff = new AtomicBackoff("keepalive time nanos", keepAliveTimeNanos);
this.keepAliveTimeoutNanos = keepAliveTimeoutNanos;
this.flowControlWindow = flowControlWindow;
this.keepAliveWithoutCalls = keepAliveWithoutCalls;
@ -689,7 +694,7 @@ public final class OkHttpChannelBuilder extends
if (closed) {
throw new IllegalStateException("The transport factory is closed.");
}
final AtomicBackoff.State keepAliveTimeNanosState = keepAliveTimeNanos.getState();
final AtomicBackoff.State keepAliveTimeNanosState = keepAliveBackoff.getState();
Runnable tooManyPingsRunnable = new Runnable() {
@Override
public void run() {
@ -727,6 +732,33 @@ public final class OkHttpChannelBuilder extends
return timeoutService;
}
@Nullable
@CheckReturnValue
@Override
public SwapChannelCredentialsResult swapChannelCredentials(ChannelCredentials channelCreds) {
SslSocketFactoryResult result = sslSocketFactoryFrom(channelCreds);
if (result.error != null) {
return null;
}
ClientTransportFactory factory = new OkHttpTransportFactory(
executor,
timeoutService,
socketFactory,
result.factory,
hostnameVerifier,
connectionSpec,
maxMessageSize,
enableKeepAlive,
keepAliveTimeNanos,
keepAliveTimeoutNanos,
flowControlWindow,
keepAliveWithoutCalls,
maxInboundMetadataSize,
transportTracerFactory,
useGetForSafeMethods);
return new SwapChannelCredentialsResult(factory, result.callCredentials);
}
@Override
public void close() {
if (closed) {

View File

@ -55,6 +55,6 @@ public final class OkHttpChannelProvider extends ManagedChannelProvider {
return NewChannelBuilderResult.error(result.error);
}
return NewChannelBuilderResult.channelBuilder(new OkHttpChannelBuilder(
target, result.factory, result.callCredentials));
target, creds, result.callCredentials, result.factory));
}
}

View File

@ -34,6 +34,7 @@ import io.grpc.InsecureChannelCredentials;
import io.grpc.ManagedChannel;
import io.grpc.TlsChannelCredentials;
import io.grpc.internal.ClientTransportFactory;
import io.grpc.internal.ClientTransportFactory.SwapChannelCredentialsResult;
import io.grpc.internal.FakeClock;
import io.grpc.internal.GrpcUtil;
import io.grpc.internal.SharedResourceHolder;
@ -358,6 +359,20 @@ public class OkHttpChannelBuilderTest {
transportFactory.close();
}
@Test
public void transportFactorySupportsOkHttpChannelCreds() {
OkHttpChannelBuilder builder = OkHttpChannelBuilder.forTarget("foo");
ClientTransportFactory transportFactory = builder.buildTransportFactory();
SwapChannelCredentialsResult result = transportFactory.swapChannelCredentials(
mock(ChannelCredentials.class));
assertThat(result).isNull();
result = transportFactory.swapChannelCredentials(
SslSocketFactoryChannelCredentials.create(mock(SSLSocketFactory.class)));
assertThat(result).isNotNull();
}
private static final class FakeChannelLogger extends ChannelLogger {
@Override