From 132f7a9a3385d58d06fbe4b80d3290139c76b14a Mon Sep 17 00:00:00 2001 From: Kun Zhang Date: Thu, 6 Oct 2016 17:15:24 -0700 Subject: [PATCH] core: Census integration for stats (#2262) Highlights ========== StatsTraceContext ----------------- The bridge between gRPC library and Census. It keeps track of the total payload sizes and the elapsed time of a Call. The rest of the gRPC code doesn't invoke Census directly. Context propagation ------------------- StatsTraceContext carries CensusContext (and the upcoming TraceContext) and is attached to the gRPC Context. 1. StatsTraceContext is created by ManagedChannelImpl, by calling createClientContext(), which inherits the current CensusContext if available. 2. ManagedChannelImpl passes StatsTraceContext to ClientCallImpl, then to the stream, then to the framer and deframer explicitly. 3. ClientCallImpl propagates the CensusContext to the headers. 1. ServerImpl creates a StatsTraceContext by implementing a new callback method StatsTraceContext methodDetermined(MethodDescriptor, Metadata) on ServerTransportListener. 2. NettyServerHandler calls methodDetermined() before creating the stream, and passes the StatsTraceContext to the stream. 3. When ServerImpl creates the gRPC Context for the new ServerCall, it calls the new method statsTraceContext() on ServerStream and puts the StatsTraceContext in the Context. Metrics recording ----------------- 1. Client-side start time: when ClientCallImpl is created 2. Server-side start time: when methodDetermined() is called 3. Server-side end time: in ServerStreamListener.closed(), but before calling onComplete() or onCancel() on ServerCall.Listener. 4. Client-side end time: in ClientStreamListener.closed(), but before calling onClonse() on ClientCall.Listener Message sizes are recorded in MessageFramer and MessageDeframer. Both the uncompressed and wire (possibly compressed) payload sizes are counted. TODOs ===== The CensusContext created from headers on the server side should be attached to the gRPC Context for the call. It's not done at this moment because Census lacks the proper API to do it. It only affects tracing and resource accounting, but doesn't affect stats functionality --- build.gradle | 1 + core/build.gradle | 3 +- .../inprocess/InProcessChannelBuilder.java | 16 ++ .../inprocess/InProcessServerBuilder.java | 17 ++ .../io/grpc/inprocess/InProcessTransport.java | 25 +- .../grpc/internal/AbstractClientStream.java | 6 +- .../grpc/internal/AbstractClientStream2.java | 9 +- .../AbstractManagedChannelImplBuilder.java | 20 +- .../internal/AbstractServerImplBuilder.java | 21 +- .../grpc/internal/AbstractServerStream.java | 16 +- .../java/io/grpc/internal/AbstractStream.java | 8 +- .../io/grpc/internal/AbstractStream2.java | 4 +- ...llCredentialsApplyingTransportFactory.java | 7 +- .../java/io/grpc/internal/ClientCallImpl.java | 24 +- .../io/grpc/internal/ClientTransport.java | 6 +- .../grpc/internal/DelayedClientTransport.java | 17 +- .../grpc/internal/FailingClientTransport.java | 6 +- .../ForwardingConnectionClientTransport.java | 5 +- .../main/java/io/grpc/internal/GrpcUtil.java | 2 +- .../io/grpc/internal/Http2ClientStream.java | 6 +- .../Http2ClientStreamTransportState.java | 4 +- .../io/grpc/internal/ManagedChannelImpl.java | 10 +- .../io/grpc/internal/MessageDeframer.java | 26 +- .../java/io/grpc/internal/MessageFramer.java | 8 +- .../io/grpc/internal/MetadataApplierImpl.java | 6 +- .../internal/NoopCensusContextFactory.java | 90 +++++++ .../java/io/grpc/internal/ServerCallImpl.java | 11 +- .../java/io/grpc/internal/ServerImpl.java | 46 +++- .../java/io/grpc/internal/ServerStream.java | 5 + .../internal/ServerTransportListener.java | 8 + .../grpc/internal/SingleTransportChannel.java | 16 +- .../io/grpc/internal/StatsTraceContext.java | 235 +++++++++++++++++ .../java/io/grpc/internal/TransportSet.java | 2 +- .../internal/AbstractClientStream2Test.java | 40 +-- .../internal/AbstractClientStreamTest.java | 2 +- .../internal/AbstractServerStreamTest.java | 4 +- .../io/grpc/internal/AbstractStreamTest.java | 2 +- .../internal/CallCredentialsApplyingTest.java | 36 +-- .../io/grpc/internal/ClientCallImplTest.java | 93 ++++++- .../internal/DelayedClientTransportTest.java | 54 ++-- .../ManagedChannelImplIdlenessTest.java | 3 +- .../grpc/internal/ManagedChannelImplTest.java | 72 +++-- ...anagedChannelImplTransportManagerTest.java | 23 +- .../io/grpc/internal/MessageDeframerTest.java | 88 ++++++- .../io/grpc/internal/MessageFramerTest.java | 67 ++++- .../io/grpc/internal/ServerCallImplTest.java | 66 ++++- .../java/io/grpc/internal/ServerImplTest.java | 175 +++++++++++-- .../test/java/io/grpc/internal/TestUtils.java | 3 +- .../io/grpc/internal/TransportSetTest.java | 49 ++-- .../integration/AbstractInteropTest.java | 225 +++++++++++++++- .../testing/integration/StressTestClient.java | 6 + .../integration/TestServiceClient.java | 1 + .../integration/AutoWindowSizingOnTest.java | 1 + .../Http2NettyLocalChannelTest.java | 1 + .../testing/integration/Http2NettyTest.java | 1 + .../testing/integration/Http2OkHttpTest.java | 8 +- .../testing/integration/InProcessTest.java | 10 +- .../integration/TransportCompressionTest.java | 1 + .../java/io/grpc/netty/NettyClientStream.java | 10 +- .../io/grpc/netty/NettyClientTransport.java | 13 +- .../io/grpc/netty/NettyServerHandler.java | 15 +- .../java/io/grpc/netty/NettyServerStream.java | 11 +- .../io/grpc/netty/NettyClientHandlerTest.java | 3 +- .../io/grpc/netty/NettyClientStreamTest.java | 10 +- .../grpc/netty/NettyClientTransportTest.java | 9 + .../io/grpc/netty/NettyHandlerTestBase.java | 3 +- .../io/grpc/netty/NettyServerHandlerTest.java | 5 + .../io/grpc/netty/NettyServerStreamTest.java | 9 +- .../io/grpc/okhttp/OkHttpClientStream.java | 6 +- .../io/grpc/okhttp/OkHttpClientTransport.java | 10 +- .../grpc/okhttp/OkHttpClientStreamTest.java | 9 +- .../testing/AbstractTransportTest.java | 10 +- .../internal/testing/CensusTestUtils.java | 246 ++++++++++++++++++ 73 files changed, 1778 insertions(+), 308 deletions(-) create mode 100644 core/src/main/java/io/grpc/internal/NoopCensusContextFactory.java create mode 100644 core/src/main/java/io/grpc/internal/StatsTraceContext.java create mode 100644 testing/src/main/java/io/grpc/internal/testing/CensusTestUtils.java diff --git a/build.gradle b/build.gradle index ce1b55d195..18d0494d8a 100644 --- a/build.gradle +++ b/build.gradle @@ -151,6 +151,7 @@ subprojects { google_auth_credentials: 'com.google.auth:google-auth-library-credentials:0.4.0', okhttp: 'com.squareup.okhttp:okhttp:2.5.0', okio: 'com.squareup.okio:okio:1.6.0', + census_api: 'com.google.census:census-api:0.2.0', protobuf: "com.google.protobuf:protobuf-java:${protobufVersion}", // swap to ${protobufVersion} after versions align again protobuf_lite: "com.google.protobuf:protobuf-lite:3.0.1", diff --git a/core/build.gradle b/core/build.gradle index 6d6147ce38..8e9a067a25 100644 --- a/core/build.gradle +++ b/core/build.gradle @@ -8,7 +8,8 @@ dependencies { compile libraries.guava, libraries.errorprone, libraries.jsr305, - project(':grpc-context') + project(':grpc-context'), + libraries.census_api testCompile project(':grpc-testing') } diff --git a/core/src/main/java/io/grpc/inprocess/InProcessChannelBuilder.java b/core/src/main/java/io/grpc/inprocess/InProcessChannelBuilder.java index 7bea513f5d..98b8331285 100644 --- a/core/src/main/java/io/grpc/inprocess/InProcessChannelBuilder.java +++ b/core/src/main/java/io/grpc/inprocess/InProcessChannelBuilder.java @@ -31,6 +31,7 @@ package io.grpc.inprocess; +import com.google.census.CensusContextFactory; import com.google.common.base.Preconditions; import io.grpc.ExperimentalApi; @@ -38,6 +39,7 @@ import io.grpc.Internal; import io.grpc.internal.AbstractManagedChannelImplBuilder; import io.grpc.internal.ClientTransportFactory; import io.grpc.internal.ConnectionClientTransport; +import io.grpc.internal.NoopCensusContextFactory; import java.net.SocketAddress; @@ -65,6 +67,10 @@ public class InProcessChannelBuilder extends private InProcessChannelBuilder(String name) { super(new InProcessSocketAddress(name), "localhost"); this.name = Preconditions.checkNotNull(name, "name"); + // TODO(zhangkun83): InProcessTransport by-passes framer and deframer, thus message sizses are + // not counted. Therefore, we disable Census for now. + // (https://github.com/grpc/grpc-java/issues/2284) + super.censusContextFactory(NoopCensusContextFactory.INSTANCE); } /** @@ -80,6 +86,16 @@ public class InProcessChannelBuilder extends return new InProcessClientTransportFactory(name); } + @Internal + @Override + public InProcessChannelBuilder censusContextFactory(CensusContextFactory censusFactory) { + // TODO(zhangkun83): InProcessTransport by-passes framer and deframer, thus message sizses are + // not counted. Census is disabled by using a NOOP Census factory in the constructor, and here + // we prevent the user from overriding it. + // (https://github.com/grpc/grpc-java/issues/2284) + return this; + } + /** * Creates InProcess transports. Exposed for internal use, as it should be private. */ diff --git a/core/src/main/java/io/grpc/inprocess/InProcessServerBuilder.java b/core/src/main/java/io/grpc/inprocess/InProcessServerBuilder.java index 58cad724fe..0e75188026 100644 --- a/core/src/main/java/io/grpc/inprocess/InProcessServerBuilder.java +++ b/core/src/main/java/io/grpc/inprocess/InProcessServerBuilder.java @@ -31,10 +31,13 @@ package io.grpc.inprocess; +import com.google.census.CensusContextFactory; import com.google.common.base.Preconditions; import io.grpc.ExperimentalApi; +import io.grpc.Internal; import io.grpc.internal.AbstractServerImplBuilder; +import io.grpc.internal.NoopCensusContextFactory; import java.io.File; @@ -61,6 +64,10 @@ public final class InProcessServerBuilder private InProcessServerBuilder(String name) { this.name = Preconditions.checkNotNull(name, "name"); + // TODO(zhangkun83): InProcessTransport by-passes framer and deframer, thus message sizses are + // not counted. Therefore, we disable Census for now. + // (https://github.com/grpc/grpc-java/issues/2284) + super.censusContextFactory(NoopCensusContextFactory.INSTANCE); } @Override @@ -72,4 +79,14 @@ public final class InProcessServerBuilder public InProcessServerBuilder useTransportSecurity(File certChain, File privateKey) { throw new UnsupportedOperationException("TLS not supported in InProcessServer"); } + + @Internal + @Override + public InProcessServerBuilder censusContextFactory(CensusContextFactory censusFactory) { + // TODO(zhangkun83): InProcessTransport by-passes framer and deframer, thus message sizses are + // not counted. Census is disabled by using a NOOP Census factory in the constructor, and here + // we prevent the user from overriding it. + // (https://github.com/grpc/grpc-java/issues/2284) + return this; + } } diff --git a/core/src/main/java/io/grpc/inprocess/InProcessTransport.java b/core/src/main/java/io/grpc/inprocess/InProcessTransport.java index 853e4a283e..59578c4cb9 100644 --- a/core/src/main/java/io/grpc/inprocess/InProcessTransport.java +++ b/core/src/main/java/io/grpc/inprocess/InProcessTransport.java @@ -51,6 +51,7 @@ import io.grpc.internal.ServerStream; import io.grpc.internal.ServerStreamListener; import io.grpc.internal.ServerTransport; import io.grpc.internal.ServerTransportListener; +import io.grpc.internal.StatsTraceContext; import java.io.InputStream; import java.util.ArrayDeque; @@ -125,7 +126,8 @@ class InProcessTransport implements ServerTransport, ConnectionClientTransport { @Override public synchronized ClientStream newStream( - final MethodDescriptor method, final Metadata headers, final CallOptions callOptions) { + final MethodDescriptor method, final Metadata headers, final CallOptions callOptions, + StatsTraceContext clientStatsTraceContext) { if (shutdownStatus != null) { final Status capturedStatus = shutdownStatus; return new NoopClientStream() { @@ -135,14 +137,15 @@ class InProcessTransport implements ServerTransport, ConnectionClientTransport { } }; } - - return new InProcessStream(method, headers).clientStream; + StatsTraceContext serverStatsTraceContext = serverTransportListener.methodDetermined( + method.getFullMethodName(), headers); + return new InProcessStream(method, headers, serverStatsTraceContext).clientStream; } @Override public synchronized ClientStream newStream( final MethodDescriptor method, final Metadata headers) { - return newStream(method, headers, CallOptions.DEFAULT); + return newStream(method, headers, CallOptions.DEFAULT, StatsTraceContext.NOOP); } @Override @@ -231,13 +234,16 @@ class InProcessTransport implements ServerTransport, ConnectionClientTransport { private class InProcessStream { private final InProcessServerStream serverStream = new InProcessServerStream(); private final InProcessClientStream clientStream = new InProcessClientStream(); + private final StatsTraceContext serverStatsTraceContext; private final Metadata headers; - private MethodDescriptor method; + private final MethodDescriptor method; - private InProcessStream(MethodDescriptor method, Metadata headers) { + private InProcessStream(MethodDescriptor method, Metadata headers, + StatsTraceContext serverStatsTraceContext) { this.method = checkNotNull(method, "method"); this.headers = checkNotNull(headers, "headers"); - + this.serverStatsTraceContext = + checkNotNull(serverStatsTraceContext, "serverStatsTraceContext"); } // Can be called multiple times due to races on both client and server closing at same time. @@ -408,6 +414,11 @@ class InProcessTransport implements ServerTransport, ConnectionClientTransport { @Override public Attributes attributes() { return serverStreamAttributes; } + + @Override + public StatsTraceContext statsTraceContext() { + return serverStatsTraceContext; + } } private class InProcessClientStream implements ClientStream { diff --git a/core/src/main/java/io/grpc/internal/AbstractClientStream.java b/core/src/main/java/io/grpc/internal/AbstractClientStream.java index 8b18f017b0..ffdc6caf3f 100644 --- a/core/src/main/java/io/grpc/internal/AbstractClientStream.java +++ b/core/src/main/java/io/grpc/internal/AbstractClientStream.java @@ -63,9 +63,9 @@ public abstract class AbstractClientStream extends AbstractStream private Runnable closeListenerTask; private volatile boolean cancelled; - protected AbstractClientStream(WritableBufferAllocator bufferAllocator, - int maxMessageSize) { - super(bufferAllocator, maxMessageSize); + protected AbstractClientStream(WritableBufferAllocator bufferAllocator, int maxMessageSize, + StatsTraceContext statsTraceCtx) { + super(bufferAllocator, maxMessageSize, statsTraceCtx); } @Override diff --git a/core/src/main/java/io/grpc/internal/AbstractClientStream2.java b/core/src/main/java/io/grpc/internal/AbstractClientStream2.java index 2a6dcacabd..a0745dfbd9 100644 --- a/core/src/main/java/io/grpc/internal/AbstractClientStream2.java +++ b/core/src/main/java/io/grpc/internal/AbstractClientStream2.java @@ -94,8 +94,9 @@ public abstract class AbstractClientStream2 extends AbstractStream2 */ private volatile boolean cancelled; - protected AbstractClientStream2(WritableBufferAllocator bufferAllocator) { - framer = new MessageFramer(this, bufferAllocator); + protected AbstractClientStream2(WritableBufferAllocator bufferAllocator, + StatsTraceContext statsTraceCtx) { + framer = new MessageFramer(this, bufferAllocator, statsTraceCtx); } /** {@inheritDoc} */ @@ -164,8 +165,8 @@ public abstract class AbstractClientStream2 extends AbstractStream2 */ private boolean statusReported; - protected TransportState(int maxMessageSize) { - super(maxMessageSize); + protected TransportState(int maxMessageSize, StatsTraceContext statsTraceCtx) { + super(maxMessageSize, statsTraceCtx); } @VisibleForTesting diff --git a/core/src/main/java/io/grpc/internal/AbstractManagedChannelImplBuilder.java b/core/src/main/java/io/grpc/internal/AbstractManagedChannelImplBuilder.java index 39d78b6fe4..857ad78427 100644 --- a/core/src/main/java/io/grpc/internal/AbstractManagedChannelImplBuilder.java +++ b/core/src/main/java/io/grpc/internal/AbstractManagedChannelImplBuilder.java @@ -34,6 +34,8 @@ package io.grpc.internal; import static com.google.common.base.MoreObjects.firstNonNull; import static com.google.common.base.Preconditions.checkArgument; +import com.google.census.Census; +import com.google.census.CensusContextFactory; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import com.google.common.util.concurrent.MoreExecutors; @@ -42,6 +44,7 @@ import io.grpc.Attributes; import io.grpc.ClientInterceptor; import io.grpc.CompressorRegistry; import io.grpc.DecompressorRegistry; +import io.grpc.Internal; import io.grpc.LoadBalancer; import io.grpc.ManagedChannelBuilder; import io.grpc.NameResolver; @@ -118,6 +121,9 @@ public abstract class AbstractManagedChannelImplBuilder private long idleTimeoutMillis = IDLE_MODE_DEFAULT_TIMEOUT_MILLIS; + @Nullable + private CensusContextFactory censusFactory; + protected AbstractManagedChannelImplBuilder(String target) { this.target = Preconditions.checkNotNull(target, "target"); this.directServerAddress = null; @@ -227,6 +233,16 @@ public abstract class AbstractManagedChannelImplBuilder return thisT(); } + /** + * Override the default Census implementation. This is meant to be used in tests. + */ + @VisibleForTesting + @Internal + public T censusContextFactory(CensusContextFactory censusFactory) { + this.censusFactory = censusFactory; + return thisT(); + } + @VisibleForTesting final long getIdleTimeoutMillis() { return idleTimeoutMillis; @@ -266,7 +282,9 @@ public abstract class AbstractManagedChannelImplBuilder firstNonNull(decompressorRegistry, DecompressorRegistry.getDefaultInstance()), firstNonNull(compressorRegistry, CompressorRegistry.getDefaultInstance()), GrpcUtil.TIMER_SERVICE, GrpcUtil.STOPWATCH_SUPPLIER, idleTimeoutMillis, - executor, userAgent, interceptors); + executor, userAgent, interceptors, + firstNonNull(censusFactory, + firstNonNull(Census.getCensusContextFactory(), NoopCensusContextFactory.INSTANCE))); } /** diff --git a/core/src/main/java/io/grpc/internal/AbstractServerImplBuilder.java b/core/src/main/java/io/grpc/internal/AbstractServerImplBuilder.java index 87f0945234..b9a34370b8 100644 --- a/core/src/main/java/io/grpc/internal/AbstractServerImplBuilder.java +++ b/core/src/main/java/io/grpc/internal/AbstractServerImplBuilder.java @@ -34,6 +34,9 @@ package io.grpc.internal; import static com.google.common.base.MoreObjects.firstNonNull; import static com.google.common.base.Preconditions.checkNotNull; +import com.google.census.Census; +import com.google.census.CensusContextFactory; +import com.google.common.annotations.VisibleForTesting; import com.google.common.util.concurrent.MoreExecutors; import io.grpc.BindableService; @@ -87,6 +90,9 @@ public abstract class AbstractServerImplBuilder method, Metadata headers, CallOptions callOptions) { + MethodDescriptor method, Metadata headers, CallOptions callOptions, + StatsTraceContext statsTraceCtx) { CallCredentials creds = callOptions.getCredentials(); if (creds != null) { MetadataApplierImpl applier = new MetadataApplierImpl( - delegate, method, headers, callOptions); + delegate, method, headers, callOptions, statsTraceCtx); Attributes.Builder effectiveAttrsBuilder = Attributes.newBuilder() .set(CallCredentials.ATTR_AUTHORITY, authority) .set(CallCredentials.ATTR_SECURITY_LEVEL, SecurityLevel.NONE) @@ -99,7 +100,7 @@ final class CallCredentialsApplyingTransportFactory implements ClientTransportFa firstNonNull(callOptions.getExecutor(), appExecutor), applier); return applier.returnStream(); } else { - return delegate.newStream(method, headers, callOptions); + return delegate.newStream(method, headers, callOptions, statsTraceCtx); } } } diff --git a/core/src/main/java/io/grpc/internal/ClientCallImpl.java b/core/src/main/java/io/grpc/internal/ClientCallImpl.java index 86fd4ebccc..5a964614e3 100644 --- a/core/src/main/java/io/grpc/internal/ClientCallImpl.java +++ b/core/src/main/java/io/grpc/internal/ClientCallImpl.java @@ -84,6 +84,7 @@ final class ClientCallImpl extends ClientCall private volatile ScheduledFuture deadlineCancellationFuture; private final boolean unaryRequest; private final CallOptions callOptions; + private final StatsTraceContext statsTraceCtx; private ClientStream stream; private volatile boolean cancelListenersShouldBeRemoved; private boolean cancelCalled; @@ -94,7 +95,8 @@ final class ClientCallImpl extends ClientCall private CompressorRegistry compressorRegistry = CompressorRegistry.getDefaultInstance(); ClientCallImpl(MethodDescriptor method, Executor executor, - CallOptions callOptions, ClientTransportProvider clientTransportProvider, + CallOptions callOptions, StatsTraceContext statsTraceCtx, + ClientTransportProvider clientTransportProvider, ScheduledExecutorService deadlineCancellationExecutor) { this.method = method; // If we know that the executor is a direct executor, we don't need to wrap it with a @@ -105,6 +107,7 @@ final class ClientCallImpl extends ClientCall : new SerializingExecutor(executor); // Propagate the context from the thread which initiated the call to all callbacks. this.context = Context.current(); + this.statsTraceCtx = Preconditions.checkNotNull(statsTraceCtx, "statsTraceCtx"); this.unaryRequest = method.getType() == MethodType.UNARY || method.getType() == MethodType.SERVER_STREAMING; this.callOptions = callOptions; @@ -139,7 +142,7 @@ final class ClientCallImpl extends ClientCall @VisibleForTesting static void prepareHeaders(Metadata headers, DecompressorRegistry decompressorRegistry, - Compressor compressor) { + Compressor compressor, StatsTraceContext statsTraceCtx) { headers.discardAll(MESSAGE_ENCODING_KEY); if (compressor != Codec.Identity.NONE) { headers.put(MESSAGE_ENCODING_KEY, compressor.getMessageEncoding()); @@ -150,6 +153,7 @@ final class ClientCallImpl extends ClientCall if (!advertisedEncodings.isEmpty()) { headers.put(MESSAGE_ACCEPT_ENCODING_KEY, advertisedEncodings); } + statsTraceCtx.propagateToHeaders(headers); } @Override @@ -169,7 +173,7 @@ final class ClientCallImpl extends ClientCall @Override public void runInContext() { - observer.onClose(statusFromCancelled(context), new Metadata()); + closeObserver(observer, statusFromCancelled(context), new Metadata()); } } @@ -189,7 +193,8 @@ final class ClientCallImpl extends ClientCall @Override public void runInContext() { - observer.onClose( + closeObserver( + observer, Status.INTERNAL.withDescription( String.format("Unable to find compressor by name %s", compressorName)), new Metadata()); @@ -203,7 +208,7 @@ final class ClientCallImpl extends ClientCall compressor = Codec.Identity.NONE; } - prepareHeaders(headers, decompressorRegistry, compressor); + prepareHeaders(headers, decompressorRegistry, compressor, statsTraceCtx); Deadline effectiveDeadline = effectiveDeadline(); boolean deadlineExceeded = effectiveDeadline != null && effectiveDeadline.isExpired(); @@ -213,7 +218,7 @@ final class ClientCallImpl extends ClientCall ClientTransport transport = clientTransportProvider.get(callOptions); Context origContext = context.attach(); try { - stream = transport.newStream(method, headers, callOptions); + stream = transport.newStream(method, headers, callOptions, statsTraceCtx); } finally { context.detach(origContext); } @@ -400,6 +405,11 @@ final class ClientCallImpl extends ClientCall return stream.isReady(); } + private void closeObserver(Listener observer, Status status, Metadata trailers) { + statsTraceCtx.callEnded(status); + observer.onClose(status, trailers); + } + private class ClientStreamListenerImpl implements ClientStreamListener { private final Listener observer; private boolean closed; @@ -483,7 +493,7 @@ final class ClientCallImpl extends ClientCall closed = true; cancelListenersShouldBeRemoved = true; try { - observer.onClose(status, trailers); + closeObserver(observer, status, trailers); } finally { removeContextListenerAndCancelDeadlineFuture(); } diff --git a/core/src/main/java/io/grpc/internal/ClientTransport.java b/core/src/main/java/io/grpc/internal/ClientTransport.java index a6d360f83a..41619b9384 100644 --- a/core/src/main/java/io/grpc/internal/ClientTransport.java +++ b/core/src/main/java/io/grpc/internal/ClientTransport.java @@ -60,12 +60,14 @@ public interface ClientTransport { * @param method the descriptor of the remote method to be called for this stream. * @param headers to send at the beginning of the call * @param callOptions runtime options of the call + * @param statsTraceCtx carries stats and tracing information * @return the newly created stream. */ // TODO(nmittler): Consider also throwing for stopping. - ClientStream newStream(MethodDescriptor method, Metadata headers, CallOptions callOptions); + ClientStream newStream(MethodDescriptor method, Metadata headers, CallOptions callOptions, + StatsTraceContext statsTraceCtx); - // TODO(zdapeng): Remove tow-argument version in favor of three-argument overload. + // TODO(zdapeng): Remove two-argument version in favor of four-argument overload. ClientStream newStream(MethodDescriptor method, Metadata headers); /** diff --git a/core/src/main/java/io/grpc/internal/DelayedClientTransport.java b/core/src/main/java/io/grpc/internal/DelayedClientTransport.java index 5c841189d5..d898a45c9c 100644 --- a/core/src/main/java/io/grpc/internal/DelayedClientTransport.java +++ b/core/src/main/java/io/grpc/internal/DelayedClientTransport.java @@ -113,8 +113,8 @@ class DelayedClientTransport implements ManagedClientTransport { * {@link FailingClientStream} is returned. */ @Override - public ClientStream newStream(MethodDescriptor method, Metadata headers, CallOptions - callOptions) { + public ClientStream newStream(MethodDescriptor method, Metadata headers, + CallOptions callOptions, StatsTraceContext statsTraceCtx) { Supplier supplier = transportSupplier; if (supplier == null) { synchronized (lock) { @@ -124,7 +124,8 @@ class DelayedClientTransport implements ManagedClientTransport { if (backoffStatus != null && !callOptions.isWaitForReady()) { return new FailingClientStream(backoffStatus); } - PendingStream pendingStream = new PendingStream(method, headers, callOptions); + PendingStream pendingStream = new PendingStream(method, headers, callOptions, + statsTraceCtx); pendingStreams.add(pendingStream); if (pendingStreams.size() == 1) { listener.transportInUse(true); @@ -134,14 +135,14 @@ class DelayedClientTransport implements ManagedClientTransport { } } if (supplier != null) { - return supplier.get().newStream(method, headers, callOptions); + return supplier.get().newStream(method, headers, callOptions, statsTraceCtx); } return new FailingClientStream(Status.UNAVAILABLE.withDescription("transport shutdown")); } @Override public ClientStream newStream(MethodDescriptor method, Metadata headers) { - return newStream(method, headers, CallOptions.DEFAULT); + return newStream(method, headers, CallOptions.DEFAULT, StatsTraceContext.NOOP); } @Override @@ -382,20 +383,22 @@ class DelayedClientTransport implements ManagedClientTransport { private final Metadata headers; private final CallOptions callOptions; private final Context context; + private final StatsTraceContext statsTraceCtx; private PendingStream(MethodDescriptor method, Metadata headers, - CallOptions callOptions) { + CallOptions callOptions, StatsTraceContext statsTraceCtx) { this.method = method; this.headers = headers; this.callOptions = callOptions; this.context = Context.current(); + this.statsTraceCtx = statsTraceCtx; } private void createRealStream(ClientTransport transport) { ClientStream realStream; Context origContext = context.attach(); try { - realStream = transport.newStream(method, headers, callOptions); + realStream = transport.newStream(method, headers, callOptions, statsTraceCtx); } finally { context.detach(origContext); } diff --git a/core/src/main/java/io/grpc/internal/FailingClientTransport.java b/core/src/main/java/io/grpc/internal/FailingClientTransport.java index 874a97880b..806e357aca 100644 --- a/core/src/main/java/io/grpc/internal/FailingClientTransport.java +++ b/core/src/main/java/io/grpc/internal/FailingClientTransport.java @@ -55,14 +55,14 @@ class FailingClientTransport implements ClientTransport { } @Override - public ClientStream newStream(MethodDescriptor method, Metadata headers, CallOptions - callOptions) { + public ClientStream newStream(MethodDescriptor method, Metadata headers, + CallOptions callOptions, StatsTraceContext statsTraceCtx) { return new FailingClientStream(error); } @Override public ClientStream newStream(MethodDescriptor method, Metadata headers) { - return newStream(method, headers, CallOptions.DEFAULT); + return newStream(method, headers, CallOptions.DEFAULT, StatsTraceContext.NOOP); } @Override diff --git a/core/src/main/java/io/grpc/internal/ForwardingConnectionClientTransport.java b/core/src/main/java/io/grpc/internal/ForwardingConnectionClientTransport.java index b79feb6d71..b5e1e25e18 100644 --- a/core/src/main/java/io/grpc/internal/ForwardingConnectionClientTransport.java +++ b/core/src/main/java/io/grpc/internal/ForwardingConnectionClientTransport.java @@ -57,8 +57,9 @@ abstract class ForwardingConnectionClientTransport implements ConnectionClientTr @Override public ClientStream newStream( - MethodDescriptor method, Metadata headers, CallOptions callOptions) { - return delegate().newStream(method, headers, callOptions); + MethodDescriptor method, Metadata headers, CallOptions callOptions, + StatsTraceContext statsTraceCtx) { + return delegate().newStream(method, headers, callOptions, statsTraceCtx); } @Override diff --git a/core/src/main/java/io/grpc/internal/GrpcUtil.java b/core/src/main/java/io/grpc/internal/GrpcUtil.java index 2c02ed923c..da3dc63693 100644 --- a/core/src/main/java/io/grpc/internal/GrpcUtil.java +++ b/core/src/main/java/io/grpc/internal/GrpcUtil.java @@ -456,7 +456,7 @@ public final class GrpcUtil { /** * The factory of default Stopwatches. */ - static final Supplier STOPWATCH_SUPPLIER = new Supplier() { + public static final Supplier STOPWATCH_SUPPLIER = new Supplier() { @Override public Stopwatch get() { return Stopwatch.createUnstarted(); diff --git a/core/src/main/java/io/grpc/internal/Http2ClientStream.java b/core/src/main/java/io/grpc/internal/Http2ClientStream.java index 9cafee4960..c2d28333c3 100644 --- a/core/src/main/java/io/grpc/internal/Http2ClientStream.java +++ b/core/src/main/java/io/grpc/internal/Http2ClientStream.java @@ -81,9 +81,9 @@ public abstract class Http2ClientStream extends AbstractClientStream { private Charset errorCharset = Charsets.UTF_8; private boolean contentTypeChecked; - protected Http2ClientStream(WritableBufferAllocator bufferAllocator, - int maxMessageSize) { - super(bufferAllocator, maxMessageSize); + protected Http2ClientStream(WritableBufferAllocator bufferAllocator, int maxMessageSize, + StatsTraceContext statsTraceCtx) { + super(bufferAllocator, maxMessageSize, statsTraceCtx); } /** diff --git a/core/src/main/java/io/grpc/internal/Http2ClientStreamTransportState.java b/core/src/main/java/io/grpc/internal/Http2ClientStreamTransportState.java index 89aa2421bc..4fb85bd101 100644 --- a/core/src/main/java/io/grpc/internal/Http2ClientStreamTransportState.java +++ b/core/src/main/java/io/grpc/internal/Http2ClientStreamTransportState.java @@ -81,8 +81,8 @@ public abstract class Http2ClientStreamTransportState extends AbstractClientStre private Charset errorCharset = Charsets.UTF_8; private boolean contentTypeChecked; - protected Http2ClientStreamTransportState(int maxMessageSize) { - super(maxMessageSize); + protected Http2ClientStreamTransportState(int maxMessageSize, StatsTraceContext statsTraceCtx) { + super(maxMessageSize, statsTraceCtx); } /** diff --git a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java index f9425eef0e..cd1282f299 100644 --- a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java +++ b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java @@ -35,6 +35,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; +import com.google.census.CensusContextFactory; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Stopwatch; import com.google.common.base.Supplier; @@ -115,6 +116,7 @@ public final class ManagedChannelImpl extends ManagedChannel implements WithLogI private final SharedResourceHolder.Resource timerService; private final Supplier stopwatchSupplier; private final long idleTimeoutMillis; + private final CensusContextFactory censusFactory; /** * Executor that runs deadline timers for requests. @@ -325,7 +327,7 @@ public final class ManagedChannelImpl extends ManagedChannel implements WithLogI SharedResourceHolder.Resource timerService, Supplier stopwatchSupplier, long idleTimeoutMillis, @Nullable Executor executor, @Nullable String userAgent, - List interceptors) { + List interceptors, CensusContextFactory censusFactory) { this.target = checkNotNull(target, "target"); this.nameResolverFactory = checkNotNull(nameResolverFactory, "nameResolverFactory"); this.nameResolverParams = checkNotNull(nameResolverParams, "nameResolverParams"); @@ -351,6 +353,7 @@ public final class ManagedChannelImpl extends ManagedChannel implements WithLogI this.decompressorRegistry = decompressorRegistry; this.compressorRegistry = compressorRegistry; this.userAgent = userAgent; + this.censusFactory = checkNotNull(censusFactory, "censusFactory"); if (log.isLoggable(Level.INFO)) { log.log(Level.INFO, "[{0}] Created with target {1}", new Object[] {getLogId(), target}); @@ -544,10 +547,13 @@ public final class ManagedChannelImpl extends ManagedChannel implements WithLogI if (executor == null) { executor = ManagedChannelImpl.this.executor; } + StatsTraceContext statsTraceCtx = StatsTraceContext.newClientContext( + method.getFullMethodName(), censusFactory, stopwatchSupplier); return new ClientCallImpl( method, executor, callOptions, + statsTraceCtx, transportProvider, scheduledExecutor) .setDecompressorRegistry(decompressorRegistry) @@ -652,7 +658,7 @@ public final class ManagedChannelImpl extends ManagedChannel implements WithLogI @Override public Channel makeChannel(ClientTransport transport) { return new SingleTransportChannel( - transport, executor, scheduledExecutor, authority()); + censusFactory, transport, executor, scheduledExecutor, authority(), stopwatchSupplier); } @Override diff --git a/core/src/main/java/io/grpc/internal/MessageDeframer.java b/core/src/main/java/io/grpc/internal/MessageDeframer.java index c7b55d340a..d1f5e2c5c6 100644 --- a/core/src/main/java/io/grpc/internal/MessageDeframer.java +++ b/core/src/main/java/io/grpc/internal/MessageDeframer.java @@ -98,6 +98,7 @@ public class MessageDeframer implements Closeable { private final Listener listener; private final int maxMessageSize; + private final StatsTraceContext statsTraceCtx; private Decompressor decompressor; private State state = State.HEADER; private int requiredLength = HEADER_LENGTH; @@ -117,10 +118,12 @@ public class MessageDeframer implements Closeable { * {@code NONE} meaning unsupported * @param maxMessageSize the maximum allowed size for received messages. */ - public MessageDeframer(Listener listener, Decompressor decompressor, int maxMessageSize) { + public MessageDeframer(Listener listener, Decompressor decompressor, int maxMessageSize, + StatsTraceContext statsTraceCtx) { this.listener = Preconditions.checkNotNull(listener, "sink"); this.decompressor = Preconditions.checkNotNull(decompressor, "decompressor"); this.maxMessageSize = maxMessageSize; + this.statsTraceCtx = checkNotNull(statsTraceCtx, "statsTraceCtx"); } /** @@ -314,6 +317,9 @@ public class MessageDeframer implements Closeable { } finally { if (totalBytesRead > 0) { listener.bytesRead(totalBytesRead); + if (state == State.BODY) { + statsTraceCtx.wireBytesReceived(totalBytesRead); + } } } } @@ -357,6 +363,7 @@ public class MessageDeframer implements Closeable { } private InputStream getUncompressedBody() { + statsTraceCtx.uncompressedBytesReceived(nextFrame.readableBytes()); return ReadableBuffers.openStream(nextFrame, true); } @@ -370,7 +377,7 @@ public class MessageDeframer implements Closeable { // Enforce the maxMessageSize limit on the returned stream. InputStream unlimitedStream = decompressor.decompress(ReadableBuffers.openStream(nextFrame, true)); - return new SizeEnforcingInputStream(unlimitedStream, maxMessageSize); + return new SizeEnforcingInputStream(unlimitedStream, maxMessageSize, statsTraceCtx); } catch (IOException e) { throw new RuntimeException(e); } @@ -382,12 +389,15 @@ public class MessageDeframer implements Closeable { @VisibleForTesting static final class SizeEnforcingInputStream extends FilterInputStream { private final int maxMessageSize; + private final StatsTraceContext statsTraceCtx; + private long maxCount; private long count; private long mark = -1; - SizeEnforcingInputStream(InputStream in, int maxMessageSize) { + SizeEnforcingInputStream(InputStream in, int maxMessageSize, StatsTraceContext statsTraceCtx) { super(in); this.maxMessageSize = maxMessageSize; + this.statsTraceCtx = statsTraceCtx; } @Override @@ -397,6 +407,7 @@ public class MessageDeframer implements Closeable { count++; } verifySize(); + reportCount(); return result; } @@ -407,6 +418,7 @@ public class MessageDeframer implements Closeable { count += result; } verifySize(); + reportCount(); return result; } @@ -415,6 +427,7 @@ public class MessageDeframer implements Closeable { long result = in.skip(n); count += result; verifySize(); + reportCount(); return result; } @@ -438,6 +451,13 @@ public class MessageDeframer implements Closeable { count = mark; } + private void reportCount() { + if (count > maxCount) { + statsTraceCtx.uncompressedBytesReceived(count - maxCount); + maxCount = count; + } + } + private void verifySize() { if (count > maxMessageSize) { throw Status.INTERNAL.withDescription(String.format( diff --git a/core/src/main/java/io/grpc/internal/MessageFramer.java b/core/src/main/java/io/grpc/internal/MessageFramer.java index 753859b49c..28c022129a 100644 --- a/core/src/main/java/io/grpc/internal/MessageFramer.java +++ b/core/src/main/java/io/grpc/internal/MessageFramer.java @@ -85,6 +85,7 @@ public class MessageFramer { private final OutputStreamAdapter outputStreamAdapter = new OutputStreamAdapter(); private final byte[] headerScratch = new byte[HEADER_LENGTH]; private final WritableBufferAllocator bufferAllocator; + private final StatsTraceContext statsTraceCtx; private boolean closed; /** @@ -93,9 +94,11 @@ public class MessageFramer { * @param sink the sink used to deliver frames to the transport * @param bufferAllocator allocates buffers that the transport can commit to the wire. */ - public MessageFramer(Sink sink, WritableBufferAllocator bufferAllocator) { + public MessageFramer(Sink sink, WritableBufferAllocator bufferAllocator, + StatsTraceContext statsTraceCtx) { this.sink = checkNotNull(sink, "sink"); this.bufferAllocator = checkNotNull(bufferAllocator, "bufferAllocator"); + this.statsTraceCtx = checkNotNull(statsTraceCtx, "statsTraceCtx"); } MessageFramer setCompressor(Compressor compressor) { @@ -142,10 +145,12 @@ public class MessageFramer { String err = String.format("Message length inaccurate %s != %s", written, messageLength); throw Status.INTERNAL.withDescription(err).asRuntimeException(); } + statsTraceCtx.uncompressedBytesSent(written); } private int writeUncompressed(InputStream message, int messageLength) throws IOException { if (messageLength != -1) { + statsTraceCtx.wireBytesSent(messageLength); return writeKnownLengthUncompressed(message, messageLength); } BufferChainOutputStream bufferChain = new BufferChainOutputStream(); @@ -220,6 +225,7 @@ public class MessageFramer { // Assign the current buffer to the last in the chain so it can be used // for future writes or written with end-of-stream=true on close. buffer = bufferList.get(bufferList.size() - 1); + statsTraceCtx.wireBytesSent(messageLength); } private static int writeToOutputStream(InputStream message, OutputStream outputStream) diff --git a/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java b/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java index b5d9bbcaaf..8990446627 100644 --- a/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java +++ b/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java @@ -51,6 +51,7 @@ final class MetadataApplierImpl implements MetadataApplier { private final Metadata origHeaders; private final CallOptions callOptions; private final Context ctx; + private final StatsTraceContext statsTraceCtx; private final Object lock = new Object(); @@ -66,12 +67,13 @@ final class MetadataApplierImpl implements MetadataApplier { DelayedStream delayedStream; MetadataApplierImpl(ClientTransport transport, MethodDescriptor method, - Metadata origHeaders, CallOptions callOptions) { + Metadata origHeaders, CallOptions callOptions, StatsTraceContext statsTraceCtx) { this.transport = transport; this.method = method; this.origHeaders = origHeaders; this.callOptions = callOptions; this.ctx = Context.current(); + this.statsTraceCtx = statsTraceCtx; } @Override @@ -82,7 +84,7 @@ final class MetadataApplierImpl implements MetadataApplier { ClientStream realStream; Context origCtx = ctx.attach(); try { - realStream = transport.newStream(method, origHeaders, callOptions); + realStream = transport.newStream(method, origHeaders, callOptions, statsTraceCtx); } finally { ctx.detach(origCtx); } diff --git a/core/src/main/java/io/grpc/internal/NoopCensusContextFactory.java b/core/src/main/java/io/grpc/internal/NoopCensusContextFactory.java new file mode 100644 index 0000000000..05b07f6362 --- /dev/null +++ b/core/src/main/java/io/grpc/internal/NoopCensusContextFactory.java @@ -0,0 +1,90 @@ +/* + * Copyright 2016, Google Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +package io.grpc.internal; + +import com.google.census.CensusContext; +import com.google.census.CensusContextFactory; +import com.google.census.MetricMap; +import com.google.census.TagKey; +import com.google.census.TagValue; + +import java.nio.ByteBuffer; + +public final class NoopCensusContextFactory extends CensusContextFactory { + private static final ByteBuffer SERIALIZED_BYTES = ByteBuffer.allocate(0).asReadOnlyBuffer(); + private static final CensusContext DEFAULT_CONTEXT = new NoopCensusContext(); + private static final CensusContext.Builder BUILDER = new NoopContextBuilder(); + + public static final CensusContextFactory INSTANCE = new NoopCensusContextFactory(); + + private NoopCensusContextFactory() { + } + + @Override + public CensusContext deserialize(ByteBuffer buffer) { + return DEFAULT_CONTEXT; + } + + @Override + public CensusContext getDefault() { + return DEFAULT_CONTEXT; + } + + private static class NoopCensusContext extends CensusContext { + @Override + public Builder builder() { + return BUILDER; + } + + @Override + public CensusContext record(MetricMap metrics) { + return DEFAULT_CONTEXT; + } + + @Override + public ByteBuffer serialize() { + return SERIALIZED_BYTES; + } + } + + private static class NoopContextBuilder extends CensusContext.Builder { + @Override + public CensusContext.Builder set(TagKey key, TagValue value) { + return this; + } + + @Override + public CensusContext build() { + return DEFAULT_CONTEXT; + } + } +} diff --git a/core/src/main/java/io/grpc/internal/ServerCallImpl.java b/core/src/main/java/io/grpc/internal/ServerCallImpl.java index 21f92aa727..b5a9b1dfe2 100644 --- a/core/src/main/java/io/grpc/internal/ServerCallImpl.java +++ b/core/src/main/java/io/grpc/internal/ServerCallImpl.java @@ -65,6 +65,7 @@ final class ServerCallImpl extends ServerCall { private final String messageAcceptEncoding; private final DecompressorRegistry decompressorRegistry; private final CompressorRegistry compressorRegistry; + private final StatsTraceContext statsTraceCtx; // state private volatile boolean cancelled; @@ -73,7 +74,7 @@ final class ServerCallImpl extends ServerCall { private Compressor compressor; ServerCallImpl(ServerStream stream, MethodDescriptor method, - Metadata inboundHeaders, Context.CancellableContext context, + Metadata inboundHeaders, Context.CancellableContext context, StatsTraceContext statsTraceCtx, DecompressorRegistry decompressorRegistry, CompressorRegistry compressorRegistry) { this.stream = stream; this.method = method; @@ -81,6 +82,7 @@ final class ServerCallImpl extends ServerCall { this.messageAcceptEncoding = inboundHeaders.get(MESSAGE_ACCEPT_ENCODING_KEY); this.decompressorRegistry = decompressorRegistry; this.compressorRegistry = compressorRegistry; + this.statsTraceCtx = checkNotNull(statsTraceCtx, "statsTraceCtx"); if (inboundHeaders.containsKey(MESSAGE_ENCODING_KEY)) { String encoding = inboundHeaders.get(MESSAGE_ENCODING_KEY); @@ -186,7 +188,7 @@ final class ServerCallImpl extends ServerCall { } ServerStreamListener newServerStreamListener(ServerCall.Listener listener) { - return new ServerStreamListenerImpl(this, listener, context); + return new ServerStreamListenerImpl(this, listener, context, statsTraceCtx); } @Override @@ -208,14 +210,16 @@ final class ServerCallImpl extends ServerCall { private final ServerCallImpl call; private final ServerCall.Listener listener; private final Context.CancellableContext context; + private final StatsTraceContext statsTraceCtx; private boolean messageReceived; public ServerStreamListenerImpl( ServerCallImpl call, ServerCall.Listener listener, - Context.CancellableContext context) { + Context.CancellableContext context, StatsTraceContext statsTraceCtx) { this.call = checkNotNull(call, "call"); this.listener = checkNotNull(listener, "listener must not be null"); this.context = checkNotNull(context, "context"); + this.statsTraceCtx = checkNotNull(statsTraceCtx, "statsTraceCtx"); } @Override @@ -263,6 +267,7 @@ final class ServerCallImpl extends ServerCall { @Override public void closed(Status status) { try { + statsTraceCtx.callEnded(status); if (status.isOk()) { listener.onComplete(); } else { diff --git a/core/src/main/java/io/grpc/internal/ServerImpl.java b/core/src/main/java/io/grpc/internal/ServerImpl.java index b66591f085..8b2365a2bd 100644 --- a/core/src/main/java/io/grpc/internal/ServerImpl.java +++ b/core/src/main/java/io/grpc/internal/ServerImpl.java @@ -39,7 +39,10 @@ import static io.grpc.internal.GrpcUtil.TIMEOUT_KEY; import static io.grpc.internal.GrpcUtil.TIMER_SERVICE; import static java.util.concurrent.TimeUnit.NANOSECONDS; +import com.google.census.CensusContextFactory; import com.google.common.base.Preconditions; +import com.google.common.base.Stopwatch; +import com.google.common.base.Supplier; import io.grpc.Attributes; import io.grpc.CompressorRegistry; @@ -90,6 +93,7 @@ public final class ServerImpl extends io.grpc.Server { private final InternalHandlerRegistry registry; private final HandlerRegistry fallbackRegistry; private final List transportFilters; + private final CensusContextFactory censusFactory; @GuardedBy("lock") private boolean started; @GuardedBy("lock") private boolean shutdown; /** non-{@code null} if immediate shutdown has been requested. */ @@ -110,6 +114,7 @@ public final class ServerImpl extends io.grpc.Server { private final DecompressorRegistry decompressorRegistry; private final CompressorRegistry compressorRegistry; + private final Supplier stopwatchSupplier; /** * Construct a server. @@ -122,7 +127,8 @@ public final class ServerImpl extends io.grpc.Server { ServerImpl(Executor executor, InternalHandlerRegistry registry, HandlerRegistry fallbackRegistry, InternalServer transportServer, Context rootContext, DecompressorRegistry decompressorRegistry, CompressorRegistry compressorRegistry, - List transportFilters) { + List transportFilters, CensusContextFactory censusFactory, + Supplier stopwatchSupplier) { this.executor = executor; this.registry = Preconditions.checkNotNull(registry, "registry"); this.fallbackRegistry = Preconditions.checkNotNull(fallbackRegistry, "fallbackRegistry"); @@ -134,6 +140,8 @@ public final class ServerImpl extends io.grpc.Server { this.compressorRegistry = compressorRegistry; this.transportFilters = Collections.unmodifiableList( new ArrayList(transportFilters)); + this.censusFactory = Preconditions.checkNotNull(censusFactory, "censusFactory"); + this.stopwatchSupplier = Preconditions.checkNotNull(stopwatchSupplier, "stopwatchSupplier"); } /** @@ -347,10 +355,20 @@ public final class ServerImpl extends io.grpc.Server { transportClosed(transport); } + @Override + public StatsTraceContext methodDetermined(String methodName, Metadata headers) { + return StatsTraceContext.newServerContext( + methodName, censusFactory, headers, stopwatchSupplier); + } + @Override public ServerStreamListener streamCreated(final ServerStream stream, final String methodName, final Metadata headers) { - final Context.CancellableContext context = createContext(stream, headers); + + final StatsTraceContext statsTraceCtx = Preconditions.checkNotNull( + stream.statsTraceContext(), "statsTraceCtx not present from stream"); + + final Context.CancellableContext context = createContext(stream, headers, statsTraceCtx); final Executor wrappedExecutor; // This is a performance optimization that avoids the synchronization and queuing overhead // that comes with SerializingExecutor. @@ -375,9 +393,13 @@ public final class ServerImpl extends io.grpc.Server { method = fallbackRegistry.lookupMethod(methodName); } if (method == null) { - stream.close( - Status.UNIMPLEMENTED.withDescription("Method not found: " + methodName), - new Metadata()); + Status status = Status.UNIMPLEMENTED.withDescription( + "Method not found: " + methodName); + stream.close(status, new Metadata()); + // TODO(zhangkun83): this would allow a misbehaving client to blow up the server + // in-memory stats storage by sending large number of distinct unimplemented method + // names. (https://github.com/grpc/grpc-java/issues/2285) + statsTraceCtx.callEnded(status); context.cancel(null); return; } @@ -398,15 +420,19 @@ public final class ServerImpl extends io.grpc.Server { return jumpListener; } - private Context.CancellableContext createContext(final ServerStream stream, Metadata headers) { + private Context.CancellableContext createContext( + final ServerStream stream, Metadata headers, StatsTraceContext statsTraceCtx) { Long timeoutNanos = headers.get(TIMEOUT_KEY); + // TODO(zhangkun83): attach the CensusContext from StatsTraceContext to baseContext + Context baseContext = rootContext; + if (timeoutNanos == null) { - return rootContext.withCancellation(); + return baseContext.withCancellation(); } Context.CancellableContext context = - rootContext.withDeadlineAfter(timeoutNanos, NANOSECONDS, timeoutService); + baseContext.withDeadlineAfter(timeoutNanos, NANOSECONDS, timeoutService); context.addListener(new Context.CancellationListener() { @Override public void cancelled(Context context) { @@ -428,8 +454,8 @@ public final class ServerImpl extends io.grpc.Server { Context.CancellableContext context) { // TODO(ejona86): should we update fullMethodName to have the canonical path of the method? ServerCallImpl call = new ServerCallImpl( - stream, methodDef.getMethodDescriptor(), headers, context, decompressorRegistry, - compressorRegistry); + stream, methodDef.getMethodDescriptor(), headers, context, stream.statsTraceContext(), + decompressorRegistry, compressorRegistry); ServerCall.Listener listener = methodDef.getServerCallHandler().startCall(call, headers); if (listener == null) { diff --git a/core/src/main/java/io/grpc/internal/ServerStream.java b/core/src/main/java/io/grpc/internal/ServerStream.java index 050e1063c1..435e7b50cc 100644 --- a/core/src/main/java/io/grpc/internal/ServerStream.java +++ b/core/src/main/java/io/grpc/internal/ServerStream.java @@ -75,4 +75,9 @@ public interface ServerStream extends Stream { * @return Attributes container */ Attributes attributes(); + + /** + * The context for recording stats and traces for this stream. + */ + StatsTraceContext statsTraceContext(); } diff --git a/core/src/main/java/io/grpc/internal/ServerTransportListener.java b/core/src/main/java/io/grpc/internal/ServerTransportListener.java index eecf404742..bc5fe817a6 100644 --- a/core/src/main/java/io/grpc/internal/ServerTransportListener.java +++ b/core/src/main/java/io/grpc/internal/ServerTransportListener.java @@ -40,6 +40,14 @@ import io.grpc.Metadata; */ public interface ServerTransportListener { + /** + * Called when the method name for a new stream has been determined, which happens before the + * stream is actually created and {@link #streamCreated} is called. + * + * @return a context object for recording stats and tracing for the new stream. + */ + StatsTraceContext methodDetermined(String methodName, Metadata headers); + /** * Called when a new stream was created by the remote client. * diff --git a/core/src/main/java/io/grpc/internal/SingleTransportChannel.java b/core/src/main/java/io/grpc/internal/SingleTransportChannel.java index 0148e76578..19ba590430 100644 --- a/core/src/main/java/io/grpc/internal/SingleTransportChannel.java +++ b/core/src/main/java/io/grpc/internal/SingleTransportChannel.java @@ -31,7 +31,10 @@ package io.grpc.internal; +import com.google.census.CensusContextFactory; import com.google.common.base.Preconditions; +import com.google.common.base.Stopwatch; +import com.google.common.base.Supplier; import io.grpc.CallOptions; import io.grpc.Channel; @@ -47,10 +50,12 @@ import java.util.concurrent.ScheduledExecutorService; */ final class SingleTransportChannel extends Channel { + private final CensusContextFactory censusFactory; private final ClientTransport transport; private final Executor executor; private final String authority; private final ScheduledExecutorService deadlineCancellationExecutor; + private final Supplier stopwatchSupplier; private final ClientTransportProvider transportProvider = new ClientTransportProvider() { @Override @@ -62,20 +67,25 @@ final class SingleTransportChannel extends Channel { /** * Creates a new channel with a connected transport. */ - public SingleTransportChannel(ClientTransport transport, Executor executor, - ScheduledExecutorService deadlineCancellationExecutor, String authority) { + public SingleTransportChannel(CensusContextFactory censusFactory, ClientTransport transport, + Executor executor, ScheduledExecutorService deadlineCancellationExecutor, String authority, + Supplier stopwatchSupplier) { + this.censusFactory = Preconditions.checkNotNull(censusFactory, "censusFactory"); this.transport = Preconditions.checkNotNull(transport, "transport"); this.executor = Preconditions.checkNotNull(executor, "executor"); this.deadlineCancellationExecutor = Preconditions.checkNotNull( deadlineCancellationExecutor, "deadlineCancellationExecutor"); this.authority = Preconditions.checkNotNull(authority, "authority"); + this.stopwatchSupplier = Preconditions.checkNotNull(stopwatchSupplier, "stopwatchSupplier"); } @Override public ClientCall newCall( MethodDescriptor methodDescriptor, CallOptions callOptions) { + StatsTraceContext statsTraceCtx = StatsTraceContext.newClientContext( + methodDescriptor.getFullMethodName(), censusFactory, stopwatchSupplier); return new ClientCallImpl(methodDescriptor, - new SerializingExecutor(executor), callOptions, transportProvider, + new SerializingExecutor(executor), callOptions, statsTraceCtx, transportProvider, deadlineCancellationExecutor); } diff --git a/core/src/main/java/io/grpc/internal/StatsTraceContext.java b/core/src/main/java/io/grpc/internal/StatsTraceContext.java new file mode 100644 index 0000000000..2705a37856 --- /dev/null +++ b/core/src/main/java/io/grpc/internal/StatsTraceContext.java @@ -0,0 +1,235 @@ +/* + * Copyright 2016, Google Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +package io.grpc.internal; + +import static com.google.common.base.Preconditions.checkState; + +import com.google.census.CensusContext; +import com.google.census.CensusContextFactory; +import com.google.census.MetricMap; +import com.google.census.MetricName; +import com.google.census.RpcConstants; +import com.google.census.TagKey; +import com.google.census.TagValue; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Stopwatch; +import com.google.common.base.Supplier; + +import io.grpc.Metadata; +import io.grpc.Status; + +import java.nio.ByteBuffer; +import java.util.concurrent.TimeUnit; + +/** + * The stats and tracing information for a call. + */ +public final class StatsTraceContext { + public static final StatsTraceContext NOOP = StatsTraceContext.newClientContext( + "noopservice/noopmethod", NoopCensusContextFactory.INSTANCE, + GrpcUtil.STOPWATCH_SUPPLIER); + + private enum Side { + CLIENT, SERVER + } + + private final CensusContext censusCtx; + private final Stopwatch stopwatch; + private final Side side; + private final Metadata.Key censusHeader; + private long wireBytesSent; + private long wireBytesReceived; + private long uncompressedBytesSent; + private long uncompressedBytesReceived; + private boolean callEnded; + + private StatsTraceContext(Side side, String fullMethodName, CensusContext parentCtx, + Supplier stopwatchSupplier, Metadata.Key censusHeader) { + this.side = side; + TagKey methodTagKey = + side == Side.CLIENT ? RpcConstants.RPC_CLIENT_METHOD : RpcConstants.RPC_SERVER_METHOD; + // TODO(carl-mastrangelo): maybe cache TagValue in MethodDescriptor + this.censusCtx = parentCtx.with(methodTagKey, new TagValue(fullMethodName)); + this.stopwatch = stopwatchSupplier.get().start(); + this.censusHeader = censusHeader; + } + + /** + * Creates a {@code StatsTraceContext} for an outgoing RPC, using the current CensusContext. + * + *

The current time is used as the start time of the RPC. + */ + public static StatsTraceContext newClientContext(String methodName, + CensusContextFactory censusFactory, Supplier stopwatchSupplier) { + return new StatsTraceContext(Side.CLIENT, methodName, + // TODO(zhangkun83): use the CensusContext out of the current Context + censusFactory.getDefault(), + stopwatchSupplier, createCensusHeader(censusFactory)); + } + + @VisibleForTesting + static StatsTraceContext newClientContextForTesting(String methodName, + CensusContextFactory censusFactory, CensusContext parent, + Supplier stopwatchSupplier) { + return new StatsTraceContext(Side.CLIENT, methodName, parent, stopwatchSupplier, + createCensusHeader(censusFactory)); + } + + /** + * Creates a {@code StatsTraceContext} for an incoming RPC, using the CensusContext deserialized + * from the headers. + * + *

The current time is used as the start time of the RPC. + */ + public static StatsTraceContext newServerContext(String methodName, + CensusContextFactory censusFactory, Metadata headers, + Supplier stopwatchSupplier) { + Metadata.Key censusHeader = createCensusHeader(censusFactory); + CensusContext parentCtx = headers.get(censusHeader); + if (parentCtx == null) { + parentCtx = censusFactory.getDefault(); + } + return new StatsTraceContext(Side.SERVER, methodName, parentCtx, stopwatchSupplier, + censusHeader); + } + + /** + * Propagate the context to the outgoing headers. + */ + void propagateToHeaders(Metadata headers) { + headers.discardAll(censusHeader); + headers.put(censusHeader, censusCtx); + } + + Metadata.Key getCensusHeader() { + return censusHeader; + } + + @VisibleForTesting + CensusContext getCensusContext() { + return censusCtx; + } + + @VisibleForTesting + static Metadata.Key createCensusHeader( + final CensusContextFactory censusCtxFactory) { + return Metadata.Key.of("grpc-census-bin", new Metadata.BinaryMarshaller() { + @Override + public byte[] toBytes(CensusContext context) { + ByteBuffer buffer = context.serialize(); + // TODO(carl-mastrangelo): currently we only make sure the correctness. We may need to + // optimize out the allocation and copy in the future. + byte[] bytes = new byte[buffer.remaining()]; + buffer.get(bytes); + return bytes; + } + + @Override + public CensusContext parseBytes(byte[] serialized) { + return censusCtxFactory.deserialize(ByteBuffer.wrap(serialized)); + } + }); + } + + /** + * Record the outgoing number of payload bytes as on the wire. + */ + void wireBytesSent(long bytes) { + // TODO(zhangkun83): maybe change of the checkState() to assert after this class is stabilized. + checkState(!callEnded, "already eneded"); + wireBytesSent += bytes; + } + + /** + * Record the incoming number of payload bytes as on the wire. + */ + void wireBytesReceived(long bytes) { + checkState(!callEnded, "already eneded"); + wireBytesReceived += bytes; + } + + /** + * Record the outgoing number of payload bytes in uncompressed form. + * + *

The time this method is called is unrelated to the actual time when those byte are sent. + */ + void uncompressedBytesSent(long bytes) { + checkState(!callEnded, "already ended"); + uncompressedBytesSent += bytes; + } + + /** + * Record the incoming number of payload bytes in uncompressed form. + * + *

The time this method is called is unrelated to the actual time when those byte are received. + */ + void uncompressedBytesReceived(long bytes) { + checkState(!callEnded, "already ended"); + uncompressedBytesReceived += bytes; + } + + /** + * Record a finished all and mark the current time as the end time. + */ + void callEnded(Status status) { + checkState(!callEnded, "already ended"); + callEnded = true; + stopwatch.stop(); + MetricName latencyMetric; + MetricName wireBytesSentMetric; + MetricName wireBytesReceivedMetric; + MetricName uncompressedBytesSentMetric; + MetricName uncompressedBytesReceivedMetric; + if (side == Side.CLIENT) { + latencyMetric = RpcConstants.RPC_CLIENT_ROUNDTRIP_LATENCY; + wireBytesSentMetric = RpcConstants.RPC_CLIENT_REQUEST_BYTES; + wireBytesReceivedMetric = RpcConstants.RPC_CLIENT_RESPONSE_BYTES; + uncompressedBytesSentMetric = RpcConstants.RPC_CLIENT_UNCOMPRESSED_REQUEST_BYTES; + uncompressedBytesReceivedMetric = RpcConstants.RPC_CLIENT_UNCOMPRESSED_RESPONSE_BYTES; + } else { + latencyMetric = RpcConstants.RPC_SERVER_SERVER_LATENCY; + wireBytesSentMetric = RpcConstants.RPC_SERVER_RESPONSE_BYTES; + wireBytesReceivedMetric = RpcConstants.RPC_SERVER_REQUEST_BYTES; + uncompressedBytesSentMetric = RpcConstants.RPC_SERVER_UNCOMPRESSED_RESPONSE_BYTES; + uncompressedBytesReceivedMetric = RpcConstants.RPC_SERVER_UNCOMPRESSED_REQUEST_BYTES; + } + censusCtx + .with(RpcConstants.RPC_STATUS, new TagValue(status.getCode().toString())) + .record(MetricMap.builder() + .put(latencyMetric, stopwatch.elapsed(TimeUnit.MILLISECONDS)) + .put(wireBytesSentMetric, wireBytesSent) + .put(wireBytesReceivedMetric, wireBytesReceived) + .put(uncompressedBytesSentMetric, uncompressedBytesSent) + .put(uncompressedBytesReceivedMetric, uncompressedBytesReceived) + .build()); + } +} diff --git a/core/src/main/java/io/grpc/internal/TransportSet.java b/core/src/main/java/io/grpc/internal/TransportSet.java index bd8098e66f..dfbdf82b11 100644 --- a/core/src/main/java/io/grpc/internal/TransportSet.java +++ b/core/src/main/java/io/grpc/internal/TransportSet.java @@ -360,7 +360,7 @@ final class TransportSet extends ManagedChannel implements WithLogId { public final ClientCall newCall( MethodDescriptor methodDescriptor, CallOptions callOptions) { return new ClientCallImpl(methodDescriptor, - new SerializingExecutor(appExecutor), callOptions, + new SerializingExecutor(appExecutor), callOptions, StatsTraceContext.NOOP, new ClientTransportProvider() { @Override public ClientTransport get(CallOptions callOptions) { diff --git a/core/src/test/java/io/grpc/internal/AbstractClientStream2Test.java b/core/src/test/java/io/grpc/internal/AbstractClientStream2Test.java index dca32686d4..9ba9fc2650 100644 --- a/core/src/test/java/io/grpc/internal/AbstractClientStream2Test.java +++ b/core/src/test/java/io/grpc/internal/AbstractClientStream2Test.java @@ -63,6 +63,7 @@ public class AbstractClientStream2Test { @Rule public final ExpectedException thrown = ExpectedException.none(); + private final StatsTraceContext statsTraceCtx = StatsTraceContext.NOOP; @Mock private ClientStreamListener mockListener; @Captor private ArgumentCaptor statusCaptor; @@ -82,7 +83,7 @@ public class AbstractClientStream2Test { public void cancel_doNotAcceptOk() { for (Code code : Code.values()) { ClientStreamListener listener = new NoopClientStreamListener(); - AbstractClientStream2 stream = new BaseAbstractClientStream(allocator); + AbstractClientStream2 stream = new BaseAbstractClientStream(allocator, statsTraceCtx); stream.start(listener); if (code != Code.OK) { stream.cancel(Status.fromCodeValue(code.value())); @@ -100,7 +101,7 @@ public class AbstractClientStream2Test { @Test public void cancel_failsOnNull() { ClientStreamListener listener = new NoopClientStreamListener(); - AbstractClientStream2 stream = new BaseAbstractClientStream(allocator); + AbstractClientStream2 stream = new BaseAbstractClientStream(allocator, statsTraceCtx); stream.start(listener); thrown.expect(NullPointerException.class); @@ -109,14 +110,14 @@ public class AbstractClientStream2Test { @Test public void cancel_notifiesOnlyOnce() { - final BaseTransportState state = new BaseTransportState(); + final BaseTransportState state = new BaseTransportState(statsTraceCtx); AbstractClientStream2 stream = new BaseAbstractClientStream(allocator, state, new BaseSink() { @Override public void cancel(Status errorStatus) { // Cancel should eventually result in a transportReportStatus on the transport thread state.transportReportStatus(errorStatus, true/*stop delivery*/, new Metadata()); } - }); + }, statsTraceCtx); stream.start(mockListener); stream.cancel(Status.DEADLINE_EXCEEDED); @@ -127,7 +128,7 @@ public class AbstractClientStream2Test { @Test public void startFailsOnNullListener() { - AbstractClientStream2 stream = new BaseAbstractClientStream(allocator); + AbstractClientStream2 stream = new BaseAbstractClientStream(allocator, statsTraceCtx); thrown.expect(NullPointerException.class); @@ -136,7 +137,7 @@ public class AbstractClientStream2Test { @Test public void cantCallStartTwice() { - AbstractClientStream2 stream = new BaseAbstractClientStream(allocator); + AbstractClientStream2 stream = new BaseAbstractClientStream(allocator, statsTraceCtx); stream.start(mockListener); thrown.expect(IllegalStateException.class); @@ -146,7 +147,7 @@ public class AbstractClientStream2Test { @Test public void inboundDataReceived_failsOnNullFrame() { ClientStreamListener listener = new NoopClientStreamListener(); - AbstractClientStream2 stream = new BaseAbstractClientStream(allocator); + AbstractClientStream2 stream = new BaseAbstractClientStream(allocator, statsTraceCtx); stream.start(listener); thrown.expect(NullPointerException.class); @@ -155,7 +156,7 @@ public class AbstractClientStream2Test { @Test public void inboundDataReceived_failsOnNoHeaders() { - AbstractClientStream2 stream = new BaseAbstractClientStream(allocator); + AbstractClientStream2 stream = new BaseAbstractClientStream(allocator, statsTraceCtx); stream.start(mockListener); stream.transportState().inboundDataReceived(ReadableBuffers.empty()); @@ -166,7 +167,7 @@ public class AbstractClientStream2Test { @Test public void inboundHeadersReceived_notifiesListener() { - AbstractClientStream2 stream = new BaseAbstractClientStream(allocator); + AbstractClientStream2 stream = new BaseAbstractClientStream(allocator, statsTraceCtx); stream.start(mockListener); Metadata headers = new Metadata(); @@ -176,7 +177,7 @@ public class AbstractClientStream2Test { @Test public void inboundHeadersReceived_failsIfStatusReported() { - AbstractClientStream2 stream = new BaseAbstractClientStream(allocator); + AbstractClientStream2 stream = new BaseAbstractClientStream(allocator, statsTraceCtx); stream.start(mockListener); stream.transportState().transportReportStatus(Status.CANCELLED, false, new Metadata()); @@ -186,7 +187,7 @@ public class AbstractClientStream2Test { @Test public void inboundHeadersReceived_acceptsGzipEncoding() { - AbstractClientStream2 stream = new BaseAbstractClientStream(allocator); + AbstractClientStream2 stream = new BaseAbstractClientStream(allocator, statsTraceCtx); stream.start(mockListener); Metadata headers = new Metadata(); headers.put(GrpcUtil.MESSAGE_ENCODING_KEY, new Codec.Gzip().getMessageEncoding()); @@ -197,7 +198,7 @@ public class AbstractClientStream2Test { @Test public void inboundHeadersReceived_acceptsIdentityEncoding() { - AbstractClientStream2 stream = new BaseAbstractClientStream(allocator); + AbstractClientStream2 stream = new BaseAbstractClientStream(allocator, statsTraceCtx); stream.start(mockListener); Metadata headers = new Metadata(); headers.put(GrpcUtil.MESSAGE_ENCODING_KEY, Codec.Identity.NONE.getMessageEncoding()); @@ -208,7 +209,7 @@ public class AbstractClientStream2Test { @Test public void rstStreamClosesStream() { - AbstractClientStream2 stream = new BaseAbstractClientStream(allocator); + AbstractClientStream2 stream = new BaseAbstractClientStream(allocator, statsTraceCtx); stream.start(mockListener); // The application will call request when waiting for a message, which will in turn call this // on the transport thread. @@ -229,13 +230,14 @@ public class AbstractClientStream2Test { private final TransportState state; private final Sink sink; - public BaseAbstractClientStream(WritableBufferAllocator allocator) { - this(allocator, new BaseTransportState(), new BaseSink()); + public BaseAbstractClientStream(WritableBufferAllocator allocator, + StatsTraceContext statsTraceCtx) { + this(allocator, new BaseTransportState(statsTraceCtx), new BaseSink(), statsTraceCtx); } public BaseAbstractClientStream(WritableBufferAllocator allocator, TransportState state, - Sink sink) { - super(allocator); + Sink sink, StatsTraceContext statsTraceCtx) { + super(allocator, statsTraceCtx); this.state = state; this.sink = sink; } @@ -266,8 +268,8 @@ public class AbstractClientStream2Test { } private static class BaseTransportState extends AbstractClientStream2.TransportState { - public BaseTransportState() { - super(DEFAULT_MAX_MESSAGE_SIZE); + public BaseTransportState(StatsTraceContext statsTraceCtx) { + super(DEFAULT_MAX_MESSAGE_SIZE, statsTraceCtx); } @Override diff --git a/core/src/test/java/io/grpc/internal/AbstractClientStreamTest.java b/core/src/test/java/io/grpc/internal/AbstractClientStreamTest.java index 0bed491be5..d2511c5bbd 100644 --- a/core/src/test/java/io/grpc/internal/AbstractClientStreamTest.java +++ b/core/src/test/java/io/grpc/internal/AbstractClientStreamTest.java @@ -257,7 +257,7 @@ public class AbstractClientStreamTest { */ private static class BaseAbstractClientStream extends AbstractClientStream { protected BaseAbstractClientStream(WritableBufferAllocator allocator) { - super(allocator, DEFAULT_MAX_MESSAGE_SIZE); + super(allocator, DEFAULT_MAX_MESSAGE_SIZE, StatsTraceContext.NOOP); } @Override diff --git a/core/src/test/java/io/grpc/internal/AbstractServerStreamTest.java b/core/src/test/java/io/grpc/internal/AbstractServerStreamTest.java index ff98336ee5..64986a987d 100644 --- a/core/src/test/java/io/grpc/internal/AbstractServerStreamTest.java +++ b/core/src/test/java/io/grpc/internal/AbstractServerStreamTest.java @@ -240,7 +240,7 @@ public class AbstractServerStreamTest { protected AbstractServerStreamBase(WritableBufferAllocator bufferAllocator, Sink sink, AbstractServerStream.TransportState state) { - super(bufferAllocator); + super(bufferAllocator, StatsTraceContext.NOOP); this.sink = sink; this.state = state; } @@ -257,7 +257,7 @@ public class AbstractServerStreamTest { static class TransportState extends AbstractServerStream.TransportState { protected TransportState(int maxMessageSize) { - super(maxMessageSize); + super(maxMessageSize, StatsTraceContext.NOOP); } @Override diff --git a/core/src/test/java/io/grpc/internal/AbstractStreamTest.java b/core/src/test/java/io/grpc/internal/AbstractStreamTest.java index 7daa32ec8d..1336cb46a4 100644 --- a/core/src/test/java/io/grpc/internal/AbstractStreamTest.java +++ b/core/src/test/java/io/grpc/internal/AbstractStreamTest.java @@ -120,7 +120,7 @@ public class AbstractStreamTest { */ private class AbstractStreamBase extends AbstractStream { private AbstractStreamBase(WritableBufferAllocator bufferAllocator) { - super(allocator, DEFAULT_MAX_MESSAGE_SIZE); + super(allocator, DEFAULT_MAX_MESSAGE_SIZE, StatsTraceContext.NOOP); } private AbstractStreamBase(MessageFramer framer, MessageDeframer deframer) { diff --git a/core/src/test/java/io/grpc/internal/CallCredentialsApplyingTest.java b/core/src/test/java/io/grpc/internal/CallCredentialsApplyingTest.java index 8a0f4f1085..19424926fd 100644 --- a/core/src/test/java/io/grpc/internal/CallCredentialsApplyingTest.java +++ b/core/src/test/java/io/grpc/internal/CallCredentialsApplyingTest.java @@ -105,6 +105,8 @@ public class CallCredentialsApplyingTest { private static final String CREDS_VALUE = "some credentials"; private final Metadata origHeaders = new Metadata(); + private final StatsTraceContext statsTraceCtx = StatsTraceContext.newClientContext( + method.getFullMethodName(), NoopCensusContextFactory.INSTANCE, GrpcUtil.STOPWATCH_SUPPLIER); private ForwardingConnectionClientTransport transport; private CallOptions callOptions; @@ -114,7 +116,8 @@ public class CallCredentialsApplyingTest { origHeaders.put(ORIG_HEADER_KEY, ORIG_HEADER_VALUE); when(mockTransportFactory.newClientTransport(address, AUTHORITY, USER_AGENT)) .thenReturn(mockTransport); - when(mockTransport.newStream(same(method), any(Metadata.class), any(CallOptions.class))) + when(mockTransport.newStream(same(method), any(Metadata.class), any(CallOptions.class), + any(StatsTraceContext.class))) .thenReturn(mockStream); ClientTransportFactory transportFactory = new CallCredentialsApplyingTransportFactory( mockTransportFactory, mockExecutor); @@ -130,7 +133,7 @@ public class CallCredentialsApplyingTest { Attributes transportAttrs = Attributes.newBuilder().set(ATTR_KEY, ATTR_VALUE).build(); when(mockTransport.getAttrs()).thenReturn(transportAttrs); - transport.newStream(method, origHeaders, callOptions); + transport.newStream(method, origHeaders, callOptions, statsTraceCtx); ArgumentCaptor attrsCaptor = ArgumentCaptor.forClass(null); verify(mockCreds).applyRequestMetadata(same(method), attrsCaptor.capture(), same(mockExecutor), @@ -150,7 +153,7 @@ public class CallCredentialsApplyingTest { .build(); when(mockTransport.getAttrs()).thenReturn(transportAttrs); - transport.newStream(method, origHeaders, callOptions); + transport.newStream(method, origHeaders, callOptions, statsTraceCtx); ArgumentCaptor attrsCaptor = ArgumentCaptor.forClass(null); verify(mockCreds).applyRequestMetadata(same(method), attrsCaptor.capture(), same(mockExecutor), @@ -172,7 +175,8 @@ public class CallCredentialsApplyingTest { Executor anotherExecutor = mock(Executor.class); transport.newStream(method, origHeaders, - callOptions.withAuthority("calloptions-authority").withExecutor(anotherExecutor)); + callOptions.withAuthority("calloptions-authority").withExecutor(anotherExecutor), + statsTraceCtx); ArgumentCaptor attrsCaptor = ArgumentCaptor.forClass(null); verify(mockCreds).applyRequestMetadata(same(method), attrsCaptor.capture(), @@ -198,9 +202,9 @@ public class CallCredentialsApplyingTest { }).when(mockCreds).applyRequestMetadata(same(method), any(Attributes.class), same(mockExecutor), any(MetadataApplier.class)); - ClientStream stream = transport.newStream(method, origHeaders, callOptions); + ClientStream stream = transport.newStream(method, origHeaders, callOptions, statsTraceCtx); - verify(mockTransport).newStream(method, origHeaders, callOptions); + verify(mockTransport).newStream(method, origHeaders, callOptions, statsTraceCtx); assertSame(mockStream, stream); assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY)); assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY)); @@ -221,9 +225,9 @@ public class CallCredentialsApplyingTest { same(mockExecutor), any(MetadataApplier.class)); FailingClientStream stream = - (FailingClientStream) transport.newStream(method, origHeaders, callOptions); + (FailingClientStream) transport.newStream(method, origHeaders, callOptions, statsTraceCtx); - verify(mockTransport, never()).newStream(method, origHeaders, callOptions); + verify(mockTransport, never()).newStream(method, origHeaders, callOptions, statsTraceCtx); assertSame(error, stream.getError()); } @@ -232,18 +236,19 @@ public class CallCredentialsApplyingTest { when(mockTransport.getAttrs()).thenReturn(Attributes.EMPTY); // Will call applyRequestMetadata(), which is no-op. - DelayedStream stream = (DelayedStream) transport.newStream(method, origHeaders, callOptions); + DelayedStream stream = (DelayedStream) transport.newStream(method, origHeaders, callOptions, + statsTraceCtx); ArgumentCaptor applierCaptor = ArgumentCaptor.forClass(null); verify(mockCreds).applyRequestMetadata(same(method), any(Attributes.class), same(mockExecutor), applierCaptor.capture()); - verify(mockTransport, never()).newStream(method, origHeaders, callOptions); + verify(mockTransport, never()).newStream(method, origHeaders, callOptions, statsTraceCtx); Metadata headers = new Metadata(); headers.put(CREDS_KEY, CREDS_VALUE); applierCaptor.getValue().apply(headers); - verify(mockTransport).newStream(method, origHeaders, callOptions); + verify(mockTransport).newStream(method, origHeaders, callOptions, statsTraceCtx); assertSame(mockStream, stream.getRealStream()); assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY)); assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY)); @@ -254,7 +259,8 @@ public class CallCredentialsApplyingTest { when(mockTransport.getAttrs()).thenReturn(Attributes.EMPTY); // Will call applyRequestMetadata(), which is no-op. - DelayedStream stream = (DelayedStream) transport.newStream(method, origHeaders, callOptions); + DelayedStream stream = (DelayedStream) transport.newStream(method, origHeaders, callOptions, + statsTraceCtx); ArgumentCaptor applierCaptor = ArgumentCaptor.forClass(null); verify(mockCreds).applyRequestMetadata(same(method), any(Attributes.class), @@ -263,7 +269,7 @@ public class CallCredentialsApplyingTest { Status error = Status.FAILED_PRECONDITION.withDescription("channel not secure for creds"); applierCaptor.getValue().fail(error); - verify(mockTransport, never()).newStream(method, origHeaders, callOptions); + verify(mockTransport, never()).newStream(method, origHeaders, callOptions, statsTraceCtx); FailingClientStream failingStream = (FailingClientStream) stream.getRealStream(); assertSame(error, failingStream.getError()); } @@ -271,9 +277,9 @@ public class CallCredentialsApplyingTest { @Test public void noCreds() { callOptions = callOptions.withCallCredentials(null); - ClientStream stream = transport.newStream(method, origHeaders, callOptions); + ClientStream stream = transport.newStream(method, origHeaders, callOptions, statsTraceCtx); - verify(mockTransport).newStream(method, origHeaders, callOptions); + verify(mockTransport).newStream(method, origHeaders, callOptions, statsTraceCtx); assertSame(mockStream, stream); assertNull(origHeaders.get(CREDS_KEY)); assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY)); diff --git a/core/src/test/java/io/grpc/internal/ClientCallImplTest.java b/core/src/test/java/io/grpc/internal/ClientCallImplTest.java index c021c39460..646c58ae9e 100644 --- a/core/src/test/java/io/grpc/internal/ClientCallImplTest.java +++ b/core/src/test/java/io/grpc/internal/ClientCallImplTest.java @@ -51,6 +51,9 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyZeroInteractions; import static org.mockito.Mockito.when; +import com.google.census.CensusContext; +import com.google.census.RpcConstants; +import com.google.census.TagValue; import com.google.common.collect.ImmutableSet; import com.google.common.util.concurrent.MoreExecutors; import com.google.common.util.concurrent.SettableFuture; @@ -68,6 +71,8 @@ import io.grpc.MethodDescriptor.Marshaller; import io.grpc.MethodDescriptor.MethodType; import io.grpc.Status; import io.grpc.internal.ClientCallImpl.ClientTransportProvider; +import io.grpc.internal.testing.CensusTestUtils.FakeCensusContextFactory; +import io.grpc.internal.testing.CensusTestUtils; import org.junit.After; import org.junit.Before; @@ -116,6 +121,14 @@ public class ClientCallImplTest { new TestMarshaller(), new TestMarshaller()); + private final FakeCensusContextFactory censusCtxFactory = new FakeCensusContextFactory(); + private final CensusContext parentCensusContext = censusCtxFactory.getDefault().with( + CensusTestUtils.EXTRA_TAG, new TagValue("extra-tag-value")); + private final StatsTraceContext statsTraceCtx = StatsTraceContext.newClientContextForTesting( + method.getFullMethodName(), censusCtxFactory, parentCensusContext, + fakeClock.getStopwatchSupplier()); + private final CensusContext censusCtx = censusCtxFactory.contexts.poll(); + @Mock private ClientStreamListener streamListener; @Mock private ClientTransport clientTransport; @Captor private ArgumentCaptor statusCaptor; @@ -141,9 +154,10 @@ public class ClientCallImplTest { @Before public void setUp() { MockitoAnnotations.initMocks(this); + assertNotNull(censusCtx); when(provider.get(any(CallOptions.class))).thenReturn(transport); when(transport.newStream(any(MethodDescriptor.class), any(Metadata.class), - any(CallOptions.class))).thenReturn(stream); + any(CallOptions.class), any(StatsTraceContext.class))).thenReturn(stream); } @After @@ -151,6 +165,29 @@ public class ClientCallImplTest { Context.ROOT.attach(); } + @Test + public void statusPropagatedFromStreamToCallListener() { + DelayedExecutor executor = new DelayedExecutor(); + ClientCallImpl call = new ClientCallImpl( + method, + executor, + CallOptions.DEFAULT, + statsTraceCtx, + provider, + deadlineCancellationExecutor); + call.start(callListener, new Metadata()); + verify(stream).start(listenerArgumentCaptor.capture()); + final ClientStreamListener streamListener = listenerArgumentCaptor.getValue(); + streamListener.headersRead(new Metadata()); + Status status = Status.RESOURCE_EXHAUSTED.withDescription("simulated"); + streamListener.closed(status , new Metadata()); + executor.release(); + + verify(callListener).onClose(statusArgumentCaptor.capture(), Matchers.isA(Metadata.class)); + assertThat(statusArgumentCaptor.getValue()).isSameAs(status); + assertStatusInStats(status.getCode()); + } + @Test public void exceptionInOnMessageTakesPrecedenceOverServer() { DelayedExecutor executor = new DelayedExecutor(); @@ -158,6 +195,7 @@ public class ClientCallImplTest { method, executor, CallOptions.DEFAULT, + statsTraceCtx, provider, deadlineCancellationExecutor); call.start(callListener, new Metadata()); @@ -182,6 +220,7 @@ public class ClientCallImplTest { assertThat(statusArgumentCaptor.getValue().getCode()).isEqualTo(Status.Code.CANCELLED); assertThat(statusArgumentCaptor.getValue().getCause()).isSameAs(failure); verify(stream).cancel(statusArgumentCaptor.getValue()); + assertStatusInStats(Status.Code.CANCELLED); } @Test @@ -191,6 +230,7 @@ public class ClientCallImplTest { method, executor, CallOptions.DEFAULT, + statsTraceCtx, provider, deadlineCancellationExecutor); call.start(callListener, new Metadata()); @@ -214,6 +254,7 @@ public class ClientCallImplTest { assertThat(statusArgumentCaptor.getValue().getCode()).isEqualTo(Status.Code.CANCELLED); assertThat(statusArgumentCaptor.getValue().getCause()).isSameAs(failure); verify(stream).cancel(statusArgumentCaptor.getValue()); + assertStatusInStats(Status.Code.CANCELLED); } @Test @@ -223,6 +264,7 @@ public class ClientCallImplTest { method, executor, CallOptions.DEFAULT, + statsTraceCtx, provider, deadlineCancellationExecutor); call.start(callListener, new Metadata()); @@ -246,6 +288,7 @@ public class ClientCallImplTest { assertThat(statusArgumentCaptor.getValue().getCode()).isEqualTo(Status.Code.CANCELLED); assertThat(statusArgumentCaptor.getValue().getCause()).isSameAs(failure); verify(stream).cancel(statusArgumentCaptor.getValue()); + assertStatusInStats(Status.Code.CANCELLED); } @Test @@ -254,6 +297,7 @@ public class ClientCallImplTest { method, MoreExecutors.directExecutor(), CallOptions.DEFAULT, + statsTraceCtx, provider, deadlineCancellationExecutor) .setDecompressorRegistry(decompressorRegistry); @@ -261,7 +305,8 @@ public class ClientCallImplTest { call.start(callListener, new Metadata()); ArgumentCaptor metadataCaptor = ArgumentCaptor.forClass(Metadata.class); - verify(transport).newStream(eq(method), metadataCaptor.capture(), same(CallOptions.DEFAULT)); + verify(transport).newStream(eq(method), metadataCaptor.capture(), same(CallOptions.DEFAULT), + same(statsTraceCtx)); Metadata actual = metadataCaptor.getValue(); Set acceptedEncodings = @@ -275,6 +320,7 @@ public class ClientCallImplTest { method, MoreExecutors.directExecutor(), CallOptions.DEFAULT.withAuthority("overridden-authority"), + statsTraceCtx, provider, deadlineCancellationExecutor) .setDecompressorRegistry(decompressorRegistry); @@ -290,6 +336,7 @@ public class ClientCallImplTest { method, MoreExecutors.directExecutor(), callOptions, + statsTraceCtx, provider, deadlineCancellationExecutor) .setDecompressorRegistry(decompressorRegistry); @@ -297,7 +344,8 @@ public class ClientCallImplTest { call.start(callListener, metadata); - verify(transport).newStream(same(method), same(metadata), same(callOptions)); + verify(transport).newStream(same(method), same(metadata), same(callOptions), + same(statsTraceCtx)); } @Test @@ -307,6 +355,7 @@ public class ClientCallImplTest { MoreExecutors.directExecutor(), // Don't provide an authority CallOptions.DEFAULT, + statsTraceCtx, provider, deadlineCancellationExecutor) .setDecompressorRegistry(decompressorRegistry); @@ -319,7 +368,7 @@ public class ClientCallImplTest { public void prepareHeaders_userAgentIgnored() { Metadata m = new Metadata(); m.put(GrpcUtil.USER_AGENT_KEY, "batmobile"); - ClientCallImpl.prepareHeaders(m, decompressorRegistry, Codec.Identity.NONE); + ClientCallImpl.prepareHeaders(m, decompressorRegistry, Codec.Identity.NONE, statsTraceCtx); // User Agent is removed and set by the transport assertThat(m.get(GrpcUtil.USER_AGENT_KEY)).isNotNull(); @@ -328,7 +377,7 @@ public class ClientCallImplTest { @Test public void prepareHeaders_ignoreIdentityEncoding() { Metadata m = new Metadata(); - ClientCallImpl.prepareHeaders(m, decompressorRegistry, Codec.Identity.NONE); + ClientCallImpl.prepareHeaders(m, decompressorRegistry, Codec.Identity.NONE, statsTraceCtx); assertNull(m.get(GrpcUtil.MESSAGE_ENCODING_KEY)); } @@ -371,7 +420,7 @@ public class ClientCallImplTest { } }, false); // not advertised - ClientCallImpl.prepareHeaders(m, customRegistry, Codec.Identity.NONE); + ClientCallImpl.prepareHeaders(m, customRegistry, Codec.Identity.NONE, statsTraceCtx); Iterable acceptedEncodings = ACCEPT_ENCODING_SPLITER.split(m.get(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY)); @@ -386,12 +435,20 @@ public class ClientCallImplTest { m.put(GrpcUtil.MESSAGE_ENCODING_KEY, "gzip"); m.put(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY, "gzip"); - ClientCallImpl.prepareHeaders(m, DecompressorRegistry.emptyInstance(), Codec.Identity.NONE); + ClientCallImpl.prepareHeaders(m, DecompressorRegistry.emptyInstance(), Codec.Identity.NONE, + statsTraceCtx); assertNull(m.get(GrpcUtil.MESSAGE_ENCODING_KEY)); assertNull(m.get(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY)); } + @Test + public void prepareHeaders_censusCtxAdded() { + Metadata m = new Metadata(); + ClientCallImpl.prepareHeaders(m, decompressorRegistry, Codec.Identity.NONE, statsTraceCtx); + assertEquals(parentCensusContext, m.get(statsTraceCtx.getCensusHeader())); + } + @Test public void callerContextPropagatedToListener() throws Exception { // Attach the context which is recorded when the call is created @@ -402,6 +459,7 @@ public class ClientCallImplTest { DESCRIPTOR, new SerializingExecutor(Executors.newSingleThreadExecutor()), CallOptions.DEFAULT, + statsTraceCtx, provider, deadlineCancellationExecutor) .setDecompressorRegistry(decompressorRegistry); @@ -475,6 +533,7 @@ public class ClientCallImplTest { DESCRIPTOR, new SerializingExecutor(Executors.newSingleThreadExecutor()), CallOptions.DEFAULT, + statsTraceCtx, provider, deadlineCancellationExecutor) .setDecompressorRegistry(decompressorRegistry); @@ -504,6 +563,7 @@ public class ClientCallImplTest { DESCRIPTOR, new SerializingExecutor(Executors.newSingleThreadExecutor()), CallOptions.DEFAULT, + statsTraceCtx, provider, deadlineCancellationExecutor) .setDecompressorRegistry(decompressorRegistry); @@ -522,6 +582,7 @@ public class ClientCallImplTest { Status status = statusFuture.get(5, TimeUnit.SECONDS); assertEquals(Status.Code.CANCELLED, status.getCode()); assertSame(cause, status.getCause()); + assertStatusInStats(Status.Code.CANCELLED); // Following operations should be no-op. call.request(1); @@ -547,6 +608,7 @@ public class ClientCallImplTest { DESCRIPTOR, new SerializingExecutor(Executors.newSingleThreadExecutor()), callOptions, + statsTraceCtx, provider, deadlineCancellationExecutor) .setDecompressorRegistry(decompressorRegistry); @@ -554,6 +616,7 @@ public class ClientCallImplTest { verify(transport, times(0)).newStream(any(MethodDescriptor.class), any(Metadata.class)); verify(callListener, timeout(1000)).onClose(statusCaptor.capture(), any(Metadata.class)); assertEquals(Status.Code.DEADLINE_EXCEEDED, statusCaptor.getValue().getCode()); + assertStatusInStats(Status.Code.DEADLINE_EXCEEDED); verifyZeroInteractions(provider); } @@ -568,6 +631,7 @@ public class ClientCallImplTest { DESCRIPTOR, MoreExecutors.directExecutor(), CallOptions.DEFAULT, + statsTraceCtx, provider, deadlineCancellationExecutor); @@ -595,6 +659,7 @@ public class ClientCallImplTest { DESCRIPTOR, MoreExecutors.directExecutor(), callOpts, + statsTraceCtx, provider, deadlineCancellationExecutor); @@ -622,6 +687,7 @@ public class ClientCallImplTest { DESCRIPTOR, MoreExecutors.directExecutor(), callOpts, + statsTraceCtx, provider, deadlineCancellationExecutor); @@ -645,6 +711,7 @@ public class ClientCallImplTest { DESCRIPTOR, MoreExecutors.directExecutor(), CallOptions.DEFAULT.withDeadline(Deadline.after(1000, TimeUnit.MILLISECONDS)), + statsTraceCtx, provider, deadlineCancellationExecutor); @@ -668,6 +735,7 @@ public class ClientCallImplTest { DESCRIPTOR, MoreExecutors.directExecutor(), CallOptions.DEFAULT, + statsTraceCtx, provider, deadlineCancellationExecutor); @@ -687,6 +755,7 @@ public class ClientCallImplTest { DESCRIPTOR, MoreExecutors.directExecutor(), CallOptions.DEFAULT.withDeadline(Deadline.after(1000, TimeUnit.MILLISECONDS)), + statsTraceCtx, provider, deadlineCancellationExecutor); call.start(callListener, new Metadata()); @@ -710,6 +779,7 @@ public class ClientCallImplTest { DESCRIPTOR, MoreExecutors.directExecutor(), CallOptions.DEFAULT, + statsTraceCtx, provider, deadlineCancellationExecutor); @@ -726,6 +796,7 @@ public class ClientCallImplTest { DESCRIPTOR, MoreExecutors.directExecutor(), CallOptions.DEFAULT, + statsTraceCtx, provider, deadlineCancellationExecutor); final Exception cause = new Exception(); @@ -753,6 +824,14 @@ public class ClientCallImplTest { assertSame(cause, status.getCause()); } + private void assertStatusInStats(Status.Code statusCode) { + CensusTestUtils.MetricsRecord record = censusCtxFactory.pollRecord(); + assertNotNull(record); + TagValue statusTag = record.tags.get(RpcConstants.RPC_STATUS); + assertNotNull(statusTag); + assertEquals(statusCode.toString(), statusTag.toString()); + } + private static class TestMarshaller implements Marshaller { @Override public InputStream stream(T value) { diff --git a/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java b/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java index c10d7939e4..44b027c1bb 100644 --- a/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java +++ b/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java @@ -94,6 +94,12 @@ public class DelayedClientTransportTest { private final CallOptions callOptions = CallOptions.DEFAULT.withAuthority("dummy_value"); private final CallOptions callOptions2 = CallOptions.DEFAULT.withAuthority("dummy_value2"); + private final StatsTraceContext statsTraceCtx = StatsTraceContext.newClientContext( + method.getFullMethodName(), NoopCensusContextFactory.INSTANCE, + GrpcUtil.STOPWATCH_SUPPLIER); + private final StatsTraceContext statsTraceCtx2 = StatsTraceContext.newClientContext( + method2.getFullMethodName(), NoopCensusContextFactory.INSTANCE, + GrpcUtil.STOPWATCH_SUPPLIER); private final FakeClock fakeExecutor = new FakeClock(); private final DelayedClientTransport delayedTransport = new DelayedClientTransport( @@ -101,9 +107,11 @@ public class DelayedClientTransportTest { @Before public void setUp() { MockitoAnnotations.initMocks(this); - when(mockRealTransport.newStream(same(method), same(headers), same(callOptions))) + when(mockRealTransport.newStream(same(method), same(headers), same(callOptions), + same(statsTraceCtx))) .thenReturn(mockRealStream); - when(mockRealTransport2.newStream(same(method2), same(headers2), same(callOptions2))) + when(mockRealTransport2.newStream(same(method2), same(headers2), same(callOptions2), + same(statsTraceCtx2))) .thenReturn(mockRealStream2); delayedTransport.start(transportListener); } @@ -113,8 +121,8 @@ public class DelayedClientTransportTest { } @Test public void transportsAreUsedInOrder() { - delayedTransport.newStream(method, headers, callOptions); - delayedTransport.newStream(method2, headers2, callOptions2); + delayedTransport.newStream(method, headers, callOptions, statsTraceCtx); + delayedTransport.newStream(method2, headers2, callOptions2, statsTraceCtx2); assertEquals(0, fakeExecutor.numPendingTasks()); delayedTransport.setTransportSupplier(new Supplier() { final Iterator it = @@ -125,13 +133,15 @@ public class DelayedClientTransportTest { } }); assertEquals(1, fakeExecutor.runDueTasks()); - verify(mockRealTransport).newStream(same(method), same(headers), same(callOptions)); - verify(mockRealTransport2).newStream(same(method2), same(headers2), same(callOptions2)); + verify(mockRealTransport).newStream(same(method), same(headers), same(callOptions), + same(statsTraceCtx)); + verify(mockRealTransport2).newStream(same(method2), same(headers2), same(callOptions2), + same(statsTraceCtx2)); } @Test public void streamStartThenSetTransport() { assertFalse(delayedTransport.hasPendingStreams()); - ClientStream stream = delayedTransport.newStream(method, headers, callOptions); + ClientStream stream = delayedTransport.newStream(method, headers, callOptions, statsTraceCtx); stream.start(streamListener); assertEquals(1, delayedTransport.getPendingStreamsCount()); assertTrue(delayedTransport.hasPendingStreams()); @@ -141,7 +151,8 @@ public class DelayedClientTransportTest { assertEquals(0, delayedTransport.getPendingStreamsCount()); assertFalse(delayedTransport.hasPendingStreams()); assertEquals(1, fakeExecutor.runDueTasks()); - verify(mockRealTransport).newStream(same(method), same(headers), same(callOptions)); + verify(mockRealTransport).newStream(same(method), same(headers), same(callOptions), + same(statsTraceCtx)); verify(mockRealStream).start(listenerCaptor.capture()); verifyNoMoreInteractions(streamListener); listenerCaptor.getValue().onReady(); @@ -150,7 +161,7 @@ public class DelayedClientTransportTest { } @Test public void newStreamThenSetTransportThenShutdown() { - ClientStream stream = delayedTransport.newStream(method, headers, callOptions); + ClientStream stream = delayedTransport.newStream(method, headers, callOptions, statsTraceCtx); assertEquals(1, delayedTransport.getPendingStreamsCount()); assertTrue(stream instanceof DelayedStream); delayedTransport.setTransport(mockRealTransport); @@ -159,7 +170,8 @@ public class DelayedClientTransportTest { verify(transportListener).transportShutdown(any(Status.class)); verify(transportListener).transportTerminated(); assertEquals(1, fakeExecutor.runDueTasks()); - verify(mockRealTransport).newStream(same(method), same(headers), same(callOptions)); + verify(mockRealTransport).newStream(same(method), same(headers), same(callOptions), + same(statsTraceCtx)); stream.start(streamListener); verify(mockRealStream).start(same(streamListener)); } @@ -177,11 +189,12 @@ public class DelayedClientTransportTest { delayedTransport.shutdown(); verify(transportListener).transportShutdown(any(Status.class)); verify(transportListener).transportTerminated(); - ClientStream stream = delayedTransport.newStream(method, headers, callOptions); + ClientStream stream = delayedTransport.newStream(method, headers, callOptions, statsTraceCtx); assertEquals(0, delayedTransport.getPendingStreamsCount()); stream.start(streamListener); assertFalse(stream instanceof DelayedStream); - verify(mockRealTransport).newStream(same(method), same(headers), same(callOptions)); + verify(mockRealTransport).newStream(same(method), same(headers), same(callOptions), + same(statsTraceCtx)); verify(mockRealStream).start(same(streamListener)); } @@ -190,11 +203,12 @@ public class DelayedClientTransportTest { delayedTransport.shutdownNow(Status.UNAVAILABLE); verify(transportListener).transportShutdown(any(Status.class)); verify(transportListener).transportTerminated(); - ClientStream stream = delayedTransport.newStream(method, headers, callOptions); + ClientStream stream = delayedTransport.newStream(method, headers, callOptions, statsTraceCtx); assertEquals(0, delayedTransport.getPendingStreamsCount()); stream.start(streamListener); assertFalse(stream instanceof DelayedStream); - verify(mockRealTransport).newStream(same(method), same(headers), same(callOptions)); + verify(mockRealTransport).newStream(same(method), same(headers), same(callOptions), + same(statsTraceCtx)); verify(mockRealStream).start(same(streamListener)); } @@ -290,10 +304,11 @@ public class DelayedClientTransportTest { final Status cause = Status.UNAVAILABLE.withDescription("some error when connecting"); final CallOptions failFastCallOptions = CallOptions.DEFAULT; final CallOptions waitForReadyCallOptions = CallOptions.DEFAULT.withWaitForReady(); - final ClientStream ffStream = delayedTransport.newStream(method, headers, failFastCallOptions); + final ClientStream ffStream = delayedTransport.newStream(method, headers, failFastCallOptions, + statsTraceCtx); ffStream.start(streamListener); - delayedTransport.newStream(method, headers, waitForReadyCallOptions); - delayedTransport.newStream(method, headers, failFastCallOptions); + delayedTransport.newStream(method, headers, waitForReadyCallOptions, statsTraceCtx); + delayedTransport.newStream(method, headers, failFastCallOptions, statsTraceCtx); assertEquals(3, delayedTransport.getPendingStreamsCount()); delayedTransport.startBackoff(cause); @@ -315,13 +330,14 @@ public class DelayedClientTransportTest { delayedTransport.startBackoff(cause); assertTrue(delayedTransport.isInBackoffPeriod()); - final ClientStream ffStream = delayedTransport.newStream(method, headers, failFastCallOptions); + final ClientStream ffStream = delayedTransport.newStream(method, headers, failFastCallOptions, + statsTraceCtx); ffStream.start(streamListener); assertEquals(0, delayedTransport.getPendingStreamsCount()); verify(streamListener).closed(statusCaptor.capture(), any(Metadata.class)); assertEquals(cause, Status.fromThrowable(statusCaptor.getValue().getCause())); - delayedTransport.newStream(method, headers, waitForReadyCallOptions); + delayedTransport.newStream(method, headers, waitForReadyCallOptions, statsTraceCtx); assertEquals(1, delayedTransport.getPendingStreamsCount()); } diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplIdlenessTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplIdlenessTest.java index f1aed26075..6db2bbc214 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplIdlenessTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplIdlenessTest.java @@ -138,7 +138,8 @@ public class ManagedChannelImplIdlenessTest { CompressorRegistry.getDefaultInstance(), timerService, timer.getStopwatchSupplier(), TimeUnit.SECONDS.toMillis(IDLE_TIMEOUT_SECONDS), executor.getScheduledExecutorService(), USER_AGENT, - Collections.emptyList()); + Collections.emptyList(), + NoopCensusContextFactory.INSTANCE); newTransports = TestUtils.captureTransports(mockTransportFactory); for (int i = 0; i < 2; i++) { diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java index cbb705310d..a719c39bbb 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java @@ -78,6 +78,7 @@ import io.grpc.SecurityLevel; import io.grpc.Status; import io.grpc.StringMarshaller; import io.grpc.TransportManager; +import io.grpc.internal.testing.CensusTestUtils.FakeCensusContextFactory; import org.junit.After; import org.junit.Before; @@ -126,6 +127,7 @@ public class ManagedChannelImplTest { private final ResolvedServerInfo server = new ResolvedServerInfo(socketAddress, Attributes.EMPTY); private final FakeClock timer = new FakeClock(); private final FakeClock executor = new FakeClock(); + private final FakeCensusContextFactory censusCtxFactory = new FakeCensusContextFactory(); private SpyingLoadBalancerFactory loadBalancerFactory = new SpyingLoadBalancerFactory(PickFirstBalancerFactory.getInstance()); @@ -134,6 +136,8 @@ public class ManagedChannelImplTest { private ManagedChannelImpl channel; @Captor private ArgumentCaptor statusCaptor; + @Captor + private ArgumentCaptor statsTraceCtxCaptor; @Mock private ConnectionClientTransport mockTransport; @Mock @@ -161,7 +165,7 @@ public class ManagedChannelImplTest { mockTransportFactory, DecompressorRegistry.getDefaultInstance(), CompressorRegistry.getDefaultInstance(), timerService, timer.getStopwatchSupplier(), ManagedChannelImpl.IDLE_TIMEOUT_MILLIS_DISABLE, - executor.getScheduledExecutorService(), userAgent, interceptors); + executor.getScheduledExecutorService(), userAgent, interceptors, censusCtxFactory); // Force-exit the initial idle-mode channel.exitIdleMode(); // Will start NameResolver in the scheduled executor @@ -237,7 +241,8 @@ public class ManagedChannelImplTest { when(mockTransportFactory.newClientTransport( any(SocketAddress.class), any(String.class), any(String.class))) .thenReturn(mockTransport); - when(mockTransport.newStream(same(method), same(headers), same(CallOptions.DEFAULT))) + when(mockTransport.newStream(same(method), same(headers), same(CallOptions.DEFAULT), + any(StatsTraceContext.class))) .thenReturn(mockStream); call.start(mockCallListener, headers); timer.runDueTasks(); @@ -250,7 +255,10 @@ public class ManagedChannelImplTest { transportListener.transportReady(); executor.runDueTasks(); - verify(mockTransport).newStream(same(method), same(headers), same(CallOptions.DEFAULT)); + verify(mockTransport).newStream(same(method), same(headers), same(CallOptions.DEFAULT), + statsTraceCtxCaptor.capture()); + assertEquals(censusCtxFactory.pollContextOrFail(), + statsTraceCtxCaptor.getValue().getCensusContext()); verify(mockStream).start(streamListenerCaptor.capture()); verify(mockStream).setCompressor(isA(Compressor.class)); ClientStreamListener streamListener = streamListenerCaptor.getValue(); @@ -259,10 +267,15 @@ public class ManagedChannelImplTest { ClientCall call2 = channel.newCall(method, CallOptions.DEFAULT); ClientStream mockStream2 = mock(ClientStream.class); Metadata headers2 = new Metadata(); - when(mockTransport.newStream(same(method), same(headers2), same(CallOptions.DEFAULT))) + when(mockTransport.newStream(same(method), same(headers2), same(CallOptions.DEFAULT), + any(StatsTraceContext.class))) .thenReturn(mockStream2); call2.start(mockCallListener2, headers2); - verify(mockTransport).newStream(same(method), same(headers2), same(CallOptions.DEFAULT)); + verify(mockTransport).newStream(same(method), same(headers2), same(CallOptions.DEFAULT), + statsTraceCtxCaptor.capture()); + assertEquals(censusCtxFactory.pollContextOrFail(), + statsTraceCtxCaptor.getValue().getCensusContext()); + verify(mockStream2).start(streamListenerCaptor.capture()); ClientStreamListener streamListener2 = streamListenerCaptor.getValue(); Metadata trailers = new Metadata(); @@ -323,7 +336,8 @@ public class ManagedChannelImplTest { when(mockTransportFactory.newClientTransport( any(SocketAddress.class), any(String.class), any(String.class))) .thenReturn(mockTransport); - when(mockTransport.newStream(same(method), same(headers), same(CallOptions.DEFAULT))) + when(mockTransport.newStream(same(method), same(headers), same(CallOptions.DEFAULT), + any(StatsTraceContext.class))) .thenReturn(mockStream); call.start(mockCallListener, headers); timer.runDueTasks(); @@ -336,7 +350,9 @@ public class ManagedChannelImplTest { transportListener.transportReady(); executor.runDueTasks(); - verify(mockTransport).newStream(same(method), same(headers), same(CallOptions.DEFAULT)); + verify(mockTransport).newStream(same(method), same(headers), same(CallOptions.DEFAULT), + any(StatsTraceContext.class)); + verify(mockStream).start(streamListenerCaptor.capture()); verify(mockStream).setCompressor(isA(Compressor.class)); ClientStreamListener streamListener = streamListenerCaptor.getValue(); @@ -391,7 +407,8 @@ public class ManagedChannelImplTest { // Create transport and call ClientStream mockStream = mock(ClientStream.class); Metadata headers = new Metadata(); - when(mockTransport.newStream(same(method), same(headers), same(CallOptions.DEFAULT))) + when(mockTransport.newStream(same(method), same(headers), same(CallOptions.DEFAULT), + any(StatsTraceContext.class))) .thenReturn(mockStream); call.start(mockCallListener, headers); timer.runDueTasks(); @@ -502,7 +519,8 @@ public class ManagedChannelImplTest { public void callOptionsExecutor() { Metadata headers = new Metadata(); ClientStream mockStream = mock(ClientStream.class); - when(mockTransport.newStream(same(method), same(headers), any(CallOptions.class))) + when(mockTransport.newStream(same(method), same(headers), any(CallOptions.class), + any(StatsTraceContext.class))) .thenReturn(mockStream); FakeClock callExecutor = new FakeClock(); createChannel(new FakeNameResolverFactory(true), NO_INTERCEPTOR); @@ -520,7 +538,8 @@ public class ManagedChannelImplTest { // Real streams are started in the channel's executor assertEquals(1, executor.runDueTasks()); - verify(mockTransport).newStream(same(method), same(headers), same(options)); + verify(mockTransport).newStream(same(method), same(headers), same(options), + any(StatsTraceContext.class)); verify(mockStream).start(streamListenerCaptor.capture()); ClientStreamListener streamListener = streamListenerCaptor.getValue(); Metadata trailers = new Metadata(); @@ -653,7 +672,8 @@ public class ManagedChannelImplTest { final ConnectionClientTransport goodTransport = mock(ConnectionClientTransport.class); final ConnectionClientTransport badTransport = mock(ConnectionClientTransport.class); when(goodTransport.newStream( - any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class))) + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + any(StatsTraceContext.class))) .thenReturn(mock(ClientStream.class)); when(mockTransportFactory.newClientTransport( same(goodAddress), any(String.class), any(String.class))) @@ -691,7 +711,8 @@ public class ManagedChannelImplTest { goodTransportListenerCaptor.getValue().transportReady(); executor.runDueTasks(); - verify(goodTransport).newStream(same(method), same(headers), same(CallOptions.DEFAULT)); + verify(goodTransport).newStream(same(method), same(headers), same(CallOptions.DEFAULT), + any(StatsTraceContext.class)); // The bad transport was never used. verify(badTransport, times(0)).newStream(any(MethodDescriptor.class), any(Metadata.class)); } @@ -776,10 +797,12 @@ public class ManagedChannelImplTest { final ConnectionClientTransport transport1 = mock(ConnectionClientTransport.class); final ConnectionClientTransport transport2 = mock(ConnectionClientTransport.class); when(transport1.newStream( - any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class))) + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + any(StatsTraceContext.class))) .thenReturn(mock(ClientStream.class)); when(transport2.newStream( - any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class))) + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + any(StatsTraceContext.class))) .thenReturn(mock(ClientStream.class)); when(mockTransportFactory.newClientTransport(same(addr1), any(String.class), any(String.class))) .thenReturn(transport1, transport2); @@ -801,7 +824,8 @@ public class ManagedChannelImplTest { transportListenerCaptor.getValue().transportReady(); executor.runDueTasks(); - verify(transport1).newStream(same(method), same(headers), same(CallOptions.DEFAULT)); + verify(transport1).newStream(same(method), same(headers), same(CallOptions.DEFAULT), + any(StatsTraceContext.class)); transportListenerCaptor.getValue().transportShutdown(Status.UNAVAILABLE); // Second call still use the first address, since it was successfully connected. @@ -813,7 +837,8 @@ public class ManagedChannelImplTest { transportListenerCaptor.getValue().transportReady(); executor.runDueTasks(); - verify(transport2).newStream(same(method), same(headers), same(CallOptions.DEFAULT)); + verify(transport2).newStream(same(method), same(headers), same(CallOptions.DEFAULT), + any(StatsTraceContext.class)); } @Test @@ -859,7 +884,8 @@ public class ManagedChannelImplTest { return mock(ClientStream.class); } }).when(transport).newStream( - any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class)); + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + any(StatsTraceContext.class)); // First call will be on delayed transport. Only newCall() is run within the expected context, // so that we can verify that the context is explicitly attached before calling newStream() and @@ -892,11 +918,13 @@ public class ManagedChannelImplTest { assertEquals(SecurityLevel.NONE, attrsCaptor.getValue().get(CallCredentials.ATTR_SECURITY_LEVEL)); verify(transport, never()).newStream( - any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class)); + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + any(StatsTraceContext.class)); // newStream() is called after apply() is called applierCaptor.getValue().apply(new Metadata()); - verify(transport).newStream(same(method), any(Metadata.class), same(callOptions)); + verify(transport).newStream(same(method), any(Metadata.class), same(callOptions), + any(StatsTraceContext.class)); assertEquals("testValue", testKey.get(newStreamContexts.poll())); // The context should not live beyond the scope of newStream() and applyRequestMetadata() assertNull(testKey.get()); @@ -916,11 +944,13 @@ public class ManagedChannelImplTest { attrsCaptor.getValue().get(CallCredentials.ATTR_SECURITY_LEVEL)); // This is from the first call verify(transport).newStream( - any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class)); + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + any(StatsTraceContext.class)); // Still, newStream() is called after apply() is called applierCaptor.getValue().apply(new Metadata()); - verify(transport, times(2)).newStream(same(method), any(Metadata.class), same(callOptions)); + verify(transport, times(2)).newStream(same(method), any(Metadata.class), same(callOptions), + any(StatsTraceContext.class)); assertEquals("testValue", testKey.get(newStreamContexts.poll())); assertNull(testKey.get()); diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplTransportManagerTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplTransportManagerTest.java index 425a26dc1f..4e22278ae4 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplTransportManagerTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplTransportManagerTest.java @@ -101,6 +101,12 @@ public class ManagedChannelImplTransportManagerTest { new StringMarshaller(), new StringMarshaller()); private final CallOptions callOptions = CallOptions.DEFAULT.withAuthority("dummy_value"); private final CallOptions callOptions2 = CallOptions.DEFAULT.withAuthority("dummy_value2"); + private final StatsTraceContext statsTraceCtx = StatsTraceContext.newClientContext( + method.getFullMethodName(), NoopCensusContextFactory.INSTANCE, + GrpcUtil.STOPWATCH_SUPPLIER); + private final StatsTraceContext statsTraceCtx2 = StatsTraceContext.newClientContext( + method2.getFullMethodName(), NoopCensusContextFactory.INSTANCE, + GrpcUtil.STOPWATCH_SUPPLIER); private ManagedChannelImpl channel; @@ -135,7 +141,8 @@ public class ManagedChannelImplTransportManagerTest { mockTransportFactory, DecompressorRegistry.getDefaultInstance(), CompressorRegistry.getDefaultInstance(), GrpcUtil.TIMER_SERVICE, GrpcUtil.STOPWATCH_SUPPLIER, ManagedChannelImpl.IDLE_TIMEOUT_MILLIS_DISABLE, - executor, USER_AGENT, Collections.emptyList()); + executor, USER_AGENT, Collections.emptyList(), + NoopCensusContextFactory.INSTANCE); ArgumentCaptor> tmCaptor = ArgumentCaptor.forClass(null); @@ -195,7 +202,7 @@ public class ManagedChannelImplTransportManagerTest { // Subsequent getTransport() will use the next address ClientTransport t2 = tm.getTransport(addressGroup); assertNotNull(t2); - t2.newStream(method, new Metadata(), callOptions); + t2.newStream(method, new Metadata(), callOptions, statsTraceCtx); // Will keep the previous back-off policy, and not consult back-off policy verify(mockTransportFactory, timeout(1000)).newClientTransport(addr2, AUTHORITY, USER_AGENT); verify(mockBackoffPolicyProvider, times(backoffReset)).get(); @@ -203,8 +210,8 @@ public class ManagedChannelImplTransportManagerTest { ClientTransport rt2 = transportInfo.transport; // Make the second transport ready transportInfo.listener.transportReady(); - verify(rt2, timeout(1000)).newStream(same(method), any(Metadata.class), - same(callOptions)); + verify(rt2, timeout(1000)).newStream( + same(method), any(Metadata.class), same(callOptions), same(statsTraceCtx)); verify(mockNameResolver, times(0)).refresh(); // Disconnect the second transport transportInfo.listener.transportShutdown(Status.UNAVAILABLE); @@ -213,7 +220,7 @@ public class ManagedChannelImplTransportManagerTest { // Subsequent getTransport() will use the first address, since last attempt was successful. ClientTransport t3 = tm.getTransport(addressGroup); - t3.newStream(method2, new Metadata(), callOptions2); + t3.newStream(method2, new Metadata(), callOptions2, statsTraceCtx2); verify(mockTransportFactory, timeout(1000).times(2)) .newClientTransport(addr1, AUTHORITY, USER_AGENT); // Still no back-off policy creation, because an address succeeded. @@ -221,8 +228,8 @@ public class ManagedChannelImplTransportManagerTest { transportInfo = transports.poll(1, TimeUnit.SECONDS); ClientTransport rt3 = transportInfo.transport; transportInfo.listener.transportReady(); - verify(rt3, timeout(1000)).newStream(same(method2), any(Metadata.class), - same(callOptions2)); + verify(rt3, timeout(1000)).newStream( + same(method2), any(Metadata.class), same(callOptions2), same(statsTraceCtx2)); verify(rt1, times(0)).newStream(any(MethodDescriptor.class), any(Metadata.class)); // Back-off policy was never consulted. @@ -283,7 +290,7 @@ public class ManagedChannelImplTransportManagerTest { ClientTransport t4 = tm.getTransport(addressGroup); assertNotNull(t4); // If backoff's DelayedTransport is still active, this is necessary. Otherwise it would be racy. - t4.newStream(method, new Metadata(), CallOptions.DEFAULT.withWaitForReady()); + t4.newStream(method, new Metadata(), CallOptions.DEFAULT.withWaitForReady(), statsTraceCtx); verify(mockTransportFactory, timeout(1000).times(++transportsAddr1)) .newClientTransport(addr1, AUTHORITY, USER_AGENT); // Back-off policy was reset and consulted. diff --git a/core/src/test/java/io/grpc/internal/MessageDeframerTest.java b/core/src/test/java/io/grpc/internal/MessageDeframerTest.java index 7e9d8a2f37..ba5e47fdae 100644 --- a/core/src/test/java/io/grpc/internal/MessageDeframerTest.java +++ b/core/src/test/java/io/grpc/internal/MessageDeframerTest.java @@ -33,6 +33,7 @@ package io.grpc.internal; import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.mockito.Matchers.anyInt; import static org.mockito.Mockito.atLeastOnce; @@ -42,14 +43,18 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; +import com.google.census.RpcConstants; import com.google.common.base.Charsets; import com.google.common.io.ByteStreams; import com.google.common.primitives.Bytes; import io.grpc.Codec; +import io.grpc.Status; import io.grpc.StatusRuntimeException; import io.grpc.internal.MessageDeframer.Listener; import io.grpc.internal.MessageDeframer.SizeEnforcingInputStream; +import io.grpc.internal.testing.CensusTestUtils.FakeCensusContextFactory; +import io.grpc.internal.testing.CensusTestUtils.MetricsRecord; import org.junit.Rule; import org.junit.Test; @@ -76,8 +81,15 @@ public class MessageDeframerTest { @Rule public final ExpectedException thrown = ExpectedException.none(); private Listener listener = mock(Listener.class); + private final FakeCensusContextFactory censusCtxFactory = new FakeCensusContextFactory(); + // MessageFramerTest tests with a server-side StatsTraceContext, so here we test with a + // client-side StatsTraceContext. + private StatsTraceContext statsTraceCtx = StatsTraceContext.newClientContext( + "service/method", censusCtxFactory, GrpcUtil.STOPWATCH_SUPPLIER); + private MessageDeframer deframer = new MessageDeframer(listener, Codec.Identity.NONE, - DEFAULT_MAX_MESSAGE_SIZE); + DEFAULT_MAX_MESSAGE_SIZE, statsTraceCtx); + private ArgumentCaptor messages = ArgumentCaptor.forClass(InputStream.class); @Test @@ -88,6 +100,7 @@ public class MessageDeframerTest { assertEquals(Bytes.asList(new byte[]{3, 14}), bytes(messages)); verify(listener, atLeastOnce()).bytesRead(anyInt()); verifyNoMoreInteractions(listener); + checkStats(2, 2); } @Test @@ -101,6 +114,7 @@ public class MessageDeframerTest { verify(listener, atLeastOnce()).bytesRead(anyInt()); assertEquals(Bytes.asList(new byte[] {14, 15}), bytes(streams.get(1))); verifyNoMoreInteractions(listener); + checkStats(3, 3); } @Test @@ -112,6 +126,7 @@ public class MessageDeframerTest { verify(listener).endOfStream(); verify(listener, atLeastOnce()).bytesRead(anyInt()); verifyNoMoreInteractions(listener); + checkStats(1, 1); } @Test @@ -119,6 +134,7 @@ public class MessageDeframerTest { deframer.deframe(buffer(new byte[0]), true); verify(listener).endOfStream(); verifyNoMoreInteractions(listener); + checkStats(0, 0); } @Test @@ -133,6 +149,7 @@ public class MessageDeframerTest { verify(listener, atLeastOnce()).bytesRead(anyInt()); assertTrue(deframer.isStalled()); verifyNoMoreInteractions(listener); + checkStats(7, 7); } @Test @@ -148,6 +165,7 @@ public class MessageDeframerTest { verify(listener, atLeastOnce()).bytesRead(anyInt()); assertTrue(deframer.isStalled()); verifyNoMoreInteractions(listener); + checkStats(1, 1); } @Test @@ -158,6 +176,7 @@ public class MessageDeframerTest { assertEquals(Bytes.asList(), bytes(messages)); verify(listener, atLeastOnce()).bytesRead(anyInt()); verifyNoMoreInteractions(listener); + checkStats(0, 0); } @Test @@ -169,6 +188,7 @@ public class MessageDeframerTest { assertEquals(Bytes.asList(new byte[1000]), bytes(messages)); verify(listener, atLeastOnce()).bytesRead(anyInt()); verifyNoMoreInteractions(listener); + checkStats(1000, 1000); } @Test @@ -182,11 +202,13 @@ public class MessageDeframerTest { verify(listener).endOfStream(); verify(listener, atLeastOnce()).bytesRead(anyInt()); verifyNoMoreInteractions(listener); + checkStats(1, 1); } @Test public void compressed() { - deframer = new MessageDeframer(listener, new Codec.Gzip(), DEFAULT_MAX_MESSAGE_SIZE); + deframer = new MessageDeframer(listener, new Codec.Gzip(), DEFAULT_MAX_MESSAGE_SIZE, + statsTraceCtx); deframer.request(1); byte[] payload = compress(new byte[1000]); @@ -197,6 +219,7 @@ public class MessageDeframerTest { assertEquals(Bytes.asList(new byte[1000]), bytes(messages)); verify(listener, atLeastOnce()).bytesRead(anyInt()); verifyNoMoreInteractions(listener); + checkStats(payload.length, 1000); } @Test @@ -222,27 +245,34 @@ public class MessageDeframerTest { @Test public void sizeEnforcingInputStream_readByteBelowLimit() throws IOException { ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8)); - SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 4); + SizeEnforcingInputStream stream = + new MessageDeframer.SizeEnforcingInputStream(in, 4, statsTraceCtx); while (stream.read() != -1) {} stream.close(); + // SizeEnforcingInputStream only reports uncompressed bytes + checkStats(0, 3); } @Test public void sizeEnforcingInputStream_readByteAtLimit() throws IOException { ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8)); - SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 3); + SizeEnforcingInputStream stream = + new MessageDeframer.SizeEnforcingInputStream(in, 3, statsTraceCtx); while (stream.read() != -1) {} stream.close(); + // SizeEnforcingInputStream only reports uncompressed bytes + checkStats(0, 3); } @Test public void sizeEnforcingInputStream_readByteAboveLimit() throws IOException { ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8)); - SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 2); + SizeEnforcingInputStream stream = + new MessageDeframer.SizeEnforcingInputStream(in, 2, statsTraceCtx); thrown.expect(StatusRuntimeException.class); thrown.expectMessage("INTERNAL: Compressed frame exceeds"); @@ -256,31 +286,38 @@ public class MessageDeframerTest { @Test public void sizeEnforcingInputStream_readBelowLimit() throws IOException { ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8)); - SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 4); + SizeEnforcingInputStream stream = + new MessageDeframer.SizeEnforcingInputStream(in, 4, statsTraceCtx); byte[] buf = new byte[10]; int read = stream.read(buf, 0, buf.length); assertEquals(3, read); stream.close(); + // SizeEnforcingInputStream only reports uncompressed bytes + checkStats(0, 3); } @Test public void sizeEnforcingInputStream_readAtLimit() throws IOException { ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8)); - SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 3); + SizeEnforcingInputStream stream = + new MessageDeframer.SizeEnforcingInputStream(in, 3, statsTraceCtx); byte[] buf = new byte[10]; int read = stream.read(buf, 0, buf.length); assertEquals(3, read); stream.close(); + // SizeEnforcingInputStream only reports uncompressed bytes + checkStats(0, 3); } @Test public void sizeEnforcingInputStream_readAboveLimit() throws IOException { ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8)); - SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 2); + SizeEnforcingInputStream stream = + new MessageDeframer.SizeEnforcingInputStream(in, 2, statsTraceCtx); byte[] buf = new byte[10]; thrown.expect(StatusRuntimeException.class); @@ -295,30 +332,37 @@ public class MessageDeframerTest { @Test public void sizeEnforcingInputStream_skipBelowLimit() throws IOException { ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8)); - SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 4); + SizeEnforcingInputStream stream = + new MessageDeframer.SizeEnforcingInputStream(in, 4, statsTraceCtx); long skipped = stream.skip(4); assertEquals(3, skipped); stream.close(); + // SizeEnforcingInputStream only reports uncompressed bytes + checkStats(0, 3); } @Test public void sizeEnforcingInputStream_skipAtLimit() throws IOException { ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8)); - SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 3); + SizeEnforcingInputStream stream = + new MessageDeframer.SizeEnforcingInputStream(in, 3, statsTraceCtx); long skipped = stream.skip(4); assertEquals(3, skipped); stream.close(); + // SizeEnforcingInputStream only reports uncompressed bytes + checkStats(0, 3); } @Test public void sizeEnforcingInputStream_skipAboveLimit() throws IOException { ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8)); - SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 2); + SizeEnforcingInputStream stream = + new MessageDeframer.SizeEnforcingInputStream(in, 2, statsTraceCtx); thrown.expect(StatusRuntimeException.class); thrown.expectMessage("INTERNAL: Compressed frame exceeds"); @@ -332,7 +376,8 @@ public class MessageDeframerTest { @Test public void sizeEnforcingInputStream_markReset() throws IOException { ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8)); - SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 3); + SizeEnforcingInputStream stream = + new MessageDeframer.SizeEnforcingInputStream(in, 3, statsTraceCtx); // stream currently looks like: |foo stream.skip(1); // f|oo stream.mark(10); // any large number will work. @@ -342,6 +387,25 @@ public class MessageDeframerTest { assertEquals(2, skipped); stream.close(); + // SizeEnforcingInputStream only reports uncompressed bytes + checkStats(0, 3); + } + + private void checkStats(long wireBytesReceived, long uncompressedBytesReceived) { + statsTraceCtx.callEnded(Status.OK); + MetricsRecord record = censusCtxFactory.pollRecord(); + assertEquals(0, record.getMetricAsLongOrFail( + RpcConstants.RPC_CLIENT_REQUEST_BYTES)); + assertEquals(0, record.getMetricAsLongOrFail( + RpcConstants.RPC_CLIENT_UNCOMPRESSED_REQUEST_BYTES)); + assertEquals(wireBytesReceived, + record.getMetricAsLongOrFail(RpcConstants.RPC_CLIENT_RESPONSE_BYTES)); + assertEquals(uncompressedBytesReceived, + record.getMetricAsLongOrFail(RpcConstants.RPC_CLIENT_UNCOMPRESSED_RESPONSE_BYTES)); + assertNull(record.getMetric(RpcConstants.RPC_SERVER_REQUEST_BYTES)); + assertNull(record.getMetric(RpcConstants.RPC_SERVER_RESPONSE_BYTES)); + assertNull(record.getMetric(RpcConstants.RPC_SERVER_UNCOMPRESSED_REQUEST_BYTES)); + assertNull(record.getMetric(RpcConstants.RPC_SERVER_UNCOMPRESSED_RESPONSE_BYTES)); } private static List bytes(ArgumentCaptor captor) { diff --git a/core/src/test/java/io/grpc/internal/MessageFramerTest.java b/core/src/test/java/io/grpc/internal/MessageFramerTest.java index 440008da4f..aad366b20d 100644 --- a/core/src/test/java/io/grpc/internal/MessageFramerTest.java +++ b/core/src/test/java/io/grpc/internal/MessageFramerTest.java @@ -32,6 +32,7 @@ package io.grpc.internal; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.Matchers.eq; @@ -40,7 +41,14 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.verifyZeroInteractions; +import com.google.census.RpcConstants; + import io.grpc.Codec; +import io.grpc.Metadata; +import io.grpc.Status; +import io.grpc.internal.StatsTraceContext; +import io.grpc.internal.testing.CensusTestUtils.FakeCensusContextFactory; +import io.grpc.internal.testing.CensusTestUtils.MetricsRecord; import org.junit.Before; import org.junit.Test; @@ -70,12 +78,19 @@ public class MessageFramerTest { private ArgumentCaptor frameCaptor; private BytesWritableBufferAllocator allocator = new BytesWritableBufferAllocator(1000, 1000); + private FakeCensusContextFactory censusCtxFactory; + private StatsTraceContext statsTraceCtx; /** Set up for test. */ @Before public void setUp() { MockitoAnnotations.initMocks(this); - framer = new MessageFramer(sink, allocator); + censusCtxFactory = new FakeCensusContextFactory(); + // MessageDeframerTest tests with a client-side StatsTraceContext, so here we test with a + // server-side StatsTraceContext. + statsTraceCtx = StatsTraceContext.newServerContext( + "service/method", censusCtxFactory, new Metadata(), GrpcUtil.STOPWATCH_SUPPLIER); + framer = new MessageFramer(sink, allocator, statsTraceCtx); } @Test @@ -83,9 +98,11 @@ public class MessageFramerTest { writeKnownLength(framer, new byte[]{3, 14}); verifyNoMoreInteractions(sink); framer.flush(); + verify(sink).deliverFrame(toWriteBuffer(new byte[] {0, 0, 0, 0, 2, 3, 14}), false, true); assertEquals(1, allocator.allocCount); verifyNoMoreInteractions(sink); + checkStats(2, 2); } @Test @@ -97,6 +114,7 @@ public class MessageFramerTest { verify(sink).deliverFrame(toWriteBuffer(new byte[] {3, 14}), false, true); assertEquals(2, allocator.allocCount); verifyNoMoreInteractions(sink); + checkStats(2, 2); } @Test @@ -110,6 +128,7 @@ public class MessageFramerTest { toWriteBuffer(new byte[] {0, 0, 0, 0, 1, 3, 0, 0, 0, 0, 1, 14}), false, true); verifyNoMoreInteractions(sink); assertEquals(1, allocator.allocCount); + checkStats(2, 2); } @Test @@ -121,6 +140,7 @@ public class MessageFramerTest { toWriteBuffer(new byte[] {0, 0, 0, 0, 7, 3, 14, 1, 5, 9, 2, 6}), true, true); verifyNoMoreInteractions(sink); assertEquals(1, allocator.allocCount); + checkStats(7, 7); } @Test @@ -129,12 +149,13 @@ public class MessageFramerTest { verify(sink).deliverFrame(null, true, true); verifyNoMoreInteractions(sink); assertEquals(0, allocator.allocCount); + checkStats(0, 0); } @Test public void payloadSplitBetweenSinks() { allocator = new BytesWritableBufferAllocator(12, 12); - framer = new MessageFramer(sink, allocator); + framer = new MessageFramer(sink, allocator, statsTraceCtx); writeKnownLength(framer, new byte[]{3, 14, 1, 5, 9, 2, 6, 5}); verify(sink).deliverFrame( toWriteBuffer(new byte[] {0, 0, 0, 0, 8, 3, 14, 1, 5, 9, 2, 6}), false, false); @@ -144,12 +165,13 @@ public class MessageFramerTest { verify(sink).deliverFrame(toWriteBuffer(new byte[] {5}), false, true); verifyNoMoreInteractions(sink); assertEquals(2, allocator.allocCount); + checkStats(8, 8); } @Test public void frameHeaderSplitBetweenSinks() { allocator = new BytesWritableBufferAllocator(12, 12); - framer = new MessageFramer(sink, allocator); + framer = new MessageFramer(sink, allocator, statsTraceCtx); writeKnownLength(framer, new byte[]{3, 14, 1}); writeKnownLength(framer, new byte[]{3}); verify(sink).deliverFrame( @@ -160,6 +182,7 @@ public class MessageFramerTest { verify(sink).deliverFrame(toWriteBufferWithMinSize(new byte[] {1, 3}, 12), false, true); verifyNoMoreInteractions(sink); assertEquals(2, allocator.allocCount); + checkStats(4, 4); } @Test @@ -168,6 +191,7 @@ public class MessageFramerTest { framer.flush(); verify(sink).deliverFrame(toWriteBuffer(new byte[] {0, 0, 0, 0, 0}), false, true); assertEquals(1, allocator.allocCount); + checkStats(0, 0); } @Test @@ -178,6 +202,7 @@ public class MessageFramerTest { verify(sink).deliverFrame(toWriteBuffer(new byte[] {0, 0, 0, 0, 0}), false, true); // One alloc for the header assertEquals(1, allocator.allocCount); + checkStats(0, 0); } @Test @@ -188,12 +213,13 @@ public class MessageFramerTest { verify(sink).deliverFrame(toWriteBuffer(new byte[] {0, 0, 0, 0, 2, 3, 14}), false, true); verifyNoMoreInteractions(sink); assertEquals(1, allocator.allocCount); + checkStats(2, 2); } @Test public void largerFrameSize() throws Exception { allocator = new BytesWritableBufferAllocator(0, 10000); - framer = new MessageFramer(sink, allocator); + framer = new MessageFramer(sink, allocator, statsTraceCtx); writeKnownLength(framer, new byte[1000]); framer.flush(); verify(sink).deliverFrame(frameCaptor.capture(), eq(false), eq(true)); @@ -207,13 +233,14 @@ public class MessageFramerTest { assertEquals(toWriteBuffer(data), buffer); verifyNoMoreInteractions(sink); assertEquals(1, allocator.allocCount); + checkStats(1000, 1000); } @Test public void largerFrameSizeUnknownLength() throws Exception { // Force payload to be split into two chunks allocator = new BytesWritableBufferAllocator(500, 500); - framer = new MessageFramer(sink, allocator); + framer = new MessageFramer(sink, allocator, statsTraceCtx); writeUnknownLength(framer, new byte[1000]); framer.flush(); // Header and first chunk written with flush = false @@ -233,13 +260,14 @@ public class MessageFramerTest { verifyNoMoreInteractions(sink); assertEquals(3, allocator.allocCount); + checkStats(1000, 1000); } @Test public void compressed() throws Exception { allocator = new BytesWritableBufferAllocator(100, Integer.MAX_VALUE); // setMessageCompression should default to true - framer = new MessageFramer(sink, allocator).setCompressor(new Codec.Gzip()); + framer = new MessageFramer(sink, allocator, statsTraceCtx).setCompressor(new Codec.Gzip()); writeKnownLength(framer, new byte[1000]); framer.flush(); // The GRPC header is written first as a separate frame. @@ -257,12 +285,13 @@ public class MessageFramerTest { assertTrue(length < 1000); assertEquals(frameCaptor.getAllValues().get(1).size(), length); + checkStats(length, 1000); } @Test public void dontCompressIfNoEncoding() throws Exception { allocator = new BytesWritableBufferAllocator(100, Integer.MAX_VALUE); - framer = new MessageFramer(sink, allocator) + framer = new MessageFramer(sink, allocator, statsTraceCtx) .setMessageCompression(true); writeKnownLength(framer, new byte[1000]); framer.flush(); @@ -281,12 +310,13 @@ public class MessageFramerTest { assertEquals(1000, length); assertEquals(buffer.data.length - 5 , length); + checkStats(1000, 1000); } @Test public void dontCompressIfNotRequested() throws Exception { allocator = new BytesWritableBufferAllocator(100, Integer.MAX_VALUE); - framer = new MessageFramer(sink, allocator) + framer = new MessageFramer(sink, allocator, statsTraceCtx) .setCompressor(new Codec.Gzip()) .setMessageCompression(false); writeKnownLength(framer, new byte[1000]); @@ -306,6 +336,7 @@ public class MessageFramerTest { assertEquals(1000, length); assertEquals(buffer.data.length - 5 , length); + checkStats(1000, 1000); } @Test @@ -322,7 +353,7 @@ public class MessageFramerTest { } } }; - framer = new MessageFramer(reentrant, allocator); + framer = new MessageFramer(reentrant, allocator, statsTraceCtx); writeKnownLength(framer, new byte[]{3, 14}); framer.close(); } @@ -334,6 +365,7 @@ public class MessageFramerTest { writeKnownLength(framer, new byte[]{}); framer.flush(); verify(sink).deliverFrame(toWriteBuffer(new byte[] {0, 0, 0, 0, 0}), false, true); + checkStats(0, 0); } private static WritableBuffer toWriteBuffer(byte[] data) { @@ -355,6 +387,23 @@ public class MessageFramerTest { // TODO(carl-mastrangelo): add framer.flush() here. } + private void checkStats(long wireBytesSent, long uncompressedBytesSent) { + statsTraceCtx.callEnded(Status.OK); + MetricsRecord record = censusCtxFactory.pollRecord(); + assertEquals(0, record.getMetricAsLongOrFail( + RpcConstants.RPC_SERVER_REQUEST_BYTES)); + assertEquals(0, record.getMetricAsLongOrFail( + RpcConstants.RPC_SERVER_UNCOMPRESSED_REQUEST_BYTES)); + assertEquals(wireBytesSent, + record.getMetricAsLongOrFail(RpcConstants.RPC_SERVER_RESPONSE_BYTES)); + assertEquals(uncompressedBytesSent, + record.getMetricAsLongOrFail(RpcConstants.RPC_SERVER_UNCOMPRESSED_RESPONSE_BYTES)); + assertNull(record.getMetric(RpcConstants.RPC_CLIENT_REQUEST_BYTES)); + assertNull(record.getMetric(RpcConstants.RPC_CLIENT_RESPONSE_BYTES)); + assertNull(record.getMetric(RpcConstants.RPC_CLIENT_UNCOMPRESSED_REQUEST_BYTES)); + assertNull(record.getMetric(RpcConstants.RPC_CLIENT_UNCOMPRESSED_RESPONSE_BYTES)); + } + static class ByteWritableBuffer implements WritableBuffer { byte[] data; private int writeIdx; diff --git a/core/src/test/java/io/grpc/internal/ServerCallImplTest.java b/core/src/test/java/io/grpc/internal/ServerCallImplTest.java index 62725cfcbf..a636a33fa8 100644 --- a/core/src/test/java/io/grpc/internal/ServerCallImplTest.java +++ b/core/src/test/java/io/grpc/internal/ServerCallImplTest.java @@ -32,6 +32,7 @@ package io.grpc.internal; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -40,6 +41,8 @@ import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import com.google.census.RpcConstants; +import com.google.census.TagValue; import com.google.common.io.CharStreams; import com.google.common.util.concurrent.Futures; @@ -53,6 +56,8 @@ import io.grpc.MethodDescriptor.MethodType; import io.grpc.ServerCall; import io.grpc.Status; import io.grpc.internal.ServerCallImpl.ServerStreamListenerImpl; +import io.grpc.internal.testing.CensusTestUtils.FakeCensusContextFactory; +import io.grpc.internal.testing.CensusTestUtils; import org.junit.Before; import org.junit.Rule; @@ -86,12 +91,18 @@ public class ServerCallImplTest { private final MethodDescriptor method = MethodDescriptor.create( MethodType.UNARY, "/service/method", new LongMarshaller(), new LongMarshaller()); + private final Metadata requestHeaders = new Metadata(); + private final FakeCensusContextFactory censusCtxFactory = new FakeCensusContextFactory(); + private final StatsTraceContext statsTraceCtx = StatsTraceContext.newServerContext( + method.getFullMethodName(), censusCtxFactory, requestHeaders, GrpcUtil.STOPWATCH_SUPPLIER); + @Before public void setUp() { MockitoAnnotations.initMocks(this); context = Context.ROOT.withCancellation(); - call = new ServerCallImpl(stream, method, new Metadata(), context, - DecompressorRegistry.getDefaultInstance(), CompressorRegistry.getDefaultInstance()); + call = new ServerCallImpl(stream, method, requestHeaders, context, + statsTraceCtx, DecompressorRegistry.getDefaultInstance(), + CompressorRegistry.getDefaultInstance()); } @Test @@ -189,7 +200,8 @@ public class ServerCallImplTest { @Test public void streamListener_halfClosed() { ServerStreamListenerImpl streamListener = - new ServerCallImpl.ServerStreamListenerImpl(call, callListener, context); + new ServerCallImpl.ServerStreamListenerImpl( + call, callListener, context, statsTraceCtx); streamListener.halfClosed(); @@ -199,7 +211,8 @@ public class ServerCallImplTest { @Test public void streamListener_halfClosed_onlyOnce() { ServerStreamListenerImpl streamListener = - new ServerCallImpl.ServerStreamListenerImpl(call, callListener, context); + new ServerCallImpl.ServerStreamListenerImpl( + call, callListener, context, statsTraceCtx); streamListener.halfClosed(); // canceling the call should short circuit future halfClosed() calls. streamListener.closed(Status.CANCELLED); @@ -212,31 +225,36 @@ public class ServerCallImplTest { @Test public void streamListener_closedOk() { ServerStreamListenerImpl streamListener = - new ServerCallImpl.ServerStreamListenerImpl(call, callListener, context); + new ServerCallImpl.ServerStreamListenerImpl( + call, callListener, context, statsTraceCtx); streamListener.closed(Status.OK); verify(callListener).onComplete(); assertTrue(context.isCancelled()); assertNull(context.cancellationCause()); + checkStats(Status.Code.OK); } @Test public void streamListener_closedCancelled() { ServerStreamListenerImpl streamListener = - new ServerCallImpl.ServerStreamListenerImpl(call, callListener, context); + new ServerCallImpl.ServerStreamListenerImpl( + call, callListener, context, statsTraceCtx); streamListener.closed(Status.CANCELLED); verify(callListener).onCancel(); assertTrue(context.isCancelled()); assertNull(context.cancellationCause()); + checkStats(Status.Code.CANCELLED); } @Test public void streamListener_onReady() { ServerStreamListenerImpl streamListener = - new ServerCallImpl.ServerStreamListenerImpl(call, callListener, context); + new ServerCallImpl.ServerStreamListenerImpl( + call, callListener, context, statsTraceCtx); streamListener.onReady(); @@ -246,7 +264,8 @@ public class ServerCallImplTest { @Test public void streamListener_onReady_onlyOnce() { ServerStreamListenerImpl streamListener = - new ServerCallImpl.ServerStreamListenerImpl(call, callListener, context); + new ServerCallImpl.ServerStreamListenerImpl( + call, callListener, context, statsTraceCtx); streamListener.onReady(); // canceling the call should short circuit future halfClosed() calls. streamListener.closed(Status.CANCELLED); @@ -259,7 +278,8 @@ public class ServerCallImplTest { @Test public void streamListener_messageRead() { ServerStreamListenerImpl streamListener = - new ServerCallImpl.ServerStreamListenerImpl(call, callListener, context); + new ServerCallImpl.ServerStreamListenerImpl( + call, callListener, context, statsTraceCtx); streamListener.messageRead(method.streamRequest(1234L)); verify(callListener).onMessage(1234L); @@ -268,7 +288,8 @@ public class ServerCallImplTest { @Test public void streamListener_messageRead_unaryFailsOnMultiple() { ServerStreamListenerImpl streamListener = - new ServerCallImpl.ServerStreamListenerImpl(call, callListener, context); + new ServerCallImpl.ServerStreamListenerImpl( + call, callListener, context, statsTraceCtx); streamListener.messageRead(method.streamRequest(1234L)); streamListener.messageRead(method.streamRequest(1234L)); @@ -282,7 +303,8 @@ public class ServerCallImplTest { @Test public void streamListener_messageRead_onlyOnce() { ServerStreamListenerImpl streamListener = - new ServerCallImpl.ServerStreamListenerImpl(call, callListener, context); + new ServerCallImpl.ServerStreamListenerImpl( + call, callListener, context, statsTraceCtx); streamListener.messageRead(method.streamRequest(1234L)); // canceling the call should short circuit future halfClosed() calls. streamListener.closed(Status.CANCELLED); @@ -292,6 +314,28 @@ public class ServerCallImplTest { verify(callListener).onMessage(1234L); } + private void checkStats(Status.Code statusCode) { + CensusTestUtils.MetricsRecord record = censusCtxFactory.pollRecord(); + assertNotNull(record); + TagValue statusTag = record.tags.get(RpcConstants.RPC_STATUS); + assertNotNull(statusTag); + assertEquals(statusCode.toString(), statusTag.toString()); + assertNull(record.getMetric(RpcConstants.RPC_CLIENT_REQUEST_BYTES)); + assertNull(record.getMetric(RpcConstants.RPC_CLIENT_RESPONSE_BYTES)); + assertNull(record.getMetric(RpcConstants.RPC_CLIENT_UNCOMPRESSED_REQUEST_BYTES)); + assertNull(record.getMetric(RpcConstants.RPC_CLIENT_UNCOMPRESSED_RESPONSE_BYTES)); + // The test doesn't invoke MessageFramer and MessageDeframer which keep the sizes. + // Thus the sizes reported to stats would be zero. + assertEquals(0, + record.getMetricAsLongOrFail(RpcConstants.RPC_SERVER_REQUEST_BYTES)); + assertEquals(0, + record.getMetricAsLongOrFail(RpcConstants.RPC_SERVER_RESPONSE_BYTES)); + assertEquals(0, + record.getMetricAsLongOrFail(RpcConstants.RPC_SERVER_UNCOMPRESSED_REQUEST_BYTES)); + assertEquals(0, + record.getMetricAsLongOrFail(RpcConstants.RPC_SERVER_UNCOMPRESSED_RESPONSE_BYTES)); + } + private static class LongMarshaller implements Marshaller { @Override public InputStream stream(Long value) { diff --git a/core/src/test/java/io/grpc/internal/ServerImplTest.java b/core/src/test/java/io/grpc/internal/ServerImplTest.java index ad7ce0e264..3022db506b 100644 --- a/core/src/test/java/io/grpc/internal/ServerImplTest.java +++ b/core/src/test/java/io/grpc/internal/ServerImplTest.java @@ -35,19 +35,26 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import static org.mockito.Matchers.any; import static org.mockito.Matchers.isA; import static org.mockito.Matchers.isNotNull; import static org.mockito.Matchers.notNull; import static org.mockito.Matchers.same; +import static org.mockito.Mockito.atLeast; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; +import com.google.census.CensusContext; +import com.google.census.RpcConstants; +import com.google.census.TagValue; import com.google.common.collect.ImmutableList; import com.google.common.truth.Truth; import com.google.common.util.concurrent.MoreExecutors; @@ -70,6 +77,8 @@ import io.grpc.ServerTransportFilter; import io.grpc.ServiceDescriptor; import io.grpc.Status; import io.grpc.StringMarshaller; +import io.grpc.internal.testing.CensusTestUtils.FakeCensusContextFactory; +import io.grpc.internal.testing.CensusTestUtils; import io.grpc.util.MutableHandlerRegistry; import org.junit.After; @@ -81,6 +90,7 @@ import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; +import org.mockito.Captor; import org.mockito.Matchers; import org.mockito.Mock; import org.mockito.MockitoAnnotations; @@ -108,6 +118,8 @@ public class ServerImplTest { private static final Context.CancellableContext SERVER_CONTEXT = Context.ROOT.withValue(SERVER_ONLY, "yes").withCancellation(); private static final ImmutableList NO_FILTERS = ImmutableList.of(); + + private final FakeCensusContextFactory censusCtxFactory = new FakeCensusContextFactory(); private final CompressorRegistry compressorRegistry = CompressorRegistry.getDefaultInstance(); private final DecompressorRegistry decompressorRegistry = DecompressorRegistry.getDefaultInstance(); @@ -126,7 +138,11 @@ public class ServerImplTest { private MutableHandlerRegistry fallbackRegistry = new MutableHandlerRegistry(); private SimpleServer transportServer = new SimpleServer(); private ServerImpl server = new ServerImpl(executor, registry, fallbackRegistry, transportServer, - SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS); + SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS, censusCtxFactory, + GrpcUtil.STOPWATCH_SUPPLIER); + + @Captor + private ArgumentCaptor statusCaptor; @Mock private ServerStream stream; @@ -158,7 +174,8 @@ public class ServerImplTest { public void shutdown() {} }; ServerImpl server = new ServerImpl(executor, registry, fallbackRegistry, transportServer, - SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS); + SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS, censusCtxFactory, + GrpcUtil.STOPWATCH_SUPPLIER); server.start(); server.shutdown(); assertTrue(server.isShutdown()); @@ -176,7 +193,8 @@ public class ServerImplTest { } }; ServerImpl server = new ServerImpl(executor, registry, fallbackRegistry, transportServer, - SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS); + SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS, censusCtxFactory, + GrpcUtil.STOPWATCH_SUPPLIER); server.shutdown(); assertTrue(server.isShutdown()); assertTrue(server.isTerminated()); @@ -185,7 +203,8 @@ public class ServerImplTest { @Test public void startStopImmediateWithChildTransport() throws IOException { ServerImpl server = new ServerImpl(executor, registry, fallbackRegistry, transportServer, - SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS); + SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS, censusCtxFactory, + GrpcUtil.STOPWATCH_SUPPLIER); server.start(); class DelayedShutdownServerTransport extends SimpleServerTransport { boolean shutdown; @@ -209,7 +228,8 @@ public class ServerImplTest { @Test public void startShutdownNowImmediateWithChildTransport() throws IOException { ServerImpl server = new ServerImpl(executor, registry, fallbackRegistry, transportServer, - SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS); + SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS, censusCtxFactory, + GrpcUtil.STOPWATCH_SUPPLIER); server.start(); class DelayedShutdownServerTransport extends SimpleServerTransport { boolean shutdown; @@ -236,7 +256,8 @@ public class ServerImplTest { @Test public void shutdownNowAfterShutdown() throws IOException { ServerImpl server = new ServerImpl(executor, registry, fallbackRegistry, transportServer, - SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS); + SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS, censusCtxFactory, + GrpcUtil.STOPWATCH_SUPPLIER); server.start(); class DelayedShutdownServerTransport extends SimpleServerTransport { boolean shutdown; @@ -270,7 +291,8 @@ public class ServerImplTest { } }; ServerImpl server = new ServerImpl(executor, registry, fallbackRegistry, transportServer, - SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS); + SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS, censusCtxFactory, + GrpcUtil.STOPWATCH_SUPPLIER); server.start(); class DelayedShutdownServerTransport extends SimpleServerTransport { boolean shutdown; @@ -307,7 +329,7 @@ public class ServerImplTest { ServerImpl server = new ServerImpl(executor, registry, fallbackRegistry, new FailingStartupServer(), SERVER_CONTEXT, decompressorRegistry, compressorRegistry, - NO_FILTERS); + NO_FILTERS, censusCtxFactory, GrpcUtil.STOPWATCH_SUPPLIER); try { server.start(); fail("expected exception"); @@ -316,10 +338,41 @@ public class ServerImplTest { } } + @Test + public void methodNotFound() throws Exception { + ServerTransportListener transportListener + = transportServer.registerNewServerTransport(new SimpleServerTransport()); + Metadata requestHeaders = new Metadata(); + StatsTraceContext statsTraceCtx = + transportListener.methodDetermined("Waiter/nonexist", requestHeaders); + when(stream.statsTraceContext()).thenReturn(statsTraceCtx); + ServerStreamListener streamListener + = transportListener.streamCreated(stream, "Waiter/nonexist", requestHeaders); + assertNotNull(streamListener); + verify(stream, atLeast(1)).statsTraceContext(); + + executeBarrier(executor).await(); + verify(stream).close(statusCaptor.capture(), any(Metadata.class)); + Status status = statusCaptor.getValue(); + assertEquals(Status.Code.UNIMPLEMENTED, status.getCode()); + assertEquals("Method not found: Waiter/nonexist", status.getDescription()); + + CensusTestUtils.MetricsRecord record = censusCtxFactory.pollRecord(); + assertNotNull(record); + TagValue methodTag = record.tags.get(RpcConstants.RPC_SERVER_METHOD); + assertNotNull(methodTag); + assertEquals("Waiter/nonexist", methodTag.toString()); + TagValue statusTag = record.tags.get(RpcConstants.RPC_STATUS); + assertNotNull(statusTag); + assertEquals(Status.Code.UNIMPLEMENTED.toString(), statusTag.toString()); + } + @Test public void basicExchangeSuccessful() throws Exception { final Metadata.Key metadataKey = Metadata.Key.of("inception", Metadata.ASCII_STRING_MARSHALLER); + final Metadata.Key censusHeaderKey + = StatsTraceContext.createCensusHeader(censusCtxFactory); final AtomicReference> callReference = new AtomicReference>(); MethodDescriptor method = MethodDescriptor.create( @@ -346,9 +399,18 @@ public class ServerImplTest { Metadata requestHeaders = new Metadata(); requestHeaders.put(metadataKey, "value"); + CensusContext censusContextOnClient = censusCtxFactory.getDefault().with( + CensusTestUtils.EXTRA_TAG, new TagValue("extraTagValue")); + requestHeaders.put(censusHeaderKey, censusContextOnClient); + StatsTraceContext statsTraceCtx = + transportListener.methodDetermined("Waiter/serve", requestHeaders); + assertNotNull(statsTraceCtx); + when(stream.statsTraceContext()).thenReturn(statsTraceCtx); + ServerStreamListener streamListener = transportListener.streamCreated(stream, "Waiter/serve", requestHeaders); assertNotNull(streamListener); + verify(stream, atLeast(1)).statsTraceContext(); executeBarrier(executor).await(); ServerCall call = callReference.get(); @@ -389,8 +451,34 @@ public class ServerImplTest { executeBarrier(executor).await(); verify(callListener).onComplete(); + verify(stream, atLeast(1)).statsTraceContext(); verifyNoMoreInteractions(stream); verifyNoMoreInteractions(callListener); + + // Check stats + CensusTestUtils.MetricsRecord record = censusCtxFactory.pollRecord(); + assertNotNull(record); + TagValue methodTag = record.tags.get(RpcConstants.RPC_SERVER_METHOD); + assertNotNull(methodTag); + assertEquals("Waiter/serve", methodTag.toString()); + TagValue statusTag = record.tags.get(RpcConstants.RPC_STATUS); + assertNotNull(statusTag); + assertEquals(Status.Code.OK.toString(), statusTag.toString()); + TagValue extraTag = record.tags.get(CensusTestUtils.EXTRA_TAG); + assertNotNull(extraTag); + assertEquals("extraTagValue", extraTag.toString()); + assertNull(record.getMetric(RpcConstants.RPC_CLIENT_REQUEST_BYTES)); + assertNull(record.getMetric(RpcConstants.RPC_CLIENT_RESPONSE_BYTES)); + assertNull(record.getMetric(RpcConstants.RPC_CLIENT_UNCOMPRESSED_REQUEST_BYTES)); + assertNull(record.getMetric(RpcConstants.RPC_CLIENT_UNCOMPRESSED_RESPONSE_BYTES)); + // The test doesn't invoke MessageFramer and MessageDeframer which keep the sizes. + // Thus the sizes reported to stats would be zero. + assertEquals(0, record.getMetricAsLongOrFail(RpcConstants.RPC_SERVER_REQUEST_BYTES)); + assertEquals(0, record.getMetricAsLongOrFail(RpcConstants.RPC_SERVER_RESPONSE_BYTES)); + assertEquals(0, + record.getMetricAsLongOrFail(RpcConstants.RPC_SERVER_UNCOMPRESSED_REQUEST_BYTES)); + assertEquals(0, + record.getMetricAsLongOrFail(RpcConstants.RPC_SERVER_UNCOMPRESSED_RESPONSE_BYTES)); } @Test @@ -454,7 +542,7 @@ public class ServerImplTest { ServerImpl server = new ServerImpl(MoreExecutors.directExecutor(), registry, fallbackRegistry, transportServer, SERVER_CONTEXT, decompressorRegistry, compressorRegistry, - ImmutableList.of(filter1, filter2)); + ImmutableList.of(filter1, filter2), censusCtxFactory, GrpcUtil.STOPWATCH_SUPPLIER); server.start(); ServerTransportListener transportListener = transportServer.registerNewServerTransport(new SimpleServerTransport()); @@ -493,14 +581,22 @@ public class ServerImplTest { ServerTransportListener transportListener = transportServer.registerNewServerTransport(new SimpleServerTransport()); + Metadata requestHeaders = new Metadata(); + StatsTraceContext statsTraceCtx = + transportListener.methodDetermined("Waiter/serve", requestHeaders); + assertNotNull(statsTraceCtx); + when(stream.statsTraceContext()).thenReturn(statsTraceCtx); + ServerStreamListener streamListener - = transportListener.streamCreated(stream, "Waiter/serve", new Metadata()); + = transportListener.streamCreated(stream, "Waiter/serve", requestHeaders); assertNotNull(streamListener); + verify(stream, atLeast(1)).statsTraceContext(); verifyNoMoreInteractions(stream); barrier.await(); executeBarrier(executor).await(); verify(stream).close(same(status), notNull(Metadata.class)); + verify(stream, atLeast(1)).statsTraceContext(); verifyNoMoreInteractions(stream); } @@ -526,7 +622,8 @@ public class ServerImplTest { transportServer = new MaybeDeadlockingServer(); ServerImpl server = new ServerImpl(executor, registry, fallbackRegistry, transportServer, - SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS); + SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS, censusCtxFactory, + GrpcUtil.STOPWATCH_SUPPLIER); server.start(); new Thread() { @Override @@ -589,6 +686,10 @@ public class ServerImplTest { public void testCallContextIsBoundInListenerCallbacks() throws Exception { MethodDescriptor method = MethodDescriptor.create( MethodType.UNKNOWN, "Waiter/serve", STRING_MARSHALLER, INTEGER_MARSHALLER); + final CountDownLatch onReadyCalled = new CountDownLatch(1); + final CountDownLatch onMessageCalled = new CountDownLatch(1); + final CountDownLatch onHalfCloseCalled = new CountDownLatch(1); + final CountDownLatch onCancelCalled = new CountDownLatch(1); fallbackRegistry.addService(ServerServiceDefinition.builder( new ServiceDescriptor("Waiter", method)) .addMethod( @@ -608,31 +709,30 @@ public class ServerImplTest { @Override public void onReady() { checkContext(); - super.onReady(); + onReadyCalled.countDown(); } @Override public void onMessage(String message) { checkContext(); - super.onMessage(message); + onMessageCalled.countDown(); } @Override public void onHalfClose() { checkContext(); - super.onHalfClose(); + onHalfCloseCalled.countDown(); } @Override public void onCancel() { checkContext(); - super.onCancel(); + onCancelCalled.countDown(); } @Override public void onComplete() { checkContext(); - super.onComplete(); } private void checkContext() { @@ -645,8 +745,14 @@ public class ServerImplTest { ServerTransportListener transportListener = transportServer.registerNewServerTransport(new SimpleServerTransport()); + Metadata requestHeaders = new Metadata(); + StatsTraceContext statsTraceCtx = + transportListener.methodDetermined("Waiter/serve", requestHeaders); + assertNotNull(statsTraceCtx); + when(stream.statsTraceContext()).thenReturn(statsTraceCtx); + ServerStreamListener streamListener - = transportListener.streamCreated(stream, "Waiter/serve", new Metadata()); + = transportListener.streamCreated(stream, "Waiter/serve", requestHeaders); assertNotNull(streamListener); streamListener.onReady(); @@ -654,6 +760,11 @@ public class ServerImplTest { streamListener.halfClosed(); streamListener.closed(Status.CANCELLED); + assertTrue(onReadyCalled.await(5, TimeUnit.SECONDS)); + assertTrue(onMessageCalled.await(5, TimeUnit.SECONDS)); + assertTrue(onHalfCloseCalled.await(5, TimeUnit.SECONDS)); + assertTrue(onCancelCalled.await(5, TimeUnit.SECONDS)); + // Close should never be called if asserts in listener pass. verify(stream, times(0)).close(isA(Status.class), isNotNull(Metadata.class)); } @@ -691,9 +802,13 @@ public class ServerImplTest { }).build()); ServerTransportListener transportListener = transportServer.registerNewServerTransport(new SimpleServerTransport()); - + Metadata requestHeaders = new Metadata(); + StatsTraceContext statsTraceCtx = + transportListener.methodDetermined("Waiter/serve", requestHeaders); + assertNotNull(statsTraceCtx); + when(stream.statsTraceContext()).thenReturn(statsTraceCtx); ServerStreamListener streamListener - = transportListener.streamCreated(stream, "Waiter/serve", new Metadata()); + = transportListener.streamCreated(stream, "Waiter/serve", requestHeaders); assertNotNull(streamListener); streamListener.onReady(); @@ -711,7 +826,8 @@ public class ServerImplTest { } }; ServerImpl server = new ServerImpl(executor, registry, fallbackRegistry, transportServer, - SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS); + SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS, censusCtxFactory, + GrpcUtil.STOPWATCH_SUPPLIER); server.start(); Truth.assertThat(server.getPort()).isEqualTo(65535); @@ -721,7 +837,8 @@ public class ServerImplTest { public void getPortBeforeStartedFails() { transportServer = new SimpleServer(); ServerImpl server = new ServerImpl(executor, registry, fallbackRegistry, transportServer, - SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS); + SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS, censusCtxFactory, + GrpcUtil.STOPWATCH_SUPPLIER); thrown.expect(IllegalStateException.class); thrown.expectMessage("started"); server.getPort(); @@ -731,7 +848,8 @@ public class ServerImplTest { public void getPortAfterTerminationFails() throws Exception { transportServer = new SimpleServer(); ServerImpl server = new ServerImpl(executor, registry, fallbackRegistry, transportServer, - SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS); + SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS, censusCtxFactory, + GrpcUtil.STOPWATCH_SUPPLIER); server.start(); server.shutdown(); server.awaitTermination(); @@ -751,16 +869,23 @@ public class ServerImplTest { .build(); transportServer = new SimpleServer(); ServerImpl server = new ServerImpl(executor, registry, fallbackRegistry, transportServer, - SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS); + SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS, censusCtxFactory, + GrpcUtil.STOPWATCH_SUPPLIER); server.start(); ServerTransportListener transportListener = transportServer.registerNewServerTransport(new SimpleServerTransport()); + Metadata requestHeaders = new Metadata(); + StatsTraceContext statsTraceCtx = + transportListener.methodDetermined("Waiter/serve", requestHeaders); + assertNotNull(statsTraceCtx); + when(stream.statsTraceContext()).thenReturn(statsTraceCtx); + // This call will be handled by callHandler from the internal registry - transportListener.streamCreated(stream, "Service1/Method1", new Metadata()); + transportListener.streamCreated(stream, "Service1/Method1", requestHeaders); // This call will be handled by the fallbackRegistry because it's not registred in the internal // registry. - transportListener.streamCreated(stream, "Service1/Method2", new Metadata()); + transportListener.streamCreated(stream, "Service1/Method2", requestHeaders); verify(callHandler, timeout(2000)).startCall(Matchers.>anyObject(), Matchers.anyObject()); diff --git a/core/src/test/java/io/grpc/internal/TestUtils.java b/core/src/test/java/io/grpc/internal/TestUtils.java index 9c887a13cf..b8c643e594 100644 --- a/core/src/test/java/io/grpc/internal/TestUtils.java +++ b/core/src/test/java/io/grpc/internal/TestUtils.java @@ -86,7 +86,8 @@ final class TestUtils { public ConnectionClientTransport answer(InvocationOnMock invocation) throws Throwable { final ConnectionClientTransport mockTransport = mock(ConnectionClientTransport.class); when(mockTransport.newStream(any(MethodDescriptor.class), any(Metadata.class), - any(CallOptions.class))).thenReturn(mock(ClientStream.class)); + any(CallOptions.class), any(StatsTraceContext.class))) + .thenReturn(mock(ClientStream.class)); // Save the listener doAnswer(new Answer() { @Override diff --git a/core/src/test/java/io/grpc/internal/TransportSetTest.java b/core/src/test/java/io/grpc/internal/TransportSetTest.java index ffc69ee803..e3088a6d64 100644 --- a/core/src/test/java/io/grpc/internal/TransportSetTest.java +++ b/core/src/test/java/io/grpc/internal/TransportSetTest.java @@ -101,6 +101,7 @@ public class TransportSetTest { private final Metadata headers = new Metadata(); private final CallOptions waitForReadyCallOptions = CallOptions.DEFAULT.withWaitForReady(); private final CallOptions failFastCallOptions = CallOptions.DEFAULT; + private final StatsTraceContext statsTraceCtx = StatsTraceContext.NOOP; private TransportSet transportSet; private EquivalentAddressGroup addressGroup; @@ -137,7 +138,8 @@ public class TransportSetTest { int onAllAddressesFailed = 0; // First attempt - transportSet.obtainActiveTransport().newStream(method, new Metadata(), waitForReadyCallOptions); + transportSet.obtainActiveTransport().newStream(method, new Metadata(), waitForReadyCallOptions, + statsTraceCtx); assertEquals(ConnectivityState.CONNECTING, transportSet.getState(false)); verify(mockTransportFactory, times(++transportsCreated)) .newClientTransport(addr, AUTHORITY, USER_AGENT); @@ -225,7 +227,7 @@ public class TransportSetTest { assertEquals(ConnectivityState.CONNECTING, transportSet.getState(false)); verify(mockTransportFactory, times(++transportsAddr1)) .newClientTransport(addr1, AUTHORITY, USER_AGENT); - delayedTransport1.newStream(method, new Metadata(), waitForReadyCallOptions); + delayedTransport1.newStream(method, new Metadata(), waitForReadyCallOptions, statsTraceCtx); // Let this one fail without success transports.poll().listener.transportShutdown(Status.UNAVAILABLE); assertEquals(ConnectivityState.CONNECTING, transportSet.getState(false)); @@ -320,7 +322,7 @@ public class TransportSetTest { (DelayedClientTransport) transportSet.obtainActiveTransport(); assertEquals(ConnectivityState.CONNECTING, transportSet.getState(false)); assertNotSame(delayedTransport5, delayedTransport6); - delayedTransport6.newStream(method, headers, waitForReadyCallOptions); + delayedTransport6.newStream(method, headers, waitForReadyCallOptions, statsTraceCtx); verify(mockBackoffPolicyProvider, times(backoffReset)).get(); verify(mockTransportFactory, times(++transportsAddr1)) .newClientTransport(addr1, AUTHORITY, USER_AGENT); @@ -387,13 +389,14 @@ public class TransportSetTest { assertFalse(delayedTransport.isInBackoffPeriod()); // Create a new fail fast stream. - ClientStream ffStream = delayedTransport.newStream(method, headers, failFastCallOptions); + ClientStream ffStream = delayedTransport.newStream(method, headers, failFastCallOptions, + statsTraceCtx); ffStream.start(mockStreamListener); // Verify it is queued. assertEquals(++pendingStreamsCount, delayedTransport.getPendingStreamsCount()); failFastPendingStreamsCount++; // Create a new non fail fast stream. - delayedTransport.newStream(method, headers, waitForReadyCallOptions); + delayedTransport.newStream(method, headers, waitForReadyCallOptions, statsTraceCtx); // Verify it is queued. assertEquals(++pendingStreamsCount, delayedTransport.getPendingStreamsCount()); @@ -405,12 +408,12 @@ public class TransportSetTest { assertEquals(pendingStreamsCount, delayedTransport.getPendingStreamsCount()); // Create a new fail fast stream. - delayedTransport.newStream(method, headers, failFastCallOptions); + delayedTransport.newStream(method, headers, failFastCallOptions, statsTraceCtx); // Verify it is queued. assertEquals(++pendingStreamsCount, delayedTransport.getPendingStreamsCount()); failFastPendingStreamsCount++; // Create a new non fail fast stream - delayedTransport.newStream(method, headers, waitForReadyCallOptions); + delayedTransport.newStream(method, headers, waitForReadyCallOptions, statsTraceCtx); // Verify it is queued. assertEquals(++pendingStreamsCount, delayedTransport.getPendingStreamsCount()); @@ -428,11 +431,11 @@ public class TransportSetTest { verify(mockStreamListener).closed(same(failureStatus), any(Metadata.class)); // Create a new fail fast stream. - delayedTransport.newStream(method, headers, failFastCallOptions); + delayedTransport.newStream(method, headers, failFastCallOptions, statsTraceCtx); // Verify it is not queued. assertEquals(pendingStreamsCount, delayedTransport.getPendingStreamsCount()); // Create a new non fail fast stream - delayedTransport.newStream(method, headers, waitForReadyCallOptions); + delayedTransport.newStream(method, headers, waitForReadyCallOptions, statsTraceCtx); // Verify it is queued. assertEquals(++pendingStreamsCount, delayedTransport.getPendingStreamsCount()); @@ -442,7 +445,7 @@ public class TransportSetTest { assertFalse(delayedTransport.isInBackoffPeriod()); // Create a new fail fast stream. - delayedTransport.newStream(method, headers, failFastCallOptions); + delayedTransport.newStream(method, headers, failFastCallOptions, statsTraceCtx); // Verify it is queued. assertEquals(++pendingStreamsCount, delayedTransport.getPendingStreamsCount()); failFastPendingStreamsCount++; @@ -487,7 +490,8 @@ public class TransportSetTest { assertEquals(ConnectivityState.IDLE, transportSet.getState(false)); // Request immediately - transportSet.obtainActiveTransport().newStream(method, new Metadata(), waitForReadyCallOptions); + transportSet.obtainActiveTransport().newStream(method, new Metadata(), waitForReadyCallOptions, + statsTraceCtx); assertEquals(ConnectivityState.CONNECTING, transportSet.getState(false)); verify(mockTransportFactory, times(++transportsCreated)) .newClientTransport(addr, AUTHORITY, USER_AGENT); @@ -514,7 +518,8 @@ public class TransportSetTest { pick = transportSet.obtainActiveTransport(); assertTrue(pick instanceof DelayedClientTransport); // Start a stream, which will be pending in the delayed transport - ClientStream pendingStream = pick.newStream(method, headers, waitForReadyCallOptions); + ClientStream pendingStream = pick.newStream(method, headers, waitForReadyCallOptions, + statsTraceCtx); pendingStream.start(mockStreamListener); // Shut down TransportSet before the transport is created. Further call to @@ -542,7 +547,7 @@ public class TransportSetTest { any(MethodDescriptor.class), any(Metadata.class)); assertEquals(1, fakeExecutor.runDueTasks()); verify(transportInfo.transport).newStream(same(method), same(headers), - same(waitForReadyCallOptions)); + same(waitForReadyCallOptions), any(StatsTraceContext.class)); verify(transportInfo.transport).shutdown(); transportInfo.listener.transportShutdown(Status.UNAVAILABLE); assertEquals(ConnectivityState.SHUTDOWN, transportSet.getState(false)); @@ -640,7 +645,8 @@ public class TransportSetTest { assertEquals(ConnectivityState.CONNECTING, transportSet.getState(true)); assertEquals(ConnectivityState.CONNECTING, transportSet.getState(true)); - transportSet.obtainActiveTransport().newStream(method, new Metadata(), waitForReadyCallOptions); + transportSet.obtainActiveTransport().newStream(method, new Metadata(), waitForReadyCallOptions, + statsTraceCtx); assertEquals(ConnectivityState.CONNECTING, transportSet.getState(true)); // Fail it @@ -698,7 +704,8 @@ public class TransportSetTest { int notInUse = 0; verify(mockTransportSetCallback, never()).onInUse(any(TransportSet.class)); - transportSet.obtainActiveTransport().newStream(method, new Metadata(), waitForReadyCallOptions); + transportSet.obtainActiveTransport().newStream(method, new Metadata(), waitForReadyCallOptions, + statsTraceCtx); verify(mockTransportSetCallback, times(++inUse)).onInUse(transportSet); MockClientTransportInfo t0 = transports.poll(); @@ -711,7 +718,8 @@ public class TransportSetTest { // Delayed transport calls newStream() on the real transport in the executor fakeExecutor.runDueTasks(); verify(t0.transport).newStream( - same(method), any(Metadata.class), same(waitForReadyCallOptions)); + same(method), any(Metadata.class), same(waitForReadyCallOptions), + any(StatsTraceContext.class)); verify(mockTransportSetCallback, times(inUse)).onInUse(transportSet); t0.listener.transportInUse(true); verify(mockTransportSetCallback, times(++inUse)).onInUse(transportSet); @@ -726,13 +734,15 @@ public class TransportSetTest { t0.listener.transportShutdown(Status.UNAVAILABLE); // Creates a new transport - transportSet.obtainActiveTransport().newStream(method, new Metadata(), waitForReadyCallOptions); + transportSet.obtainActiveTransport().newStream(method, new Metadata(), waitForReadyCallOptions, + statsTraceCtx); MockClientTransportInfo t1 = transports.poll(); t1.listener.transportReady(); // Delayed transport calls newStream() on the real transport in the executor fakeExecutor.runDueTasks(); verify(t1.transport).newStream( - same(method), any(Metadata.class), same(waitForReadyCallOptions)); + same(method), any(Metadata.class), same(waitForReadyCallOptions), + any(StatsTraceContext.class)); t1.listener.transportInUse(true); // No turbulance from the race mentioned eariler, because t0 has been in-use verify(mockTransportSetCallback, times(inUse)).onInUse(transportSet); @@ -769,7 +779,8 @@ public class TransportSetTest { // Attempt and fail, scheduleBackoff should be triggered, // and transportSet.shutdown should be triggered by setup - transportSet.obtainActiveTransport().newStream(method, new Metadata(), waitForReadyCallOptions); + transportSet.obtainActiveTransport().newStream(method, new Metadata(), waitForReadyCallOptions, + statsTraceCtx); transports.poll().listener.transportShutdown(Status.UNAVAILABLE); verify(mockTransportSetCallback, times(1)).onAllAddressesFailed(); assertTrue(startBackoffAndShutdownAreCalled[0]); diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java b/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java index 02dee58033..617c813847 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java @@ -46,12 +46,16 @@ import com.google.auth.oauth2.ComputeEngineCredentials; import com.google.auth.oauth2.GoogleCredentials; import com.google.auth.oauth2.OAuth2Credentials; import com.google.auth.oauth2.ServiceAccountCredentials; +import com.google.census.CensusContextFactory; +import com.google.census.RpcConstants; +import com.google.census.TagValue; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; import com.google.common.net.HostAndPort; import com.google.protobuf.ByteString; import com.google.protobuf.EmptyProtos.Empty; +import com.google.protobuf.MessageLite; import io.grpc.CallOptions; import io.grpc.ClientCall; @@ -59,14 +63,16 @@ import io.grpc.Grpc; import io.grpc.ManagedChannel; import io.grpc.Metadata; import io.grpc.Server; -import io.grpc.ServerBuilder; import io.grpc.ServerCall; import io.grpc.ServerInterceptor; import io.grpc.ServerInterceptors; import io.grpc.Status; import io.grpc.StatusRuntimeException; import io.grpc.auth.MoreCallCredentials; +import io.grpc.internal.AbstractServerImplBuilder; import io.grpc.internal.GrpcUtil; +import io.grpc.internal.testing.CensusTestUtils.FakeCensusContextFactory; +import io.grpc.internal.testing.CensusTestUtils.MetricsRecord; import io.grpc.protobuf.ProtoUtils; import io.grpc.stub.MetadataUtils; import io.grpc.stub.StreamObserver; @@ -95,7 +101,10 @@ import java.io.IOException; import java.io.InputStream; import java.security.cert.Certificate; import java.security.cert.X509Certificate; +import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; import java.util.List; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.Executors; @@ -120,9 +129,14 @@ public abstract class AbstractInteropTest { new AtomicReference(); private static ScheduledExecutorService testServiceExecutor; private static Server server; + private static final FakeCensusContextFactory clientCensusFactory = + new FakeCensusContextFactory(); + private static final FakeCensusContextFactory serverCensusFactory = + new FakeCensusContextFactory(); + protected static final Empty EMPTY = Empty.getDefaultInstance(); protected static void startStaticServer( - ServerBuilder builder, ServerInterceptor ... interceptors) { + AbstractServerImplBuilder builder, ServerInterceptor ... interceptors) { testServiceExecutor = Executors.newScheduledThreadPool(2); List allInterceptors = ImmutableList.builder() @@ -135,6 +149,7 @@ public abstract class AbstractInteropTest { builder.addService(ServerInterceptors.intercept( new TestServiceImpl(testServiceExecutor), allInterceptors)); + builder.censusContextFactory(serverCensusFactory); try { server = builder.build().start(); } catch (IOException ex) { @@ -165,6 +180,8 @@ public abstract class AbstractInteropTest { blockingStub = TestServiceGrpc.newBlockingStub(channel); asyncStub = TestServiceGrpc.newStub(channel); requestHeadersCapture.set(null); + clientCensusFactory.rolloverRecords(); + serverCensusFactory.rolloverRecords(); } /** Clean up. */ @@ -177,9 +194,17 @@ public abstract class AbstractInteropTest { protected abstract ManagedChannel createChannel(); + protected final CensusContextFactory getClientCensusFactory() { + return clientCensusFactory; + } + + protected boolean metricsExpected() { + return true; + } + @Test(timeout = 10000) public void emptyUnary() throws Exception { - assertEquals(Empty.getDefaultInstance(), blockingStub.emptyCall(Empty.getDefaultInstance())); + assertEquals(EMPTY, blockingStub.emptyCall(EMPTY)); } @Test(timeout = 10000) @@ -198,6 +223,11 @@ public abstract class AbstractInteropTest { .build(); assertEquals(goldenResponse, blockingStub.unaryCall(request)); + + if (metricsExpected()) { + assertMetrics("grpc.testing.TestService/UnaryCall", Status.Code.OK, + Collections.singleton(request), Collections.singleton(goldenResponse)); + } } @Test(timeout = 10000) @@ -273,6 +303,7 @@ public abstract class AbstractInteropTest { } requestObserver.onCompleted(); assertEquals(goldenResponse, responseObserver.firstValue().get()); + responseObserver.awaitCompletion(); } @Test(timeout = 10000) @@ -359,6 +390,11 @@ public abstract class AbstractInteropTest { assertEquals(Arrays.asList(), responseObserver.getValues()); assertEquals(Status.Code.CANCELLED, Status.fromThrowable(responseObserver.getError()).getCode()); + + if (metricsExpected()) { + assertClientMetrics("grpc.testing.TestService/StreamingInputCall", Status.Code.CANCELLED); + // Do not check server-side metrics, because the status on the server side is undetermined. + } } @Test(timeout = 10000) @@ -388,6 +424,10 @@ public abstract class AbstractInteropTest { verify(responseObserver, timeout(operationTimeoutMillis())).onError(captor.capture()); assertEquals(Status.Code.CANCELLED, Status.fromThrowable(captor.getValue()).getCode()); verifyNoMoreInteractions(responseObserver); + + if (metricsExpected()) { + assertMetrics("grpc.testing.TestService/FullDuplexCall", Status.Code.CANCELLED); + } } @Test(timeout = 10000) @@ -407,7 +447,10 @@ public abstract class AbstractInteropTest { asyncStub.fullDuplexCall(recorder); final int numRequests = 10; + List requests = + new ArrayList(numRequests); for (int ix = numRequests; ix > 0; --ix) { + requests.add(request); requestStream.onNext(request); } requestStream.onCompleted(); @@ -421,6 +464,11 @@ public abstract class AbstractInteropTest { int expectedSize = responseSizes.get(ix % responseSizes.size()); assertEquals("comparison failed at index " + ix, expectedSize, length); } + + if (metricsExpected()) { + assertMetrics("grpc.testing.TestService/FullDuplexCall", Status.Code.OK, requests, + recorder.getValues()); + } } @Test(timeout = 10000) @@ -439,7 +487,10 @@ public abstract class AbstractInteropTest { StreamObserver requestStream = asyncStub.halfDuplexCall(recorder); final int numRequests = 10; + List requests = + new ArrayList(numRequests); for (int ix = numRequests; ix > 0; --ix) { + requests.add(request); requestStream.onNext(request); } requestStream.onCompleted(); @@ -566,7 +617,7 @@ public abstract class AbstractInteropTest { AtomicReference headersCapture = new AtomicReference(); stub = MetadataUtils.captureMetadata(stub, headersCapture, trailersCapture); - assertNotNull(stub.emptyCall(Empty.getDefaultInstance())); + assertNotNull(stub.emptyCall(EMPTY)); // Assert that our side channel object is echoed back in both headers and trailers Assert.assertEquals(contextValue, headersCapture.get().get(METADATA_KEY)); @@ -603,7 +654,11 @@ public abstract class AbstractInteropTest { stub.fullDuplexCall(recorder); final int numRequests = 10; + List requests = + new ArrayList(numRequests); + for (int ix = numRequests; ix > 0; --ix) { + requests.add(request); requestStream.onNext(request); } requestStream.onCompleted(); @@ -621,7 +676,7 @@ public abstract class AbstractInteropTest { long configuredTimeoutMinutes = 100; TestServiceGrpc.TestServiceBlockingStub stub = TestServiceGrpc.newBlockingStub(channel) .withDeadlineAfter(configuredTimeoutMinutes, TimeUnit.MINUTES); - stub.emptyCall(Empty.getDefaultInstance()); + stub.emptyCall(EMPTY); long transferredTimeoutMinutes = TimeUnit.NANOSECONDS.toMinutes( requestHeadersCapture.get().get(GrpcUtil.TIMEOUT_KEY)); Assert.assertTrue( @@ -649,15 +704,22 @@ public abstract class AbstractInteropTest { blockingStub.emptyCall(Empty.getDefaultInstance()); TestServiceGrpc.TestServiceBlockingStub stub = TestServiceGrpc.newBlockingStub(channel) .withDeadlineAfter(10, TimeUnit.MILLISECONDS); + StreamingOutputCallRequest request = StreamingOutputCallRequest.newBuilder() + .addResponseParameters(ResponseParameters.newBuilder() + .setIntervalUs(20000)) + .build(); try { - stub.streamingOutputCall(StreamingOutputCallRequest.newBuilder() - .addResponseParameters(ResponseParameters.newBuilder() - .setIntervalUs(20000)) - .build()).next(); + stub.streamingOutputCall(request).next(); fail("Expected deadline to be exceeded"); } catch (StatusRuntimeException ex) { assertEquals(Status.DEADLINE_EXCEEDED.getCode(), ex.getStatus().getCode()); } + if (metricsExpected()) { + assertMetrics("grpc.testing.TestService/EmptyCall", Status.Code.OK); + assertClientMetrics("grpc.testing.TestService/StreamingOutputCall", + Status.Code.DEADLINE_EXCEEDED); + // Do not check server-side metrics, because the status on the server side is undetermined. + } } @Test(timeout = 10000) @@ -681,6 +743,12 @@ public abstract class AbstractInteropTest { recorder.awaitCompletion(); assertEquals(Status.DEADLINE_EXCEEDED.getCode(), Status.fromThrowable(recorder.getError()).getCode()); + if (metricsExpected()) { + assertMetrics("grpc.testing.TestService/EmptyCall", Status.Code.OK); + assertClientMetrics("grpc.testing.TestService/StreamingOutputCall", + Status.Code.DEADLINE_EXCEEDED); + // Do not check server-side metrics, because the status on the server side is undetermined. + } } @Test(timeout = 10000) @@ -690,9 +758,13 @@ public abstract class AbstractInteropTest { TestServiceGrpc.newBlockingStub(channel) .withDeadlineAfter(-10, TimeUnit.SECONDS) .emptyCall(Empty.getDefaultInstance()); + fail("Should have thrown"); } catch (StatusRuntimeException ex) { assertEquals(Status.Code.DEADLINE_EXCEEDED, ex.getStatus().getCode()); } + if (metricsExpected()) { + assertClientMetrics("grpc.testing.TestService/EmptyCall", Status.Code.DEADLINE_EXCEEDED); + } // warm up the channel blockingStub.emptyCall(Empty.getDefaultInstance()); @@ -700,9 +772,14 @@ public abstract class AbstractInteropTest { TestServiceGrpc.newBlockingStub(channel) .withDeadlineAfter(-10, TimeUnit.SECONDS) .emptyCall(Empty.getDefaultInstance()); + fail("Should have thrown"); } catch (StatusRuntimeException ex) { assertEquals(Status.Code.DEADLINE_EXCEEDED, ex.getStatus().getCode()); } + if (metricsExpected()) { + assertMetrics("grpc.testing.TestService/EmptyCall", Status.Code.OK); + assertClientMetrics("grpc.testing.TestService/EmptyCall", Status.Code.DEADLINE_EXCEEDED); + } } protected int unaryPayloadLength() { @@ -777,6 +854,11 @@ public abstract class AbstractInteropTest { } catch (StatusRuntimeException e) { assertEquals(Status.UNIMPLEMENTED.getCode(), e.getStatus().getCode()); } + + if (metricsExpected()) { + assertMetrics("grpc.testing.UnimplementedService/UnimplementedCall", + Status.Code.UNIMPLEMENTED); + } } /** Start a fullDuplexCall which the server will not respond, and verify the deadline expires. */ @@ -789,11 +871,12 @@ public abstract class AbstractInteropTest { StreamObserver requestObserver = stub.fullDuplexCall(responseObserver); + StreamingOutputCallRequest request = StreamingOutputCallRequest.newBuilder() + .setPayload(Payload.newBuilder() + .setBody(ByteString.copyFrom(new byte[27182]))) + .build(); try { - requestObserver.onNext(StreamingOutputCallRequest.newBuilder() - .setPayload(Payload.newBuilder() - .setBody(ByteString.copyFrom(new byte[27182]))) - .build()); + requestObserver.onNext(request); } catch (IllegalStateException expected) { // This can happen if the stream has already been terminated due to deadline exceeded. } @@ -803,6 +886,11 @@ public abstract class AbstractInteropTest { assertEquals(Status.DEADLINE_EXCEEDED.getCode(), Status.fromThrowable(captor.getValue()).getCode()); verifyNoMoreInteractions(responseObserver); + + if (metricsExpected()) { + assertClientMetrics("grpc.testing.TestService/FullDuplexCall", Status.Code.DEADLINE_EXCEEDED); + // Do not check server-side metrics, because the status on the server side is undetermined. + } } /** Sends a large unary rpc with service account credentials. */ @@ -1020,4 +1108,115 @@ public abstract class AbstractInteropTest { throw e; } } + + /** + * Poll the next metrics record and check it against the provided information, including the + * message sizes. + */ + private void assertMetrics(String method, Status.Code status, + Collection requests, + Collection responses) { + assertClientMetrics(method, status, requests, responses); + assertServerMetrics(method, status, requests, responses); + } + + /** + * Poll the next metrics record and check it against the provided information, without checking + * the message sizes. + */ + private void assertMetrics(String method, Status.Code status) { + assertMetrics(method, status, null, null); + } + + private void assertClientMetrics(String method, Status.Code status, + Collection requests, Collection responses) { + MetricsRecord clientRecord = clientCensusFactory.pollRecord(); + assertNotNull("clientRecord is not null", clientRecord); + checkTags(clientRecord, false, method, status); + if (requests != null && responses != null) { + checkMetrics(clientRecord, false, requests, responses); + } + } + + private void assertClientMetrics(String method, Status.Code status) { + assertClientMetrics(method, status, null, null); + } + + private void assertServerMetrics(String method, Status.Code status, + Collection requests, Collection responses) { + AssertionError checkFailure = null; + // Because the server doesn't restart between tests, it may still be processing the requests + // from the previous tests when a new test starts, thus the test may see metrics from previous + // tests. The best we can do here is to exhaust all records and find one that matches the given + // conditions. + while (true) { + MetricsRecord serverRecord; + try { + // On the server, the stats is finalized in ServerStreamListener.closed(), which can be run + // after the client receives the final status. So we use a timeout. + serverRecord = serverCensusFactory.pollRecord(1, TimeUnit.SECONDS); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + if (serverRecord == null) { + break; + } + try { + checkTags(serverRecord, true, method, status); + if (requests != null && responses != null) { + checkMetrics(serverRecord, true, requests, responses); + } + return; // passed + } catch (AssertionError e) { + // May be the fallout from a previous test, continue trying + checkFailure = e; + } + } + if (checkFailure == null) { + throw new AssertionError("No record found"); + } + throw checkFailure; + } + + private static void checkTags( + MetricsRecord record, boolean server, String methodName, Status.Code status) { + TagValue methodNameTag = record.tags.get( + server ? RpcConstants.RPC_SERVER_METHOD : RpcConstants.RPC_CLIENT_METHOD); + assertNotNull("method name tagged", methodNameTag); + assertEquals("method names match", methodName, methodNameTag.toString()); + TagValue statusTag = record.tags.get(RpcConstants.RPC_STATUS); + assertNotNull("status tagged", statusTag); + assertEquals(status.toString(), statusTag.toString()); + } + + private static void checkMetrics(MetricsRecord record, boolean server, + Collection requests, Collection responses) { + int uncompressedRequestsSize = 0; + for (MessageLite request : requests) { + uncompressedRequestsSize += request.getSerializedSize(); + } + int uncompressedResponsesSize = 0; + for (MessageLite response : responses) { + uncompressedResponsesSize += response.getSerializedSize(); + } + if (server) { + assertEquals(uncompressedRequestsSize, + record.getMetricAsLongOrFail(RpcConstants.RPC_SERVER_UNCOMPRESSED_REQUEST_BYTES)); + assertEquals(uncompressedResponsesSize, + record.getMetricAsLongOrFail(RpcConstants.RPC_SERVER_UNCOMPRESSED_RESPONSE_BYTES)); + // It's impossible to get the expected wire sizes because it may be compressed, so we just + // check if they are recorded. + assertNotNull(record.getMetric(RpcConstants.RPC_SERVER_REQUEST_BYTES)); + assertNotNull(record.getMetric(RpcConstants.RPC_SERVER_RESPONSE_BYTES)); + } else { + assertEquals(uncompressedRequestsSize, + record.getMetricAsLongOrFail(RpcConstants.RPC_CLIENT_UNCOMPRESSED_REQUEST_BYTES)); + assertEquals(uncompressedResponsesSize, + record.getMetricAsLongOrFail(RpcConstants.RPC_CLIENT_UNCOMPRESSED_RESPONSE_BYTES)); + // It's impossible to get the expected wire sizes because it may be compressed, so we just + // check if they are recorded. + assertNotNull(record.getMetric(RpcConstants.RPC_CLIENT_REQUEST_BYTES)); + assertNotNull(record.getMetric(RpcConstants.RPC_CLIENT_RESPONSE_BYTES)); + } + } } diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/StressTestClient.java b/interop-testing/src/main/java/io/grpc/testing/integration/StressTestClient.java index fa8af60168..c6f3711dcb 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/StressTestClient.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/StressTestClient.java @@ -484,6 +484,12 @@ public class StressTestClient { // Fixes https://github.com/grpc/grpc-java/issues/1812 return Integer.MAX_VALUE; } + + @Override + protected boolean metricsExpected() { + // TODO(zhangkun83): we may want to enable the real Census implementation in stress tests. + return false; + } } class WeightedTestCaseSelector { diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceClient.java b/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceClient.java index 4cfb052962..c9fbe7cd38 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceClient.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceClient.java @@ -314,6 +314,7 @@ public class TestServiceClient { .flowControlWindow(65 * 1024) .negotiationType(useTls ? NegotiationType.TLS : NegotiationType.PLAINTEXT) .sslContext(sslContext) + .censusContextFactory(getClientCensusFactory()) .build(); } else { OkHttpChannelBuilder builder = OkHttpChannelBuilder.forAddress(serverHost, serverPort); diff --git a/interop-testing/src/test/java/io/grpc/testing/integration/AutoWindowSizingOnTest.java b/interop-testing/src/test/java/io/grpc/testing/integration/AutoWindowSizingOnTest.java index c7fc8817fb..03477fed66 100644 --- a/interop-testing/src/test/java/io/grpc/testing/integration/AutoWindowSizingOnTest.java +++ b/interop-testing/src/test/java/io/grpc/testing/integration/AutoWindowSizingOnTest.java @@ -63,6 +63,7 @@ public class AutoWindowSizingOnTest extends AbstractInteropTest { return NettyChannelBuilder.forAddress("localhost", getPort()) .negotiationType(NegotiationType.PLAINTEXT) .maxMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE) + .censusContextFactory(getClientCensusFactory()) .build(); } } diff --git a/interop-testing/src/test/java/io/grpc/testing/integration/Http2NettyLocalChannelTest.java b/interop-testing/src/test/java/io/grpc/testing/integration/Http2NettyLocalChannelTest.java index daa7289269..0504c44e89 100644 --- a/interop-testing/src/test/java/io/grpc/testing/integration/Http2NettyLocalChannelTest.java +++ b/interop-testing/src/test/java/io/grpc/testing/integration/Http2NettyLocalChannelTest.java @@ -75,6 +75,7 @@ public class Http2NettyLocalChannelTest extends AbstractInteropTest { .channelType(LocalChannel.class) .flowControlWindow(65 * 1024) .maxMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE) + .censusContextFactory(getClientCensusFactory()) .build(); } } diff --git a/interop-testing/src/test/java/io/grpc/testing/integration/Http2NettyTest.java b/interop-testing/src/test/java/io/grpc/testing/integration/Http2NettyTest.java index 7e11b01099..8270ac6b86 100644 --- a/interop-testing/src/test/java/io/grpc/testing/integration/Http2NettyTest.java +++ b/interop-testing/src/test/java/io/grpc/testing/integration/Http2NettyTest.java @@ -92,6 +92,7 @@ public class Http2NettyTest extends AbstractInteropTest { .ciphers(TestUtils.preferredTestCiphers(), SupportedCipherSuiteFilter.INSTANCE) .sslProvider(SslProvider.OPENSSL) .build()) + .censusContextFactory(getClientCensusFactory()) .build(); } catch (Exception ex) { throw new RuntimeException(ex); diff --git a/interop-testing/src/test/java/io/grpc/testing/integration/Http2OkHttpTest.java b/interop-testing/src/test/java/io/grpc/testing/integration/Http2OkHttpTest.java index 936cd1546a..bf750796be 100644 --- a/interop-testing/src/test/java/io/grpc/testing/integration/Http2OkHttpTest.java +++ b/interop-testing/src/test/java/io/grpc/testing/integration/Http2OkHttpTest.java @@ -31,6 +31,7 @@ package io.grpc.testing.integration; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -107,6 +108,7 @@ public class Http2OkHttpTest extends AbstractInteropTest { .cipherSuites(TestUtils.preferredTestCiphers().toArray(new String[0])) .tlsVersions(ConnectionSpec.MODERN_TLS.tlsVersions().toArray(new TlsVersion[0])) .build()) + .censusContextFactory(getClientCensusFactory()) .overrideAuthority(GrpcUtil.authorityFromHostAndPort( TestUtils.TEST_SERVER_HOST, getPort())); try { @@ -133,12 +135,14 @@ public class Http2OkHttpTest extends AbstractInteropTest { StreamRecorder recorder = StreamRecorder.create(); StreamObserver requestStream = asyncStub.fullDuplexCall(recorder); - requestStream.onNext(requestBuilder.build()); + Messages.StreamingOutputCallRequest request = requestBuilder.build(); + requestStream.onNext(request); recorder.firstValue().get(); requestStream.onError(new Exception("failed")); recorder.awaitCompletion(); - emptyUnary(); + + assertEquals(EMPTY, blockingStub.emptyCall(EMPTY)); } @Test(timeout = 10000) diff --git a/interop-testing/src/test/java/io/grpc/testing/integration/InProcessTest.java b/interop-testing/src/test/java/io/grpc/testing/integration/InProcessTest.java index 7f4aa2ef53..2396288662 100644 --- a/interop-testing/src/test/java/io/grpc/testing/integration/InProcessTest.java +++ b/interop-testing/src/test/java/io/grpc/testing/integration/InProcessTest.java @@ -58,6 +58,14 @@ public class InProcessTest extends AbstractInteropTest { @Override protected ManagedChannel createChannel() { - return InProcessChannelBuilder.forName(serverName).build(); + return InProcessChannelBuilder.forName(serverName) + .censusContextFactory(getClientCensusFactory()).build(); + } + + @Override + protected boolean metricsExpected() { + // TODO(zhangkun83): InProcessTransport by-passes framer and deframer, thus message sizses are + // not counted. (https://github.com/grpc/grpc-java/issues/2284) + return false; } } diff --git a/interop-testing/src/test/java/io/grpc/testing/integration/TransportCompressionTest.java b/interop-testing/src/test/java/io/grpc/testing/integration/TransportCompressionTest.java index 1330cf87ab..0efdcf999a 100644 --- a/interop-testing/src/test/java/io/grpc/testing/integration/TransportCompressionTest.java +++ b/interop-testing/src/test/java/io/grpc/testing/integration/TransportCompressionTest.java @@ -152,6 +152,7 @@ public class TransportCompressionTest extends AbstractInteropTest { .maxMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE) .decompressorRegistry(decompressors) .compressorRegistry(compressors) + .censusContextFactory(getClientCensusFactory()) .intercept(new ClientInterceptor() { @Override public ClientCall interceptCall( diff --git a/netty/src/main/java/io/grpc/netty/NettyClientStream.java b/netty/src/main/java/io/grpc/netty/NettyClientStream.java index c16e39a668..c2f5637f11 100644 --- a/netty/src/main/java/io/grpc/netty/NettyClientStream.java +++ b/netty/src/main/java/io/grpc/netty/NettyClientStream.java @@ -45,6 +45,7 @@ import io.grpc.internal.AbstractClientStream2; import io.grpc.internal.ClientStreamListener; import io.grpc.internal.GrpcUtil; import io.grpc.internal.Http2ClientStreamTransportState; +import io.grpc.internal.StatsTraceContext; import io.grpc.internal.WritableBuffer; import io.netty.buffer.ByteBuf; import io.netty.channel.Channel; @@ -77,8 +78,8 @@ class NettyClientStream extends AbstractClientStream2 { NettyClientStream(TransportState state, MethodDescriptor method, Metadata headers, Channel channel, AsciiString authority, AsciiString scheme, - AsciiString userAgent) { - super(new NettyWritableBufferAllocator(channel.alloc())); + AsciiString userAgent, StatsTraceContext statsTraceCtx) { + super(new NettyWritableBufferAllocator(channel.alloc()), statsTraceCtx); this.state = checkNotNull(state, "transportState"); this.writeQueue = state.handler.getWriteQueue(); this.method = checkNotNull(method, "method"); @@ -183,8 +184,9 @@ class NettyClientStream extends AbstractClientStream2 { private int id; private Http2Stream http2Stream; - public TransportState(NettyClientHandler handler, int maxMessageSize) { - super(maxMessageSize); + public TransportState(NettyClientHandler handler, int maxMessageSize, + StatsTraceContext statsTraceCtx) { + super(maxMessageSize, statsTraceCtx); this.handler = checkNotNull(handler, "handler"); } diff --git a/netty/src/main/java/io/grpc/netty/NettyClientTransport.java b/netty/src/main/java/io/grpc/netty/NettyClientTransport.java index 51c38f950f..a75ef18c6c 100644 --- a/netty/src/main/java/io/grpc/netty/NettyClientTransport.java +++ b/netty/src/main/java/io/grpc/netty/NettyClientTransport.java @@ -45,6 +45,7 @@ import io.grpc.internal.ClientStream; import io.grpc.internal.ConnectionClientTransport; import io.grpc.internal.GrpcUtil; import io.grpc.internal.Http2Ping; +import io.grpc.internal.StatsTraceContext; import io.netty.bootstrap.Bootstrap; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; @@ -116,23 +117,25 @@ class NettyClientTransport implements ConnectionClientTransport { } @Override - public ClientStream newStream(MethodDescriptor method, Metadata headers, CallOptions - callOptions) { + public ClientStream newStream(MethodDescriptor method, Metadata headers, + CallOptions callOptions, StatsTraceContext statsTraceCtx) { Preconditions.checkNotNull(method, "method"); Preconditions.checkNotNull(headers, "headers"); + Preconditions.checkNotNull(statsTraceCtx, "statsTraceCtx"); return new NettyClientStream( - new NettyClientStream.TransportState(handler, maxMessageSize) { + new NettyClientStream.TransportState(handler, maxMessageSize, statsTraceCtx) { @Override protected Status statusFromFailedFuture(ChannelFuture f) { return NettyClientTransport.this.statusFromFailedFuture(f); } }, - method, headers, channel, authority, negotiationHandler.scheme(), userAgent); + method, headers, channel, authority, negotiationHandler.scheme(), userAgent, + statsTraceCtx); } @Override public ClientStream newStream(MethodDescriptor method, Metadata headers) { - return newStream(method, headers, CallOptions.DEFAULT); + return newStream(method, headers, CallOptions.DEFAULT, StatsTraceContext.NOOP); } @Override diff --git a/netty/src/main/java/io/grpc/netty/NettyServerHandler.java b/netty/src/main/java/io/grpc/netty/NettyServerHandler.java index b3bae73296..9c3a37949c 100644 --- a/netty/src/main/java/io/grpc/netty/NettyServerHandler.java +++ b/netty/src/main/java/io/grpc/netty/NettyServerHandler.java @@ -49,6 +49,7 @@ import io.grpc.Status; import io.grpc.internal.GrpcUtil; import io.grpc.internal.ServerStreamListener; import io.grpc.internal.ServerTransportListener; +import io.grpc.internal.StatsTraceContext; import io.grpc.netty.GrpcHttp2HeadersDecoder.GrpcHttp2ServerHeadersDecoder; import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelFuture; @@ -193,14 +194,14 @@ class NettyServerHandler extends AbstractNettyHandler { // method. Http2Stream http2Stream = requireHttp2Stream(streamId); - NettyServerStream.TransportState state = - new NettyServerStream.TransportState(this, http2Stream, maxMessageSize); - NettyServerStream stream = new NettyServerStream(ctx.channel(), state, attributes); - Metadata metadata = Utils.convertHeaders(headers); - - ServerStreamListener listener = - transportListener.streamCreated(stream, method, metadata); + StatsTraceContext statsTraceCtx = + checkNotNull(transportListener.methodDetermined(method, metadata), "statsTraceCtx"); + NettyServerStream.TransportState state = new NettyServerStream.TransportState( + this, http2Stream, maxMessageSize, statsTraceCtx); + NettyServerStream stream = new NettyServerStream(ctx.channel(), state, attributes, + statsTraceCtx); + ServerStreamListener listener = transportListener.streamCreated(stream, method, metadata); // TODO(ejona): this could be racy since stream could have been used before getting here. All // cases appear to be fine, but some are almost only by happenstance and it is difficult to // audit. It would be good to improve the API to be less prone to races. diff --git a/netty/src/main/java/io/grpc/netty/NettyServerStream.java b/netty/src/main/java/io/grpc/netty/NettyServerStream.java index 81bd69429e..cef513a09d 100644 --- a/netty/src/main/java/io/grpc/netty/NettyServerStream.java +++ b/netty/src/main/java/io/grpc/netty/NettyServerStream.java @@ -37,6 +37,7 @@ import io.grpc.Attributes; import io.grpc.Metadata; import io.grpc.Status; import io.grpc.internal.AbstractServerStream; +import io.grpc.internal.StatsTraceContext; import io.grpc.internal.WritableBuffer; import io.netty.buffer.ByteBuf; import io.netty.channel.Channel; @@ -61,8 +62,9 @@ class NettyServerStream extends AbstractServerStream { private final WriteQueue writeQueue; private final Attributes attributes; - public NettyServerStream(Channel channel, TransportState state, Attributes transportAttrs) { - super(new NettyWritableBufferAllocator(channel.alloc())); + public NettyServerStream(Channel channel, TransportState state, Attributes transportAttrs, + StatsTraceContext statsTraceCtx) { + super(new NettyWritableBufferAllocator(channel.alloc()), statsTraceCtx); this.state = checkNotNull(state, "transportState"); this.channel = checkNotNull(channel, "channel"); this.writeQueue = state.handler.getWriteQueue(); @@ -142,8 +144,9 @@ class NettyServerStream extends AbstractServerStream { private final Http2Stream http2Stream; private final NettyServerHandler handler; - public TransportState(NettyServerHandler handler, Http2Stream http2Stream, int maxMessageSize) { - super(maxMessageSize); + public TransportState(NettyServerHandler handler, Http2Stream http2Stream, int maxMessageSize, + StatsTraceContext statsTraceCtx) { + super(maxMessageSize, statsTraceCtx); this.http2Stream = checkNotNull(http2Stream, "http2Stream"); this.handler = checkNotNull(handler, "handler"); } diff --git a/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java b/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java index 3b1f5c9d02..fe6113305d 100644 --- a/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java @@ -64,6 +64,7 @@ import io.grpc.internal.ClientStreamListener; import io.grpc.internal.ClientTransport; import io.grpc.internal.ClientTransport.PingCallback; import io.grpc.internal.GrpcUtil; +import io.grpc.internal.StatsTraceContext; import io.grpc.netty.GrpcHttp2HeadersDecoder.GrpcHttp2ClientHeadersDecoder; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufUtil; @@ -568,7 +569,7 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase cmdCap = ArgumentCaptor.forClass(CreateStreamCommand.class); @@ -404,7 +405,8 @@ public class NettyClientStreamTest extends NettyStreamTestBase transports = new ArrayList(); private final NioEventLoopGroup group = new NioEventLoopGroup(1); + private final StatsTraceContext statsTraceCtx = StatsTraceContext.NOOP; + private InetSocketAddress address; private String authority; private NettyServer server; @@ -109,6 +113,7 @@ public class NettyClientTransportTest { @After public void teardown() throws Exception { + Context.ROOT.attach(); for (NettyClientTransport transport : transports) { transport.shutdown(); } @@ -433,6 +438,10 @@ public class NettyClientTransportTest { public ServerTransportListener transportCreated(final ServerTransport transport) { transports.add((NettyServerTransport) transport); return new ServerTransportListener() { + @Override + public StatsTraceContext methodDetermined(String method, Metadata headers) { + return StatsTraceContext.NOOP; + } @Override public ServerStreamListener streamCreated(final ServerStream stream, String method, diff --git a/netty/src/test/java/io/grpc/netty/NettyHandlerTestBase.java b/netty/src/test/java/io/grpc/netty/NettyHandlerTestBase.java index 712a9929f4..fd3ef962f3 100644 --- a/netty/src/test/java/io/grpc/netty/NettyHandlerTestBase.java +++ b/netty/src/test/java/io/grpc/netty/NettyHandlerTestBase.java @@ -40,6 +40,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import io.grpc.internal.MessageFramer; +import io.grpc.internal.StatsTraceContext; import io.grpc.internal.WritableBuffer; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; @@ -159,7 +160,7 @@ public abstract class NettyHandlerTestBase { compressionFrame.writeBytes(bytebuf); } } - }, new NettyWritableBufferAllocator(ByteBufAllocator.DEFAULT)); + }, new NettyWritableBufferAllocator(ByteBufAllocator.DEFAULT), StatsTraceContext.NOOP); framer.writePayload(new ByteArrayInputStream(content)); framer.flush(); ChannelHandlerContext ctx = newMockContext(); diff --git a/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java b/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java index cb81fb9fdb..f59e16ba99 100644 --- a/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java @@ -62,6 +62,7 @@ import io.grpc.internal.GrpcUtil; import io.grpc.internal.ServerStream; import io.grpc.internal.ServerStreamListener; import io.grpc.internal.ServerTransportListener; +import io.grpc.internal.StatsTraceContext; import io.grpc.netty.GrpcHttp2HeadersDecoder.GrpcHttp2ServerHeadersDecoder; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufUtil; @@ -103,6 +104,8 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase method, final Metadata - headers, CallOptions callOptions) { + public OkHttpClientStream newStream(final MethodDescriptor method, + final Metadata headers, CallOptions callOptions, StatsTraceContext statsTraceCtx) { Preconditions.checkNotNull(method, "method"); Preconditions.checkNotNull(headers, "headers"); + Preconditions.checkNotNull(statsTraceCtx, "statsTraceCtx"); return new OkHttpClientStream(method, headers, frameWriter, OkHttpClientTransport.this, - outboundFlow, lock, maxMessageSize, defaultAuthority, userAgent); + outboundFlow, lock, maxMessageSize, defaultAuthority, userAgent, statsTraceCtx); } @Override public OkHttpClientStream newStream(final MethodDescriptor method, final Metadata headers) { - return newStream(method, headers, CallOptions.DEFAULT); + return newStream(method, headers, CallOptions.DEFAULT, StatsTraceContext.NOOP); } @GuardedBy("lock") diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientStreamTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientStreamTest.java index 8b07ab650a..59b3943e53 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientStreamTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientStreamTest.java @@ -44,6 +44,7 @@ import io.grpc.MethodDescriptor.MethodType; import io.grpc.Status; import io.grpc.internal.ClientStreamListener; import io.grpc.internal.GrpcUtil; +import io.grpc.internal.StatsTraceContext; import io.grpc.okhttp.internal.framed.ErrorCode; import io.grpc.okhttp.internal.framed.Header; @@ -84,7 +85,7 @@ public class OkHttpClientStreamTest { methodDescriptor = MethodDescriptor.create( MethodType.UNARY, "/testService/test", marshaller, marshaller); stream = new OkHttpClientStream(methodDescriptor, new Metadata(), frameWriter, transport, - flowController, lock, MAX_MESSAGE_SIZE, "localhost", "userAgent"); + flowController, lock, MAX_MESSAGE_SIZE, "localhost", "userAgent", StatsTraceContext.NOOP); } @Test @@ -140,7 +141,8 @@ public class OkHttpClientStreamTest { Metadata metaData = new Metadata(); metaData.put(GrpcUtil.USER_AGENT_KEY, "misbehaving-application"); stream = new OkHttpClientStream(methodDescriptor, metaData, frameWriter, transport, - flowController, lock, MAX_MESSAGE_SIZE, "localhost", "good-application"); + flowController, lock, MAX_MESSAGE_SIZE, "localhost", "good-application", + StatsTraceContext.NOOP); stream.start(new BaseClientStreamListener()); stream.start(3); @@ -154,7 +156,8 @@ public class OkHttpClientStreamTest { Metadata metaData = new Metadata(); metaData.put(GrpcUtil.USER_AGENT_KEY, "misbehaving-application"); stream = new OkHttpClientStream(methodDescriptor, metaData, frameWriter, transport, - flowController, lock, MAX_MESSAGE_SIZE, "localhost", "good-application"); + flowController, lock, MAX_MESSAGE_SIZE, "localhost", "good-application", + StatsTraceContext.NOOP); stream.start(new BaseClientStreamListener()); stream.start(3); diff --git a/testing/src/main/java/io/grpc/internal/testing/AbstractTransportTest.java b/testing/src/main/java/io/grpc/internal/testing/AbstractTransportTest.java index 159dde6e49..5345b36876 100644 --- a/testing/src/main/java/io/grpc/internal/testing/AbstractTransportTest.java +++ b/testing/src/main/java/io/grpc/internal/testing/AbstractTransportTest.java @@ -72,6 +72,7 @@ import io.grpc.internal.ServerStream; import io.grpc.internal.ServerStreamListener; import io.grpc.internal.ServerTransport; import io.grpc.internal.ServerTransportListener; +import io.grpc.internal.StatsTraceContext; import org.junit.After; import org.junit.Before; @@ -803,7 +804,7 @@ public abstract class AbstractTransportTest { verify(mockClientStreamListener, never()).closed(any(Status.class), any(Metadata.class)); } - @Test(timeout = 5000) + @Test public void clientCancelFromWithinMessageRead() throws Exception { server.start(serverListener); client = newClientTransport(server); @@ -852,7 +853,7 @@ public abstract class AbstractTransportTest { serverStream.flush(); // Block until closedCalled was set. - closedCalled.get(); + closedCalled.get(5, TimeUnit.SECONDS); serverStream.close(Status.OK, new Metadata()); } @@ -1156,6 +1157,11 @@ public abstract class AbstractTransportTest { this.transport = transport; } + @Override + public StatsTraceContext methodDetermined(String method, Metadata headers) { + return StatsTraceContext.NOOP; + } + @Override public ServerStreamListener streamCreated(ServerStream stream, String method, Metadata headers) { diff --git a/testing/src/main/java/io/grpc/internal/testing/CensusTestUtils.java b/testing/src/main/java/io/grpc/internal/testing/CensusTestUtils.java new file mode 100644 index 0000000000..bc3ea14321 --- /dev/null +++ b/testing/src/main/java/io/grpc/internal/testing/CensusTestUtils.java @@ -0,0 +1,246 @@ +/* + * Copyright 2016, Google Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +package io.grpc.internal.testing; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.census.CensusContext; +import com.google.census.CensusContextFactory; +import com.google.census.Metric; +import com.google.census.MetricMap; +import com.google.census.MetricName; +import com.google.census.TagKey; +import com.google.census.TagValue; +import com.google.common.collect.ImmutableMap; + +import io.grpc.Context; + +import java.nio.ByteBuffer; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; + +import javax.annotation.Nullable; + +public class CensusTestUtils { + private CensusTestUtils() { + } + + public static class MetricsRecord { + public final ImmutableMap tags; + public final MetricMap metrics; + + private MetricsRecord(ImmutableMap tags, MetricMap metrics) { + this.tags = tags; + this.metrics = metrics; + } + + /** + * Returns the value of a metric, or {@code null} if not found. + */ + @Nullable + public Double getMetric(MetricName metricName) { + for (Metric m : metrics) { + if (m.getName().equals(metricName)) { + return m.getValue(); + } + } + return null; + } + + /** + * Returns the value of a metric converted to long, or throw if not found. + */ + public long getMetricAsLongOrFail(MetricName metricName) { + Double doubleValue = getMetric(metricName); + checkNotNull(doubleValue, "Metric not found: %s", metricName.toString()); + long longValue = (long) (Math.abs(doubleValue) + 0.0001); + if (doubleValue < 0) { + longValue = -longValue; + } + return longValue; + } + } + + public static final TagKey EXTRA_TAG = new TagKey("/rpc/test/extratag"); + + private static final String EXTRA_TAG_HEADER_VALUE_PREFIX = "extratag:"; + private static final String NO_EXTRA_TAG_HEADER_VALUE_PREFIX = "noextratag"; + + /** + * A factory that makes fake {@link CensusContext}s and saves the created contexts to be + * accessible from {@link #pollContextOrFail}. The contexts it has created would save metrics + * records to be accessible from {@link #pollRecord()} and {@link #pollRecord(long, TimeUnit)}, + * until {@link #rolloverRecords} is called. + */ + public static final class FakeCensusContextFactory extends CensusContextFactory { + private BlockingQueue records; + public final BlockingQueue contexts = + new LinkedBlockingQueue(); + private static final Context.Key CONTEXT_KEY = + Context.key("fakeCensusContext"); + private final FakeCensusContext defaultContext; + + /** + * Constructor. + */ + public FakeCensusContextFactory() { + rolloverRecords(); + defaultContext = new FakeCensusContext(ImmutableMap.of(), this); + // The records on the default context is not visible from pollRecord(), just like it's + // not visible from pollContextOrFail() either. + rolloverRecords(); + } + + public CensusContext pollContextOrFail() { + CensusContext cc = contexts.poll(); + return checkNotNull(cc); + } + + public MetricsRecord pollRecord() { + return getCurrentRecordSink().poll(); + } + + public MetricsRecord pollRecord(long timeout, TimeUnit unit) throws InterruptedException { + return getCurrentRecordSink().poll(timeout, unit); + } + + @Override + public CensusContext deserialize(ByteBuffer buffer) { + String serializedString = new String(buffer.array()); + if (serializedString.startsWith(EXTRA_TAG_HEADER_VALUE_PREFIX)) { + return getDefault().with(EXTRA_TAG, + new TagValue(serializedString.substring(EXTRA_TAG_HEADER_VALUE_PREFIX.length()))); + } else if (serializedString.startsWith(NO_EXTRA_TAG_HEADER_VALUE_PREFIX)) { + return getDefault(); + } else { + return null; + } + } + + @Override + public FakeCensusContext getDefault() { + return defaultContext; + } + + /** + * Disconnect this factory with the contexts it has created so far. The records from those + * contexts will not show up in {@link #pollRecord}. Useful for isolating the records between + * test cases. + */ + // This needs to be synchronized with getCurrentRecordSink() which may run concurrently. + public synchronized void rolloverRecords() { + records = new LinkedBlockingQueue(); + } + + private synchronized BlockingQueue getCurrentRecordSink() { + return records; + } + } + + public static class FakeCensusContext extends CensusContext { + private final ImmutableMap tags; + private final FakeCensusContextFactory factory; + private final BlockingQueue recordSink; + + private FakeCensusContext(ImmutableMap tags, + FakeCensusContextFactory factory) { + this.tags = tags; + this.factory = factory; + this.recordSink = factory.getCurrentRecordSink(); + } + + @Override + public Builder builder() { + return new FakeCensusContextBuilder(this); + } + + @Override + public CensusContext record(MetricMap metrics) { + recordSink.add(new MetricsRecord(tags, metrics)); + return this; + } + + @Override + public ByteBuffer serialize() { + TagValue extraTagValue = tags.get(EXTRA_TAG); + if (extraTagValue == null) { + return ByteBuffer.wrap(NO_EXTRA_TAG_HEADER_VALUE_PREFIX.getBytes()); + } else { + return ByteBuffer.wrap( + (EXTRA_TAG_HEADER_VALUE_PREFIX + extraTagValue.toString()).getBytes()); + } + } + + @Override + public String toString() { + return "[tags=" + tags + "]"; + } + + @Override + public boolean equals(Object other) { + if (!(other instanceof FakeCensusContext)) { + return false; + } + FakeCensusContext otherCtx = (FakeCensusContext) other; + return tags.equals(otherCtx.tags); + } + + @Override + public int hashCode() { + return tags.hashCode(); + } + } + + private static class FakeCensusContextBuilder extends CensusContext.Builder { + private final ImmutableMap.Builder tagsBuilder = ImmutableMap.builder(); + private final FakeCensusContext base; + + private FakeCensusContextBuilder(FakeCensusContext base) { + this.base = base; + tagsBuilder.putAll(base.tags); + } + + @Override + public CensusContext.Builder set(TagKey key, TagValue value) { + tagsBuilder.put(key, value); + return this; + } + + @Override + public CensusContext build() { + FakeCensusContext context = new FakeCensusContext(tagsBuilder.build(), base.factory); + base.factory.contexts.add(context); + return context; + } + } +}