core/internal: add 3-arg newStream method to ClientTransport interface (#1898)

adding 
ClientStream newStream(MethodDescriptor<?, ?> method, Metadata headers, CallOptions callOptions);
to ClientTransport interface

Created this PR first because both fail fast implementation and another change will be using this interface change
This commit is contained in:
ZHANG Dapeng 2016-06-06 19:43:15 -07:00
parent 53cd333531
commit b88ea27b53
13 changed files with 136 additions and 49 deletions

View File

@ -34,6 +34,7 @@ package io.grpc.inprocess;
import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkNotNull;
import io.grpc.Attributes; import io.grpc.Attributes;
import io.grpc.CallOptions;
import io.grpc.Compressor; import io.grpc.Compressor;
import io.grpc.Decompressor; import io.grpc.Decompressor;
import io.grpc.Metadata; import io.grpc.Metadata;
@ -126,7 +127,7 @@ class InProcessTransport implements ServerTransport, ManagedClientTransport {
@Override @Override
public synchronized ClientStream newStream( public synchronized ClientStream newStream(
final MethodDescriptor<?, ?> method, final Metadata headers) { final MethodDescriptor<?, ?> method, final Metadata headers, final CallOptions callOptions) {
if (shutdownStatus != null) { if (shutdownStatus != null) {
final Status capturedStatus = shutdownStatus; final Status capturedStatus = shutdownStatus;
return new NoopClientStream() { return new NoopClientStream() {
@ -140,6 +141,12 @@ class InProcessTransport implements ServerTransport, ManagedClientTransport {
return new InProcessStream(method, headers).clientStream; return new InProcessStream(method, headers).clientStream;
} }
@Override
public synchronized ClientStream newStream(
final MethodDescriptor<?, ?> method, final Metadata headers) {
return newStream(method, headers, CallOptions.DEFAULT);
}
@Override @Override
public synchronized void ping(final PingCallback callback, Executor executor) { public synchronized void ping(final PingCallback callback, Executor executor) {
if (terminated) { if (terminated) {

View File

@ -205,7 +205,7 @@ final class ClientCallImpl<ReqT, RespT> extends ClientCall<ReqT, RespT>
updateTimeoutHeaders(effectiveDeadline, callOptions.getDeadline(), updateTimeoutHeaders(effectiveDeadline, callOptions.getDeadline(),
parentContext.getDeadline(), headers); parentContext.getDeadline(), headers);
ClientTransport transport = clientTransportProvider.get(callOptions); ClientTransport transport = clientTransportProvider.get(callOptions);
stream = transport.newStream(method, headers); stream = transport.newStream(method, headers, callOptions);
} else { } else {
stream = new FailingClientStream(DEADLINE_EXCEEDED); stream = new FailingClientStream(DEADLINE_EXCEEDED);
} }

View File

@ -31,6 +31,7 @@
package io.grpc.internal; package io.grpc.internal;
import io.grpc.CallOptions;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor;
@ -56,9 +57,13 @@ public interface ClientTransport {
* *
* @param method the descriptor of the remote method to be called for this stream. * @param method the descriptor of the remote method to be called for this stream.
* @param headers to send at the beginning of the call * @param headers to send at the beginning of the call
* @param callOptions runtime options of the call
* @return the newly created stream. * @return the newly created stream.
*/ */
// TODO(nmittler): Consider also throwing for stopping. // TODO(nmittler): Consider also throwing for stopping.
ClientStream newStream(MethodDescriptor<?, ?> method, Metadata headers, CallOptions callOptions);
// TODO(zdapeng): Remove tow-argument version in favor of three-argument overload.
ClientStream newStream(MethodDescriptor<?, ?> method, Metadata headers); ClientStream newStream(MethodDescriptor<?, ?> method, Metadata headers);
/** /**

View File

@ -36,6 +36,7 @@ import com.google.common.base.Preconditions;
import com.google.common.base.Supplier; import com.google.common.base.Supplier;
import com.google.common.base.Suppliers; import com.google.common.base.Suppliers;
import io.grpc.CallOptions;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor;
import io.grpc.Status; import io.grpc.Status;
@ -85,25 +86,31 @@ class DelayedClientTransport implements ManagedClientTransport {
} }
@Override @Override
public ClientStream newStream(MethodDescriptor<?, ?> method, Metadata headers) { public ClientStream newStream(MethodDescriptor<?, ?> method, Metadata headers, CallOptions
callOptions) {
Supplier<ClientTransport> supplier = transportSupplier; Supplier<ClientTransport> supplier = transportSupplier;
if (supplier == null) { if (supplier == null) {
synchronized (lock) { synchronized (lock) {
// Check again, since it may have changed while waiting for lock // Check again, since it may have changed while waiting for lock
supplier = transportSupplier; supplier = transportSupplier;
if (supplier == null && !shutdown) { if (supplier == null && !shutdown) {
PendingStream pendingStream = new PendingStream(method, headers); PendingStream pendingStream = new PendingStream(method, headers, callOptions);
pendingStreams.add(pendingStream); pendingStreams.add(pendingStream);
return pendingStream; return pendingStream;
} }
} }
} }
if (supplier != null) { if (supplier != null) {
return supplier.get().newStream(method, headers); return supplier.get().newStream(method, headers, callOptions);
} }
return new FailingClientStream(Status.UNAVAILABLE.withDescription("transport shutdown")); return new FailingClientStream(Status.UNAVAILABLE.withDescription("transport shutdown"));
} }
@Override
public ClientStream newStream(MethodDescriptor<?, ?> method, Metadata headers) {
return newStream(method, headers, CallOptions.DEFAULT);
}
@Override @Override
public void ping(final PingCallback callback, Executor executor) { public void ping(final PingCallback callback, Executor executor) {
Supplier<ClientTransport> supplier = transportSupplier; Supplier<ClientTransport> supplier = transportSupplier;
@ -133,7 +140,7 @@ class DelayedClientTransport implements ManagedClientTransport {
/** /**
* Prevents creating any new streams until {@link #setTransport} is called. Buffered streams are * Prevents creating any new streams until {@link #setTransport} is called. Buffered streams are
* not failed, so if {@link #shutdown} is called when {@link #setTransport} has not been called, * not failed, so if {@link #shutdown} is called when {@link #setTransport} has not been called,
* you still need to call {@link setTransport} to make this transport terminated. * you still need to call {@link #setTransport} to make this transport terminated.
*/ */
@Override @Override
public void shutdown() { public void shutdown() {
@ -257,14 +264,17 @@ class DelayedClientTransport implements ManagedClientTransport {
private class PendingStream extends DelayedStream { private class PendingStream extends DelayedStream {
private final MethodDescriptor<?, ?> method; private final MethodDescriptor<?, ?> method;
private final Metadata headers; private final Metadata headers;
private final CallOptions callOptions;
private PendingStream(MethodDescriptor<?, ?> method, Metadata headers) { private PendingStream(MethodDescriptor<?, ?> method, Metadata headers, CallOptions
callOptions) {
this.method = method; this.method = method;
this.headers = headers; this.headers = headers;
this.callOptions = callOptions;
} }
private void createRealStream(ClientTransport transport) { private void createRealStream(ClientTransport transport) {
setStream(transport.newStream(method, headers)); setStream(transport.newStream(method, headers, callOptions));
} }
@Override @Override

View File

@ -34,6 +34,7 @@ package io.grpc.internal;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import io.grpc.CallOptions;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor;
import io.grpc.Status; import io.grpc.Status;
@ -54,10 +55,16 @@ class FailingClientTransport implements ClientTransport {
} }
@Override @Override
public ClientStream newStream(MethodDescriptor<?, ?> method, Metadata headers) { public ClientStream newStream(MethodDescriptor<?, ?> method, Metadata headers, CallOptions
callOptions) {
return new FailingClientStream(error); return new FailingClientStream(error);
} }
@Override
public ClientStream newStream(MethodDescriptor<?, ?> method, Metadata headers) {
return newStream(method, headers, CallOptions.DEFAULT);
}
@Override @Override
public void ping(final PingCallback callback, Executor executor) { public void ping(final PingCallback callback, Executor executor) {
executor.execute(new Runnable() { executor.execute(new Runnable() {

View File

@ -42,6 +42,7 @@ import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail; import static org.junit.Assert.fail;
import static org.mockito.Matchers.any; import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq; import static org.mockito.Matchers.eq;
import static org.mockito.Matchers.same;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
@ -135,7 +136,8 @@ public class ClientCallImplTest {
public void setUp() { public void setUp() {
MockitoAnnotations.initMocks(this); MockitoAnnotations.initMocks(this);
when(provider.get(any(CallOptions.class))).thenReturn(transport); when(provider.get(any(CallOptions.class))).thenReturn(transport);
when(transport.newStream(any(MethodDescriptor.class), any(Metadata.class))).thenReturn(stream); when(transport.newStream(any(MethodDescriptor.class), any(Metadata.class),
any(CallOptions.class))).thenReturn(stream);
} }
@After @After
@ -156,7 +158,7 @@ public class ClientCallImplTest {
call.start(callListener, new Metadata()); call.start(callListener, new Metadata());
ArgumentCaptor<Metadata> metadataCaptor = ArgumentCaptor.forClass(Metadata.class); ArgumentCaptor<Metadata> metadataCaptor = ArgumentCaptor.forClass(Metadata.class);
verify(transport).newStream(eq(method), metadataCaptor.capture()); verify(transport).newStream(eq(method), metadataCaptor.capture(), same(CallOptions.DEFAULT));
Metadata actual = metadataCaptor.getValue(); Metadata actual = metadataCaptor.getValue();
Set<String> acceptedEncodings = Set<String> acceptedEncodings =
@ -178,6 +180,23 @@ public class ClientCallImplTest {
verify(stream).setAuthority("overridden-authority"); verify(stream).setAuthority("overridden-authority");
} }
@Test
public void callOptionsPropagatedToTransport() {
final CallOptions callOptions = CallOptions.DEFAULT.withAuthority("dummy_value");
final ClientCallImpl<Void, Void> call = new ClientCallImpl<Void, Void>(
method,
MoreExecutors.directExecutor(),
callOptions,
provider,
deadlineCancellationExecutor)
.setDecompressorRegistry(decompressorRegistry);
final Metadata metadata = new Metadata();
call.start(callListener, metadata);
verify(transport).newStream(same(method), same(metadata), same(callOptions));
}
@Test @Test
public void authorityNotPropagatedToStream() { public void authorityNotPropagatedToStream() {
ClientCallImpl<Void, Void> call = new ClientCallImpl<Void, Void>( ClientCallImpl<Void, Void> call = new ClientCallImpl<Void, Void>(

View File

@ -43,12 +43,12 @@ import static org.mockito.Mockito.when;
import com.google.common.base.Supplier; import com.google.common.base.Supplier;
import io.grpc.CallOptions;
import io.grpc.IntegerMarshaller; import io.grpc.IntegerMarshaller;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.StringMarshaller; import io.grpc.StringMarshaller;
import io.grpc.internal.ClientTransport;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
@ -90,14 +90,19 @@ public class DelayedClientTransportTest {
private final Metadata headers = new Metadata(); private final Metadata headers = new Metadata();
private final Metadata headers2 = new Metadata(); private final Metadata headers2 = new Metadata();
private final CallOptions callOptions = CallOptions.DEFAULT.withAuthority("dummy_value");
private final CallOptions callOptions2 = CallOptions.DEFAULT.withAuthority("dummy_value2");
private final FakeClock fakeExecutor = new FakeClock(); private final FakeClock fakeExecutor = new FakeClock();
private final DelayedClientTransport delayedTransport = new DelayedClientTransport( private final DelayedClientTransport delayedTransport = new DelayedClientTransport(
fakeExecutor.scheduledExecutorService); fakeExecutor.scheduledExecutorService);
@Before public void setUp() { @Before public void setUp() {
MockitoAnnotations.initMocks(this); MockitoAnnotations.initMocks(this);
when(mockRealTransport.newStream(same(method), same(headers))).thenReturn(mockRealStream); when(mockRealTransport.newStream(same(method), same(headers), same(callOptions)))
when(mockRealTransport2.newStream(same(method2), same(headers2))).thenReturn(mockRealStream2); .thenReturn(mockRealStream);
when(mockRealTransport2.newStream(same(method2), same(headers2), same(callOptions2)))
.thenReturn(mockRealStream2);
delayedTransport.start(transportListener); delayedTransport.start(transportListener);
} }
@ -106,8 +111,8 @@ public class DelayedClientTransportTest {
} }
@Test public void transportsAreUsedInOrder() { @Test public void transportsAreUsedInOrder() {
delayedTransport.newStream(method, headers); delayedTransport.newStream(method, headers, callOptions);
delayedTransport.newStream(method2, headers2); delayedTransport.newStream(method2, headers2, callOptions2);
assertEquals(0, fakeExecutor.numPendingTasks()); assertEquals(0, fakeExecutor.numPendingTasks());
delayedTransport.setTransportSupplier(new Supplier<ClientTransport>() { delayedTransport.setTransportSupplier(new Supplier<ClientTransport>() {
final Iterator<ClientTransport> it = final Iterator<ClientTransport> it =
@ -118,13 +123,13 @@ public class DelayedClientTransportTest {
} }
}); });
assertEquals(1, fakeExecutor.runDueTasks()); assertEquals(1, fakeExecutor.runDueTasks());
verify(mockRealTransport).newStream(same(method), same(headers)); verify(mockRealTransport).newStream(same(method), same(headers), same(callOptions));
verify(mockRealTransport2).newStream(same(method2), same(headers2)); verify(mockRealTransport2).newStream(same(method2), same(headers2), same(callOptions2));
} }
@Test public void streamStartThenSetTransport() { @Test public void streamStartThenSetTransport() {
assertFalse(delayedTransport.hasPendingStreams()); assertFalse(delayedTransport.hasPendingStreams());
ClientStream stream = delayedTransport.newStream(method, headers); ClientStream stream = delayedTransport.newStream(method, headers, callOptions);
stream.start(streamListener); stream.start(streamListener);
assertEquals(1, delayedTransport.getPendingStreamsCount()); assertEquals(1, delayedTransport.getPendingStreamsCount());
assertTrue(delayedTransport.hasPendingStreams()); assertTrue(delayedTransport.hasPendingStreams());
@ -134,12 +139,12 @@ public class DelayedClientTransportTest {
assertEquals(0, delayedTransport.getPendingStreamsCount()); assertEquals(0, delayedTransport.getPendingStreamsCount());
assertFalse(delayedTransport.hasPendingStreams()); assertFalse(delayedTransport.hasPendingStreams());
assertEquals(1, fakeExecutor.runDueTasks()); assertEquals(1, fakeExecutor.runDueTasks());
verify(mockRealTransport).newStream(same(method), same(headers)); verify(mockRealTransport).newStream(same(method), same(headers), same(callOptions));
verify(mockRealStream).start(same(streamListener)); verify(mockRealStream).start(same(streamListener));
} }
@Test public void newStreamThenSetTransportThenShutdown() { @Test public void newStreamThenSetTransportThenShutdown() {
ClientStream stream = delayedTransport.newStream(method, headers); ClientStream stream = delayedTransport.newStream(method, headers, callOptions);
assertEquals(1, delayedTransport.getPendingStreamsCount()); assertEquals(1, delayedTransport.getPendingStreamsCount());
assertTrue(stream instanceof DelayedStream); assertTrue(stream instanceof DelayedStream);
delayedTransport.setTransport(mockRealTransport); delayedTransport.setTransport(mockRealTransport);
@ -148,7 +153,7 @@ public class DelayedClientTransportTest {
verify(transportListener).transportShutdown(any(Status.class)); verify(transportListener).transportShutdown(any(Status.class));
verify(transportListener).transportTerminated(); verify(transportListener).transportTerminated();
assertEquals(1, fakeExecutor.runDueTasks()); assertEquals(1, fakeExecutor.runDueTasks());
verify(mockRealTransport).newStream(same(method), same(headers)); verify(mockRealTransport).newStream(same(method), same(headers), same(callOptions));
stream.start(streamListener); stream.start(streamListener);
verify(mockRealStream).start(same(streamListener)); verify(mockRealStream).start(same(streamListener));
} }
@ -166,11 +171,11 @@ public class DelayedClientTransportTest {
delayedTransport.shutdown(); delayedTransport.shutdown();
verify(transportListener).transportShutdown(any(Status.class)); verify(transportListener).transportShutdown(any(Status.class));
verify(transportListener).transportTerminated(); verify(transportListener).transportTerminated();
ClientStream stream = delayedTransport.newStream(method, headers); ClientStream stream = delayedTransport.newStream(method, headers, callOptions);
assertEquals(0, delayedTransport.getPendingStreamsCount()); assertEquals(0, delayedTransport.getPendingStreamsCount());
stream.start(streamListener); stream.start(streamListener);
assertFalse(stream instanceof DelayedStream); assertFalse(stream instanceof DelayedStream);
verify(mockRealTransport).newStream(same(method), same(headers)); verify(mockRealTransport).newStream(same(method), same(headers), same(callOptions));
verify(mockRealStream).start(same(streamListener)); verify(mockRealStream).start(same(streamListener));
} }
@ -179,11 +184,11 @@ public class DelayedClientTransportTest {
delayedTransport.shutdownNow(Status.UNAVAILABLE); delayedTransport.shutdownNow(Status.UNAVAILABLE);
verify(transportListener).transportShutdown(any(Status.class)); verify(transportListener).transportShutdown(any(Status.class));
verify(transportListener).transportTerminated(); verify(transportListener).transportTerminated();
ClientStream stream = delayedTransport.newStream(method, headers); ClientStream stream = delayedTransport.newStream(method, headers, callOptions);
assertEquals(0, delayedTransport.getPendingStreamsCount()); assertEquals(0, delayedTransport.getPendingStreamsCount());
stream.start(streamListener); stream.start(streamListener);
assertFalse(stream instanceof DelayedStream); assertFalse(stream instanceof DelayedStream);
verify(mockRealTransport).newStream(same(method), same(headers)); verify(mockRealTransport).newStream(same(method), same(headers), same(callOptions));
verify(mockRealStream).start(same(streamListener)); verify(mockRealStream).start(same(streamListener));
} }

View File

@ -200,14 +200,16 @@ public class ManagedChannelImplTest {
when(mockTransportFactory.newClientTransport( when(mockTransportFactory.newClientTransport(
any(SocketAddress.class), any(String.class), any(String.class))) any(SocketAddress.class), any(String.class), any(String.class)))
.thenReturn(mockTransport); .thenReturn(mockTransport);
when(mockTransport.newStream(same(method), same(headers))).thenReturn(mockStream); when(mockTransport.newStream(same(method), same(headers), same(CallOptions.DEFAULT)))
.thenReturn(mockStream);
call.start(mockCallListener, headers); call.start(mockCallListener, headers);
verify(mockTransportFactory, timeout(1000)) verify(mockTransportFactory, timeout(1000))
.newClientTransport(same(socketAddress), eq(authority), eq(userAgent)); .newClientTransport(same(socketAddress), eq(authority), eq(userAgent));
verify(mockTransport, timeout(1000)).start(transportListenerCaptor.capture()); verify(mockTransport, timeout(1000)).start(transportListenerCaptor.capture());
ManagedClientTransport.Listener transportListener = transportListenerCaptor.getValue(); ManagedClientTransport.Listener transportListener = transportListenerCaptor.getValue();
transportListener.transportReady(); transportListener.transportReady();
verify(mockTransport, timeout(1000)).newStream(same(method), same(headers)); verify(mockTransport, timeout(1000)).newStream(same(method), same(headers),
same(CallOptions.DEFAULT));
verify(mockStream, timeout(1000)).start(streamListenerCaptor.capture()); verify(mockStream, timeout(1000)).start(streamListenerCaptor.capture());
verify(mockStream).setCompressor(isA(Compressor.class)); verify(mockStream).setCompressor(isA(Compressor.class));
// Depends on how quick the real transport is created, ClientCallImpl may start on mockStream // Depends on how quick the real transport is created, ClientCallImpl may start on mockStream
@ -221,9 +223,11 @@ public class ManagedChannelImplTest {
ClientCall<String, Integer> call2 = channel.newCall(method, CallOptions.DEFAULT); ClientCall<String, Integer> call2 = channel.newCall(method, CallOptions.DEFAULT);
ClientStream mockStream2 = mock(ClientStream.class); ClientStream mockStream2 = mock(ClientStream.class);
Metadata headers2 = new Metadata(); Metadata headers2 = new Metadata();
when(mockTransport.newStream(same(method), same(headers2))).thenReturn(mockStream2); when(mockTransport.newStream(same(method), same(headers2), same(CallOptions.DEFAULT)))
.thenReturn(mockStream2);
call2.start(mockCallListener2, headers2); call2.start(mockCallListener2, headers2);
verify(mockTransport, timeout(1000)).newStream(same(method), same(headers2)); verify(mockTransport, timeout(1000)).newStream(same(method), same(headers2),
same(CallOptions.DEFAULT));
verify(mockStream2, timeout(1000)).start(streamListenerCaptor.capture()); verify(mockStream2, timeout(1000)).start(streamListenerCaptor.capture());
ClientStreamListener streamListener2 = streamListenerCaptor.getValue(); ClientStreamListener streamListener2 = streamListenerCaptor.getValue();
Metadata trailers = new Metadata(); Metadata trailers = new Metadata();
@ -278,14 +282,16 @@ public class ManagedChannelImplTest {
when(mockTransportFactory.newClientTransport( when(mockTransportFactory.newClientTransport(
any(SocketAddress.class), any(String.class), any(String.class))) any(SocketAddress.class), any(String.class), any(String.class)))
.thenReturn(mockTransport); .thenReturn(mockTransport);
when(mockTransport.newStream(same(method), same(headers))).thenReturn(mockStream); when(mockTransport.newStream(same(method), same(headers), same(CallOptions.DEFAULT)))
.thenReturn(mockStream);
call.start(mockCallListener, headers); call.start(mockCallListener, headers);
verify(mockTransportFactory, timeout(1000)) verify(mockTransportFactory, timeout(1000))
.newClientTransport(same(socketAddress), eq(authority), any(String.class)); .newClientTransport(same(socketAddress), eq(authority), any(String.class));
verify(mockTransport, timeout(1000)).start(transportListenerCaptor.capture()); verify(mockTransport, timeout(1000)).start(transportListenerCaptor.capture());
ManagedClientTransport.Listener transportListener = transportListenerCaptor.getValue(); ManagedClientTransport.Listener transportListener = transportListenerCaptor.getValue();
transportListener.transportReady(); transportListener.transportReady();
verify(mockTransport, timeout(1000)).newStream(same(method), same(headers)); verify(mockTransport, timeout(1000)).newStream(same(method), same(headers),
same(CallOptions.DEFAULT));
verify(mockStream, timeout(1000)).start(streamListenerCaptor.capture()); verify(mockStream, timeout(1000)).start(streamListenerCaptor.capture());
verify(mockStream).setCompressor(isA(Compressor.class)); verify(mockStream).setCompressor(isA(Compressor.class));
// Depends on how quick the real transport is created, ClientCallImpl may start on mockStream // Depends on how quick the real transport is created, ClientCallImpl may start on mockStream
@ -342,7 +348,8 @@ public class ManagedChannelImplTest {
// Create transport and call // Create transport and call
ClientStream mockStream = mock(ClientStream.class); ClientStream mockStream = mock(ClientStream.class);
Metadata headers = new Metadata(); Metadata headers = new Metadata();
when(mockTransport.newStream(same(method), same(headers))).thenReturn(mockStream); when(mockTransport.newStream(same(method), same(headers), same(CallOptions.DEFAULT)))
.thenReturn(mockStream);
call.start(mockCallListener, headers); call.start(mockCallListener, headers);
verify(mockTransport, timeout(1000)).start(transportListenerCaptor.capture()); verify(mockTransport, timeout(1000)).start(transportListenerCaptor.capture());
ManagedClientTransport.Listener transportListener = transportListenerCaptor.getValue(); ManagedClientTransport.Listener transportListener = transportListenerCaptor.getValue();
@ -444,18 +451,19 @@ public class ManagedChannelImplTest {
public void callOptionsExecutor() { public void callOptionsExecutor() {
Metadata headers = new Metadata(); Metadata headers = new Metadata();
ClientStream mockStream = mock(ClientStream.class); ClientStream mockStream = mock(ClientStream.class);
when(mockTransport.newStream(same(method), same(headers))).thenReturn(mockStream); when(mockTransport.newStream(same(method), same(headers), any(CallOptions.class)))
.thenReturn(mockStream);
FakeClock fakeExecutor = new FakeClock(); FakeClock fakeExecutor = new FakeClock();
ManagedChannel channel = createChannel( ManagedChannel channel = createChannel(
new FakeNameResolverFactory(true), NO_INTERCEPTOR); new FakeNameResolverFactory(true), NO_INTERCEPTOR);
CallOptions options = CallOptions.DEFAULT.withExecutor(fakeExecutor.scheduledExecutorService);
ClientCall<String, Integer> call = channel.newCall(method, CallOptions.DEFAULT.withExecutor( ClientCall<String, Integer> call = channel.newCall(method, options);
fakeExecutor.scheduledExecutorService));
call.start(mockCallListener, headers); call.start(mockCallListener, headers);
verify(mockTransport, timeout(1000)).start(transportListenerCaptor.capture()); verify(mockTransport, timeout(1000)).start(transportListenerCaptor.capture());
transportListenerCaptor.getValue().transportReady(); transportListenerCaptor.getValue().transportReady();
verify(mockTransport, timeout(1000)).newStream(same(method), same(headers)); verify(mockTransport, timeout(1000)).newStream(same(method), same(headers), same(options));
verify(mockStream, timeout(1000)).start(streamListenerCaptor.capture()); verify(mockStream, timeout(1000)).start(streamListenerCaptor.capture());
ClientStreamListener streamListener = streamListenerCaptor.getValue(); ClientStreamListener streamListener = streamListenerCaptor.getValue();
Metadata trailers = new Metadata(); Metadata trailers = new Metadata();
@ -608,7 +616,8 @@ public class ManagedChannelImplTest {
.newClientTransport(same(goodAddress), any(String.class), any(String.class)); .newClientTransport(same(goodAddress), any(String.class), any(String.class));
verify(goodTransport, timeout(1000)).start(goodTransportListenerCaptor.capture()); verify(goodTransport, timeout(1000)).start(goodTransportListenerCaptor.capture());
goodTransportListenerCaptor.getValue().transportReady(); goodTransportListenerCaptor.getValue().transportReady();
verify(goodTransport, timeout(1000)).newStream(same(method), same(headers)); verify(goodTransport, timeout(1000)).newStream(same(method), same(headers),
same(CallOptions.DEFAULT));
// The bad transport was never used. // The bad transport was never used.
verify(badTransport, times(0)).newStream(any(MethodDescriptor.class), any(Metadata.class)); verify(badTransport, times(0)).newStream(any(MethodDescriptor.class), any(Metadata.class));
} }
@ -707,7 +716,8 @@ public class ManagedChannelImplTest {
.newClientTransport(same(addr1), any(String.class), any(String.class)); .newClientTransport(same(addr1), any(String.class), any(String.class));
verify(transport1, timeout(1000)).start(transportListenerCaptor.capture()); verify(transport1, timeout(1000)).start(transportListenerCaptor.capture());
transportListenerCaptor.getValue().transportReady(); transportListenerCaptor.getValue().transportReady();
verify(transport1, timeout(1000)).newStream(same(method), same(headers)); verify(transport1, timeout(1000)).newStream(same(method), same(headers),
same(CallOptions.DEFAULT));
transportListenerCaptor.getValue().transportShutdown(Status.UNAVAILABLE); transportListenerCaptor.getValue().transportShutdown(Status.UNAVAILABLE);
// Second call still use the first address, since it was successfully connected. // Second call still use the first address, since it was successfully connected.
@ -717,7 +727,8 @@ public class ManagedChannelImplTest {
verify(mockTransportFactory, times(2)) verify(mockTransportFactory, times(2))
.newClientTransport(same(addr1), any(String.class), any(String.class)); .newClientTransport(same(addr1), any(String.class), any(String.class));
transportListenerCaptor.getValue().transportReady(); transportListenerCaptor.getValue().transportReady();
verify(transport2, timeout(1000)).newStream(same(method), same(headers)); verify(transport2, timeout(1000)).newStream(same(method), same(headers),
same(CallOptions.DEFAULT));
} }
@Test @Test

View File

@ -47,6 +47,7 @@ import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import io.grpc.Attributes; import io.grpc.Attributes;
import io.grpc.CallOptions;
import io.grpc.ClientInterceptor; import io.grpc.ClientInterceptor;
import io.grpc.CompressorRegistry; import io.grpc.CompressorRegistry;
import io.grpc.DecompressorRegistry; import io.grpc.DecompressorRegistry;
@ -98,6 +99,8 @@ public class ManagedChannelImplTransportManagerTest {
private final MethodDescriptor<String, String> method2 = MethodDescriptor.create( private final MethodDescriptor<String, String> method2 = MethodDescriptor.create(
MethodDescriptor.MethodType.UNKNOWN, "/service/method2", MethodDescriptor.MethodType.UNKNOWN, "/service/method2",
new StringMarshaller(), new StringMarshaller()); new StringMarshaller(), new StringMarshaller());
private final CallOptions callOptions = CallOptions.DEFAULT.withAuthority("dummy_value");
private final CallOptions callOptions2 = CallOptions.DEFAULT.withAuthority("dummy_value2");
private ManagedChannelImpl channel; private ManagedChannelImpl channel;
@ -188,7 +191,7 @@ public class ManagedChannelImplTransportManagerTest {
// Subsequent getTransport() will use the next address // Subsequent getTransport() will use the next address
ClientTransport t2 = tm.getTransport(addressGroup); ClientTransport t2 = tm.getTransport(addressGroup);
assertNotNull(t2); assertNotNull(t2);
t2.newStream(method, new Metadata()); t2.newStream(method, new Metadata(), callOptions);
// Will keep the previous back-off policy, and not consult back-off policy // Will keep the previous back-off policy, and not consult back-off policy
verify(mockTransportFactory, timeout(1000)).newClientTransport(addr2, authority, userAgent); verify(mockTransportFactory, timeout(1000)).newClientTransport(addr2, authority, userAgent);
verify(mockBackoffPolicyProvider, times(backoffReset)).get(); verify(mockBackoffPolicyProvider, times(backoffReset)).get();
@ -196,7 +199,8 @@ public class ManagedChannelImplTransportManagerTest {
ClientTransport rt2 = transportInfo.transport; ClientTransport rt2 = transportInfo.transport;
// Make the second transport ready // Make the second transport ready
transportInfo.listener.transportReady(); transportInfo.listener.transportReady();
verify(rt2, timeout(1000)).newStream(same(method), any(Metadata.class)); verify(rt2, timeout(1000)).newStream(same(method), any(Metadata.class),
same(callOptions));
verify(mockNameResolver, times(0)).refresh(); verify(mockNameResolver, times(0)).refresh();
// Disconnect the second transport // Disconnect the second transport
transportInfo.listener.transportShutdown(Status.UNAVAILABLE); transportInfo.listener.transportShutdown(Status.UNAVAILABLE);
@ -205,7 +209,7 @@ public class ManagedChannelImplTransportManagerTest {
// Subsequent getTransport() will use the first address, since last attempt was successful. // Subsequent getTransport() will use the first address, since last attempt was successful.
ClientTransport t3 = tm.getTransport(addressGroup); ClientTransport t3 = tm.getTransport(addressGroup);
t3.newStream(method2, new Metadata()); t3.newStream(method2, new Metadata(), callOptions2);
verify(mockTransportFactory, timeout(1000).times(2)) verify(mockTransportFactory, timeout(1000).times(2))
.newClientTransport(addr1, authority, userAgent); .newClientTransport(addr1, authority, userAgent);
// Still no back-off policy creation, because an address succeeded. // Still no back-off policy creation, because an address succeeded.
@ -213,7 +217,8 @@ public class ManagedChannelImplTransportManagerTest {
transportInfo = transports.poll(1, TimeUnit.SECONDS); transportInfo = transports.poll(1, TimeUnit.SECONDS);
ClientTransport rt3 = transportInfo.transport; ClientTransport rt3 = transportInfo.transport;
transportInfo.listener.transportReady(); transportInfo.listener.transportReady();
verify(rt3, timeout(1000)).newStream(same(method2), any(Metadata.class)); verify(rt3, timeout(1000)).newStream(same(method2), any(Metadata.class),
same(callOptions2));
verify(rt1, times(0)).newStream(any(MethodDescriptor.class), any(Metadata.class)); verify(rt1, times(0)).newStream(any(MethodDescriptor.class), any(Metadata.class));
// Back-off policy was never consulted. // Back-off policy was never consulted.

View File

@ -36,6 +36,7 @@ import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import io.grpc.CallOptions;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor;
@ -84,8 +85,8 @@ final class TestUtils {
@Override @Override
public ManagedClientTransport answer(InvocationOnMock invocation) throws Throwable { public ManagedClientTransport answer(InvocationOnMock invocation) throws Throwable {
final ManagedClientTransport mockTransport = mock(ManagedClientTransport.class); final ManagedClientTransport mockTransport = mock(ManagedClientTransport.class);
when(mockTransport.newStream(any(MethodDescriptor.class), any(Metadata.class))) when(mockTransport.newStream(any(MethodDescriptor.class), any(Metadata.class),
.thenReturn(mock(ClientStream.class)); any(CallOptions.class))).thenReturn(mock(ClientStream.class));
// Save the listener // Save the listener
doAnswer(new Answer<Void>() { doAnswer(new Answer<Void>() {
@Override @Override

View File

@ -48,6 +48,7 @@ import static org.mockito.Mockito.when;
import com.google.common.base.Stopwatch; import com.google.common.base.Stopwatch;
import io.grpc.CallOptions;
import io.grpc.EquivalentAddressGroup; import io.grpc.EquivalentAddressGroup;
import io.grpc.IntegerMarshaller; import io.grpc.IntegerMarshaller;
import io.grpc.LoadBalancer; import io.grpc.LoadBalancer;
@ -450,7 +451,8 @@ public class TransportSetTest {
verify(transportInfo.transport, times(0)).newStream( verify(transportInfo.transport, times(0)).newStream(
any(MethodDescriptor.class), any(Metadata.class)); any(MethodDescriptor.class), any(Metadata.class));
assertEquals(1, fakeExecutor.runDueTasks()); assertEquals(1, fakeExecutor.runDueTasks());
verify(transportInfo.transport).newStream(same(method), same(headers)); verify(transportInfo.transport).newStream(same(method), same(headers),
same(CallOptions.DEFAULT));
verify(transportInfo.transport).shutdown(); verify(transportInfo.transport).shutdown();
transportInfo.listener.transportShutdown(Status.UNAVAILABLE); transportInfo.listener.transportShutdown(Status.UNAVAILABLE);
verify(mockTransportSetCallback, never()).onTerminated(); verify(mockTransportSetCallback, never()).onTerminated();

View File

@ -36,6 +36,7 @@ import static io.netty.channel.ChannelOption.SO_KEEPALIVE;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import com.google.common.base.Ticker; import com.google.common.base.Ticker;
import io.grpc.CallOptions;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor;
import io.grpc.Status; import io.grpc.Status;
@ -112,7 +113,8 @@ class NettyClientTransport implements ManagedClientTransport {
} }
@Override @Override
public ClientStream newStream(MethodDescriptor<?, ?> method, Metadata headers) { public ClientStream newStream(MethodDescriptor<?, ?> method, Metadata headers, CallOptions
callOptions) {
Preconditions.checkNotNull(method, "method"); Preconditions.checkNotNull(method, "method");
Preconditions.checkNotNull(headers, "headers"); Preconditions.checkNotNull(headers, "headers");
return new NettyClientStream(method, headers, channel, handler, maxMessageSize, authority, return new NettyClientStream(method, headers, channel, handler, maxMessageSize, authority,
@ -124,6 +126,11 @@ class NettyClientTransport implements ManagedClientTransport {
}; };
} }
@Override
public ClientStream newStream(MethodDescriptor<?, ?> method, Metadata headers) {
return newStream(method, headers, CallOptions.DEFAULT);
}
@Override @Override
public void start(Listener transportListener) { public void start(Listener transportListener) {
lifecycleManager = new ClientTransportLifecycleManager( lifecycleManager = new ClientTransportLifecycleManager(

View File

@ -39,6 +39,7 @@ import com.google.common.base.Stopwatch;
import com.google.common.base.Ticker; import com.google.common.base.Ticker;
import com.google.common.util.concurrent.SettableFuture; import com.google.common.util.concurrent.SettableFuture;
import io.grpc.CallOptions;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor;
import io.grpc.MethodDescriptor.MethodType; import io.grpc.MethodDescriptor.MethodType;
@ -246,13 +247,20 @@ class OkHttpClientTransport implements ManagedClientTransport {
} }
@Override @Override
public OkHttpClientStream newStream(final MethodDescriptor<?, ?> method, final Metadata headers) { public OkHttpClientStream newStream(final MethodDescriptor<?, ?> method, final Metadata
headers, CallOptions callOptions) {
Preconditions.checkNotNull(method, "method"); Preconditions.checkNotNull(method, "method");
Preconditions.checkNotNull(headers, "headers"); Preconditions.checkNotNull(headers, "headers");
return new OkHttpClientStream(method, headers, frameWriter, OkHttpClientTransport.this, return new OkHttpClientStream(method, headers, frameWriter, OkHttpClientTransport.this,
outboundFlow, lock, maxMessageSize, defaultAuthority, userAgent); outboundFlow, lock, maxMessageSize, defaultAuthority, userAgent);
} }
@Override
public OkHttpClientStream newStream(final MethodDescriptor<?, ?> method, final Metadata
headers) {
return newStream(method, headers, CallOptions.DEFAULT);
}
@GuardedBy("lock") @GuardedBy("lock")
void streamReadyToStart(OkHttpClientStream clientStream) { void streamReadyToStart(OkHttpClientStream clientStream) {
synchronized (lock) { synchronized (lock) {