diff --git a/core/src/main/java/io/grpc/internal/AbstractClientStream.java b/core/src/main/java/io/grpc/internal/AbstractClientStream.java index 14279dcf23..c100998a13 100644 --- a/core/src/main/java/io/grpc/internal/AbstractClientStream.java +++ b/core/src/main/java/io/grpc/internal/AbstractClientStream.java @@ -17,6 +17,7 @@ package io.grpc.internal; import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; import static io.grpc.internal.GrpcUtil.CONTENT_ENCODING_KEY; import static io.grpc.internal.GrpcUtil.MESSAGE_ENCODING_KEY; import static io.grpc.internal.GrpcUtil.TIMEOUT_KEY; @@ -227,6 +228,8 @@ public abstract class AbstractClientStream extends AbstractStream * #listenerClosed} because there may still be messages buffered to deliver to the application. */ private boolean statusReported; + private Metadata trailers; + private Status trailerStatus; protected TransportState( int maxMessageSize, @@ -241,20 +244,31 @@ public abstract class AbstractClientStream extends AbstractStream } private void setDecompressorRegistry(DecompressorRegistry decompressorRegistry) { - Preconditions.checkState(this.listener == null, "Already called start"); + checkState(this.listener == null, "Already called start"); this.decompressorRegistry = checkNotNull(decompressorRegistry, "decompressorRegistry"); } @VisibleForTesting public final void setListener(ClientStreamListener listener) { - Preconditions.checkState(this.listener == null, "Already called setListener"); + checkState(this.listener == null, "Already called setListener"); this.listener = checkNotNull(listener, "listener"); } @Override - public void deframerClosed(boolean hasPartialMessageIgnored) { + public void deframerClosed(boolean hasPartialMessage) { deframerClosed = true; + + if (trailerStatus != null) { + if (trailerStatus.isOk() && hasPartialMessage) { + trailerStatus = Status.INTERNAL.withDescription("Encountered end-of-stream mid-frame"); + trailers = new Metadata(); + } + transportReportStatus(trailerStatus, false, trailers); + } else { + checkState(statusReported, "status should have been reported on deframer closed"); + } + if (deframerClosedTask != null) { deframerClosedTask.run(); deframerClosedTask = null; @@ -280,7 +294,7 @@ public abstract class AbstractClientStream extends AbstractStream * @param headers the parsed headers */ protected void inboundHeadersReceived(Metadata headers) { - Preconditions.checkState(!statusReported, "Received headers on closed stream"); + checkState(!statusReported, "Received headers on closed stream"); statsTraceCtx.clientInboundHeaders(); boolean compressedStream = false; @@ -361,7 +375,9 @@ public abstract class AbstractClientStream extends AbstractStream new Object[]{status, trailers}); return; } - transportReportStatus(status, false, trailers); + this.trailers = trailers; + trailerStatus = status; + closeDeframer(false); } /** @@ -454,7 +470,7 @@ public abstract class AbstractClientStream extends AbstractStream @Override public void writePayload(InputStream message) { - Preconditions.checkState(payload == null, "writePayload should not be called multiple times"); + checkState(payload == null, "writePayload should not be called multiple times"); try { payload = IoUtils.toByteArray(message); } catch (java.io.IOException ex) { @@ -487,7 +503,7 @@ public abstract class AbstractClientStream extends AbstractStream @Override public void close() { closed = true; - Preconditions.checkState(payload != null, + checkState(payload != null, "Lack of request message. GET request is only supported for unary requests"); abstractClientStreamSink().writeHeaders(headers, payload); payload = null; diff --git a/core/src/test/java/io/grpc/internal/AbstractClientStreamTest.java b/core/src/test/java/io/grpc/internal/AbstractClientStreamTest.java index aad3ddf9ad..0d540a95f7 100644 --- a/core/src/test/java/io/grpc/internal/AbstractClientStreamTest.java +++ b/core/src/test/java/io/grpc/internal/AbstractClientStreamTest.java @@ -21,6 +21,7 @@ import static io.grpc.internal.ClientStreamListener.RpcProgress.PROCESSED; import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.AdditionalAnswers.delegatesTo; @@ -40,6 +41,7 @@ import io.grpc.Status; import io.grpc.Status.Code; import io.grpc.StreamTracer; import io.grpc.internal.AbstractClientStream.TransportState; +import io.grpc.internal.ClientStreamListener.RpcProgress; import io.grpc.internal.MessageFramerTest.ByteWritableBuffer; import io.grpc.internal.testing.TestClientStreamTracer; import java.io.ByteArrayInputStream; @@ -324,11 +326,50 @@ public class AbstractClientStreamTest { stream.transportState().requestMessagesFromDeframer(1); // Send first byte of 2 byte message stream.transportState().deframe(ReadableBuffers.wrap(new byte[] {0, 0, 0, 0, 2, 1})); - Status status = Status.INTERNAL; + Status status = Status.INTERNAL.withDescription("rst___stream"); // Simulate getting a reset stream.transportState().transportReportStatus(status, false /*stop delivery*/, new Metadata()); - verify(mockListener).closed(any(Status.class), same(PROCESSED), any(Metadata.class)); + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); + verify(mockListener) + .closed(statusCaptor.capture(), any(RpcProgress.class), any(Metadata.class)); + assertSame(Status.Code.INTERNAL, statusCaptor.getValue().getCode()); + assertEquals("rst___stream", statusCaptor.getValue().getDescription()); + } + + @Test + public void trailerOkWithTruncatedMessage() { + AbstractClientStream stream = + new BaseAbstractClientStream(allocator, statsTraceCtx, transportTracer); + stream.start(mockListener); + + stream.transportState().requestMessagesFromDeframer(1); + stream.transportState().deframe(ReadableBuffers.wrap(new byte[] {0, 0, 0, 0, 2, 1})); + stream.transportState().inboundTrailersReceived(new Metadata(), Status.OK); + + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); + verify(mockListener) + .closed(statusCaptor.capture(), any(RpcProgress.class), any(Metadata.class)); + assertSame(Status.Code.INTERNAL, statusCaptor.getValue().getCode()); + assertEquals("Encountered end-of-stream mid-frame", statusCaptor.getValue().getDescription()); + } + + @Test + public void trailerNotOkWithTruncatedMessage() { + AbstractClientStream stream = + new BaseAbstractClientStream(allocator, statsTraceCtx, transportTracer); + stream.start(mockListener); + + stream.transportState().requestMessagesFromDeframer(1); + stream.transportState().deframe(ReadableBuffers.wrap(new byte[] {0, 0, 0, 0, 2, 1})); + stream.transportState().inboundTrailersReceived( + new Metadata(), Status.DATA_LOSS.withDescription("data___loss")); + + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); + verify(mockListener) + .closed(statusCaptor.capture(), any(RpcProgress.class), any(Metadata.class)); + assertSame(Status.Code.DATA_LOSS, statusCaptor.getValue().getCode()); + assertEquals("data___loss", statusCaptor.getValue().getDescription()); } @Test diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientStream.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientStream.java index 85be5aff1a..c1585b43b1 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientStream.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientStream.java @@ -277,9 +277,9 @@ class OkHttpClientStream extends AbstractClientStream { @Override @GuardedBy("lock") - public void deframerClosed(boolean hasPartialMessageIgnored) { + public void deframerClosed(boolean hasPartialMessage) { onEndOfStream(); - super.deframerClosed(hasPartialMessageIgnored); + super.deframerClosed(hasPartialMessage); } @Override