okhttp: add okhttpServerBuilder permitKeepAliveTime() and permitKeepAliveWithoutCalls() for server keepAlive enforcement (#9544)

This commit is contained in:
yifeizhuang 2022-09-15 16:08:55 -07:00 committed by GitHub
parent 15033caf1c
commit a3c1d7711f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 194 additions and 13 deletions

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
package io.grpc.netty;
package io.grpc.internal;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
@ -22,11 +22,11 @@ import java.util.concurrent.TimeUnit;
import javax.annotation.CheckReturnValue;
/** Monitors the client's PING usage to make sure the rate is permitted. */
class KeepAliveEnforcer {
public final class KeepAliveEnforcer {
@VisibleForTesting
static final int MAX_PING_STRIKES = 2;
public static final int MAX_PING_STRIKES = 2;
@VisibleForTesting
static final long IMPLICIT_PERMIT_TIME_NANOS = TimeUnit.HOURS.toNanos(2);
public static final long IMPLICIT_PERMIT_TIME_NANOS = TimeUnit.HOURS.toNanos(2);
private final boolean permitWithoutCalls;
private final long minTimeNanos;

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
package io.grpc.netty;
package io.grpc.internal;
import static com.google.common.truth.Truth.assertThat;

View File

@ -44,6 +44,7 @@ import io.grpc.Metadata;
import io.grpc.ServerStreamTracer;
import io.grpc.Status;
import io.grpc.internal.GrpcUtil;
import io.grpc.internal.KeepAliveEnforcer;
import io.grpc.internal.KeepAliveManager;
import io.grpc.internal.LogExceptionRunnable;
import io.grpc.internal.MaxConnectionIdleManager;

View File

@ -61,6 +61,7 @@ import io.grpc.Status;
import io.grpc.Status.Code;
import io.grpc.StreamTracer;
import io.grpc.internal.GrpcUtil;
import io.grpc.internal.KeepAliveEnforcer;
import io.grpc.internal.KeepAliveManager;
import io.grpc.internal.ServerStream;
import io.grpc.internal.ServerStreamListener;

View File

@ -20,6 +20,7 @@ import static com.google.common.base.Preconditions.checkArgument;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
import com.google.errorprone.annotations.DoNotCall;
import io.grpc.ChoiceServerCredentials;
import io.grpc.ExperimentalApi;
@ -117,6 +118,8 @@ public final class OkHttpServerBuilder extends ForwardingServerBuilder<OkHttpSer
int maxInboundMetadataSize = GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE;
int maxInboundMessageSize = GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
long maxConnectionIdleInNanos = MAX_CONNECTION_IDLE_NANOS_DISABLED;
boolean permitKeepAliveWithoutCalls;
long permitKeepAliveTimeInNanos = TimeUnit.MINUTES.toNanos(5);
@VisibleForTesting
OkHttpServerBuilder(
@ -223,6 +226,41 @@ public final class OkHttpServerBuilder extends ForwardingServerBuilder<OkHttpSer
return this;
}
/**
* Specify the most aggressive keep-alive time clients are permitted to configure. The server will
* try to detect clients exceeding this rate and when detected will forcefully close the
* connection. The default is 5 minutes.
*
* <p>Even though a default is defined that allows some keep-alives, clients must not use
* keep-alive without approval from the service owner. Otherwise, they may experience failures in
* the future if the service becomes more restrictive. When unthrottled, keep-alives can cause a
* significant amount of traffic and CPU usage, so clients and servers should be conservative in
* what they use and accept.
*
* @see #permitKeepAliveWithoutCalls(boolean)
*/
@CanIgnoreReturnValue
@Override
public OkHttpServerBuilder permitKeepAliveTime(long keepAliveTime, TimeUnit timeUnit) {
checkArgument(keepAliveTime >= 0, "permit keepalive time must be non-negative: %s",
keepAliveTime);
permitKeepAliveTimeInNanos = timeUnit.toNanos(keepAliveTime);
return this;
}
/**
* Sets whether to allow clients to send keep-alive HTTP/2 PINGs even if there are no outstanding
* RPCs on the connection. Defaults to {@code false}.
*
* @see #permitKeepAliveTime(long, TimeUnit)
*/
@CanIgnoreReturnValue
@Override
public OkHttpServerBuilder permitKeepAliveWithoutCalls(boolean permit) {
permitKeepAliveWithoutCalls = permit;
return this;
}
/**
* Sets the flow control window in bytes. If not called, the default value is 64 KiB.
*/

View File

@ -29,6 +29,7 @@ import io.grpc.Metadata;
import io.grpc.ServerStreamTracer;
import io.grpc.Status;
import io.grpc.internal.GrpcUtil;
import io.grpc.internal.KeepAliveEnforcer;
import io.grpc.internal.KeepAliveManager;
import io.grpc.internal.MaxConnectionIdleManager;
import io.grpc.internal.ObjectPool;
@ -95,6 +96,7 @@ final class OkHttpServerTransport implements ServerTransport,
private Attributes attributes;
private KeepAliveManager keepAliveManager;
private MaxConnectionIdleManager maxConnectionIdleManager;
private final KeepAliveEnforcer keepAliveEnforcer;
private final Object lock = new Object();
@GuardedBy("lock")
@ -137,6 +139,8 @@ final class OkHttpServerTransport implements ServerTransport,
logId = InternalLogId.allocate(getClass(), bareSocket.getRemoteSocketAddress().toString());
transportExecutor = config.transportExecutorPool.getObject();
scheduledExecutorService = config.scheduledExecutorServicePool.getObject();
keepAliveEnforcer = new KeepAliveEnforcer(config.permitKeepAliveWithoutCalls,
config.permitKeepAliveTimeInNanos, TimeUnit.NANOSECONDS);
}
public void start(ServerTransportListener listener) {
@ -159,6 +163,27 @@ final class OkHttpServerTransport implements ServerTransport,
asyncSink.becomeConnected(Okio.sink(socket), socket);
FrameWriter rawFrameWriter = asyncSink.limitControlFramesWriter(
variant.newWriter(Okio.buffer(asyncSink), false));
FrameWriter writeMonitoringFrameWriter = new ForwardingFrameWriter(rawFrameWriter) {
@Override
public void synReply(boolean outFinished, int streamId, List<Header> headerBlock)
throws IOException {
keepAliveEnforcer.resetCounters();
super.synReply(outFinished, streamId, headerBlock);
}
@Override
public void headers(int streamId, List<Header> headerBlock) throws IOException {
keepAliveEnforcer.resetCounters();
super.headers(streamId, headerBlock);
}
@Override
public void data(boolean outFinished, int streamId, Buffer source, int byteCount)
throws IOException {
keepAliveEnforcer.resetCounters();
super.data(outFinished, streamId, source, byteCount);
}
};
synchronized (lock) {
this.securityInfo = result.securityInfo;
@ -167,7 +192,7 @@ final class OkHttpServerTransport implements ServerTransport,
// does not propagate syscall errors through the FrameWriter. But we handle the
// AsyncSink failures with the same TransportExceptionHandler instance so it is all
// mixed back together.
frameWriter = new ExceptionHandlingFrameWriter(this, rawFrameWriter);
frameWriter = new ExceptionHandlingFrameWriter(this, writeMonitoringFrameWriter);
outboundFlow = new OutboundFlowController(this, frameWriter);
// These writes will be queued in the serializingExecutor waiting for this function to
@ -381,8 +406,11 @@ final class OkHttpServerTransport implements ServerTransport,
void streamClosed(int streamId, boolean flush) {
synchronized (lock) {
streams.remove(streamId);
if (maxConnectionIdleManager != null && streams.isEmpty()) {
maxConnectionIdleManager.onTransportIdle();
if (streams.isEmpty()) {
keepAliveEnforcer.onTransportIdle();
if (maxConnectionIdleManager != null) {
maxConnectionIdleManager.onTransportIdle();
}
}
if (gracefulShutdown && streams.isEmpty()) {
frameWriter.close();
@ -449,6 +477,8 @@ final class OkHttpServerTransport implements ServerTransport,
final int maxInboundMessageSize;
final int maxInboundMetadataSize;
final long maxConnectionIdleNanos;
final boolean permitKeepAliveWithoutCalls;
final long permitKeepAliveTimeInNanos;
public Config(
OkHttpServerBuilder builder,
@ -469,6 +499,8 @@ final class OkHttpServerTransport implements ServerTransport,
maxInboundMessageSize = builder.maxInboundMessageSize;
maxInboundMetadataSize = builder.maxInboundMetadataSize;
maxConnectionIdleNanos = builder.maxConnectionIdleInNanos;
permitKeepAliveWithoutCalls = builder.permitKeepAliveWithoutCalls;
permitKeepAliveTimeInNanos = builder.permitKeepAliveTimeInNanos;
}
}
@ -714,8 +746,11 @@ final class OkHttpServerTransport implements ServerTransport,
authority == null ? null : asciiString(authority),
statsTraceCtx,
tracer);
if (maxConnectionIdleManager != null && streams.isEmpty()) {
maxConnectionIdleManager.onTransportActive();
if (streams.isEmpty()) {
keepAliveEnforcer.onTransportActive();
if (maxConnectionIdleManager != null) {
maxConnectionIdleManager.onTransportActive();
}
}
streams.put(streamId, stream);
listener.streamCreated(streamForApp, method, metadata);
@ -849,6 +884,11 @@ final class OkHttpServerTransport implements ServerTransport,
@Override
public void ping(boolean ack, int payload1, int payload2) {
if (!keepAliveEnforcer.pingAcceptable()) {
abruptShutdown(ErrorCode.ENHANCE_YOUR_CALM, "too_many_pings",
Status.RESOURCE_EXHAUSTED.withDescription("Too many pings from client"), false);
return;
}
long payload = (((long) payload1) << 32) | (payload2 & 0xffffffffL);
if (!ack) {
frameLogger.logPing(OkHttpFrameLogger.Direction.INBOUND, payload);
@ -973,8 +1013,11 @@ final class OkHttpServerTransport implements ServerTransport,
synchronized (lock) {
Http2ErrorStreamState stream =
new Http2ErrorStreamState(streamId, lock, outboundFlow, config.flowControlWindow);
if (maxConnectionIdleManager != null && streams.isEmpty()) {
maxConnectionIdleManager.onTransportActive();
if (streams.isEmpty()) {
keepAliveEnforcer.onTransportActive();
if (maxConnectionIdleManager != null) {
maxConnectionIdleManager.onTransportActive();
}
}
streams.put(streamId, stream);
if (inFinished) {

View File

@ -41,6 +41,7 @@ import io.grpc.Metadata;
import io.grpc.Status;
import io.grpc.internal.FakeClock;
import io.grpc.internal.GrpcUtil;
import io.grpc.internal.KeepAliveEnforcer;
import io.grpc.internal.ServerStream;
import io.grpc.internal.ServerStreamListener;
import io.grpc.internal.ServerTransportListener;
@ -122,7 +123,9 @@ public class OkHttpServerTransportTest {
}
})
.flowControlWindow(INITIAL_WINDOW_SIZE)
.maxConnectionIdle(MAX_CONNECTION_IDLE, TimeUnit.NANOSECONDS);
.maxConnectionIdle(MAX_CONNECTION_IDLE, TimeUnit.NANOSECONDS)
.permitKeepAliveWithoutCalls(true)
.permitKeepAliveTime(0, TimeUnit.SECONDS);
@Rule public final Timeout globalTimeout = Timeout.seconds(10);
@ -1054,6 +1057,101 @@ public class OkHttpServerTransportTest {
assertThat(stats.remote).isEqualTo(new InetSocketAddress("127.0.0.2", 5000));
}
@Test
public void keepAliveEnforcer_enforcesPings() throws Exception {
serverBuilder.permitKeepAliveTime(1, TimeUnit.HOURS)
.permitKeepAliveWithoutCalls(false);
initTransport();
handshake();
for (int i = 0; i < KeepAliveEnforcer.MAX_PING_STRIKES; i++) {
pingPong();
}
pingPongId++;
clientFrameWriter.ping(false, pingPongId, 0);
clientFrameWriter.flush();
assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue();
verify(clientFramesRead).goAway(0, ErrorCode.ENHANCE_YOUR_CALM,
ByteString.encodeString("too_many_pings", GrpcUtil.US_ASCII));
}
@Test
public void keepAliveEnforcer_sendingDataResetsCounters() throws Exception {
serverBuilder.permitKeepAliveTime(1, TimeUnit.HOURS)
.permitKeepAliveWithoutCalls(false);
initTransport();
handshake();
clientFrameWriter.headers(1, Arrays.asList(
HTTP_SCHEME_HEADER,
METHOD_HEADER,
new Header(Header.TARGET_AUTHORITY, "example.com:80"),
new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"),
CONTENT_TYPE_HEADER,
TE_HEADER,
new Header("some-metadata", "this could be anything")));
Buffer requestMessageFrame = createMessageFrame("Hello server");
clientFrameWriter.data(false, 1, requestMessageFrame, (int) requestMessageFrame.size());
pingPong();
MockStreamListener streamListener = mockTransportListener.newStreams.pop();
streamListener.stream.request(1);
pingPong();
assertThat(streamListener.messages.pop()).isEqualTo("Hello server");
streamListener.stream.writeHeaders(metadata("User-Data", "best data"));
streamListener.stream.writeMessage(new ByteArrayInputStream("Howdy client".getBytes(UTF_8)));
streamListener.stream.flush();
assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue();
for (int i = 0; i < 10; i++) {
assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue();
pingPong();
streamListener.stream.writeMessage(new ByteArrayInputStream("Howdy client".getBytes(UTF_8)));
streamListener.stream.flush();
}
}
@Test
public void keepAliveEnforcer_initialIdle() throws Exception {
serverBuilder.permitKeepAliveTime(0, TimeUnit.SECONDS)
.permitKeepAliveWithoutCalls(false);
initTransport();
handshake();
for (int i = 0; i < KeepAliveEnforcer.MAX_PING_STRIKES; i++) {
pingPong();
}
pingPongId++;
clientFrameWriter.ping(false, pingPongId, 0);
clientFrameWriter.flush();
assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue();
verify(clientFramesRead).goAway(0, ErrorCode.ENHANCE_YOUR_CALM,
ByteString.encodeString("too_many_pings", GrpcUtil.US_ASCII));
}
@Test
public void keepAliveEnforcer_noticesActive() throws Exception {
serverBuilder.permitKeepAliveTime(0, TimeUnit.SECONDS)
.permitKeepAliveWithoutCalls(false);
initTransport();
handshake();
clientFrameWriter.headers(1, Arrays.asList(
HTTP_SCHEME_HEADER,
METHOD_HEADER,
new Header(Header.TARGET_AUTHORITY, "example.com:80"),
new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"),
CONTENT_TYPE_HEADER,
TE_HEADER,
new Header("some-metadata", "this could be anything")));
for (int i = 0; i < 10; i++) {
pingPong();
}
verify(clientFramesRead, never()).goAway(anyInt(), eq(ErrorCode.ENHANCE_YOUR_CALM),
eq(ByteString.encodeString("too_many_pings", GrpcUtil.US_ASCII)));
}
private void initTransport() throws Exception {
serverTransport = new OkHttpServerTransport(
new OkHttpServerTransport.Config(serverBuilder, Arrays.asList()),