core: fix client does not detect truncated message

Resolves #3264
This commit is contained in:
ZHANG Dapeng 2018-05-21 13:51:15 -07:00 committed by GitHub
parent 10291d5ccc
commit 451c412354
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 68 additions and 11 deletions

View File

@ -17,6 +17,7 @@
package io.grpc.internal;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;
import static io.grpc.internal.GrpcUtil.CONTENT_ENCODING_KEY;
import static io.grpc.internal.GrpcUtil.MESSAGE_ENCODING_KEY;
import static io.grpc.internal.GrpcUtil.TIMEOUT_KEY;
@ -227,6 +228,8 @@ public abstract class AbstractClientStream extends AbstractStream
* #listenerClosed} because there may still be messages buffered to deliver to the application.
*/
private boolean statusReported;
private Metadata trailers;
private Status trailerStatus;
protected TransportState(
int maxMessageSize,
@ -241,20 +244,31 @@ public abstract class AbstractClientStream extends AbstractStream
}
private void setDecompressorRegistry(DecompressorRegistry decompressorRegistry) {
Preconditions.checkState(this.listener == null, "Already called start");
checkState(this.listener == null, "Already called start");
this.decompressorRegistry =
checkNotNull(decompressorRegistry, "decompressorRegistry");
}
@VisibleForTesting
public final void setListener(ClientStreamListener listener) {
Preconditions.checkState(this.listener == null, "Already called setListener");
checkState(this.listener == null, "Already called setListener");
this.listener = checkNotNull(listener, "listener");
}
@Override
public void deframerClosed(boolean hasPartialMessageIgnored) {
public void deframerClosed(boolean hasPartialMessage) {
deframerClosed = true;
if (trailerStatus != null) {
if (trailerStatus.isOk() && hasPartialMessage) {
trailerStatus = Status.INTERNAL.withDescription("Encountered end-of-stream mid-frame");
trailers = new Metadata();
}
transportReportStatus(trailerStatus, false, trailers);
} else {
checkState(statusReported, "status should have been reported on deframer closed");
}
if (deframerClosedTask != null) {
deframerClosedTask.run();
deframerClosedTask = null;
@ -280,7 +294,7 @@ public abstract class AbstractClientStream extends AbstractStream
* @param headers the parsed headers
*/
protected void inboundHeadersReceived(Metadata headers) {
Preconditions.checkState(!statusReported, "Received headers on closed stream");
checkState(!statusReported, "Received headers on closed stream");
statsTraceCtx.clientInboundHeaders();
boolean compressedStream = false;
@ -361,7 +375,9 @@ public abstract class AbstractClientStream extends AbstractStream
new Object[]{status, trailers});
return;
}
transportReportStatus(status, false, trailers);
this.trailers = trailers;
trailerStatus = status;
closeDeframer(false);
}
/**
@ -454,7 +470,7 @@ public abstract class AbstractClientStream extends AbstractStream
@Override
public void writePayload(InputStream message) {
Preconditions.checkState(payload == null, "writePayload should not be called multiple times");
checkState(payload == null, "writePayload should not be called multiple times");
try {
payload = IoUtils.toByteArray(message);
} catch (java.io.IOException ex) {
@ -487,7 +503,7 @@ public abstract class AbstractClientStream extends AbstractStream
@Override
public void close() {
closed = true;
Preconditions.checkState(payload != null,
checkState(payload != null,
"Lack of request message. GET request is only supported for unary requests");
abstractClientStreamSink().writeHeaders(headers, payload);
payload = null;

View File

@ -21,6 +21,7 @@ import static io.grpc.internal.ClientStreamListener.RpcProgress.PROCESSED;
import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.AdditionalAnswers.delegatesTo;
@ -40,6 +41,7 @@ import io.grpc.Status;
import io.grpc.Status.Code;
import io.grpc.StreamTracer;
import io.grpc.internal.AbstractClientStream.TransportState;
import io.grpc.internal.ClientStreamListener.RpcProgress;
import io.grpc.internal.MessageFramerTest.ByteWritableBuffer;
import io.grpc.internal.testing.TestClientStreamTracer;
import java.io.ByteArrayInputStream;
@ -324,11 +326,50 @@ public class AbstractClientStreamTest {
stream.transportState().requestMessagesFromDeframer(1);
// Send first byte of 2 byte message
stream.transportState().deframe(ReadableBuffers.wrap(new byte[] {0, 0, 0, 0, 2, 1}));
Status status = Status.INTERNAL;
Status status = Status.INTERNAL.withDescription("rst___stream");
// Simulate getting a reset
stream.transportState().transportReportStatus(status, false /*stop delivery*/, new Metadata());
verify(mockListener).closed(any(Status.class), same(PROCESSED), any(Metadata.class));
ArgumentCaptor<Status> statusCaptor = ArgumentCaptor.forClass(Status.class);
verify(mockListener)
.closed(statusCaptor.capture(), any(RpcProgress.class), any(Metadata.class));
assertSame(Status.Code.INTERNAL, statusCaptor.getValue().getCode());
assertEquals("rst___stream", statusCaptor.getValue().getDescription());
}
@Test
public void trailerOkWithTruncatedMessage() {
AbstractClientStream stream =
new BaseAbstractClientStream(allocator, statsTraceCtx, transportTracer);
stream.start(mockListener);
stream.transportState().requestMessagesFromDeframer(1);
stream.transportState().deframe(ReadableBuffers.wrap(new byte[] {0, 0, 0, 0, 2, 1}));
stream.transportState().inboundTrailersReceived(new Metadata(), Status.OK);
ArgumentCaptor<Status> statusCaptor = ArgumentCaptor.forClass(Status.class);
verify(mockListener)
.closed(statusCaptor.capture(), any(RpcProgress.class), any(Metadata.class));
assertSame(Status.Code.INTERNAL, statusCaptor.getValue().getCode());
assertEquals("Encountered end-of-stream mid-frame", statusCaptor.getValue().getDescription());
}
@Test
public void trailerNotOkWithTruncatedMessage() {
AbstractClientStream stream =
new BaseAbstractClientStream(allocator, statsTraceCtx, transportTracer);
stream.start(mockListener);
stream.transportState().requestMessagesFromDeframer(1);
stream.transportState().deframe(ReadableBuffers.wrap(new byte[] {0, 0, 0, 0, 2, 1}));
stream.transportState().inboundTrailersReceived(
new Metadata(), Status.DATA_LOSS.withDescription("data___loss"));
ArgumentCaptor<Status> statusCaptor = ArgumentCaptor.forClass(Status.class);
verify(mockListener)
.closed(statusCaptor.capture(), any(RpcProgress.class), any(Metadata.class));
assertSame(Status.Code.DATA_LOSS, statusCaptor.getValue().getCode());
assertEquals("data___loss", statusCaptor.getValue().getDescription());
}
@Test

View File

@ -277,9 +277,9 @@ class OkHttpClientStream extends AbstractClientStream {
@Override
@GuardedBy("lock")
public void deframerClosed(boolean hasPartialMessageIgnored) {
public void deframerClosed(boolean hasPartialMessage) {
onEndOfStream();
super.deframerClosed(hasPartialMessageIgnored);
super.deframerClosed(hasPartialMessage);
}
@Override