mirror of https://github.com/grpc/grpc-java.git
parent
40c66a17a9
commit
15f02ba19c
|
@ -32,19 +32,15 @@
|
||||||
package io.grpc.internal;
|
package io.grpc.internal;
|
||||||
|
|
||||||
import static com.google.common.base.Preconditions.checkArgument;
|
import static com.google.common.base.Preconditions.checkArgument;
|
||||||
import static io.grpc.Status.Code.CANCELLED;
|
import static io.grpc.internal.GrpcUtil.CANCEL_REASONS;
|
||||||
import static io.grpc.Status.Code.DEADLINE_EXCEEDED;
|
|
||||||
|
|
||||||
import com.google.common.base.Objects;
|
import com.google.common.base.Objects;
|
||||||
import com.google.common.base.Preconditions;
|
import com.google.common.base.Preconditions;
|
||||||
|
|
||||||
import io.grpc.Metadata;
|
import io.grpc.Metadata;
|
||||||
import io.grpc.Status;
|
import io.grpc.Status;
|
||||||
import io.grpc.Status.Code;
|
|
||||||
|
|
||||||
import java.io.InputStream;
|
import java.io.InputStream;
|
||||||
import java.util.EnumSet;
|
|
||||||
import java.util.Set;
|
|
||||||
import java.util.logging.Level;
|
import java.util.logging.Level;
|
||||||
import java.util.logging.Logger;
|
import java.util.logging.Logger;
|
||||||
|
|
||||||
|
@ -55,8 +51,6 @@ public abstract class AbstractClientStream<IdT> extends AbstractStream<IdT>
|
||||||
implements ClientStream {
|
implements ClientStream {
|
||||||
|
|
||||||
private static final Logger log = Logger.getLogger(AbstractClientStream.class.getName());
|
private static final Logger log = Logger.getLogger(AbstractClientStream.class.getName());
|
||||||
private static final Set<Code> CANCEL_REASONS =
|
|
||||||
EnumSet.of(CANCELLED, DEADLINE_EXCEEDED, Code.INTERNAL, Code.UNKNOWN);
|
|
||||||
|
|
||||||
private final ClientStreamListener listener;
|
private final ClientStreamListener listener;
|
||||||
private boolean listenerClosed;
|
private boolean listenerClosed;
|
||||||
|
@ -67,15 +61,10 @@ public abstract class AbstractClientStream<IdT> extends AbstractStream<IdT>
|
||||||
private Metadata trailers;
|
private Metadata trailers;
|
||||||
private Runnable closeListenerTask;
|
private Runnable closeListenerTask;
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Constructor used by subclasses.
|
|
||||||
*
|
|
||||||
* @param listener the listener to receive notifications
|
|
||||||
*/
|
|
||||||
protected AbstractClientStream(WritableBufferAllocator bufferAllocator,
|
protected AbstractClientStream(WritableBufferAllocator bufferAllocator,
|
||||||
ClientStreamListener listener) {
|
ClientStreamListener listener,
|
||||||
super(bufferAllocator);
|
int maxMessageSize) {
|
||||||
|
super(bufferAllocator, maxMessageSize);
|
||||||
this.listener = Preconditions.checkNotNull(listener);
|
this.listener = Preconditions.checkNotNull(listener);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -63,8 +63,9 @@ public abstract class AbstractServerStream<IdT> extends AbstractStream<IdT>
|
||||||
/** Saved trailers from close() that need to be sent once the framer has sent all messages. */
|
/** Saved trailers from close() that need to be sent once the framer has sent all messages. */
|
||||||
private Metadata stashedTrailers;
|
private Metadata stashedTrailers;
|
||||||
|
|
||||||
protected AbstractServerStream(WritableBufferAllocator bufferAllocator) {
|
protected AbstractServerStream(WritableBufferAllocator bufferAllocator,
|
||||||
super(bufferAllocator);
|
int maxMessageSize) {
|
||||||
|
super(bufferAllocator, maxMessageSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -100,7 +100,7 @@ public abstract class AbstractStream<IdT> implements Stream {
|
||||||
|
|
||||||
private final Object onReadyLock = new Object();
|
private final Object onReadyLock = new Object();
|
||||||
|
|
||||||
AbstractStream(WritableBufferAllocator bufferAllocator) {
|
AbstractStream(WritableBufferAllocator bufferAllocator, int maxMessageSize) {
|
||||||
MessageDeframer.Listener inboundMessageHandler = new MessageDeframer.Listener() {
|
MessageDeframer.Listener inboundMessageHandler = new MessageDeframer.Listener() {
|
||||||
@Override
|
@Override
|
||||||
public void bytesRead(int numBytes) {
|
public void bytesRead(int numBytes) {
|
||||||
|
@ -130,7 +130,7 @@ public abstract class AbstractStream<IdT> implements Stream {
|
||||||
};
|
};
|
||||||
|
|
||||||
framer = new MessageFramer(outboundFrameHandler, bufferAllocator);
|
framer = new MessageFramer(outboundFrameHandler, bufferAllocator);
|
||||||
deframer = new MessageDeframer(inboundMessageHandler);
|
deframer = new MessageDeframer(inboundMessageHandler, MessageEncoding.NONE, maxMessageSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -31,6 +31,9 @@
|
||||||
|
|
||||||
package io.grpc.internal;
|
package io.grpc.internal;
|
||||||
|
|
||||||
|
import static io.grpc.Status.Code.CANCELLED;
|
||||||
|
import static io.grpc.Status.Code.DEADLINE_EXCEEDED;
|
||||||
|
|
||||||
import com.google.common.annotations.VisibleForTesting;
|
import com.google.common.annotations.VisibleForTesting;
|
||||||
import com.google.common.base.Preconditions;
|
import com.google.common.base.Preconditions;
|
||||||
|
|
||||||
|
@ -38,6 +41,8 @@ import io.grpc.Metadata;
|
||||||
import io.grpc.Status;
|
import io.grpc.Status;
|
||||||
|
|
||||||
import java.net.HttpURLConnection;
|
import java.net.HttpURLConnection;
|
||||||
|
import java.util.EnumSet;
|
||||||
|
import java.util.Set;
|
||||||
import java.util.concurrent.Executors;
|
import java.util.concurrent.Executors;
|
||||||
import java.util.concurrent.ScheduledExecutorService;
|
import java.util.concurrent.ScheduledExecutorService;
|
||||||
import java.util.concurrent.ThreadFactory;
|
import java.util.concurrent.ThreadFactory;
|
||||||
|
@ -109,6 +114,17 @@ public final class GrpcUtil {
|
||||||
*/
|
*/
|
||||||
public static final String MESSAGE_ENCODING = "grpc-encoding";
|
public static final String MESSAGE_ENCODING = "grpc-encoding";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The default maximum uncompressed size (in bytes) for inbound messages. Defaults to 100 MiB.
|
||||||
|
*/
|
||||||
|
public static final int DEFAULT_MAX_MESSAGE_SIZE = 100 * 1024 * 1024;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The set of valid status codes for client cancellation.
|
||||||
|
*/
|
||||||
|
public static final Set<Status.Code> CANCEL_REASONS =
|
||||||
|
EnumSet.of(CANCELLED, DEADLINE_EXCEEDED, Status.Code.INTERNAL, Status.Code.UNKNOWN);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Maps HTTP error response status codes to transport codes.
|
* Maps HTTP error response status codes to transport codes.
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -71,8 +71,9 @@ public abstract class Http2ClientStream extends AbstractClientStream<Integer> {
|
||||||
private boolean contentTypeChecked;
|
private boolean contentTypeChecked;
|
||||||
|
|
||||||
protected Http2ClientStream(WritableBufferAllocator bufferAllocator,
|
protected Http2ClientStream(WritableBufferAllocator bufferAllocator,
|
||||||
ClientStreamListener listener) {
|
ClientStreamListener listener,
|
||||||
super(bufferAllocator, listener);
|
int maxMessageSize) {
|
||||||
|
super(bufferAllocator, listener, maxMessageSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -39,6 +39,7 @@ import io.grpc.MessageEncoding;
|
||||||
import io.grpc.Status;
|
import io.grpc.Status;
|
||||||
|
|
||||||
import java.io.Closeable;
|
import java.io.Closeable;
|
||||||
|
import java.io.FilterInputStream;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.io.InputStream;
|
import java.io.InputStream;
|
||||||
|
|
||||||
|
@ -94,6 +95,7 @@ public class MessageDeframer implements Closeable {
|
||||||
}
|
}
|
||||||
|
|
||||||
private final Listener listener;
|
private final Listener listener;
|
||||||
|
private final int maxMessageSize;
|
||||||
private MessageEncoding.Decompressor decompressor;
|
private MessageEncoding.Decompressor decompressor;
|
||||||
private State state = State.HEADER;
|
private State state = State.HEADER;
|
||||||
private int requiredLength = HEADER_LENGTH;
|
private int requiredLength = HEADER_LENGTH;
|
||||||
|
@ -105,25 +107,19 @@ public class MessageDeframer implements Closeable {
|
||||||
private boolean deliveryStalled = true;
|
private boolean deliveryStalled = true;
|
||||||
private boolean inDelivery = false;
|
private boolean inDelivery = false;
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates a deframer. Compression will not be supported.
|
|
||||||
*
|
|
||||||
* @param listener listener for deframer events.
|
|
||||||
*/
|
|
||||||
public MessageDeframer(Listener listener) {
|
|
||||||
this(listener, MessageEncoding.NONE);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a deframer.
|
* Create a deframer.
|
||||||
*
|
*
|
||||||
* @param listener listener for deframer events.
|
* @param listener listener for deframer events.
|
||||||
* @param decompressor the compression used if a compressed frame is encountered, with
|
* @param decompressor the compression used if a compressed frame is encountered, with
|
||||||
* {@code NONE} meaning unsupported
|
* {@code NONE} meaning unsupported
|
||||||
|
* @param maxMessageSize the maximum allowed size for received messages.
|
||||||
*/
|
*/
|
||||||
public MessageDeframer(Listener listener, MessageEncoding.Decompressor decompressor) {
|
public MessageDeframer(Listener listener, MessageEncoding.Decompressor decompressor,
|
||||||
this.listener = Preconditions.checkNotNull(listener, "listener");
|
int maxMessageSize) {
|
||||||
|
this.listener = Preconditions.checkNotNull(listener, "sink");
|
||||||
this.decompressor = Preconditions.checkNotNull(decompressor, "decompressor");
|
this.decompressor = Preconditions.checkNotNull(decompressor, "decompressor");
|
||||||
|
this.maxMessageSize = maxMessageSize;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -162,8 +158,7 @@ public class MessageDeframer implements Closeable {
|
||||||
* the remote endpoint. End of stream should not be used in the event of a transport
|
* the remote endpoint. End of stream should not be used in the event of a transport
|
||||||
* error, such as a stream reset.
|
* error, such as a stream reset.
|
||||||
* @throws IllegalStateException if {@link #close()} has been called previously or if
|
* @throws IllegalStateException if {@link #close()} has been called previously or if
|
||||||
* {@link #deframe(ReadableBuffer, boolean)} has previously been called with
|
* this method has previously been called with {@code endOfStream=true}.
|
||||||
* {@code endOfStream=true}.
|
|
||||||
*/
|
*/
|
||||||
public void deframe(ReadableBuffer data, boolean endOfStream) {
|
public void deframe(ReadableBuffer data, boolean endOfStream) {
|
||||||
Preconditions.checkNotNull(data, "data");
|
Preconditions.checkNotNull(data, "data");
|
||||||
|
@ -291,10 +286,6 @@ public class MessageDeframer implements Closeable {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private boolean isDataAvailable() {
|
|
||||||
return unprocessed.readableBytes() > 0 || (nextFrame != null && nextFrame.readableBytes() > 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Attempts to read the required bytes into nextFrame.
|
* Attempts to read the required bytes into nextFrame.
|
||||||
*
|
*
|
||||||
|
@ -340,6 +331,10 @@ public class MessageDeframer implements Closeable {
|
||||||
|
|
||||||
// Update the required length to include the length of the frame.
|
// Update the required length to include the length of the frame.
|
||||||
requiredLength = nextFrame.readInt();
|
requiredLength = nextFrame.readInt();
|
||||||
|
if (requiredLength < 0 || requiredLength > maxMessageSize) {
|
||||||
|
throw Status.INTERNAL.withDescription(String.format("Frame size %d exceeds maximum: %d, ",
|
||||||
|
requiredLength, maxMessageSize)).asRuntimeException();
|
||||||
|
}
|
||||||
|
|
||||||
// Continue reading the frame body.
|
// Continue reading the frame body.
|
||||||
state = State.BODY;
|
state = State.BODY;
|
||||||
|
@ -370,9 +365,79 @@ public class MessageDeframer implements Closeable {
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
return decompressor.decompress(ReadableBuffers.openStream(nextFrame, true));
|
// Enforce the maxMessageSize limit on the returned stream.
|
||||||
|
return new SizeEnforcingInputStream(decompressor.decompress(
|
||||||
|
ReadableBuffers.openStream(nextFrame, true)));
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An {@link InputStream} that enforces the {@link #maxMessageSize} limit for compressed frames.
|
||||||
|
*/
|
||||||
|
private final class SizeEnforcingInputStream extends FilterInputStream {
|
||||||
|
private long count;
|
||||||
|
private long mark = -1;
|
||||||
|
|
||||||
|
public SizeEnforcingInputStream(InputStream in) {
|
||||||
|
super(in);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int read() throws IOException {
|
||||||
|
int result = in.read();
|
||||||
|
if (result != -1) {
|
||||||
|
count++;
|
||||||
|
}
|
||||||
|
verifySize();
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int read(byte[] b, int off, int len) throws IOException {
|
||||||
|
int result = in.read(b, off, len);
|
||||||
|
if (result != -1) {
|
||||||
|
count += result;
|
||||||
|
}
|
||||||
|
verifySize();
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long skip(long n) throws IOException {
|
||||||
|
long result = in.skip(n);
|
||||||
|
count += result;
|
||||||
|
verifySize();
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public synchronized void mark(int readlimit) {
|
||||||
|
in.mark(readlimit);
|
||||||
|
mark = count;
|
||||||
|
// it's okay to mark even if mark isn't supported, as reset won't work
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public synchronized void reset() throws IOException {
|
||||||
|
if (!in.markSupported()) {
|
||||||
|
throw new IOException("Mark not supported");
|
||||||
|
}
|
||||||
|
if (mark == -1) {
|
||||||
|
throw new IOException("Mark not set");
|
||||||
|
}
|
||||||
|
|
||||||
|
in.reset();
|
||||||
|
count = mark;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void verifySize() {
|
||||||
|
if (count > maxMessageSize) {
|
||||||
|
throw Status.INTERNAL.withDescription(String.format(
|
||||||
|
"Compressed frame exceeds maximum frame size: %d. Bytes read: %d",
|
||||||
|
maxMessageSize, count)).asRuntimeException();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,6 +31,7 @@
|
||||||
|
|
||||||
package io.grpc.internal;
|
package io.grpc.internal;
|
||||||
|
|
||||||
|
import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
import static org.junit.Assert.fail;
|
import static org.junit.Assert.fail;
|
||||||
import static org.mockito.Matchers.isA;
|
import static org.mockito.Matchers.isA;
|
||||||
|
@ -242,7 +243,7 @@ public class AbstractClientStreamTest {
|
||||||
private static class BaseAbstractClientStream<T> extends AbstractClientStream<T> {
|
private static class BaseAbstractClientStream<T> extends AbstractClientStream<T> {
|
||||||
protected BaseAbstractClientStream(
|
protected BaseAbstractClientStream(
|
||||||
WritableBufferAllocator allocator, ClientStreamListener listener) {
|
WritableBufferAllocator allocator, ClientStreamListener listener) {
|
||||||
super(allocator, listener);
|
super(allocator, listener, DEFAULT_MAX_MESSAGE_SIZE);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -32,6 +32,7 @@
|
||||||
|
|
||||||
package io.grpc.internal;
|
package io.grpc.internal;
|
||||||
|
|
||||||
|
import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
|
||||||
import static org.junit.Assert.fail;
|
import static org.junit.Assert.fail;
|
||||||
import static org.mockito.Mockito.verify;
|
import static org.mockito.Mockito.verify;
|
||||||
|
|
||||||
|
@ -102,7 +103,7 @@ public class AbstractStreamTest {
|
||||||
*/
|
*/
|
||||||
private class AbstractStreamBase<IdT> extends AbstractStream<IdT> {
|
private class AbstractStreamBase<IdT> extends AbstractStream<IdT> {
|
||||||
private AbstractStreamBase(WritableBufferAllocator bufferAllocator) {
|
private AbstractStreamBase(WritableBufferAllocator bufferAllocator) {
|
||||||
super(bufferAllocator);
|
super(bufferAllocator, DEFAULT_MAX_MESSAGE_SIZE);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -31,6 +31,7 @@
|
||||||
|
|
||||||
package io.grpc.internal;
|
package io.grpc.internal;
|
||||||
|
|
||||||
|
import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.Assert.assertTrue;
|
||||||
import static org.mockito.Matchers.anyInt;
|
import static org.mockito.Matchers.anyInt;
|
||||||
|
@ -67,7 +68,8 @@ import java.util.zip.GZIPOutputStream;
|
||||||
@RunWith(JUnit4.class)
|
@RunWith(JUnit4.class)
|
||||||
public class MessageDeframerTest {
|
public class MessageDeframerTest {
|
||||||
private Listener listener = mock(Listener.class);
|
private Listener listener = mock(Listener.class);
|
||||||
private MessageDeframer deframer = new MessageDeframer(listener);
|
private MessageDeframer deframer = new MessageDeframer(listener, MessageEncoding.NONE,
|
||||||
|
DEFAULT_MAX_MESSAGE_SIZE);
|
||||||
private ArgumentCaptor<InputStream> messages = ArgumentCaptor.forClass(InputStream.class);
|
private ArgumentCaptor<InputStream> messages = ArgumentCaptor.forClass(InputStream.class);
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -176,7 +178,7 @@ public class MessageDeframerTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void compressed() {
|
public void compressed() {
|
||||||
deframer = new MessageDeframer(listener, new MessageEncoding.Gzip());
|
deframer = new MessageDeframer(listener, new MessageEncoding.Gzip(), DEFAULT_MAX_MESSAGE_SIZE);
|
||||||
deframer.request(1);
|
deframer.request(1);
|
||||||
|
|
||||||
byte[] payload = compress(new byte[1000]);
|
byte[] payload = compress(new byte[1000]);
|
||||||
|
|
|
@ -31,15 +31,12 @@
|
||||||
|
|
||||||
package io.grpc.netty;
|
package io.grpc.netty;
|
||||||
|
|
||||||
import static io.grpc.Status.Code.CANCELLED;
|
import static io.grpc.internal.GrpcUtil.CANCEL_REASONS;
|
||||||
import static io.grpc.Status.Code.DEADLINE_EXCEEDED;
|
|
||||||
|
|
||||||
import com.google.common.base.Preconditions;
|
import com.google.common.base.Preconditions;
|
||||||
|
|
||||||
import io.grpc.Status;
|
import io.grpc.Status;
|
||||||
|
|
||||||
import java.util.EnumSet;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Command sent from a Netty client stream to the handler to cancel the stream.
|
* Command sent from a Netty client stream to the handler to cancel the stream.
|
||||||
*/
|
*/
|
||||||
|
@ -50,8 +47,8 @@ class CancelClientStreamCommand {
|
||||||
CancelClientStreamCommand(NettyClientStream stream, Status reason) {
|
CancelClientStreamCommand(NettyClientStream stream, Status reason) {
|
||||||
this.stream = Preconditions.checkNotNull(stream, "stream");
|
this.stream = Preconditions.checkNotNull(stream, "stream");
|
||||||
Preconditions.checkNotNull(reason);
|
Preconditions.checkNotNull(reason);
|
||||||
Preconditions.checkArgument(EnumSet.of(CANCELLED, DEADLINE_EXCEEDED).contains(reason.getCode()),
|
Preconditions.checkArgument(CANCEL_REASONS.contains(reason.getCode()),
|
||||||
"Invalid cancellation reason");
|
"Invalid cancellation reason");
|
||||||
this.reason = reason;
|
this.reason = reason;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -31,6 +31,9 @@
|
||||||
|
|
||||||
package io.grpc.netty;
|
package io.grpc.netty;
|
||||||
|
|
||||||
|
import static com.google.common.base.Preconditions.checkArgument;
|
||||||
|
import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
|
||||||
|
|
||||||
import com.google.common.base.Preconditions;
|
import com.google.common.base.Preconditions;
|
||||||
|
|
||||||
import io.grpc.AbstractChannelBuilder;
|
import io.grpc.AbstractChannelBuilder;
|
||||||
|
@ -38,6 +41,7 @@ import io.grpc.internal.AbstractReferenceCounted;
|
||||||
import io.grpc.internal.ClientTransport;
|
import io.grpc.internal.ClientTransport;
|
||||||
import io.grpc.internal.ClientTransportFactory;
|
import io.grpc.internal.ClientTransportFactory;
|
||||||
import io.grpc.internal.SharedResourceHolder;
|
import io.grpc.internal.SharedResourceHolder;
|
||||||
|
|
||||||
import io.netty.channel.Channel;
|
import io.netty.channel.Channel;
|
||||||
import io.netty.channel.EventLoopGroup;
|
import io.netty.channel.EventLoopGroup;
|
||||||
import io.netty.channel.socket.nio.NioSocketChannel;
|
import io.netty.channel.socket.nio.NioSocketChannel;
|
||||||
|
@ -62,6 +66,7 @@ public final class NettyChannelBuilder extends AbstractChannelBuilder<NettyChann
|
||||||
private EventLoopGroup eventLoopGroup;
|
private EventLoopGroup eventLoopGroup;
|
||||||
private SslContext sslContext;
|
private SslContext sslContext;
|
||||||
private int flowControlWindow = DEFAULT_FLOW_CONTROL_WINDOW;
|
private int flowControlWindow = DEFAULT_FLOW_CONTROL_WINDOW;
|
||||||
|
private int maxMessageSize = DEFAULT_MAX_MESSAGE_SIZE;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates a new builder with the given server address.
|
* Creates a new builder with the given server address.
|
||||||
|
@ -132,10 +137,20 @@ public final class NettyChannelBuilder extends AbstractChannelBuilder<NettyChann
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the maximum message size allowed to be received on the channel. If not called,
|
||||||
|
* defaults to {@link io.grpc.internal.GrpcUtil#DEFAULT_MAX_MESSAGE_SIZE}.
|
||||||
|
*/
|
||||||
|
public NettyChannelBuilder maxMessageSize(int maxMessageSize) {
|
||||||
|
checkArgument(maxMessageSize >= 0, "maxMessageSize must be >= 0");
|
||||||
|
this.maxMessageSize = maxMessageSize;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected ClientTransportFactory buildTransportFactory() {
|
protected ClientTransportFactory buildTransportFactory() {
|
||||||
return new NettyTransportFactory(serverAddress, channelType, eventLoopGroup, flowControlWindow,
|
return new NettyTransportFactory(serverAddress, channelType, eventLoopGroup, flowControlWindow,
|
||||||
createProtocolNegotiator());
|
createProtocolNegotiator(), maxMessageSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
private ProtocolNegotiator createProtocolNegotiator() {
|
private ProtocolNegotiator createProtocolNegotiator() {
|
||||||
|
@ -170,16 +185,19 @@ public final class NettyChannelBuilder extends AbstractChannelBuilder<NettyChann
|
||||||
private final boolean usingSharedGroup;
|
private final boolean usingSharedGroup;
|
||||||
private final int flowControlWindow;
|
private final int flowControlWindow;
|
||||||
private final ProtocolNegotiator negotiator;
|
private final ProtocolNegotiator negotiator;
|
||||||
|
private final int maxMessageSize;
|
||||||
|
|
||||||
private NettyTransportFactory(SocketAddress serverAddress,
|
private NettyTransportFactory(SocketAddress serverAddress,
|
||||||
Class<? extends Channel> channelType,
|
Class<? extends Channel> channelType,
|
||||||
EventLoopGroup group,
|
EventLoopGroup group,
|
||||||
int flowControlWindow,
|
int flowControlWindow,
|
||||||
ProtocolNegotiator negotiator) {
|
ProtocolNegotiator negotiator,
|
||||||
|
int maxMessageSize) {
|
||||||
this.serverAddress = serverAddress;
|
this.serverAddress = serverAddress;
|
||||||
this.channelType = channelType;
|
this.channelType = channelType;
|
||||||
this.flowControlWindow = flowControlWindow;
|
this.flowControlWindow = flowControlWindow;
|
||||||
this.negotiator = negotiator;
|
this.negotiator = negotiator;
|
||||||
|
this.maxMessageSize = maxMessageSize;
|
||||||
|
|
||||||
usingSharedGroup = group == null;
|
usingSharedGroup = group == null;
|
||||||
if (usingSharedGroup) {
|
if (usingSharedGroup) {
|
||||||
|
@ -193,7 +211,7 @@ public final class NettyChannelBuilder extends AbstractChannelBuilder<NettyChann
|
||||||
@Override
|
@Override
|
||||||
public ClientTransport newClientTransport() {
|
public ClientTransport newClientTransport() {
|
||||||
return new NettyClientTransport(serverAddress, channelType, group, negotiator,
|
return new NettyClientTransport(serverAddress, channelType, group, negotiator,
|
||||||
flowControlWindow);
|
flowControlWindow, maxMessageSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -58,8 +58,9 @@ class NettyClientStream extends Http2ClientStream {
|
||||||
private Integer id;
|
private Integer id;
|
||||||
private WriteQueue writeQueue;
|
private WriteQueue writeQueue;
|
||||||
|
|
||||||
NettyClientStream(ClientStreamListener listener, Channel channel, NettyClientHandler handler) {
|
NettyClientStream(ClientStreamListener listener, Channel channel, NettyClientHandler handler,
|
||||||
super(new NettyWritableBufferAllocator(channel.alloc()), listener);
|
int maxMessageSize) {
|
||||||
|
super(new NettyWritableBufferAllocator(channel.alloc()), listener, maxMessageSize);
|
||||||
this.writeQueue = handler.getWriteQueue();
|
this.writeQueue = handler.getWriteQueue();
|
||||||
this.channel = checkNotNull(channel, "channel");
|
this.channel = checkNotNull(channel, "channel");
|
||||||
this.handler = checkNotNull(handler, "handler");
|
this.handler = checkNotNull(handler, "handler");
|
||||||
|
|
|
@ -81,6 +81,7 @@ class NettyClientTransport implements ClientTransport {
|
||||||
private final NettyClientHandler handler;
|
private final NettyClientHandler handler;
|
||||||
private final AsciiString authority;
|
private final AsciiString authority;
|
||||||
private final int flowControlWindow;
|
private final int flowControlWindow;
|
||||||
|
private final int maxMessageSize;
|
||||||
// We should not send on the channel until negotiation completes. This is a hard requirement
|
// We should not send on the channel until negotiation completes. This is a hard requirement
|
||||||
// by SslHandler but is appropriate for HTTP/1.1 Upgrade as well.
|
// by SslHandler but is appropriate for HTTP/1.1 Upgrade as well.
|
||||||
private Channel channel;
|
private Channel channel;
|
||||||
|
@ -94,12 +95,13 @@ class NettyClientTransport implements ClientTransport {
|
||||||
|
|
||||||
NettyClientTransport(SocketAddress address, Class<? extends Channel> channelType,
|
NettyClientTransport(SocketAddress address, Class<? extends Channel> channelType,
|
||||||
EventLoopGroup group, ProtocolNegotiator negotiator,
|
EventLoopGroup group, ProtocolNegotiator negotiator,
|
||||||
int flowControlWindow) {
|
int flowControlWindow, int maxMessageSize) {
|
||||||
Preconditions.checkNotNull(negotiator, "negotiator");
|
Preconditions.checkNotNull(negotiator, "negotiator");
|
||||||
this.address = Preconditions.checkNotNull(address, "address");
|
this.address = Preconditions.checkNotNull(address, "address");
|
||||||
this.group = Preconditions.checkNotNull(group, "group");
|
this.group = Preconditions.checkNotNull(group, "group");
|
||||||
this.channelType = Preconditions.checkNotNull(channelType, "channelType");
|
this.channelType = Preconditions.checkNotNull(channelType, "channelType");
|
||||||
this.flowControlWindow = flowControlWindow;
|
this.flowControlWindow = flowControlWindow;
|
||||||
|
this.maxMessageSize = maxMessageSize;
|
||||||
|
|
||||||
if (address instanceof InetSocketAddress) {
|
if (address instanceof InetSocketAddress) {
|
||||||
InetSocketAddress inetAddress = (InetSocketAddress) address;
|
InetSocketAddress inetAddress = (InetSocketAddress) address;
|
||||||
|
@ -128,7 +130,8 @@ class NettyClientTransport implements ClientTransport {
|
||||||
Preconditions.checkNotNull(listener, "listener");
|
Preconditions.checkNotNull(listener, "listener");
|
||||||
|
|
||||||
// Create the stream.
|
// Create the stream.
|
||||||
final NettyClientStream stream = new NettyClientStream(listener, channel, handler);
|
final NettyClientStream stream = new NettyClientStream(listener, channel, handler,
|
||||||
|
maxMessageSize);
|
||||||
|
|
||||||
// Convert the headers into Netty HTTP/2 headers.
|
// Convert the headers into Netty HTTP/2 headers.
|
||||||
AsciiString defaultPath = new AsciiString("/" + method.getFullMethodName());
|
AsciiString defaultPath = new AsciiString("/" + method.getFullMethodName());
|
||||||
|
|
|
@ -72,11 +72,13 @@ public class NettyServer implements Server {
|
||||||
private EventLoopGroup workerGroup;
|
private EventLoopGroup workerGroup;
|
||||||
private ServerListener listener;
|
private ServerListener listener;
|
||||||
private Channel channel;
|
private Channel channel;
|
||||||
private int flowControlWindow;
|
private final int flowControlWindow;
|
||||||
|
private final int maxMessageSize;
|
||||||
|
|
||||||
NettyServer(SocketAddress address, Class<? extends ServerChannel> channelType,
|
NettyServer(SocketAddress address, Class<? extends ServerChannel> channelType,
|
||||||
@Nullable EventLoopGroup bossGroup, @Nullable EventLoopGroup workerGroup,
|
@Nullable EventLoopGroup bossGroup, @Nullable EventLoopGroup workerGroup,
|
||||||
@Nullable SslContext sslContext, int maxStreamsPerConnection, int flowControlWindow) {
|
@Nullable SslContext sslContext, int maxStreamsPerConnection, int flowControlWindow,
|
||||||
|
int maxMessageSize) {
|
||||||
this.address = address;
|
this.address = address;
|
||||||
this.channelType = checkNotNull(channelType, "channelType");
|
this.channelType = checkNotNull(channelType, "channelType");
|
||||||
this.bossGroup = bossGroup;
|
this.bossGroup = bossGroup;
|
||||||
|
@ -86,6 +88,7 @@ public class NettyServer implements Server {
|
||||||
this.usingSharedWorkerGroup = workerGroup == null;
|
this.usingSharedWorkerGroup = workerGroup == null;
|
||||||
this.maxStreamsPerConnection = maxStreamsPerConnection;
|
this.maxStreamsPerConnection = maxStreamsPerConnection;
|
||||||
this.flowControlWindow = flowControlWindow;
|
this.flowControlWindow = flowControlWindow;
|
||||||
|
this.maxMessageSize = maxMessageSize;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -106,7 +109,8 @@ public class NettyServer implements Server {
|
||||||
@Override
|
@Override
|
||||||
public void initChannel(Channel ch) throws Exception {
|
public void initChannel(Channel ch) throws Exception {
|
||||||
NettyServerTransport transport
|
NettyServerTransport transport
|
||||||
= new NettyServerTransport(ch, sslContext, maxStreamsPerConnection, flowControlWindow);
|
= new NettyServerTransport(ch, sslContext, maxStreamsPerConnection, flowControlWindow,
|
||||||
|
maxMessageSize);
|
||||||
transport.start(listener.transportCreated(transport));
|
transport.start(listener.transportCreated(transport));
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
|
@ -31,6 +31,9 @@
|
||||||
|
|
||||||
package io.grpc.netty;
|
package io.grpc.netty;
|
||||||
|
|
||||||
|
import static com.google.common.base.Preconditions.checkArgument;
|
||||||
|
import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
|
||||||
|
|
||||||
import com.google.common.base.Preconditions;
|
import com.google.common.base.Preconditions;
|
||||||
|
|
||||||
import io.grpc.AbstractServerBuilder;
|
import io.grpc.AbstractServerBuilder;
|
||||||
|
@ -61,6 +64,7 @@ public final class NettyServerBuilder extends AbstractServerBuilder<NettyServerB
|
||||||
private SslContext sslContext;
|
private SslContext sslContext;
|
||||||
private int maxConcurrentCallsPerConnection = Integer.MAX_VALUE;
|
private int maxConcurrentCallsPerConnection = Integer.MAX_VALUE;
|
||||||
private int flowControlWindow = DEFAULT_FLOW_CONTROL_WINDOW;
|
private int flowControlWindow = DEFAULT_FLOW_CONTROL_WINDOW;
|
||||||
|
private int maxMessageSize = DEFAULT_MAX_MESSAGE_SIZE;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates a server builder that will bind to the given port.
|
* Creates a server builder that will bind to the given port.
|
||||||
|
@ -190,9 +194,20 @@ public final class NettyServerBuilder extends AbstractServerBuilder<NettyServerB
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the maximum message size allowed to be received on the server. If not called,
|
||||||
|
* defaults to 100 MiB.
|
||||||
|
*/
|
||||||
|
public NettyServerBuilder maxMessageSize(int maxMessageSize) {
|
||||||
|
checkArgument(maxMessageSize >= 0, "maxMessageSize must be >= 0");
|
||||||
|
this.maxMessageSize = maxMessageSize;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected NettyServer buildTransportServer() {
|
protected NettyServer buildTransportServer() {
|
||||||
return new NettyServer(address, channelType, bossEventLoopGroup,
|
return new NettyServer(address, channelType, bossEventLoopGroup,
|
||||||
workerEventLoopGroup, sslContext, maxConcurrentCallsPerConnection, flowControlWindow);
|
workerEventLoopGroup, sslContext, maxConcurrentCallsPerConnection, flowControlWindow,
|
||||||
|
maxMessageSize);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,6 +31,7 @@
|
||||||
|
|
||||||
package io.grpc.netty;
|
package io.grpc.netty;
|
||||||
|
|
||||||
|
import static com.google.common.base.Preconditions.checkArgument;
|
||||||
import static io.grpc.netty.Utils.CONTENT_TYPE_GRPC;
|
import static io.grpc.netty.Utils.CONTENT_TYPE_GRPC;
|
||||||
import static io.grpc.netty.Utils.CONTENT_TYPE_HEADER;
|
import static io.grpc.netty.Utils.CONTENT_TYPE_HEADER;
|
||||||
import static io.grpc.netty.Utils.HTTP_METHOD;
|
import static io.grpc.netty.Utils.HTTP_METHOD;
|
||||||
|
@ -81,6 +82,7 @@ class NettyServerHandler extends Http2ConnectionHandler {
|
||||||
|
|
||||||
private final Http2Connection.PropertyKey streamKey;
|
private final Http2Connection.PropertyKey streamKey;
|
||||||
private final ServerTransportListener transportListener;
|
private final ServerTransportListener transportListener;
|
||||||
|
private final int maxMessageSize;
|
||||||
private Throwable connectionError;
|
private Throwable connectionError;
|
||||||
private ChannelHandlerContext ctx;
|
private ChannelHandlerContext ctx;
|
||||||
private boolean teWarningLogged;
|
private boolean teWarningLogged;
|
||||||
|
@ -93,10 +95,13 @@ class NettyServerHandler extends Http2ConnectionHandler {
|
||||||
Http2FrameReader frameReader,
|
Http2FrameReader frameReader,
|
||||||
Http2FrameWriter frameWriter,
|
Http2FrameWriter frameWriter,
|
||||||
int maxStreams,
|
int maxStreams,
|
||||||
int flowControlWindow) {
|
int flowControlWindow,
|
||||||
|
int maxMessageSize) {
|
||||||
super(connection, frameReader, frameWriter, new LazyFrameListener());
|
super(connection, frameReader, frameWriter, new LazyFrameListener());
|
||||||
Preconditions.checkArgument(flowControlWindow > 0, "flowControlWindow must be positive");
|
Preconditions.checkArgument(flowControlWindow > 0, "flowControlWindow must be positive");
|
||||||
this.flowControlWindow = flowControlWindow;
|
this.flowControlWindow = flowControlWindow;
|
||||||
|
checkArgument(maxMessageSize >= 0, "maxMessageSize must be >= 0");
|
||||||
|
this.maxMessageSize = maxMessageSize;
|
||||||
|
|
||||||
streamKey = connection.newKey();
|
streamKey = connection.newKey();
|
||||||
this.transportListener = Preconditions.checkNotNull(transportListener, "transportListener");
|
this.transportListener = Preconditions.checkNotNull(transportListener, "transportListener");
|
||||||
|
@ -146,7 +151,8 @@ class NettyServerHandler extends Http2ConnectionHandler {
|
||||||
// The Http2Stream object was put by AbstractHttp2ConnectionHandler before calling this
|
// The Http2Stream object was put by AbstractHttp2ConnectionHandler before calling this
|
||||||
// method.
|
// method.
|
||||||
Http2Stream http2Stream = requireHttp2Stream(streamId);
|
Http2Stream http2Stream = requireHttp2Stream(streamId);
|
||||||
NettyServerStream stream = new NettyServerStream(ctx.channel(), http2Stream, this);
|
NettyServerStream stream = new NettyServerStream(ctx.channel(), http2Stream, this,
|
||||||
|
maxMessageSize);
|
||||||
http2Stream.setProperty(streamKey, stream);
|
http2Stream.setProperty(streamKey, stream);
|
||||||
String method = determineMethod(streamId, headers);
|
String method = determineMethod(streamId, headers);
|
||||||
|
|
||||||
|
|
|
@ -54,8 +54,9 @@ class NettyServerStream extends AbstractServerStream<Integer> {
|
||||||
private final Http2Stream http2Stream;
|
private final Http2Stream http2Stream;
|
||||||
private final WriteQueue writeQueue;
|
private final WriteQueue writeQueue;
|
||||||
|
|
||||||
NettyServerStream(Channel channel, Http2Stream http2Stream, NettyServerHandler handler) {
|
NettyServerStream(Channel channel, Http2Stream http2Stream, NettyServerHandler handler,
|
||||||
super(new NettyWritableBufferAllocator(channel.alloc()));
|
int maxMessageSize) {
|
||||||
|
super(new NettyWritableBufferAllocator(channel.alloc()), maxMessageSize);
|
||||||
this.writeQueue = handler.getWriteQueue();
|
this.writeQueue = handler.getWriteQueue();
|
||||||
this.channel = checkNotNull(channel, "channel");
|
this.channel = checkNotNull(channel, "channel");
|
||||||
this.http2Stream = checkNotNull(http2Stream, "http2Stream");
|
this.http2Stream = checkNotNull(http2Stream, "http2Stream");
|
||||||
|
|
|
@ -67,14 +67,16 @@ class NettyServerTransport implements ServerTransport {
|
||||||
private final int maxStreams;
|
private final int maxStreams;
|
||||||
private ServerTransportListener listener;
|
private ServerTransportListener listener;
|
||||||
private boolean terminated;
|
private boolean terminated;
|
||||||
private int flowControlWindow;
|
private final int flowControlWindow;
|
||||||
|
private final int maxMessageSize;
|
||||||
|
|
||||||
NettyServerTransport(Channel channel, @Nullable SslContext sslContext, int maxStreams,
|
NettyServerTransport(Channel channel, @Nullable SslContext sslContext, int maxStreams,
|
||||||
int flowControlWindow) {
|
int flowControlWindow, int maxMessageSize) {
|
||||||
this.channel = Preconditions.checkNotNull(channel, "channel");
|
this.channel = Preconditions.checkNotNull(channel, "channel");
|
||||||
this.sslContext = sslContext;
|
this.sslContext = sslContext;
|
||||||
this.maxStreams = maxStreams;
|
this.maxStreams = maxStreams;
|
||||||
this.flowControlWindow = flowControlWindow;
|
this.flowControlWindow = flowControlWindow;
|
||||||
|
this.maxMessageSize = maxMessageSize;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void start(ServerTransportListener listener) {
|
public void start(ServerTransportListener listener) {
|
||||||
|
@ -135,6 +137,6 @@ class NettyServerTransport implements ServerTransport {
|
||||||
new Http2OutboundFrameLogger(new DefaultHttp2FrameWriter(), frameLogger);
|
new Http2OutboundFrameLogger(new DefaultHttp2FrameWriter(), frameLogger);
|
||||||
|
|
||||||
return new NettyServerHandler(transportListener, connection, frameReader, frameWriter,
|
return new NettyServerHandler(transportListener, connection, frameReader, frameWriter,
|
||||||
maxStreams, flowControlWindow);
|
maxStreams, flowControlWindow, maxMessageSize);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,6 +31,7 @@
|
||||||
|
|
||||||
package io.grpc.netty;
|
package io.grpc.netty;
|
||||||
|
|
||||||
|
import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
|
||||||
import static io.grpc.netty.NettyTestUtil.messageFrame;
|
import static io.grpc.netty.NettyTestUtil.messageFrame;
|
||||||
import static io.grpc.netty.Utils.CONTENT_TYPE_GRPC;
|
import static io.grpc.netty.Utils.CONTENT_TYPE_GRPC;
|
||||||
import static io.grpc.netty.Utils.CONTENT_TYPE_HEADER;
|
import static io.grpc.netty.Utils.CONTENT_TYPE_HEADER;
|
||||||
|
@ -54,6 +55,7 @@ import static org.mockito.Mockito.when;
|
||||||
import io.grpc.Metadata;
|
import io.grpc.Metadata;
|
||||||
import io.grpc.Status;
|
import io.grpc.Status;
|
||||||
import io.grpc.internal.ClientStreamListener;
|
import io.grpc.internal.ClientStreamListener;
|
||||||
|
|
||||||
import io.netty.buffer.ByteBuf;
|
import io.netty.buffer.ByteBuf;
|
||||||
import io.netty.buffer.Unpooled;
|
import io.netty.buffer.Unpooled;
|
||||||
import io.netty.channel.ChannelPromise;
|
import io.netty.channel.ChannelPromise;
|
||||||
|
@ -321,7 +323,7 @@ public class NettyClientStreamTest extends NettyStreamTestBase {
|
||||||
@Test
|
@Test
|
||||||
public void setHttp2StreamShouldNotifyReady() {
|
public void setHttp2StreamShouldNotifyReady() {
|
||||||
listener = mock(ClientStreamListener.class);
|
listener = mock(ClientStreamListener.class);
|
||||||
stream = new NettyClientStream(listener, channel, handler);
|
stream = new NettyClientStream(listener, channel, handler, DEFAULT_MAX_MESSAGE_SIZE);
|
||||||
stream().id(STREAM_ID);
|
stream().id(STREAM_ID);
|
||||||
verify(listener, never()).onReady();
|
verify(listener, never()).onReady();
|
||||||
assertFalse(stream.isReady());
|
assertFalse(stream.isReady());
|
||||||
|
@ -343,7 +345,8 @@ public class NettyClientStreamTest extends NettyStreamTestBase {
|
||||||
}
|
}
|
||||||
}).when(writeQueue).enqueue(any(), any(ChannelPromise.class), anyBoolean());
|
}).when(writeQueue).enqueue(any(), any(ChannelPromise.class), anyBoolean());
|
||||||
when(writeQueue.enqueue(any(), anyBoolean())).thenReturn(future);
|
when(writeQueue.enqueue(any(), anyBoolean())).thenReturn(future);
|
||||||
NettyClientStream stream = new NettyClientStream(listener, channel, handler);
|
NettyClientStream stream = new NettyClientStream(listener, channel, handler,
|
||||||
|
DEFAULT_MAX_MESSAGE_SIZE);
|
||||||
assertTrue(stream.canSend());
|
assertTrue(stream.canSend());
|
||||||
assertTrue(stream.canReceive());
|
assertTrue(stream.canReceive());
|
||||||
stream.id(STREAM_ID);
|
stream.id(STREAM_ID);
|
||||||
|
|
|
@ -32,9 +32,12 @@
|
||||||
package io.grpc.netty;
|
package io.grpc.netty;
|
||||||
|
|
||||||
import static com.google.common.base.Charsets.UTF_8;
|
import static com.google.common.base.Charsets.UTF_8;
|
||||||
|
import static io.grpc.Status.Code.INTERNAL;
|
||||||
|
import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
|
||||||
import static io.grpc.internal.GrpcUtil.USER_AGENT_KEY;
|
import static io.grpc.internal.GrpcUtil.USER_AGENT_KEY;
|
||||||
import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_WINDOW_SIZE;
|
import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_WINDOW_SIZE;
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
|
import static org.junit.Assert.assertTrue;
|
||||||
import static org.junit.Assert.fail;
|
import static org.junit.Assert.fail;
|
||||||
|
|
||||||
import com.google.common.io.ByteStreams;
|
import com.google.common.io.ByteStreams;
|
||||||
|
@ -152,6 +155,25 @@ public class NettyClientTransportTest {
|
||||||
receivedHeaders.get(USER_AGENT_KEY));
|
receivedHeaders.get(USER_AGENT_KEY));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void maxMessageSizeShouldBeEnforced() throws Throwable {
|
||||||
|
startServer();
|
||||||
|
// Allow the response payloads of up to 1 byte.
|
||||||
|
NettyClientTransport transport = newTransport(newNegotiator(), 1);
|
||||||
|
transport.start(clientTransportListener);
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Send a single RPC and wait for the response.
|
||||||
|
new Rpc(transport).halfClose().waitForResponse();
|
||||||
|
fail("Expected the stream to fail.");
|
||||||
|
} catch (ExecutionException e) {
|
||||||
|
Status status = Status.fromThrowable(e);
|
||||||
|
assertEquals(INTERNAL, status.getCode());
|
||||||
|
System.err.println(status.getDescription());
|
||||||
|
assertTrue(status.getDescription().contains("deframing"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Verifies that we can create multiple TLS client transports from the same builder.
|
* Verifies that we can create multiple TLS client transports from the same builder.
|
||||||
*/
|
*/
|
||||||
|
@ -218,8 +240,12 @@ public class NettyClientTransportTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
private NettyClientTransport newTransport(ProtocolNegotiator negotiator) {
|
private NettyClientTransport newTransport(ProtocolNegotiator negotiator) {
|
||||||
|
return newTransport(negotiator, DEFAULT_MAX_MESSAGE_SIZE);
|
||||||
|
}
|
||||||
|
|
||||||
|
private NettyClientTransport newTransport(ProtocolNegotiator negotiator, int maxMsgSize) {
|
||||||
NettyClientTransport transport = new NettyClientTransport(address, NioSocketChannel.class,
|
NettyClientTransport transport = new NettyClientTransport(address, NioSocketChannel.class,
|
||||||
group, negotiator, DEFAULT_WINDOW_SIZE);
|
group, negotiator, DEFAULT_WINDOW_SIZE, maxMsgSize);
|
||||||
transports.add(transport);
|
transports.add(transport);
|
||||||
return transport;
|
return transport;
|
||||||
}
|
}
|
||||||
|
@ -235,7 +261,7 @@ public class NettyClientTransportTest {
|
||||||
.ciphers(TestUtils.preferredTestCiphers(), SupportedCipherSuiteFilter.INSTANCE).build();
|
.ciphers(TestUtils.preferredTestCiphers(), SupportedCipherSuiteFilter.INSTANCE).build();
|
||||||
server = new NettyServer(address, NioServerSocketChannel.class,
|
server = new NettyServer(address, NioServerSocketChannel.class,
|
||||||
group, group, serverContext, maxStreamsPerConnection,
|
group, group, serverContext, maxStreamsPerConnection,
|
||||||
DEFAULT_WINDOW_SIZE);
|
DEFAULT_WINDOW_SIZE, DEFAULT_MAX_MESSAGE_SIZE);
|
||||||
server.start(serverListener);
|
server.start(serverListener);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -32,6 +32,7 @@
|
||||||
package io.grpc.netty;
|
package io.grpc.netty;
|
||||||
|
|
||||||
import static com.google.common.base.Charsets.UTF_8;
|
import static com.google.common.base.Charsets.UTF_8;
|
||||||
|
import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
|
||||||
import static io.grpc.netty.Utils.CONTENT_TYPE_GRPC;
|
import static io.grpc.netty.Utils.CONTENT_TYPE_GRPC;
|
||||||
import static io.grpc.netty.Utils.CONTENT_TYPE_HEADER;
|
import static io.grpc.netty.Utils.CONTENT_TYPE_HEADER;
|
||||||
import static io.grpc.netty.Utils.HTTP_METHOD;
|
import static io.grpc.netty.Utils.HTTP_METHOD;
|
||||||
|
@ -64,6 +65,7 @@ import io.grpc.internal.ServerStream;
|
||||||
import io.grpc.internal.ServerStreamListener;
|
import io.grpc.internal.ServerStreamListener;
|
||||||
import io.grpc.internal.ServerTransportListener;
|
import io.grpc.internal.ServerTransportListener;
|
||||||
import io.grpc.internal.WritableBuffer;
|
import io.grpc.internal.WritableBuffer;
|
||||||
|
|
||||||
import io.netty.buffer.ByteBuf;
|
import io.netty.buffer.ByteBuf;
|
||||||
import io.netty.buffer.ByteBufAllocator;
|
import io.netty.buffer.ByteBufAllocator;
|
||||||
import io.netty.buffer.Unpooled;
|
import io.netty.buffer.Unpooled;
|
||||||
|
@ -303,7 +305,7 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase {
|
||||||
Http2Connection connection = new DefaultHttp2Connection(true);
|
Http2Connection connection = new DefaultHttp2Connection(true);
|
||||||
handler =
|
handler =
|
||||||
new NettyServerHandler(transportListener, connection, new DefaultHttp2FrameReader(),
|
new NettyServerHandler(transportListener, connection, new DefaultHttp2FrameReader(),
|
||||||
frameWriter, maxConcurrentStreams, DEFAULT_WINDOW_SIZE);
|
frameWriter, maxConcurrentStreams, DEFAULT_WINDOW_SIZE, DEFAULT_MAX_MESSAGE_SIZE);
|
||||||
|
|
||||||
when(channel.isActive()).thenReturn(true);
|
when(channel.isActive()).thenReturn(true);
|
||||||
mockContext();
|
mockContext();
|
||||||
|
@ -403,7 +405,7 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase {
|
||||||
Http2FrameReader frameReader = new DefaultHttp2FrameReader();
|
Http2FrameReader frameReader = new DefaultHttp2FrameReader();
|
||||||
Http2FrameWriter frameWriter = new DefaultHttp2FrameWriter();
|
Http2FrameWriter frameWriter = new DefaultHttp2FrameWriter();
|
||||||
return new NettyServerHandler(transportListener, connection, frameReader, frameWriter,
|
return new NettyServerHandler(transportListener, connection, frameReader, frameWriter,
|
||||||
Integer.MAX_VALUE, flowControlWindow);
|
Integer.MAX_VALUE, flowControlWindow, DEFAULT_MAX_MESSAGE_SIZE);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static NettyServerHandler newHandler(ServerTransportListener transportListener) {
|
private static NettyServerHandler newHandler(ServerTransportListener transportListener) {
|
||||||
|
|
|
@ -31,6 +31,7 @@
|
||||||
|
|
||||||
package io.grpc.netty;
|
package io.grpc.netty;
|
||||||
|
|
||||||
|
import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
|
||||||
import static io.grpc.netty.NettyTestUtil.messageFrame;
|
import static io.grpc.netty.NettyTestUtil.messageFrame;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.Assert.assertTrue;
|
||||||
import static org.junit.Assert.fail;
|
import static org.junit.Assert.fail;
|
||||||
|
@ -50,6 +51,7 @@ import static org.mockito.Mockito.when;
|
||||||
import io.grpc.Metadata;
|
import io.grpc.Metadata;
|
||||||
import io.grpc.Status;
|
import io.grpc.Status;
|
||||||
import io.grpc.internal.ServerStreamListener;
|
import io.grpc.internal.ServerStreamListener;
|
||||||
|
|
||||||
import io.netty.buffer.EmptyByteBuf;
|
import io.netty.buffer.EmptyByteBuf;
|
||||||
import io.netty.buffer.UnpooledByteBufAllocator;
|
import io.netty.buffer.UnpooledByteBufAllocator;
|
||||||
import io.netty.channel.ChannelPromise;
|
import io.netty.channel.ChannelPromise;
|
||||||
|
@ -254,7 +256,8 @@ public class NettyServerStreamTest extends NettyStreamTestBase {
|
||||||
}
|
}
|
||||||
}).when(writeQueue).enqueue(any(), any(ChannelPromise.class), anyBoolean());
|
}).when(writeQueue).enqueue(any(), any(ChannelPromise.class), anyBoolean());
|
||||||
when(writeQueue.enqueue(any(), anyBoolean())).thenReturn(future);
|
when(writeQueue.enqueue(any(), anyBoolean())).thenReturn(future);
|
||||||
NettyServerStream stream = new NettyServerStream(channel, http2Stream, handler);
|
NettyServerStream stream = new NettyServerStream(channel, http2Stream, handler,
|
||||||
|
DEFAULT_MAX_MESSAGE_SIZE);
|
||||||
stream.setListener(serverListener);
|
stream.setListener(serverListener);
|
||||||
assertTrue(stream.canReceive());
|
assertTrue(stream.canReceive());
|
||||||
assertTrue(stream.canSend());
|
assertTrue(stream.canSend());
|
||||||
|
|
|
@ -31,6 +31,9 @@
|
||||||
|
|
||||||
package io.grpc.okhttp;
|
package io.grpc.okhttp;
|
||||||
|
|
||||||
|
import static com.google.common.base.Preconditions.checkArgument;
|
||||||
|
import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
|
||||||
|
|
||||||
import com.google.common.base.Preconditions;
|
import com.google.common.base.Preconditions;
|
||||||
import com.google.common.util.concurrent.ThreadFactoryBuilder;
|
import com.google.common.util.concurrent.ThreadFactoryBuilder;
|
||||||
|
|
||||||
|
@ -96,6 +99,7 @@ public final class OkHttpChannelBuilder extends AbstractChannelBuilder<OkHttpCha
|
||||||
private SSLSocketFactory sslSocketFactory;
|
private SSLSocketFactory sslSocketFactory;
|
||||||
private ConnectionSpec connectionSpec = DEFAULT_CONNECTION_SPEC;
|
private ConnectionSpec connectionSpec = DEFAULT_CONNECTION_SPEC;
|
||||||
private NegotiationType negotiationType = NegotiationType.TLS;
|
private NegotiationType negotiationType = NegotiationType.TLS;
|
||||||
|
private int maxMessageSize = DEFAULT_MAX_MESSAGE_SIZE;
|
||||||
|
|
||||||
private OkHttpChannelBuilder(String host, int port) {
|
private OkHttpChannelBuilder(String host, int port) {
|
||||||
this.host = Preconditions.checkNotNull(host);
|
this.host = Preconditions.checkNotNull(host);
|
||||||
|
@ -159,10 +163,20 @@ public final class OkHttpChannelBuilder extends AbstractChannelBuilder<OkHttpCha
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the maximum message size allowed to be received on the channel. If not called,
|
||||||
|
* defaults to {@link io.grpc.internal.GrpcUtil#DEFAULT_MAX_MESSAGE_SIZE}.
|
||||||
|
*/
|
||||||
|
public OkHttpChannelBuilder maxMessageSize(int maxMessageSize) {
|
||||||
|
checkArgument(maxMessageSize >= 0, "maxMessageSize must be >= 0");
|
||||||
|
this.maxMessageSize = maxMessageSize;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected ClientTransportFactory buildTransportFactory() {
|
protected ClientTransportFactory buildTransportFactory() {
|
||||||
return new OkHttpTransportFactory(host, port, authorityHost, transportExecutor,
|
return new OkHttpTransportFactory(host, port, authorityHost, transportExecutor,
|
||||||
createSocketFactory(), connectionSpec);
|
createSocketFactory(), connectionSpec, maxMessageSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
private SSLSocketFactory createSocketFactory() {
|
private SSLSocketFactory createSocketFactory() {
|
||||||
|
@ -186,18 +200,21 @@ public final class OkHttpChannelBuilder extends AbstractChannelBuilder<OkHttpCha
|
||||||
private final boolean usingSharedExecutor;
|
private final boolean usingSharedExecutor;
|
||||||
private final SSLSocketFactory socketFactory;
|
private final SSLSocketFactory socketFactory;
|
||||||
private final ConnectionSpec connectionSpec;
|
private final ConnectionSpec connectionSpec;
|
||||||
|
private final int maxMessageSize;
|
||||||
|
|
||||||
private OkHttpTransportFactory(String host,
|
private OkHttpTransportFactory(String host,
|
||||||
int port,
|
int port,
|
||||||
String authorityHost,
|
String authorityHost,
|
||||||
ExecutorService executor,
|
ExecutorService executor,
|
||||||
SSLSocketFactory socketFactory,
|
SSLSocketFactory socketFactory,
|
||||||
ConnectionSpec connectionSpec) {
|
ConnectionSpec connectionSpec,
|
||||||
|
int maxMessageSize) {
|
||||||
this.host = host;
|
this.host = host;
|
||||||
this.port = port;
|
this.port = port;
|
||||||
this.authorityHost = authorityHost;
|
this.authorityHost = authorityHost;
|
||||||
this.socketFactory = socketFactory;
|
this.socketFactory = socketFactory;
|
||||||
this.connectionSpec = connectionSpec;
|
this.connectionSpec = connectionSpec;
|
||||||
|
this.maxMessageSize = maxMessageSize;
|
||||||
|
|
||||||
usingSharedExecutor = executor == null;
|
usingSharedExecutor = executor == null;
|
||||||
if (usingSharedExecutor) {
|
if (usingSharedExecutor) {
|
||||||
|
@ -211,7 +228,7 @@ public final class OkHttpChannelBuilder extends AbstractChannelBuilder<OkHttpCha
|
||||||
@Override
|
@Override
|
||||||
public ClientTransport newClientTransport() {
|
public ClientTransport newClientTransport() {
|
||||||
return new OkHttpClientTransport(host, port, authorityHost, executor, socketFactory,
|
return new OkHttpClientTransport(host, port, authorityHost, executor, socketFactory,
|
||||||
connectionSpec);
|
connectionSpec, maxMessageSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -58,26 +58,12 @@ import javax.annotation.concurrent.GuardedBy;
|
||||||
*/
|
*/
|
||||||
class OkHttpClientStream extends Http2ClientStream {
|
class OkHttpClientStream extends Http2ClientStream {
|
||||||
|
|
||||||
private static int WINDOW_UPDATE_THRESHOLD = Utils.DEFAULT_WINDOW_SIZE / 2;
|
private static final int WINDOW_UPDATE_THRESHOLD = Utils.DEFAULT_WINDOW_SIZE / 2;
|
||||||
|
|
||||||
private static final Buffer EMPTY_BUFFER = new Buffer();
|
private static final Buffer EMPTY_BUFFER = new Buffer();
|
||||||
|
|
||||||
private final MethodType type;
|
private final MethodType type;
|
||||||
|
|
||||||
/**
|
|
||||||
* Construct a new client stream.
|
|
||||||
*/
|
|
||||||
static OkHttpClientStream newStream(ClientStreamListener listener,
|
|
||||||
AsyncFrameWriter frameWriter,
|
|
||||||
OkHttpClientTransport transport,
|
|
||||||
OutboundFlowController outboundFlow,
|
|
||||||
MethodType type,
|
|
||||||
Object lock,
|
|
||||||
List<Header> requestHeaders) {
|
|
||||||
return new OkHttpClientStream(
|
|
||||||
listener, frameWriter, transport, outboundFlow, type, lock, requestHeaders);
|
|
||||||
}
|
|
||||||
|
|
||||||
@GuardedBy("lock")
|
@GuardedBy("lock")
|
||||||
private int window = Utils.DEFAULT_WINDOW_SIZE;
|
private int window = Utils.DEFAULT_WINDOW_SIZE;
|
||||||
@GuardedBy("lock")
|
@GuardedBy("lock")
|
||||||
|
@ -94,15 +80,15 @@ class OkHttpClientStream extends Http2ClientStream {
|
||||||
@GuardedBy("lock")
|
@GuardedBy("lock")
|
||||||
private boolean cancelSent = false;
|
private boolean cancelSent = false;
|
||||||
|
|
||||||
|
OkHttpClientStream(ClientStreamListener listener,
|
||||||
private OkHttpClientStream(ClientStreamListener listener,
|
|
||||||
AsyncFrameWriter frameWriter,
|
AsyncFrameWriter frameWriter,
|
||||||
OkHttpClientTransport transport,
|
OkHttpClientTransport transport,
|
||||||
OutboundFlowController outboundFlow,
|
OutboundFlowController outboundFlow,
|
||||||
MethodType type,
|
MethodType type,
|
||||||
Object lock,
|
Object lock,
|
||||||
List<Header> requestHeaders) {
|
List<Header> requestHeaders,
|
||||||
super(new OkHttpWritableBufferAllocator(), listener);
|
int maxMessageSize) {
|
||||||
|
super(new OkHttpWritableBufferAllocator(), listener, maxMessageSize);
|
||||||
this.frameWriter = frameWriter;
|
this.frameWriter = frameWriter;
|
||||||
this.transport = transport;
|
this.transport = transport;
|
||||||
this.outboundFlow = outboundFlow;
|
this.outboundFlow = outboundFlow;
|
||||||
|
|
|
@ -141,6 +141,7 @@ class OkHttpClientTransport implements ClientTransport {
|
||||||
private final Executor executor;
|
private final Executor executor;
|
||||||
// Wrap on executor, to guarantee some operations be executed serially.
|
// Wrap on executor, to guarantee some operations be executed serially.
|
||||||
private final SerializingExecutor serializingExecutor;
|
private final SerializingExecutor serializingExecutor;
|
||||||
|
private final int maxMessageSize;
|
||||||
private int connectionUnacknowledgedBytesRead;
|
private int connectionUnacknowledgedBytesRead;
|
||||||
private ClientFrameHandler clientFrameHandler;
|
private ClientFrameHandler clientFrameHandler;
|
||||||
// The status used to finish all active streams when the transport is closed.
|
// The status used to finish all active streams when the transport is closed.
|
||||||
|
@ -166,10 +167,13 @@ class OkHttpClientTransport implements ClientTransport {
|
||||||
SettableFuture<Void> connectedFuture;
|
SettableFuture<Void> connectedFuture;
|
||||||
|
|
||||||
OkHttpClientTransport(String host, int port, String authorityHost, Executor executor,
|
OkHttpClientTransport(String host, int port, String authorityHost, Executor executor,
|
||||||
@Nullable SSLSocketFactory sslSocketFactory, ConnectionSpec connectionSpec) {
|
|
||||||
|
@Nullable SSLSocketFactory sslSocketFactory, ConnectionSpec connectionSpec,
|
||||||
|
int maxMessageSize) {
|
||||||
this.host = Preconditions.checkNotNull(host, "host");
|
this.host = Preconditions.checkNotNull(host, "host");
|
||||||
this.port = port;
|
this.port = port;
|
||||||
this.authorityHost = authorityHost;
|
this.authorityHost = authorityHost;
|
||||||
|
this.maxMessageSize = maxMessageSize;
|
||||||
defaultAuthority = authorityHost + ":" + port;
|
defaultAuthority = authorityHost + ":" + port;
|
||||||
this.executor = Preconditions.checkNotNull(executor, "executor");
|
this.executor = Preconditions.checkNotNull(executor, "executor");
|
||||||
serializingExecutor = new SerializingExecutor(executor);
|
serializingExecutor = new SerializingExecutor(executor);
|
||||||
|
@ -187,10 +191,12 @@ class OkHttpClientTransport implements ClientTransport {
|
||||||
@VisibleForTesting
|
@VisibleForTesting
|
||||||
OkHttpClientTransport(Executor executor, FrameReader frameReader, FrameWriter testFrameWriter,
|
OkHttpClientTransport(Executor executor, FrameReader frameReader, FrameWriter testFrameWriter,
|
||||||
int nextStreamId, Socket socket, Ticker ticker,
|
int nextStreamId, Socket socket, Ticker ticker,
|
||||||
@Nullable Runnable connectingCallback, SettableFuture<Void> connectedFuture) {
|
@Nullable Runnable connectingCallback, SettableFuture<Void> connectedFuture,
|
||||||
|
int maxMessageSize) {
|
||||||
host = null;
|
host = null;
|
||||||
port = 0;
|
port = 0;
|
||||||
authorityHost = null;
|
authorityHost = null;
|
||||||
|
this.maxMessageSize = maxMessageSize;
|
||||||
defaultAuthority = "notarealauthority:80";
|
defaultAuthority = "notarealauthority:80";
|
||||||
this.executor = Preconditions.checkNotNull(executor);
|
this.executor = Preconditions.checkNotNull(executor);
|
||||||
serializingExecutor = new SerializingExecutor(executor);
|
serializingExecutor = new SerializingExecutor(executor);
|
||||||
|
@ -248,9 +254,10 @@ class OkHttpClientTransport implements ClientTransport {
|
||||||
Preconditions.checkNotNull(listener, "listener");
|
Preconditions.checkNotNull(listener, "listener");
|
||||||
|
|
||||||
String defaultPath = "/" + method.getFullMethodName();
|
String defaultPath = "/" + method.getFullMethodName();
|
||||||
OkHttpClientStream clientStream = OkHttpClientStream.newStream(
|
OkHttpClientStream clientStream = new OkHttpClientStream(
|
||||||
listener, frameWriter, this, outboundFlow, method.getType(), lock,
|
listener, frameWriter, this, outboundFlow, method.getType(), lock,
|
||||||
Headers.createRequestHeaders(headers, defaultPath, defaultAuthority));
|
Headers.createRequestHeaders(headers, defaultPath, defaultAuthority),
|
||||||
|
maxMessageSize);
|
||||||
|
|
||||||
synchronized (lock) {
|
synchronized (lock) {
|
||||||
if (goAway) {
|
if (goAway) {
|
||||||
|
|
|
@ -32,6 +32,8 @@
|
||||||
package io.grpc.okhttp;
|
package io.grpc.okhttp;
|
||||||
|
|
||||||
import static com.google.common.base.Charsets.UTF_8;
|
import static com.google.common.base.Charsets.UTF_8;
|
||||||
|
import static io.grpc.Status.Code.INTERNAL;
|
||||||
|
import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
|
||||||
import static io.grpc.okhttp.Headers.CONTENT_TYPE_HEADER;
|
import static io.grpc.okhttp.Headers.CONTENT_TYPE_HEADER;
|
||||||
import static io.grpc.okhttp.Headers.METHOD_HEADER;
|
import static io.grpc.okhttp.Headers.METHOD_HEADER;
|
||||||
import static io.grpc.okhttp.Headers.SCHEME_HEADER;
|
import static io.grpc.okhttp.Headers.SCHEME_HEADER;
|
||||||
|
@ -148,20 +150,20 @@ public class OkHttpClientTransportTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
private void initTransport() throws Exception {
|
private void initTransport() throws Exception {
|
||||||
startTransport(3, null, true);
|
startTransport(3, null, true, DEFAULT_MAX_MESSAGE_SIZE);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void initTransport(int startId) throws Exception {
|
private void initTransport(int startId) throws Exception {
|
||||||
startTransport(startId, null, true);
|
startTransport(startId, null, true, DEFAULT_MAX_MESSAGE_SIZE);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void initTransportAndDelayConnected() throws Exception {
|
private void initTransportAndDelayConnected() throws Exception {
|
||||||
delayConnectedCallback = new DelayConnectedCallback();
|
delayConnectedCallback = new DelayConnectedCallback();
|
||||||
startTransport(3, delayConnectedCallback, false);
|
startTransport(3, delayConnectedCallback, false, DEFAULT_MAX_MESSAGE_SIZE);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void startTransport(int startId, @Nullable Runnable connectingCallback,
|
private void startTransport(int startId, @Nullable Runnable connectingCallback,
|
||||||
boolean waitingForConnected) throws Exception {
|
boolean waitingForConnected, int maxMessageSize) throws Exception {
|
||||||
connectedFuture = SettableFuture.create();
|
connectedFuture = SettableFuture.create();
|
||||||
Ticker ticker = new Ticker() {
|
Ticker ticker = new Ticker() {
|
||||||
@Override
|
@Override
|
||||||
|
@ -171,7 +173,8 @@ public class OkHttpClientTransportTest {
|
||||||
};
|
};
|
||||||
clientTransport = new OkHttpClientTransport(
|
clientTransport = new OkHttpClientTransport(
|
||||||
executor, frameReader, frameWriter, startId,
|
executor, frameReader, frameWriter, startId,
|
||||||
new MockSocket(frameReader), ticker, connectingCallback, connectedFuture);
|
new MockSocket(frameReader), ticker, connectingCallback, connectedFuture,
|
||||||
|
maxMessageSize);
|
||||||
clientTransport.start(transportListener);
|
clientTransport.start(transportListener);
|
||||||
if (waitingForConnected) {
|
if (waitingForConnected) {
|
||||||
connectedFuture.get(TIME_OUT_MS, TimeUnit.MILLISECONDS);
|
connectedFuture.get(TIME_OUT_MS, TimeUnit.MILLISECONDS);
|
||||||
|
@ -188,6 +191,26 @@ public class OkHttpClientTransportTest {
|
||||||
executor.shutdown();
|
executor.shutdown();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void maxMessageSizeShouldBeEnforced() throws Exception {
|
||||||
|
// Allow the response payloads of up to 1 byte.
|
||||||
|
startTransport(3, null, true, 1);
|
||||||
|
|
||||||
|
MockStreamListener listener = new MockStreamListener();
|
||||||
|
clientTransport.newStream(method, new Metadata.Headers(), listener).request(1);
|
||||||
|
assertContainStream(3);
|
||||||
|
frameHandler().headers(false, false, 3, 0, grpcResponseHeaders(), HeadersMode.HTTP_20_HEADERS);
|
||||||
|
assertNotNull(listener.headers);
|
||||||
|
|
||||||
|
// Receive the message.
|
||||||
|
final String message = "Hello Client";
|
||||||
|
Buffer buffer = createMessageFrame(message);
|
||||||
|
frameHandler().data(false, 3, buffer, (int) buffer.size());
|
||||||
|
|
||||||
|
listener.waitUntilStreamClosed();
|
||||||
|
assertEquals(INTERNAL, listener.status.getCode());
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* When nextFrame throws IOException, the transport should be aborted.
|
* When nextFrame throws IOException, the transport should be aborted.
|
||||||
*/
|
*/
|
||||||
|
|
Loading…
Reference in New Issue