all/tests: unmock ClientCall and ServerCall

This commit is contained in:
Carl Mastrangelo 2016-08-29 13:25:33 -07:00
parent 3bf8d94f02
commit 48c6b3d398
8 changed files with 406 additions and 131 deletions

View File

@ -31,11 +31,12 @@
package io.grpc.auth;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.isA;
import static org.mockito.Matchers.same;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@ -102,8 +103,7 @@ public class ClientAuthInterceptorTest {
@Mock
Channel channel;
@Mock
ClientCall<String, Integer> call;
ClientCallRecorder call = new ClientCallRecorder();
ClientAuthInterceptor interceptor;
@ -130,7 +130,8 @@ public class ClientAuthInterceptorTest {
interceptor.interceptCall(descriptor, CallOptions.DEFAULT, channel);
Metadata headers = new Metadata();
interceptedCall.start(listener, headers);
verify(call).start(listener, headers);
assertEquals(listener, call.responseListener);
assertEquals(headers, call.headers);
Iterable<String> authorization = headers.getAll(AUTHORIZATION);
Assert.assertArrayEquals(new String[]{"token1", "token2"},
@ -150,7 +151,8 @@ public class ClientAuthInterceptorTest {
ArgumentCaptor<Status> statusCaptor = ArgumentCaptor.forClass(Status.class);
Mockito.verify(listener).onClose(statusCaptor.capture(), isA(Metadata.class));
Assert.assertNull(headers.getAll(AUTHORIZATION));
Mockito.verify(call, never()).start(listener, headers);
assertNull(call.responseListener);
assertNull(call.headers);
Assert.assertEquals(Status.Code.UNAUTHENTICATED, statusCaptor.getValue().getCode());
Assert.assertNotNull(statusCaptor.getValue().getCause());
}
@ -169,7 +171,8 @@ public class ClientAuthInterceptorTest {
interceptor.interceptCall(descriptor, CallOptions.DEFAULT, channel);
Metadata headers = new Metadata();
interceptedCall.start(listener, headers);
verify(call).start(listener, headers);
assertEquals(listener, call.responseListener);
assertEquals(headers, call.headers);
Iterable<String> authorization = headers.getAll(AUTHORIZATION);
Assert.assertArrayEquals(new String[]{"Bearer allyourbase"},
Iterables.toArray(authorization, String.class));
@ -191,4 +194,42 @@ public class ClientAuthInterceptorTest {
verify(credentials).getRequestMetadata(URI.create("https://example.com:123/a.service"));
interceptedCall.cancel("Cancel for test", null);
}
private static final class ClientCallRecorder extends ClientCall<String, Integer> {
private ClientCall.Listener<Integer> responseListener;
private Metadata headers;
private int numMessages;
private String cancelMessage;
private Throwable cancelCause;
private boolean halfClosed;
private String sentMessage;
@Override
public void start(ClientCall.Listener<Integer> responseListener, Metadata headers) {
this.responseListener = responseListener;
this.headers = headers;
}
@Override
public void request(int numMessages) {
this.numMessages = numMessages;
}
@Override
public void cancel(String message, Throwable cause) {
this.cancelMessage = message;
this.cancelCause = cause;
}
@Override
public void halfClose() {
halfClosed = true;
}
@Override
public void sendMessage(String message) {
sentMessage = message;
}
}
}

View File

@ -31,17 +31,16 @@
package io.grpc;
import static com.google.common.truth.Truth.assertThat;
import static java.util.concurrent.TimeUnit.NANOSECONDS;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyInt;
import static org.mockito.Matchers.eq;
import static org.mockito.Matchers.isA;
import static org.mockito.Matchers.same;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
@ -61,8 +60,6 @@ import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import java.util.ArrayList;
import java.util.Arrays;
@ -75,8 +72,7 @@ public class ClientInterceptorsTest {
@Mock
private Channel channel;
@Mock
private ClientCall<String, Integer> call;
private BaseClientCall call = new BaseClientCall();
@Mock
private MethodDescriptor<String, Integer> method;
@ -89,18 +85,6 @@ public class ClientInterceptorsTest {
when(channel.newCall(
Mockito.<MethodDescriptor<String, Integer>>any(), any(CallOptions.class)))
.thenReturn(call);
// Emulate the precondition checks in ChannelImpl.CallImpl
Answer<Void> checkStartCalled = new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocation) {
verify(call).start(Mockito.<ClientCall.Listener<Integer>>any(), Mockito.<Metadata>any());
return null;
}
};
doAnswer(checkStartCalled).when(call).request(anyInt());
doAnswer(checkStartCalled).when(call).halfClose();
doAnswer(checkStartCalled).when(call).sendMessage(Mockito.<String>any());
}
@Test(expected = NullPointerException.class)
@ -290,11 +274,10 @@ public class ClientInterceptorsTest {
ClientCall<String, Integer> interceptedCall = intercepted.newCall(method, CallOptions.DEFAULT);
// start() on the intercepted call will eventually reach the call created by the real channel
interceptedCall.start(listener, new Metadata());
ArgumentCaptor<Metadata> captor = ArgumentCaptor.forClass(Metadata.class);
// The headers passed to the real channel call will contain the information inserted by the
// interceptor.
verify(call).start(same(listener), captor.capture());
assertEquals("abcd", captor.getValue().get(credKey));
assertSame(listener, call.listener);
assertEquals("abcd", call.headers.get(credKey));
}
@Test
@ -327,12 +310,11 @@ public class ClientInterceptorsTest {
ClientCall<String, Integer> interceptedCall = intercepted.newCall(method, CallOptions.DEFAULT);
interceptedCall.start(listener, new Metadata());
// Capture the underlying call listener that will receive headers from the transport.
ArgumentCaptor<ClientCall.Listener<Integer>> captor = ArgumentCaptor.forClass(null);
verify(call).start(captor.capture(), Mockito.<Metadata>any());
Metadata inboundHeaders = new Metadata();
// Simulate that a headers arrives on the underlying call listener.
captor.getValue().onHeaders(inboundHeaders);
assertEquals(Arrays.asList(inboundHeaders), examinedHeaders);
call.listener.onHeaders(inboundHeaders);
assertThat(examinedHeaders).contains(inboundHeaders);
}
@Test
@ -354,13 +336,14 @@ public class ClientInterceptorsTest {
ClientCall.Listener<Integer> listener = mock(ClientCall.Listener.class);
Metadata headers = new Metadata();
interceptedCall.start(listener, headers);
verify(call).start(same(listener), same(headers));
assertSame(listener, call.listener);
assertSame(headers, call.headers);
interceptedCall.sendMessage("request");
verify(call).sendMessage(eq("request"));
assertThat(call.messages).containsExactly("request");
interceptedCall.halfClose();
verify(call).halfClose();
assertTrue(call.halfClosed);
interceptedCall.request(1);
verify(call).request(1);
assertThat(call.requests).containsExactly(1);
}
@Test
@ -392,7 +375,7 @@ public class ClientInterceptorsTest {
interceptedCall.sendMessage("request");
interceptedCall.halfClose();
interceptedCall.request(1);
verifyNoMoreInteractions(call);
call.done = true;
ArgumentCaptor<Status> captor = ArgumentCaptor.forClass(Status.class);
verify(listener).onClose(captor.capture(), any(Metadata.class));
assertSame(error, captor.getValue().getCause());
@ -406,7 +389,6 @@ public class ClientInterceptorsTest {
noop.halfClose();
noop.sendMessage(null);
assertFalse(noop.isReady());
verifyNoMoreInteractions(call);
}
@Test
@ -432,12 +414,12 @@ public class ClientInterceptorsTest {
CallOptions callOptions = CallOptions.DEFAULT.withOption(customOption, "value");
ArgumentCaptor<CallOptions> passedOptions = ArgumentCaptor.forClass(CallOptions.class);
ClientInterceptor interceptor = spy(new NoopInterceptor());
Channel intercepted = ClientInterceptors.intercept(channel, interceptor);
assertSame(call, intercepted.newCall(method, callOptions));
verify(channel).newCall(same(method), same(callOptions));
verify(interceptor).interceptCall(same(method), passedOptions.capture(), isA(Channel.class));
assertSame("value", passedOptions.getValue().getOption(customOption));
}
@ -449,4 +431,64 @@ public class ClientInterceptorsTest {
return next.newCall(method, callOptions);
}
}
private static class BaseClientCall extends ClientCall<String, Integer> {
private boolean started;
private boolean done;
private ClientCall.Listener<Integer> listener;
private Metadata headers;
private List<Integer> requests = new ArrayList<Integer>();
private List<String> messages = new ArrayList<String>();
private boolean halfClosed;
private Throwable cancelCause;
private String cancelMessage;
@Override
public void start(ClientCall.Listener<Integer> listener, Metadata headers) {
checkNotDone();
started = true;
this.listener = listener;
this.headers = headers;
}
@Override
public void request(int numMessages) {
checkNotDone();
checkStarted();
requests.add(numMessages);
}
@Override
public void cancel(String message, Throwable cause) {
checkNotDone();
this.cancelMessage = message;
this.cancelCause = cause;
}
@Override
public void halfClose() {
checkNotDone();
checkStarted();
this.halfClosed = true;
}
@Override
public void sendMessage(String message) {
checkNotDone();
checkStarted();
messages.add(message);
}
private void checkNotDone() {
if (done) {
throw new IllegalStateException("no more methods should be called");
}
}
private void checkStarted() {
if (!started) {
throw new IllegalStateException("should have called start");
}
}
}
}

View File

@ -45,6 +45,7 @@ import static org.junit.Assert.fail;
import static org.mockito.Mockito.mock;
import io.grpc.internal.FakeClock;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@ -66,7 +67,30 @@ public class ContextsTest {
@SuppressWarnings("unchecked")
private MethodDescriptor<Object, Object> method = mock(MethodDescriptor.class);
@SuppressWarnings("unchecked")
private ServerCall<Object, Object> call = mock(ServerCall.class);
private ServerCall<Object, Object> call = new ServerCall<Object, Object>() {
@Override
public void request(int numMessages) {}
@Override
public void sendHeaders(Metadata headers) {}
@Override
public void sendMessage(Object message) {}
@Override
public void close(Status status, Metadata trailers) {}
@Override
public boolean isCancelled() {
return false;
}
@Override
public MethodDescriptor<Object, Object> getMethodDescriptor() {
return null;
}
};
private Metadata headers = new Metadata();
@Test

View File

@ -45,7 +45,6 @@ import static org.mockito.Mockito.verifyZeroInteractions;
import io.grpc.MethodDescriptor.Marshaller;
import io.grpc.MethodDescriptor.MethodType;
import io.grpc.ServerCall.Listener;
import io.grpc.ServerMethodDefinition;
import org.junit.After;
import org.junit.Before;
@ -78,9 +77,8 @@ public class ServerInterceptorsTest {
private ServerCall.Listener<String> listener;
private MethodDescriptor<String, Integer> flowMethod;
@Mock
private ServerCall<String, Integer> call;
private ServerCall<String, Integer> call = new BaseServerCall<String, Integer>();
private ServerServiceDefinition serviceDefinition;
@ -282,7 +280,7 @@ public class ServerInterceptorsTest {
@Test
public void argumentsPassed() {
@SuppressWarnings("unchecked")
final ServerCall<String, Integer> call2 = mock(ServerCall.class);
final ServerCall<String, Integer> call2 = new BaseServerCall<String, Integer>();
@SuppressWarnings("unchecked")
final ServerCall.Listener<String> listener2 = mock(ServerCall.Listener.class);
@ -408,7 +406,7 @@ public class ServerInterceptorsTest {
.intercept(inputStreamMessageService, interceptor2);
ServerMethodDefinition<InputStream, InputStream> serverMethod =
(ServerMethodDefinition<InputStream, InputStream>) intercepted2.getMethod("basic/wrapped");
ServerCall<InputStream, InputStream> call2 = mock(ServerCall.class);
ServerCall<InputStream, InputStream> call2 = new BaseServerCall<InputStream, InputStream>();
byte[] bytes = {};
serverMethod
.getServerCallHandler()
@ -459,4 +457,29 @@ public class ServerInterceptorsTest {
return inputStream;
}
}
private static class BaseServerCall<ReqT, RespT> extends ServerCall<ReqT, RespT> {
@Override
public void request(int numMessages) {}
@Override
public void sendHeaders(Metadata headers) {}
@Override
public void sendMessage(RespT message) {}
@Override
public void close(Status status, Metadata trailers) {}
@Override
public boolean isCancelled() {
return false;
}
@Override
public MethodDescriptor<ReqT, RespT> getMethodDescriptor() {
return null;
}
}
}

View File

@ -48,6 +48,7 @@ import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ClientCall;
import io.grpc.Deadline;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.testing.integration.Messages.SimpleRequest;
import io.grpc.testing.integration.Messages.SimpleResponse;
@ -73,17 +74,37 @@ public class StubConfigTest {
@Mock
private StreamObserver<SimpleResponse> responseObserver;
@Mock
private ClientCall<SimpleRequest, SimpleResponse> call;
/**
* Sets up mocks.
*/
@Before public void setUp() {
MockitoAnnotations.initMocks(this);
ClientCall<SimpleRequest, SimpleResponse> call =
new ClientCall<SimpleRequest, SimpleResponse>() {
@Override
public void start(
ClientCall.Listener<SimpleResponse> responseListener, Metadata headers) {
}
@Override
public void request(int numMessages) {
}
@Override
public void cancel(String message, Throwable cause) {
}
@Override
public void halfClose() {
}
@Override
public void sendMessage(SimpleRequest message) {
}
};
when(channel.newCall(
Mockito.<MethodDescriptor<SimpleRequest, SimpleResponse>>any(), any(CallOptions.class)))
.thenReturn(call);
Mockito.<MethodDescriptor<SimpleRequest, SimpleResponse>>any(), any(CallOptions.class)))
.thenReturn(call);
}
@Test

View File

@ -5,6 +5,7 @@ plugins {
description = "gRPC: Stub"
dependencies {
compile project(':grpc-core')
testCompile libraries.truth
}
// Configure the animal sniffer plugin

View File

@ -31,21 +31,18 @@
package io.grpc.stub;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.SettableFuture;
import io.grpc.CallOptions;
import io.grpc.ClientCall;
import io.grpc.ClientCall.Listener;
import io.grpc.ManagedChannel;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
@ -64,12 +61,7 @@ import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import java.util.ArrayList;
import java.util.Arrays;
@ -80,6 +72,7 @@ import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
/**
* Unit tests for {@link ClientCalls}.
@ -96,9 +89,6 @@ public class ClientCallsTest {
private Server server;
private ManagedChannel channel;
@Mock
private ClientCall<Integer, String> call;
@Before
public void setUp() {
MockitoAnnotations.initMocks(this);
@ -121,16 +111,13 @@ public class ClientCallsTest {
final Status status = Status.OK;
final Metadata trailers = new Metadata();
doAnswer(new Answer<Void>() {
BaseClientCall call = new BaseClientCall() {
@Override
public Void answer(InvocationOnMock in) throws Throwable {
@SuppressWarnings("unchecked")
Listener<String> listener = (Listener<String>) in.getArguments()[0];
public void start(ClientCall.Listener<String> listener, Metadata headers) {
listener.onMessage(resp);
listener.onClose(status, trailers);
return null;
}
}).when(call).start(Mockito.<Listener<String>>any(), any(Metadata.class));
};
String actualResponse = ClientCalls.blockingUnaryCall(call, req);
assertEquals(resp, actualResponse);
@ -142,15 +129,12 @@ public class ClientCallsTest {
final Status status = Status.INTERNAL.withDescription("Unique status");
final Metadata trailers = new Metadata();
doAnswer(new Answer<Void>() {
BaseClientCall call = new BaseClientCall() {
@Override
public Void answer(InvocationOnMock in) throws Throwable {
@SuppressWarnings("unchecked")
Listener<String> listener = (Listener<String>) in.getArguments()[0];
public void start(io.grpc.ClientCall.Listener<String> listener, Metadata headers) {
listener.onClose(status, trailers);
return null;
}
}).when(call).start(Mockito.<Listener<String>>any(), any(Metadata.class));
};
try {
ClientCalls.blockingUnaryCall(call, req);
@ -163,27 +147,50 @@ public class ClientCallsTest {
@Test
public void unaryFutureCallSuccess() throws Exception {
final AtomicReference<ClientCall.Listener<String>> listener =
new AtomicReference<ClientCall.Listener<String>>();
final AtomicReference<Integer> message = new AtomicReference<Integer>();
final AtomicReference<Boolean> halfClosed = new AtomicReference<Boolean>();
BaseClientCall call = new BaseClientCall() {
@Override
public void start(io.grpc.ClientCall.Listener<String> responseListener, Metadata headers) {
listener.set(responseListener);
}
@Override
public void sendMessage(Integer msg) {
message.set(msg);
}
@Override
public void halfClose() {
halfClosed.set(true);
}
};
Integer req = 2;
ListenableFuture<String> future = ClientCalls.futureUnaryCall(call, req);
ArgumentCaptor<ClientCall.Listener<String>> listenerCaptor = ArgumentCaptor.forClass(null);
verify(call).start(listenerCaptor.capture(), any(Metadata.class));
ClientCall.Listener<String> listener = listenerCaptor.getValue();
verify(call).sendMessage(req);
verify(call).halfClose();
listener.onMessage("bar");
listener.onClose(Status.OK, new Metadata());
assertEquals(req, message.get());
assertTrue(halfClosed.get());
listener.get().onMessage("bar");
listener.get().onClose(Status.OK, new Metadata());
assertEquals("bar", future.get());
}
@Test
public void unaryFutureCallFailed() throws Exception {
final AtomicReference<ClientCall.Listener<String>> listener =
new AtomicReference<ClientCall.Listener<String>>();
BaseClientCall call = new BaseClientCall() {
@Override
public void start(io.grpc.ClientCall.Listener<String> responseListener, Metadata headers) {
listener.set(responseListener);
}
};
Integer req = 2;
ListenableFuture<String> future = ClientCalls.futureUnaryCall(call, req);
ArgumentCaptor<ClientCall.Listener<String>> listenerCaptor = ArgumentCaptor.forClass(null);
verify(call).start(listenerCaptor.capture(), any(Metadata.class));
ClientCall.Listener<String> listener = listenerCaptor.getValue();
Metadata trailers = new Metadata();
listener.onClose(Status.INTERNAL, trailers);
listener.get().onClose(Status.INTERNAL, trailers);
try {
future.get();
fail("Should fail");
@ -197,15 +204,29 @@ public class ClientCallsTest {
@Test
public void unaryFutureCallCancelled() throws Exception {
final AtomicReference<ClientCall.Listener<String>> listener =
new AtomicReference<ClientCall.Listener<String>>();
final AtomicReference<String> cancelMessage = new AtomicReference<String>();
final AtomicReference<Throwable> cancelCause = new AtomicReference<Throwable>();
BaseClientCall call = new BaseClientCall() {
@Override
public void start(io.grpc.ClientCall.Listener<String> responseListener, Metadata headers) {
listener.set(responseListener);
}
@Override
public void cancel(String message, Throwable cause) {
cancelMessage.set(message);
cancelCause.set(cause);
}
};
Integer req = 2;
ListenableFuture<String> future = ClientCalls.futureUnaryCall(call, req);
ArgumentCaptor<ClientCall.Listener<String>> listenerCaptor = ArgumentCaptor.forClass(null);
verify(call).start(listenerCaptor.capture(), any(Metadata.class));
ClientCall.Listener<String> listener = listenerCaptor.getValue();
future.cancel(true);
verify(call).cancel("GrpcFuture was cancelled", null);
listener.onMessage("bar");
listener.onClose(Status.OK, new Metadata());
assertEquals("GrpcFuture was cancelled", cancelMessage.get());
assertNull(cancelCause.get());
listener.get().onMessage("bar");
listener.get().onClose(Status.OK, new Metadata());
try {
future.get();
fail("Should fail");
@ -216,6 +237,7 @@ public class ClientCallsTest {
@Test
public void cannotSetOnReadyAfterCallStarted() throws Exception {
BaseClientCall call = new BaseClientCall();
CallStreamObserver<Integer> callStreamObserver =
(CallStreamObserver<Integer>) ClientCalls.asyncClientStreamingCall(call,
new NoopStreamObserver<String>());
@ -235,7 +257,20 @@ public class ClientCallsTest {
@Test
public void disablingInboundAutoFlowControlSuppressesRequestsForMoreMessages()
throws Exception {
ArgumentCaptor<ClientCall.Listener<String>> listenerCaptor = ArgumentCaptor.forClass(null);
final AtomicReference<ClientCall.Listener<String>> listener =
new AtomicReference<ClientCall.Listener<String>>();
final List<Integer> requests = new ArrayList<Integer>();
BaseClientCall call = new BaseClientCall() {
@Override
public void start(io.grpc.ClientCall.Listener<String> responseListener, Metadata headers) {
listener.set(responseListener);
}
@Override
public void request(int numMessages) {
requests.add(numMessages);
}
};
ClientCalls.asyncBidiStreamingCall(call, new ClientResponseObserver<Integer, String>() {
@Override
public void beforeStart(ClientCallStreamObserver<Integer> requestStream) {
@ -257,15 +292,13 @@ public class ClientCallsTest {
}
});
verify(call).start(listenerCaptor.capture(), any(Metadata.class));
listenerCaptor.getValue().onMessage("message");
verify(call, times(1)).request(1);
listener.get().onMessage("message");
assertThat(requests).containsExactly(1);
}
@Test
public void callStreamObserverPropagatesFlowControlRequestsToCall()
throws Exception {
ArgumentCaptor<ClientCall.Listener<String>> listenerCaptor = ArgumentCaptor.forClass(null);
ClientResponseObserver<Integer, String> responseObserver =
new ClientResponseObserver<Integer, String>() {
@Override
@ -285,19 +318,32 @@ public class ClientCallsTest {
public void onCompleted() {
}
};
final AtomicReference<ClientCall.Listener<String>> listener =
new AtomicReference<ClientCall.Listener<String>>();
final List<Integer> requests = new ArrayList<Integer>();
BaseClientCall call = new BaseClientCall() {
@Override
public void start(io.grpc.ClientCall.Listener<String> responseListener, Metadata headers) {
listener.set(responseListener);
}
@Override
public void request(int numMessages) {
requests.add(numMessages);
}
};
CallStreamObserver<Integer> requestObserver =
(CallStreamObserver<Integer>)
ClientCalls.asyncBidiStreamingCall(call, responseObserver);
verify(call).start(listenerCaptor.capture(), any(Metadata.class));
listenerCaptor.getValue().onMessage("message");
listener.get().onMessage("message");
requestObserver.request(5);
verify(call, times(1)).request(5);
assertThat(requests).contains(5);
}
@Test
public void canCaptureInboundFlowControlForServerStreamingObserver()
throws Exception {
ArgumentCaptor<ClientCall.Listener<String>> listenerCaptor = ArgumentCaptor.forClass(null);
ClientResponseObserver<Integer, String> responseObserver =
new ClientResponseObserver<Integer, String>() {
@Override
@ -318,11 +364,23 @@ public class ClientCallsTest {
public void onCompleted() {
}
};
final AtomicReference<ClientCall.Listener<String>> listener =
new AtomicReference<ClientCall.Listener<String>>();
final List<Integer> requests = new ArrayList<Integer>();
BaseClientCall call = new BaseClientCall() {
@Override
public void start(io.grpc.ClientCall.Listener<String> responseListener, Metadata headers) {
listener.set(responseListener);
}
@Override
public void request(int numMessages) {
requests.add(numMessages);
}
};
ClientCalls.asyncServerStreamingCall(call, 1, responseObserver);
verify(call).start(listenerCaptor.capture(), any(Metadata.class));
listenerCaptor.getValue().onMessage("message");
verify(call, times(1)).request(1);
verify(call, times(1)).request(5);
listener.get().onMessage("message");
assertThat(requests).containsExactly(5, 1).inOrder();
}
@Test
@ -497,13 +555,20 @@ public class ClientCallsTest {
@Test
public void blockingResponseStreamFailed() throws Exception {
final AtomicReference<ClientCall.Listener<String>> listener =
new AtomicReference<ClientCall.Listener<String>>();
BaseClientCall call = new BaseClientCall() {
@Override
public void start(io.grpc.ClientCall.Listener<String> responseListener, Metadata headers) {
listener.set(responseListener);
}
};
Integer req = 2;
Iterator<String> iter = ClientCalls.blockingServerStreamingCall(call, req);
ArgumentCaptor<ClientCall.Listener<String>> listenerCaptor = ArgumentCaptor.forClass(null);
verify(call).start(listenerCaptor.capture(), any(Metadata.class));
ClientCall.Listener<String> listener = listenerCaptor.getValue();
Metadata trailers = new Metadata();
listener.onClose(Status.INTERNAL, trailers);
listener.get().onClose(Status.INTERNAL, trailers);
try {
iter.next();
fail("Should fail");
@ -514,4 +579,21 @@ public class ClientCallsTest {
assertSame(trailers, metadata);
}
}
private static class BaseClientCall extends ClientCall<Integer, String> {
@Override
public void start(io.grpc.ClientCall.Listener<String> responseListener, Metadata headers) {}
@Override
public void request(int numMessages) {}
@Override
public void cancel(String message, Throwable cause) {}
@Override
public void halfClose() {}
@Override
public void sendMessage(Integer message) {}
}
}

View File

@ -31,12 +31,12 @@
package io.grpc.stub;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.times;
import io.grpc.CallOptions;
import io.grpc.ClientCall;
@ -51,13 +51,9 @@ import io.grpc.inprocess.InProcessChannelBuilder;
import io.grpc.inprocess.InProcessServerBuilder;
import io.grpc.internal.ManagedChannelImpl;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
@ -65,6 +61,8 @@ import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
@ -88,13 +86,7 @@ public class ServerCallsTest {
"some/unarymethod",
new IntegerMarshaller(), new IntegerMarshaller());
@Mock
ServerCall<Integer, Integer> serverCall;
@Before
public void setUp() throws Exception {
MockitoAnnotations.initMocks(this);
}
private final ServerCallRecorder serverCall = new ServerCallRecorder();
@Test
public void runtimeStreamObserverIsServerCallStreamObserver() throws Exception {
@ -130,8 +122,8 @@ public class ServerCallsTest {
});
ServerCall.Listener<Integer> callListener =
callHandler.startCall(serverCall, new Metadata());
Mockito.when(serverCall.isReady()).thenReturn(true).thenReturn(false);
Mockito.when(serverCall.isCancelled()).thenReturn(false).thenReturn(true);
serverCall.isReady = true;
serverCall.isCancelled = false;
assertTrue(callObserver.get().isReady());
assertFalse(callObserver.get().isCancelled());
callListener.onReady();
@ -140,11 +132,13 @@ public class ServerCallsTest {
assertTrue(invokeCalled.get());
assertTrue(onReadyCalled.get());
assertTrue(onCancelCalled.get());
serverCall.isReady = false;
serverCall.isCancelled = true;
assertFalse(callObserver.get().isReady());
assertTrue(callObserver.get().isCancelled());
// Is called twice, once to permit the first message and once again after the first message
// has been processed (auto flow control)
Mockito.verify(serverCall, times(2)).request(1);
assertThat(serverCall.requestCalls).containsExactly(1, 1).inOrder();
}
@Test
@ -247,7 +241,7 @@ public class ServerCallsTest {
// to verify that message delivery does not trigger a call to request(1).
callListener.onMessage(1);
// Should never be called
Mockito.verify(serverCall, times(0)).request(1);
assertThat(serverCall.requestCalls).isEmpty();
}
@Test
@ -265,7 +259,7 @@ public class ServerCallsTest {
callHandler.startCall(serverCall, new Metadata());
// Auto inbound flow-control always requests 2 messages for unary to detect a violation
// of the unary semantic.
Mockito.verify(serverCall, times(1)).request(2);
assertThat(serverCall.requestCalls).containsExactly(2);
}
@Test
@ -288,8 +282,8 @@ public class ServerCallsTest {
});
ServerCall.Listener<Integer> callListener =
callHandler.startCall(serverCall, new Metadata());
Mockito.when(serverCall.isReady()).thenReturn(true).thenReturn(false);
Mockito.when(serverCall.isCancelled()).thenReturn(false).thenReturn(true);
serverCall.isReady = true;
serverCall.isCancelled = false;
callListener.onReady();
// On ready is not called until the unary request message is delivered
assertEquals(0, onReadyCalled.get());
@ -392,4 +386,51 @@ public class ServerCallsTest {
}
}
}
private static class ServerCallRecorder extends ServerCall<Integer, Integer> {
private List<Integer> requestCalls = new ArrayList<Integer>();
private Metadata headers;
private Integer message;
private Metadata trailers;
private Status status;
private boolean isCancelled;
private MethodDescriptor<Integer, Integer> methodDescriptor;
private boolean isReady;
@Override
public void request(int numMessages) {
requestCalls.add(numMessages);
}
@Override
public void sendHeaders(Metadata headers) {
this.headers = headers;
}
@Override
public void sendMessage(Integer message) {
this.message = message;
}
@Override
public void close(Status status, Metadata trailers) {
this.status = status;
this.trailers = trailers;
}
@Override
public boolean isCancelled() {
return isCancelled;
}
@Override
public boolean isReady() {
return isReady;
}
@Override
public MethodDescriptor<Integer, Integer> getMethodDescriptor() {
return methodDescriptor;
}
}
}