core,netty: support GET verb in AbstractClientStream2

This commit is contained in:
Eric Anderson 2017-02-17 15:53:53 -08:00 committed by Xiao Hang
parent d4c9d5f087
commit 4096d4b668
11 changed files with 307 additions and 52 deletions

View File

@ -89,7 +89,8 @@ public class OutboundHeadersBenchmark {
@BenchmarkMode(Mode.SampleTime)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
public Http2Headers convertClientHeaders() {
return Utils.convertClientHeaders(metadata, scheme, defaultPath, authority, userAgent);
return Utils.convertClientHeaders(metadata, scheme, defaultPath, authority, Utils.HTTP_METHOD,
userAgent);
}
@Benchmark
@ -108,7 +109,8 @@ public class OutboundHeadersBenchmark {
public ByteBuf encodeClientHeaders() throws Exception {
scratchBuffer.clear();
Http2Headers headers =
Utils.convertClientHeaders(metadata, scheme, defaultPath, authority, userAgent);
Utils.convertClientHeaders(metadata, scheme, defaultPath, authority, Utils.HTTP_METHOD,
userAgent);
headersEncoder.encodeHeaders(1, headers, scratchBuffer);
return scratchBuffer;
}

View File

@ -33,8 +33,10 @@ package io.grpc.internal;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import io.grpc.Compressor;
import io.grpc.Metadata;
import io.grpc.Status;
import java.io.InputStream;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.annotation.Nullable;
@ -54,6 +56,15 @@ public abstract class AbstractClientStream2 extends AbstractStream2
* collisions/confusion. Only called from application thread.
*/
protected interface Sink {
/**
* Sends the request headers to the remote end point.
*
* @param metadata the metadata to be sent
* @param payload the payload needs to be sent in the headers if not null. Should only be used
* when sending an unary GET request
*/
void writeHeaders(Metadata metadata, @Nullable byte[] payload);
/**
* Sends an outbound frame to the remote end point.
*
@ -82,7 +93,9 @@ public abstract class AbstractClientStream2 extends AbstractStream2
void cancel(Status status);
}
private final MessageFramer framer;
private final Framer framer;
private boolean useGet;
private Metadata headers;
private boolean outboundClosed;
/**
* Whether cancel() has been called. This is not strictly necessary, but removes the delay between
@ -92,8 +105,15 @@ public abstract class AbstractClientStream2 extends AbstractStream2
private volatile boolean cancelled;
protected AbstractClientStream2(WritableBufferAllocator bufferAllocator,
StatsTraceContext statsTraceCtx) {
framer = new MessageFramer(this, bufferAllocator, statsTraceCtx);
StatsTraceContext statsTraceCtx, Metadata headers, boolean useGet) {
Preconditions.checkNotNull(headers, "headers");
this.useGet = useGet;
if (!useGet) {
framer = new MessageFramer(this, bufferAllocator, statsTraceCtx);
this.headers = headers;
} else {
framer = new GetFramer(headers, statsTraceCtx);
}
}
@Override
@ -111,8 +131,12 @@ public abstract class AbstractClientStream2 extends AbstractStream2
protected abstract TransportState transportState();
@Override
public void start(ClientStreamListener listener) {
public final void start(ClientStreamListener listener) {
transportState().setListener(listener);
if (!useGet) {
abstractClientStreamSink().writeHeaders(headers, null);
headers = null;
}
}
/**
@ -122,7 +146,7 @@ public abstract class AbstractClientStream2 extends AbstractStream2
protected abstract Sink abstractClientStreamSink();
@Override
protected final MessageFramer framer() {
protected final Framer framer() {
return framer;
}
@ -308,4 +332,71 @@ public abstract class AbstractClientStream2 extends AbstractStream2
}
}
}
public static final String GRPC_PAYLOAD_BIN_KEY = "grpc-payload-bin";
private class GetFramer implements Framer {
private Metadata headers;
private boolean closed;
private final StatsTraceContext statsTraceCtx;
private byte[] payload;
public GetFramer(Metadata headers, StatsTraceContext statsTraceCtx) {
this.headers = Preconditions.checkNotNull(headers, "headers");
this.statsTraceCtx = Preconditions.checkNotNull(statsTraceCtx, "statsTraceCtx");
}
@Override
public void writePayload(InputStream message) {
Preconditions.checkState(payload == null, "writePayload should not be called multiple times");
try {
payload = IoUtils.toByteArray(message);
} catch (java.io.IOException ex) {
throw new RuntimeException(ex);
}
}
@Override
public void flush() {}
@Override
public boolean isClosed() {
return closed;
}
/** Closes, with flush. */
@Override
public void close() {
closed = true;
Preconditions.checkState(payload != null,
"Lack of request message. GET request is only supported for unary requests");
abstractClientStreamSink().writeHeaders(headers, payload);
statsTraceCtx.wireBytesSent(payload.length);
payload = null;
headers = null;
}
/** Closes, without flush. */
@Override
public void dispose() {
closed = true;
payload = null;
headers = null;
}
// Compression is not supported for GET encoding.
@Override
public Framer setMessageCompression(boolean enable) {
return this;
}
@Override
public Framer setCompressor(Compressor compressor) {
return this;
}
// TODO(zsurocking): support this
@Override
public void setMaxOutboundMessageSize(int maxSize) {}
}
}

View File

@ -46,7 +46,7 @@ import javax.annotation.concurrent.GuardedBy;
*/
public abstract class AbstractStream2 implements Stream {
/** The framer to use for sending messages. */
protected abstract MessageFramer framer();
protected abstract Framer framer();
/**
* Obtain the transport state corresponding to this stream. Each stream must have its own unique

View File

@ -0,0 +1,66 @@
/*
* Copyright 2017, Google Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
*
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
package io.grpc.internal;
import io.grpc.Compressor;
import java.io.InputStream;
/** Interface for framing gRPC messages. */
interface Framer {
/**
* Writes out a payload message.
*
* @param message contains the message to be written out. It will be completely consumed.
*/
void writePayload(InputStream message);
/** Flush the buffered payload. */
void flush();
/** Returns whether the framer is closed. */
boolean isClosed();
/** Closes, with flush. */
void close();
/** Closes, without flush. */
void dispose();
/** Enable or disable compression. */
Framer setMessageCompression(boolean enable);
/** Set the compressor used for compression. */
Framer setCompressor(Compressor compressor);
/** Set a size limit for each outbound message. */
void setMaxOutboundMessageSize(int maxSize);
}

View File

@ -55,7 +55,7 @@ import javax.annotation.Nullable;
* Encodes gRPC messages to be delivered via the transport layer which implements {@link
* MessageFramer.Sink}.
*/
public class MessageFramer {
public class MessageFramer implements Framer {
private static final int NO_MAX_OUTBOUND_MESSAGE_SIZE = -1;
@ -104,17 +104,20 @@ public class MessageFramer {
this.statsTraceCtx = checkNotNull(statsTraceCtx, "statsTraceCtx");
}
MessageFramer setCompressor(Compressor compressor) {
@Override
public MessageFramer setCompressor(Compressor compressor) {
this.compressor = checkNotNull(compressor, "Can't pass an empty compressor");
return this;
}
MessageFramer setMessageCompression(boolean enable) {
@Override
public MessageFramer setMessageCompression(boolean enable) {
messageCompression = enable;
return this;
}
void setMaxOutboundMessageSize(int maxSize) {
@Override
public void setMaxOutboundMessageSize(int maxSize) {
checkState(maxOutboundMessageSize == NO_MAX_OUTBOUND_MESSAGE_SIZE, "max size already set");
maxOutboundMessageSize = maxSize;
}
@ -124,6 +127,7 @@ public class MessageFramer {
*
* @param message contains the message to be written out. It will be completely consumed.
*/
@Override
public void writePayload(InputStream message) {
verifyNotClosed();
boolean compressed = messageCompression && compressor != Codec.Identity.NONE;
@ -286,6 +290,7 @@ public class MessageFramer {
/**
* Flushes any buffered data in the framer to the sink.
*/
@Override
public void flush() {
if (buffer != null && buffer.readableBytes() > 0) {
commitToSink(false, true);
@ -296,6 +301,7 @@ public class MessageFramer {
* Indicates whether or not this framer has been closed via a call to either
* {@link #close()} or {@link #dispose()}.
*/
@Override
public boolean isClosed() {
return closed;
}
@ -304,6 +310,7 @@ public class MessageFramer {
* Flushes and closes the framer and releases any buffers. After the framer is closed or
* disposed, additional calls to this method will have no affect.
*/
@Override
public void close() {
if (!isClosed()) {
closed = true;
@ -320,6 +327,7 @@ public class MessageFramer {
* Closes the framer and releases any buffers, but does not flush. After the framer is
* closed or disposed, additional calls to this method will have no affect.
*/
@Override
public void dispose() {
closed = true;
releaseBuffer();

View File

@ -33,8 +33,11 @@ package io.grpc.internal;
import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import io.grpc.Attributes;
@ -44,6 +47,7 @@ import io.grpc.Status;
import io.grpc.Status.Code;
import io.grpc.internal.AbstractClientStream2.TransportState;
import io.grpc.internal.MessageFramerTest.ByteWritableBuffer;
import java.io.ByteArrayInputStream;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
@ -227,6 +231,25 @@ public class AbstractClientStream2Test {
verify(mockListener).closed(any(Status.class), any(Metadata.class));
}
@Test
public void getRequest() {
AbstractClientStream2.Sink sink = mock(AbstractClientStream2.Sink.class);
AbstractClientStream2 stream = new BaseAbstractClientStream(allocator,
new BaseTransportState(statsTraceCtx), sink, statsTraceCtx, true);
stream.start(mockListener);
stream.writeMessage(new ByteArrayInputStream(new byte[1]));
// writeHeaders will be delayed since we're sending a GET request.
verify(sink, never()).writeHeaders(any(Metadata.class), any(byte[].class));
// halfClose will trigger writeHeaders.
stream.halfClose();
ArgumentCaptor<byte[]> payloadCaptor = ArgumentCaptor.forClass(byte[].class);
verify(sink).writeHeaders(any(Metadata.class), payloadCaptor.capture());
assertTrue(payloadCaptor.getValue() != null);
// GET requests don't have BODY.
verify(sink, never())
.writeFrame(any(WritableBuffer.class), any(Boolean.class), any(Boolean.class));
}
/**
* No-op base class for testing.
@ -242,7 +265,12 @@ public class AbstractClientStream2Test {
public BaseAbstractClientStream(WritableBufferAllocator allocator, TransportState state,
Sink sink, StatsTraceContext statsTraceCtx) {
super(allocator, statsTraceCtx);
this(allocator, state, sink, statsTraceCtx, false);
}
public BaseAbstractClientStream(WritableBufferAllocator allocator, TransportState state,
Sink sink, StatsTraceContext statsTraceCtx, boolean useGet) {
super(allocator, statsTraceCtx, new Metadata(), useGet);
this.state = state;
this.sink = sink;
}
@ -273,6 +301,9 @@ public class AbstractClientStream2Test {
}
private static class BaseSink implements AbstractClientStream2.Sink {
@Override
public void writeHeaders(Metadata headers, byte[] payload) {}
@Override
public void request(int numMessages) {}

View File

@ -41,11 +41,18 @@ import io.netty.handler.codec.http2.Http2Headers;
class CreateStreamCommand extends WriteQueue.AbstractQueuedCommand {
private final Http2Headers headers;
private final NettyClientStream.TransportState stream;
private final boolean get;
CreateStreamCommand(Http2Headers headers,
NettyClientStream.TransportState stream) {
this(headers, stream, false);
}
CreateStreamCommand(Http2Headers headers,
NettyClientStream.TransportState stream, boolean get) {
this.stream = Preconditions.checkNotNull(stream, "stream");
this.headers = Preconditions.checkNotNull(headers, "headers");
this.get = get;
}
NettyClientStream.TransportState stream() {
@ -55,4 +62,8 @@ class CreateStreamCommand extends WriteQueue.AbstractQueuedCommand {
Http2Headers headers() {
return headers;
}
boolean isGet() {
return get;
}
}

View File

@ -413,7 +413,7 @@ class NettyClientHandler extends AbstractNettyHandler {
// Create an intermediate promise so that we can intercept the failure reported back to the
// application.
ChannelPromise tempPromise = ctx().newPromise();
encoder().writeHeaders(ctx(), streamId, headers, 0, false, tempPromise)
encoder().writeHeaders(ctx(), streamId, headers, 0, command.isGet(), tempPromise)
.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {

View File

@ -36,6 +36,7 @@ import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;
import static io.netty.buffer.Unpooled.EMPTY_BUFFER;
import com.google.common.io.BaseEncoding;
import io.grpc.Attributes;
import io.grpc.InternalKnownTransport;
import io.grpc.InternalMethodDescriptor;
@ -43,7 +44,6 @@ import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.Status;
import io.grpc.internal.AbstractClientStream2;
import io.grpc.internal.ClientStreamListener;
import io.grpc.internal.GrpcUtil;
import io.grpc.internal.Http2ClientStreamTransportState;
import io.grpc.internal.StatsTraceContext;
@ -69,8 +69,6 @@ class NettyClientStream extends AbstractClientStream2 {
private final TransportState state;
private final WriteQueue writeQueue;
private final MethodDescriptor<?, ?> method;
/** {@code null} after start. */
private Metadata headers;
private final Channel channel;
private AsciiString authority;
private final AsciiString scheme;
@ -80,11 +78,13 @@ class NettyClientStream extends AbstractClientStream2 {
TransportState state, MethodDescriptor<?, ?> method, Metadata headers,
Channel channel, AsciiString authority, AsciiString scheme, AsciiString userAgent,
StatsTraceContext statsTraceCtx) {
super(new NettyWritableBufferAllocator(channel.alloc()), statsTraceCtx);
super(new NettyWritableBufferAllocator(channel.alloc()),
statsTraceCtx,
headers,
useGet(method));
this.state = checkNotNull(state, "transportState");
this.writeQueue = state.handler.getWriteQueue();
this.method = checkNotNull(method, "method");
this.headers = checkNotNull(headers, "headers");
this.channel = checkNotNull(channel, "channel");
this.authority = checkNotNull(authority, "authority");
this.scheme = checkNotNull(scheme, "scheme");
@ -103,47 +103,57 @@ class NettyClientStream extends AbstractClientStream2 {
@Override
public void setAuthority(String authority) {
checkState(headers != null, "must be call before start");
this.authority = AsciiString.of(checkNotNull(authority, "authority"));
}
@Override
public void start(ClientStreamListener listener) {
super.start(listener);
// Convert the headers into Netty HTTP/2 headers.
AsciiString defaultPath = (AsciiString) methodDescriptorAccessor.geRawMethodName(method);
if (defaultPath == null) {
defaultPath = new AsciiString("/" + method.getFullMethodName());
methodDescriptorAccessor.setRawMethodName(method, defaultPath);
}
headers.discardAll(GrpcUtil.USER_AGENT_KEY);
Http2Headers http2Headers
= Utils.convertClientHeaders(headers, scheme, defaultPath, authority, userAgent);
headers = null;
ChannelFutureListener failureListener = new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
if (!future.isSuccess()) {
// Stream creation failed. Close the stream if not already closed.
Status s = transportState().statusFromFailedFuture(future);
transportState().transportReportStatus(s, true, new Metadata());
}
}
};
// Write the command requesting the creation of the stream.
writeQueue.enqueue(new CreateStreamCommand(http2Headers, transportState()),
!method.getType().clientSendsOneMessage()).addListener(failureListener);
}
@Override
public Attributes getAttributes() {
return state.handler.getAttributes();
}
private static boolean useGet(MethodDescriptor<?, ?> method) {
return method.isSafe();
}
private class Sink implements AbstractClientStream2.Sink {
@Override
public void writeHeaders(Metadata headers, byte[] requestPayload) {
// Convert the headers into Netty HTTP/2 headers.
AsciiString defaultPath = (AsciiString) methodDescriptorAccessor.geRawMethodName(method);
if (defaultPath == null) {
defaultPath = new AsciiString("/" + method.getFullMethodName());
methodDescriptorAccessor.setRawMethodName(method, defaultPath);
}
boolean get = (requestPayload != null);
AsciiString httpMethod;
if (get) {
// Forge the query string
defaultPath = new AsciiString(defaultPath + "?" + GRPC_PAYLOAD_BIN_KEY + "="
+ BaseEncoding.base64().encode(requestPayload));
httpMethod = Utils.HTTP_GET_METHOD;
} else {
httpMethod = Utils.HTTP_METHOD;
}
headers.discardAll(GrpcUtil.USER_AGENT_KEY);
Http2Headers http2Headers = Utils.convertClientHeaders(headers, scheme, defaultPath,
authority, httpMethod, userAgent);
ChannelFutureListener failureListener = new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
if (!future.isSuccess()) {
// Stream creation failed. Close the stream if not already closed.
Status s = transportState().statusFromFailedFuture(future);
transportState().transportReportStatus(s, true, new Metadata());
}
}
};
// Write the command requesting the creation of the stream.
writeQueue.enqueue(new CreateStreamCommand(http2Headers, transportState(), get),
!method.getType().clientSendsOneMessage() || get).addListener(failureListener);
}
@Override
public void writeFrame(WritableBuffer frame, boolean endOfStream, boolean flush) {
ByteBuf bytebuf = frame == null ? EMPTY_BUFFER : ((NettyWritableBuffer) frame).bytebuf();

View File

@ -65,6 +65,7 @@ class Utils {
public static final AsciiString STATUS_OK = AsciiString.of("200");
public static final AsciiString HTTP_METHOD = AsciiString.of(GrpcUtil.HTTP_METHOD);
public static final AsciiString HTTP_GET_METHOD = AsciiString.of("GET");
public static final AsciiString HTTPS = AsciiString.of("https");
public static final AsciiString HTTP = AsciiString.of("http");
public static final AsciiString CONTENT_TYPE_HEADER = AsciiString.of(CONTENT_TYPE_KEY.name());
@ -116,15 +117,17 @@ class Utils {
AsciiString scheme,
AsciiString defaultPath,
AsciiString authority,
AsciiString method,
AsciiString userAgent) {
Preconditions.checkNotNull(defaultPath, "defaultPath");
Preconditions.checkNotNull(authority, "authority");
Preconditions.checkNotNull(method, "method");
return GrpcHttp2OutboundHeaders.clientRequestHeaders(
toHttp2Headers(headers),
authority,
defaultPath,
HTTP_METHOD,
method,
scheme,
userAgent);
}

View File

@ -55,6 +55,7 @@ import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.io.BaseEncoding;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.Status;
@ -392,6 +393,38 @@ public class NettyClientStreamTest extends NettyStreamTestBase<NettyClientStream
.containsEntry(Utils.USER_AGENT, AsciiString.of("good agent"));
}
@Test
public void getRequestSentThroughHeader() {
// Creating a GET method
MethodDescriptor<?, ?> descriptor = MethodDescriptor.<Void, Void>newBuilder()
.setType(MethodDescriptor.MethodType.UNARY)
.setFullMethodName("/testService/test")
.setRequestMarshaller(marshaller)
.setResponseMarshaller(marshaller)
.setIdempotent(true)
.setSafe(true)
.build();
NettyClientStream stream = new NettyClientStream(
new TransportStateImpl(handler, DEFAULT_MAX_MESSAGE_SIZE), descriptor, new Metadata(),
channel, AsciiString.of("localhost"), AsciiString.of("http"), AsciiString.of("agent"),
StatsTraceContext.NOOP);
stream.start(listener);
stream.transportState().setId(STREAM_ID);
stream.transportState().setHttp2Stream(http2Stream);
byte[] msg = smallMessage();
stream.writeMessage(new ByteArrayInputStream(msg));
stream.flush();
stream.halfClose();
verify(writeQueue, never()).enqueue(any(SendGrpcFrameCommand.class), any(ChannelPromise.class),
any(Boolean.class));
ArgumentCaptor<CreateStreamCommand> cmdCap = ArgumentCaptor.forClass(CreateStreamCommand.class);
verify(writeQueue).enqueue(cmdCap.capture(), eq(true));
assertThat(ImmutableListMultimap.copyOf(cmdCap.getValue().headers()))
.containsEntry(AsciiString.of(":path"), AsciiString.of(
"//testService/test?grpc-payload-bin=" + BaseEncoding.base64().encode(msg)));
}
@Override
protected NettyClientStream createStream() {
when(handler.getWriteQueue()).thenReturn(writeQueue);