Adding maxMessageSize config option

Fixes #832
This commit is contained in:
nmittler 2015-08-21 16:20:01 -07:00
parent 40c66a17a9
commit 15f02ba19c
26 changed files with 295 additions and 105 deletions

View File

@ -32,19 +32,15 @@
package io.grpc.internal; package io.grpc.internal;
import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkArgument;
import static io.grpc.Status.Code.CANCELLED; import static io.grpc.internal.GrpcUtil.CANCEL_REASONS;
import static io.grpc.Status.Code.DEADLINE_EXCEEDED;
import com.google.common.base.Objects; import com.google.common.base.Objects;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.Status.Code;
import java.io.InputStream; import java.io.InputStream;
import java.util.EnumSet;
import java.util.Set;
import java.util.logging.Level; import java.util.logging.Level;
import java.util.logging.Logger; import java.util.logging.Logger;
@ -55,8 +51,6 @@ public abstract class AbstractClientStream<IdT> extends AbstractStream<IdT>
implements ClientStream { implements ClientStream {
private static final Logger log = Logger.getLogger(AbstractClientStream.class.getName()); private static final Logger log = Logger.getLogger(AbstractClientStream.class.getName());
private static final Set<Code> CANCEL_REASONS =
EnumSet.of(CANCELLED, DEADLINE_EXCEEDED, Code.INTERNAL, Code.UNKNOWN);
private final ClientStreamListener listener; private final ClientStreamListener listener;
private boolean listenerClosed; private boolean listenerClosed;
@ -67,15 +61,10 @@ public abstract class AbstractClientStream<IdT> extends AbstractStream<IdT>
private Metadata trailers; private Metadata trailers;
private Runnable closeListenerTask; private Runnable closeListenerTask;
/**
* Constructor used by subclasses.
*
* @param listener the listener to receive notifications
*/
protected AbstractClientStream(WritableBufferAllocator bufferAllocator, protected AbstractClientStream(WritableBufferAllocator bufferAllocator,
ClientStreamListener listener) { ClientStreamListener listener,
super(bufferAllocator); int maxMessageSize) {
super(bufferAllocator, maxMessageSize);
this.listener = Preconditions.checkNotNull(listener); this.listener = Preconditions.checkNotNull(listener);
} }

View File

@ -63,8 +63,9 @@ public abstract class AbstractServerStream<IdT> extends AbstractStream<IdT>
/** Saved trailers from close() that need to be sent once the framer has sent all messages. */ /** Saved trailers from close() that need to be sent once the framer has sent all messages. */
private Metadata stashedTrailers; private Metadata stashedTrailers;
protected AbstractServerStream(WritableBufferAllocator bufferAllocator) { protected AbstractServerStream(WritableBufferAllocator bufferAllocator,
super(bufferAllocator); int maxMessageSize) {
super(bufferAllocator, maxMessageSize);
} }
/** /**

View File

@ -100,7 +100,7 @@ public abstract class AbstractStream<IdT> implements Stream {
private final Object onReadyLock = new Object(); private final Object onReadyLock = new Object();
AbstractStream(WritableBufferAllocator bufferAllocator) { AbstractStream(WritableBufferAllocator bufferAllocator, int maxMessageSize) {
MessageDeframer.Listener inboundMessageHandler = new MessageDeframer.Listener() { MessageDeframer.Listener inboundMessageHandler = new MessageDeframer.Listener() {
@Override @Override
public void bytesRead(int numBytes) { public void bytesRead(int numBytes) {
@ -130,7 +130,7 @@ public abstract class AbstractStream<IdT> implements Stream {
}; };
framer = new MessageFramer(outboundFrameHandler, bufferAllocator); framer = new MessageFramer(outboundFrameHandler, bufferAllocator);
deframer = new MessageDeframer(inboundMessageHandler); deframer = new MessageDeframer(inboundMessageHandler, MessageEncoding.NONE, maxMessageSize);
} }
@Override @Override

View File

@ -31,6 +31,9 @@
package io.grpc.internal; package io.grpc.internal;
import static io.grpc.Status.Code.CANCELLED;
import static io.grpc.Status.Code.DEADLINE_EXCEEDED;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
@ -38,6 +41,8 @@ import io.grpc.Metadata;
import io.grpc.Status; import io.grpc.Status;
import java.net.HttpURLConnection; import java.net.HttpURLConnection;
import java.util.EnumSet;
import java.util.Set;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ThreadFactory; import java.util.concurrent.ThreadFactory;
@ -109,6 +114,17 @@ public final class GrpcUtil {
*/ */
public static final String MESSAGE_ENCODING = "grpc-encoding"; public static final String MESSAGE_ENCODING = "grpc-encoding";
/**
* The default maximum uncompressed size (in bytes) for inbound messages. Defaults to 100 MiB.
*/
public static final int DEFAULT_MAX_MESSAGE_SIZE = 100 * 1024 * 1024;
/**
* The set of valid status codes for client cancellation.
*/
public static final Set<Status.Code> CANCEL_REASONS =
EnumSet.of(CANCELLED, DEADLINE_EXCEEDED, Status.Code.INTERNAL, Status.Code.UNKNOWN);
/** /**
* Maps HTTP error response status codes to transport codes. * Maps HTTP error response status codes to transport codes.
*/ */

View File

@ -71,8 +71,9 @@ public abstract class Http2ClientStream extends AbstractClientStream<Integer> {
private boolean contentTypeChecked; private boolean contentTypeChecked;
protected Http2ClientStream(WritableBufferAllocator bufferAllocator, protected Http2ClientStream(WritableBufferAllocator bufferAllocator,
ClientStreamListener listener) { ClientStreamListener listener,
super(bufferAllocator, listener); int maxMessageSize) {
super(bufferAllocator, listener, maxMessageSize);
} }
/** /**

View File

@ -39,6 +39,7 @@ import io.grpc.MessageEncoding;
import io.grpc.Status; import io.grpc.Status;
import java.io.Closeable; import java.io.Closeable;
import java.io.FilterInputStream;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
@ -94,6 +95,7 @@ public class MessageDeframer implements Closeable {
} }
private final Listener listener; private final Listener listener;
private final int maxMessageSize;
private MessageEncoding.Decompressor decompressor; private MessageEncoding.Decompressor decompressor;
private State state = State.HEADER; private State state = State.HEADER;
private int requiredLength = HEADER_LENGTH; private int requiredLength = HEADER_LENGTH;
@ -105,25 +107,19 @@ public class MessageDeframer implements Closeable {
private boolean deliveryStalled = true; private boolean deliveryStalled = true;
private boolean inDelivery = false; private boolean inDelivery = false;
/**
* Creates a deframer. Compression will not be supported.
*
* @param listener listener for deframer events.
*/
public MessageDeframer(Listener listener) {
this(listener, MessageEncoding.NONE);
}
/** /**
* Create a deframer. * Create a deframer.
* *
* @param listener listener for deframer events. * @param listener listener for deframer events.
* @param decompressor the compression used if a compressed frame is encountered, with * @param decompressor the compression used if a compressed frame is encountered, with
* {@code NONE} meaning unsupported * {@code NONE} meaning unsupported
* @param maxMessageSize the maximum allowed size for received messages.
*/ */
public MessageDeframer(Listener listener, MessageEncoding.Decompressor decompressor) { public MessageDeframer(Listener listener, MessageEncoding.Decompressor decompressor,
this.listener = Preconditions.checkNotNull(listener, "listener"); int maxMessageSize) {
this.listener = Preconditions.checkNotNull(listener, "sink");
this.decompressor = Preconditions.checkNotNull(decompressor, "decompressor"); this.decompressor = Preconditions.checkNotNull(decompressor, "decompressor");
this.maxMessageSize = maxMessageSize;
} }
/** /**
@ -162,8 +158,7 @@ public class MessageDeframer implements Closeable {
* the remote endpoint. End of stream should not be used in the event of a transport * the remote endpoint. End of stream should not be used in the event of a transport
* error, such as a stream reset. * error, such as a stream reset.
* @throws IllegalStateException if {@link #close()} has been called previously or if * @throws IllegalStateException if {@link #close()} has been called previously or if
* {@link #deframe(ReadableBuffer, boolean)} has previously been called with * this method has previously been called with {@code endOfStream=true}.
* {@code endOfStream=true}.
*/ */
public void deframe(ReadableBuffer data, boolean endOfStream) { public void deframe(ReadableBuffer data, boolean endOfStream) {
Preconditions.checkNotNull(data, "data"); Preconditions.checkNotNull(data, "data");
@ -291,10 +286,6 @@ public class MessageDeframer implements Closeable {
} }
} }
private boolean isDataAvailable() {
return unprocessed.readableBytes() > 0 || (nextFrame != null && nextFrame.readableBytes() > 0);
}
/** /**
* Attempts to read the required bytes into nextFrame. * Attempts to read the required bytes into nextFrame.
* *
@ -340,6 +331,10 @@ public class MessageDeframer implements Closeable {
// Update the required length to include the length of the frame. // Update the required length to include the length of the frame.
requiredLength = nextFrame.readInt(); requiredLength = nextFrame.readInt();
if (requiredLength < 0 || requiredLength > maxMessageSize) {
throw Status.INTERNAL.withDescription(String.format("Frame size %d exceeds maximum: %d, ",
requiredLength, maxMessageSize)).asRuntimeException();
}
// Continue reading the frame body. // Continue reading the frame body.
state = State.BODY; state = State.BODY;
@ -370,9 +365,79 @@ public class MessageDeframer implements Closeable {
} }
try { try {
return decompressor.decompress(ReadableBuffers.openStream(nextFrame, true)); // Enforce the maxMessageSize limit on the returned stream.
return new SizeEnforcingInputStream(decompressor.decompress(
ReadableBuffers.openStream(nextFrame, true)));
} catch (IOException e) { } catch (IOException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
} }
/**
* An {@link InputStream} that enforces the {@link #maxMessageSize} limit for compressed frames.
*/
private final class SizeEnforcingInputStream extends FilterInputStream {
private long count;
private long mark = -1;
public SizeEnforcingInputStream(InputStream in) {
super(in);
}
@Override
public int read() throws IOException {
int result = in.read();
if (result != -1) {
count++;
}
verifySize();
return result;
}
@Override
public int read(byte[] b, int off, int len) throws IOException {
int result = in.read(b, off, len);
if (result != -1) {
count += result;
}
verifySize();
return result;
}
@Override
public long skip(long n) throws IOException {
long result = in.skip(n);
count += result;
verifySize();
return result;
}
@Override
public synchronized void mark(int readlimit) {
in.mark(readlimit);
mark = count;
// it's okay to mark even if mark isn't supported, as reset won't work
}
@Override
public synchronized void reset() throws IOException {
if (!in.markSupported()) {
throw new IOException("Mark not supported");
}
if (mark == -1) {
throw new IOException("Mark not set");
}
in.reset();
count = mark;
}
private void verifySize() {
if (count > maxMessageSize) {
throw Status.INTERNAL.withDescription(String.format(
"Compressed frame exceeds maximum frame size: %d. Bytes read: %d",
maxMessageSize, count)).asRuntimeException();
}
}
}
} }

View File

@ -31,6 +31,7 @@
package io.grpc.internal; package io.grpc.internal;
import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail; import static org.junit.Assert.fail;
import static org.mockito.Matchers.isA; import static org.mockito.Matchers.isA;
@ -242,7 +243,7 @@ public class AbstractClientStreamTest {
private static class BaseAbstractClientStream<T> extends AbstractClientStream<T> { private static class BaseAbstractClientStream<T> extends AbstractClientStream<T> {
protected BaseAbstractClientStream( protected BaseAbstractClientStream(
WritableBufferAllocator allocator, ClientStreamListener listener) { WritableBufferAllocator allocator, ClientStreamListener listener) {
super(allocator, listener); super(allocator, listener, DEFAULT_MAX_MESSAGE_SIZE);
} }
@Override @Override

View File

@ -32,6 +32,7 @@
package io.grpc.internal; package io.grpc.internal;
import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
import static org.junit.Assert.fail; import static org.junit.Assert.fail;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
@ -102,7 +103,7 @@ public class AbstractStreamTest {
*/ */
private class AbstractStreamBase<IdT> extends AbstractStream<IdT> { private class AbstractStreamBase<IdT> extends AbstractStream<IdT> {
private AbstractStreamBase(WritableBufferAllocator bufferAllocator) { private AbstractStreamBase(WritableBufferAllocator bufferAllocator) {
super(bufferAllocator); super(bufferAllocator, DEFAULT_MAX_MESSAGE_SIZE);
} }
@Override @Override

View File

@ -31,6 +31,7 @@
package io.grpc.internal; package io.grpc.internal;
import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.anyInt; import static org.mockito.Matchers.anyInt;
@ -67,7 +68,8 @@ import java.util.zip.GZIPOutputStream;
@RunWith(JUnit4.class) @RunWith(JUnit4.class)
public class MessageDeframerTest { public class MessageDeframerTest {
private Listener listener = mock(Listener.class); private Listener listener = mock(Listener.class);
private MessageDeframer deframer = new MessageDeframer(listener); private MessageDeframer deframer = new MessageDeframer(listener, MessageEncoding.NONE,
DEFAULT_MAX_MESSAGE_SIZE);
private ArgumentCaptor<InputStream> messages = ArgumentCaptor.forClass(InputStream.class); private ArgumentCaptor<InputStream> messages = ArgumentCaptor.forClass(InputStream.class);
@Test @Test
@ -176,7 +178,7 @@ public class MessageDeframerTest {
@Test @Test
public void compressed() { public void compressed() {
deframer = new MessageDeframer(listener, new MessageEncoding.Gzip()); deframer = new MessageDeframer(listener, new MessageEncoding.Gzip(), DEFAULT_MAX_MESSAGE_SIZE);
deframer.request(1); deframer.request(1);
byte[] payload = compress(new byte[1000]); byte[] payload = compress(new byte[1000]);

View File

@ -31,15 +31,12 @@
package io.grpc.netty; package io.grpc.netty;
import static io.grpc.Status.Code.CANCELLED; import static io.grpc.internal.GrpcUtil.CANCEL_REASONS;
import static io.grpc.Status.Code.DEADLINE_EXCEEDED;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import io.grpc.Status; import io.grpc.Status;
import java.util.EnumSet;
/** /**
* Command sent from a Netty client stream to the handler to cancel the stream. * Command sent from a Netty client stream to the handler to cancel the stream.
*/ */
@ -50,7 +47,7 @@ class CancelClientStreamCommand {
CancelClientStreamCommand(NettyClientStream stream, Status reason) { CancelClientStreamCommand(NettyClientStream stream, Status reason) {
this.stream = Preconditions.checkNotNull(stream, "stream"); this.stream = Preconditions.checkNotNull(stream, "stream");
Preconditions.checkNotNull(reason); Preconditions.checkNotNull(reason);
Preconditions.checkArgument(EnumSet.of(CANCELLED, DEADLINE_EXCEEDED).contains(reason.getCode()), Preconditions.checkArgument(CANCEL_REASONS.contains(reason.getCode()),
"Invalid cancellation reason"); "Invalid cancellation reason");
this.reason = reason; this.reason = reason;
} }

View File

@ -31,6 +31,9 @@
package io.grpc.netty; package io.grpc.netty;
import static com.google.common.base.Preconditions.checkArgument;
import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import io.grpc.AbstractChannelBuilder; import io.grpc.AbstractChannelBuilder;
@ -38,6 +41,7 @@ import io.grpc.internal.AbstractReferenceCounted;
import io.grpc.internal.ClientTransport; import io.grpc.internal.ClientTransport;
import io.grpc.internal.ClientTransportFactory; import io.grpc.internal.ClientTransportFactory;
import io.grpc.internal.SharedResourceHolder; import io.grpc.internal.SharedResourceHolder;
import io.netty.channel.Channel; import io.netty.channel.Channel;
import io.netty.channel.EventLoopGroup; import io.netty.channel.EventLoopGroup;
import io.netty.channel.socket.nio.NioSocketChannel; import io.netty.channel.socket.nio.NioSocketChannel;
@ -62,6 +66,7 @@ public final class NettyChannelBuilder extends AbstractChannelBuilder<NettyChann
private EventLoopGroup eventLoopGroup; private EventLoopGroup eventLoopGroup;
private SslContext sslContext; private SslContext sslContext;
private int flowControlWindow = DEFAULT_FLOW_CONTROL_WINDOW; private int flowControlWindow = DEFAULT_FLOW_CONTROL_WINDOW;
private int maxMessageSize = DEFAULT_MAX_MESSAGE_SIZE;
/** /**
* Creates a new builder with the given server address. * Creates a new builder with the given server address.
@ -132,10 +137,20 @@ public final class NettyChannelBuilder extends AbstractChannelBuilder<NettyChann
return this; return this;
} }
/**
* Sets the maximum message size allowed to be received on the channel. If not called,
* defaults to {@link io.grpc.internal.GrpcUtil#DEFAULT_MAX_MESSAGE_SIZE}.
*/
public NettyChannelBuilder maxMessageSize(int maxMessageSize) {
checkArgument(maxMessageSize >= 0, "maxMessageSize must be >= 0");
this.maxMessageSize = maxMessageSize;
return this;
}
@Override @Override
protected ClientTransportFactory buildTransportFactory() { protected ClientTransportFactory buildTransportFactory() {
return new NettyTransportFactory(serverAddress, channelType, eventLoopGroup, flowControlWindow, return new NettyTransportFactory(serverAddress, channelType, eventLoopGroup, flowControlWindow,
createProtocolNegotiator()); createProtocolNegotiator(), maxMessageSize);
} }
private ProtocolNegotiator createProtocolNegotiator() { private ProtocolNegotiator createProtocolNegotiator() {
@ -170,16 +185,19 @@ public final class NettyChannelBuilder extends AbstractChannelBuilder<NettyChann
private final boolean usingSharedGroup; private final boolean usingSharedGroup;
private final int flowControlWindow; private final int flowControlWindow;
private final ProtocolNegotiator negotiator; private final ProtocolNegotiator negotiator;
private final int maxMessageSize;
private NettyTransportFactory(SocketAddress serverAddress, private NettyTransportFactory(SocketAddress serverAddress,
Class<? extends Channel> channelType, Class<? extends Channel> channelType,
EventLoopGroup group, EventLoopGroup group,
int flowControlWindow, int flowControlWindow,
ProtocolNegotiator negotiator) { ProtocolNegotiator negotiator,
int maxMessageSize) {
this.serverAddress = serverAddress; this.serverAddress = serverAddress;
this.channelType = channelType; this.channelType = channelType;
this.flowControlWindow = flowControlWindow; this.flowControlWindow = flowControlWindow;
this.negotiator = negotiator; this.negotiator = negotiator;
this.maxMessageSize = maxMessageSize;
usingSharedGroup = group == null; usingSharedGroup = group == null;
if (usingSharedGroup) { if (usingSharedGroup) {
@ -193,7 +211,7 @@ public final class NettyChannelBuilder extends AbstractChannelBuilder<NettyChann
@Override @Override
public ClientTransport newClientTransport() { public ClientTransport newClientTransport() {
return new NettyClientTransport(serverAddress, channelType, group, negotiator, return new NettyClientTransport(serverAddress, channelType, group, negotiator,
flowControlWindow); flowControlWindow, maxMessageSize);
} }
@Override @Override

View File

@ -58,8 +58,9 @@ class NettyClientStream extends Http2ClientStream {
private Integer id; private Integer id;
private WriteQueue writeQueue; private WriteQueue writeQueue;
NettyClientStream(ClientStreamListener listener, Channel channel, NettyClientHandler handler) { NettyClientStream(ClientStreamListener listener, Channel channel, NettyClientHandler handler,
super(new NettyWritableBufferAllocator(channel.alloc()), listener); int maxMessageSize) {
super(new NettyWritableBufferAllocator(channel.alloc()), listener, maxMessageSize);
this.writeQueue = handler.getWriteQueue(); this.writeQueue = handler.getWriteQueue();
this.channel = checkNotNull(channel, "channel"); this.channel = checkNotNull(channel, "channel");
this.handler = checkNotNull(handler, "handler"); this.handler = checkNotNull(handler, "handler");

View File

@ -81,6 +81,7 @@ class NettyClientTransport implements ClientTransport {
private final NettyClientHandler handler; private final NettyClientHandler handler;
private final AsciiString authority; private final AsciiString authority;
private final int flowControlWindow; private final int flowControlWindow;
private final int maxMessageSize;
// We should not send on the channel until negotiation completes. This is a hard requirement // We should not send on the channel until negotiation completes. This is a hard requirement
// by SslHandler but is appropriate for HTTP/1.1 Upgrade as well. // by SslHandler but is appropriate for HTTP/1.1 Upgrade as well.
private Channel channel; private Channel channel;
@ -94,12 +95,13 @@ class NettyClientTransport implements ClientTransport {
NettyClientTransport(SocketAddress address, Class<? extends Channel> channelType, NettyClientTransport(SocketAddress address, Class<? extends Channel> channelType,
EventLoopGroup group, ProtocolNegotiator negotiator, EventLoopGroup group, ProtocolNegotiator negotiator,
int flowControlWindow) { int flowControlWindow, int maxMessageSize) {
Preconditions.checkNotNull(negotiator, "negotiator"); Preconditions.checkNotNull(negotiator, "negotiator");
this.address = Preconditions.checkNotNull(address, "address"); this.address = Preconditions.checkNotNull(address, "address");
this.group = Preconditions.checkNotNull(group, "group"); this.group = Preconditions.checkNotNull(group, "group");
this.channelType = Preconditions.checkNotNull(channelType, "channelType"); this.channelType = Preconditions.checkNotNull(channelType, "channelType");
this.flowControlWindow = flowControlWindow; this.flowControlWindow = flowControlWindow;
this.maxMessageSize = maxMessageSize;
if (address instanceof InetSocketAddress) { if (address instanceof InetSocketAddress) {
InetSocketAddress inetAddress = (InetSocketAddress) address; InetSocketAddress inetAddress = (InetSocketAddress) address;
@ -128,7 +130,8 @@ class NettyClientTransport implements ClientTransport {
Preconditions.checkNotNull(listener, "listener"); Preconditions.checkNotNull(listener, "listener");
// Create the stream. // Create the stream.
final NettyClientStream stream = new NettyClientStream(listener, channel, handler); final NettyClientStream stream = new NettyClientStream(listener, channel, handler,
maxMessageSize);
// Convert the headers into Netty HTTP/2 headers. // Convert the headers into Netty HTTP/2 headers.
AsciiString defaultPath = new AsciiString("/" + method.getFullMethodName()); AsciiString defaultPath = new AsciiString("/" + method.getFullMethodName());

View File

@ -72,11 +72,13 @@ public class NettyServer implements Server {
private EventLoopGroup workerGroup; private EventLoopGroup workerGroup;
private ServerListener listener; private ServerListener listener;
private Channel channel; private Channel channel;
private int flowControlWindow; private final int flowControlWindow;
private final int maxMessageSize;
NettyServer(SocketAddress address, Class<? extends ServerChannel> channelType, NettyServer(SocketAddress address, Class<? extends ServerChannel> channelType,
@Nullable EventLoopGroup bossGroup, @Nullable EventLoopGroup workerGroup, @Nullable EventLoopGroup bossGroup, @Nullable EventLoopGroup workerGroup,
@Nullable SslContext sslContext, int maxStreamsPerConnection, int flowControlWindow) { @Nullable SslContext sslContext, int maxStreamsPerConnection, int flowControlWindow,
int maxMessageSize) {
this.address = address; this.address = address;
this.channelType = checkNotNull(channelType, "channelType"); this.channelType = checkNotNull(channelType, "channelType");
this.bossGroup = bossGroup; this.bossGroup = bossGroup;
@ -86,6 +88,7 @@ public class NettyServer implements Server {
this.usingSharedWorkerGroup = workerGroup == null; this.usingSharedWorkerGroup = workerGroup == null;
this.maxStreamsPerConnection = maxStreamsPerConnection; this.maxStreamsPerConnection = maxStreamsPerConnection;
this.flowControlWindow = flowControlWindow; this.flowControlWindow = flowControlWindow;
this.maxMessageSize = maxMessageSize;
} }
@Override @Override
@ -106,7 +109,8 @@ public class NettyServer implements Server {
@Override @Override
public void initChannel(Channel ch) throws Exception { public void initChannel(Channel ch) throws Exception {
NettyServerTransport transport NettyServerTransport transport
= new NettyServerTransport(ch, sslContext, maxStreamsPerConnection, flowControlWindow); = new NettyServerTransport(ch, sslContext, maxStreamsPerConnection, flowControlWindow,
maxMessageSize);
transport.start(listener.transportCreated(transport)); transport.start(listener.transportCreated(transport));
} }
}); });

View File

@ -31,6 +31,9 @@
package io.grpc.netty; package io.grpc.netty;
import static com.google.common.base.Preconditions.checkArgument;
import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import io.grpc.AbstractServerBuilder; import io.grpc.AbstractServerBuilder;
@ -61,6 +64,7 @@ public final class NettyServerBuilder extends AbstractServerBuilder<NettyServerB
private SslContext sslContext; private SslContext sslContext;
private int maxConcurrentCallsPerConnection = Integer.MAX_VALUE; private int maxConcurrentCallsPerConnection = Integer.MAX_VALUE;
private int flowControlWindow = DEFAULT_FLOW_CONTROL_WINDOW; private int flowControlWindow = DEFAULT_FLOW_CONTROL_WINDOW;
private int maxMessageSize = DEFAULT_MAX_MESSAGE_SIZE;
/** /**
* Creates a server builder that will bind to the given port. * Creates a server builder that will bind to the given port.
@ -190,9 +194,20 @@ public final class NettyServerBuilder extends AbstractServerBuilder<NettyServerB
return this; return this;
} }
/**
* Sets the maximum message size allowed to be received on the server. If not called,
* defaults to 100 MiB.
*/
public NettyServerBuilder maxMessageSize(int maxMessageSize) {
checkArgument(maxMessageSize >= 0, "maxMessageSize must be >= 0");
this.maxMessageSize = maxMessageSize;
return this;
}
@Override @Override
protected NettyServer buildTransportServer() { protected NettyServer buildTransportServer() {
return new NettyServer(address, channelType, bossEventLoopGroup, return new NettyServer(address, channelType, bossEventLoopGroup,
workerEventLoopGroup, sslContext, maxConcurrentCallsPerConnection, flowControlWindow); workerEventLoopGroup, sslContext, maxConcurrentCallsPerConnection, flowControlWindow,
maxMessageSize);
} }
} }

View File

@ -31,6 +31,7 @@
package io.grpc.netty; package io.grpc.netty;
import static com.google.common.base.Preconditions.checkArgument;
import static io.grpc.netty.Utils.CONTENT_TYPE_GRPC; import static io.grpc.netty.Utils.CONTENT_TYPE_GRPC;
import static io.grpc.netty.Utils.CONTENT_TYPE_HEADER; import static io.grpc.netty.Utils.CONTENT_TYPE_HEADER;
import static io.grpc.netty.Utils.HTTP_METHOD; import static io.grpc.netty.Utils.HTTP_METHOD;
@ -81,6 +82,7 @@ class NettyServerHandler extends Http2ConnectionHandler {
private final Http2Connection.PropertyKey streamKey; private final Http2Connection.PropertyKey streamKey;
private final ServerTransportListener transportListener; private final ServerTransportListener transportListener;
private final int maxMessageSize;
private Throwable connectionError; private Throwable connectionError;
private ChannelHandlerContext ctx; private ChannelHandlerContext ctx;
private boolean teWarningLogged; private boolean teWarningLogged;
@ -93,10 +95,13 @@ class NettyServerHandler extends Http2ConnectionHandler {
Http2FrameReader frameReader, Http2FrameReader frameReader,
Http2FrameWriter frameWriter, Http2FrameWriter frameWriter,
int maxStreams, int maxStreams,
int flowControlWindow) { int flowControlWindow,
int maxMessageSize) {
super(connection, frameReader, frameWriter, new LazyFrameListener()); super(connection, frameReader, frameWriter, new LazyFrameListener());
Preconditions.checkArgument(flowControlWindow > 0, "flowControlWindow must be positive"); Preconditions.checkArgument(flowControlWindow > 0, "flowControlWindow must be positive");
this.flowControlWindow = flowControlWindow; this.flowControlWindow = flowControlWindow;
checkArgument(maxMessageSize >= 0, "maxMessageSize must be >= 0");
this.maxMessageSize = maxMessageSize;
streamKey = connection.newKey(); streamKey = connection.newKey();
this.transportListener = Preconditions.checkNotNull(transportListener, "transportListener"); this.transportListener = Preconditions.checkNotNull(transportListener, "transportListener");
@ -146,7 +151,8 @@ class NettyServerHandler extends Http2ConnectionHandler {
// The Http2Stream object was put by AbstractHttp2ConnectionHandler before calling this // The Http2Stream object was put by AbstractHttp2ConnectionHandler before calling this
// method. // method.
Http2Stream http2Stream = requireHttp2Stream(streamId); Http2Stream http2Stream = requireHttp2Stream(streamId);
NettyServerStream stream = new NettyServerStream(ctx.channel(), http2Stream, this); NettyServerStream stream = new NettyServerStream(ctx.channel(), http2Stream, this,
maxMessageSize);
http2Stream.setProperty(streamKey, stream); http2Stream.setProperty(streamKey, stream);
String method = determineMethod(streamId, headers); String method = determineMethod(streamId, headers);

View File

@ -54,8 +54,9 @@ class NettyServerStream extends AbstractServerStream<Integer> {
private final Http2Stream http2Stream; private final Http2Stream http2Stream;
private final WriteQueue writeQueue; private final WriteQueue writeQueue;
NettyServerStream(Channel channel, Http2Stream http2Stream, NettyServerHandler handler) { NettyServerStream(Channel channel, Http2Stream http2Stream, NettyServerHandler handler,
super(new NettyWritableBufferAllocator(channel.alloc())); int maxMessageSize) {
super(new NettyWritableBufferAllocator(channel.alloc()), maxMessageSize);
this.writeQueue = handler.getWriteQueue(); this.writeQueue = handler.getWriteQueue();
this.channel = checkNotNull(channel, "channel"); this.channel = checkNotNull(channel, "channel");
this.http2Stream = checkNotNull(http2Stream, "http2Stream"); this.http2Stream = checkNotNull(http2Stream, "http2Stream");

View File

@ -67,14 +67,16 @@ class NettyServerTransport implements ServerTransport {
private final int maxStreams; private final int maxStreams;
private ServerTransportListener listener; private ServerTransportListener listener;
private boolean terminated; private boolean terminated;
private int flowControlWindow; private final int flowControlWindow;
private final int maxMessageSize;
NettyServerTransport(Channel channel, @Nullable SslContext sslContext, int maxStreams, NettyServerTransport(Channel channel, @Nullable SslContext sslContext, int maxStreams,
int flowControlWindow) { int flowControlWindow, int maxMessageSize) {
this.channel = Preconditions.checkNotNull(channel, "channel"); this.channel = Preconditions.checkNotNull(channel, "channel");
this.sslContext = sslContext; this.sslContext = sslContext;
this.maxStreams = maxStreams; this.maxStreams = maxStreams;
this.flowControlWindow = flowControlWindow; this.flowControlWindow = flowControlWindow;
this.maxMessageSize = maxMessageSize;
} }
public void start(ServerTransportListener listener) { public void start(ServerTransportListener listener) {
@ -135,6 +137,6 @@ class NettyServerTransport implements ServerTransport {
new Http2OutboundFrameLogger(new DefaultHttp2FrameWriter(), frameLogger); new Http2OutboundFrameLogger(new DefaultHttp2FrameWriter(), frameLogger);
return new NettyServerHandler(transportListener, connection, frameReader, frameWriter, return new NettyServerHandler(transportListener, connection, frameReader, frameWriter,
maxStreams, flowControlWindow); maxStreams, flowControlWindow, maxMessageSize);
} }
} }

View File

@ -31,6 +31,7 @@
package io.grpc.netty; package io.grpc.netty;
import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
import static io.grpc.netty.NettyTestUtil.messageFrame; import static io.grpc.netty.NettyTestUtil.messageFrame;
import static io.grpc.netty.Utils.CONTENT_TYPE_GRPC; import static io.grpc.netty.Utils.CONTENT_TYPE_GRPC;
import static io.grpc.netty.Utils.CONTENT_TYPE_HEADER; import static io.grpc.netty.Utils.CONTENT_TYPE_HEADER;
@ -54,6 +55,7 @@ import static org.mockito.Mockito.when;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.internal.ClientStreamListener; import io.grpc.internal.ClientStreamListener;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled; import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelPromise; import io.netty.channel.ChannelPromise;
@ -321,7 +323,7 @@ public class NettyClientStreamTest extends NettyStreamTestBase {
@Test @Test
public void setHttp2StreamShouldNotifyReady() { public void setHttp2StreamShouldNotifyReady() {
listener = mock(ClientStreamListener.class); listener = mock(ClientStreamListener.class);
stream = new NettyClientStream(listener, channel, handler); stream = new NettyClientStream(listener, channel, handler, DEFAULT_MAX_MESSAGE_SIZE);
stream().id(STREAM_ID); stream().id(STREAM_ID);
verify(listener, never()).onReady(); verify(listener, never()).onReady();
assertFalse(stream.isReady()); assertFalse(stream.isReady());
@ -343,7 +345,8 @@ public class NettyClientStreamTest extends NettyStreamTestBase {
} }
}).when(writeQueue).enqueue(any(), any(ChannelPromise.class), anyBoolean()); }).when(writeQueue).enqueue(any(), any(ChannelPromise.class), anyBoolean());
when(writeQueue.enqueue(any(), anyBoolean())).thenReturn(future); when(writeQueue.enqueue(any(), anyBoolean())).thenReturn(future);
NettyClientStream stream = new NettyClientStream(listener, channel, handler); NettyClientStream stream = new NettyClientStream(listener, channel, handler,
DEFAULT_MAX_MESSAGE_SIZE);
assertTrue(stream.canSend()); assertTrue(stream.canSend());
assertTrue(stream.canReceive()); assertTrue(stream.canReceive());
stream.id(STREAM_ID); stream.id(STREAM_ID);

View File

@ -32,9 +32,12 @@
package io.grpc.netty; package io.grpc.netty;
import static com.google.common.base.Charsets.UTF_8; import static com.google.common.base.Charsets.UTF_8;
import static io.grpc.Status.Code.INTERNAL;
import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
import static io.grpc.internal.GrpcUtil.USER_AGENT_KEY; import static io.grpc.internal.GrpcUtil.USER_AGENT_KEY;
import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_WINDOW_SIZE; import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_WINDOW_SIZE;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail; import static org.junit.Assert.fail;
import com.google.common.io.ByteStreams; import com.google.common.io.ByteStreams;
@ -152,6 +155,25 @@ public class NettyClientTransportTest {
receivedHeaders.get(USER_AGENT_KEY)); receivedHeaders.get(USER_AGENT_KEY));
} }
@Test
public void maxMessageSizeShouldBeEnforced() throws Throwable {
startServer();
// Allow the response payloads of up to 1 byte.
NettyClientTransport transport = newTransport(newNegotiator(), 1);
transport.start(clientTransportListener);
try {
// Send a single RPC and wait for the response.
new Rpc(transport).halfClose().waitForResponse();
fail("Expected the stream to fail.");
} catch (ExecutionException e) {
Status status = Status.fromThrowable(e);
assertEquals(INTERNAL, status.getCode());
System.err.println(status.getDescription());
assertTrue(status.getDescription().contains("deframing"));
}
}
/** /**
* Verifies that we can create multiple TLS client transports from the same builder. * Verifies that we can create multiple TLS client transports from the same builder.
*/ */
@ -218,8 +240,12 @@ public class NettyClientTransportTest {
} }
private NettyClientTransport newTransport(ProtocolNegotiator negotiator) { private NettyClientTransport newTransport(ProtocolNegotiator negotiator) {
return newTransport(negotiator, DEFAULT_MAX_MESSAGE_SIZE);
}
private NettyClientTransport newTransport(ProtocolNegotiator negotiator, int maxMsgSize) {
NettyClientTransport transport = new NettyClientTransport(address, NioSocketChannel.class, NettyClientTransport transport = new NettyClientTransport(address, NioSocketChannel.class,
group, negotiator, DEFAULT_WINDOW_SIZE); group, negotiator, DEFAULT_WINDOW_SIZE, maxMsgSize);
transports.add(transport); transports.add(transport);
return transport; return transport;
} }
@ -235,7 +261,7 @@ public class NettyClientTransportTest {
.ciphers(TestUtils.preferredTestCiphers(), SupportedCipherSuiteFilter.INSTANCE).build(); .ciphers(TestUtils.preferredTestCiphers(), SupportedCipherSuiteFilter.INSTANCE).build();
server = new NettyServer(address, NioServerSocketChannel.class, server = new NettyServer(address, NioServerSocketChannel.class,
group, group, serverContext, maxStreamsPerConnection, group, group, serverContext, maxStreamsPerConnection,
DEFAULT_WINDOW_SIZE); DEFAULT_WINDOW_SIZE, DEFAULT_MAX_MESSAGE_SIZE);
server.start(serverListener); server.start(serverListener);
} }

View File

@ -32,6 +32,7 @@
package io.grpc.netty; package io.grpc.netty;
import static com.google.common.base.Charsets.UTF_8; import static com.google.common.base.Charsets.UTF_8;
import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
import static io.grpc.netty.Utils.CONTENT_TYPE_GRPC; import static io.grpc.netty.Utils.CONTENT_TYPE_GRPC;
import static io.grpc.netty.Utils.CONTENT_TYPE_HEADER; import static io.grpc.netty.Utils.CONTENT_TYPE_HEADER;
import static io.grpc.netty.Utils.HTTP_METHOD; import static io.grpc.netty.Utils.HTTP_METHOD;
@ -64,6 +65,7 @@ import io.grpc.internal.ServerStream;
import io.grpc.internal.ServerStreamListener; import io.grpc.internal.ServerStreamListener;
import io.grpc.internal.ServerTransportListener; import io.grpc.internal.ServerTransportListener;
import io.grpc.internal.WritableBuffer; import io.grpc.internal.WritableBuffer;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.Unpooled; import io.netty.buffer.Unpooled;
@ -303,7 +305,7 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase {
Http2Connection connection = new DefaultHttp2Connection(true); Http2Connection connection = new DefaultHttp2Connection(true);
handler = handler =
new NettyServerHandler(transportListener, connection, new DefaultHttp2FrameReader(), new NettyServerHandler(transportListener, connection, new DefaultHttp2FrameReader(),
frameWriter, maxConcurrentStreams, DEFAULT_WINDOW_SIZE); frameWriter, maxConcurrentStreams, DEFAULT_WINDOW_SIZE, DEFAULT_MAX_MESSAGE_SIZE);
when(channel.isActive()).thenReturn(true); when(channel.isActive()).thenReturn(true);
mockContext(); mockContext();
@ -403,7 +405,7 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase {
Http2FrameReader frameReader = new DefaultHttp2FrameReader(); Http2FrameReader frameReader = new DefaultHttp2FrameReader();
Http2FrameWriter frameWriter = new DefaultHttp2FrameWriter(); Http2FrameWriter frameWriter = new DefaultHttp2FrameWriter();
return new NettyServerHandler(transportListener, connection, frameReader, frameWriter, return new NettyServerHandler(transportListener, connection, frameReader, frameWriter,
Integer.MAX_VALUE, flowControlWindow); Integer.MAX_VALUE, flowControlWindow, DEFAULT_MAX_MESSAGE_SIZE);
} }
private static NettyServerHandler newHandler(ServerTransportListener transportListener) { private static NettyServerHandler newHandler(ServerTransportListener transportListener) {

View File

@ -31,6 +31,7 @@
package io.grpc.netty; package io.grpc.netty;
import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
import static io.grpc.netty.NettyTestUtil.messageFrame; import static io.grpc.netty.NettyTestUtil.messageFrame;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail; import static org.junit.Assert.fail;
@ -50,6 +51,7 @@ import static org.mockito.Mockito.when;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.internal.ServerStreamListener; import io.grpc.internal.ServerStreamListener;
import io.netty.buffer.EmptyByteBuf; import io.netty.buffer.EmptyByteBuf;
import io.netty.buffer.UnpooledByteBufAllocator; import io.netty.buffer.UnpooledByteBufAllocator;
import io.netty.channel.ChannelPromise; import io.netty.channel.ChannelPromise;
@ -254,7 +256,8 @@ public class NettyServerStreamTest extends NettyStreamTestBase {
} }
}).when(writeQueue).enqueue(any(), any(ChannelPromise.class), anyBoolean()); }).when(writeQueue).enqueue(any(), any(ChannelPromise.class), anyBoolean());
when(writeQueue.enqueue(any(), anyBoolean())).thenReturn(future); when(writeQueue.enqueue(any(), anyBoolean())).thenReturn(future);
NettyServerStream stream = new NettyServerStream(channel, http2Stream, handler); NettyServerStream stream = new NettyServerStream(channel, http2Stream, handler,
DEFAULT_MAX_MESSAGE_SIZE);
stream.setListener(serverListener); stream.setListener(serverListener);
assertTrue(stream.canReceive()); assertTrue(stream.canReceive());
assertTrue(stream.canSend()); assertTrue(stream.canSend());

View File

@ -31,6 +31,9 @@
package io.grpc.okhttp; package io.grpc.okhttp;
import static com.google.common.base.Preconditions.checkArgument;
import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import com.google.common.util.concurrent.ThreadFactoryBuilder; import com.google.common.util.concurrent.ThreadFactoryBuilder;
@ -96,6 +99,7 @@ public final class OkHttpChannelBuilder extends AbstractChannelBuilder<OkHttpCha
private SSLSocketFactory sslSocketFactory; private SSLSocketFactory sslSocketFactory;
private ConnectionSpec connectionSpec = DEFAULT_CONNECTION_SPEC; private ConnectionSpec connectionSpec = DEFAULT_CONNECTION_SPEC;
private NegotiationType negotiationType = NegotiationType.TLS; private NegotiationType negotiationType = NegotiationType.TLS;
private int maxMessageSize = DEFAULT_MAX_MESSAGE_SIZE;
private OkHttpChannelBuilder(String host, int port) { private OkHttpChannelBuilder(String host, int port) {
this.host = Preconditions.checkNotNull(host); this.host = Preconditions.checkNotNull(host);
@ -159,10 +163,20 @@ public final class OkHttpChannelBuilder extends AbstractChannelBuilder<OkHttpCha
return this; return this;
} }
/**
* Sets the maximum message size allowed to be received on the channel. If not called,
* defaults to {@link io.grpc.internal.GrpcUtil#DEFAULT_MAX_MESSAGE_SIZE}.
*/
public OkHttpChannelBuilder maxMessageSize(int maxMessageSize) {
checkArgument(maxMessageSize >= 0, "maxMessageSize must be >= 0");
this.maxMessageSize = maxMessageSize;
return this;
}
@Override @Override
protected ClientTransportFactory buildTransportFactory() { protected ClientTransportFactory buildTransportFactory() {
return new OkHttpTransportFactory(host, port, authorityHost, transportExecutor, return new OkHttpTransportFactory(host, port, authorityHost, transportExecutor,
createSocketFactory(), connectionSpec); createSocketFactory(), connectionSpec, maxMessageSize);
} }
private SSLSocketFactory createSocketFactory() { private SSLSocketFactory createSocketFactory() {
@ -186,18 +200,21 @@ public final class OkHttpChannelBuilder extends AbstractChannelBuilder<OkHttpCha
private final boolean usingSharedExecutor; private final boolean usingSharedExecutor;
private final SSLSocketFactory socketFactory; private final SSLSocketFactory socketFactory;
private final ConnectionSpec connectionSpec; private final ConnectionSpec connectionSpec;
private final int maxMessageSize;
private OkHttpTransportFactory(String host, private OkHttpTransportFactory(String host,
int port, int port,
String authorityHost, String authorityHost,
ExecutorService executor, ExecutorService executor,
SSLSocketFactory socketFactory, SSLSocketFactory socketFactory,
ConnectionSpec connectionSpec) { ConnectionSpec connectionSpec,
int maxMessageSize) {
this.host = host; this.host = host;
this.port = port; this.port = port;
this.authorityHost = authorityHost; this.authorityHost = authorityHost;
this.socketFactory = socketFactory; this.socketFactory = socketFactory;
this.connectionSpec = connectionSpec; this.connectionSpec = connectionSpec;
this.maxMessageSize = maxMessageSize;
usingSharedExecutor = executor == null; usingSharedExecutor = executor == null;
if (usingSharedExecutor) { if (usingSharedExecutor) {
@ -211,7 +228,7 @@ public final class OkHttpChannelBuilder extends AbstractChannelBuilder<OkHttpCha
@Override @Override
public ClientTransport newClientTransport() { public ClientTransport newClientTransport() {
return new OkHttpClientTransport(host, port, authorityHost, executor, socketFactory, return new OkHttpClientTransport(host, port, authorityHost, executor, socketFactory,
connectionSpec); connectionSpec, maxMessageSize);
} }
@Override @Override

View File

@ -58,26 +58,12 @@ import javax.annotation.concurrent.GuardedBy;
*/ */
class OkHttpClientStream extends Http2ClientStream { class OkHttpClientStream extends Http2ClientStream {
private static int WINDOW_UPDATE_THRESHOLD = Utils.DEFAULT_WINDOW_SIZE / 2; private static final int WINDOW_UPDATE_THRESHOLD = Utils.DEFAULT_WINDOW_SIZE / 2;
private static final Buffer EMPTY_BUFFER = new Buffer(); private static final Buffer EMPTY_BUFFER = new Buffer();
private final MethodType type; private final MethodType type;
/**
* Construct a new client stream.
*/
static OkHttpClientStream newStream(ClientStreamListener listener,
AsyncFrameWriter frameWriter,
OkHttpClientTransport transport,
OutboundFlowController outboundFlow,
MethodType type,
Object lock,
List<Header> requestHeaders) {
return new OkHttpClientStream(
listener, frameWriter, transport, outboundFlow, type, lock, requestHeaders);
}
@GuardedBy("lock") @GuardedBy("lock")
private int window = Utils.DEFAULT_WINDOW_SIZE; private int window = Utils.DEFAULT_WINDOW_SIZE;
@GuardedBy("lock") @GuardedBy("lock")
@ -94,15 +80,15 @@ class OkHttpClientStream extends Http2ClientStream {
@GuardedBy("lock") @GuardedBy("lock")
private boolean cancelSent = false; private boolean cancelSent = false;
OkHttpClientStream(ClientStreamListener listener,
private OkHttpClientStream(ClientStreamListener listener,
AsyncFrameWriter frameWriter, AsyncFrameWriter frameWriter,
OkHttpClientTransport transport, OkHttpClientTransport transport,
OutboundFlowController outboundFlow, OutboundFlowController outboundFlow,
MethodType type, MethodType type,
Object lock, Object lock,
List<Header> requestHeaders) { List<Header> requestHeaders,
super(new OkHttpWritableBufferAllocator(), listener); int maxMessageSize) {
super(new OkHttpWritableBufferAllocator(), listener, maxMessageSize);
this.frameWriter = frameWriter; this.frameWriter = frameWriter;
this.transport = transport; this.transport = transport;
this.outboundFlow = outboundFlow; this.outboundFlow = outboundFlow;

View File

@ -141,6 +141,7 @@ class OkHttpClientTransport implements ClientTransport {
private final Executor executor; private final Executor executor;
// Wrap on executor, to guarantee some operations be executed serially. // Wrap on executor, to guarantee some operations be executed serially.
private final SerializingExecutor serializingExecutor; private final SerializingExecutor serializingExecutor;
private final int maxMessageSize;
private int connectionUnacknowledgedBytesRead; private int connectionUnacknowledgedBytesRead;
private ClientFrameHandler clientFrameHandler; private ClientFrameHandler clientFrameHandler;
// The status used to finish all active streams when the transport is closed. // The status used to finish all active streams when the transport is closed.
@ -166,10 +167,13 @@ class OkHttpClientTransport implements ClientTransport {
SettableFuture<Void> connectedFuture; SettableFuture<Void> connectedFuture;
OkHttpClientTransport(String host, int port, String authorityHost, Executor executor, OkHttpClientTransport(String host, int port, String authorityHost, Executor executor,
@Nullable SSLSocketFactory sslSocketFactory, ConnectionSpec connectionSpec) {
@Nullable SSLSocketFactory sslSocketFactory, ConnectionSpec connectionSpec,
int maxMessageSize) {
this.host = Preconditions.checkNotNull(host, "host"); this.host = Preconditions.checkNotNull(host, "host");
this.port = port; this.port = port;
this.authorityHost = authorityHost; this.authorityHost = authorityHost;
this.maxMessageSize = maxMessageSize;
defaultAuthority = authorityHost + ":" + port; defaultAuthority = authorityHost + ":" + port;
this.executor = Preconditions.checkNotNull(executor, "executor"); this.executor = Preconditions.checkNotNull(executor, "executor");
serializingExecutor = new SerializingExecutor(executor); serializingExecutor = new SerializingExecutor(executor);
@ -187,10 +191,12 @@ class OkHttpClientTransport implements ClientTransport {
@VisibleForTesting @VisibleForTesting
OkHttpClientTransport(Executor executor, FrameReader frameReader, FrameWriter testFrameWriter, OkHttpClientTransport(Executor executor, FrameReader frameReader, FrameWriter testFrameWriter,
int nextStreamId, Socket socket, Ticker ticker, int nextStreamId, Socket socket, Ticker ticker,
@Nullable Runnable connectingCallback, SettableFuture<Void> connectedFuture) { @Nullable Runnable connectingCallback, SettableFuture<Void> connectedFuture,
int maxMessageSize) {
host = null; host = null;
port = 0; port = 0;
authorityHost = null; authorityHost = null;
this.maxMessageSize = maxMessageSize;
defaultAuthority = "notarealauthority:80"; defaultAuthority = "notarealauthority:80";
this.executor = Preconditions.checkNotNull(executor); this.executor = Preconditions.checkNotNull(executor);
serializingExecutor = new SerializingExecutor(executor); serializingExecutor = new SerializingExecutor(executor);
@ -248,9 +254,10 @@ class OkHttpClientTransport implements ClientTransport {
Preconditions.checkNotNull(listener, "listener"); Preconditions.checkNotNull(listener, "listener");
String defaultPath = "/" + method.getFullMethodName(); String defaultPath = "/" + method.getFullMethodName();
OkHttpClientStream clientStream = OkHttpClientStream.newStream( OkHttpClientStream clientStream = new OkHttpClientStream(
listener, frameWriter, this, outboundFlow, method.getType(), lock, listener, frameWriter, this, outboundFlow, method.getType(), lock,
Headers.createRequestHeaders(headers, defaultPath, defaultAuthority)); Headers.createRequestHeaders(headers, defaultPath, defaultAuthority),
maxMessageSize);
synchronized (lock) { synchronized (lock) {
if (goAway) { if (goAway) {

View File

@ -32,6 +32,8 @@
package io.grpc.okhttp; package io.grpc.okhttp;
import static com.google.common.base.Charsets.UTF_8; import static com.google.common.base.Charsets.UTF_8;
import static io.grpc.Status.Code.INTERNAL;
import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
import static io.grpc.okhttp.Headers.CONTENT_TYPE_HEADER; import static io.grpc.okhttp.Headers.CONTENT_TYPE_HEADER;
import static io.grpc.okhttp.Headers.METHOD_HEADER; import static io.grpc.okhttp.Headers.METHOD_HEADER;
import static io.grpc.okhttp.Headers.SCHEME_HEADER; import static io.grpc.okhttp.Headers.SCHEME_HEADER;
@ -148,20 +150,20 @@ public class OkHttpClientTransportTest {
} }
private void initTransport() throws Exception { private void initTransport() throws Exception {
startTransport(3, null, true); startTransport(3, null, true, DEFAULT_MAX_MESSAGE_SIZE);
} }
private void initTransport(int startId) throws Exception { private void initTransport(int startId) throws Exception {
startTransport(startId, null, true); startTransport(startId, null, true, DEFAULT_MAX_MESSAGE_SIZE);
} }
private void initTransportAndDelayConnected() throws Exception { private void initTransportAndDelayConnected() throws Exception {
delayConnectedCallback = new DelayConnectedCallback(); delayConnectedCallback = new DelayConnectedCallback();
startTransport(3, delayConnectedCallback, false); startTransport(3, delayConnectedCallback, false, DEFAULT_MAX_MESSAGE_SIZE);
} }
private void startTransport(int startId, @Nullable Runnable connectingCallback, private void startTransport(int startId, @Nullable Runnable connectingCallback,
boolean waitingForConnected) throws Exception { boolean waitingForConnected, int maxMessageSize) throws Exception {
connectedFuture = SettableFuture.create(); connectedFuture = SettableFuture.create();
Ticker ticker = new Ticker() { Ticker ticker = new Ticker() {
@Override @Override
@ -171,7 +173,8 @@ public class OkHttpClientTransportTest {
}; };
clientTransport = new OkHttpClientTransport( clientTransport = new OkHttpClientTransport(
executor, frameReader, frameWriter, startId, executor, frameReader, frameWriter, startId,
new MockSocket(frameReader), ticker, connectingCallback, connectedFuture); new MockSocket(frameReader), ticker, connectingCallback, connectedFuture,
maxMessageSize);
clientTransport.start(transportListener); clientTransport.start(transportListener);
if (waitingForConnected) { if (waitingForConnected) {
connectedFuture.get(TIME_OUT_MS, TimeUnit.MILLISECONDS); connectedFuture.get(TIME_OUT_MS, TimeUnit.MILLISECONDS);
@ -188,6 +191,26 @@ public class OkHttpClientTransportTest {
executor.shutdown(); executor.shutdown();
} }
@Test
public void maxMessageSizeShouldBeEnforced() throws Exception {
// Allow the response payloads of up to 1 byte.
startTransport(3, null, true, 1);
MockStreamListener listener = new MockStreamListener();
clientTransport.newStream(method, new Metadata.Headers(), listener).request(1);
assertContainStream(3);
frameHandler().headers(false, false, 3, 0, grpcResponseHeaders(), HeadersMode.HTTP_20_HEADERS);
assertNotNull(listener.headers);
// Receive the message.
final String message = "Hello Client";
Buffer buffer = createMessageFrame(message);
frameHandler().data(false, 3, buffer, (int) buffer.size());
listener.waitUntilStreamClosed();
assertEquals(INTERNAL, listener.status.getCode());
}
/** /**
* When nextFrame throws IOException, the transport should be aborted. * When nextFrame throws IOException, the transport should be aborted.
*/ */