netty: Remove option to pass promise to WriteQueue

Passing a promise to WriteQueue was only misused to add a listener on
the promise before issuing the write. Although in this case the listener
ordering will be "random" because listeners are being added from two
different threads, in general we always want to add a listener after the
write returns to let any lower-level listeners be registered first.

Future work can resolve the "random" listener order by passing the
listener to the WriteQueue and adding the listener from the event loop.
This commit is contained in:
Eric Anderson 2018-05-30 17:46:40 -07:00
parent e393b4a1c1
commit ed4b5f30f4
7 changed files with 26 additions and 79 deletions

View File

@ -165,9 +165,8 @@ class NettyClientStream extends AbstractClientStream {
if (numBytes > 0) {
// Add the bytes to outbound flow control.
onSendingBytes(numBytes);
writeQueue.enqueue(
new SendGrpcFrameCommand(transportState(), bytebuf, endOfStream),
channel.newPromise().addListener(new ChannelFutureListener() {
writeQueue.enqueue(new SendGrpcFrameCommand(transportState(), bytebuf, endOfStream), flush)
.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
// If the future succeeds when http2stream is null, the stream has been cancelled
@ -179,7 +178,7 @@ class NettyClientStream extends AbstractClientStream {
NettyClientStream.this.getTransportTracer().reportMessageSent(numMessages);
}
}
}), flush);
});
} else {
// The frame is empty and will not impact outbound flow control. Just send it.
writeQueue.enqueue(new SendGrpcFrameCommand(transportState(), bytebuf, endOfStream), flush);

View File

@ -123,9 +123,8 @@ class NettyServerStream extends AbstractServerStream {
final int numBytes = bytebuf.readableBytes();
// Add the bytes to outbound flow control.
onSendingBytes(numBytes);
writeQueue.enqueue(
new SendGrpcFrameCommand(transportState(), bytebuf, false),
channel.newPromise().addListener(new ChannelFutureListener() {
writeQueue.enqueue(new SendGrpcFrameCommand(transportState(), bytebuf, false), flush)
.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
// Remove the bytes from outbound flow control, optionally notifying
@ -135,7 +134,7 @@ class NettyServerStream extends AbstractServerStream {
transportTracer.reportMessageSent(numMessages);
}
}
}), flush);
});
}
@Override

View File

@ -75,22 +75,10 @@ class WriteQueue {
*/
@CanIgnoreReturnValue
ChannelFuture enqueue(QueuedCommand command, boolean flush) {
return enqueue(command, channel.newPromise(), flush);
}
/**
* Enqueue a write command with a completion listener.
*
* @param command a write to be executed on the channel.
* @param promise to be marked on the completion of the write.
* @param flush true if a flush of the write should be schedule, false if a later call to
* enqueue will schedule the flush.
*/
@CanIgnoreReturnValue
ChannelFuture enqueue(QueuedCommand command, ChannelPromise promise, boolean flush) {
// Detect erroneous code that tries to reuse command objects.
Preconditions.checkArgument(command.promise() == null, "promise must not be set on command");
ChannelPromise promise = channel.newPromise();
command.promise(promise);
queue.add(command);
if (flush) {

View File

@ -30,7 +30,6 @@ import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyBoolean;
import static org.mockito.Matchers.eq;
import static org.mockito.Matchers.isA;
import static org.mockito.Matchers.same;
@ -58,6 +57,7 @@ import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultChannelPromise;
import io.netty.handler.codec.http2.DefaultHttp2Headers;
import io.netty.handler.codec.http2.Http2Headers;
import io.netty.util.AsciiString;
@ -180,7 +180,6 @@ public class NettyClientStreamTest extends NettyStreamTestBase<NettyClientStream
stream.flush();
verify(writeQueue).enqueue(
eq(new SendGrpcFrameCommand(stream.transportState(), messageFrame(MESSAGE), false)),
any(ChannelPromise.class),
eq(true));
}
@ -196,12 +195,10 @@ public class NettyClientStreamTest extends NettyStreamTestBase<NettyClientStream
verify(writeQueue).enqueue(
eq(new SendGrpcFrameCommand(
stream.transportState(), messageFrame(MESSAGE).slice(0, 5), false)),
any(ChannelPromise.class),
eq(false));
verify(writeQueue).enqueue(
eq(new SendGrpcFrameCommand(
stream.transportState(), messageFrame(MESSAGE).slice(5, 11), false)),
any(ChannelPromise.class),
eq(true));
}
@ -436,7 +433,10 @@ public class NettyClientStreamTest extends NettyStreamTestBase<NettyClientStream
metadata.put(GrpcUtil.USER_AGENT_KEY, "bad agent");
listener = mock(ClientStreamListener.class);
Mockito.reset(writeQueue);
when(writeQueue.enqueue(any(QueuedCommand.class), any(boolean.class))).thenReturn(future);
ChannelPromise completedPromise = new DefaultChannelPromise(channel)
.setSuccess();
when(writeQueue.enqueue(any(QueuedCommand.class), any(boolean.class)))
.thenReturn(completedPromise);
stream = new NettyClientStream(
new TransportStateImpl(handler, DEFAULT_MAX_MESSAGE_SIZE),
@ -485,8 +485,6 @@ public class NettyClientStreamTest extends NettyStreamTestBase<NettyClientStream
stream.writeMessage(new ByteArrayInputStream(msg));
stream.flush();
stream.halfClose();
verify(writeQueue, never()).enqueue(any(SendGrpcFrameCommand.class), any(ChannelPromise.class),
any(Boolean.class));
ArgumentCaptor<CreateStreamCommand> cmdCap = ArgumentCaptor.forClass(CreateStreamCommand.class);
verify(writeQueue).enqueue(cmdCap.capture(), eq(true));
ImmutableListMultimap<CharSequence, CharSequence> headers =
@ -501,16 +499,6 @@ public class NettyClientStreamTest extends NettyStreamTestBase<NettyClientStream
@Override
protected NettyClientStream createStream() {
when(handler.getWriteQueue()).thenReturn(writeQueue);
doAnswer(new Answer<Object>() {
@Override
public Object answer(InvocationOnMock invocation) throws Throwable {
if (future.isDone()) {
((ChannelPromise) invocation.getArguments()[1]).trySuccess();
}
return null;
}
}).when(writeQueue).enqueue(any(QueuedCommand.class), any(ChannelPromise.class), anyBoolean());
when(writeQueue.enqueue(any(QueuedCommand.class), anyBoolean())).thenReturn(future);
NettyClientStream stream = new NettyClientStream(
new TransportStateImpl(handler, DEFAULT_MAX_MESSAGE_SIZE),
methodDescriptor,

View File

@ -326,7 +326,7 @@ public abstract class NettyHandlerTestBase<T extends Http2ConnectionHandler> {
@CanIgnoreReturnValue
protected final ChannelFuture enqueue(WriteQueue.QueuedCommand command) {
ChannelFuture future = writeQueue.enqueue(command, newPromise(), true);
ChannelFuture future = writeQueue.enqueue(command, true);
channel.runPendingTasks();
return future;
}

View File

@ -21,9 +21,7 @@ import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
import static io.grpc.netty.NettyTestUtil.messageFrame;
import static org.junit.Assert.assertNull;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyBoolean;
import static org.mockito.Matchers.eq;
import static org.mockito.Matchers.isA;
import static org.mockito.Matchers.same;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.doAnswer;
@ -43,10 +41,8 @@ import io.grpc.internal.ServerStreamListener;
import io.grpc.internal.StatsTraceContext;
import io.grpc.internal.StreamListener;
import io.grpc.internal.TransportTracer;
import io.grpc.netty.WriteQueue.QueuedCommand;
import io.netty.buffer.EmptyByteBuf;
import io.netty.buffer.UnpooledByteBufAllocator;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http2.DefaultHttp2Headers;
import io.netty.util.AsciiString;
import java.io.ByteArrayInputStream;
@ -125,7 +121,6 @@ public class NettyServerStreamTest extends NettyStreamTestBase<NettyServerStream
verify(writeQueue).enqueue(
eq(new SendGrpcFrameCommand(stream.transportState(), messageFrame(MESSAGE), false)),
isA(ChannelPromise.class),
eq(true));
}
@ -288,16 +283,6 @@ public class NettyServerStreamTest extends NettyStreamTestBase<NettyServerStream
@Override
protected NettyServerStream createStream() {
when(handler.getWriteQueue()).thenReturn(writeQueue);
doAnswer(new Answer<Object>() {
@Override
public Object answer(InvocationOnMock invocation) throws Throwable {
if (future.isDone()) {
((ChannelPromise) invocation.getArguments()[1]).setSuccess();
}
return null;
}
}).when(writeQueue).enqueue(any(QueuedCommand.class), any(ChannelPromise.class), anyBoolean());
when(writeQueue.enqueue(any(QueuedCommand.class), anyBoolean())).thenReturn(future);
StatsTraceContext statsTraceCtx = StatsTraceContext.NOOP;
TransportTracer transportTracer = new TransportTracer();
NettyServerStream.TransportState state = new NettyServerStream.TransportState(

View File

@ -23,18 +23,17 @@ import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyLong;
import static org.mockito.Matchers.anyBoolean;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import io.grpc.internal.Stream;
import io.grpc.internal.StreamListener;
import io.grpc.netty.WriteQueue.QueuedCommand;
import io.netty.buffer.UnpooledByteBufAllocator;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelPromise;
@ -45,7 +44,6 @@ import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.Queue;
import java.util.concurrent.TimeUnit;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mock;
@ -69,11 +67,6 @@ public abstract class NettyStreamTestBase<T extends Stream> {
@Mock
private ChannelPipeline pipeline;
// ChannelFuture has too many methods to implement; we stubbed all necessary methods of Future.
@SuppressWarnings("DoNotMock")
@Mock
protected ChannelFuture future;
@Mock
protected EventLoop eventLoop;
@ -95,14 +88,16 @@ public abstract class NettyStreamTestBase<T extends Stream> {
public void setUp() {
MockitoAnnotations.initMocks(this);
mockFuture(true);
when(channel.write(any())).thenReturn(future);
when(channel.writeAndFlush(any())).thenReturn(future);
when(channel.alloc()).thenReturn(UnpooledByteBufAllocator.DEFAULT);
when(channel.pipeline()).thenReturn(pipeline);
when(channel.eventLoop()).thenReturn(eventLoop);
when(channel.newPromise()).thenReturn(new DefaultChannelPromise(channel));
when(channel.voidPromise()).thenReturn(new DefaultChannelPromise(channel));
ChannelPromise completedPromise = new DefaultChannelPromise(channel)
.setSuccess();
when(channel.write(any())).thenReturn(completedPromise);
when(channel.writeAndFlush(any())).thenReturn(completedPromise);
when(writeQueue.enqueue(any(QueuedCommand.class), anyBoolean())).thenReturn(completedPromise);
when(pipeline.firstContext()).thenReturn(ctx);
when(eventLoop.inEventLoop()).thenReturn(true);
when(http2Stream.id()).thenReturn(STREAM_ID);
@ -155,7 +150,8 @@ public abstract class NettyStreamTestBase<T extends Stream> {
sendHeadersIfServer();
assertTrue(stream.isReady());
byte[] msg = largeMessage();
// The future is set up to automatically complete, indicating that the write is done.
// The channel.write future is set up to automatically complete, indicating that the write is
// done.
stream.writeMessage(new ByteArrayInputStream(msg));
stream.flush();
assertTrue(stream.isReady());
@ -166,7 +162,8 @@ public abstract class NettyStreamTestBase<T extends Stream> {
public void shouldBeReadyForDataAfterWritingSmallMessage() throws IOException {
sendHeadersIfServer();
// Make sure the writes don't complete so we "back up"
reset(future);
ChannelPromise uncompletedPromise = new DefaultChannelPromise(channel);
when(writeQueue.enqueue(any(QueuedCommand.class), anyBoolean())).thenReturn(uncompletedPromise);
assertTrue(stream.isReady());
byte[] msg = smallMessage();
@ -180,7 +177,8 @@ public abstract class NettyStreamTestBase<T extends Stream> {
public void shouldNotBeReadyForDataAfterWritingLargeMessage() throws IOException {
sendHeadersIfServer();
// Make sure the writes don't complete so we "back up"
reset(future);
ChannelPromise uncompletedPromise = new DefaultChannelPromise(channel);
when(writeQueue.enqueue(any(QueuedCommand.class), anyBoolean())).thenReturn(uncompletedPromise);
assertTrue(stream.isReady());
byte[] msg = largeMessage();
@ -213,14 +211,4 @@ public abstract class NettyStreamTestBase<T extends Stream> {
protected abstract Queue<InputStream> listenerMessageQueue();
protected abstract void closeStream();
private void mockFuture(boolean succeeded) {
when(future.isDone()).thenReturn(true);
when(future.isCancelled()).thenReturn(false);
when(future.isSuccess()).thenReturn(succeeded);
when(future.awaitUninterruptibly(anyLong(), any(TimeUnit.class))).thenReturn(true);
if (!succeeded) {
when(future.cause()).thenReturn(new Exception("fake"));
}
}
}