core: Census integration for stats (#2262)

Highlights
==========

StatsTraceContext
-----------------

The bridge between gRPC library and Census. It keeps track of the total
payload sizes and the elapsed time of a Call. The rest of the gRPC code
doesn't invoke Census directly.

Context propagation
-------------------

StatsTraceContext carries CensusContext (and the upcoming TraceContext)
and is attached to the gRPC Context.

1. StatsTraceContext is created by ManagedChannelImpl, by calling
createClientContext(), which inherits the current CensusContext if available.

2. ManagedChannelImpl passes StatsTraceContext to ClientCallImpl, then
to the stream, then to the framer and deframer explicitly.

3. ClientCallImpl propagates the CensusContext to the headers.

1. ServerImpl creates a StatsTraceContext by implementing a new callback
method StatsTraceContext methodDetermined(MethodDescriptor, Metadata) on
ServerTransportListener.

2. NettyServerHandler calls methodDetermined() before creating the
stream, and passes the StatsTraceContext to the stream.

3. When ServerImpl creates the gRPC Context for the new ServerCall, it
calls the new method statsTraceContext() on ServerStream and puts the
StatsTraceContext in the Context.

Metrics recording
-----------------

1. Client-side start time: when ClientCallImpl is created

2. Server-side start time: when methodDetermined() is called

3. Server-side end time: in ServerStreamListener.closed(), but before
calling onComplete() or onCancel() on ServerCall.Listener.

4. Client-side end time: in ClientStreamListener.closed(), but before
calling onClonse() on ClientCall.Listener

Message sizes are recorded in MessageFramer and MessageDeframer. Both
the uncompressed and wire (possibly compressed) payload sizes are
counted.

TODOs
=====

The CensusContext created from headers on the server side should be
attached to the gRPC Context for the call.  It's not done at this moment
because Census lacks the proper API to do it. It only affects tracing
and resource accounting, but doesn't affect stats functionality
This commit is contained in:
Kun Zhang 2016-10-06 17:15:24 -07:00 committed by GitHub
parent 4e5765a93f
commit 132f7a9a33
73 changed files with 1778 additions and 308 deletions

View File

@ -151,6 +151,7 @@ subprojects {
google_auth_credentials: 'com.google.auth:google-auth-library-credentials:0.4.0', google_auth_credentials: 'com.google.auth:google-auth-library-credentials:0.4.0',
okhttp: 'com.squareup.okhttp:okhttp:2.5.0', okhttp: 'com.squareup.okhttp:okhttp:2.5.0',
okio: 'com.squareup.okio:okio:1.6.0', okio: 'com.squareup.okio:okio:1.6.0',
census_api: 'com.google.census:census-api:0.2.0',
protobuf: "com.google.protobuf:protobuf-java:${protobufVersion}", protobuf: "com.google.protobuf:protobuf-java:${protobufVersion}",
// swap to ${protobufVersion} after versions align again // swap to ${protobufVersion} after versions align again
protobuf_lite: "com.google.protobuf:protobuf-lite:3.0.1", protobuf_lite: "com.google.protobuf:protobuf-lite:3.0.1",

View File

@ -8,7 +8,8 @@ dependencies {
compile libraries.guava, compile libraries.guava,
libraries.errorprone, libraries.errorprone,
libraries.jsr305, libraries.jsr305,
project(':grpc-context') project(':grpc-context'),
libraries.census_api
testCompile project(':grpc-testing') testCompile project(':grpc-testing')
} }

View File

@ -31,6 +31,7 @@
package io.grpc.inprocess; package io.grpc.inprocess;
import com.google.census.CensusContextFactory;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import io.grpc.ExperimentalApi; import io.grpc.ExperimentalApi;
@ -38,6 +39,7 @@ import io.grpc.Internal;
import io.grpc.internal.AbstractManagedChannelImplBuilder; import io.grpc.internal.AbstractManagedChannelImplBuilder;
import io.grpc.internal.ClientTransportFactory; import io.grpc.internal.ClientTransportFactory;
import io.grpc.internal.ConnectionClientTransport; import io.grpc.internal.ConnectionClientTransport;
import io.grpc.internal.NoopCensusContextFactory;
import java.net.SocketAddress; import java.net.SocketAddress;
@ -65,6 +67,10 @@ public class InProcessChannelBuilder extends
private InProcessChannelBuilder(String name) { private InProcessChannelBuilder(String name) {
super(new InProcessSocketAddress(name), "localhost"); super(new InProcessSocketAddress(name), "localhost");
this.name = Preconditions.checkNotNull(name, "name"); this.name = Preconditions.checkNotNull(name, "name");
// TODO(zhangkun83): InProcessTransport by-passes framer and deframer, thus message sizses are
// not counted. Therefore, we disable Census for now.
// (https://github.com/grpc/grpc-java/issues/2284)
super.censusContextFactory(NoopCensusContextFactory.INSTANCE);
} }
/** /**
@ -80,6 +86,16 @@ public class InProcessChannelBuilder extends
return new InProcessClientTransportFactory(name); return new InProcessClientTransportFactory(name);
} }
@Internal
@Override
public InProcessChannelBuilder censusContextFactory(CensusContextFactory censusFactory) {
// TODO(zhangkun83): InProcessTransport by-passes framer and deframer, thus message sizses are
// not counted. Census is disabled by using a NOOP Census factory in the constructor, and here
// we prevent the user from overriding it.
// (https://github.com/grpc/grpc-java/issues/2284)
return this;
}
/** /**
* Creates InProcess transports. Exposed for internal use, as it should be private. * Creates InProcess transports. Exposed for internal use, as it should be private.
*/ */

View File

@ -31,10 +31,13 @@
package io.grpc.inprocess; package io.grpc.inprocess;
import com.google.census.CensusContextFactory;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import io.grpc.ExperimentalApi; import io.grpc.ExperimentalApi;
import io.grpc.Internal;
import io.grpc.internal.AbstractServerImplBuilder; import io.grpc.internal.AbstractServerImplBuilder;
import io.grpc.internal.NoopCensusContextFactory;
import java.io.File; import java.io.File;
@ -61,6 +64,10 @@ public final class InProcessServerBuilder
private InProcessServerBuilder(String name) { private InProcessServerBuilder(String name) {
this.name = Preconditions.checkNotNull(name, "name"); this.name = Preconditions.checkNotNull(name, "name");
// TODO(zhangkun83): InProcessTransport by-passes framer and deframer, thus message sizses are
// not counted. Therefore, we disable Census for now.
// (https://github.com/grpc/grpc-java/issues/2284)
super.censusContextFactory(NoopCensusContextFactory.INSTANCE);
} }
@Override @Override
@ -72,4 +79,14 @@ public final class InProcessServerBuilder
public InProcessServerBuilder useTransportSecurity(File certChain, File privateKey) { public InProcessServerBuilder useTransportSecurity(File certChain, File privateKey) {
throw new UnsupportedOperationException("TLS not supported in InProcessServer"); throw new UnsupportedOperationException("TLS not supported in InProcessServer");
} }
@Internal
@Override
public InProcessServerBuilder censusContextFactory(CensusContextFactory censusFactory) {
// TODO(zhangkun83): InProcessTransport by-passes framer and deframer, thus message sizses are
// not counted. Census is disabled by using a NOOP Census factory in the constructor, and here
// we prevent the user from overriding it.
// (https://github.com/grpc/grpc-java/issues/2284)
return this;
}
} }

View File

@ -51,6 +51,7 @@ import io.grpc.internal.ServerStream;
import io.grpc.internal.ServerStreamListener; import io.grpc.internal.ServerStreamListener;
import io.grpc.internal.ServerTransport; import io.grpc.internal.ServerTransport;
import io.grpc.internal.ServerTransportListener; import io.grpc.internal.ServerTransportListener;
import io.grpc.internal.StatsTraceContext;
import java.io.InputStream; import java.io.InputStream;
import java.util.ArrayDeque; import java.util.ArrayDeque;
@ -125,7 +126,8 @@ class InProcessTransport implements ServerTransport, ConnectionClientTransport {
@Override @Override
public synchronized ClientStream newStream( public synchronized ClientStream newStream(
final MethodDescriptor<?, ?> method, final Metadata headers, final CallOptions callOptions) { final MethodDescriptor<?, ?> method, final Metadata headers, final CallOptions callOptions,
StatsTraceContext clientStatsTraceContext) {
if (shutdownStatus != null) { if (shutdownStatus != null) {
final Status capturedStatus = shutdownStatus; final Status capturedStatus = shutdownStatus;
return new NoopClientStream() { return new NoopClientStream() {
@ -135,14 +137,15 @@ class InProcessTransport implements ServerTransport, ConnectionClientTransport {
} }
}; };
} }
StatsTraceContext serverStatsTraceContext = serverTransportListener.methodDetermined(
return new InProcessStream(method, headers).clientStream; method.getFullMethodName(), headers);
return new InProcessStream(method, headers, serverStatsTraceContext).clientStream;
} }
@Override @Override
public synchronized ClientStream newStream( public synchronized ClientStream newStream(
final MethodDescriptor<?, ?> method, final Metadata headers) { final MethodDescriptor<?, ?> method, final Metadata headers) {
return newStream(method, headers, CallOptions.DEFAULT); return newStream(method, headers, CallOptions.DEFAULT, StatsTraceContext.NOOP);
} }
@Override @Override
@ -231,13 +234,16 @@ class InProcessTransport implements ServerTransport, ConnectionClientTransport {
private class InProcessStream { private class InProcessStream {
private final InProcessServerStream serverStream = new InProcessServerStream(); private final InProcessServerStream serverStream = new InProcessServerStream();
private final InProcessClientStream clientStream = new InProcessClientStream(); private final InProcessClientStream clientStream = new InProcessClientStream();
private final StatsTraceContext serverStatsTraceContext;
private final Metadata headers; private final Metadata headers;
private MethodDescriptor<?, ?> method; private final MethodDescriptor<?, ?> method;
private InProcessStream(MethodDescriptor<?, ?> method, Metadata headers) { private InProcessStream(MethodDescriptor<?, ?> method, Metadata headers,
StatsTraceContext serverStatsTraceContext) {
this.method = checkNotNull(method, "method"); this.method = checkNotNull(method, "method");
this.headers = checkNotNull(headers, "headers"); this.headers = checkNotNull(headers, "headers");
this.serverStatsTraceContext =
checkNotNull(serverStatsTraceContext, "serverStatsTraceContext");
} }
// Can be called multiple times due to races on both client and server closing at same time. // Can be called multiple times due to races on both client and server closing at same time.
@ -408,6 +414,11 @@ class InProcessTransport implements ServerTransport, ConnectionClientTransport {
@Override public Attributes attributes() { @Override public Attributes attributes() {
return serverStreamAttributes; return serverStreamAttributes;
} }
@Override
public StatsTraceContext statsTraceContext() {
return serverStatsTraceContext;
}
} }
private class InProcessClientStream implements ClientStream { private class InProcessClientStream implements ClientStream {

View File

@ -63,9 +63,9 @@ public abstract class AbstractClientStream extends AbstractStream
private Runnable closeListenerTask; private Runnable closeListenerTask;
private volatile boolean cancelled; private volatile boolean cancelled;
protected AbstractClientStream(WritableBufferAllocator bufferAllocator, protected AbstractClientStream(WritableBufferAllocator bufferAllocator, int maxMessageSize,
int maxMessageSize) { StatsTraceContext statsTraceCtx) {
super(bufferAllocator, maxMessageSize); super(bufferAllocator, maxMessageSize, statsTraceCtx);
} }
@Override @Override

View File

@ -94,8 +94,9 @@ public abstract class AbstractClientStream2 extends AbstractStream2
*/ */
private volatile boolean cancelled; private volatile boolean cancelled;
protected AbstractClientStream2(WritableBufferAllocator bufferAllocator) { protected AbstractClientStream2(WritableBufferAllocator bufferAllocator,
framer = new MessageFramer(this, bufferAllocator); StatsTraceContext statsTraceCtx) {
framer = new MessageFramer(this, bufferAllocator, statsTraceCtx);
} }
/** {@inheritDoc} */ /** {@inheritDoc} */
@ -164,8 +165,8 @@ public abstract class AbstractClientStream2 extends AbstractStream2
*/ */
private boolean statusReported; private boolean statusReported;
protected TransportState(int maxMessageSize) { protected TransportState(int maxMessageSize, StatsTraceContext statsTraceCtx) {
super(maxMessageSize); super(maxMessageSize, statsTraceCtx);
} }
@VisibleForTesting @VisibleForTesting

View File

@ -34,6 +34,8 @@ package io.grpc.internal;
import static com.google.common.base.MoreObjects.firstNonNull; import static com.google.common.base.MoreObjects.firstNonNull;
import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkArgument;
import com.google.census.Census;
import com.google.census.CensusContextFactory;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import com.google.common.util.concurrent.MoreExecutors; import com.google.common.util.concurrent.MoreExecutors;
@ -42,6 +44,7 @@ import io.grpc.Attributes;
import io.grpc.ClientInterceptor; import io.grpc.ClientInterceptor;
import io.grpc.CompressorRegistry; import io.grpc.CompressorRegistry;
import io.grpc.DecompressorRegistry; import io.grpc.DecompressorRegistry;
import io.grpc.Internal;
import io.grpc.LoadBalancer; import io.grpc.LoadBalancer;
import io.grpc.ManagedChannelBuilder; import io.grpc.ManagedChannelBuilder;
import io.grpc.NameResolver; import io.grpc.NameResolver;
@ -118,6 +121,9 @@ public abstract class AbstractManagedChannelImplBuilder
private long idleTimeoutMillis = IDLE_MODE_DEFAULT_TIMEOUT_MILLIS; private long idleTimeoutMillis = IDLE_MODE_DEFAULT_TIMEOUT_MILLIS;
@Nullable
private CensusContextFactory censusFactory;
protected AbstractManagedChannelImplBuilder(String target) { protected AbstractManagedChannelImplBuilder(String target) {
this.target = Preconditions.checkNotNull(target, "target"); this.target = Preconditions.checkNotNull(target, "target");
this.directServerAddress = null; this.directServerAddress = null;
@ -227,6 +233,16 @@ public abstract class AbstractManagedChannelImplBuilder
return thisT(); return thisT();
} }
/**
* Override the default Census implementation. This is meant to be used in tests.
*/
@VisibleForTesting
@Internal
public T censusContextFactory(CensusContextFactory censusFactory) {
this.censusFactory = censusFactory;
return thisT();
}
@VisibleForTesting @VisibleForTesting
final long getIdleTimeoutMillis() { final long getIdleTimeoutMillis() {
return idleTimeoutMillis; return idleTimeoutMillis;
@ -266,7 +282,9 @@ public abstract class AbstractManagedChannelImplBuilder
firstNonNull(decompressorRegistry, DecompressorRegistry.getDefaultInstance()), firstNonNull(decompressorRegistry, DecompressorRegistry.getDefaultInstance()),
firstNonNull(compressorRegistry, CompressorRegistry.getDefaultInstance()), firstNonNull(compressorRegistry, CompressorRegistry.getDefaultInstance()),
GrpcUtil.TIMER_SERVICE, GrpcUtil.STOPWATCH_SUPPLIER, idleTimeoutMillis, GrpcUtil.TIMER_SERVICE, GrpcUtil.STOPWATCH_SUPPLIER, idleTimeoutMillis,
executor, userAgent, interceptors); executor, userAgent, interceptors,
firstNonNull(censusFactory,
firstNonNull(Census.getCensusContextFactory(), NoopCensusContextFactory.INSTANCE)));
} }
/** /**

View File

@ -34,6 +34,9 @@ package io.grpc.internal;
import static com.google.common.base.MoreObjects.firstNonNull; import static com.google.common.base.MoreObjects.firstNonNull;
import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkNotNull;
import com.google.census.Census;
import com.google.census.CensusContextFactory;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.util.concurrent.MoreExecutors; import com.google.common.util.concurrent.MoreExecutors;
import io.grpc.BindableService; import io.grpc.BindableService;
@ -87,6 +90,9 @@ public abstract class AbstractServerImplBuilder<T extends AbstractServerImplBuil
@Nullable @Nullable
private CompressorRegistry compressorRegistry; private CompressorRegistry compressorRegistry;
@Nullable
private CensusContextFactory censusFactory;
@Override @Override
public final T directExecutor() { public final T directExecutor() {
return executor(MoreExecutors.directExecutor()); return executor(MoreExecutors.directExecutor());
@ -133,6 +139,16 @@ public abstract class AbstractServerImplBuilder<T extends AbstractServerImplBuil
return thisT(); return thisT();
} }
/**
* Override the default Census implementation. This is meant to be used in tests.
*/
@VisibleForTesting
@Internal
public T censusContextFactory(CensusContextFactory censusFactory) {
this.censusFactory = censusFactory;
return thisT();
}
@Override @Override
public ServerImpl build() { public ServerImpl build() {
io.grpc.internal.InternalServer transportServer = buildTransportServer(); io.grpc.internal.InternalServer transportServer = buildTransportServer();
@ -140,7 +156,10 @@ public abstract class AbstractServerImplBuilder<T extends AbstractServerImplBuil
firstNonNull(fallbackRegistry, EMPTY_FALLBACK_REGISTRY), transportServer, firstNonNull(fallbackRegistry, EMPTY_FALLBACK_REGISTRY), transportServer,
Context.ROOT, firstNonNull(decompressorRegistry, DecompressorRegistry.getDefaultInstance()), Context.ROOT, firstNonNull(decompressorRegistry, DecompressorRegistry.getDefaultInstance()),
firstNonNull(compressorRegistry, CompressorRegistry.getDefaultInstance()), firstNonNull(compressorRegistry, CompressorRegistry.getDefaultInstance()),
transportFilters); transportFilters,
firstNonNull(censusFactory,
firstNonNull(Census.getCensusContextFactory(), NoopCensusContextFactory.INSTANCE)),
GrpcUtil.STOPWATCH_SUPPLIER);
} }
/** /**

View File

@ -90,11 +90,14 @@ public abstract class AbstractServerStream extends AbstractStream2
} }
private final MessageFramer framer; private final MessageFramer framer;
private final StatsTraceContext statsTraceCtx;
private boolean outboundClosed; private boolean outboundClosed;
private boolean headersSent; private boolean headersSent;
protected AbstractServerStream(WritableBufferAllocator bufferAllocator) { protected AbstractServerStream(WritableBufferAllocator bufferAllocator,
framer = new MessageFramer(this, bufferAllocator); StatsTraceContext statsTraceCtx) {
this.statsTraceCtx = Preconditions.checkNotNull(statsTraceCtx, "statsTraceCtx");
framer = new MessageFramer(this, bufferAllocator, statsTraceCtx);
} }
@Override @Override
@ -166,14 +169,19 @@ public abstract class AbstractServerStream extends AbstractStream2
return Attributes.EMPTY; return Attributes.EMPTY;
} }
@Override
public StatsTraceContext statsTraceContext() {
return statsTraceCtx;
}
/** This should only called from the transport thread. */ /** This should only called from the transport thread. */
protected abstract static class TransportState extends AbstractStream2.TransportState { protected abstract static class TransportState extends AbstractStream2.TransportState {
/** Whether listener.closed() has been called. */ /** Whether listener.closed() has been called. */
private boolean listenerClosed; private boolean listenerClosed;
private ServerStreamListener listener; private ServerStreamListener listener;
protected TransportState(int maxMessageSize) { protected TransportState(int maxMessageSize, StatsTraceContext statsTraceCtx) {
super(maxMessageSize); super(maxMessageSize, statsTraceCtx);
} }
/** /**

View File

@ -131,9 +131,11 @@ public abstract class AbstractStream implements Stream {
} }
} }
AbstractStream(WritableBufferAllocator bufferAllocator, int maxMessageSize) { AbstractStream(WritableBufferAllocator bufferAllocator, int maxMessageSize,
framer = new MessageFramer(new FramerSink(), bufferAllocator); StatsTraceContext statsTraceCtx) {
deframer = new MessageDeframer(new DeframerListener(), Codec.Identity.NONE, maxMessageSize); framer = new MessageFramer(new FramerSink(), bufferAllocator, statsTraceCtx);
deframer = new MessageDeframer(new DeframerListener(), Codec.Identity.NONE, maxMessageSize,
statsTraceCtx);
} }
@VisibleForTesting @VisibleForTesting

View File

@ -146,8 +146,8 @@ public abstract class AbstractStream2 implements Stream {
@GuardedBy("onReadyLock") @GuardedBy("onReadyLock")
private boolean deallocated; private boolean deallocated;
protected TransportState(int maxMessageSize) { protected TransportState(int maxMessageSize, StatsTraceContext statsTraceCtx) {
deframer = new MessageDeframer(this, Codec.Identity.NONE, maxMessageSize); deframer = new MessageDeframer(this, Codec.Identity.NONE, maxMessageSize, statsTraceCtx);
} }
@VisibleForTesting @VisibleForTesting

View File

@ -83,11 +83,12 @@ final class CallCredentialsApplyingTransportFactory implements ClientTransportFa
@Override @Override
public ClientStream newStream( public ClientStream newStream(
MethodDescriptor<?, ?> method, Metadata headers, CallOptions callOptions) { MethodDescriptor<?, ?> method, Metadata headers, CallOptions callOptions,
StatsTraceContext statsTraceCtx) {
CallCredentials creds = callOptions.getCredentials(); CallCredentials creds = callOptions.getCredentials();
if (creds != null) { if (creds != null) {
MetadataApplierImpl applier = new MetadataApplierImpl( MetadataApplierImpl applier = new MetadataApplierImpl(
delegate, method, headers, callOptions); delegate, method, headers, callOptions, statsTraceCtx);
Attributes.Builder effectiveAttrsBuilder = Attributes.newBuilder() Attributes.Builder effectiveAttrsBuilder = Attributes.newBuilder()
.set(CallCredentials.ATTR_AUTHORITY, authority) .set(CallCredentials.ATTR_AUTHORITY, authority)
.set(CallCredentials.ATTR_SECURITY_LEVEL, SecurityLevel.NONE) .set(CallCredentials.ATTR_SECURITY_LEVEL, SecurityLevel.NONE)
@ -99,7 +100,7 @@ final class CallCredentialsApplyingTransportFactory implements ClientTransportFa
firstNonNull(callOptions.getExecutor(), appExecutor), applier); firstNonNull(callOptions.getExecutor(), appExecutor), applier);
return applier.returnStream(); return applier.returnStream();
} else { } else {
return delegate.newStream(method, headers, callOptions); return delegate.newStream(method, headers, callOptions, statsTraceCtx);
} }
} }
} }

View File

@ -84,6 +84,7 @@ final class ClientCallImpl<ReqT, RespT> extends ClientCall<ReqT, RespT>
private volatile ScheduledFuture<?> deadlineCancellationFuture; private volatile ScheduledFuture<?> deadlineCancellationFuture;
private final boolean unaryRequest; private final boolean unaryRequest;
private final CallOptions callOptions; private final CallOptions callOptions;
private final StatsTraceContext statsTraceCtx;
private ClientStream stream; private ClientStream stream;
private volatile boolean cancelListenersShouldBeRemoved; private volatile boolean cancelListenersShouldBeRemoved;
private boolean cancelCalled; private boolean cancelCalled;
@ -94,7 +95,8 @@ final class ClientCallImpl<ReqT, RespT> extends ClientCall<ReqT, RespT>
private CompressorRegistry compressorRegistry = CompressorRegistry.getDefaultInstance(); private CompressorRegistry compressorRegistry = CompressorRegistry.getDefaultInstance();
ClientCallImpl(MethodDescriptor<ReqT, RespT> method, Executor executor, ClientCallImpl(MethodDescriptor<ReqT, RespT> method, Executor executor,
CallOptions callOptions, ClientTransportProvider clientTransportProvider, CallOptions callOptions, StatsTraceContext statsTraceCtx,
ClientTransportProvider clientTransportProvider,
ScheduledExecutorService deadlineCancellationExecutor) { ScheduledExecutorService deadlineCancellationExecutor) {
this.method = method; this.method = method;
// If we know that the executor is a direct executor, we don't need to wrap it with a // If we know that the executor is a direct executor, we don't need to wrap it with a
@ -105,6 +107,7 @@ final class ClientCallImpl<ReqT, RespT> extends ClientCall<ReqT, RespT>
: new SerializingExecutor(executor); : new SerializingExecutor(executor);
// Propagate the context from the thread which initiated the call to all callbacks. // Propagate the context from the thread which initiated the call to all callbacks.
this.context = Context.current(); this.context = Context.current();
this.statsTraceCtx = Preconditions.checkNotNull(statsTraceCtx, "statsTraceCtx");
this.unaryRequest = method.getType() == MethodType.UNARY this.unaryRequest = method.getType() == MethodType.UNARY
|| method.getType() == MethodType.SERVER_STREAMING; || method.getType() == MethodType.SERVER_STREAMING;
this.callOptions = callOptions; this.callOptions = callOptions;
@ -139,7 +142,7 @@ final class ClientCallImpl<ReqT, RespT> extends ClientCall<ReqT, RespT>
@VisibleForTesting @VisibleForTesting
static void prepareHeaders(Metadata headers, DecompressorRegistry decompressorRegistry, static void prepareHeaders(Metadata headers, DecompressorRegistry decompressorRegistry,
Compressor compressor) { Compressor compressor, StatsTraceContext statsTraceCtx) {
headers.discardAll(MESSAGE_ENCODING_KEY); headers.discardAll(MESSAGE_ENCODING_KEY);
if (compressor != Codec.Identity.NONE) { if (compressor != Codec.Identity.NONE) {
headers.put(MESSAGE_ENCODING_KEY, compressor.getMessageEncoding()); headers.put(MESSAGE_ENCODING_KEY, compressor.getMessageEncoding());
@ -150,6 +153,7 @@ final class ClientCallImpl<ReqT, RespT> extends ClientCall<ReqT, RespT>
if (!advertisedEncodings.isEmpty()) { if (!advertisedEncodings.isEmpty()) {
headers.put(MESSAGE_ACCEPT_ENCODING_KEY, advertisedEncodings); headers.put(MESSAGE_ACCEPT_ENCODING_KEY, advertisedEncodings);
} }
statsTraceCtx.propagateToHeaders(headers);
} }
@Override @Override
@ -169,7 +173,7 @@ final class ClientCallImpl<ReqT, RespT> extends ClientCall<ReqT, RespT>
@Override @Override
public void runInContext() { public void runInContext() {
observer.onClose(statusFromCancelled(context), new Metadata()); closeObserver(observer, statusFromCancelled(context), new Metadata());
} }
} }
@ -189,7 +193,8 @@ final class ClientCallImpl<ReqT, RespT> extends ClientCall<ReqT, RespT>
@Override @Override
public void runInContext() { public void runInContext() {
observer.onClose( closeObserver(
observer,
Status.INTERNAL.withDescription( Status.INTERNAL.withDescription(
String.format("Unable to find compressor by name %s", compressorName)), String.format("Unable to find compressor by name %s", compressorName)),
new Metadata()); new Metadata());
@ -203,7 +208,7 @@ final class ClientCallImpl<ReqT, RespT> extends ClientCall<ReqT, RespT>
compressor = Codec.Identity.NONE; compressor = Codec.Identity.NONE;
} }
prepareHeaders(headers, decompressorRegistry, compressor); prepareHeaders(headers, decompressorRegistry, compressor, statsTraceCtx);
Deadline effectiveDeadline = effectiveDeadline(); Deadline effectiveDeadline = effectiveDeadline();
boolean deadlineExceeded = effectiveDeadline != null && effectiveDeadline.isExpired(); boolean deadlineExceeded = effectiveDeadline != null && effectiveDeadline.isExpired();
@ -213,7 +218,7 @@ final class ClientCallImpl<ReqT, RespT> extends ClientCall<ReqT, RespT>
ClientTransport transport = clientTransportProvider.get(callOptions); ClientTransport transport = clientTransportProvider.get(callOptions);
Context origContext = context.attach(); Context origContext = context.attach();
try { try {
stream = transport.newStream(method, headers, callOptions); stream = transport.newStream(method, headers, callOptions, statsTraceCtx);
} finally { } finally {
context.detach(origContext); context.detach(origContext);
} }
@ -400,6 +405,11 @@ final class ClientCallImpl<ReqT, RespT> extends ClientCall<ReqT, RespT>
return stream.isReady(); return stream.isReady();
} }
private void closeObserver(Listener<RespT> observer, Status status, Metadata trailers) {
statsTraceCtx.callEnded(status);
observer.onClose(status, trailers);
}
private class ClientStreamListenerImpl implements ClientStreamListener { private class ClientStreamListenerImpl implements ClientStreamListener {
private final Listener<RespT> observer; private final Listener<RespT> observer;
private boolean closed; private boolean closed;
@ -483,7 +493,7 @@ final class ClientCallImpl<ReqT, RespT> extends ClientCall<ReqT, RespT>
closed = true; closed = true;
cancelListenersShouldBeRemoved = true; cancelListenersShouldBeRemoved = true;
try { try {
observer.onClose(status, trailers); closeObserver(observer, status, trailers);
} finally { } finally {
removeContextListenerAndCancelDeadlineFuture(); removeContextListenerAndCancelDeadlineFuture();
} }

View File

@ -60,12 +60,14 @@ public interface ClientTransport {
* @param method the descriptor of the remote method to be called for this stream. * @param method the descriptor of the remote method to be called for this stream.
* @param headers to send at the beginning of the call * @param headers to send at the beginning of the call
* @param callOptions runtime options of the call * @param callOptions runtime options of the call
* @param statsTraceCtx carries stats and tracing information
* @return the newly created stream. * @return the newly created stream.
*/ */
// TODO(nmittler): Consider also throwing for stopping. // TODO(nmittler): Consider also throwing for stopping.
ClientStream newStream(MethodDescriptor<?, ?> method, Metadata headers, CallOptions callOptions); ClientStream newStream(MethodDescriptor<?, ?> method, Metadata headers, CallOptions callOptions,
StatsTraceContext statsTraceCtx);
// TODO(zdapeng): Remove tow-argument version in favor of three-argument overload. // TODO(zdapeng): Remove two-argument version in favor of four-argument overload.
ClientStream newStream(MethodDescriptor<?, ?> method, Metadata headers); ClientStream newStream(MethodDescriptor<?, ?> method, Metadata headers);
/** /**

View File

@ -113,8 +113,8 @@ class DelayedClientTransport implements ManagedClientTransport {
* {@link FailingClientStream} is returned. * {@link FailingClientStream} is returned.
*/ */
@Override @Override
public ClientStream newStream(MethodDescriptor<?, ?> method, Metadata headers, CallOptions public ClientStream newStream(MethodDescriptor<?, ?> method, Metadata headers,
callOptions) { CallOptions callOptions, StatsTraceContext statsTraceCtx) {
Supplier<ClientTransport> supplier = transportSupplier; Supplier<ClientTransport> supplier = transportSupplier;
if (supplier == null) { if (supplier == null) {
synchronized (lock) { synchronized (lock) {
@ -124,7 +124,8 @@ class DelayedClientTransport implements ManagedClientTransport {
if (backoffStatus != null && !callOptions.isWaitForReady()) { if (backoffStatus != null && !callOptions.isWaitForReady()) {
return new FailingClientStream(backoffStatus); return new FailingClientStream(backoffStatus);
} }
PendingStream pendingStream = new PendingStream(method, headers, callOptions); PendingStream pendingStream = new PendingStream(method, headers, callOptions,
statsTraceCtx);
pendingStreams.add(pendingStream); pendingStreams.add(pendingStream);
if (pendingStreams.size() == 1) { if (pendingStreams.size() == 1) {
listener.transportInUse(true); listener.transportInUse(true);
@ -134,14 +135,14 @@ class DelayedClientTransport implements ManagedClientTransport {
} }
} }
if (supplier != null) { if (supplier != null) {
return supplier.get().newStream(method, headers, callOptions); return supplier.get().newStream(method, headers, callOptions, statsTraceCtx);
} }
return new FailingClientStream(Status.UNAVAILABLE.withDescription("transport shutdown")); return new FailingClientStream(Status.UNAVAILABLE.withDescription("transport shutdown"));
} }
@Override @Override
public ClientStream newStream(MethodDescriptor<?, ?> method, Metadata headers) { public ClientStream newStream(MethodDescriptor<?, ?> method, Metadata headers) {
return newStream(method, headers, CallOptions.DEFAULT); return newStream(method, headers, CallOptions.DEFAULT, StatsTraceContext.NOOP);
} }
@Override @Override
@ -382,20 +383,22 @@ class DelayedClientTransport implements ManagedClientTransport {
private final Metadata headers; private final Metadata headers;
private final CallOptions callOptions; private final CallOptions callOptions;
private final Context context; private final Context context;
private final StatsTraceContext statsTraceCtx;
private PendingStream(MethodDescriptor<?, ?> method, Metadata headers, private PendingStream(MethodDescriptor<?, ?> method, Metadata headers,
CallOptions callOptions) { CallOptions callOptions, StatsTraceContext statsTraceCtx) {
this.method = method; this.method = method;
this.headers = headers; this.headers = headers;
this.callOptions = callOptions; this.callOptions = callOptions;
this.context = Context.current(); this.context = Context.current();
this.statsTraceCtx = statsTraceCtx;
} }
private void createRealStream(ClientTransport transport) { private void createRealStream(ClientTransport transport) {
ClientStream realStream; ClientStream realStream;
Context origContext = context.attach(); Context origContext = context.attach();
try { try {
realStream = transport.newStream(method, headers, callOptions); realStream = transport.newStream(method, headers, callOptions, statsTraceCtx);
} finally { } finally {
context.detach(origContext); context.detach(origContext);
} }

View File

@ -55,14 +55,14 @@ class FailingClientTransport implements ClientTransport {
} }
@Override @Override
public ClientStream newStream(MethodDescriptor<?, ?> method, Metadata headers, CallOptions public ClientStream newStream(MethodDescriptor<?, ?> method, Metadata headers,
callOptions) { CallOptions callOptions, StatsTraceContext statsTraceCtx) {
return new FailingClientStream(error); return new FailingClientStream(error);
} }
@Override @Override
public ClientStream newStream(MethodDescriptor<?, ?> method, Metadata headers) { public ClientStream newStream(MethodDescriptor<?, ?> method, Metadata headers) {
return newStream(method, headers, CallOptions.DEFAULT); return newStream(method, headers, CallOptions.DEFAULT, StatsTraceContext.NOOP);
} }
@Override @Override

View File

@ -57,8 +57,9 @@ abstract class ForwardingConnectionClientTransport implements ConnectionClientTr
@Override @Override
public ClientStream newStream( public ClientStream newStream(
MethodDescriptor<?, ?> method, Metadata headers, CallOptions callOptions) { MethodDescriptor<?, ?> method, Metadata headers, CallOptions callOptions,
return delegate().newStream(method, headers, callOptions); StatsTraceContext statsTraceCtx) {
return delegate().newStream(method, headers, callOptions, statsTraceCtx);
} }
@Override @Override

View File

@ -456,7 +456,7 @@ public final class GrpcUtil {
/** /**
* The factory of default Stopwatches. * The factory of default Stopwatches.
*/ */
static final Supplier<Stopwatch> STOPWATCH_SUPPLIER = new Supplier<Stopwatch>() { public static final Supplier<Stopwatch> STOPWATCH_SUPPLIER = new Supplier<Stopwatch>() {
@Override @Override
public Stopwatch get() { public Stopwatch get() {
return Stopwatch.createUnstarted(); return Stopwatch.createUnstarted();

View File

@ -81,9 +81,9 @@ public abstract class Http2ClientStream extends AbstractClientStream {
private Charset errorCharset = Charsets.UTF_8; private Charset errorCharset = Charsets.UTF_8;
private boolean contentTypeChecked; private boolean contentTypeChecked;
protected Http2ClientStream(WritableBufferAllocator bufferAllocator, protected Http2ClientStream(WritableBufferAllocator bufferAllocator, int maxMessageSize,
int maxMessageSize) { StatsTraceContext statsTraceCtx) {
super(bufferAllocator, maxMessageSize); super(bufferAllocator, maxMessageSize, statsTraceCtx);
} }
/** /**

View File

@ -81,8 +81,8 @@ public abstract class Http2ClientStreamTransportState extends AbstractClientStre
private Charset errorCharset = Charsets.UTF_8; private Charset errorCharset = Charsets.UTF_8;
private boolean contentTypeChecked; private boolean contentTypeChecked;
protected Http2ClientStreamTransportState(int maxMessageSize) { protected Http2ClientStreamTransportState(int maxMessageSize, StatsTraceContext statsTraceCtx) {
super(maxMessageSize); super(maxMessageSize, statsTraceCtx);
} }
/** /**

View File

@ -35,6 +35,7 @@ import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Preconditions.checkState;
import com.google.census.CensusContextFactory;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Stopwatch; import com.google.common.base.Stopwatch;
import com.google.common.base.Supplier; import com.google.common.base.Supplier;
@ -115,6 +116,7 @@ public final class ManagedChannelImpl extends ManagedChannel implements WithLogI
private final SharedResourceHolder.Resource<ScheduledExecutorService> timerService; private final SharedResourceHolder.Resource<ScheduledExecutorService> timerService;
private final Supplier<Stopwatch> stopwatchSupplier; private final Supplier<Stopwatch> stopwatchSupplier;
private final long idleTimeoutMillis; private final long idleTimeoutMillis;
private final CensusContextFactory censusFactory;
/** /**
* Executor that runs deadline timers for requests. * Executor that runs deadline timers for requests.
@ -325,7 +327,7 @@ public final class ManagedChannelImpl extends ManagedChannel implements WithLogI
SharedResourceHolder.Resource<ScheduledExecutorService> timerService, SharedResourceHolder.Resource<ScheduledExecutorService> timerService,
Supplier<Stopwatch> stopwatchSupplier, long idleTimeoutMillis, Supplier<Stopwatch> stopwatchSupplier, long idleTimeoutMillis,
@Nullable Executor executor, @Nullable String userAgent, @Nullable Executor executor, @Nullable String userAgent,
List<ClientInterceptor> interceptors) { List<ClientInterceptor> interceptors, CensusContextFactory censusFactory) {
this.target = checkNotNull(target, "target"); this.target = checkNotNull(target, "target");
this.nameResolverFactory = checkNotNull(nameResolverFactory, "nameResolverFactory"); this.nameResolverFactory = checkNotNull(nameResolverFactory, "nameResolverFactory");
this.nameResolverParams = checkNotNull(nameResolverParams, "nameResolverParams"); this.nameResolverParams = checkNotNull(nameResolverParams, "nameResolverParams");
@ -351,6 +353,7 @@ public final class ManagedChannelImpl extends ManagedChannel implements WithLogI
this.decompressorRegistry = decompressorRegistry; this.decompressorRegistry = decompressorRegistry;
this.compressorRegistry = compressorRegistry; this.compressorRegistry = compressorRegistry;
this.userAgent = userAgent; this.userAgent = userAgent;
this.censusFactory = checkNotNull(censusFactory, "censusFactory");
if (log.isLoggable(Level.INFO)) { if (log.isLoggable(Level.INFO)) {
log.log(Level.INFO, "[{0}] Created with target {1}", new Object[] {getLogId(), target}); log.log(Level.INFO, "[{0}] Created with target {1}", new Object[] {getLogId(), target});
@ -544,10 +547,13 @@ public final class ManagedChannelImpl extends ManagedChannel implements WithLogI
if (executor == null) { if (executor == null) {
executor = ManagedChannelImpl.this.executor; executor = ManagedChannelImpl.this.executor;
} }
StatsTraceContext statsTraceCtx = StatsTraceContext.newClientContext(
method.getFullMethodName(), censusFactory, stopwatchSupplier);
return new ClientCallImpl<ReqT, RespT>( return new ClientCallImpl<ReqT, RespT>(
method, method,
executor, executor,
callOptions, callOptions,
statsTraceCtx,
transportProvider, transportProvider,
scheduledExecutor) scheduledExecutor)
.setDecompressorRegistry(decompressorRegistry) .setDecompressorRegistry(decompressorRegistry)
@ -652,7 +658,7 @@ public final class ManagedChannelImpl extends ManagedChannel implements WithLogI
@Override @Override
public Channel makeChannel(ClientTransport transport) { public Channel makeChannel(ClientTransport transport) {
return new SingleTransportChannel( return new SingleTransportChannel(
transport, executor, scheduledExecutor, authority()); censusFactory, transport, executor, scheduledExecutor, authority(), stopwatchSupplier);
} }
@Override @Override

View File

@ -98,6 +98,7 @@ public class MessageDeframer implements Closeable {
private final Listener listener; private final Listener listener;
private final int maxMessageSize; private final int maxMessageSize;
private final StatsTraceContext statsTraceCtx;
private Decompressor decompressor; private Decompressor decompressor;
private State state = State.HEADER; private State state = State.HEADER;
private int requiredLength = HEADER_LENGTH; private int requiredLength = HEADER_LENGTH;
@ -117,10 +118,12 @@ public class MessageDeframer implements Closeable {
* {@code NONE} meaning unsupported * {@code NONE} meaning unsupported
* @param maxMessageSize the maximum allowed size for received messages. * @param maxMessageSize the maximum allowed size for received messages.
*/ */
public MessageDeframer(Listener listener, Decompressor decompressor, int maxMessageSize) { public MessageDeframer(Listener listener, Decompressor decompressor, int maxMessageSize,
StatsTraceContext statsTraceCtx) {
this.listener = Preconditions.checkNotNull(listener, "sink"); this.listener = Preconditions.checkNotNull(listener, "sink");
this.decompressor = Preconditions.checkNotNull(decompressor, "decompressor"); this.decompressor = Preconditions.checkNotNull(decompressor, "decompressor");
this.maxMessageSize = maxMessageSize; this.maxMessageSize = maxMessageSize;
this.statsTraceCtx = checkNotNull(statsTraceCtx, "statsTraceCtx");
} }
/** /**
@ -314,6 +317,9 @@ public class MessageDeframer implements Closeable {
} finally { } finally {
if (totalBytesRead > 0) { if (totalBytesRead > 0) {
listener.bytesRead(totalBytesRead); listener.bytesRead(totalBytesRead);
if (state == State.BODY) {
statsTraceCtx.wireBytesReceived(totalBytesRead);
}
} }
} }
} }
@ -357,6 +363,7 @@ public class MessageDeframer implements Closeable {
} }
private InputStream getUncompressedBody() { private InputStream getUncompressedBody() {
statsTraceCtx.uncompressedBytesReceived(nextFrame.readableBytes());
return ReadableBuffers.openStream(nextFrame, true); return ReadableBuffers.openStream(nextFrame, true);
} }
@ -370,7 +377,7 @@ public class MessageDeframer implements Closeable {
// Enforce the maxMessageSize limit on the returned stream. // Enforce the maxMessageSize limit on the returned stream.
InputStream unlimitedStream = InputStream unlimitedStream =
decompressor.decompress(ReadableBuffers.openStream(nextFrame, true)); decompressor.decompress(ReadableBuffers.openStream(nextFrame, true));
return new SizeEnforcingInputStream(unlimitedStream, maxMessageSize); return new SizeEnforcingInputStream(unlimitedStream, maxMessageSize, statsTraceCtx);
} catch (IOException e) { } catch (IOException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
@ -382,12 +389,15 @@ public class MessageDeframer implements Closeable {
@VisibleForTesting @VisibleForTesting
static final class SizeEnforcingInputStream extends FilterInputStream { static final class SizeEnforcingInputStream extends FilterInputStream {
private final int maxMessageSize; private final int maxMessageSize;
private final StatsTraceContext statsTraceCtx;
private long maxCount;
private long count; private long count;
private long mark = -1; private long mark = -1;
SizeEnforcingInputStream(InputStream in, int maxMessageSize) { SizeEnforcingInputStream(InputStream in, int maxMessageSize, StatsTraceContext statsTraceCtx) {
super(in); super(in);
this.maxMessageSize = maxMessageSize; this.maxMessageSize = maxMessageSize;
this.statsTraceCtx = statsTraceCtx;
} }
@Override @Override
@ -397,6 +407,7 @@ public class MessageDeframer implements Closeable {
count++; count++;
} }
verifySize(); verifySize();
reportCount();
return result; return result;
} }
@ -407,6 +418,7 @@ public class MessageDeframer implements Closeable {
count += result; count += result;
} }
verifySize(); verifySize();
reportCount();
return result; return result;
} }
@ -415,6 +427,7 @@ public class MessageDeframer implements Closeable {
long result = in.skip(n); long result = in.skip(n);
count += result; count += result;
verifySize(); verifySize();
reportCount();
return result; return result;
} }
@ -438,6 +451,13 @@ public class MessageDeframer implements Closeable {
count = mark; count = mark;
} }
private void reportCount() {
if (count > maxCount) {
statsTraceCtx.uncompressedBytesReceived(count - maxCount);
maxCount = count;
}
}
private void verifySize() { private void verifySize() {
if (count > maxMessageSize) { if (count > maxMessageSize) {
throw Status.INTERNAL.withDescription(String.format( throw Status.INTERNAL.withDescription(String.format(

View File

@ -85,6 +85,7 @@ public class MessageFramer {
private final OutputStreamAdapter outputStreamAdapter = new OutputStreamAdapter(); private final OutputStreamAdapter outputStreamAdapter = new OutputStreamAdapter();
private final byte[] headerScratch = new byte[HEADER_LENGTH]; private final byte[] headerScratch = new byte[HEADER_LENGTH];
private final WritableBufferAllocator bufferAllocator; private final WritableBufferAllocator bufferAllocator;
private final StatsTraceContext statsTraceCtx;
private boolean closed; private boolean closed;
/** /**
@ -93,9 +94,11 @@ public class MessageFramer {
* @param sink the sink used to deliver frames to the transport * @param sink the sink used to deliver frames to the transport
* @param bufferAllocator allocates buffers that the transport can commit to the wire. * @param bufferAllocator allocates buffers that the transport can commit to the wire.
*/ */
public MessageFramer(Sink sink, WritableBufferAllocator bufferAllocator) { public MessageFramer(Sink sink, WritableBufferAllocator bufferAllocator,
StatsTraceContext statsTraceCtx) {
this.sink = checkNotNull(sink, "sink"); this.sink = checkNotNull(sink, "sink");
this.bufferAllocator = checkNotNull(bufferAllocator, "bufferAllocator"); this.bufferAllocator = checkNotNull(bufferAllocator, "bufferAllocator");
this.statsTraceCtx = checkNotNull(statsTraceCtx, "statsTraceCtx");
} }
MessageFramer setCompressor(Compressor compressor) { MessageFramer setCompressor(Compressor compressor) {
@ -142,10 +145,12 @@ public class MessageFramer {
String err = String.format("Message length inaccurate %s != %s", written, messageLength); String err = String.format("Message length inaccurate %s != %s", written, messageLength);
throw Status.INTERNAL.withDescription(err).asRuntimeException(); throw Status.INTERNAL.withDescription(err).asRuntimeException();
} }
statsTraceCtx.uncompressedBytesSent(written);
} }
private int writeUncompressed(InputStream message, int messageLength) throws IOException { private int writeUncompressed(InputStream message, int messageLength) throws IOException {
if (messageLength != -1) { if (messageLength != -1) {
statsTraceCtx.wireBytesSent(messageLength);
return writeKnownLengthUncompressed(message, messageLength); return writeKnownLengthUncompressed(message, messageLength);
} }
BufferChainOutputStream bufferChain = new BufferChainOutputStream(); BufferChainOutputStream bufferChain = new BufferChainOutputStream();
@ -220,6 +225,7 @@ public class MessageFramer {
// Assign the current buffer to the last in the chain so it can be used // Assign the current buffer to the last in the chain so it can be used
// for future writes or written with end-of-stream=true on close. // for future writes or written with end-of-stream=true on close.
buffer = bufferList.get(bufferList.size() - 1); buffer = bufferList.get(bufferList.size() - 1);
statsTraceCtx.wireBytesSent(messageLength);
} }
private static int writeToOutputStream(InputStream message, OutputStream outputStream) private static int writeToOutputStream(InputStream message, OutputStream outputStream)

View File

@ -51,6 +51,7 @@ final class MetadataApplierImpl implements MetadataApplier {
private final Metadata origHeaders; private final Metadata origHeaders;
private final CallOptions callOptions; private final CallOptions callOptions;
private final Context ctx; private final Context ctx;
private final StatsTraceContext statsTraceCtx;
private final Object lock = new Object(); private final Object lock = new Object();
@ -66,12 +67,13 @@ final class MetadataApplierImpl implements MetadataApplier {
DelayedStream delayedStream; DelayedStream delayedStream;
MetadataApplierImpl(ClientTransport transport, MethodDescriptor<?, ?> method, MetadataApplierImpl(ClientTransport transport, MethodDescriptor<?, ?> method,
Metadata origHeaders, CallOptions callOptions) { Metadata origHeaders, CallOptions callOptions, StatsTraceContext statsTraceCtx) {
this.transport = transport; this.transport = transport;
this.method = method; this.method = method;
this.origHeaders = origHeaders; this.origHeaders = origHeaders;
this.callOptions = callOptions; this.callOptions = callOptions;
this.ctx = Context.current(); this.ctx = Context.current();
this.statsTraceCtx = statsTraceCtx;
} }
@Override @Override
@ -82,7 +84,7 @@ final class MetadataApplierImpl implements MetadataApplier {
ClientStream realStream; ClientStream realStream;
Context origCtx = ctx.attach(); Context origCtx = ctx.attach();
try { try {
realStream = transport.newStream(method, origHeaders, callOptions); realStream = transport.newStream(method, origHeaders, callOptions, statsTraceCtx);
} finally { } finally {
ctx.detach(origCtx); ctx.detach(origCtx);
} }

View File

@ -0,0 +1,90 @@
/*
* Copyright 2016, Google Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
*
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
package io.grpc.internal;
import com.google.census.CensusContext;
import com.google.census.CensusContextFactory;
import com.google.census.MetricMap;
import com.google.census.TagKey;
import com.google.census.TagValue;
import java.nio.ByteBuffer;
public final class NoopCensusContextFactory extends CensusContextFactory {
private static final ByteBuffer SERIALIZED_BYTES = ByteBuffer.allocate(0).asReadOnlyBuffer();
private static final CensusContext DEFAULT_CONTEXT = new NoopCensusContext();
private static final CensusContext.Builder BUILDER = new NoopContextBuilder();
public static final CensusContextFactory INSTANCE = new NoopCensusContextFactory();
private NoopCensusContextFactory() {
}
@Override
public CensusContext deserialize(ByteBuffer buffer) {
return DEFAULT_CONTEXT;
}
@Override
public CensusContext getDefault() {
return DEFAULT_CONTEXT;
}
private static class NoopCensusContext extends CensusContext {
@Override
public Builder builder() {
return BUILDER;
}
@Override
public CensusContext record(MetricMap metrics) {
return DEFAULT_CONTEXT;
}
@Override
public ByteBuffer serialize() {
return SERIALIZED_BYTES;
}
}
private static class NoopContextBuilder extends CensusContext.Builder {
@Override
public CensusContext.Builder set(TagKey key, TagValue value) {
return this;
}
@Override
public CensusContext build() {
return DEFAULT_CONTEXT;
}
}
}

View File

@ -65,6 +65,7 @@ final class ServerCallImpl<ReqT, RespT> extends ServerCall<ReqT, RespT> {
private final String messageAcceptEncoding; private final String messageAcceptEncoding;
private final DecompressorRegistry decompressorRegistry; private final DecompressorRegistry decompressorRegistry;
private final CompressorRegistry compressorRegistry; private final CompressorRegistry compressorRegistry;
private final StatsTraceContext statsTraceCtx;
// state // state
private volatile boolean cancelled; private volatile boolean cancelled;
@ -73,7 +74,7 @@ final class ServerCallImpl<ReqT, RespT> extends ServerCall<ReqT, RespT> {
private Compressor compressor; private Compressor compressor;
ServerCallImpl(ServerStream stream, MethodDescriptor<ReqT, RespT> method, ServerCallImpl(ServerStream stream, MethodDescriptor<ReqT, RespT> method,
Metadata inboundHeaders, Context.CancellableContext context, Metadata inboundHeaders, Context.CancellableContext context, StatsTraceContext statsTraceCtx,
DecompressorRegistry decompressorRegistry, CompressorRegistry compressorRegistry) { DecompressorRegistry decompressorRegistry, CompressorRegistry compressorRegistry) {
this.stream = stream; this.stream = stream;
this.method = method; this.method = method;
@ -81,6 +82,7 @@ final class ServerCallImpl<ReqT, RespT> extends ServerCall<ReqT, RespT> {
this.messageAcceptEncoding = inboundHeaders.get(MESSAGE_ACCEPT_ENCODING_KEY); this.messageAcceptEncoding = inboundHeaders.get(MESSAGE_ACCEPT_ENCODING_KEY);
this.decompressorRegistry = decompressorRegistry; this.decompressorRegistry = decompressorRegistry;
this.compressorRegistry = compressorRegistry; this.compressorRegistry = compressorRegistry;
this.statsTraceCtx = checkNotNull(statsTraceCtx, "statsTraceCtx");
if (inboundHeaders.containsKey(MESSAGE_ENCODING_KEY)) { if (inboundHeaders.containsKey(MESSAGE_ENCODING_KEY)) {
String encoding = inboundHeaders.get(MESSAGE_ENCODING_KEY); String encoding = inboundHeaders.get(MESSAGE_ENCODING_KEY);
@ -186,7 +188,7 @@ final class ServerCallImpl<ReqT, RespT> extends ServerCall<ReqT, RespT> {
} }
ServerStreamListener newServerStreamListener(ServerCall.Listener<ReqT> listener) { ServerStreamListener newServerStreamListener(ServerCall.Listener<ReqT> listener) {
return new ServerStreamListenerImpl<ReqT>(this, listener, context); return new ServerStreamListenerImpl<ReqT>(this, listener, context, statsTraceCtx);
} }
@Override @Override
@ -208,14 +210,16 @@ final class ServerCallImpl<ReqT, RespT> extends ServerCall<ReqT, RespT> {
private final ServerCallImpl<ReqT, ?> call; private final ServerCallImpl<ReqT, ?> call;
private final ServerCall.Listener<ReqT> listener; private final ServerCall.Listener<ReqT> listener;
private final Context.CancellableContext context; private final Context.CancellableContext context;
private final StatsTraceContext statsTraceCtx;
private boolean messageReceived; private boolean messageReceived;
public ServerStreamListenerImpl( public ServerStreamListenerImpl(
ServerCallImpl<ReqT, ?> call, ServerCall.Listener<ReqT> listener, ServerCallImpl<ReqT, ?> call, ServerCall.Listener<ReqT> listener,
Context.CancellableContext context) { Context.CancellableContext context, StatsTraceContext statsTraceCtx) {
this.call = checkNotNull(call, "call"); this.call = checkNotNull(call, "call");
this.listener = checkNotNull(listener, "listener must not be null"); this.listener = checkNotNull(listener, "listener must not be null");
this.context = checkNotNull(context, "context"); this.context = checkNotNull(context, "context");
this.statsTraceCtx = checkNotNull(statsTraceCtx, "statsTraceCtx");
} }
@Override @Override
@ -263,6 +267,7 @@ final class ServerCallImpl<ReqT, RespT> extends ServerCall<ReqT, RespT> {
@Override @Override
public void closed(Status status) { public void closed(Status status) {
try { try {
statsTraceCtx.callEnded(status);
if (status.isOk()) { if (status.isOk()) {
listener.onComplete(); listener.onComplete();
} else { } else {

View File

@ -39,7 +39,10 @@ import static io.grpc.internal.GrpcUtil.TIMEOUT_KEY;
import static io.grpc.internal.GrpcUtil.TIMER_SERVICE; import static io.grpc.internal.GrpcUtil.TIMER_SERVICE;
import static java.util.concurrent.TimeUnit.NANOSECONDS; import static java.util.concurrent.TimeUnit.NANOSECONDS;
import com.google.census.CensusContextFactory;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import com.google.common.base.Stopwatch;
import com.google.common.base.Supplier;
import io.grpc.Attributes; import io.grpc.Attributes;
import io.grpc.CompressorRegistry; import io.grpc.CompressorRegistry;
@ -90,6 +93,7 @@ public final class ServerImpl extends io.grpc.Server {
private final InternalHandlerRegistry registry; private final InternalHandlerRegistry registry;
private final HandlerRegistry fallbackRegistry; private final HandlerRegistry fallbackRegistry;
private final List<ServerTransportFilter> transportFilters; private final List<ServerTransportFilter> transportFilters;
private final CensusContextFactory censusFactory;
@GuardedBy("lock") private boolean started; @GuardedBy("lock") private boolean started;
@GuardedBy("lock") private boolean shutdown; @GuardedBy("lock") private boolean shutdown;
/** non-{@code null} if immediate shutdown has been requested. */ /** non-{@code null} if immediate shutdown has been requested. */
@ -110,6 +114,7 @@ public final class ServerImpl extends io.grpc.Server {
private final DecompressorRegistry decompressorRegistry; private final DecompressorRegistry decompressorRegistry;
private final CompressorRegistry compressorRegistry; private final CompressorRegistry compressorRegistry;
private final Supplier<Stopwatch> stopwatchSupplier;
/** /**
* Construct a server. * Construct a server.
@ -122,7 +127,8 @@ public final class ServerImpl extends io.grpc.Server {
ServerImpl(Executor executor, InternalHandlerRegistry registry, HandlerRegistry fallbackRegistry, ServerImpl(Executor executor, InternalHandlerRegistry registry, HandlerRegistry fallbackRegistry,
InternalServer transportServer, Context rootContext, InternalServer transportServer, Context rootContext,
DecompressorRegistry decompressorRegistry, CompressorRegistry compressorRegistry, DecompressorRegistry decompressorRegistry, CompressorRegistry compressorRegistry,
List<ServerTransportFilter> transportFilters) { List<ServerTransportFilter> transportFilters, CensusContextFactory censusFactory,
Supplier<Stopwatch> stopwatchSupplier) {
this.executor = executor; this.executor = executor;
this.registry = Preconditions.checkNotNull(registry, "registry"); this.registry = Preconditions.checkNotNull(registry, "registry");
this.fallbackRegistry = Preconditions.checkNotNull(fallbackRegistry, "fallbackRegistry"); this.fallbackRegistry = Preconditions.checkNotNull(fallbackRegistry, "fallbackRegistry");
@ -134,6 +140,8 @@ public final class ServerImpl extends io.grpc.Server {
this.compressorRegistry = compressorRegistry; this.compressorRegistry = compressorRegistry;
this.transportFilters = Collections.unmodifiableList( this.transportFilters = Collections.unmodifiableList(
new ArrayList<ServerTransportFilter>(transportFilters)); new ArrayList<ServerTransportFilter>(transportFilters));
this.censusFactory = Preconditions.checkNotNull(censusFactory, "censusFactory");
this.stopwatchSupplier = Preconditions.checkNotNull(stopwatchSupplier, "stopwatchSupplier");
} }
/** /**
@ -347,10 +355,20 @@ public final class ServerImpl extends io.grpc.Server {
transportClosed(transport); transportClosed(transport);
} }
@Override
public StatsTraceContext methodDetermined(String methodName, Metadata headers) {
return StatsTraceContext.newServerContext(
methodName, censusFactory, headers, stopwatchSupplier);
}
@Override @Override
public ServerStreamListener streamCreated(final ServerStream stream, final String methodName, public ServerStreamListener streamCreated(final ServerStream stream, final String methodName,
final Metadata headers) { final Metadata headers) {
final Context.CancellableContext context = createContext(stream, headers);
final StatsTraceContext statsTraceCtx = Preconditions.checkNotNull(
stream.statsTraceContext(), "statsTraceCtx not present from stream");
final Context.CancellableContext context = createContext(stream, headers, statsTraceCtx);
final Executor wrappedExecutor; final Executor wrappedExecutor;
// This is a performance optimization that avoids the synchronization and queuing overhead // This is a performance optimization that avoids the synchronization and queuing overhead
// that comes with SerializingExecutor. // that comes with SerializingExecutor.
@ -375,9 +393,13 @@ public final class ServerImpl extends io.grpc.Server {
method = fallbackRegistry.lookupMethod(methodName); method = fallbackRegistry.lookupMethod(methodName);
} }
if (method == null) { if (method == null) {
stream.close( Status status = Status.UNIMPLEMENTED.withDescription(
Status.UNIMPLEMENTED.withDescription("Method not found: " + methodName), "Method not found: " + methodName);
new Metadata()); stream.close(status, new Metadata());
// TODO(zhangkun83): this would allow a misbehaving client to blow up the server
// in-memory stats storage by sending large number of distinct unimplemented method
// names. (https://github.com/grpc/grpc-java/issues/2285)
statsTraceCtx.callEnded(status);
context.cancel(null); context.cancel(null);
return; return;
} }
@ -398,15 +420,19 @@ public final class ServerImpl extends io.grpc.Server {
return jumpListener; return jumpListener;
} }
private Context.CancellableContext createContext(final ServerStream stream, Metadata headers) { private Context.CancellableContext createContext(
final ServerStream stream, Metadata headers, StatsTraceContext statsTraceCtx) {
Long timeoutNanos = headers.get(TIMEOUT_KEY); Long timeoutNanos = headers.get(TIMEOUT_KEY);
// TODO(zhangkun83): attach the CensusContext from StatsTraceContext to baseContext
Context baseContext = rootContext;
if (timeoutNanos == null) { if (timeoutNanos == null) {
return rootContext.withCancellation(); return baseContext.withCancellation();
} }
Context.CancellableContext context = Context.CancellableContext context =
rootContext.withDeadlineAfter(timeoutNanos, NANOSECONDS, timeoutService); baseContext.withDeadlineAfter(timeoutNanos, NANOSECONDS, timeoutService);
context.addListener(new Context.CancellationListener() { context.addListener(new Context.CancellationListener() {
@Override @Override
public void cancelled(Context context) { public void cancelled(Context context) {
@ -428,8 +454,8 @@ public final class ServerImpl extends io.grpc.Server {
Context.CancellableContext context) { Context.CancellableContext context) {
// TODO(ejona86): should we update fullMethodName to have the canonical path of the method? // TODO(ejona86): should we update fullMethodName to have the canonical path of the method?
ServerCallImpl<ReqT, RespT> call = new ServerCallImpl<ReqT, RespT>( ServerCallImpl<ReqT, RespT> call = new ServerCallImpl<ReqT, RespT>(
stream, methodDef.getMethodDescriptor(), headers, context, decompressorRegistry, stream, methodDef.getMethodDescriptor(), headers, context, stream.statsTraceContext(),
compressorRegistry); decompressorRegistry, compressorRegistry);
ServerCall.Listener<ReqT> listener = ServerCall.Listener<ReqT> listener =
methodDef.getServerCallHandler().startCall(call, headers); methodDef.getServerCallHandler().startCall(call, headers);
if (listener == null) { if (listener == null) {

View File

@ -75,4 +75,9 @@ public interface ServerStream extends Stream {
* @return Attributes container * @return Attributes container
*/ */
Attributes attributes(); Attributes attributes();
/**
* The context for recording stats and traces for this stream.
*/
StatsTraceContext statsTraceContext();
} }

View File

@ -40,6 +40,14 @@ import io.grpc.Metadata;
*/ */
public interface ServerTransportListener { public interface ServerTransportListener {
/**
* Called when the method name for a new stream has been determined, which happens before the
* stream is actually created and {@link #streamCreated} is called.
*
* @return a context object for recording stats and tracing for the new stream.
*/
StatsTraceContext methodDetermined(String methodName, Metadata headers);
/** /**
* Called when a new stream was created by the remote client. * Called when a new stream was created by the remote client.
* *

View File

@ -31,7 +31,10 @@
package io.grpc.internal; package io.grpc.internal;
import com.google.census.CensusContextFactory;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import com.google.common.base.Stopwatch;
import com.google.common.base.Supplier;
import io.grpc.CallOptions; import io.grpc.CallOptions;
import io.grpc.Channel; import io.grpc.Channel;
@ -47,10 +50,12 @@ import java.util.concurrent.ScheduledExecutorService;
*/ */
final class SingleTransportChannel extends Channel { final class SingleTransportChannel extends Channel {
private final CensusContextFactory censusFactory;
private final ClientTransport transport; private final ClientTransport transport;
private final Executor executor; private final Executor executor;
private final String authority; private final String authority;
private final ScheduledExecutorService deadlineCancellationExecutor; private final ScheduledExecutorService deadlineCancellationExecutor;
private final Supplier<Stopwatch> stopwatchSupplier;
private final ClientTransportProvider transportProvider = new ClientTransportProvider() { private final ClientTransportProvider transportProvider = new ClientTransportProvider() {
@Override @Override
@ -62,20 +67,25 @@ final class SingleTransportChannel extends Channel {
/** /**
* Creates a new channel with a connected transport. * Creates a new channel with a connected transport.
*/ */
public SingleTransportChannel(ClientTransport transport, Executor executor, public SingleTransportChannel(CensusContextFactory censusFactory, ClientTransport transport,
ScheduledExecutorService deadlineCancellationExecutor, String authority) { Executor executor, ScheduledExecutorService deadlineCancellationExecutor, String authority,
Supplier<Stopwatch> stopwatchSupplier) {
this.censusFactory = Preconditions.checkNotNull(censusFactory, "censusFactory");
this.transport = Preconditions.checkNotNull(transport, "transport"); this.transport = Preconditions.checkNotNull(transport, "transport");
this.executor = Preconditions.checkNotNull(executor, "executor"); this.executor = Preconditions.checkNotNull(executor, "executor");
this.deadlineCancellationExecutor = Preconditions.checkNotNull( this.deadlineCancellationExecutor = Preconditions.checkNotNull(
deadlineCancellationExecutor, "deadlineCancellationExecutor"); deadlineCancellationExecutor, "deadlineCancellationExecutor");
this.authority = Preconditions.checkNotNull(authority, "authority"); this.authority = Preconditions.checkNotNull(authority, "authority");
this.stopwatchSupplier = Preconditions.checkNotNull(stopwatchSupplier, "stopwatchSupplier");
} }
@Override @Override
public <RequestT, ResponseT> ClientCall<RequestT, ResponseT> newCall( public <RequestT, ResponseT> ClientCall<RequestT, ResponseT> newCall(
MethodDescriptor<RequestT, ResponseT> methodDescriptor, CallOptions callOptions) { MethodDescriptor<RequestT, ResponseT> methodDescriptor, CallOptions callOptions) {
StatsTraceContext statsTraceCtx = StatsTraceContext.newClientContext(
methodDescriptor.getFullMethodName(), censusFactory, stopwatchSupplier);
return new ClientCallImpl<RequestT, ResponseT>(methodDescriptor, return new ClientCallImpl<RequestT, ResponseT>(methodDescriptor,
new SerializingExecutor(executor), callOptions, transportProvider, new SerializingExecutor(executor), callOptions, statsTraceCtx, transportProvider,
deadlineCancellationExecutor); deadlineCancellationExecutor);
} }

View File

@ -0,0 +1,235 @@
/*
* Copyright 2016, Google Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
*
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
package io.grpc.internal;
import static com.google.common.base.Preconditions.checkState;
import com.google.census.CensusContext;
import com.google.census.CensusContextFactory;
import com.google.census.MetricMap;
import com.google.census.MetricName;
import com.google.census.RpcConstants;
import com.google.census.TagKey;
import com.google.census.TagValue;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Stopwatch;
import com.google.common.base.Supplier;
import io.grpc.Metadata;
import io.grpc.Status;
import java.nio.ByteBuffer;
import java.util.concurrent.TimeUnit;
/**
* The stats and tracing information for a call.
*/
public final class StatsTraceContext {
public static final StatsTraceContext NOOP = StatsTraceContext.newClientContext(
"noopservice/noopmethod", NoopCensusContextFactory.INSTANCE,
GrpcUtil.STOPWATCH_SUPPLIER);
private enum Side {
CLIENT, SERVER
}
private final CensusContext censusCtx;
private final Stopwatch stopwatch;
private final Side side;
private final Metadata.Key<CensusContext> censusHeader;
private long wireBytesSent;
private long wireBytesReceived;
private long uncompressedBytesSent;
private long uncompressedBytesReceived;
private boolean callEnded;
private StatsTraceContext(Side side, String fullMethodName, CensusContext parentCtx,
Supplier<Stopwatch> stopwatchSupplier, Metadata.Key<CensusContext> censusHeader) {
this.side = side;
TagKey methodTagKey =
side == Side.CLIENT ? RpcConstants.RPC_CLIENT_METHOD : RpcConstants.RPC_SERVER_METHOD;
// TODO(carl-mastrangelo): maybe cache TagValue in MethodDescriptor
this.censusCtx = parentCtx.with(methodTagKey, new TagValue(fullMethodName));
this.stopwatch = stopwatchSupplier.get().start();
this.censusHeader = censusHeader;
}
/**
* Creates a {@code StatsTraceContext} for an outgoing RPC, using the current CensusContext.
*
* <p>The current time is used as the start time of the RPC.
*/
public static StatsTraceContext newClientContext(String methodName,
CensusContextFactory censusFactory, Supplier<Stopwatch> stopwatchSupplier) {
return new StatsTraceContext(Side.CLIENT, methodName,
// TODO(zhangkun83): use the CensusContext out of the current Context
censusFactory.getDefault(),
stopwatchSupplier, createCensusHeader(censusFactory));
}
@VisibleForTesting
static StatsTraceContext newClientContextForTesting(String methodName,
CensusContextFactory censusFactory, CensusContext parent,
Supplier<Stopwatch> stopwatchSupplier) {
return new StatsTraceContext(Side.CLIENT, methodName, parent, stopwatchSupplier,
createCensusHeader(censusFactory));
}
/**
* Creates a {@code StatsTraceContext} for an incoming RPC, using the CensusContext deserialized
* from the headers.
*
* <p>The current time is used as the start time of the RPC.
*/
public static StatsTraceContext newServerContext(String methodName,
CensusContextFactory censusFactory, Metadata headers,
Supplier<Stopwatch> stopwatchSupplier) {
Metadata.Key<CensusContext> censusHeader = createCensusHeader(censusFactory);
CensusContext parentCtx = headers.get(censusHeader);
if (parentCtx == null) {
parentCtx = censusFactory.getDefault();
}
return new StatsTraceContext(Side.SERVER, methodName, parentCtx, stopwatchSupplier,
censusHeader);
}
/**
* Propagate the context to the outgoing headers.
*/
void propagateToHeaders(Metadata headers) {
headers.discardAll(censusHeader);
headers.put(censusHeader, censusCtx);
}
Metadata.Key<CensusContext> getCensusHeader() {
return censusHeader;
}
@VisibleForTesting
CensusContext getCensusContext() {
return censusCtx;
}
@VisibleForTesting
static Metadata.Key<CensusContext> createCensusHeader(
final CensusContextFactory censusCtxFactory) {
return Metadata.Key.of("grpc-census-bin", new Metadata.BinaryMarshaller<CensusContext>() {
@Override
public byte[] toBytes(CensusContext context) {
ByteBuffer buffer = context.serialize();
// TODO(carl-mastrangelo): currently we only make sure the correctness. We may need to
// optimize out the allocation and copy in the future.
byte[] bytes = new byte[buffer.remaining()];
buffer.get(bytes);
return bytes;
}
@Override
public CensusContext parseBytes(byte[] serialized) {
return censusCtxFactory.deserialize(ByteBuffer.wrap(serialized));
}
});
}
/**
* Record the outgoing number of payload bytes as on the wire.
*/
void wireBytesSent(long bytes) {
// TODO(zhangkun83): maybe change of the checkState() to assert after this class is stabilized.
checkState(!callEnded, "already eneded");
wireBytesSent += bytes;
}
/**
* Record the incoming number of payload bytes as on the wire.
*/
void wireBytesReceived(long bytes) {
checkState(!callEnded, "already eneded");
wireBytesReceived += bytes;
}
/**
* Record the outgoing number of payload bytes in uncompressed form.
*
* <p>The time this method is called is unrelated to the actual time when those byte are sent.
*/
void uncompressedBytesSent(long bytes) {
checkState(!callEnded, "already ended");
uncompressedBytesSent += bytes;
}
/**
* Record the incoming number of payload bytes in uncompressed form.
*
* <p>The time this method is called is unrelated to the actual time when those byte are received.
*/
void uncompressedBytesReceived(long bytes) {
checkState(!callEnded, "already ended");
uncompressedBytesReceived += bytes;
}
/**
* Record a finished all and mark the current time as the end time.
*/
void callEnded(Status status) {
checkState(!callEnded, "already ended");
callEnded = true;
stopwatch.stop();
MetricName latencyMetric;
MetricName wireBytesSentMetric;
MetricName wireBytesReceivedMetric;
MetricName uncompressedBytesSentMetric;
MetricName uncompressedBytesReceivedMetric;
if (side == Side.CLIENT) {
latencyMetric = RpcConstants.RPC_CLIENT_ROUNDTRIP_LATENCY;
wireBytesSentMetric = RpcConstants.RPC_CLIENT_REQUEST_BYTES;
wireBytesReceivedMetric = RpcConstants.RPC_CLIENT_RESPONSE_BYTES;
uncompressedBytesSentMetric = RpcConstants.RPC_CLIENT_UNCOMPRESSED_REQUEST_BYTES;
uncompressedBytesReceivedMetric = RpcConstants.RPC_CLIENT_UNCOMPRESSED_RESPONSE_BYTES;
} else {
latencyMetric = RpcConstants.RPC_SERVER_SERVER_LATENCY;
wireBytesSentMetric = RpcConstants.RPC_SERVER_RESPONSE_BYTES;
wireBytesReceivedMetric = RpcConstants.RPC_SERVER_REQUEST_BYTES;
uncompressedBytesSentMetric = RpcConstants.RPC_SERVER_UNCOMPRESSED_RESPONSE_BYTES;
uncompressedBytesReceivedMetric = RpcConstants.RPC_SERVER_UNCOMPRESSED_REQUEST_BYTES;
}
censusCtx
.with(RpcConstants.RPC_STATUS, new TagValue(status.getCode().toString()))
.record(MetricMap.builder()
.put(latencyMetric, stopwatch.elapsed(TimeUnit.MILLISECONDS))
.put(wireBytesSentMetric, wireBytesSent)
.put(wireBytesReceivedMetric, wireBytesReceived)
.put(uncompressedBytesSentMetric, uncompressedBytesSent)
.put(uncompressedBytesReceivedMetric, uncompressedBytesReceived)
.build());
}
}

View File

@ -360,7 +360,7 @@ final class TransportSet extends ManagedChannel implements WithLogId {
public final <RequestT, ResponseT> ClientCall<RequestT, ResponseT> newCall( public final <RequestT, ResponseT> ClientCall<RequestT, ResponseT> newCall(
MethodDescriptor<RequestT, ResponseT> methodDescriptor, CallOptions callOptions) { MethodDescriptor<RequestT, ResponseT> methodDescriptor, CallOptions callOptions) {
return new ClientCallImpl<RequestT, ResponseT>(methodDescriptor, return new ClientCallImpl<RequestT, ResponseT>(methodDescriptor,
new SerializingExecutor(appExecutor), callOptions, new SerializingExecutor(appExecutor), callOptions, StatsTraceContext.NOOP,
new ClientTransportProvider() { new ClientTransportProvider() {
@Override @Override
public ClientTransport get(CallOptions callOptions) { public ClientTransport get(CallOptions callOptions) {

View File

@ -63,6 +63,7 @@ public class AbstractClientStream2Test {
@Rule public final ExpectedException thrown = ExpectedException.none(); @Rule public final ExpectedException thrown = ExpectedException.none();
private final StatsTraceContext statsTraceCtx = StatsTraceContext.NOOP;
@Mock private ClientStreamListener mockListener; @Mock private ClientStreamListener mockListener;
@Captor private ArgumentCaptor<Status> statusCaptor; @Captor private ArgumentCaptor<Status> statusCaptor;
@ -82,7 +83,7 @@ public class AbstractClientStream2Test {
public void cancel_doNotAcceptOk() { public void cancel_doNotAcceptOk() {
for (Code code : Code.values()) { for (Code code : Code.values()) {
ClientStreamListener listener = new NoopClientStreamListener(); ClientStreamListener listener = new NoopClientStreamListener();
AbstractClientStream2 stream = new BaseAbstractClientStream(allocator); AbstractClientStream2 stream = new BaseAbstractClientStream(allocator, statsTraceCtx);
stream.start(listener); stream.start(listener);
if (code != Code.OK) { if (code != Code.OK) {
stream.cancel(Status.fromCodeValue(code.value())); stream.cancel(Status.fromCodeValue(code.value()));
@ -100,7 +101,7 @@ public class AbstractClientStream2Test {
@Test @Test
public void cancel_failsOnNull() { public void cancel_failsOnNull() {
ClientStreamListener listener = new NoopClientStreamListener(); ClientStreamListener listener = new NoopClientStreamListener();
AbstractClientStream2 stream = new BaseAbstractClientStream(allocator); AbstractClientStream2 stream = new BaseAbstractClientStream(allocator, statsTraceCtx);
stream.start(listener); stream.start(listener);
thrown.expect(NullPointerException.class); thrown.expect(NullPointerException.class);
@ -109,14 +110,14 @@ public class AbstractClientStream2Test {
@Test @Test
public void cancel_notifiesOnlyOnce() { public void cancel_notifiesOnlyOnce() {
final BaseTransportState state = new BaseTransportState(); final BaseTransportState state = new BaseTransportState(statsTraceCtx);
AbstractClientStream2 stream = new BaseAbstractClientStream(allocator, state, new BaseSink() { AbstractClientStream2 stream = new BaseAbstractClientStream(allocator, state, new BaseSink() {
@Override @Override
public void cancel(Status errorStatus) { public void cancel(Status errorStatus) {
// Cancel should eventually result in a transportReportStatus on the transport thread // Cancel should eventually result in a transportReportStatus on the transport thread
state.transportReportStatus(errorStatus, true/*stop delivery*/, new Metadata()); state.transportReportStatus(errorStatus, true/*stop delivery*/, new Metadata());
} }
}); }, statsTraceCtx);
stream.start(mockListener); stream.start(mockListener);
stream.cancel(Status.DEADLINE_EXCEEDED); stream.cancel(Status.DEADLINE_EXCEEDED);
@ -127,7 +128,7 @@ public class AbstractClientStream2Test {
@Test @Test
public void startFailsOnNullListener() { public void startFailsOnNullListener() {
AbstractClientStream2 stream = new BaseAbstractClientStream(allocator); AbstractClientStream2 stream = new BaseAbstractClientStream(allocator, statsTraceCtx);
thrown.expect(NullPointerException.class); thrown.expect(NullPointerException.class);
@ -136,7 +137,7 @@ public class AbstractClientStream2Test {
@Test @Test
public void cantCallStartTwice() { public void cantCallStartTwice() {
AbstractClientStream2 stream = new BaseAbstractClientStream(allocator); AbstractClientStream2 stream = new BaseAbstractClientStream(allocator, statsTraceCtx);
stream.start(mockListener); stream.start(mockListener);
thrown.expect(IllegalStateException.class); thrown.expect(IllegalStateException.class);
@ -146,7 +147,7 @@ public class AbstractClientStream2Test {
@Test @Test
public void inboundDataReceived_failsOnNullFrame() { public void inboundDataReceived_failsOnNullFrame() {
ClientStreamListener listener = new NoopClientStreamListener(); ClientStreamListener listener = new NoopClientStreamListener();
AbstractClientStream2 stream = new BaseAbstractClientStream(allocator); AbstractClientStream2 stream = new BaseAbstractClientStream(allocator, statsTraceCtx);
stream.start(listener); stream.start(listener);
thrown.expect(NullPointerException.class); thrown.expect(NullPointerException.class);
@ -155,7 +156,7 @@ public class AbstractClientStream2Test {
@Test @Test
public void inboundDataReceived_failsOnNoHeaders() { public void inboundDataReceived_failsOnNoHeaders() {
AbstractClientStream2 stream = new BaseAbstractClientStream(allocator); AbstractClientStream2 stream = new BaseAbstractClientStream(allocator, statsTraceCtx);
stream.start(mockListener); stream.start(mockListener);
stream.transportState().inboundDataReceived(ReadableBuffers.empty()); stream.transportState().inboundDataReceived(ReadableBuffers.empty());
@ -166,7 +167,7 @@ public class AbstractClientStream2Test {
@Test @Test
public void inboundHeadersReceived_notifiesListener() { public void inboundHeadersReceived_notifiesListener() {
AbstractClientStream2 stream = new BaseAbstractClientStream(allocator); AbstractClientStream2 stream = new BaseAbstractClientStream(allocator, statsTraceCtx);
stream.start(mockListener); stream.start(mockListener);
Metadata headers = new Metadata(); Metadata headers = new Metadata();
@ -176,7 +177,7 @@ public class AbstractClientStream2Test {
@Test @Test
public void inboundHeadersReceived_failsIfStatusReported() { public void inboundHeadersReceived_failsIfStatusReported() {
AbstractClientStream2 stream = new BaseAbstractClientStream(allocator); AbstractClientStream2 stream = new BaseAbstractClientStream(allocator, statsTraceCtx);
stream.start(mockListener); stream.start(mockListener);
stream.transportState().transportReportStatus(Status.CANCELLED, false, new Metadata()); stream.transportState().transportReportStatus(Status.CANCELLED, false, new Metadata());
@ -186,7 +187,7 @@ public class AbstractClientStream2Test {
@Test @Test
public void inboundHeadersReceived_acceptsGzipEncoding() { public void inboundHeadersReceived_acceptsGzipEncoding() {
AbstractClientStream2 stream = new BaseAbstractClientStream(allocator); AbstractClientStream2 stream = new BaseAbstractClientStream(allocator, statsTraceCtx);
stream.start(mockListener); stream.start(mockListener);
Metadata headers = new Metadata(); Metadata headers = new Metadata();
headers.put(GrpcUtil.MESSAGE_ENCODING_KEY, new Codec.Gzip().getMessageEncoding()); headers.put(GrpcUtil.MESSAGE_ENCODING_KEY, new Codec.Gzip().getMessageEncoding());
@ -197,7 +198,7 @@ public class AbstractClientStream2Test {
@Test @Test
public void inboundHeadersReceived_acceptsIdentityEncoding() { public void inboundHeadersReceived_acceptsIdentityEncoding() {
AbstractClientStream2 stream = new BaseAbstractClientStream(allocator); AbstractClientStream2 stream = new BaseAbstractClientStream(allocator, statsTraceCtx);
stream.start(mockListener); stream.start(mockListener);
Metadata headers = new Metadata(); Metadata headers = new Metadata();
headers.put(GrpcUtil.MESSAGE_ENCODING_KEY, Codec.Identity.NONE.getMessageEncoding()); headers.put(GrpcUtil.MESSAGE_ENCODING_KEY, Codec.Identity.NONE.getMessageEncoding());
@ -208,7 +209,7 @@ public class AbstractClientStream2Test {
@Test @Test
public void rstStreamClosesStream() { public void rstStreamClosesStream() {
AbstractClientStream2 stream = new BaseAbstractClientStream(allocator); AbstractClientStream2 stream = new BaseAbstractClientStream(allocator, statsTraceCtx);
stream.start(mockListener); stream.start(mockListener);
// The application will call request when waiting for a message, which will in turn call this // The application will call request when waiting for a message, which will in turn call this
// on the transport thread. // on the transport thread.
@ -229,13 +230,14 @@ public class AbstractClientStream2Test {
private final TransportState state; private final TransportState state;
private final Sink sink; private final Sink sink;
public BaseAbstractClientStream(WritableBufferAllocator allocator) { public BaseAbstractClientStream(WritableBufferAllocator allocator,
this(allocator, new BaseTransportState(), new BaseSink()); StatsTraceContext statsTraceCtx) {
this(allocator, new BaseTransportState(statsTraceCtx), new BaseSink(), statsTraceCtx);
} }
public BaseAbstractClientStream(WritableBufferAllocator allocator, TransportState state, public BaseAbstractClientStream(WritableBufferAllocator allocator, TransportState state,
Sink sink) { Sink sink, StatsTraceContext statsTraceCtx) {
super(allocator); super(allocator, statsTraceCtx);
this.state = state; this.state = state;
this.sink = sink; this.sink = sink;
} }
@ -266,8 +268,8 @@ public class AbstractClientStream2Test {
} }
private static class BaseTransportState extends AbstractClientStream2.TransportState { private static class BaseTransportState extends AbstractClientStream2.TransportState {
public BaseTransportState() { public BaseTransportState(StatsTraceContext statsTraceCtx) {
super(DEFAULT_MAX_MESSAGE_SIZE); super(DEFAULT_MAX_MESSAGE_SIZE, statsTraceCtx);
} }
@Override @Override

View File

@ -257,7 +257,7 @@ public class AbstractClientStreamTest {
*/ */
private static class BaseAbstractClientStream extends AbstractClientStream { private static class BaseAbstractClientStream extends AbstractClientStream {
protected BaseAbstractClientStream(WritableBufferAllocator allocator) { protected BaseAbstractClientStream(WritableBufferAllocator allocator) {
super(allocator, DEFAULT_MAX_MESSAGE_SIZE); super(allocator, DEFAULT_MAX_MESSAGE_SIZE, StatsTraceContext.NOOP);
} }
@Override @Override

View File

@ -240,7 +240,7 @@ public class AbstractServerStreamTest {
protected AbstractServerStreamBase(WritableBufferAllocator bufferAllocator, Sink sink, protected AbstractServerStreamBase(WritableBufferAllocator bufferAllocator, Sink sink,
AbstractServerStream.TransportState state) { AbstractServerStream.TransportState state) {
super(bufferAllocator); super(bufferAllocator, StatsTraceContext.NOOP);
this.sink = sink; this.sink = sink;
this.state = state; this.state = state;
} }
@ -257,7 +257,7 @@ public class AbstractServerStreamTest {
static class TransportState extends AbstractServerStream.TransportState { static class TransportState extends AbstractServerStream.TransportState {
protected TransportState(int maxMessageSize) { protected TransportState(int maxMessageSize) {
super(maxMessageSize); super(maxMessageSize, StatsTraceContext.NOOP);
} }
@Override @Override

View File

@ -120,7 +120,7 @@ public class AbstractStreamTest {
*/ */
private class AbstractStreamBase extends AbstractStream { private class AbstractStreamBase extends AbstractStream {
private AbstractStreamBase(WritableBufferAllocator bufferAllocator) { private AbstractStreamBase(WritableBufferAllocator bufferAllocator) {
super(allocator, DEFAULT_MAX_MESSAGE_SIZE); super(allocator, DEFAULT_MAX_MESSAGE_SIZE, StatsTraceContext.NOOP);
} }
private AbstractStreamBase(MessageFramer framer, MessageDeframer deframer) { private AbstractStreamBase(MessageFramer framer, MessageDeframer deframer) {

View File

@ -105,6 +105,8 @@ public class CallCredentialsApplyingTest {
private static final String CREDS_VALUE = "some credentials"; private static final String CREDS_VALUE = "some credentials";
private final Metadata origHeaders = new Metadata(); private final Metadata origHeaders = new Metadata();
private final StatsTraceContext statsTraceCtx = StatsTraceContext.newClientContext(
method.getFullMethodName(), NoopCensusContextFactory.INSTANCE, GrpcUtil.STOPWATCH_SUPPLIER);
private ForwardingConnectionClientTransport transport; private ForwardingConnectionClientTransport transport;
private CallOptions callOptions; private CallOptions callOptions;
@ -114,7 +116,8 @@ public class CallCredentialsApplyingTest {
origHeaders.put(ORIG_HEADER_KEY, ORIG_HEADER_VALUE); origHeaders.put(ORIG_HEADER_KEY, ORIG_HEADER_VALUE);
when(mockTransportFactory.newClientTransport(address, AUTHORITY, USER_AGENT)) when(mockTransportFactory.newClientTransport(address, AUTHORITY, USER_AGENT))
.thenReturn(mockTransport); .thenReturn(mockTransport);
when(mockTransport.newStream(same(method), any(Metadata.class), any(CallOptions.class))) when(mockTransport.newStream(same(method), any(Metadata.class), any(CallOptions.class),
any(StatsTraceContext.class)))
.thenReturn(mockStream); .thenReturn(mockStream);
ClientTransportFactory transportFactory = new CallCredentialsApplyingTransportFactory( ClientTransportFactory transportFactory = new CallCredentialsApplyingTransportFactory(
mockTransportFactory, mockExecutor); mockTransportFactory, mockExecutor);
@ -130,7 +133,7 @@ public class CallCredentialsApplyingTest {
Attributes transportAttrs = Attributes.newBuilder().set(ATTR_KEY, ATTR_VALUE).build(); Attributes transportAttrs = Attributes.newBuilder().set(ATTR_KEY, ATTR_VALUE).build();
when(mockTransport.getAttrs()).thenReturn(transportAttrs); when(mockTransport.getAttrs()).thenReturn(transportAttrs);
transport.newStream(method, origHeaders, callOptions); transport.newStream(method, origHeaders, callOptions, statsTraceCtx);
ArgumentCaptor<Attributes> attrsCaptor = ArgumentCaptor.forClass(null); ArgumentCaptor<Attributes> attrsCaptor = ArgumentCaptor.forClass(null);
verify(mockCreds).applyRequestMetadata(same(method), attrsCaptor.capture(), same(mockExecutor), verify(mockCreds).applyRequestMetadata(same(method), attrsCaptor.capture(), same(mockExecutor),
@ -150,7 +153,7 @@ public class CallCredentialsApplyingTest {
.build(); .build();
when(mockTransport.getAttrs()).thenReturn(transportAttrs); when(mockTransport.getAttrs()).thenReturn(transportAttrs);
transport.newStream(method, origHeaders, callOptions); transport.newStream(method, origHeaders, callOptions, statsTraceCtx);
ArgumentCaptor<Attributes> attrsCaptor = ArgumentCaptor.forClass(null); ArgumentCaptor<Attributes> attrsCaptor = ArgumentCaptor.forClass(null);
verify(mockCreds).applyRequestMetadata(same(method), attrsCaptor.capture(), same(mockExecutor), verify(mockCreds).applyRequestMetadata(same(method), attrsCaptor.capture(), same(mockExecutor),
@ -172,7 +175,8 @@ public class CallCredentialsApplyingTest {
Executor anotherExecutor = mock(Executor.class); Executor anotherExecutor = mock(Executor.class);
transport.newStream(method, origHeaders, transport.newStream(method, origHeaders,
callOptions.withAuthority("calloptions-authority").withExecutor(anotherExecutor)); callOptions.withAuthority("calloptions-authority").withExecutor(anotherExecutor),
statsTraceCtx);
ArgumentCaptor<Attributes> attrsCaptor = ArgumentCaptor.forClass(null); ArgumentCaptor<Attributes> attrsCaptor = ArgumentCaptor.forClass(null);
verify(mockCreds).applyRequestMetadata(same(method), attrsCaptor.capture(), verify(mockCreds).applyRequestMetadata(same(method), attrsCaptor.capture(),
@ -198,9 +202,9 @@ public class CallCredentialsApplyingTest {
}).when(mockCreds).applyRequestMetadata(same(method), any(Attributes.class), }).when(mockCreds).applyRequestMetadata(same(method), any(Attributes.class),
same(mockExecutor), any(MetadataApplier.class)); same(mockExecutor), any(MetadataApplier.class));
ClientStream stream = transport.newStream(method, origHeaders, callOptions); ClientStream stream = transport.newStream(method, origHeaders, callOptions, statsTraceCtx);
verify(mockTransport).newStream(method, origHeaders, callOptions); verify(mockTransport).newStream(method, origHeaders, callOptions, statsTraceCtx);
assertSame(mockStream, stream); assertSame(mockStream, stream);
assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY)); assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY));
assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY)); assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY));
@ -221,9 +225,9 @@ public class CallCredentialsApplyingTest {
same(mockExecutor), any(MetadataApplier.class)); same(mockExecutor), any(MetadataApplier.class));
FailingClientStream stream = FailingClientStream stream =
(FailingClientStream) transport.newStream(method, origHeaders, callOptions); (FailingClientStream) transport.newStream(method, origHeaders, callOptions, statsTraceCtx);
verify(mockTransport, never()).newStream(method, origHeaders, callOptions); verify(mockTransport, never()).newStream(method, origHeaders, callOptions, statsTraceCtx);
assertSame(error, stream.getError()); assertSame(error, stream.getError());
} }
@ -232,18 +236,19 @@ public class CallCredentialsApplyingTest {
when(mockTransport.getAttrs()).thenReturn(Attributes.EMPTY); when(mockTransport.getAttrs()).thenReturn(Attributes.EMPTY);
// Will call applyRequestMetadata(), which is no-op. // Will call applyRequestMetadata(), which is no-op.
DelayedStream stream = (DelayedStream) transport.newStream(method, origHeaders, callOptions); DelayedStream stream = (DelayedStream) transport.newStream(method, origHeaders, callOptions,
statsTraceCtx);
ArgumentCaptor<MetadataApplier> applierCaptor = ArgumentCaptor.forClass(null); ArgumentCaptor<MetadataApplier> applierCaptor = ArgumentCaptor.forClass(null);
verify(mockCreds).applyRequestMetadata(same(method), any(Attributes.class), verify(mockCreds).applyRequestMetadata(same(method), any(Attributes.class),
same(mockExecutor), applierCaptor.capture()); same(mockExecutor), applierCaptor.capture());
verify(mockTransport, never()).newStream(method, origHeaders, callOptions); verify(mockTransport, never()).newStream(method, origHeaders, callOptions, statsTraceCtx);
Metadata headers = new Metadata(); Metadata headers = new Metadata();
headers.put(CREDS_KEY, CREDS_VALUE); headers.put(CREDS_KEY, CREDS_VALUE);
applierCaptor.getValue().apply(headers); applierCaptor.getValue().apply(headers);
verify(mockTransport).newStream(method, origHeaders, callOptions); verify(mockTransport).newStream(method, origHeaders, callOptions, statsTraceCtx);
assertSame(mockStream, stream.getRealStream()); assertSame(mockStream, stream.getRealStream());
assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY)); assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY));
assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY)); assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY));
@ -254,7 +259,8 @@ public class CallCredentialsApplyingTest {
when(mockTransport.getAttrs()).thenReturn(Attributes.EMPTY); when(mockTransport.getAttrs()).thenReturn(Attributes.EMPTY);
// Will call applyRequestMetadata(), which is no-op. // Will call applyRequestMetadata(), which is no-op.
DelayedStream stream = (DelayedStream) transport.newStream(method, origHeaders, callOptions); DelayedStream stream = (DelayedStream) transport.newStream(method, origHeaders, callOptions,
statsTraceCtx);
ArgumentCaptor<MetadataApplier> applierCaptor = ArgumentCaptor.forClass(null); ArgumentCaptor<MetadataApplier> applierCaptor = ArgumentCaptor.forClass(null);
verify(mockCreds).applyRequestMetadata(same(method), any(Attributes.class), verify(mockCreds).applyRequestMetadata(same(method), any(Attributes.class),
@ -263,7 +269,7 @@ public class CallCredentialsApplyingTest {
Status error = Status.FAILED_PRECONDITION.withDescription("channel not secure for creds"); Status error = Status.FAILED_PRECONDITION.withDescription("channel not secure for creds");
applierCaptor.getValue().fail(error); applierCaptor.getValue().fail(error);
verify(mockTransport, never()).newStream(method, origHeaders, callOptions); verify(mockTransport, never()).newStream(method, origHeaders, callOptions, statsTraceCtx);
FailingClientStream failingStream = (FailingClientStream) stream.getRealStream(); FailingClientStream failingStream = (FailingClientStream) stream.getRealStream();
assertSame(error, failingStream.getError()); assertSame(error, failingStream.getError());
} }
@ -271,9 +277,9 @@ public class CallCredentialsApplyingTest {
@Test @Test
public void noCreds() { public void noCreds() {
callOptions = callOptions.withCallCredentials(null); callOptions = callOptions.withCallCredentials(null);
ClientStream stream = transport.newStream(method, origHeaders, callOptions); ClientStream stream = transport.newStream(method, origHeaders, callOptions, statsTraceCtx);
verify(mockTransport).newStream(method, origHeaders, callOptions); verify(mockTransport).newStream(method, origHeaders, callOptions, statsTraceCtx);
assertSame(mockStream, stream); assertSame(mockStream, stream);
assertNull(origHeaders.get(CREDS_KEY)); assertNull(origHeaders.get(CREDS_KEY));
assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY)); assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY));

View File

@ -51,6 +51,9 @@ import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions; import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import com.google.census.CensusContext;
import com.google.census.RpcConstants;
import com.google.census.TagValue;
import com.google.common.collect.ImmutableSet; import com.google.common.collect.ImmutableSet;
import com.google.common.util.concurrent.MoreExecutors; import com.google.common.util.concurrent.MoreExecutors;
import com.google.common.util.concurrent.SettableFuture; import com.google.common.util.concurrent.SettableFuture;
@ -68,6 +71,8 @@ import io.grpc.MethodDescriptor.Marshaller;
import io.grpc.MethodDescriptor.MethodType; import io.grpc.MethodDescriptor.MethodType;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.internal.ClientCallImpl.ClientTransportProvider; import io.grpc.internal.ClientCallImpl.ClientTransportProvider;
import io.grpc.internal.testing.CensusTestUtils.FakeCensusContextFactory;
import io.grpc.internal.testing.CensusTestUtils;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
@ -116,6 +121,14 @@ public class ClientCallImplTest {
new TestMarshaller<Void>(), new TestMarshaller<Void>(),
new TestMarshaller<Void>()); new TestMarshaller<Void>());
private final FakeCensusContextFactory censusCtxFactory = new FakeCensusContextFactory();
private final CensusContext parentCensusContext = censusCtxFactory.getDefault().with(
CensusTestUtils.EXTRA_TAG, new TagValue("extra-tag-value"));
private final StatsTraceContext statsTraceCtx = StatsTraceContext.newClientContextForTesting(
method.getFullMethodName(), censusCtxFactory, parentCensusContext,
fakeClock.getStopwatchSupplier());
private final CensusContext censusCtx = censusCtxFactory.contexts.poll();
@Mock private ClientStreamListener streamListener; @Mock private ClientStreamListener streamListener;
@Mock private ClientTransport clientTransport; @Mock private ClientTransport clientTransport;
@Captor private ArgumentCaptor<Status> statusCaptor; @Captor private ArgumentCaptor<Status> statusCaptor;
@ -141,9 +154,10 @@ public class ClientCallImplTest {
@Before @Before
public void setUp() { public void setUp() {
MockitoAnnotations.initMocks(this); MockitoAnnotations.initMocks(this);
assertNotNull(censusCtx);
when(provider.get(any(CallOptions.class))).thenReturn(transport); when(provider.get(any(CallOptions.class))).thenReturn(transport);
when(transport.newStream(any(MethodDescriptor.class), any(Metadata.class), when(transport.newStream(any(MethodDescriptor.class), any(Metadata.class),
any(CallOptions.class))).thenReturn(stream); any(CallOptions.class), any(StatsTraceContext.class))).thenReturn(stream);
} }
@After @After
@ -151,6 +165,29 @@ public class ClientCallImplTest {
Context.ROOT.attach(); Context.ROOT.attach();
} }
@Test
public void statusPropagatedFromStreamToCallListener() {
DelayedExecutor executor = new DelayedExecutor();
ClientCallImpl<Void, Void> call = new ClientCallImpl<Void, Void>(
method,
executor,
CallOptions.DEFAULT,
statsTraceCtx,
provider,
deadlineCancellationExecutor);
call.start(callListener, new Metadata());
verify(stream).start(listenerArgumentCaptor.capture());
final ClientStreamListener streamListener = listenerArgumentCaptor.getValue();
streamListener.headersRead(new Metadata());
Status status = Status.RESOURCE_EXHAUSTED.withDescription("simulated");
streamListener.closed(status , new Metadata());
executor.release();
verify(callListener).onClose(statusArgumentCaptor.capture(), Matchers.isA(Metadata.class));
assertThat(statusArgumentCaptor.getValue()).isSameAs(status);
assertStatusInStats(status.getCode());
}
@Test @Test
public void exceptionInOnMessageTakesPrecedenceOverServer() { public void exceptionInOnMessageTakesPrecedenceOverServer() {
DelayedExecutor executor = new DelayedExecutor(); DelayedExecutor executor = new DelayedExecutor();
@ -158,6 +195,7 @@ public class ClientCallImplTest {
method, method,
executor, executor,
CallOptions.DEFAULT, CallOptions.DEFAULT,
statsTraceCtx,
provider, provider,
deadlineCancellationExecutor); deadlineCancellationExecutor);
call.start(callListener, new Metadata()); call.start(callListener, new Metadata());
@ -182,6 +220,7 @@ public class ClientCallImplTest {
assertThat(statusArgumentCaptor.getValue().getCode()).isEqualTo(Status.Code.CANCELLED); assertThat(statusArgumentCaptor.getValue().getCode()).isEqualTo(Status.Code.CANCELLED);
assertThat(statusArgumentCaptor.getValue().getCause()).isSameAs(failure); assertThat(statusArgumentCaptor.getValue().getCause()).isSameAs(failure);
verify(stream).cancel(statusArgumentCaptor.getValue()); verify(stream).cancel(statusArgumentCaptor.getValue());
assertStatusInStats(Status.Code.CANCELLED);
} }
@Test @Test
@ -191,6 +230,7 @@ public class ClientCallImplTest {
method, method,
executor, executor,
CallOptions.DEFAULT, CallOptions.DEFAULT,
statsTraceCtx,
provider, provider,
deadlineCancellationExecutor); deadlineCancellationExecutor);
call.start(callListener, new Metadata()); call.start(callListener, new Metadata());
@ -214,6 +254,7 @@ public class ClientCallImplTest {
assertThat(statusArgumentCaptor.getValue().getCode()).isEqualTo(Status.Code.CANCELLED); assertThat(statusArgumentCaptor.getValue().getCode()).isEqualTo(Status.Code.CANCELLED);
assertThat(statusArgumentCaptor.getValue().getCause()).isSameAs(failure); assertThat(statusArgumentCaptor.getValue().getCause()).isSameAs(failure);
verify(stream).cancel(statusArgumentCaptor.getValue()); verify(stream).cancel(statusArgumentCaptor.getValue());
assertStatusInStats(Status.Code.CANCELLED);
} }
@Test @Test
@ -223,6 +264,7 @@ public class ClientCallImplTest {
method, method,
executor, executor,
CallOptions.DEFAULT, CallOptions.DEFAULT,
statsTraceCtx,
provider, provider,
deadlineCancellationExecutor); deadlineCancellationExecutor);
call.start(callListener, new Metadata()); call.start(callListener, new Metadata());
@ -246,6 +288,7 @@ public class ClientCallImplTest {
assertThat(statusArgumentCaptor.getValue().getCode()).isEqualTo(Status.Code.CANCELLED); assertThat(statusArgumentCaptor.getValue().getCode()).isEqualTo(Status.Code.CANCELLED);
assertThat(statusArgumentCaptor.getValue().getCause()).isSameAs(failure); assertThat(statusArgumentCaptor.getValue().getCause()).isSameAs(failure);
verify(stream).cancel(statusArgumentCaptor.getValue()); verify(stream).cancel(statusArgumentCaptor.getValue());
assertStatusInStats(Status.Code.CANCELLED);
} }
@Test @Test
@ -254,6 +297,7 @@ public class ClientCallImplTest {
method, method,
MoreExecutors.directExecutor(), MoreExecutors.directExecutor(),
CallOptions.DEFAULT, CallOptions.DEFAULT,
statsTraceCtx,
provider, provider,
deadlineCancellationExecutor) deadlineCancellationExecutor)
.setDecompressorRegistry(decompressorRegistry); .setDecompressorRegistry(decompressorRegistry);
@ -261,7 +305,8 @@ public class ClientCallImplTest {
call.start(callListener, new Metadata()); call.start(callListener, new Metadata());
ArgumentCaptor<Metadata> metadataCaptor = ArgumentCaptor.forClass(Metadata.class); ArgumentCaptor<Metadata> metadataCaptor = ArgumentCaptor.forClass(Metadata.class);
verify(transport).newStream(eq(method), metadataCaptor.capture(), same(CallOptions.DEFAULT)); verify(transport).newStream(eq(method), metadataCaptor.capture(), same(CallOptions.DEFAULT),
same(statsTraceCtx));
Metadata actual = metadataCaptor.getValue(); Metadata actual = metadataCaptor.getValue();
Set<String> acceptedEncodings = Set<String> acceptedEncodings =
@ -275,6 +320,7 @@ public class ClientCallImplTest {
method, method,
MoreExecutors.directExecutor(), MoreExecutors.directExecutor(),
CallOptions.DEFAULT.withAuthority("overridden-authority"), CallOptions.DEFAULT.withAuthority("overridden-authority"),
statsTraceCtx,
provider, provider,
deadlineCancellationExecutor) deadlineCancellationExecutor)
.setDecompressorRegistry(decompressorRegistry); .setDecompressorRegistry(decompressorRegistry);
@ -290,6 +336,7 @@ public class ClientCallImplTest {
method, method,
MoreExecutors.directExecutor(), MoreExecutors.directExecutor(),
callOptions, callOptions,
statsTraceCtx,
provider, provider,
deadlineCancellationExecutor) deadlineCancellationExecutor)
.setDecompressorRegistry(decompressorRegistry); .setDecompressorRegistry(decompressorRegistry);
@ -297,7 +344,8 @@ public class ClientCallImplTest {
call.start(callListener, metadata); call.start(callListener, metadata);
verify(transport).newStream(same(method), same(metadata), same(callOptions)); verify(transport).newStream(same(method), same(metadata), same(callOptions),
same(statsTraceCtx));
} }
@Test @Test
@ -307,6 +355,7 @@ public class ClientCallImplTest {
MoreExecutors.directExecutor(), MoreExecutors.directExecutor(),
// Don't provide an authority // Don't provide an authority
CallOptions.DEFAULT, CallOptions.DEFAULT,
statsTraceCtx,
provider, provider,
deadlineCancellationExecutor) deadlineCancellationExecutor)
.setDecompressorRegistry(decompressorRegistry); .setDecompressorRegistry(decompressorRegistry);
@ -319,7 +368,7 @@ public class ClientCallImplTest {
public void prepareHeaders_userAgentIgnored() { public void prepareHeaders_userAgentIgnored() {
Metadata m = new Metadata(); Metadata m = new Metadata();
m.put(GrpcUtil.USER_AGENT_KEY, "batmobile"); m.put(GrpcUtil.USER_AGENT_KEY, "batmobile");
ClientCallImpl.prepareHeaders(m, decompressorRegistry, Codec.Identity.NONE); ClientCallImpl.prepareHeaders(m, decompressorRegistry, Codec.Identity.NONE, statsTraceCtx);
// User Agent is removed and set by the transport // User Agent is removed and set by the transport
assertThat(m.get(GrpcUtil.USER_AGENT_KEY)).isNotNull(); assertThat(m.get(GrpcUtil.USER_AGENT_KEY)).isNotNull();
@ -328,7 +377,7 @@ public class ClientCallImplTest {
@Test @Test
public void prepareHeaders_ignoreIdentityEncoding() { public void prepareHeaders_ignoreIdentityEncoding() {
Metadata m = new Metadata(); Metadata m = new Metadata();
ClientCallImpl.prepareHeaders(m, decompressorRegistry, Codec.Identity.NONE); ClientCallImpl.prepareHeaders(m, decompressorRegistry, Codec.Identity.NONE, statsTraceCtx);
assertNull(m.get(GrpcUtil.MESSAGE_ENCODING_KEY)); assertNull(m.get(GrpcUtil.MESSAGE_ENCODING_KEY));
} }
@ -371,7 +420,7 @@ public class ClientCallImplTest {
} }
}, false); // not advertised }, false); // not advertised
ClientCallImpl.prepareHeaders(m, customRegistry, Codec.Identity.NONE); ClientCallImpl.prepareHeaders(m, customRegistry, Codec.Identity.NONE, statsTraceCtx);
Iterable<String> acceptedEncodings = Iterable<String> acceptedEncodings =
ACCEPT_ENCODING_SPLITER.split(m.get(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY)); ACCEPT_ENCODING_SPLITER.split(m.get(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY));
@ -386,12 +435,20 @@ public class ClientCallImplTest {
m.put(GrpcUtil.MESSAGE_ENCODING_KEY, "gzip"); m.put(GrpcUtil.MESSAGE_ENCODING_KEY, "gzip");
m.put(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY, "gzip"); m.put(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY, "gzip");
ClientCallImpl.prepareHeaders(m, DecompressorRegistry.emptyInstance(), Codec.Identity.NONE); ClientCallImpl.prepareHeaders(m, DecompressorRegistry.emptyInstance(), Codec.Identity.NONE,
statsTraceCtx);
assertNull(m.get(GrpcUtil.MESSAGE_ENCODING_KEY)); assertNull(m.get(GrpcUtil.MESSAGE_ENCODING_KEY));
assertNull(m.get(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY)); assertNull(m.get(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY));
} }
@Test
public void prepareHeaders_censusCtxAdded() {
Metadata m = new Metadata();
ClientCallImpl.prepareHeaders(m, decompressorRegistry, Codec.Identity.NONE, statsTraceCtx);
assertEquals(parentCensusContext, m.get(statsTraceCtx.getCensusHeader()));
}
@Test @Test
public void callerContextPropagatedToListener() throws Exception { public void callerContextPropagatedToListener() throws Exception {
// Attach the context which is recorded when the call is created // Attach the context which is recorded when the call is created
@ -402,6 +459,7 @@ public class ClientCallImplTest {
DESCRIPTOR, DESCRIPTOR,
new SerializingExecutor(Executors.newSingleThreadExecutor()), new SerializingExecutor(Executors.newSingleThreadExecutor()),
CallOptions.DEFAULT, CallOptions.DEFAULT,
statsTraceCtx,
provider, provider,
deadlineCancellationExecutor) deadlineCancellationExecutor)
.setDecompressorRegistry(decompressorRegistry); .setDecompressorRegistry(decompressorRegistry);
@ -475,6 +533,7 @@ public class ClientCallImplTest {
DESCRIPTOR, DESCRIPTOR,
new SerializingExecutor(Executors.newSingleThreadExecutor()), new SerializingExecutor(Executors.newSingleThreadExecutor()),
CallOptions.DEFAULT, CallOptions.DEFAULT,
statsTraceCtx,
provider, provider,
deadlineCancellationExecutor) deadlineCancellationExecutor)
.setDecompressorRegistry(decompressorRegistry); .setDecompressorRegistry(decompressorRegistry);
@ -504,6 +563,7 @@ public class ClientCallImplTest {
DESCRIPTOR, DESCRIPTOR,
new SerializingExecutor(Executors.newSingleThreadExecutor()), new SerializingExecutor(Executors.newSingleThreadExecutor()),
CallOptions.DEFAULT, CallOptions.DEFAULT,
statsTraceCtx,
provider, provider,
deadlineCancellationExecutor) deadlineCancellationExecutor)
.setDecompressorRegistry(decompressorRegistry); .setDecompressorRegistry(decompressorRegistry);
@ -522,6 +582,7 @@ public class ClientCallImplTest {
Status status = statusFuture.get(5, TimeUnit.SECONDS); Status status = statusFuture.get(5, TimeUnit.SECONDS);
assertEquals(Status.Code.CANCELLED, status.getCode()); assertEquals(Status.Code.CANCELLED, status.getCode());
assertSame(cause, status.getCause()); assertSame(cause, status.getCause());
assertStatusInStats(Status.Code.CANCELLED);
// Following operations should be no-op. // Following operations should be no-op.
call.request(1); call.request(1);
@ -547,6 +608,7 @@ public class ClientCallImplTest {
DESCRIPTOR, DESCRIPTOR,
new SerializingExecutor(Executors.newSingleThreadExecutor()), new SerializingExecutor(Executors.newSingleThreadExecutor()),
callOptions, callOptions,
statsTraceCtx,
provider, provider,
deadlineCancellationExecutor) deadlineCancellationExecutor)
.setDecompressorRegistry(decompressorRegistry); .setDecompressorRegistry(decompressorRegistry);
@ -554,6 +616,7 @@ public class ClientCallImplTest {
verify(transport, times(0)).newStream(any(MethodDescriptor.class), any(Metadata.class)); verify(transport, times(0)).newStream(any(MethodDescriptor.class), any(Metadata.class));
verify(callListener, timeout(1000)).onClose(statusCaptor.capture(), any(Metadata.class)); verify(callListener, timeout(1000)).onClose(statusCaptor.capture(), any(Metadata.class));
assertEquals(Status.Code.DEADLINE_EXCEEDED, statusCaptor.getValue().getCode()); assertEquals(Status.Code.DEADLINE_EXCEEDED, statusCaptor.getValue().getCode());
assertStatusInStats(Status.Code.DEADLINE_EXCEEDED);
verifyZeroInteractions(provider); verifyZeroInteractions(provider);
} }
@ -568,6 +631,7 @@ public class ClientCallImplTest {
DESCRIPTOR, DESCRIPTOR,
MoreExecutors.directExecutor(), MoreExecutors.directExecutor(),
CallOptions.DEFAULT, CallOptions.DEFAULT,
statsTraceCtx,
provider, provider,
deadlineCancellationExecutor); deadlineCancellationExecutor);
@ -595,6 +659,7 @@ public class ClientCallImplTest {
DESCRIPTOR, DESCRIPTOR,
MoreExecutors.directExecutor(), MoreExecutors.directExecutor(),
callOpts, callOpts,
statsTraceCtx,
provider, provider,
deadlineCancellationExecutor); deadlineCancellationExecutor);
@ -622,6 +687,7 @@ public class ClientCallImplTest {
DESCRIPTOR, DESCRIPTOR,
MoreExecutors.directExecutor(), MoreExecutors.directExecutor(),
callOpts, callOpts,
statsTraceCtx,
provider, provider,
deadlineCancellationExecutor); deadlineCancellationExecutor);
@ -645,6 +711,7 @@ public class ClientCallImplTest {
DESCRIPTOR, DESCRIPTOR,
MoreExecutors.directExecutor(), MoreExecutors.directExecutor(),
CallOptions.DEFAULT.withDeadline(Deadline.after(1000, TimeUnit.MILLISECONDS)), CallOptions.DEFAULT.withDeadline(Deadline.after(1000, TimeUnit.MILLISECONDS)),
statsTraceCtx,
provider, provider,
deadlineCancellationExecutor); deadlineCancellationExecutor);
@ -668,6 +735,7 @@ public class ClientCallImplTest {
DESCRIPTOR, DESCRIPTOR,
MoreExecutors.directExecutor(), MoreExecutors.directExecutor(),
CallOptions.DEFAULT, CallOptions.DEFAULT,
statsTraceCtx,
provider, provider,
deadlineCancellationExecutor); deadlineCancellationExecutor);
@ -687,6 +755,7 @@ public class ClientCallImplTest {
DESCRIPTOR, DESCRIPTOR,
MoreExecutors.directExecutor(), MoreExecutors.directExecutor(),
CallOptions.DEFAULT.withDeadline(Deadline.after(1000, TimeUnit.MILLISECONDS)), CallOptions.DEFAULT.withDeadline(Deadline.after(1000, TimeUnit.MILLISECONDS)),
statsTraceCtx,
provider, provider,
deadlineCancellationExecutor); deadlineCancellationExecutor);
call.start(callListener, new Metadata()); call.start(callListener, new Metadata());
@ -710,6 +779,7 @@ public class ClientCallImplTest {
DESCRIPTOR, DESCRIPTOR,
MoreExecutors.directExecutor(), MoreExecutors.directExecutor(),
CallOptions.DEFAULT, CallOptions.DEFAULT,
statsTraceCtx,
provider, provider,
deadlineCancellationExecutor); deadlineCancellationExecutor);
@ -726,6 +796,7 @@ public class ClientCallImplTest {
DESCRIPTOR, DESCRIPTOR,
MoreExecutors.directExecutor(), MoreExecutors.directExecutor(),
CallOptions.DEFAULT, CallOptions.DEFAULT,
statsTraceCtx,
provider, provider,
deadlineCancellationExecutor); deadlineCancellationExecutor);
final Exception cause = new Exception(); final Exception cause = new Exception();
@ -753,6 +824,14 @@ public class ClientCallImplTest {
assertSame(cause, status.getCause()); assertSame(cause, status.getCause());
} }
private void assertStatusInStats(Status.Code statusCode) {
CensusTestUtils.MetricsRecord record = censusCtxFactory.pollRecord();
assertNotNull(record);
TagValue statusTag = record.tags.get(RpcConstants.RPC_STATUS);
assertNotNull(statusTag);
assertEquals(statusCode.toString(), statusTag.toString());
}
private static class TestMarshaller<T> implements Marshaller<T> { private static class TestMarshaller<T> implements Marshaller<T> {
@Override @Override
public InputStream stream(T value) { public InputStream stream(T value) {

View File

@ -94,6 +94,12 @@ public class DelayedClientTransportTest {
private final CallOptions callOptions = CallOptions.DEFAULT.withAuthority("dummy_value"); private final CallOptions callOptions = CallOptions.DEFAULT.withAuthority("dummy_value");
private final CallOptions callOptions2 = CallOptions.DEFAULT.withAuthority("dummy_value2"); private final CallOptions callOptions2 = CallOptions.DEFAULT.withAuthority("dummy_value2");
private final StatsTraceContext statsTraceCtx = StatsTraceContext.newClientContext(
method.getFullMethodName(), NoopCensusContextFactory.INSTANCE,
GrpcUtil.STOPWATCH_SUPPLIER);
private final StatsTraceContext statsTraceCtx2 = StatsTraceContext.newClientContext(
method2.getFullMethodName(), NoopCensusContextFactory.INSTANCE,
GrpcUtil.STOPWATCH_SUPPLIER);
private final FakeClock fakeExecutor = new FakeClock(); private final FakeClock fakeExecutor = new FakeClock();
private final DelayedClientTransport delayedTransport = new DelayedClientTransport( private final DelayedClientTransport delayedTransport = new DelayedClientTransport(
@ -101,9 +107,11 @@ public class DelayedClientTransportTest {
@Before public void setUp() { @Before public void setUp() {
MockitoAnnotations.initMocks(this); MockitoAnnotations.initMocks(this);
when(mockRealTransport.newStream(same(method), same(headers), same(callOptions))) when(mockRealTransport.newStream(same(method), same(headers), same(callOptions),
same(statsTraceCtx)))
.thenReturn(mockRealStream); .thenReturn(mockRealStream);
when(mockRealTransport2.newStream(same(method2), same(headers2), same(callOptions2))) when(mockRealTransport2.newStream(same(method2), same(headers2), same(callOptions2),
same(statsTraceCtx2)))
.thenReturn(mockRealStream2); .thenReturn(mockRealStream2);
delayedTransport.start(transportListener); delayedTransport.start(transportListener);
} }
@ -113,8 +121,8 @@ public class DelayedClientTransportTest {
} }
@Test public void transportsAreUsedInOrder() { @Test public void transportsAreUsedInOrder() {
delayedTransport.newStream(method, headers, callOptions); delayedTransport.newStream(method, headers, callOptions, statsTraceCtx);
delayedTransport.newStream(method2, headers2, callOptions2); delayedTransport.newStream(method2, headers2, callOptions2, statsTraceCtx2);
assertEquals(0, fakeExecutor.numPendingTasks()); assertEquals(0, fakeExecutor.numPendingTasks());
delayedTransport.setTransportSupplier(new Supplier<ClientTransport>() { delayedTransport.setTransportSupplier(new Supplier<ClientTransport>() {
final Iterator<ClientTransport> it = final Iterator<ClientTransport> it =
@ -125,13 +133,15 @@ public class DelayedClientTransportTest {
} }
}); });
assertEquals(1, fakeExecutor.runDueTasks()); assertEquals(1, fakeExecutor.runDueTasks());
verify(mockRealTransport).newStream(same(method), same(headers), same(callOptions)); verify(mockRealTransport).newStream(same(method), same(headers), same(callOptions),
verify(mockRealTransport2).newStream(same(method2), same(headers2), same(callOptions2)); same(statsTraceCtx));
verify(mockRealTransport2).newStream(same(method2), same(headers2), same(callOptions2),
same(statsTraceCtx2));
} }
@Test public void streamStartThenSetTransport() { @Test public void streamStartThenSetTransport() {
assertFalse(delayedTransport.hasPendingStreams()); assertFalse(delayedTransport.hasPendingStreams());
ClientStream stream = delayedTransport.newStream(method, headers, callOptions); ClientStream stream = delayedTransport.newStream(method, headers, callOptions, statsTraceCtx);
stream.start(streamListener); stream.start(streamListener);
assertEquals(1, delayedTransport.getPendingStreamsCount()); assertEquals(1, delayedTransport.getPendingStreamsCount());
assertTrue(delayedTransport.hasPendingStreams()); assertTrue(delayedTransport.hasPendingStreams());
@ -141,7 +151,8 @@ public class DelayedClientTransportTest {
assertEquals(0, delayedTransport.getPendingStreamsCount()); assertEquals(0, delayedTransport.getPendingStreamsCount());
assertFalse(delayedTransport.hasPendingStreams()); assertFalse(delayedTransport.hasPendingStreams());
assertEquals(1, fakeExecutor.runDueTasks()); assertEquals(1, fakeExecutor.runDueTasks());
verify(mockRealTransport).newStream(same(method), same(headers), same(callOptions)); verify(mockRealTransport).newStream(same(method), same(headers), same(callOptions),
same(statsTraceCtx));
verify(mockRealStream).start(listenerCaptor.capture()); verify(mockRealStream).start(listenerCaptor.capture());
verifyNoMoreInteractions(streamListener); verifyNoMoreInteractions(streamListener);
listenerCaptor.getValue().onReady(); listenerCaptor.getValue().onReady();
@ -150,7 +161,7 @@ public class DelayedClientTransportTest {
} }
@Test public void newStreamThenSetTransportThenShutdown() { @Test public void newStreamThenSetTransportThenShutdown() {
ClientStream stream = delayedTransport.newStream(method, headers, callOptions); ClientStream stream = delayedTransport.newStream(method, headers, callOptions, statsTraceCtx);
assertEquals(1, delayedTransport.getPendingStreamsCount()); assertEquals(1, delayedTransport.getPendingStreamsCount());
assertTrue(stream instanceof DelayedStream); assertTrue(stream instanceof DelayedStream);
delayedTransport.setTransport(mockRealTransport); delayedTransport.setTransport(mockRealTransport);
@ -159,7 +170,8 @@ public class DelayedClientTransportTest {
verify(transportListener).transportShutdown(any(Status.class)); verify(transportListener).transportShutdown(any(Status.class));
verify(transportListener).transportTerminated(); verify(transportListener).transportTerminated();
assertEquals(1, fakeExecutor.runDueTasks()); assertEquals(1, fakeExecutor.runDueTasks());
verify(mockRealTransport).newStream(same(method), same(headers), same(callOptions)); verify(mockRealTransport).newStream(same(method), same(headers), same(callOptions),
same(statsTraceCtx));
stream.start(streamListener); stream.start(streamListener);
verify(mockRealStream).start(same(streamListener)); verify(mockRealStream).start(same(streamListener));
} }
@ -177,11 +189,12 @@ public class DelayedClientTransportTest {
delayedTransport.shutdown(); delayedTransport.shutdown();
verify(transportListener).transportShutdown(any(Status.class)); verify(transportListener).transportShutdown(any(Status.class));
verify(transportListener).transportTerminated(); verify(transportListener).transportTerminated();
ClientStream stream = delayedTransport.newStream(method, headers, callOptions); ClientStream stream = delayedTransport.newStream(method, headers, callOptions, statsTraceCtx);
assertEquals(0, delayedTransport.getPendingStreamsCount()); assertEquals(0, delayedTransport.getPendingStreamsCount());
stream.start(streamListener); stream.start(streamListener);
assertFalse(stream instanceof DelayedStream); assertFalse(stream instanceof DelayedStream);
verify(mockRealTransport).newStream(same(method), same(headers), same(callOptions)); verify(mockRealTransport).newStream(same(method), same(headers), same(callOptions),
same(statsTraceCtx));
verify(mockRealStream).start(same(streamListener)); verify(mockRealStream).start(same(streamListener));
} }
@ -190,11 +203,12 @@ public class DelayedClientTransportTest {
delayedTransport.shutdownNow(Status.UNAVAILABLE); delayedTransport.shutdownNow(Status.UNAVAILABLE);
verify(transportListener).transportShutdown(any(Status.class)); verify(transportListener).transportShutdown(any(Status.class));
verify(transportListener).transportTerminated(); verify(transportListener).transportTerminated();
ClientStream stream = delayedTransport.newStream(method, headers, callOptions); ClientStream stream = delayedTransport.newStream(method, headers, callOptions, statsTraceCtx);
assertEquals(0, delayedTransport.getPendingStreamsCount()); assertEquals(0, delayedTransport.getPendingStreamsCount());
stream.start(streamListener); stream.start(streamListener);
assertFalse(stream instanceof DelayedStream); assertFalse(stream instanceof DelayedStream);
verify(mockRealTransport).newStream(same(method), same(headers), same(callOptions)); verify(mockRealTransport).newStream(same(method), same(headers), same(callOptions),
same(statsTraceCtx));
verify(mockRealStream).start(same(streamListener)); verify(mockRealStream).start(same(streamListener));
} }
@ -290,10 +304,11 @@ public class DelayedClientTransportTest {
final Status cause = Status.UNAVAILABLE.withDescription("some error when connecting"); final Status cause = Status.UNAVAILABLE.withDescription("some error when connecting");
final CallOptions failFastCallOptions = CallOptions.DEFAULT; final CallOptions failFastCallOptions = CallOptions.DEFAULT;
final CallOptions waitForReadyCallOptions = CallOptions.DEFAULT.withWaitForReady(); final CallOptions waitForReadyCallOptions = CallOptions.DEFAULT.withWaitForReady();
final ClientStream ffStream = delayedTransport.newStream(method, headers, failFastCallOptions); final ClientStream ffStream = delayedTransport.newStream(method, headers, failFastCallOptions,
statsTraceCtx);
ffStream.start(streamListener); ffStream.start(streamListener);
delayedTransport.newStream(method, headers, waitForReadyCallOptions); delayedTransport.newStream(method, headers, waitForReadyCallOptions, statsTraceCtx);
delayedTransport.newStream(method, headers, failFastCallOptions); delayedTransport.newStream(method, headers, failFastCallOptions, statsTraceCtx);
assertEquals(3, delayedTransport.getPendingStreamsCount()); assertEquals(3, delayedTransport.getPendingStreamsCount());
delayedTransport.startBackoff(cause); delayedTransport.startBackoff(cause);
@ -315,13 +330,14 @@ public class DelayedClientTransportTest {
delayedTransport.startBackoff(cause); delayedTransport.startBackoff(cause);
assertTrue(delayedTransport.isInBackoffPeriod()); assertTrue(delayedTransport.isInBackoffPeriod());
final ClientStream ffStream = delayedTransport.newStream(method, headers, failFastCallOptions); final ClientStream ffStream = delayedTransport.newStream(method, headers, failFastCallOptions,
statsTraceCtx);
ffStream.start(streamListener); ffStream.start(streamListener);
assertEquals(0, delayedTransport.getPendingStreamsCount()); assertEquals(0, delayedTransport.getPendingStreamsCount());
verify(streamListener).closed(statusCaptor.capture(), any(Metadata.class)); verify(streamListener).closed(statusCaptor.capture(), any(Metadata.class));
assertEquals(cause, Status.fromThrowable(statusCaptor.getValue().getCause())); assertEquals(cause, Status.fromThrowable(statusCaptor.getValue().getCause()));
delayedTransport.newStream(method, headers, waitForReadyCallOptions); delayedTransport.newStream(method, headers, waitForReadyCallOptions, statsTraceCtx);
assertEquals(1, delayedTransport.getPendingStreamsCount()); assertEquals(1, delayedTransport.getPendingStreamsCount());
} }

View File

@ -138,7 +138,8 @@ public class ManagedChannelImplIdlenessTest {
CompressorRegistry.getDefaultInstance(), timerService, timer.getStopwatchSupplier(), CompressorRegistry.getDefaultInstance(), timerService, timer.getStopwatchSupplier(),
TimeUnit.SECONDS.toMillis(IDLE_TIMEOUT_SECONDS), TimeUnit.SECONDS.toMillis(IDLE_TIMEOUT_SECONDS),
executor.getScheduledExecutorService(), USER_AGENT, executor.getScheduledExecutorService(), USER_AGENT,
Collections.<ClientInterceptor>emptyList()); Collections.<ClientInterceptor>emptyList(),
NoopCensusContextFactory.INSTANCE);
newTransports = TestUtils.captureTransports(mockTransportFactory); newTransports = TestUtils.captureTransports(mockTransportFactory);
for (int i = 0; i < 2; i++) { for (int i = 0; i < 2; i++) {

View File

@ -78,6 +78,7 @@ import io.grpc.SecurityLevel;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.StringMarshaller; import io.grpc.StringMarshaller;
import io.grpc.TransportManager; import io.grpc.TransportManager;
import io.grpc.internal.testing.CensusTestUtils.FakeCensusContextFactory;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
@ -126,6 +127,7 @@ public class ManagedChannelImplTest {
private final ResolvedServerInfo server = new ResolvedServerInfo(socketAddress, Attributes.EMPTY); private final ResolvedServerInfo server = new ResolvedServerInfo(socketAddress, Attributes.EMPTY);
private final FakeClock timer = new FakeClock(); private final FakeClock timer = new FakeClock();
private final FakeClock executor = new FakeClock(); private final FakeClock executor = new FakeClock();
private final FakeCensusContextFactory censusCtxFactory = new FakeCensusContextFactory();
private SpyingLoadBalancerFactory loadBalancerFactory = private SpyingLoadBalancerFactory loadBalancerFactory =
new SpyingLoadBalancerFactory(PickFirstBalancerFactory.getInstance()); new SpyingLoadBalancerFactory(PickFirstBalancerFactory.getInstance());
@ -134,6 +136,8 @@ public class ManagedChannelImplTest {
private ManagedChannelImpl channel; private ManagedChannelImpl channel;
@Captor @Captor
private ArgumentCaptor<Status> statusCaptor; private ArgumentCaptor<Status> statusCaptor;
@Captor
private ArgumentCaptor<StatsTraceContext> statsTraceCtxCaptor;
@Mock @Mock
private ConnectionClientTransport mockTransport; private ConnectionClientTransport mockTransport;
@Mock @Mock
@ -161,7 +165,7 @@ public class ManagedChannelImplTest {
mockTransportFactory, DecompressorRegistry.getDefaultInstance(), mockTransportFactory, DecompressorRegistry.getDefaultInstance(),
CompressorRegistry.getDefaultInstance(), timerService, timer.getStopwatchSupplier(), CompressorRegistry.getDefaultInstance(), timerService, timer.getStopwatchSupplier(),
ManagedChannelImpl.IDLE_TIMEOUT_MILLIS_DISABLE, ManagedChannelImpl.IDLE_TIMEOUT_MILLIS_DISABLE,
executor.getScheduledExecutorService(), userAgent, interceptors); executor.getScheduledExecutorService(), userAgent, interceptors, censusCtxFactory);
// Force-exit the initial idle-mode // Force-exit the initial idle-mode
channel.exitIdleMode(); channel.exitIdleMode();
// Will start NameResolver in the scheduled executor // Will start NameResolver in the scheduled executor
@ -237,7 +241,8 @@ public class ManagedChannelImplTest {
when(mockTransportFactory.newClientTransport( when(mockTransportFactory.newClientTransport(
any(SocketAddress.class), any(String.class), any(String.class))) any(SocketAddress.class), any(String.class), any(String.class)))
.thenReturn(mockTransport); .thenReturn(mockTransport);
when(mockTransport.newStream(same(method), same(headers), same(CallOptions.DEFAULT))) when(mockTransport.newStream(same(method), same(headers), same(CallOptions.DEFAULT),
any(StatsTraceContext.class)))
.thenReturn(mockStream); .thenReturn(mockStream);
call.start(mockCallListener, headers); call.start(mockCallListener, headers);
timer.runDueTasks(); timer.runDueTasks();
@ -250,7 +255,10 @@ public class ManagedChannelImplTest {
transportListener.transportReady(); transportListener.transportReady();
executor.runDueTasks(); executor.runDueTasks();
verify(mockTransport).newStream(same(method), same(headers), same(CallOptions.DEFAULT)); verify(mockTransport).newStream(same(method), same(headers), same(CallOptions.DEFAULT),
statsTraceCtxCaptor.capture());
assertEquals(censusCtxFactory.pollContextOrFail(),
statsTraceCtxCaptor.getValue().getCensusContext());
verify(mockStream).start(streamListenerCaptor.capture()); verify(mockStream).start(streamListenerCaptor.capture());
verify(mockStream).setCompressor(isA(Compressor.class)); verify(mockStream).setCompressor(isA(Compressor.class));
ClientStreamListener streamListener = streamListenerCaptor.getValue(); ClientStreamListener streamListener = streamListenerCaptor.getValue();
@ -259,10 +267,15 @@ public class ManagedChannelImplTest {
ClientCall<String, Integer> call2 = channel.newCall(method, CallOptions.DEFAULT); ClientCall<String, Integer> call2 = channel.newCall(method, CallOptions.DEFAULT);
ClientStream mockStream2 = mock(ClientStream.class); ClientStream mockStream2 = mock(ClientStream.class);
Metadata headers2 = new Metadata(); Metadata headers2 = new Metadata();
when(mockTransport.newStream(same(method), same(headers2), same(CallOptions.DEFAULT))) when(mockTransport.newStream(same(method), same(headers2), same(CallOptions.DEFAULT),
any(StatsTraceContext.class)))
.thenReturn(mockStream2); .thenReturn(mockStream2);
call2.start(mockCallListener2, headers2); call2.start(mockCallListener2, headers2);
verify(mockTransport).newStream(same(method), same(headers2), same(CallOptions.DEFAULT)); verify(mockTransport).newStream(same(method), same(headers2), same(CallOptions.DEFAULT),
statsTraceCtxCaptor.capture());
assertEquals(censusCtxFactory.pollContextOrFail(),
statsTraceCtxCaptor.getValue().getCensusContext());
verify(mockStream2).start(streamListenerCaptor.capture()); verify(mockStream2).start(streamListenerCaptor.capture());
ClientStreamListener streamListener2 = streamListenerCaptor.getValue(); ClientStreamListener streamListener2 = streamListenerCaptor.getValue();
Metadata trailers = new Metadata(); Metadata trailers = new Metadata();
@ -323,7 +336,8 @@ public class ManagedChannelImplTest {
when(mockTransportFactory.newClientTransport( when(mockTransportFactory.newClientTransport(
any(SocketAddress.class), any(String.class), any(String.class))) any(SocketAddress.class), any(String.class), any(String.class)))
.thenReturn(mockTransport); .thenReturn(mockTransport);
when(mockTransport.newStream(same(method), same(headers), same(CallOptions.DEFAULT))) when(mockTransport.newStream(same(method), same(headers), same(CallOptions.DEFAULT),
any(StatsTraceContext.class)))
.thenReturn(mockStream); .thenReturn(mockStream);
call.start(mockCallListener, headers); call.start(mockCallListener, headers);
timer.runDueTasks(); timer.runDueTasks();
@ -336,7 +350,9 @@ public class ManagedChannelImplTest {
transportListener.transportReady(); transportListener.transportReady();
executor.runDueTasks(); executor.runDueTasks();
verify(mockTransport).newStream(same(method), same(headers), same(CallOptions.DEFAULT)); verify(mockTransport).newStream(same(method), same(headers), same(CallOptions.DEFAULT),
any(StatsTraceContext.class));
verify(mockStream).start(streamListenerCaptor.capture()); verify(mockStream).start(streamListenerCaptor.capture());
verify(mockStream).setCompressor(isA(Compressor.class)); verify(mockStream).setCompressor(isA(Compressor.class));
ClientStreamListener streamListener = streamListenerCaptor.getValue(); ClientStreamListener streamListener = streamListenerCaptor.getValue();
@ -391,7 +407,8 @@ public class ManagedChannelImplTest {
// Create transport and call // Create transport and call
ClientStream mockStream = mock(ClientStream.class); ClientStream mockStream = mock(ClientStream.class);
Metadata headers = new Metadata(); Metadata headers = new Metadata();
when(mockTransport.newStream(same(method), same(headers), same(CallOptions.DEFAULT))) when(mockTransport.newStream(same(method), same(headers), same(CallOptions.DEFAULT),
any(StatsTraceContext.class)))
.thenReturn(mockStream); .thenReturn(mockStream);
call.start(mockCallListener, headers); call.start(mockCallListener, headers);
timer.runDueTasks(); timer.runDueTasks();
@ -502,7 +519,8 @@ public class ManagedChannelImplTest {
public void callOptionsExecutor() { public void callOptionsExecutor() {
Metadata headers = new Metadata(); Metadata headers = new Metadata();
ClientStream mockStream = mock(ClientStream.class); ClientStream mockStream = mock(ClientStream.class);
when(mockTransport.newStream(same(method), same(headers), any(CallOptions.class))) when(mockTransport.newStream(same(method), same(headers), any(CallOptions.class),
any(StatsTraceContext.class)))
.thenReturn(mockStream); .thenReturn(mockStream);
FakeClock callExecutor = new FakeClock(); FakeClock callExecutor = new FakeClock();
createChannel(new FakeNameResolverFactory(true), NO_INTERCEPTOR); createChannel(new FakeNameResolverFactory(true), NO_INTERCEPTOR);
@ -520,7 +538,8 @@ public class ManagedChannelImplTest {
// Real streams are started in the channel's executor // Real streams are started in the channel's executor
assertEquals(1, executor.runDueTasks()); assertEquals(1, executor.runDueTasks());
verify(mockTransport).newStream(same(method), same(headers), same(options)); verify(mockTransport).newStream(same(method), same(headers), same(options),
any(StatsTraceContext.class));
verify(mockStream).start(streamListenerCaptor.capture()); verify(mockStream).start(streamListenerCaptor.capture());
ClientStreamListener streamListener = streamListenerCaptor.getValue(); ClientStreamListener streamListener = streamListenerCaptor.getValue();
Metadata trailers = new Metadata(); Metadata trailers = new Metadata();
@ -653,7 +672,8 @@ public class ManagedChannelImplTest {
final ConnectionClientTransport goodTransport = mock(ConnectionClientTransport.class); final ConnectionClientTransport goodTransport = mock(ConnectionClientTransport.class);
final ConnectionClientTransport badTransport = mock(ConnectionClientTransport.class); final ConnectionClientTransport badTransport = mock(ConnectionClientTransport.class);
when(goodTransport.newStream( when(goodTransport.newStream(
any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class))) any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class),
any(StatsTraceContext.class)))
.thenReturn(mock(ClientStream.class)); .thenReturn(mock(ClientStream.class));
when(mockTransportFactory.newClientTransport( when(mockTransportFactory.newClientTransport(
same(goodAddress), any(String.class), any(String.class))) same(goodAddress), any(String.class), any(String.class)))
@ -691,7 +711,8 @@ public class ManagedChannelImplTest {
goodTransportListenerCaptor.getValue().transportReady(); goodTransportListenerCaptor.getValue().transportReady();
executor.runDueTasks(); executor.runDueTasks();
verify(goodTransport).newStream(same(method), same(headers), same(CallOptions.DEFAULT)); verify(goodTransport).newStream(same(method), same(headers), same(CallOptions.DEFAULT),
any(StatsTraceContext.class));
// The bad transport was never used. // The bad transport was never used.
verify(badTransport, times(0)).newStream(any(MethodDescriptor.class), any(Metadata.class)); verify(badTransport, times(0)).newStream(any(MethodDescriptor.class), any(Metadata.class));
} }
@ -776,10 +797,12 @@ public class ManagedChannelImplTest {
final ConnectionClientTransport transport1 = mock(ConnectionClientTransport.class); final ConnectionClientTransport transport1 = mock(ConnectionClientTransport.class);
final ConnectionClientTransport transport2 = mock(ConnectionClientTransport.class); final ConnectionClientTransport transport2 = mock(ConnectionClientTransport.class);
when(transport1.newStream( when(transport1.newStream(
any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class))) any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class),
any(StatsTraceContext.class)))
.thenReturn(mock(ClientStream.class)); .thenReturn(mock(ClientStream.class));
when(transport2.newStream( when(transport2.newStream(
any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class))) any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class),
any(StatsTraceContext.class)))
.thenReturn(mock(ClientStream.class)); .thenReturn(mock(ClientStream.class));
when(mockTransportFactory.newClientTransport(same(addr1), any(String.class), any(String.class))) when(mockTransportFactory.newClientTransport(same(addr1), any(String.class), any(String.class)))
.thenReturn(transport1, transport2); .thenReturn(transport1, transport2);
@ -801,7 +824,8 @@ public class ManagedChannelImplTest {
transportListenerCaptor.getValue().transportReady(); transportListenerCaptor.getValue().transportReady();
executor.runDueTasks(); executor.runDueTasks();
verify(transport1).newStream(same(method), same(headers), same(CallOptions.DEFAULT)); verify(transport1).newStream(same(method), same(headers), same(CallOptions.DEFAULT),
any(StatsTraceContext.class));
transportListenerCaptor.getValue().transportShutdown(Status.UNAVAILABLE); transportListenerCaptor.getValue().transportShutdown(Status.UNAVAILABLE);
// Second call still use the first address, since it was successfully connected. // Second call still use the first address, since it was successfully connected.
@ -813,7 +837,8 @@ public class ManagedChannelImplTest {
transportListenerCaptor.getValue().transportReady(); transportListenerCaptor.getValue().transportReady();
executor.runDueTasks(); executor.runDueTasks();
verify(transport2).newStream(same(method), same(headers), same(CallOptions.DEFAULT)); verify(transport2).newStream(same(method), same(headers), same(CallOptions.DEFAULT),
any(StatsTraceContext.class));
} }
@Test @Test
@ -859,7 +884,8 @@ public class ManagedChannelImplTest {
return mock(ClientStream.class); return mock(ClientStream.class);
} }
}).when(transport).newStream( }).when(transport).newStream(
any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class)); any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class),
any(StatsTraceContext.class));
// First call will be on delayed transport. Only newCall() is run within the expected context, // First call will be on delayed transport. Only newCall() is run within the expected context,
// so that we can verify that the context is explicitly attached before calling newStream() and // so that we can verify that the context is explicitly attached before calling newStream() and
@ -892,11 +918,13 @@ public class ManagedChannelImplTest {
assertEquals(SecurityLevel.NONE, assertEquals(SecurityLevel.NONE,
attrsCaptor.getValue().get(CallCredentials.ATTR_SECURITY_LEVEL)); attrsCaptor.getValue().get(CallCredentials.ATTR_SECURITY_LEVEL));
verify(transport, never()).newStream( verify(transport, never()).newStream(
any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class)); any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class),
any(StatsTraceContext.class));
// newStream() is called after apply() is called // newStream() is called after apply() is called
applierCaptor.getValue().apply(new Metadata()); applierCaptor.getValue().apply(new Metadata());
verify(transport).newStream(same(method), any(Metadata.class), same(callOptions)); verify(transport).newStream(same(method), any(Metadata.class), same(callOptions),
any(StatsTraceContext.class));
assertEquals("testValue", testKey.get(newStreamContexts.poll())); assertEquals("testValue", testKey.get(newStreamContexts.poll()));
// The context should not live beyond the scope of newStream() and applyRequestMetadata() // The context should not live beyond the scope of newStream() and applyRequestMetadata()
assertNull(testKey.get()); assertNull(testKey.get());
@ -916,11 +944,13 @@ public class ManagedChannelImplTest {
attrsCaptor.getValue().get(CallCredentials.ATTR_SECURITY_LEVEL)); attrsCaptor.getValue().get(CallCredentials.ATTR_SECURITY_LEVEL));
// This is from the first call // This is from the first call
verify(transport).newStream( verify(transport).newStream(
any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class)); any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class),
any(StatsTraceContext.class));
// Still, newStream() is called after apply() is called // Still, newStream() is called after apply() is called
applierCaptor.getValue().apply(new Metadata()); applierCaptor.getValue().apply(new Metadata());
verify(transport, times(2)).newStream(same(method), any(Metadata.class), same(callOptions)); verify(transport, times(2)).newStream(same(method), any(Metadata.class), same(callOptions),
any(StatsTraceContext.class));
assertEquals("testValue", testKey.get(newStreamContexts.poll())); assertEquals("testValue", testKey.get(newStreamContexts.poll()));
assertNull(testKey.get()); assertNull(testKey.get());

View File

@ -101,6 +101,12 @@ public class ManagedChannelImplTransportManagerTest {
new StringMarshaller(), new StringMarshaller()); new StringMarshaller(), new StringMarshaller());
private final CallOptions callOptions = CallOptions.DEFAULT.withAuthority("dummy_value"); private final CallOptions callOptions = CallOptions.DEFAULT.withAuthority("dummy_value");
private final CallOptions callOptions2 = CallOptions.DEFAULT.withAuthority("dummy_value2"); private final CallOptions callOptions2 = CallOptions.DEFAULT.withAuthority("dummy_value2");
private final StatsTraceContext statsTraceCtx = StatsTraceContext.newClientContext(
method.getFullMethodName(), NoopCensusContextFactory.INSTANCE,
GrpcUtil.STOPWATCH_SUPPLIER);
private final StatsTraceContext statsTraceCtx2 = StatsTraceContext.newClientContext(
method2.getFullMethodName(), NoopCensusContextFactory.INSTANCE,
GrpcUtil.STOPWATCH_SUPPLIER);
private ManagedChannelImpl channel; private ManagedChannelImpl channel;
@ -135,7 +141,8 @@ public class ManagedChannelImplTransportManagerTest {
mockTransportFactory, DecompressorRegistry.getDefaultInstance(), mockTransportFactory, DecompressorRegistry.getDefaultInstance(),
CompressorRegistry.getDefaultInstance(), GrpcUtil.TIMER_SERVICE, CompressorRegistry.getDefaultInstance(), GrpcUtil.TIMER_SERVICE,
GrpcUtil.STOPWATCH_SUPPLIER, ManagedChannelImpl.IDLE_TIMEOUT_MILLIS_DISABLE, GrpcUtil.STOPWATCH_SUPPLIER, ManagedChannelImpl.IDLE_TIMEOUT_MILLIS_DISABLE,
executor, USER_AGENT, Collections.<ClientInterceptor>emptyList()); executor, USER_AGENT, Collections.<ClientInterceptor>emptyList(),
NoopCensusContextFactory.INSTANCE);
ArgumentCaptor<TransportManager<ClientTransport>> tmCaptor ArgumentCaptor<TransportManager<ClientTransport>> tmCaptor
= ArgumentCaptor.forClass(null); = ArgumentCaptor.forClass(null);
@ -195,7 +202,7 @@ public class ManagedChannelImplTransportManagerTest {
// Subsequent getTransport() will use the next address // Subsequent getTransport() will use the next address
ClientTransport t2 = tm.getTransport(addressGroup); ClientTransport t2 = tm.getTransport(addressGroup);
assertNotNull(t2); assertNotNull(t2);
t2.newStream(method, new Metadata(), callOptions); t2.newStream(method, new Metadata(), callOptions, statsTraceCtx);
// Will keep the previous back-off policy, and not consult back-off policy // Will keep the previous back-off policy, and not consult back-off policy
verify(mockTransportFactory, timeout(1000)).newClientTransport(addr2, AUTHORITY, USER_AGENT); verify(mockTransportFactory, timeout(1000)).newClientTransport(addr2, AUTHORITY, USER_AGENT);
verify(mockBackoffPolicyProvider, times(backoffReset)).get(); verify(mockBackoffPolicyProvider, times(backoffReset)).get();
@ -203,8 +210,8 @@ public class ManagedChannelImplTransportManagerTest {
ClientTransport rt2 = transportInfo.transport; ClientTransport rt2 = transportInfo.transport;
// Make the second transport ready // Make the second transport ready
transportInfo.listener.transportReady(); transportInfo.listener.transportReady();
verify(rt2, timeout(1000)).newStream(same(method), any(Metadata.class), verify(rt2, timeout(1000)).newStream(
same(callOptions)); same(method), any(Metadata.class), same(callOptions), same(statsTraceCtx));
verify(mockNameResolver, times(0)).refresh(); verify(mockNameResolver, times(0)).refresh();
// Disconnect the second transport // Disconnect the second transport
transportInfo.listener.transportShutdown(Status.UNAVAILABLE); transportInfo.listener.transportShutdown(Status.UNAVAILABLE);
@ -213,7 +220,7 @@ public class ManagedChannelImplTransportManagerTest {
// Subsequent getTransport() will use the first address, since last attempt was successful. // Subsequent getTransport() will use the first address, since last attempt was successful.
ClientTransport t3 = tm.getTransport(addressGroup); ClientTransport t3 = tm.getTransport(addressGroup);
t3.newStream(method2, new Metadata(), callOptions2); t3.newStream(method2, new Metadata(), callOptions2, statsTraceCtx2);
verify(mockTransportFactory, timeout(1000).times(2)) verify(mockTransportFactory, timeout(1000).times(2))
.newClientTransport(addr1, AUTHORITY, USER_AGENT); .newClientTransport(addr1, AUTHORITY, USER_AGENT);
// Still no back-off policy creation, because an address succeeded. // Still no back-off policy creation, because an address succeeded.
@ -221,8 +228,8 @@ public class ManagedChannelImplTransportManagerTest {
transportInfo = transports.poll(1, TimeUnit.SECONDS); transportInfo = transports.poll(1, TimeUnit.SECONDS);
ClientTransport rt3 = transportInfo.transport; ClientTransport rt3 = transportInfo.transport;
transportInfo.listener.transportReady(); transportInfo.listener.transportReady();
verify(rt3, timeout(1000)).newStream(same(method2), any(Metadata.class), verify(rt3, timeout(1000)).newStream(
same(callOptions2)); same(method2), any(Metadata.class), same(callOptions2), same(statsTraceCtx2));
verify(rt1, times(0)).newStream(any(MethodDescriptor.class), any(Metadata.class)); verify(rt1, times(0)).newStream(any(MethodDescriptor.class), any(Metadata.class));
// Back-off policy was never consulted. // Back-off policy was never consulted.
@ -283,7 +290,7 @@ public class ManagedChannelImplTransportManagerTest {
ClientTransport t4 = tm.getTransport(addressGroup); ClientTransport t4 = tm.getTransport(addressGroup);
assertNotNull(t4); assertNotNull(t4);
// If backoff's DelayedTransport is still active, this is necessary. Otherwise it would be racy. // If backoff's DelayedTransport is still active, this is necessary. Otherwise it would be racy.
t4.newStream(method, new Metadata(), CallOptions.DEFAULT.withWaitForReady()); t4.newStream(method, new Metadata(), CallOptions.DEFAULT.withWaitForReady(), statsTraceCtx);
verify(mockTransportFactory, timeout(1000).times(++transportsAddr1)) verify(mockTransportFactory, timeout(1000).times(++transportsAddr1))
.newClientTransport(addr1, AUTHORITY, USER_AGENT); .newClientTransport(addr1, AUTHORITY, USER_AGENT);
// Back-off policy was reset and consulted. // Back-off policy was reset and consulted.

View File

@ -33,6 +33,7 @@ package io.grpc.internal;
import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE; import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.anyInt; import static org.mockito.Matchers.anyInt;
import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.atLeastOnce;
@ -42,14 +43,18 @@ import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions;
import com.google.census.RpcConstants;
import com.google.common.base.Charsets; import com.google.common.base.Charsets;
import com.google.common.io.ByteStreams; import com.google.common.io.ByteStreams;
import com.google.common.primitives.Bytes; import com.google.common.primitives.Bytes;
import io.grpc.Codec; import io.grpc.Codec;
import io.grpc.Status;
import io.grpc.StatusRuntimeException; import io.grpc.StatusRuntimeException;
import io.grpc.internal.MessageDeframer.Listener; import io.grpc.internal.MessageDeframer.Listener;
import io.grpc.internal.MessageDeframer.SizeEnforcingInputStream; import io.grpc.internal.MessageDeframer.SizeEnforcingInputStream;
import io.grpc.internal.testing.CensusTestUtils.FakeCensusContextFactory;
import io.grpc.internal.testing.CensusTestUtils.MetricsRecord;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
@ -76,8 +81,15 @@ public class MessageDeframerTest {
@Rule public final ExpectedException thrown = ExpectedException.none(); @Rule public final ExpectedException thrown = ExpectedException.none();
private Listener listener = mock(Listener.class); private Listener listener = mock(Listener.class);
private final FakeCensusContextFactory censusCtxFactory = new FakeCensusContextFactory();
// MessageFramerTest tests with a server-side StatsTraceContext, so here we test with a
// client-side StatsTraceContext.
private StatsTraceContext statsTraceCtx = StatsTraceContext.newClientContext(
"service/method", censusCtxFactory, GrpcUtil.STOPWATCH_SUPPLIER);
private MessageDeframer deframer = new MessageDeframer(listener, Codec.Identity.NONE, private MessageDeframer deframer = new MessageDeframer(listener, Codec.Identity.NONE,
DEFAULT_MAX_MESSAGE_SIZE); DEFAULT_MAX_MESSAGE_SIZE, statsTraceCtx);
private ArgumentCaptor<InputStream> messages = ArgumentCaptor.forClass(InputStream.class); private ArgumentCaptor<InputStream> messages = ArgumentCaptor.forClass(InputStream.class);
@Test @Test
@ -88,6 +100,7 @@ public class MessageDeframerTest {
assertEquals(Bytes.asList(new byte[]{3, 14}), bytes(messages)); assertEquals(Bytes.asList(new byte[]{3, 14}), bytes(messages));
verify(listener, atLeastOnce()).bytesRead(anyInt()); verify(listener, atLeastOnce()).bytesRead(anyInt());
verifyNoMoreInteractions(listener); verifyNoMoreInteractions(listener);
checkStats(2, 2);
} }
@Test @Test
@ -101,6 +114,7 @@ public class MessageDeframerTest {
verify(listener, atLeastOnce()).bytesRead(anyInt()); verify(listener, atLeastOnce()).bytesRead(anyInt());
assertEquals(Bytes.asList(new byte[] {14, 15}), bytes(streams.get(1))); assertEquals(Bytes.asList(new byte[] {14, 15}), bytes(streams.get(1)));
verifyNoMoreInteractions(listener); verifyNoMoreInteractions(listener);
checkStats(3, 3);
} }
@Test @Test
@ -112,6 +126,7 @@ public class MessageDeframerTest {
verify(listener).endOfStream(); verify(listener).endOfStream();
verify(listener, atLeastOnce()).bytesRead(anyInt()); verify(listener, atLeastOnce()).bytesRead(anyInt());
verifyNoMoreInteractions(listener); verifyNoMoreInteractions(listener);
checkStats(1, 1);
} }
@Test @Test
@ -119,6 +134,7 @@ public class MessageDeframerTest {
deframer.deframe(buffer(new byte[0]), true); deframer.deframe(buffer(new byte[0]), true);
verify(listener).endOfStream(); verify(listener).endOfStream();
verifyNoMoreInteractions(listener); verifyNoMoreInteractions(listener);
checkStats(0, 0);
} }
@Test @Test
@ -133,6 +149,7 @@ public class MessageDeframerTest {
verify(listener, atLeastOnce()).bytesRead(anyInt()); verify(listener, atLeastOnce()).bytesRead(anyInt());
assertTrue(deframer.isStalled()); assertTrue(deframer.isStalled());
verifyNoMoreInteractions(listener); verifyNoMoreInteractions(listener);
checkStats(7, 7);
} }
@Test @Test
@ -148,6 +165,7 @@ public class MessageDeframerTest {
verify(listener, atLeastOnce()).bytesRead(anyInt()); verify(listener, atLeastOnce()).bytesRead(anyInt());
assertTrue(deframer.isStalled()); assertTrue(deframer.isStalled());
verifyNoMoreInteractions(listener); verifyNoMoreInteractions(listener);
checkStats(1, 1);
} }
@Test @Test
@ -158,6 +176,7 @@ public class MessageDeframerTest {
assertEquals(Bytes.asList(), bytes(messages)); assertEquals(Bytes.asList(), bytes(messages));
verify(listener, atLeastOnce()).bytesRead(anyInt()); verify(listener, atLeastOnce()).bytesRead(anyInt());
verifyNoMoreInteractions(listener); verifyNoMoreInteractions(listener);
checkStats(0, 0);
} }
@Test @Test
@ -169,6 +188,7 @@ public class MessageDeframerTest {
assertEquals(Bytes.asList(new byte[1000]), bytes(messages)); assertEquals(Bytes.asList(new byte[1000]), bytes(messages));
verify(listener, atLeastOnce()).bytesRead(anyInt()); verify(listener, atLeastOnce()).bytesRead(anyInt());
verifyNoMoreInteractions(listener); verifyNoMoreInteractions(listener);
checkStats(1000, 1000);
} }
@Test @Test
@ -182,11 +202,13 @@ public class MessageDeframerTest {
verify(listener).endOfStream(); verify(listener).endOfStream();
verify(listener, atLeastOnce()).bytesRead(anyInt()); verify(listener, atLeastOnce()).bytesRead(anyInt());
verifyNoMoreInteractions(listener); verifyNoMoreInteractions(listener);
checkStats(1, 1);
} }
@Test @Test
public void compressed() { public void compressed() {
deframer = new MessageDeframer(listener, new Codec.Gzip(), DEFAULT_MAX_MESSAGE_SIZE); deframer = new MessageDeframer(listener, new Codec.Gzip(), DEFAULT_MAX_MESSAGE_SIZE,
statsTraceCtx);
deframer.request(1); deframer.request(1);
byte[] payload = compress(new byte[1000]); byte[] payload = compress(new byte[1000]);
@ -197,6 +219,7 @@ public class MessageDeframerTest {
assertEquals(Bytes.asList(new byte[1000]), bytes(messages)); assertEquals(Bytes.asList(new byte[1000]), bytes(messages));
verify(listener, atLeastOnce()).bytesRead(anyInt()); verify(listener, atLeastOnce()).bytesRead(anyInt());
verifyNoMoreInteractions(listener); verifyNoMoreInteractions(listener);
checkStats(payload.length, 1000);
} }
@Test @Test
@ -222,27 +245,34 @@ public class MessageDeframerTest {
@Test @Test
public void sizeEnforcingInputStream_readByteBelowLimit() throws IOException { public void sizeEnforcingInputStream_readByteBelowLimit() throws IOException {
ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8)); ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8));
SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 4); SizeEnforcingInputStream stream =
new MessageDeframer.SizeEnforcingInputStream(in, 4, statsTraceCtx);
while (stream.read() != -1) {} while (stream.read() != -1) {}
stream.close(); stream.close();
// SizeEnforcingInputStream only reports uncompressed bytes
checkStats(0, 3);
} }
@Test @Test
public void sizeEnforcingInputStream_readByteAtLimit() throws IOException { public void sizeEnforcingInputStream_readByteAtLimit() throws IOException {
ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8)); ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8));
SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 3); SizeEnforcingInputStream stream =
new MessageDeframer.SizeEnforcingInputStream(in, 3, statsTraceCtx);
while (stream.read() != -1) {} while (stream.read() != -1) {}
stream.close(); stream.close();
// SizeEnforcingInputStream only reports uncompressed bytes
checkStats(0, 3);
} }
@Test @Test
public void sizeEnforcingInputStream_readByteAboveLimit() throws IOException { public void sizeEnforcingInputStream_readByteAboveLimit() throws IOException {
ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8)); ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8));
SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 2); SizeEnforcingInputStream stream =
new MessageDeframer.SizeEnforcingInputStream(in, 2, statsTraceCtx);
thrown.expect(StatusRuntimeException.class); thrown.expect(StatusRuntimeException.class);
thrown.expectMessage("INTERNAL: Compressed frame exceeds"); thrown.expectMessage("INTERNAL: Compressed frame exceeds");
@ -256,31 +286,38 @@ public class MessageDeframerTest {
@Test @Test
public void sizeEnforcingInputStream_readBelowLimit() throws IOException { public void sizeEnforcingInputStream_readBelowLimit() throws IOException {
ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8)); ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8));
SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 4); SizeEnforcingInputStream stream =
new MessageDeframer.SizeEnforcingInputStream(in, 4, statsTraceCtx);
byte[] buf = new byte[10]; byte[] buf = new byte[10];
int read = stream.read(buf, 0, buf.length); int read = stream.read(buf, 0, buf.length);
assertEquals(3, read); assertEquals(3, read);
stream.close(); stream.close();
// SizeEnforcingInputStream only reports uncompressed bytes
checkStats(0, 3);
} }
@Test @Test
public void sizeEnforcingInputStream_readAtLimit() throws IOException { public void sizeEnforcingInputStream_readAtLimit() throws IOException {
ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8)); ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8));
SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 3); SizeEnforcingInputStream stream =
new MessageDeframer.SizeEnforcingInputStream(in, 3, statsTraceCtx);
byte[] buf = new byte[10]; byte[] buf = new byte[10];
int read = stream.read(buf, 0, buf.length); int read = stream.read(buf, 0, buf.length);
assertEquals(3, read); assertEquals(3, read);
stream.close(); stream.close();
// SizeEnforcingInputStream only reports uncompressed bytes
checkStats(0, 3);
} }
@Test @Test
public void sizeEnforcingInputStream_readAboveLimit() throws IOException { public void sizeEnforcingInputStream_readAboveLimit() throws IOException {
ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8)); ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8));
SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 2); SizeEnforcingInputStream stream =
new MessageDeframer.SizeEnforcingInputStream(in, 2, statsTraceCtx);
byte[] buf = new byte[10]; byte[] buf = new byte[10];
thrown.expect(StatusRuntimeException.class); thrown.expect(StatusRuntimeException.class);
@ -295,30 +332,37 @@ public class MessageDeframerTest {
@Test @Test
public void sizeEnforcingInputStream_skipBelowLimit() throws IOException { public void sizeEnforcingInputStream_skipBelowLimit() throws IOException {
ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8)); ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8));
SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 4); SizeEnforcingInputStream stream =
new MessageDeframer.SizeEnforcingInputStream(in, 4, statsTraceCtx);
long skipped = stream.skip(4); long skipped = stream.skip(4);
assertEquals(3, skipped); assertEquals(3, skipped);
stream.close(); stream.close();
// SizeEnforcingInputStream only reports uncompressed bytes
checkStats(0, 3);
} }
@Test @Test
public void sizeEnforcingInputStream_skipAtLimit() throws IOException { public void sizeEnforcingInputStream_skipAtLimit() throws IOException {
ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8)); ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8));
SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 3); SizeEnforcingInputStream stream =
new MessageDeframer.SizeEnforcingInputStream(in, 3, statsTraceCtx);
long skipped = stream.skip(4); long skipped = stream.skip(4);
assertEquals(3, skipped); assertEquals(3, skipped);
stream.close(); stream.close();
// SizeEnforcingInputStream only reports uncompressed bytes
checkStats(0, 3);
} }
@Test @Test
public void sizeEnforcingInputStream_skipAboveLimit() throws IOException { public void sizeEnforcingInputStream_skipAboveLimit() throws IOException {
ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8)); ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8));
SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 2); SizeEnforcingInputStream stream =
new MessageDeframer.SizeEnforcingInputStream(in, 2, statsTraceCtx);
thrown.expect(StatusRuntimeException.class); thrown.expect(StatusRuntimeException.class);
thrown.expectMessage("INTERNAL: Compressed frame exceeds"); thrown.expectMessage("INTERNAL: Compressed frame exceeds");
@ -332,7 +376,8 @@ public class MessageDeframerTest {
@Test @Test
public void sizeEnforcingInputStream_markReset() throws IOException { public void sizeEnforcingInputStream_markReset() throws IOException {
ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8)); ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8));
SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 3); SizeEnforcingInputStream stream =
new MessageDeframer.SizeEnforcingInputStream(in, 3, statsTraceCtx);
// stream currently looks like: |foo // stream currently looks like: |foo
stream.skip(1); // f|oo stream.skip(1); // f|oo
stream.mark(10); // any large number will work. stream.mark(10); // any large number will work.
@ -342,6 +387,25 @@ public class MessageDeframerTest {
assertEquals(2, skipped); assertEquals(2, skipped);
stream.close(); stream.close();
// SizeEnforcingInputStream only reports uncompressed bytes
checkStats(0, 3);
}
private void checkStats(long wireBytesReceived, long uncompressedBytesReceived) {
statsTraceCtx.callEnded(Status.OK);
MetricsRecord record = censusCtxFactory.pollRecord();
assertEquals(0, record.getMetricAsLongOrFail(
RpcConstants.RPC_CLIENT_REQUEST_BYTES));
assertEquals(0, record.getMetricAsLongOrFail(
RpcConstants.RPC_CLIENT_UNCOMPRESSED_REQUEST_BYTES));
assertEquals(wireBytesReceived,
record.getMetricAsLongOrFail(RpcConstants.RPC_CLIENT_RESPONSE_BYTES));
assertEquals(uncompressedBytesReceived,
record.getMetricAsLongOrFail(RpcConstants.RPC_CLIENT_UNCOMPRESSED_RESPONSE_BYTES));
assertNull(record.getMetric(RpcConstants.RPC_SERVER_REQUEST_BYTES));
assertNull(record.getMetric(RpcConstants.RPC_SERVER_RESPONSE_BYTES));
assertNull(record.getMetric(RpcConstants.RPC_SERVER_UNCOMPRESSED_REQUEST_BYTES));
assertNull(record.getMetric(RpcConstants.RPC_SERVER_UNCOMPRESSED_RESPONSE_BYTES));
} }
private static List<Byte> bytes(ArgumentCaptor<InputStream> captor) { private static List<Byte> bytes(ArgumentCaptor<InputStream> captor) {

View File

@ -32,6 +32,7 @@
package io.grpc.internal; package io.grpc.internal;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail; import static org.junit.Assert.fail;
import static org.mockito.Matchers.eq; import static org.mockito.Matchers.eq;
@ -40,7 +41,14 @@ import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.verifyZeroInteractions; import static org.mockito.Mockito.verifyZeroInteractions;
import com.google.census.RpcConstants;
import io.grpc.Codec; import io.grpc.Codec;
import io.grpc.Metadata;
import io.grpc.Status;
import io.grpc.internal.StatsTraceContext;
import io.grpc.internal.testing.CensusTestUtils.FakeCensusContextFactory;
import io.grpc.internal.testing.CensusTestUtils.MetricsRecord;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
@ -70,12 +78,19 @@ public class MessageFramerTest {
private ArgumentCaptor<ByteWritableBuffer> frameCaptor; private ArgumentCaptor<ByteWritableBuffer> frameCaptor;
private BytesWritableBufferAllocator allocator = private BytesWritableBufferAllocator allocator =
new BytesWritableBufferAllocator(1000, 1000); new BytesWritableBufferAllocator(1000, 1000);
private FakeCensusContextFactory censusCtxFactory;
private StatsTraceContext statsTraceCtx;
/** Set up for test. */ /** Set up for test. */
@Before @Before
public void setUp() { public void setUp() {
MockitoAnnotations.initMocks(this); MockitoAnnotations.initMocks(this);
framer = new MessageFramer(sink, allocator); censusCtxFactory = new FakeCensusContextFactory();
// MessageDeframerTest tests with a client-side StatsTraceContext, so here we test with a
// server-side StatsTraceContext.
statsTraceCtx = StatsTraceContext.newServerContext(
"service/method", censusCtxFactory, new Metadata(), GrpcUtil.STOPWATCH_SUPPLIER);
framer = new MessageFramer(sink, allocator, statsTraceCtx);
} }
@Test @Test
@ -83,9 +98,11 @@ public class MessageFramerTest {
writeKnownLength(framer, new byte[]{3, 14}); writeKnownLength(framer, new byte[]{3, 14});
verifyNoMoreInteractions(sink); verifyNoMoreInteractions(sink);
framer.flush(); framer.flush();
verify(sink).deliverFrame(toWriteBuffer(new byte[] {0, 0, 0, 0, 2, 3, 14}), false, true); verify(sink).deliverFrame(toWriteBuffer(new byte[] {0, 0, 0, 0, 2, 3, 14}), false, true);
assertEquals(1, allocator.allocCount); assertEquals(1, allocator.allocCount);
verifyNoMoreInteractions(sink); verifyNoMoreInteractions(sink);
checkStats(2, 2);
} }
@Test @Test
@ -97,6 +114,7 @@ public class MessageFramerTest {
verify(sink).deliverFrame(toWriteBuffer(new byte[] {3, 14}), false, true); verify(sink).deliverFrame(toWriteBuffer(new byte[] {3, 14}), false, true);
assertEquals(2, allocator.allocCount); assertEquals(2, allocator.allocCount);
verifyNoMoreInteractions(sink); verifyNoMoreInteractions(sink);
checkStats(2, 2);
} }
@Test @Test
@ -110,6 +128,7 @@ public class MessageFramerTest {
toWriteBuffer(new byte[] {0, 0, 0, 0, 1, 3, 0, 0, 0, 0, 1, 14}), false, true); toWriteBuffer(new byte[] {0, 0, 0, 0, 1, 3, 0, 0, 0, 0, 1, 14}), false, true);
verifyNoMoreInteractions(sink); verifyNoMoreInteractions(sink);
assertEquals(1, allocator.allocCount); assertEquals(1, allocator.allocCount);
checkStats(2, 2);
} }
@Test @Test
@ -121,6 +140,7 @@ public class MessageFramerTest {
toWriteBuffer(new byte[] {0, 0, 0, 0, 7, 3, 14, 1, 5, 9, 2, 6}), true, true); toWriteBuffer(new byte[] {0, 0, 0, 0, 7, 3, 14, 1, 5, 9, 2, 6}), true, true);
verifyNoMoreInteractions(sink); verifyNoMoreInteractions(sink);
assertEquals(1, allocator.allocCount); assertEquals(1, allocator.allocCount);
checkStats(7, 7);
} }
@Test @Test
@ -129,12 +149,13 @@ public class MessageFramerTest {
verify(sink).deliverFrame(null, true, true); verify(sink).deliverFrame(null, true, true);
verifyNoMoreInteractions(sink); verifyNoMoreInteractions(sink);
assertEquals(0, allocator.allocCount); assertEquals(0, allocator.allocCount);
checkStats(0, 0);
} }
@Test @Test
public void payloadSplitBetweenSinks() { public void payloadSplitBetweenSinks() {
allocator = new BytesWritableBufferAllocator(12, 12); allocator = new BytesWritableBufferAllocator(12, 12);
framer = new MessageFramer(sink, allocator); framer = new MessageFramer(sink, allocator, statsTraceCtx);
writeKnownLength(framer, new byte[]{3, 14, 1, 5, 9, 2, 6, 5}); writeKnownLength(framer, new byte[]{3, 14, 1, 5, 9, 2, 6, 5});
verify(sink).deliverFrame( verify(sink).deliverFrame(
toWriteBuffer(new byte[] {0, 0, 0, 0, 8, 3, 14, 1, 5, 9, 2, 6}), false, false); toWriteBuffer(new byte[] {0, 0, 0, 0, 8, 3, 14, 1, 5, 9, 2, 6}), false, false);
@ -144,12 +165,13 @@ public class MessageFramerTest {
verify(sink).deliverFrame(toWriteBuffer(new byte[] {5}), false, true); verify(sink).deliverFrame(toWriteBuffer(new byte[] {5}), false, true);
verifyNoMoreInteractions(sink); verifyNoMoreInteractions(sink);
assertEquals(2, allocator.allocCount); assertEquals(2, allocator.allocCount);
checkStats(8, 8);
} }
@Test @Test
public void frameHeaderSplitBetweenSinks() { public void frameHeaderSplitBetweenSinks() {
allocator = new BytesWritableBufferAllocator(12, 12); allocator = new BytesWritableBufferAllocator(12, 12);
framer = new MessageFramer(sink, allocator); framer = new MessageFramer(sink, allocator, statsTraceCtx);
writeKnownLength(framer, new byte[]{3, 14, 1}); writeKnownLength(framer, new byte[]{3, 14, 1});
writeKnownLength(framer, new byte[]{3}); writeKnownLength(framer, new byte[]{3});
verify(sink).deliverFrame( verify(sink).deliverFrame(
@ -160,6 +182,7 @@ public class MessageFramerTest {
verify(sink).deliverFrame(toWriteBufferWithMinSize(new byte[] {1, 3}, 12), false, true); verify(sink).deliverFrame(toWriteBufferWithMinSize(new byte[] {1, 3}, 12), false, true);
verifyNoMoreInteractions(sink); verifyNoMoreInteractions(sink);
assertEquals(2, allocator.allocCount); assertEquals(2, allocator.allocCount);
checkStats(4, 4);
} }
@Test @Test
@ -168,6 +191,7 @@ public class MessageFramerTest {
framer.flush(); framer.flush();
verify(sink).deliverFrame(toWriteBuffer(new byte[] {0, 0, 0, 0, 0}), false, true); verify(sink).deliverFrame(toWriteBuffer(new byte[] {0, 0, 0, 0, 0}), false, true);
assertEquals(1, allocator.allocCount); assertEquals(1, allocator.allocCount);
checkStats(0, 0);
} }
@Test @Test
@ -178,6 +202,7 @@ public class MessageFramerTest {
verify(sink).deliverFrame(toWriteBuffer(new byte[] {0, 0, 0, 0, 0}), false, true); verify(sink).deliverFrame(toWriteBuffer(new byte[] {0, 0, 0, 0, 0}), false, true);
// One alloc for the header // One alloc for the header
assertEquals(1, allocator.allocCount); assertEquals(1, allocator.allocCount);
checkStats(0, 0);
} }
@Test @Test
@ -188,12 +213,13 @@ public class MessageFramerTest {
verify(sink).deliverFrame(toWriteBuffer(new byte[] {0, 0, 0, 0, 2, 3, 14}), false, true); verify(sink).deliverFrame(toWriteBuffer(new byte[] {0, 0, 0, 0, 2, 3, 14}), false, true);
verifyNoMoreInteractions(sink); verifyNoMoreInteractions(sink);
assertEquals(1, allocator.allocCount); assertEquals(1, allocator.allocCount);
checkStats(2, 2);
} }
@Test @Test
public void largerFrameSize() throws Exception { public void largerFrameSize() throws Exception {
allocator = new BytesWritableBufferAllocator(0, 10000); allocator = new BytesWritableBufferAllocator(0, 10000);
framer = new MessageFramer(sink, allocator); framer = new MessageFramer(sink, allocator, statsTraceCtx);
writeKnownLength(framer, new byte[1000]); writeKnownLength(framer, new byte[1000]);
framer.flush(); framer.flush();
verify(sink).deliverFrame(frameCaptor.capture(), eq(false), eq(true)); verify(sink).deliverFrame(frameCaptor.capture(), eq(false), eq(true));
@ -207,13 +233,14 @@ public class MessageFramerTest {
assertEquals(toWriteBuffer(data), buffer); assertEquals(toWriteBuffer(data), buffer);
verifyNoMoreInteractions(sink); verifyNoMoreInteractions(sink);
assertEquals(1, allocator.allocCount); assertEquals(1, allocator.allocCount);
checkStats(1000, 1000);
} }
@Test @Test
public void largerFrameSizeUnknownLength() throws Exception { public void largerFrameSizeUnknownLength() throws Exception {
// Force payload to be split into two chunks // Force payload to be split into two chunks
allocator = new BytesWritableBufferAllocator(500, 500); allocator = new BytesWritableBufferAllocator(500, 500);
framer = new MessageFramer(sink, allocator); framer = new MessageFramer(sink, allocator, statsTraceCtx);
writeUnknownLength(framer, new byte[1000]); writeUnknownLength(framer, new byte[1000]);
framer.flush(); framer.flush();
// Header and first chunk written with flush = false // Header and first chunk written with flush = false
@ -233,13 +260,14 @@ public class MessageFramerTest {
verifyNoMoreInteractions(sink); verifyNoMoreInteractions(sink);
assertEquals(3, allocator.allocCount); assertEquals(3, allocator.allocCount);
checkStats(1000, 1000);
} }
@Test @Test
public void compressed() throws Exception { public void compressed() throws Exception {
allocator = new BytesWritableBufferAllocator(100, Integer.MAX_VALUE); allocator = new BytesWritableBufferAllocator(100, Integer.MAX_VALUE);
// setMessageCompression should default to true // setMessageCompression should default to true
framer = new MessageFramer(sink, allocator).setCompressor(new Codec.Gzip()); framer = new MessageFramer(sink, allocator, statsTraceCtx).setCompressor(new Codec.Gzip());
writeKnownLength(framer, new byte[1000]); writeKnownLength(framer, new byte[1000]);
framer.flush(); framer.flush();
// The GRPC header is written first as a separate frame. // The GRPC header is written first as a separate frame.
@ -257,12 +285,13 @@ public class MessageFramerTest {
assertTrue(length < 1000); assertTrue(length < 1000);
assertEquals(frameCaptor.getAllValues().get(1).size(), length); assertEquals(frameCaptor.getAllValues().get(1).size(), length);
checkStats(length, 1000);
} }
@Test @Test
public void dontCompressIfNoEncoding() throws Exception { public void dontCompressIfNoEncoding() throws Exception {
allocator = new BytesWritableBufferAllocator(100, Integer.MAX_VALUE); allocator = new BytesWritableBufferAllocator(100, Integer.MAX_VALUE);
framer = new MessageFramer(sink, allocator) framer = new MessageFramer(sink, allocator, statsTraceCtx)
.setMessageCompression(true); .setMessageCompression(true);
writeKnownLength(framer, new byte[1000]); writeKnownLength(framer, new byte[1000]);
framer.flush(); framer.flush();
@ -281,12 +310,13 @@ public class MessageFramerTest {
assertEquals(1000, length); assertEquals(1000, length);
assertEquals(buffer.data.length - 5 , length); assertEquals(buffer.data.length - 5 , length);
checkStats(1000, 1000);
} }
@Test @Test
public void dontCompressIfNotRequested() throws Exception { public void dontCompressIfNotRequested() throws Exception {
allocator = new BytesWritableBufferAllocator(100, Integer.MAX_VALUE); allocator = new BytesWritableBufferAllocator(100, Integer.MAX_VALUE);
framer = new MessageFramer(sink, allocator) framer = new MessageFramer(sink, allocator, statsTraceCtx)
.setCompressor(new Codec.Gzip()) .setCompressor(new Codec.Gzip())
.setMessageCompression(false); .setMessageCompression(false);
writeKnownLength(framer, new byte[1000]); writeKnownLength(framer, new byte[1000]);
@ -306,6 +336,7 @@ public class MessageFramerTest {
assertEquals(1000, length); assertEquals(1000, length);
assertEquals(buffer.data.length - 5 , length); assertEquals(buffer.data.length - 5 , length);
checkStats(1000, 1000);
} }
@Test @Test
@ -322,7 +353,7 @@ public class MessageFramerTest {
} }
} }
}; };
framer = new MessageFramer(reentrant, allocator); framer = new MessageFramer(reentrant, allocator, statsTraceCtx);
writeKnownLength(framer, new byte[]{3, 14}); writeKnownLength(framer, new byte[]{3, 14});
framer.close(); framer.close();
} }
@ -334,6 +365,7 @@ public class MessageFramerTest {
writeKnownLength(framer, new byte[]{}); writeKnownLength(framer, new byte[]{});
framer.flush(); framer.flush();
verify(sink).deliverFrame(toWriteBuffer(new byte[] {0, 0, 0, 0, 0}), false, true); verify(sink).deliverFrame(toWriteBuffer(new byte[] {0, 0, 0, 0, 0}), false, true);
checkStats(0, 0);
} }
private static WritableBuffer toWriteBuffer(byte[] data) { private static WritableBuffer toWriteBuffer(byte[] data) {
@ -355,6 +387,23 @@ public class MessageFramerTest {
// TODO(carl-mastrangelo): add framer.flush() here. // TODO(carl-mastrangelo): add framer.flush() here.
} }
private void checkStats(long wireBytesSent, long uncompressedBytesSent) {
statsTraceCtx.callEnded(Status.OK);
MetricsRecord record = censusCtxFactory.pollRecord();
assertEquals(0, record.getMetricAsLongOrFail(
RpcConstants.RPC_SERVER_REQUEST_BYTES));
assertEquals(0, record.getMetricAsLongOrFail(
RpcConstants.RPC_SERVER_UNCOMPRESSED_REQUEST_BYTES));
assertEquals(wireBytesSent,
record.getMetricAsLongOrFail(RpcConstants.RPC_SERVER_RESPONSE_BYTES));
assertEquals(uncompressedBytesSent,
record.getMetricAsLongOrFail(RpcConstants.RPC_SERVER_UNCOMPRESSED_RESPONSE_BYTES));
assertNull(record.getMetric(RpcConstants.RPC_CLIENT_REQUEST_BYTES));
assertNull(record.getMetric(RpcConstants.RPC_CLIENT_RESPONSE_BYTES));
assertNull(record.getMetric(RpcConstants.RPC_CLIENT_UNCOMPRESSED_REQUEST_BYTES));
assertNull(record.getMetric(RpcConstants.RPC_CLIENT_UNCOMPRESSED_RESPONSE_BYTES));
}
static class ByteWritableBuffer implements WritableBuffer { static class ByteWritableBuffer implements WritableBuffer {
byte[] data; byte[] data;
private int writeIdx; private int writeIdx;

View File

@ -32,6 +32,7 @@
package io.grpc.internal; package io.grpc.internal;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull; import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail; import static org.junit.Assert.fail;
@ -40,6 +41,8 @@ import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import com.google.census.RpcConstants;
import com.google.census.TagValue;
import com.google.common.io.CharStreams; import com.google.common.io.CharStreams;
import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.Futures;
@ -53,6 +56,8 @@ import io.grpc.MethodDescriptor.MethodType;
import io.grpc.ServerCall; import io.grpc.ServerCall;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.internal.ServerCallImpl.ServerStreamListenerImpl; import io.grpc.internal.ServerCallImpl.ServerStreamListenerImpl;
import io.grpc.internal.testing.CensusTestUtils.FakeCensusContextFactory;
import io.grpc.internal.testing.CensusTestUtils;
import org.junit.Before; import org.junit.Before;
import org.junit.Rule; import org.junit.Rule;
@ -86,12 +91,18 @@ public class ServerCallImplTest {
private final MethodDescriptor<Long, Long> method = MethodDescriptor.create( private final MethodDescriptor<Long, Long> method = MethodDescriptor.create(
MethodType.UNARY, "/service/method", new LongMarshaller(), new LongMarshaller()); MethodType.UNARY, "/service/method", new LongMarshaller(), new LongMarshaller());
private final Metadata requestHeaders = new Metadata();
private final FakeCensusContextFactory censusCtxFactory = new FakeCensusContextFactory();
private final StatsTraceContext statsTraceCtx = StatsTraceContext.newServerContext(
method.getFullMethodName(), censusCtxFactory, requestHeaders, GrpcUtil.STOPWATCH_SUPPLIER);
@Before @Before
public void setUp() { public void setUp() {
MockitoAnnotations.initMocks(this); MockitoAnnotations.initMocks(this);
context = Context.ROOT.withCancellation(); context = Context.ROOT.withCancellation();
call = new ServerCallImpl<Long, Long>(stream, method, new Metadata(), context, call = new ServerCallImpl<Long, Long>(stream, method, requestHeaders, context,
DecompressorRegistry.getDefaultInstance(), CompressorRegistry.getDefaultInstance()); statsTraceCtx, DecompressorRegistry.getDefaultInstance(),
CompressorRegistry.getDefaultInstance());
} }
@Test @Test
@ -189,7 +200,8 @@ public class ServerCallImplTest {
@Test @Test
public void streamListener_halfClosed() { public void streamListener_halfClosed() {
ServerStreamListenerImpl<Long> streamListener = ServerStreamListenerImpl<Long> streamListener =
new ServerCallImpl.ServerStreamListenerImpl<Long>(call, callListener, context); new ServerCallImpl.ServerStreamListenerImpl<Long>(
call, callListener, context, statsTraceCtx);
streamListener.halfClosed(); streamListener.halfClosed();
@ -199,7 +211,8 @@ public class ServerCallImplTest {
@Test @Test
public void streamListener_halfClosed_onlyOnce() { public void streamListener_halfClosed_onlyOnce() {
ServerStreamListenerImpl<Long> streamListener = ServerStreamListenerImpl<Long> streamListener =
new ServerCallImpl.ServerStreamListenerImpl<Long>(call, callListener, context); new ServerCallImpl.ServerStreamListenerImpl<Long>(
call, callListener, context, statsTraceCtx);
streamListener.halfClosed(); streamListener.halfClosed();
// canceling the call should short circuit future halfClosed() calls. // canceling the call should short circuit future halfClosed() calls.
streamListener.closed(Status.CANCELLED); streamListener.closed(Status.CANCELLED);
@ -212,31 +225,36 @@ public class ServerCallImplTest {
@Test @Test
public void streamListener_closedOk() { public void streamListener_closedOk() {
ServerStreamListenerImpl<Long> streamListener = ServerStreamListenerImpl<Long> streamListener =
new ServerCallImpl.ServerStreamListenerImpl<Long>(call, callListener, context); new ServerCallImpl.ServerStreamListenerImpl<Long>(
call, callListener, context, statsTraceCtx);
streamListener.closed(Status.OK); streamListener.closed(Status.OK);
verify(callListener).onComplete(); verify(callListener).onComplete();
assertTrue(context.isCancelled()); assertTrue(context.isCancelled());
assertNull(context.cancellationCause()); assertNull(context.cancellationCause());
checkStats(Status.Code.OK);
} }
@Test @Test
public void streamListener_closedCancelled() { public void streamListener_closedCancelled() {
ServerStreamListenerImpl<Long> streamListener = ServerStreamListenerImpl<Long> streamListener =
new ServerCallImpl.ServerStreamListenerImpl<Long>(call, callListener, context); new ServerCallImpl.ServerStreamListenerImpl<Long>(
call, callListener, context, statsTraceCtx);
streamListener.closed(Status.CANCELLED); streamListener.closed(Status.CANCELLED);
verify(callListener).onCancel(); verify(callListener).onCancel();
assertTrue(context.isCancelled()); assertTrue(context.isCancelled());
assertNull(context.cancellationCause()); assertNull(context.cancellationCause());
checkStats(Status.Code.CANCELLED);
} }
@Test @Test
public void streamListener_onReady() { public void streamListener_onReady() {
ServerStreamListenerImpl<Long> streamListener = ServerStreamListenerImpl<Long> streamListener =
new ServerCallImpl.ServerStreamListenerImpl<Long>(call, callListener, context); new ServerCallImpl.ServerStreamListenerImpl<Long>(
call, callListener, context, statsTraceCtx);
streamListener.onReady(); streamListener.onReady();
@ -246,7 +264,8 @@ public class ServerCallImplTest {
@Test @Test
public void streamListener_onReady_onlyOnce() { public void streamListener_onReady_onlyOnce() {
ServerStreamListenerImpl<Long> streamListener = ServerStreamListenerImpl<Long> streamListener =
new ServerCallImpl.ServerStreamListenerImpl<Long>(call, callListener, context); new ServerCallImpl.ServerStreamListenerImpl<Long>(
call, callListener, context, statsTraceCtx);
streamListener.onReady(); streamListener.onReady();
// canceling the call should short circuit future halfClosed() calls. // canceling the call should short circuit future halfClosed() calls.
streamListener.closed(Status.CANCELLED); streamListener.closed(Status.CANCELLED);
@ -259,7 +278,8 @@ public class ServerCallImplTest {
@Test @Test
public void streamListener_messageRead() { public void streamListener_messageRead() {
ServerStreamListenerImpl<Long> streamListener = ServerStreamListenerImpl<Long> streamListener =
new ServerCallImpl.ServerStreamListenerImpl<Long>(call, callListener, context); new ServerCallImpl.ServerStreamListenerImpl<Long>(
call, callListener, context, statsTraceCtx);
streamListener.messageRead(method.streamRequest(1234L)); streamListener.messageRead(method.streamRequest(1234L));
verify(callListener).onMessage(1234L); verify(callListener).onMessage(1234L);
@ -268,7 +288,8 @@ public class ServerCallImplTest {
@Test @Test
public void streamListener_messageRead_unaryFailsOnMultiple() { public void streamListener_messageRead_unaryFailsOnMultiple() {
ServerStreamListenerImpl<Long> streamListener = ServerStreamListenerImpl<Long> streamListener =
new ServerCallImpl.ServerStreamListenerImpl<Long>(call, callListener, context); new ServerCallImpl.ServerStreamListenerImpl<Long>(
call, callListener, context, statsTraceCtx);
streamListener.messageRead(method.streamRequest(1234L)); streamListener.messageRead(method.streamRequest(1234L));
streamListener.messageRead(method.streamRequest(1234L)); streamListener.messageRead(method.streamRequest(1234L));
@ -282,7 +303,8 @@ public class ServerCallImplTest {
@Test @Test
public void streamListener_messageRead_onlyOnce() { public void streamListener_messageRead_onlyOnce() {
ServerStreamListenerImpl<Long> streamListener = ServerStreamListenerImpl<Long> streamListener =
new ServerCallImpl.ServerStreamListenerImpl<Long>(call, callListener, context); new ServerCallImpl.ServerStreamListenerImpl<Long>(
call, callListener, context, statsTraceCtx);
streamListener.messageRead(method.streamRequest(1234L)); streamListener.messageRead(method.streamRequest(1234L));
// canceling the call should short circuit future halfClosed() calls. // canceling the call should short circuit future halfClosed() calls.
streamListener.closed(Status.CANCELLED); streamListener.closed(Status.CANCELLED);
@ -292,6 +314,28 @@ public class ServerCallImplTest {
verify(callListener).onMessage(1234L); verify(callListener).onMessage(1234L);
} }
private void checkStats(Status.Code statusCode) {
CensusTestUtils.MetricsRecord record = censusCtxFactory.pollRecord();
assertNotNull(record);
TagValue statusTag = record.tags.get(RpcConstants.RPC_STATUS);
assertNotNull(statusTag);
assertEquals(statusCode.toString(), statusTag.toString());
assertNull(record.getMetric(RpcConstants.RPC_CLIENT_REQUEST_BYTES));
assertNull(record.getMetric(RpcConstants.RPC_CLIENT_RESPONSE_BYTES));
assertNull(record.getMetric(RpcConstants.RPC_CLIENT_UNCOMPRESSED_REQUEST_BYTES));
assertNull(record.getMetric(RpcConstants.RPC_CLIENT_UNCOMPRESSED_RESPONSE_BYTES));
// The test doesn't invoke MessageFramer and MessageDeframer which keep the sizes.
// Thus the sizes reported to stats would be zero.
assertEquals(0,
record.getMetricAsLongOrFail(RpcConstants.RPC_SERVER_REQUEST_BYTES));
assertEquals(0,
record.getMetricAsLongOrFail(RpcConstants.RPC_SERVER_RESPONSE_BYTES));
assertEquals(0,
record.getMetricAsLongOrFail(RpcConstants.RPC_SERVER_UNCOMPRESSED_REQUEST_BYTES));
assertEquals(0,
record.getMetricAsLongOrFail(RpcConstants.RPC_SERVER_UNCOMPRESSED_RESPONSE_BYTES));
}
private static class LongMarshaller implements Marshaller<Long> { private static class LongMarshaller implements Marshaller<Long> {
@Override @Override
public InputStream stream(Long value) { public InputStream stream(Long value) {

View File

@ -35,19 +35,26 @@ import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNotSame; import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame; import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail; import static org.junit.Assert.fail;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.isA; import static org.mockito.Matchers.isA;
import static org.mockito.Matchers.isNotNull; import static org.mockito.Matchers.isNotNull;
import static org.mockito.Matchers.notNull; import static org.mockito.Matchers.notNull;
import static org.mockito.Matchers.same; import static org.mockito.Matchers.same;
import static org.mockito.Mockito.atLeast;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import com.google.census.CensusContext;
import com.google.census.RpcConstants;
import com.google.census.TagValue;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import com.google.common.truth.Truth; import com.google.common.truth.Truth;
import com.google.common.util.concurrent.MoreExecutors; import com.google.common.util.concurrent.MoreExecutors;
@ -70,6 +77,8 @@ import io.grpc.ServerTransportFilter;
import io.grpc.ServiceDescriptor; import io.grpc.ServiceDescriptor;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.StringMarshaller; import io.grpc.StringMarshaller;
import io.grpc.internal.testing.CensusTestUtils.FakeCensusContextFactory;
import io.grpc.internal.testing.CensusTestUtils;
import io.grpc.util.MutableHandlerRegistry; import io.grpc.util.MutableHandlerRegistry;
import org.junit.After; import org.junit.After;
@ -81,6 +90,7 @@ import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.junit.runners.JUnit4; import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Matchers; import org.mockito.Matchers;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.MockitoAnnotations; import org.mockito.MockitoAnnotations;
@ -108,6 +118,8 @@ public class ServerImplTest {
private static final Context.CancellableContext SERVER_CONTEXT = private static final Context.CancellableContext SERVER_CONTEXT =
Context.ROOT.withValue(SERVER_ONLY, "yes").withCancellation(); Context.ROOT.withValue(SERVER_ONLY, "yes").withCancellation();
private static final ImmutableList<ServerTransportFilter> NO_FILTERS = ImmutableList.of(); private static final ImmutableList<ServerTransportFilter> NO_FILTERS = ImmutableList.of();
private final FakeCensusContextFactory censusCtxFactory = new FakeCensusContextFactory();
private final CompressorRegistry compressorRegistry = CompressorRegistry.getDefaultInstance(); private final CompressorRegistry compressorRegistry = CompressorRegistry.getDefaultInstance();
private final DecompressorRegistry decompressorRegistry = private final DecompressorRegistry decompressorRegistry =
DecompressorRegistry.getDefaultInstance(); DecompressorRegistry.getDefaultInstance();
@ -126,7 +138,11 @@ public class ServerImplTest {
private MutableHandlerRegistry fallbackRegistry = new MutableHandlerRegistry(); private MutableHandlerRegistry fallbackRegistry = new MutableHandlerRegistry();
private SimpleServer transportServer = new SimpleServer(); private SimpleServer transportServer = new SimpleServer();
private ServerImpl server = new ServerImpl(executor, registry, fallbackRegistry, transportServer, private ServerImpl server = new ServerImpl(executor, registry, fallbackRegistry, transportServer,
SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS); SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS, censusCtxFactory,
GrpcUtil.STOPWATCH_SUPPLIER);
@Captor
private ArgumentCaptor<Status> statusCaptor;
@Mock @Mock
private ServerStream stream; private ServerStream stream;
@ -158,7 +174,8 @@ public class ServerImplTest {
public void shutdown() {} public void shutdown() {}
}; };
ServerImpl server = new ServerImpl(executor, registry, fallbackRegistry, transportServer, ServerImpl server = new ServerImpl(executor, registry, fallbackRegistry, transportServer,
SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS); SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS, censusCtxFactory,
GrpcUtil.STOPWATCH_SUPPLIER);
server.start(); server.start();
server.shutdown(); server.shutdown();
assertTrue(server.isShutdown()); assertTrue(server.isShutdown());
@ -176,7 +193,8 @@ public class ServerImplTest {
} }
}; };
ServerImpl server = new ServerImpl(executor, registry, fallbackRegistry, transportServer, ServerImpl server = new ServerImpl(executor, registry, fallbackRegistry, transportServer,
SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS); SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS, censusCtxFactory,
GrpcUtil.STOPWATCH_SUPPLIER);
server.shutdown(); server.shutdown();
assertTrue(server.isShutdown()); assertTrue(server.isShutdown());
assertTrue(server.isTerminated()); assertTrue(server.isTerminated());
@ -185,7 +203,8 @@ public class ServerImplTest {
@Test @Test
public void startStopImmediateWithChildTransport() throws IOException { public void startStopImmediateWithChildTransport() throws IOException {
ServerImpl server = new ServerImpl(executor, registry, fallbackRegistry, transportServer, ServerImpl server = new ServerImpl(executor, registry, fallbackRegistry, transportServer,
SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS); SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS, censusCtxFactory,
GrpcUtil.STOPWATCH_SUPPLIER);
server.start(); server.start();
class DelayedShutdownServerTransport extends SimpleServerTransport { class DelayedShutdownServerTransport extends SimpleServerTransport {
boolean shutdown; boolean shutdown;
@ -209,7 +228,8 @@ public class ServerImplTest {
@Test @Test
public void startShutdownNowImmediateWithChildTransport() throws IOException { public void startShutdownNowImmediateWithChildTransport() throws IOException {
ServerImpl server = new ServerImpl(executor, registry, fallbackRegistry, transportServer, ServerImpl server = new ServerImpl(executor, registry, fallbackRegistry, transportServer,
SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS); SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS, censusCtxFactory,
GrpcUtil.STOPWATCH_SUPPLIER);
server.start(); server.start();
class DelayedShutdownServerTransport extends SimpleServerTransport { class DelayedShutdownServerTransport extends SimpleServerTransport {
boolean shutdown; boolean shutdown;
@ -236,7 +256,8 @@ public class ServerImplTest {
@Test @Test
public void shutdownNowAfterShutdown() throws IOException { public void shutdownNowAfterShutdown() throws IOException {
ServerImpl server = new ServerImpl(executor, registry, fallbackRegistry, transportServer, ServerImpl server = new ServerImpl(executor, registry, fallbackRegistry, transportServer,
SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS); SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS, censusCtxFactory,
GrpcUtil.STOPWATCH_SUPPLIER);
server.start(); server.start();
class DelayedShutdownServerTransport extends SimpleServerTransport { class DelayedShutdownServerTransport extends SimpleServerTransport {
boolean shutdown; boolean shutdown;
@ -270,7 +291,8 @@ public class ServerImplTest {
} }
}; };
ServerImpl server = new ServerImpl(executor, registry, fallbackRegistry, transportServer, ServerImpl server = new ServerImpl(executor, registry, fallbackRegistry, transportServer,
SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS); SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS, censusCtxFactory,
GrpcUtil.STOPWATCH_SUPPLIER);
server.start(); server.start();
class DelayedShutdownServerTransport extends SimpleServerTransport { class DelayedShutdownServerTransport extends SimpleServerTransport {
boolean shutdown; boolean shutdown;
@ -307,7 +329,7 @@ public class ServerImplTest {
ServerImpl server = new ServerImpl(executor, registry, fallbackRegistry, ServerImpl server = new ServerImpl(executor, registry, fallbackRegistry,
new FailingStartupServer(), SERVER_CONTEXT, decompressorRegistry, compressorRegistry, new FailingStartupServer(), SERVER_CONTEXT, decompressorRegistry, compressorRegistry,
NO_FILTERS); NO_FILTERS, censusCtxFactory, GrpcUtil.STOPWATCH_SUPPLIER);
try { try {
server.start(); server.start();
fail("expected exception"); fail("expected exception");
@ -316,10 +338,41 @@ public class ServerImplTest {
} }
} }
@Test
public void methodNotFound() throws Exception {
ServerTransportListener transportListener
= transportServer.registerNewServerTransport(new SimpleServerTransport());
Metadata requestHeaders = new Metadata();
StatsTraceContext statsTraceCtx =
transportListener.methodDetermined("Waiter/nonexist", requestHeaders);
when(stream.statsTraceContext()).thenReturn(statsTraceCtx);
ServerStreamListener streamListener
= transportListener.streamCreated(stream, "Waiter/nonexist", requestHeaders);
assertNotNull(streamListener);
verify(stream, atLeast(1)).statsTraceContext();
executeBarrier(executor).await();
verify(stream).close(statusCaptor.capture(), any(Metadata.class));
Status status = statusCaptor.getValue();
assertEquals(Status.Code.UNIMPLEMENTED, status.getCode());
assertEquals("Method not found: Waiter/nonexist", status.getDescription());
CensusTestUtils.MetricsRecord record = censusCtxFactory.pollRecord();
assertNotNull(record);
TagValue methodTag = record.tags.get(RpcConstants.RPC_SERVER_METHOD);
assertNotNull(methodTag);
assertEquals("Waiter/nonexist", methodTag.toString());
TagValue statusTag = record.tags.get(RpcConstants.RPC_STATUS);
assertNotNull(statusTag);
assertEquals(Status.Code.UNIMPLEMENTED.toString(), statusTag.toString());
}
@Test @Test
public void basicExchangeSuccessful() throws Exception { public void basicExchangeSuccessful() throws Exception {
final Metadata.Key<String> metadataKey final Metadata.Key<String> metadataKey
= Metadata.Key.of("inception", Metadata.ASCII_STRING_MARSHALLER); = Metadata.Key.of("inception", Metadata.ASCII_STRING_MARSHALLER);
final Metadata.Key<CensusContext> censusHeaderKey
= StatsTraceContext.createCensusHeader(censusCtxFactory);
final AtomicReference<ServerCall<String, Integer>> callReference final AtomicReference<ServerCall<String, Integer>> callReference
= new AtomicReference<ServerCall<String, Integer>>(); = new AtomicReference<ServerCall<String, Integer>>();
MethodDescriptor<String, Integer> method = MethodDescriptor.create( MethodDescriptor<String, Integer> method = MethodDescriptor.create(
@ -346,9 +399,18 @@ public class ServerImplTest {
Metadata requestHeaders = new Metadata(); Metadata requestHeaders = new Metadata();
requestHeaders.put(metadataKey, "value"); requestHeaders.put(metadataKey, "value");
CensusContext censusContextOnClient = censusCtxFactory.getDefault().with(
CensusTestUtils.EXTRA_TAG, new TagValue("extraTagValue"));
requestHeaders.put(censusHeaderKey, censusContextOnClient);
StatsTraceContext statsTraceCtx =
transportListener.methodDetermined("Waiter/serve", requestHeaders);
assertNotNull(statsTraceCtx);
when(stream.statsTraceContext()).thenReturn(statsTraceCtx);
ServerStreamListener streamListener ServerStreamListener streamListener
= transportListener.streamCreated(stream, "Waiter/serve", requestHeaders); = transportListener.streamCreated(stream, "Waiter/serve", requestHeaders);
assertNotNull(streamListener); assertNotNull(streamListener);
verify(stream, atLeast(1)).statsTraceContext();
executeBarrier(executor).await(); executeBarrier(executor).await();
ServerCall<String, Integer> call = callReference.get(); ServerCall<String, Integer> call = callReference.get();
@ -389,8 +451,34 @@ public class ServerImplTest {
executeBarrier(executor).await(); executeBarrier(executor).await();
verify(callListener).onComplete(); verify(callListener).onComplete();
verify(stream, atLeast(1)).statsTraceContext();
verifyNoMoreInteractions(stream); verifyNoMoreInteractions(stream);
verifyNoMoreInteractions(callListener); verifyNoMoreInteractions(callListener);
// Check stats
CensusTestUtils.MetricsRecord record = censusCtxFactory.pollRecord();
assertNotNull(record);
TagValue methodTag = record.tags.get(RpcConstants.RPC_SERVER_METHOD);
assertNotNull(methodTag);
assertEquals("Waiter/serve", methodTag.toString());
TagValue statusTag = record.tags.get(RpcConstants.RPC_STATUS);
assertNotNull(statusTag);
assertEquals(Status.Code.OK.toString(), statusTag.toString());
TagValue extraTag = record.tags.get(CensusTestUtils.EXTRA_TAG);
assertNotNull(extraTag);
assertEquals("extraTagValue", extraTag.toString());
assertNull(record.getMetric(RpcConstants.RPC_CLIENT_REQUEST_BYTES));
assertNull(record.getMetric(RpcConstants.RPC_CLIENT_RESPONSE_BYTES));
assertNull(record.getMetric(RpcConstants.RPC_CLIENT_UNCOMPRESSED_REQUEST_BYTES));
assertNull(record.getMetric(RpcConstants.RPC_CLIENT_UNCOMPRESSED_RESPONSE_BYTES));
// The test doesn't invoke MessageFramer and MessageDeframer which keep the sizes.
// Thus the sizes reported to stats would be zero.
assertEquals(0, record.getMetricAsLongOrFail(RpcConstants.RPC_SERVER_REQUEST_BYTES));
assertEquals(0, record.getMetricAsLongOrFail(RpcConstants.RPC_SERVER_RESPONSE_BYTES));
assertEquals(0,
record.getMetricAsLongOrFail(RpcConstants.RPC_SERVER_UNCOMPRESSED_REQUEST_BYTES));
assertEquals(0,
record.getMetricAsLongOrFail(RpcConstants.RPC_SERVER_UNCOMPRESSED_RESPONSE_BYTES));
} }
@Test @Test
@ -454,7 +542,7 @@ public class ServerImplTest {
ServerImpl server = new ServerImpl(MoreExecutors.directExecutor(), registry, fallbackRegistry, ServerImpl server = new ServerImpl(MoreExecutors.directExecutor(), registry, fallbackRegistry,
transportServer, SERVER_CONTEXT, decompressorRegistry, compressorRegistry, transportServer, SERVER_CONTEXT, decompressorRegistry, compressorRegistry,
ImmutableList.of(filter1, filter2)); ImmutableList.of(filter1, filter2), censusCtxFactory, GrpcUtil.STOPWATCH_SUPPLIER);
server.start(); server.start();
ServerTransportListener transportListener ServerTransportListener transportListener
= transportServer.registerNewServerTransport(new SimpleServerTransport()); = transportServer.registerNewServerTransport(new SimpleServerTransport());
@ -493,14 +581,22 @@ public class ServerImplTest {
ServerTransportListener transportListener ServerTransportListener transportListener
= transportServer.registerNewServerTransport(new SimpleServerTransport()); = transportServer.registerNewServerTransport(new SimpleServerTransport());
Metadata requestHeaders = new Metadata();
StatsTraceContext statsTraceCtx =
transportListener.methodDetermined("Waiter/serve", requestHeaders);
assertNotNull(statsTraceCtx);
when(stream.statsTraceContext()).thenReturn(statsTraceCtx);
ServerStreamListener streamListener ServerStreamListener streamListener
= transportListener.streamCreated(stream, "Waiter/serve", new Metadata()); = transportListener.streamCreated(stream, "Waiter/serve", requestHeaders);
assertNotNull(streamListener); assertNotNull(streamListener);
verify(stream, atLeast(1)).statsTraceContext();
verifyNoMoreInteractions(stream); verifyNoMoreInteractions(stream);
barrier.await(); barrier.await();
executeBarrier(executor).await(); executeBarrier(executor).await();
verify(stream).close(same(status), notNull(Metadata.class)); verify(stream).close(same(status), notNull(Metadata.class));
verify(stream, atLeast(1)).statsTraceContext();
verifyNoMoreInteractions(stream); verifyNoMoreInteractions(stream);
} }
@ -526,7 +622,8 @@ public class ServerImplTest {
transportServer = new MaybeDeadlockingServer(); transportServer = new MaybeDeadlockingServer();
ServerImpl server = new ServerImpl(executor, registry, fallbackRegistry, transportServer, ServerImpl server = new ServerImpl(executor, registry, fallbackRegistry, transportServer,
SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS); SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS, censusCtxFactory,
GrpcUtil.STOPWATCH_SUPPLIER);
server.start(); server.start();
new Thread() { new Thread() {
@Override @Override
@ -589,6 +686,10 @@ public class ServerImplTest {
public void testCallContextIsBoundInListenerCallbacks() throws Exception { public void testCallContextIsBoundInListenerCallbacks() throws Exception {
MethodDescriptor<String, Integer> method = MethodDescriptor.create( MethodDescriptor<String, Integer> method = MethodDescriptor.create(
MethodType.UNKNOWN, "Waiter/serve", STRING_MARSHALLER, INTEGER_MARSHALLER); MethodType.UNKNOWN, "Waiter/serve", STRING_MARSHALLER, INTEGER_MARSHALLER);
final CountDownLatch onReadyCalled = new CountDownLatch(1);
final CountDownLatch onMessageCalled = new CountDownLatch(1);
final CountDownLatch onHalfCloseCalled = new CountDownLatch(1);
final CountDownLatch onCancelCalled = new CountDownLatch(1);
fallbackRegistry.addService(ServerServiceDefinition.builder( fallbackRegistry.addService(ServerServiceDefinition.builder(
new ServiceDescriptor("Waiter", method)) new ServiceDescriptor("Waiter", method))
.addMethod( .addMethod(
@ -608,31 +709,30 @@ public class ServerImplTest {
@Override @Override
public void onReady() { public void onReady() {
checkContext(); checkContext();
super.onReady(); onReadyCalled.countDown();
} }
@Override @Override
public void onMessage(String message) { public void onMessage(String message) {
checkContext(); checkContext();
super.onMessage(message); onMessageCalled.countDown();
} }
@Override @Override
public void onHalfClose() { public void onHalfClose() {
checkContext(); checkContext();
super.onHalfClose(); onHalfCloseCalled.countDown();
} }
@Override @Override
public void onCancel() { public void onCancel() {
checkContext(); checkContext();
super.onCancel(); onCancelCalled.countDown();
} }
@Override @Override
public void onComplete() { public void onComplete() {
checkContext(); checkContext();
super.onComplete();
} }
private void checkContext() { private void checkContext() {
@ -645,8 +745,14 @@ public class ServerImplTest {
ServerTransportListener transportListener ServerTransportListener transportListener
= transportServer.registerNewServerTransport(new SimpleServerTransport()); = transportServer.registerNewServerTransport(new SimpleServerTransport());
Metadata requestHeaders = new Metadata();
StatsTraceContext statsTraceCtx =
transportListener.methodDetermined("Waiter/serve", requestHeaders);
assertNotNull(statsTraceCtx);
when(stream.statsTraceContext()).thenReturn(statsTraceCtx);
ServerStreamListener streamListener ServerStreamListener streamListener
= transportListener.streamCreated(stream, "Waiter/serve", new Metadata()); = transportListener.streamCreated(stream, "Waiter/serve", requestHeaders);
assertNotNull(streamListener); assertNotNull(streamListener);
streamListener.onReady(); streamListener.onReady();
@ -654,6 +760,11 @@ public class ServerImplTest {
streamListener.halfClosed(); streamListener.halfClosed();
streamListener.closed(Status.CANCELLED); streamListener.closed(Status.CANCELLED);
assertTrue(onReadyCalled.await(5, TimeUnit.SECONDS));
assertTrue(onMessageCalled.await(5, TimeUnit.SECONDS));
assertTrue(onHalfCloseCalled.await(5, TimeUnit.SECONDS));
assertTrue(onCancelCalled.await(5, TimeUnit.SECONDS));
// Close should never be called if asserts in listener pass. // Close should never be called if asserts in listener pass.
verify(stream, times(0)).close(isA(Status.class), isNotNull(Metadata.class)); verify(stream, times(0)).close(isA(Status.class), isNotNull(Metadata.class));
} }
@ -691,9 +802,13 @@ public class ServerImplTest {
}).build()); }).build());
ServerTransportListener transportListener ServerTransportListener transportListener
= transportServer.registerNewServerTransport(new SimpleServerTransport()); = transportServer.registerNewServerTransport(new SimpleServerTransport());
Metadata requestHeaders = new Metadata();
StatsTraceContext statsTraceCtx =
transportListener.methodDetermined("Waiter/serve", requestHeaders);
assertNotNull(statsTraceCtx);
when(stream.statsTraceContext()).thenReturn(statsTraceCtx);
ServerStreamListener streamListener ServerStreamListener streamListener
= transportListener.streamCreated(stream, "Waiter/serve", new Metadata()); = transportListener.streamCreated(stream, "Waiter/serve", requestHeaders);
assertNotNull(streamListener); assertNotNull(streamListener);
streamListener.onReady(); streamListener.onReady();
@ -711,7 +826,8 @@ public class ServerImplTest {
} }
}; };
ServerImpl server = new ServerImpl(executor, registry, fallbackRegistry, transportServer, ServerImpl server = new ServerImpl(executor, registry, fallbackRegistry, transportServer,
SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS); SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS, censusCtxFactory,
GrpcUtil.STOPWATCH_SUPPLIER);
server.start(); server.start();
Truth.assertThat(server.getPort()).isEqualTo(65535); Truth.assertThat(server.getPort()).isEqualTo(65535);
@ -721,7 +837,8 @@ public class ServerImplTest {
public void getPortBeforeStartedFails() { public void getPortBeforeStartedFails() {
transportServer = new SimpleServer(); transportServer = new SimpleServer();
ServerImpl server = new ServerImpl(executor, registry, fallbackRegistry, transportServer, ServerImpl server = new ServerImpl(executor, registry, fallbackRegistry, transportServer,
SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS); SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS, censusCtxFactory,
GrpcUtil.STOPWATCH_SUPPLIER);
thrown.expect(IllegalStateException.class); thrown.expect(IllegalStateException.class);
thrown.expectMessage("started"); thrown.expectMessage("started");
server.getPort(); server.getPort();
@ -731,7 +848,8 @@ public class ServerImplTest {
public void getPortAfterTerminationFails() throws Exception { public void getPortAfterTerminationFails() throws Exception {
transportServer = new SimpleServer(); transportServer = new SimpleServer();
ServerImpl server = new ServerImpl(executor, registry, fallbackRegistry, transportServer, ServerImpl server = new ServerImpl(executor, registry, fallbackRegistry, transportServer,
SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS); SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS, censusCtxFactory,
GrpcUtil.STOPWATCH_SUPPLIER);
server.start(); server.start();
server.shutdown(); server.shutdown();
server.awaitTermination(); server.awaitTermination();
@ -751,16 +869,23 @@ public class ServerImplTest {
.build(); .build();
transportServer = new SimpleServer(); transportServer = new SimpleServer();
ServerImpl server = new ServerImpl(executor, registry, fallbackRegistry, transportServer, ServerImpl server = new ServerImpl(executor, registry, fallbackRegistry, transportServer,
SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS); SERVER_CONTEXT, decompressorRegistry, compressorRegistry, NO_FILTERS, censusCtxFactory,
GrpcUtil.STOPWATCH_SUPPLIER);
server.start(); server.start();
ServerTransportListener transportListener ServerTransportListener transportListener
= transportServer.registerNewServerTransport(new SimpleServerTransport()); = transportServer.registerNewServerTransport(new SimpleServerTransport());
Metadata requestHeaders = new Metadata();
StatsTraceContext statsTraceCtx =
transportListener.methodDetermined("Waiter/serve", requestHeaders);
assertNotNull(statsTraceCtx);
when(stream.statsTraceContext()).thenReturn(statsTraceCtx);
// This call will be handled by callHandler from the internal registry // This call will be handled by callHandler from the internal registry
transportListener.streamCreated(stream, "Service1/Method1", new Metadata()); transportListener.streamCreated(stream, "Service1/Method1", requestHeaders);
// This call will be handled by the fallbackRegistry because it's not registred in the internal // This call will be handled by the fallbackRegistry because it's not registred in the internal
// registry. // registry.
transportListener.streamCreated(stream, "Service1/Method2", new Metadata()); transportListener.streamCreated(stream, "Service1/Method2", requestHeaders);
verify(callHandler, timeout(2000)).startCall(Matchers.<ServerCall<String, Integer>>anyObject(), verify(callHandler, timeout(2000)).startCall(Matchers.<ServerCall<String, Integer>>anyObject(),
Matchers.<Metadata>anyObject()); Matchers.<Metadata>anyObject());

View File

@ -86,7 +86,8 @@ final class TestUtils {
public ConnectionClientTransport answer(InvocationOnMock invocation) throws Throwable { public ConnectionClientTransport answer(InvocationOnMock invocation) throws Throwable {
final ConnectionClientTransport mockTransport = mock(ConnectionClientTransport.class); final ConnectionClientTransport mockTransport = mock(ConnectionClientTransport.class);
when(mockTransport.newStream(any(MethodDescriptor.class), any(Metadata.class), when(mockTransport.newStream(any(MethodDescriptor.class), any(Metadata.class),
any(CallOptions.class))).thenReturn(mock(ClientStream.class)); any(CallOptions.class), any(StatsTraceContext.class)))
.thenReturn(mock(ClientStream.class));
// Save the listener // Save the listener
doAnswer(new Answer<Void>() { doAnswer(new Answer<Void>() {
@Override @Override

View File

@ -101,6 +101,7 @@ public class TransportSetTest {
private final Metadata headers = new Metadata(); private final Metadata headers = new Metadata();
private final CallOptions waitForReadyCallOptions = CallOptions.DEFAULT.withWaitForReady(); private final CallOptions waitForReadyCallOptions = CallOptions.DEFAULT.withWaitForReady();
private final CallOptions failFastCallOptions = CallOptions.DEFAULT; private final CallOptions failFastCallOptions = CallOptions.DEFAULT;
private final StatsTraceContext statsTraceCtx = StatsTraceContext.NOOP;
private TransportSet transportSet; private TransportSet transportSet;
private EquivalentAddressGroup addressGroup; private EquivalentAddressGroup addressGroup;
@ -137,7 +138,8 @@ public class TransportSetTest {
int onAllAddressesFailed = 0; int onAllAddressesFailed = 0;
// First attempt // First attempt
transportSet.obtainActiveTransport().newStream(method, new Metadata(), waitForReadyCallOptions); transportSet.obtainActiveTransport().newStream(method, new Metadata(), waitForReadyCallOptions,
statsTraceCtx);
assertEquals(ConnectivityState.CONNECTING, transportSet.getState(false)); assertEquals(ConnectivityState.CONNECTING, transportSet.getState(false));
verify(mockTransportFactory, times(++transportsCreated)) verify(mockTransportFactory, times(++transportsCreated))
.newClientTransport(addr, AUTHORITY, USER_AGENT); .newClientTransport(addr, AUTHORITY, USER_AGENT);
@ -225,7 +227,7 @@ public class TransportSetTest {
assertEquals(ConnectivityState.CONNECTING, transportSet.getState(false)); assertEquals(ConnectivityState.CONNECTING, transportSet.getState(false));
verify(mockTransportFactory, times(++transportsAddr1)) verify(mockTransportFactory, times(++transportsAddr1))
.newClientTransport(addr1, AUTHORITY, USER_AGENT); .newClientTransport(addr1, AUTHORITY, USER_AGENT);
delayedTransport1.newStream(method, new Metadata(), waitForReadyCallOptions); delayedTransport1.newStream(method, new Metadata(), waitForReadyCallOptions, statsTraceCtx);
// Let this one fail without success // Let this one fail without success
transports.poll().listener.transportShutdown(Status.UNAVAILABLE); transports.poll().listener.transportShutdown(Status.UNAVAILABLE);
assertEquals(ConnectivityState.CONNECTING, transportSet.getState(false)); assertEquals(ConnectivityState.CONNECTING, transportSet.getState(false));
@ -320,7 +322,7 @@ public class TransportSetTest {
(DelayedClientTransport) transportSet.obtainActiveTransport(); (DelayedClientTransport) transportSet.obtainActiveTransport();
assertEquals(ConnectivityState.CONNECTING, transportSet.getState(false)); assertEquals(ConnectivityState.CONNECTING, transportSet.getState(false));
assertNotSame(delayedTransport5, delayedTransport6); assertNotSame(delayedTransport5, delayedTransport6);
delayedTransport6.newStream(method, headers, waitForReadyCallOptions); delayedTransport6.newStream(method, headers, waitForReadyCallOptions, statsTraceCtx);
verify(mockBackoffPolicyProvider, times(backoffReset)).get(); verify(mockBackoffPolicyProvider, times(backoffReset)).get();
verify(mockTransportFactory, times(++transportsAddr1)) verify(mockTransportFactory, times(++transportsAddr1))
.newClientTransport(addr1, AUTHORITY, USER_AGENT); .newClientTransport(addr1, AUTHORITY, USER_AGENT);
@ -387,13 +389,14 @@ public class TransportSetTest {
assertFalse(delayedTransport.isInBackoffPeriod()); assertFalse(delayedTransport.isInBackoffPeriod());
// Create a new fail fast stream. // Create a new fail fast stream.
ClientStream ffStream = delayedTransport.newStream(method, headers, failFastCallOptions); ClientStream ffStream = delayedTransport.newStream(method, headers, failFastCallOptions,
statsTraceCtx);
ffStream.start(mockStreamListener); ffStream.start(mockStreamListener);
// Verify it is queued. // Verify it is queued.
assertEquals(++pendingStreamsCount, delayedTransport.getPendingStreamsCount()); assertEquals(++pendingStreamsCount, delayedTransport.getPendingStreamsCount());
failFastPendingStreamsCount++; failFastPendingStreamsCount++;
// Create a new non fail fast stream. // Create a new non fail fast stream.
delayedTransport.newStream(method, headers, waitForReadyCallOptions); delayedTransport.newStream(method, headers, waitForReadyCallOptions, statsTraceCtx);
// Verify it is queued. // Verify it is queued.
assertEquals(++pendingStreamsCount, delayedTransport.getPendingStreamsCount()); assertEquals(++pendingStreamsCount, delayedTransport.getPendingStreamsCount());
@ -405,12 +408,12 @@ public class TransportSetTest {
assertEquals(pendingStreamsCount, delayedTransport.getPendingStreamsCount()); assertEquals(pendingStreamsCount, delayedTransport.getPendingStreamsCount());
// Create a new fail fast stream. // Create a new fail fast stream.
delayedTransport.newStream(method, headers, failFastCallOptions); delayedTransport.newStream(method, headers, failFastCallOptions, statsTraceCtx);
// Verify it is queued. // Verify it is queued.
assertEquals(++pendingStreamsCount, delayedTransport.getPendingStreamsCount()); assertEquals(++pendingStreamsCount, delayedTransport.getPendingStreamsCount());
failFastPendingStreamsCount++; failFastPendingStreamsCount++;
// Create a new non fail fast stream // Create a new non fail fast stream
delayedTransport.newStream(method, headers, waitForReadyCallOptions); delayedTransport.newStream(method, headers, waitForReadyCallOptions, statsTraceCtx);
// Verify it is queued. // Verify it is queued.
assertEquals(++pendingStreamsCount, delayedTransport.getPendingStreamsCount()); assertEquals(++pendingStreamsCount, delayedTransport.getPendingStreamsCount());
@ -428,11 +431,11 @@ public class TransportSetTest {
verify(mockStreamListener).closed(same(failureStatus), any(Metadata.class)); verify(mockStreamListener).closed(same(failureStatus), any(Metadata.class));
// Create a new fail fast stream. // Create a new fail fast stream.
delayedTransport.newStream(method, headers, failFastCallOptions); delayedTransport.newStream(method, headers, failFastCallOptions, statsTraceCtx);
// Verify it is not queued. // Verify it is not queued.
assertEquals(pendingStreamsCount, delayedTransport.getPendingStreamsCount()); assertEquals(pendingStreamsCount, delayedTransport.getPendingStreamsCount());
// Create a new non fail fast stream // Create a new non fail fast stream
delayedTransport.newStream(method, headers, waitForReadyCallOptions); delayedTransport.newStream(method, headers, waitForReadyCallOptions, statsTraceCtx);
// Verify it is queued. // Verify it is queued.
assertEquals(++pendingStreamsCount, delayedTransport.getPendingStreamsCount()); assertEquals(++pendingStreamsCount, delayedTransport.getPendingStreamsCount());
@ -442,7 +445,7 @@ public class TransportSetTest {
assertFalse(delayedTransport.isInBackoffPeriod()); assertFalse(delayedTransport.isInBackoffPeriod());
// Create a new fail fast stream. // Create a new fail fast stream.
delayedTransport.newStream(method, headers, failFastCallOptions); delayedTransport.newStream(method, headers, failFastCallOptions, statsTraceCtx);
// Verify it is queued. // Verify it is queued.
assertEquals(++pendingStreamsCount, delayedTransport.getPendingStreamsCount()); assertEquals(++pendingStreamsCount, delayedTransport.getPendingStreamsCount());
failFastPendingStreamsCount++; failFastPendingStreamsCount++;
@ -487,7 +490,8 @@ public class TransportSetTest {
assertEquals(ConnectivityState.IDLE, transportSet.getState(false)); assertEquals(ConnectivityState.IDLE, transportSet.getState(false));
// Request immediately // Request immediately
transportSet.obtainActiveTransport().newStream(method, new Metadata(), waitForReadyCallOptions); transportSet.obtainActiveTransport().newStream(method, new Metadata(), waitForReadyCallOptions,
statsTraceCtx);
assertEquals(ConnectivityState.CONNECTING, transportSet.getState(false)); assertEquals(ConnectivityState.CONNECTING, transportSet.getState(false));
verify(mockTransportFactory, times(++transportsCreated)) verify(mockTransportFactory, times(++transportsCreated))
.newClientTransport(addr, AUTHORITY, USER_AGENT); .newClientTransport(addr, AUTHORITY, USER_AGENT);
@ -514,7 +518,8 @@ public class TransportSetTest {
pick = transportSet.obtainActiveTransport(); pick = transportSet.obtainActiveTransport();
assertTrue(pick instanceof DelayedClientTransport); assertTrue(pick instanceof DelayedClientTransport);
// Start a stream, which will be pending in the delayed transport // Start a stream, which will be pending in the delayed transport
ClientStream pendingStream = pick.newStream(method, headers, waitForReadyCallOptions); ClientStream pendingStream = pick.newStream(method, headers, waitForReadyCallOptions,
statsTraceCtx);
pendingStream.start(mockStreamListener); pendingStream.start(mockStreamListener);
// Shut down TransportSet before the transport is created. Further call to // Shut down TransportSet before the transport is created. Further call to
@ -542,7 +547,7 @@ public class TransportSetTest {
any(MethodDescriptor.class), any(Metadata.class)); any(MethodDescriptor.class), any(Metadata.class));
assertEquals(1, fakeExecutor.runDueTasks()); assertEquals(1, fakeExecutor.runDueTasks());
verify(transportInfo.transport).newStream(same(method), same(headers), verify(transportInfo.transport).newStream(same(method), same(headers),
same(waitForReadyCallOptions)); same(waitForReadyCallOptions), any(StatsTraceContext.class));
verify(transportInfo.transport).shutdown(); verify(transportInfo.transport).shutdown();
transportInfo.listener.transportShutdown(Status.UNAVAILABLE); transportInfo.listener.transportShutdown(Status.UNAVAILABLE);
assertEquals(ConnectivityState.SHUTDOWN, transportSet.getState(false)); assertEquals(ConnectivityState.SHUTDOWN, transportSet.getState(false));
@ -640,7 +645,8 @@ public class TransportSetTest {
assertEquals(ConnectivityState.CONNECTING, transportSet.getState(true)); assertEquals(ConnectivityState.CONNECTING, transportSet.getState(true));
assertEquals(ConnectivityState.CONNECTING, transportSet.getState(true)); assertEquals(ConnectivityState.CONNECTING, transportSet.getState(true));
transportSet.obtainActiveTransport().newStream(method, new Metadata(), waitForReadyCallOptions); transportSet.obtainActiveTransport().newStream(method, new Metadata(), waitForReadyCallOptions,
statsTraceCtx);
assertEquals(ConnectivityState.CONNECTING, transportSet.getState(true)); assertEquals(ConnectivityState.CONNECTING, transportSet.getState(true));
// Fail it // Fail it
@ -698,7 +704,8 @@ public class TransportSetTest {
int notInUse = 0; int notInUse = 0;
verify(mockTransportSetCallback, never()).onInUse(any(TransportSet.class)); verify(mockTransportSetCallback, never()).onInUse(any(TransportSet.class));
transportSet.obtainActiveTransport().newStream(method, new Metadata(), waitForReadyCallOptions); transportSet.obtainActiveTransport().newStream(method, new Metadata(), waitForReadyCallOptions,
statsTraceCtx);
verify(mockTransportSetCallback, times(++inUse)).onInUse(transportSet); verify(mockTransportSetCallback, times(++inUse)).onInUse(transportSet);
MockClientTransportInfo t0 = transports.poll(); MockClientTransportInfo t0 = transports.poll();
@ -711,7 +718,8 @@ public class TransportSetTest {
// Delayed transport calls newStream() on the real transport in the executor // Delayed transport calls newStream() on the real transport in the executor
fakeExecutor.runDueTasks(); fakeExecutor.runDueTasks();
verify(t0.transport).newStream( verify(t0.transport).newStream(
same(method), any(Metadata.class), same(waitForReadyCallOptions)); same(method), any(Metadata.class), same(waitForReadyCallOptions),
any(StatsTraceContext.class));
verify(mockTransportSetCallback, times(inUse)).onInUse(transportSet); verify(mockTransportSetCallback, times(inUse)).onInUse(transportSet);
t0.listener.transportInUse(true); t0.listener.transportInUse(true);
verify(mockTransportSetCallback, times(++inUse)).onInUse(transportSet); verify(mockTransportSetCallback, times(++inUse)).onInUse(transportSet);
@ -726,13 +734,15 @@ public class TransportSetTest {
t0.listener.transportShutdown(Status.UNAVAILABLE); t0.listener.transportShutdown(Status.UNAVAILABLE);
// Creates a new transport // Creates a new transport
transportSet.obtainActiveTransport().newStream(method, new Metadata(), waitForReadyCallOptions); transportSet.obtainActiveTransport().newStream(method, new Metadata(), waitForReadyCallOptions,
statsTraceCtx);
MockClientTransportInfo t1 = transports.poll(); MockClientTransportInfo t1 = transports.poll();
t1.listener.transportReady(); t1.listener.transportReady();
// Delayed transport calls newStream() on the real transport in the executor // Delayed transport calls newStream() on the real transport in the executor
fakeExecutor.runDueTasks(); fakeExecutor.runDueTasks();
verify(t1.transport).newStream( verify(t1.transport).newStream(
same(method), any(Metadata.class), same(waitForReadyCallOptions)); same(method), any(Metadata.class), same(waitForReadyCallOptions),
any(StatsTraceContext.class));
t1.listener.transportInUse(true); t1.listener.transportInUse(true);
// No turbulance from the race mentioned eariler, because t0 has been in-use // No turbulance from the race mentioned eariler, because t0 has been in-use
verify(mockTransportSetCallback, times(inUse)).onInUse(transportSet); verify(mockTransportSetCallback, times(inUse)).onInUse(transportSet);
@ -769,7 +779,8 @@ public class TransportSetTest {
// Attempt and fail, scheduleBackoff should be triggered, // Attempt and fail, scheduleBackoff should be triggered,
// and transportSet.shutdown should be triggered by setup // and transportSet.shutdown should be triggered by setup
transportSet.obtainActiveTransport().newStream(method, new Metadata(), waitForReadyCallOptions); transportSet.obtainActiveTransport().newStream(method, new Metadata(), waitForReadyCallOptions,
statsTraceCtx);
transports.poll().listener.transportShutdown(Status.UNAVAILABLE); transports.poll().listener.transportShutdown(Status.UNAVAILABLE);
verify(mockTransportSetCallback, times(1)).onAllAddressesFailed(); verify(mockTransportSetCallback, times(1)).onAllAddressesFailed();
assertTrue(startBackoffAndShutdownAreCalled[0]); assertTrue(startBackoffAndShutdownAreCalled[0]);

View File

@ -46,12 +46,16 @@ import com.google.auth.oauth2.ComputeEngineCredentials;
import com.google.auth.oauth2.GoogleCredentials; import com.google.auth.oauth2.GoogleCredentials;
import com.google.auth.oauth2.OAuth2Credentials; import com.google.auth.oauth2.OAuth2Credentials;
import com.google.auth.oauth2.ServiceAccountCredentials; import com.google.auth.oauth2.ServiceAccountCredentials;
import com.google.census.CensusContextFactory;
import com.google.census.RpcConstants;
import com.google.census.TagValue;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.google.common.net.HostAndPort; import com.google.common.net.HostAndPort;
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import com.google.protobuf.EmptyProtos.Empty; import com.google.protobuf.EmptyProtos.Empty;
import com.google.protobuf.MessageLite;
import io.grpc.CallOptions; import io.grpc.CallOptions;
import io.grpc.ClientCall; import io.grpc.ClientCall;
@ -59,14 +63,16 @@ import io.grpc.Grpc;
import io.grpc.ManagedChannel; import io.grpc.ManagedChannel;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.Server; import io.grpc.Server;
import io.grpc.ServerBuilder;
import io.grpc.ServerCall; import io.grpc.ServerCall;
import io.grpc.ServerInterceptor; import io.grpc.ServerInterceptor;
import io.grpc.ServerInterceptors; import io.grpc.ServerInterceptors;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.StatusRuntimeException; import io.grpc.StatusRuntimeException;
import io.grpc.auth.MoreCallCredentials; import io.grpc.auth.MoreCallCredentials;
import io.grpc.internal.AbstractServerImplBuilder;
import io.grpc.internal.GrpcUtil; import io.grpc.internal.GrpcUtil;
import io.grpc.internal.testing.CensusTestUtils.FakeCensusContextFactory;
import io.grpc.internal.testing.CensusTestUtils.MetricsRecord;
import io.grpc.protobuf.ProtoUtils; import io.grpc.protobuf.ProtoUtils;
import io.grpc.stub.MetadataUtils; import io.grpc.stub.MetadataUtils;
import io.grpc.stub.StreamObserver; import io.grpc.stub.StreamObserver;
@ -95,7 +101,10 @@ import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.security.cert.Certificate; import java.security.cert.Certificate;
import java.security.cert.X509Certificate; import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
@ -120,9 +129,14 @@ public abstract class AbstractInteropTest {
new AtomicReference<Metadata>(); new AtomicReference<Metadata>();
private static ScheduledExecutorService testServiceExecutor; private static ScheduledExecutorService testServiceExecutor;
private static Server server; private static Server server;
private static final FakeCensusContextFactory clientCensusFactory =
new FakeCensusContextFactory();
private static final FakeCensusContextFactory serverCensusFactory =
new FakeCensusContextFactory();
protected static final Empty EMPTY = Empty.getDefaultInstance();
protected static void startStaticServer( protected static void startStaticServer(
ServerBuilder<?> builder, ServerInterceptor ... interceptors) { AbstractServerImplBuilder<?> builder, ServerInterceptor ... interceptors) {
testServiceExecutor = Executors.newScheduledThreadPool(2); testServiceExecutor = Executors.newScheduledThreadPool(2);
List<ServerInterceptor> allInterceptors = ImmutableList.<ServerInterceptor>builder() List<ServerInterceptor> allInterceptors = ImmutableList.<ServerInterceptor>builder()
@ -135,6 +149,7 @@ public abstract class AbstractInteropTest {
builder.addService(ServerInterceptors.intercept( builder.addService(ServerInterceptors.intercept(
new TestServiceImpl(testServiceExecutor), new TestServiceImpl(testServiceExecutor),
allInterceptors)); allInterceptors));
builder.censusContextFactory(serverCensusFactory);
try { try {
server = builder.build().start(); server = builder.build().start();
} catch (IOException ex) { } catch (IOException ex) {
@ -165,6 +180,8 @@ public abstract class AbstractInteropTest {
blockingStub = TestServiceGrpc.newBlockingStub(channel); blockingStub = TestServiceGrpc.newBlockingStub(channel);
asyncStub = TestServiceGrpc.newStub(channel); asyncStub = TestServiceGrpc.newStub(channel);
requestHeadersCapture.set(null); requestHeadersCapture.set(null);
clientCensusFactory.rolloverRecords();
serverCensusFactory.rolloverRecords();
} }
/** Clean up. */ /** Clean up. */
@ -177,9 +194,17 @@ public abstract class AbstractInteropTest {
protected abstract ManagedChannel createChannel(); protected abstract ManagedChannel createChannel();
protected final CensusContextFactory getClientCensusFactory() {
return clientCensusFactory;
}
protected boolean metricsExpected() {
return true;
}
@Test(timeout = 10000) @Test(timeout = 10000)
public void emptyUnary() throws Exception { public void emptyUnary() throws Exception {
assertEquals(Empty.getDefaultInstance(), blockingStub.emptyCall(Empty.getDefaultInstance())); assertEquals(EMPTY, blockingStub.emptyCall(EMPTY));
} }
@Test(timeout = 10000) @Test(timeout = 10000)
@ -198,6 +223,11 @@ public abstract class AbstractInteropTest {
.build(); .build();
assertEquals(goldenResponse, blockingStub.unaryCall(request)); assertEquals(goldenResponse, blockingStub.unaryCall(request));
if (metricsExpected()) {
assertMetrics("grpc.testing.TestService/UnaryCall", Status.Code.OK,
Collections.singleton(request), Collections.singleton(goldenResponse));
}
} }
@Test(timeout = 10000) @Test(timeout = 10000)
@ -273,6 +303,7 @@ public abstract class AbstractInteropTest {
} }
requestObserver.onCompleted(); requestObserver.onCompleted();
assertEquals(goldenResponse, responseObserver.firstValue().get()); assertEquals(goldenResponse, responseObserver.firstValue().get());
responseObserver.awaitCompletion();
} }
@Test(timeout = 10000) @Test(timeout = 10000)
@ -359,6 +390,11 @@ public abstract class AbstractInteropTest {
assertEquals(Arrays.<StreamingInputCallResponse>asList(), responseObserver.getValues()); assertEquals(Arrays.<StreamingInputCallResponse>asList(), responseObserver.getValues());
assertEquals(Status.Code.CANCELLED, assertEquals(Status.Code.CANCELLED,
Status.fromThrowable(responseObserver.getError()).getCode()); Status.fromThrowable(responseObserver.getError()).getCode());
if (metricsExpected()) {
assertClientMetrics("grpc.testing.TestService/StreamingInputCall", Status.Code.CANCELLED);
// Do not check server-side metrics, because the status on the server side is undetermined.
}
} }
@Test(timeout = 10000) @Test(timeout = 10000)
@ -388,6 +424,10 @@ public abstract class AbstractInteropTest {
verify(responseObserver, timeout(operationTimeoutMillis())).onError(captor.capture()); verify(responseObserver, timeout(operationTimeoutMillis())).onError(captor.capture());
assertEquals(Status.Code.CANCELLED, Status.fromThrowable(captor.getValue()).getCode()); assertEquals(Status.Code.CANCELLED, Status.fromThrowable(captor.getValue()).getCode());
verifyNoMoreInteractions(responseObserver); verifyNoMoreInteractions(responseObserver);
if (metricsExpected()) {
assertMetrics("grpc.testing.TestService/FullDuplexCall", Status.Code.CANCELLED);
}
} }
@Test(timeout = 10000) @Test(timeout = 10000)
@ -407,7 +447,10 @@ public abstract class AbstractInteropTest {
asyncStub.fullDuplexCall(recorder); asyncStub.fullDuplexCall(recorder);
final int numRequests = 10; final int numRequests = 10;
List<StreamingOutputCallRequest> requests =
new ArrayList<StreamingOutputCallRequest>(numRequests);
for (int ix = numRequests; ix > 0; --ix) { for (int ix = numRequests; ix > 0; --ix) {
requests.add(request);
requestStream.onNext(request); requestStream.onNext(request);
} }
requestStream.onCompleted(); requestStream.onCompleted();
@ -421,6 +464,11 @@ public abstract class AbstractInteropTest {
int expectedSize = responseSizes.get(ix % responseSizes.size()); int expectedSize = responseSizes.get(ix % responseSizes.size());
assertEquals("comparison failed at index " + ix, expectedSize, length); assertEquals("comparison failed at index " + ix, expectedSize, length);
} }
if (metricsExpected()) {
assertMetrics("grpc.testing.TestService/FullDuplexCall", Status.Code.OK, requests,
recorder.getValues());
}
} }
@Test(timeout = 10000) @Test(timeout = 10000)
@ -439,7 +487,10 @@ public abstract class AbstractInteropTest {
StreamObserver<StreamingOutputCallRequest> requestStream = asyncStub.halfDuplexCall(recorder); StreamObserver<StreamingOutputCallRequest> requestStream = asyncStub.halfDuplexCall(recorder);
final int numRequests = 10; final int numRequests = 10;
List<StreamingOutputCallRequest> requests =
new ArrayList<StreamingOutputCallRequest>(numRequests);
for (int ix = numRequests; ix > 0; --ix) { for (int ix = numRequests; ix > 0; --ix) {
requests.add(request);
requestStream.onNext(request); requestStream.onNext(request);
} }
requestStream.onCompleted(); requestStream.onCompleted();
@ -566,7 +617,7 @@ public abstract class AbstractInteropTest {
AtomicReference<Metadata> headersCapture = new AtomicReference<Metadata>(); AtomicReference<Metadata> headersCapture = new AtomicReference<Metadata>();
stub = MetadataUtils.captureMetadata(stub, headersCapture, trailersCapture); stub = MetadataUtils.captureMetadata(stub, headersCapture, trailersCapture);
assertNotNull(stub.emptyCall(Empty.getDefaultInstance())); assertNotNull(stub.emptyCall(EMPTY));
// Assert that our side channel object is echoed back in both headers and trailers // Assert that our side channel object is echoed back in both headers and trailers
Assert.assertEquals(contextValue, headersCapture.get().get(METADATA_KEY)); Assert.assertEquals(contextValue, headersCapture.get().get(METADATA_KEY));
@ -603,7 +654,11 @@ public abstract class AbstractInteropTest {
stub.fullDuplexCall(recorder); stub.fullDuplexCall(recorder);
final int numRequests = 10; final int numRequests = 10;
List<StreamingOutputCallRequest> requests =
new ArrayList<StreamingOutputCallRequest>(numRequests);
for (int ix = numRequests; ix > 0; --ix) { for (int ix = numRequests; ix > 0; --ix) {
requests.add(request);
requestStream.onNext(request); requestStream.onNext(request);
} }
requestStream.onCompleted(); requestStream.onCompleted();
@ -621,7 +676,7 @@ public abstract class AbstractInteropTest {
long configuredTimeoutMinutes = 100; long configuredTimeoutMinutes = 100;
TestServiceGrpc.TestServiceBlockingStub stub = TestServiceGrpc.newBlockingStub(channel) TestServiceGrpc.TestServiceBlockingStub stub = TestServiceGrpc.newBlockingStub(channel)
.withDeadlineAfter(configuredTimeoutMinutes, TimeUnit.MINUTES); .withDeadlineAfter(configuredTimeoutMinutes, TimeUnit.MINUTES);
stub.emptyCall(Empty.getDefaultInstance()); stub.emptyCall(EMPTY);
long transferredTimeoutMinutes = TimeUnit.NANOSECONDS.toMinutes( long transferredTimeoutMinutes = TimeUnit.NANOSECONDS.toMinutes(
requestHeadersCapture.get().get(GrpcUtil.TIMEOUT_KEY)); requestHeadersCapture.get().get(GrpcUtil.TIMEOUT_KEY));
Assert.assertTrue( Assert.assertTrue(
@ -649,15 +704,22 @@ public abstract class AbstractInteropTest {
blockingStub.emptyCall(Empty.getDefaultInstance()); blockingStub.emptyCall(Empty.getDefaultInstance());
TestServiceGrpc.TestServiceBlockingStub stub = TestServiceGrpc.newBlockingStub(channel) TestServiceGrpc.TestServiceBlockingStub stub = TestServiceGrpc.newBlockingStub(channel)
.withDeadlineAfter(10, TimeUnit.MILLISECONDS); .withDeadlineAfter(10, TimeUnit.MILLISECONDS);
StreamingOutputCallRequest request = StreamingOutputCallRequest.newBuilder()
.addResponseParameters(ResponseParameters.newBuilder()
.setIntervalUs(20000))
.build();
try { try {
stub.streamingOutputCall(StreamingOutputCallRequest.newBuilder() stub.streamingOutputCall(request).next();
.addResponseParameters(ResponseParameters.newBuilder()
.setIntervalUs(20000))
.build()).next();
fail("Expected deadline to be exceeded"); fail("Expected deadline to be exceeded");
} catch (StatusRuntimeException ex) { } catch (StatusRuntimeException ex) {
assertEquals(Status.DEADLINE_EXCEEDED.getCode(), ex.getStatus().getCode()); assertEquals(Status.DEADLINE_EXCEEDED.getCode(), ex.getStatus().getCode());
} }
if (metricsExpected()) {
assertMetrics("grpc.testing.TestService/EmptyCall", Status.Code.OK);
assertClientMetrics("grpc.testing.TestService/StreamingOutputCall",
Status.Code.DEADLINE_EXCEEDED);
// Do not check server-side metrics, because the status on the server side is undetermined.
}
} }
@Test(timeout = 10000) @Test(timeout = 10000)
@ -681,6 +743,12 @@ public abstract class AbstractInteropTest {
recorder.awaitCompletion(); recorder.awaitCompletion();
assertEquals(Status.DEADLINE_EXCEEDED.getCode(), assertEquals(Status.DEADLINE_EXCEEDED.getCode(),
Status.fromThrowable(recorder.getError()).getCode()); Status.fromThrowable(recorder.getError()).getCode());
if (metricsExpected()) {
assertMetrics("grpc.testing.TestService/EmptyCall", Status.Code.OK);
assertClientMetrics("grpc.testing.TestService/StreamingOutputCall",
Status.Code.DEADLINE_EXCEEDED);
// Do not check server-side metrics, because the status on the server side is undetermined.
}
} }
@Test(timeout = 10000) @Test(timeout = 10000)
@ -690,9 +758,13 @@ public abstract class AbstractInteropTest {
TestServiceGrpc.newBlockingStub(channel) TestServiceGrpc.newBlockingStub(channel)
.withDeadlineAfter(-10, TimeUnit.SECONDS) .withDeadlineAfter(-10, TimeUnit.SECONDS)
.emptyCall(Empty.getDefaultInstance()); .emptyCall(Empty.getDefaultInstance());
fail("Should have thrown");
} catch (StatusRuntimeException ex) { } catch (StatusRuntimeException ex) {
assertEquals(Status.Code.DEADLINE_EXCEEDED, ex.getStatus().getCode()); assertEquals(Status.Code.DEADLINE_EXCEEDED, ex.getStatus().getCode());
} }
if (metricsExpected()) {
assertClientMetrics("grpc.testing.TestService/EmptyCall", Status.Code.DEADLINE_EXCEEDED);
}
// warm up the channel // warm up the channel
blockingStub.emptyCall(Empty.getDefaultInstance()); blockingStub.emptyCall(Empty.getDefaultInstance());
@ -700,9 +772,14 @@ public abstract class AbstractInteropTest {
TestServiceGrpc.newBlockingStub(channel) TestServiceGrpc.newBlockingStub(channel)
.withDeadlineAfter(-10, TimeUnit.SECONDS) .withDeadlineAfter(-10, TimeUnit.SECONDS)
.emptyCall(Empty.getDefaultInstance()); .emptyCall(Empty.getDefaultInstance());
fail("Should have thrown");
} catch (StatusRuntimeException ex) { } catch (StatusRuntimeException ex) {
assertEquals(Status.Code.DEADLINE_EXCEEDED, ex.getStatus().getCode()); assertEquals(Status.Code.DEADLINE_EXCEEDED, ex.getStatus().getCode());
} }
if (metricsExpected()) {
assertMetrics("grpc.testing.TestService/EmptyCall", Status.Code.OK);
assertClientMetrics("grpc.testing.TestService/EmptyCall", Status.Code.DEADLINE_EXCEEDED);
}
} }
protected int unaryPayloadLength() { protected int unaryPayloadLength() {
@ -777,6 +854,11 @@ public abstract class AbstractInteropTest {
} catch (StatusRuntimeException e) { } catch (StatusRuntimeException e) {
assertEquals(Status.UNIMPLEMENTED.getCode(), e.getStatus().getCode()); assertEquals(Status.UNIMPLEMENTED.getCode(), e.getStatus().getCode());
} }
if (metricsExpected()) {
assertMetrics("grpc.testing.UnimplementedService/UnimplementedCall",
Status.Code.UNIMPLEMENTED);
}
} }
/** Start a fullDuplexCall which the server will not respond, and verify the deadline expires. */ /** Start a fullDuplexCall which the server will not respond, and verify the deadline expires. */
@ -789,11 +871,12 @@ public abstract class AbstractInteropTest {
StreamObserver<StreamingOutputCallRequest> requestObserver StreamObserver<StreamingOutputCallRequest> requestObserver
= stub.fullDuplexCall(responseObserver); = stub.fullDuplexCall(responseObserver);
StreamingOutputCallRequest request = StreamingOutputCallRequest.newBuilder()
.setPayload(Payload.newBuilder()
.setBody(ByteString.copyFrom(new byte[27182])))
.build();
try { try {
requestObserver.onNext(StreamingOutputCallRequest.newBuilder() requestObserver.onNext(request);
.setPayload(Payload.newBuilder()
.setBody(ByteString.copyFrom(new byte[27182])))
.build());
} catch (IllegalStateException expected) { } catch (IllegalStateException expected) {
// This can happen if the stream has already been terminated due to deadline exceeded. // This can happen if the stream has already been terminated due to deadline exceeded.
} }
@ -803,6 +886,11 @@ public abstract class AbstractInteropTest {
assertEquals(Status.DEADLINE_EXCEEDED.getCode(), assertEquals(Status.DEADLINE_EXCEEDED.getCode(),
Status.fromThrowable(captor.getValue()).getCode()); Status.fromThrowable(captor.getValue()).getCode());
verifyNoMoreInteractions(responseObserver); verifyNoMoreInteractions(responseObserver);
if (metricsExpected()) {
assertClientMetrics("grpc.testing.TestService/FullDuplexCall", Status.Code.DEADLINE_EXCEEDED);
// Do not check server-side metrics, because the status on the server side is undetermined.
}
} }
/** Sends a large unary rpc with service account credentials. */ /** Sends a large unary rpc with service account credentials. */
@ -1020,4 +1108,115 @@ public abstract class AbstractInteropTest {
throw e; throw e;
} }
} }
/**
* Poll the next metrics record and check it against the provided information, including the
* message sizes.
*/
private void assertMetrics(String method, Status.Code status,
Collection<? extends MessageLite> requests,
Collection<? extends MessageLite> responses) {
assertClientMetrics(method, status, requests, responses);
assertServerMetrics(method, status, requests, responses);
}
/**
* Poll the next metrics record and check it against the provided information, without checking
* the message sizes.
*/
private void assertMetrics(String method, Status.Code status) {
assertMetrics(method, status, null, null);
}
private void assertClientMetrics(String method, Status.Code status,
Collection<? extends MessageLite> requests, Collection<? extends MessageLite> responses) {
MetricsRecord clientRecord = clientCensusFactory.pollRecord();
assertNotNull("clientRecord is not null", clientRecord);
checkTags(clientRecord, false, method, status);
if (requests != null && responses != null) {
checkMetrics(clientRecord, false, requests, responses);
}
}
private void assertClientMetrics(String method, Status.Code status) {
assertClientMetrics(method, status, null, null);
}
private void assertServerMetrics(String method, Status.Code status,
Collection<? extends MessageLite> requests, Collection<? extends MessageLite> responses) {
AssertionError checkFailure = null;
// Because the server doesn't restart between tests, it may still be processing the requests
// from the previous tests when a new test starts, thus the test may see metrics from previous
// tests. The best we can do here is to exhaust all records and find one that matches the given
// conditions.
while (true) {
MetricsRecord serverRecord;
try {
// On the server, the stats is finalized in ServerStreamListener.closed(), which can be run
// after the client receives the final status. So we use a timeout.
serverRecord = serverCensusFactory.pollRecord(1, TimeUnit.SECONDS);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
if (serverRecord == null) {
break;
}
try {
checkTags(serverRecord, true, method, status);
if (requests != null && responses != null) {
checkMetrics(serverRecord, true, requests, responses);
}
return; // passed
} catch (AssertionError e) {
// May be the fallout from a previous test, continue trying
checkFailure = e;
}
}
if (checkFailure == null) {
throw new AssertionError("No record found");
}
throw checkFailure;
}
private static void checkTags(
MetricsRecord record, boolean server, String methodName, Status.Code status) {
TagValue methodNameTag = record.tags.get(
server ? RpcConstants.RPC_SERVER_METHOD : RpcConstants.RPC_CLIENT_METHOD);
assertNotNull("method name tagged", methodNameTag);
assertEquals("method names match", methodName, methodNameTag.toString());
TagValue statusTag = record.tags.get(RpcConstants.RPC_STATUS);
assertNotNull("status tagged", statusTag);
assertEquals(status.toString(), statusTag.toString());
}
private static void checkMetrics(MetricsRecord record, boolean server,
Collection<? extends MessageLite> requests, Collection<? extends MessageLite> responses) {
int uncompressedRequestsSize = 0;
for (MessageLite request : requests) {
uncompressedRequestsSize += request.getSerializedSize();
}
int uncompressedResponsesSize = 0;
for (MessageLite response : responses) {
uncompressedResponsesSize += response.getSerializedSize();
}
if (server) {
assertEquals(uncompressedRequestsSize,
record.getMetricAsLongOrFail(RpcConstants.RPC_SERVER_UNCOMPRESSED_REQUEST_BYTES));
assertEquals(uncompressedResponsesSize,
record.getMetricAsLongOrFail(RpcConstants.RPC_SERVER_UNCOMPRESSED_RESPONSE_BYTES));
// It's impossible to get the expected wire sizes because it may be compressed, so we just
// check if they are recorded.
assertNotNull(record.getMetric(RpcConstants.RPC_SERVER_REQUEST_BYTES));
assertNotNull(record.getMetric(RpcConstants.RPC_SERVER_RESPONSE_BYTES));
} else {
assertEquals(uncompressedRequestsSize,
record.getMetricAsLongOrFail(RpcConstants.RPC_CLIENT_UNCOMPRESSED_REQUEST_BYTES));
assertEquals(uncompressedResponsesSize,
record.getMetricAsLongOrFail(RpcConstants.RPC_CLIENT_UNCOMPRESSED_RESPONSE_BYTES));
// It's impossible to get the expected wire sizes because it may be compressed, so we just
// check if they are recorded.
assertNotNull(record.getMetric(RpcConstants.RPC_CLIENT_REQUEST_BYTES));
assertNotNull(record.getMetric(RpcConstants.RPC_CLIENT_RESPONSE_BYTES));
}
}
} }

View File

@ -484,6 +484,12 @@ public class StressTestClient {
// Fixes https://github.com/grpc/grpc-java/issues/1812 // Fixes https://github.com/grpc/grpc-java/issues/1812
return Integer.MAX_VALUE; return Integer.MAX_VALUE;
} }
@Override
protected boolean metricsExpected() {
// TODO(zhangkun83): we may want to enable the real Census implementation in stress tests.
return false;
}
} }
class WeightedTestCaseSelector { class WeightedTestCaseSelector {

View File

@ -314,6 +314,7 @@ public class TestServiceClient {
.flowControlWindow(65 * 1024) .flowControlWindow(65 * 1024)
.negotiationType(useTls ? NegotiationType.TLS : NegotiationType.PLAINTEXT) .negotiationType(useTls ? NegotiationType.TLS : NegotiationType.PLAINTEXT)
.sslContext(sslContext) .sslContext(sslContext)
.censusContextFactory(getClientCensusFactory())
.build(); .build();
} else { } else {
OkHttpChannelBuilder builder = OkHttpChannelBuilder.forAddress(serverHost, serverPort); OkHttpChannelBuilder builder = OkHttpChannelBuilder.forAddress(serverHost, serverPort);

View File

@ -63,6 +63,7 @@ public class AutoWindowSizingOnTest extends AbstractInteropTest {
return NettyChannelBuilder.forAddress("localhost", getPort()) return NettyChannelBuilder.forAddress("localhost", getPort())
.negotiationType(NegotiationType.PLAINTEXT) .negotiationType(NegotiationType.PLAINTEXT)
.maxMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE) .maxMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE)
.censusContextFactory(getClientCensusFactory())
.build(); .build();
} }
} }

View File

@ -75,6 +75,7 @@ public class Http2NettyLocalChannelTest extends AbstractInteropTest {
.channelType(LocalChannel.class) .channelType(LocalChannel.class)
.flowControlWindow(65 * 1024) .flowControlWindow(65 * 1024)
.maxMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE) .maxMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE)
.censusContextFactory(getClientCensusFactory())
.build(); .build();
} }
} }

View File

@ -92,6 +92,7 @@ public class Http2NettyTest extends AbstractInteropTest {
.ciphers(TestUtils.preferredTestCiphers(), SupportedCipherSuiteFilter.INSTANCE) .ciphers(TestUtils.preferredTestCiphers(), SupportedCipherSuiteFilter.INSTANCE)
.sslProvider(SslProvider.OPENSSL) .sslProvider(SslProvider.OPENSSL)
.build()) .build())
.censusContextFactory(getClientCensusFactory())
.build(); .build();
} catch (Exception ex) { } catch (Exception ex) {
throw new RuntimeException(ex); throw new RuntimeException(ex);

View File

@ -31,6 +31,7 @@
package io.grpc.testing.integration; package io.grpc.testing.integration;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail; import static org.junit.Assert.fail;
@ -107,6 +108,7 @@ public class Http2OkHttpTest extends AbstractInteropTest {
.cipherSuites(TestUtils.preferredTestCiphers().toArray(new String[0])) .cipherSuites(TestUtils.preferredTestCiphers().toArray(new String[0]))
.tlsVersions(ConnectionSpec.MODERN_TLS.tlsVersions().toArray(new TlsVersion[0])) .tlsVersions(ConnectionSpec.MODERN_TLS.tlsVersions().toArray(new TlsVersion[0]))
.build()) .build())
.censusContextFactory(getClientCensusFactory())
.overrideAuthority(GrpcUtil.authorityFromHostAndPort( .overrideAuthority(GrpcUtil.authorityFromHostAndPort(
TestUtils.TEST_SERVER_HOST, getPort())); TestUtils.TEST_SERVER_HOST, getPort()));
try { try {
@ -133,12 +135,14 @@ public class Http2OkHttpTest extends AbstractInteropTest {
StreamRecorder<Messages.StreamingOutputCallResponse> recorder = StreamRecorder.create(); StreamRecorder<Messages.StreamingOutputCallResponse> recorder = StreamRecorder.create();
StreamObserver<Messages.StreamingOutputCallRequest> requestStream = StreamObserver<Messages.StreamingOutputCallRequest> requestStream =
asyncStub.fullDuplexCall(recorder); asyncStub.fullDuplexCall(recorder);
requestStream.onNext(requestBuilder.build()); Messages.StreamingOutputCallRequest request = requestBuilder.build();
requestStream.onNext(request);
recorder.firstValue().get(); recorder.firstValue().get();
requestStream.onError(new Exception("failed")); requestStream.onError(new Exception("failed"));
recorder.awaitCompletion(); recorder.awaitCompletion();
emptyUnary();
assertEquals(EMPTY, blockingStub.emptyCall(EMPTY));
} }
@Test(timeout = 10000) @Test(timeout = 10000)

View File

@ -58,6 +58,14 @@ public class InProcessTest extends AbstractInteropTest {
@Override @Override
protected ManagedChannel createChannel() { protected ManagedChannel createChannel() {
return InProcessChannelBuilder.forName(serverName).build(); return InProcessChannelBuilder.forName(serverName)
.censusContextFactory(getClientCensusFactory()).build();
}
@Override
protected boolean metricsExpected() {
// TODO(zhangkun83): InProcessTransport by-passes framer and deframer, thus message sizses are
// not counted. (https://github.com/grpc/grpc-java/issues/2284)
return false;
} }
} }

View File

@ -152,6 +152,7 @@ public class TransportCompressionTest extends AbstractInteropTest {
.maxMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE) .maxMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE)
.decompressorRegistry(decompressors) .decompressorRegistry(decompressors)
.compressorRegistry(compressors) .compressorRegistry(compressors)
.censusContextFactory(getClientCensusFactory())
.intercept(new ClientInterceptor() { .intercept(new ClientInterceptor() {
@Override @Override
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall( public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(

View File

@ -45,6 +45,7 @@ import io.grpc.internal.AbstractClientStream2;
import io.grpc.internal.ClientStreamListener; import io.grpc.internal.ClientStreamListener;
import io.grpc.internal.GrpcUtil; import io.grpc.internal.GrpcUtil;
import io.grpc.internal.Http2ClientStreamTransportState; import io.grpc.internal.Http2ClientStreamTransportState;
import io.grpc.internal.StatsTraceContext;
import io.grpc.internal.WritableBuffer; import io.grpc.internal.WritableBuffer;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel; import io.netty.channel.Channel;
@ -77,8 +78,8 @@ class NettyClientStream extends AbstractClientStream2 {
NettyClientStream(TransportState state, MethodDescriptor<?, ?> method, Metadata headers, NettyClientStream(TransportState state, MethodDescriptor<?, ?> method, Metadata headers,
Channel channel, AsciiString authority, AsciiString scheme, Channel channel, AsciiString authority, AsciiString scheme,
AsciiString userAgent) { AsciiString userAgent, StatsTraceContext statsTraceCtx) {
super(new NettyWritableBufferAllocator(channel.alloc())); super(new NettyWritableBufferAllocator(channel.alloc()), statsTraceCtx);
this.state = checkNotNull(state, "transportState"); this.state = checkNotNull(state, "transportState");
this.writeQueue = state.handler.getWriteQueue(); this.writeQueue = state.handler.getWriteQueue();
this.method = checkNotNull(method, "method"); this.method = checkNotNull(method, "method");
@ -183,8 +184,9 @@ class NettyClientStream extends AbstractClientStream2 {
private int id; private int id;
private Http2Stream http2Stream; private Http2Stream http2Stream;
public TransportState(NettyClientHandler handler, int maxMessageSize) { public TransportState(NettyClientHandler handler, int maxMessageSize,
super(maxMessageSize); StatsTraceContext statsTraceCtx) {
super(maxMessageSize, statsTraceCtx);
this.handler = checkNotNull(handler, "handler"); this.handler = checkNotNull(handler, "handler");
} }

View File

@ -45,6 +45,7 @@ import io.grpc.internal.ClientStream;
import io.grpc.internal.ConnectionClientTransport; import io.grpc.internal.ConnectionClientTransport;
import io.grpc.internal.GrpcUtil; import io.grpc.internal.GrpcUtil;
import io.grpc.internal.Http2Ping; import io.grpc.internal.Http2Ping;
import io.grpc.internal.StatsTraceContext;
import io.netty.bootstrap.Bootstrap; import io.netty.bootstrap.Bootstrap;
import io.netty.channel.Channel; import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFuture;
@ -116,23 +117,25 @@ class NettyClientTransport implements ConnectionClientTransport {
} }
@Override @Override
public ClientStream newStream(MethodDescriptor<?, ?> method, Metadata headers, CallOptions public ClientStream newStream(MethodDescriptor<?, ?> method, Metadata headers,
callOptions) { CallOptions callOptions, StatsTraceContext statsTraceCtx) {
Preconditions.checkNotNull(method, "method"); Preconditions.checkNotNull(method, "method");
Preconditions.checkNotNull(headers, "headers"); Preconditions.checkNotNull(headers, "headers");
Preconditions.checkNotNull(statsTraceCtx, "statsTraceCtx");
return new NettyClientStream( return new NettyClientStream(
new NettyClientStream.TransportState(handler, maxMessageSize) { new NettyClientStream.TransportState(handler, maxMessageSize, statsTraceCtx) {
@Override @Override
protected Status statusFromFailedFuture(ChannelFuture f) { protected Status statusFromFailedFuture(ChannelFuture f) {
return NettyClientTransport.this.statusFromFailedFuture(f); return NettyClientTransport.this.statusFromFailedFuture(f);
} }
}, },
method, headers, channel, authority, negotiationHandler.scheme(), userAgent); method, headers, channel, authority, negotiationHandler.scheme(), userAgent,
statsTraceCtx);
} }
@Override @Override
public ClientStream newStream(MethodDescriptor<?, ?> method, Metadata headers) { public ClientStream newStream(MethodDescriptor<?, ?> method, Metadata headers) {
return newStream(method, headers, CallOptions.DEFAULT); return newStream(method, headers, CallOptions.DEFAULT, StatsTraceContext.NOOP);
} }
@Override @Override

View File

@ -49,6 +49,7 @@ import io.grpc.Status;
import io.grpc.internal.GrpcUtil; import io.grpc.internal.GrpcUtil;
import io.grpc.internal.ServerStreamListener; import io.grpc.internal.ServerStreamListener;
import io.grpc.internal.ServerTransportListener; import io.grpc.internal.ServerTransportListener;
import io.grpc.internal.StatsTraceContext;
import io.grpc.netty.GrpcHttp2HeadersDecoder.GrpcHttp2ServerHeadersDecoder; import io.grpc.netty.GrpcHttp2HeadersDecoder.GrpcHttp2ServerHeadersDecoder;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFuture;
@ -193,14 +194,14 @@ class NettyServerHandler extends AbstractNettyHandler {
// method. // method.
Http2Stream http2Stream = requireHttp2Stream(streamId); Http2Stream http2Stream = requireHttp2Stream(streamId);
NettyServerStream.TransportState state =
new NettyServerStream.TransportState(this, http2Stream, maxMessageSize);
NettyServerStream stream = new NettyServerStream(ctx.channel(), state, attributes);
Metadata metadata = Utils.convertHeaders(headers); Metadata metadata = Utils.convertHeaders(headers);
StatsTraceContext statsTraceCtx =
ServerStreamListener listener = checkNotNull(transportListener.methodDetermined(method, metadata), "statsTraceCtx");
transportListener.streamCreated(stream, method, metadata); NettyServerStream.TransportState state = new NettyServerStream.TransportState(
this, http2Stream, maxMessageSize, statsTraceCtx);
NettyServerStream stream = new NettyServerStream(ctx.channel(), state, attributes,
statsTraceCtx);
ServerStreamListener listener = transportListener.streamCreated(stream, method, metadata);
// TODO(ejona): this could be racy since stream could have been used before getting here. All // TODO(ejona): this could be racy since stream could have been used before getting here. All
// cases appear to be fine, but some are almost only by happenstance and it is difficult to // cases appear to be fine, but some are almost only by happenstance and it is difficult to
// audit. It would be good to improve the API to be less prone to races. // audit. It would be good to improve the API to be less prone to races.

View File

@ -37,6 +37,7 @@ import io.grpc.Attributes;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.internal.AbstractServerStream; import io.grpc.internal.AbstractServerStream;
import io.grpc.internal.StatsTraceContext;
import io.grpc.internal.WritableBuffer; import io.grpc.internal.WritableBuffer;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel; import io.netty.channel.Channel;
@ -61,8 +62,9 @@ class NettyServerStream extends AbstractServerStream {
private final WriteQueue writeQueue; private final WriteQueue writeQueue;
private final Attributes attributes; private final Attributes attributes;
public NettyServerStream(Channel channel, TransportState state, Attributes transportAttrs) { public NettyServerStream(Channel channel, TransportState state, Attributes transportAttrs,
super(new NettyWritableBufferAllocator(channel.alloc())); StatsTraceContext statsTraceCtx) {
super(new NettyWritableBufferAllocator(channel.alloc()), statsTraceCtx);
this.state = checkNotNull(state, "transportState"); this.state = checkNotNull(state, "transportState");
this.channel = checkNotNull(channel, "channel"); this.channel = checkNotNull(channel, "channel");
this.writeQueue = state.handler.getWriteQueue(); this.writeQueue = state.handler.getWriteQueue();
@ -142,8 +144,9 @@ class NettyServerStream extends AbstractServerStream {
private final Http2Stream http2Stream; private final Http2Stream http2Stream;
private final NettyServerHandler handler; private final NettyServerHandler handler;
public TransportState(NettyServerHandler handler, Http2Stream http2Stream, int maxMessageSize) { public TransportState(NettyServerHandler handler, Http2Stream http2Stream, int maxMessageSize,
super(maxMessageSize); StatsTraceContext statsTraceCtx) {
super(maxMessageSize, statsTraceCtx);
this.http2Stream = checkNotNull(http2Stream, "http2Stream"); this.http2Stream = checkNotNull(http2Stream, "http2Stream");
this.handler = checkNotNull(handler, "handler"); this.handler = checkNotNull(handler, "handler");
} }

View File

@ -64,6 +64,7 @@ import io.grpc.internal.ClientStreamListener;
import io.grpc.internal.ClientTransport; import io.grpc.internal.ClientTransport;
import io.grpc.internal.ClientTransport.PingCallback; import io.grpc.internal.ClientTransport.PingCallback;
import io.grpc.internal.GrpcUtil; import io.grpc.internal.GrpcUtil;
import io.grpc.internal.StatsTraceContext;
import io.grpc.netty.GrpcHttp2HeadersDecoder.GrpcHttp2ClientHeadersDecoder; import io.grpc.netty.GrpcHttp2HeadersDecoder.GrpcHttp2ClientHeadersDecoder;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil; import io.netty.buffer.ByteBufUtil;
@ -568,7 +569,7 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase<NettyClientHand
private static class TransportStateImpl extends NettyClientStream.TransportState { private static class TransportStateImpl extends NettyClientStream.TransportState {
public TransportStateImpl(NettyClientHandler handler, int maxMessageSize) { public TransportStateImpl(NettyClientHandler handler, int maxMessageSize) {
super(handler, maxMessageSize); super(handler, maxMessageSize, StatsTraceContext.NOOP);
} }
@Override @Override

View File

@ -61,6 +61,7 @@ import io.grpc.MethodDescriptor;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.internal.ClientStreamListener; import io.grpc.internal.ClientStreamListener;
import io.grpc.internal.GrpcUtil; import io.grpc.internal.GrpcUtil;
import io.grpc.internal.StatsTraceContext;
import io.grpc.netty.WriteQueue.QueuedCommand; import io.grpc.netty.WriteQueue.QueuedCommand;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled; import io.netty.buffer.Unpooled;
@ -360,7 +361,7 @@ public class NettyClientStreamTest extends NettyStreamTestBase<NettyClientStream
stream = new NettyClientStream(new TransportStateImpl(handler, DEFAULT_MAX_MESSAGE_SIZE), stream = new NettyClientStream(new TransportStateImpl(handler, DEFAULT_MAX_MESSAGE_SIZE),
methodDescriptor, new Metadata(), channel, AsciiString.of("localhost"), methodDescriptor, new Metadata(), channel, AsciiString.of("localhost"),
AsciiString.of("http"), AsciiString.of("agent")); AsciiString.of("http"), AsciiString.of("agent"), StatsTraceContext.NOOP);
stream.start(listener); stream.start(listener);
stream().transportState().setId(STREAM_ID); stream().transportState().setId(STREAM_ID);
verify(listener, never()).onReady(); verify(listener, never()).onReady();
@ -380,7 +381,7 @@ public class NettyClientStreamTest extends NettyStreamTestBase<NettyClientStream
stream = new NettyClientStream(new TransportStateImpl(handler, DEFAULT_MAX_MESSAGE_SIZE), stream = new NettyClientStream(new TransportStateImpl(handler, DEFAULT_MAX_MESSAGE_SIZE),
methodDescriptor, new Metadata(), channel, AsciiString.of("localhost"), methodDescriptor, new Metadata(), channel, AsciiString.of("localhost"),
AsciiString.of("http"), AsciiString.of("good agent")); AsciiString.of("http"), AsciiString.of("good agent"), StatsTraceContext.NOOP);
stream.start(listener); stream.start(listener);
ArgumentCaptor<CreateStreamCommand> cmdCap = ArgumentCaptor.forClass(CreateStreamCommand.class); ArgumentCaptor<CreateStreamCommand> cmdCap = ArgumentCaptor.forClass(CreateStreamCommand.class);
@ -404,7 +405,8 @@ public class NettyClientStreamTest extends NettyStreamTestBase<NettyClientStream
when(writeQueue.enqueue(any(QueuedCommand.class), anyBoolean())).thenReturn(future); when(writeQueue.enqueue(any(QueuedCommand.class), anyBoolean())).thenReturn(future);
NettyClientStream stream = new NettyClientStream( NettyClientStream stream = new NettyClientStream(
new TransportStateImpl(handler, DEFAULT_MAX_MESSAGE_SIZE), methodDescriptor, new Metadata(), new TransportStateImpl(handler, DEFAULT_MAX_MESSAGE_SIZE), methodDescriptor, new Metadata(),
channel, AsciiString.of("localhost"), AsciiString.of("http"), AsciiString.of("agent")); channel, AsciiString.of("localhost"), AsciiString.of("http"), AsciiString.of("agent"),
StatsTraceContext.NOOP);
stream.start(listener); stream.start(listener);
stream.transportState().setId(STREAM_ID); stream.transportState().setId(STREAM_ID);
stream.transportState().setHttp2Stream(http2Stream); stream.transportState().setHttp2Stream(http2Stream);
@ -442,7 +444,7 @@ public class NettyClientStreamTest extends NettyStreamTestBase<NettyClientStream
private static class TransportStateImpl extends NettyClientStream.TransportState { private static class TransportStateImpl extends NettyClientStream.TransportState {
public TransportStateImpl(NettyClientHandler handler, int maxMessageSize) { public TransportStateImpl(NettyClientHandler handler, int maxMessageSize) {
super(handler, maxMessageSize); super(handler, maxMessageSize, StatsTraceContext.NOOP);
} }
@Override @Override

View File

@ -44,6 +44,7 @@ import com.google.common.io.ByteStreams;
import com.google.common.util.concurrent.SettableFuture; import com.google.common.util.concurrent.SettableFuture;
import io.grpc.Attributes; import io.grpc.Attributes;
import io.grpc.Context;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor;
import io.grpc.MethodDescriptor.Marshaller; import io.grpc.MethodDescriptor.Marshaller;
@ -58,6 +59,7 @@ import io.grpc.internal.ServerStream;
import io.grpc.internal.ServerStreamListener; import io.grpc.internal.ServerStreamListener;
import io.grpc.internal.ServerTransport; import io.grpc.internal.ServerTransport;
import io.grpc.internal.ServerTransportListener; import io.grpc.internal.ServerTransportListener;
import io.grpc.internal.StatsTraceContext;
import io.grpc.testing.TestUtils; import io.grpc.testing.TestUtils;
import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.channel.socket.nio.NioServerSocketChannel;
@ -97,6 +99,8 @@ public class NettyClientTransportTest {
private final List<NettyClientTransport> transports = new ArrayList<NettyClientTransport>(); private final List<NettyClientTransport> transports = new ArrayList<NettyClientTransport>();
private final NioEventLoopGroup group = new NioEventLoopGroup(1); private final NioEventLoopGroup group = new NioEventLoopGroup(1);
private final StatsTraceContext statsTraceCtx = StatsTraceContext.NOOP;
private InetSocketAddress address; private InetSocketAddress address;
private String authority; private String authority;
private NettyServer server; private NettyServer server;
@ -109,6 +113,7 @@ public class NettyClientTransportTest {
@After @After
public void teardown() throws Exception { public void teardown() throws Exception {
Context.ROOT.attach();
for (NettyClientTransport transport : transports) { for (NettyClientTransport transport : transports) {
transport.shutdown(); transport.shutdown();
} }
@ -433,6 +438,10 @@ public class NettyClientTransportTest {
public ServerTransportListener transportCreated(final ServerTransport transport) { public ServerTransportListener transportCreated(final ServerTransport transport) {
transports.add((NettyServerTransport) transport); transports.add((NettyServerTransport) transport);
return new ServerTransportListener() { return new ServerTransportListener() {
@Override
public StatsTraceContext methodDetermined(String method, Metadata headers) {
return StatsTraceContext.NOOP;
}
@Override @Override
public ServerStreamListener streamCreated(final ServerStream stream, String method, public ServerStreamListener streamCreated(final ServerStream stream, String method,

View File

@ -40,6 +40,7 @@ import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import io.grpc.internal.MessageFramer; import io.grpc.internal.MessageFramer;
import io.grpc.internal.StatsTraceContext;
import io.grpc.internal.WritableBuffer; import io.grpc.internal.WritableBuffer;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.ByteBufAllocator;
@ -159,7 +160,7 @@ public abstract class NettyHandlerTestBase<T extends Http2ConnectionHandler> {
compressionFrame.writeBytes(bytebuf); compressionFrame.writeBytes(bytebuf);
} }
} }
}, new NettyWritableBufferAllocator(ByteBufAllocator.DEFAULT)); }, new NettyWritableBufferAllocator(ByteBufAllocator.DEFAULT), StatsTraceContext.NOOP);
framer.writePayload(new ByteArrayInputStream(content)); framer.writePayload(new ByteArrayInputStream(content));
framer.flush(); framer.flush();
ChannelHandlerContext ctx = newMockContext(); ChannelHandlerContext ctx = newMockContext();

View File

@ -62,6 +62,7 @@ import io.grpc.internal.GrpcUtil;
import io.grpc.internal.ServerStream; import io.grpc.internal.ServerStream;
import io.grpc.internal.ServerStreamListener; import io.grpc.internal.ServerStreamListener;
import io.grpc.internal.ServerTransportListener; import io.grpc.internal.ServerTransportListener;
import io.grpc.internal.StatsTraceContext;
import io.grpc.netty.GrpcHttp2HeadersDecoder.GrpcHttp2ServerHeadersDecoder; import io.grpc.netty.GrpcHttp2HeadersDecoder.GrpcHttp2ServerHeadersDecoder;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil; import io.netty.buffer.ByteBufUtil;
@ -103,6 +104,8 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase<NettyServerHand
@Mock @Mock
private ServerStreamListener streamListener; private ServerStreamListener streamListener;
private final StatsTraceContext statsTraceCtx = StatsTraceContext.NOOP;
private NettyServerStream stream; private NettyServerStream stream;
private int flowControlWindow = DEFAULT_WINDOW_SIZE; private int flowControlWindow = DEFAULT_WINDOW_SIZE;
@ -112,6 +115,8 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase<NettyServerHand
public void setUp() throws Exception { public void setUp() throws Exception {
MockitoAnnotations.initMocks(this); MockitoAnnotations.initMocks(this);
when(transportListener.methodDetermined(any(String.class), any(Metadata.class)))
.thenReturn(statsTraceCtx);
when(transportListener.streamCreated(any(ServerStream.class), when(transportListener.streamCreated(any(ServerStream.class),
any(String.class), any(String.class),
any(Metadata.class))) any(Metadata.class)))

View File

@ -55,6 +55,7 @@ import io.grpc.Attributes;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.internal.ServerStreamListener; import io.grpc.internal.ServerStreamListener;
import io.grpc.internal.StatsTraceContext;
import io.grpc.netty.WriteQueue.QueuedCommand; import io.grpc.netty.WriteQueue.QueuedCommand;
import io.netty.buffer.EmptyByteBuf; import io.netty.buffer.EmptyByteBuf;
import io.netty.buffer.UnpooledByteBufAllocator; import io.netty.buffer.UnpooledByteBufAllocator;
@ -291,9 +292,11 @@ public class NettyServerStreamTest extends NettyStreamTestBase<NettyServerStream
} }
}).when(writeQueue).enqueue(any(QueuedCommand.class), any(ChannelPromise.class), anyBoolean()); }).when(writeQueue).enqueue(any(QueuedCommand.class), any(ChannelPromise.class), anyBoolean());
when(writeQueue.enqueue(any(QueuedCommand.class), anyBoolean())).thenReturn(future); when(writeQueue.enqueue(any(QueuedCommand.class), anyBoolean())).thenReturn(future);
NettyServerStream.TransportState state = StatsTraceContext statsTraceCtx = StatsTraceContext.NOOP;
new NettyServerStream.TransportState(handler, http2Stream, DEFAULT_MAX_MESSAGE_SIZE); NettyServerStream.TransportState state = new NettyServerStream.TransportState(
NettyServerStream stream = new NettyServerStream(channel, state, Attributes.EMPTY); handler, http2Stream, DEFAULT_MAX_MESSAGE_SIZE, statsTraceCtx);
NettyServerStream stream = new NettyServerStream(channel, state, Attributes.EMPTY,
statsTraceCtx);
stream.transportState().setListener(serverListener); stream.transportState().setListener(serverListener);
verify(serverListener, atLeastOnce()).onReady(); verify(serverListener, atLeastOnce()).onReady();
verifyNoMoreInteractions(serverListener); verifyNoMoreInteractions(serverListener);

View File

@ -40,6 +40,7 @@ import io.grpc.Status;
import io.grpc.internal.ClientStreamListener; import io.grpc.internal.ClientStreamListener;
import io.grpc.internal.GrpcUtil; import io.grpc.internal.GrpcUtil;
import io.grpc.internal.Http2ClientStream; import io.grpc.internal.Http2ClientStream;
import io.grpc.internal.StatsTraceContext;
import io.grpc.internal.WritableBuffer; import io.grpc.internal.WritableBuffer;
import io.grpc.okhttp.internal.framed.ErrorCode; import io.grpc.okhttp.internal.framed.ErrorCode;
import io.grpc.okhttp.internal.framed.Header; import io.grpc.okhttp.internal.framed.Header;
@ -96,8 +97,9 @@ class OkHttpClientStream extends Http2ClientStream {
Object lock, Object lock,
int maxMessageSize, int maxMessageSize,
String authority, String authority,
String userAgent) { String userAgent,
super(new OkHttpWritableBufferAllocator(), maxMessageSize); StatsTraceContext statsTraceCtx) {
super(new OkHttpWritableBufferAllocator(), maxMessageSize, statsTraceCtx);
this.method = method; this.method = method;
this.headers = headers; this.headers = headers;
this.frameWriter = frameWriter; this.frameWriter = frameWriter;

View File

@ -53,6 +53,7 @@ import io.grpc.internal.Http2Ping;
import io.grpc.internal.KeepAliveManager; import io.grpc.internal.KeepAliveManager;
import io.grpc.internal.SerializingExecutor; import io.grpc.internal.SerializingExecutor;
import io.grpc.internal.SharedResourceHolder; import io.grpc.internal.SharedResourceHolder;
import io.grpc.internal.StatsTraceContext;
import io.grpc.okhttp.internal.ConnectionSpec; import io.grpc.okhttp.internal.ConnectionSpec;
import io.grpc.okhttp.internal.framed.ErrorCode; import io.grpc.okhttp.internal.framed.ErrorCode;
import io.grpc.okhttp.internal.framed.FrameReader; import io.grpc.okhttp.internal.framed.FrameReader;
@ -269,18 +270,19 @@ class OkHttpClientTransport implements ConnectionClientTransport {
} }
@Override @Override
public OkHttpClientStream newStream(final MethodDescriptor<?, ?> method, final Metadata public OkHttpClientStream newStream(final MethodDescriptor<?, ?> method,
headers, CallOptions callOptions) { final Metadata headers, CallOptions callOptions, StatsTraceContext statsTraceCtx) {
Preconditions.checkNotNull(method, "method"); Preconditions.checkNotNull(method, "method");
Preconditions.checkNotNull(headers, "headers"); Preconditions.checkNotNull(headers, "headers");
Preconditions.checkNotNull(statsTraceCtx, "statsTraceCtx");
return new OkHttpClientStream(method, headers, frameWriter, OkHttpClientTransport.this, return new OkHttpClientStream(method, headers, frameWriter, OkHttpClientTransport.this,
outboundFlow, lock, maxMessageSize, defaultAuthority, userAgent); outboundFlow, lock, maxMessageSize, defaultAuthority, userAgent, statsTraceCtx);
} }
@Override @Override
public OkHttpClientStream newStream(final MethodDescriptor<?, ?> method, final Metadata public OkHttpClientStream newStream(final MethodDescriptor<?, ?> method, final Metadata
headers) { headers) {
return newStream(method, headers, CallOptions.DEFAULT); return newStream(method, headers, CallOptions.DEFAULT, StatsTraceContext.NOOP);
} }
@GuardedBy("lock") @GuardedBy("lock")

View File

@ -44,6 +44,7 @@ import io.grpc.MethodDescriptor.MethodType;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.internal.ClientStreamListener; import io.grpc.internal.ClientStreamListener;
import io.grpc.internal.GrpcUtil; import io.grpc.internal.GrpcUtil;
import io.grpc.internal.StatsTraceContext;
import io.grpc.okhttp.internal.framed.ErrorCode; import io.grpc.okhttp.internal.framed.ErrorCode;
import io.grpc.okhttp.internal.framed.Header; import io.grpc.okhttp.internal.framed.Header;
@ -84,7 +85,7 @@ public class OkHttpClientStreamTest {
methodDescriptor = MethodDescriptor.create( methodDescriptor = MethodDescriptor.create(
MethodType.UNARY, "/testService/test", marshaller, marshaller); MethodType.UNARY, "/testService/test", marshaller, marshaller);
stream = new OkHttpClientStream(methodDescriptor, new Metadata(), frameWriter, transport, stream = new OkHttpClientStream(methodDescriptor, new Metadata(), frameWriter, transport,
flowController, lock, MAX_MESSAGE_SIZE, "localhost", "userAgent"); flowController, lock, MAX_MESSAGE_SIZE, "localhost", "userAgent", StatsTraceContext.NOOP);
} }
@Test @Test
@ -140,7 +141,8 @@ public class OkHttpClientStreamTest {
Metadata metaData = new Metadata(); Metadata metaData = new Metadata();
metaData.put(GrpcUtil.USER_AGENT_KEY, "misbehaving-application"); metaData.put(GrpcUtil.USER_AGENT_KEY, "misbehaving-application");
stream = new OkHttpClientStream(methodDescriptor, metaData, frameWriter, transport, stream = new OkHttpClientStream(methodDescriptor, metaData, frameWriter, transport,
flowController, lock, MAX_MESSAGE_SIZE, "localhost", "good-application"); flowController, lock, MAX_MESSAGE_SIZE, "localhost", "good-application",
StatsTraceContext.NOOP);
stream.start(new BaseClientStreamListener()); stream.start(new BaseClientStreamListener());
stream.start(3); stream.start(3);
@ -154,7 +156,8 @@ public class OkHttpClientStreamTest {
Metadata metaData = new Metadata(); Metadata metaData = new Metadata();
metaData.put(GrpcUtil.USER_AGENT_KEY, "misbehaving-application"); metaData.put(GrpcUtil.USER_AGENT_KEY, "misbehaving-application");
stream = new OkHttpClientStream(methodDescriptor, metaData, frameWriter, transport, stream = new OkHttpClientStream(methodDescriptor, metaData, frameWriter, transport,
flowController, lock, MAX_MESSAGE_SIZE, "localhost", "good-application"); flowController, lock, MAX_MESSAGE_SIZE, "localhost", "good-application",
StatsTraceContext.NOOP);
stream.start(new BaseClientStreamListener()); stream.start(new BaseClientStreamListener());
stream.start(3); stream.start(3);

View File

@ -72,6 +72,7 @@ import io.grpc.internal.ServerStream;
import io.grpc.internal.ServerStreamListener; import io.grpc.internal.ServerStreamListener;
import io.grpc.internal.ServerTransport; import io.grpc.internal.ServerTransport;
import io.grpc.internal.ServerTransportListener; import io.grpc.internal.ServerTransportListener;
import io.grpc.internal.StatsTraceContext;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
@ -803,7 +804,7 @@ public abstract class AbstractTransportTest {
verify(mockClientStreamListener, never()).closed(any(Status.class), any(Metadata.class)); verify(mockClientStreamListener, never()).closed(any(Status.class), any(Metadata.class));
} }
@Test(timeout = 5000) @Test
public void clientCancelFromWithinMessageRead() throws Exception { public void clientCancelFromWithinMessageRead() throws Exception {
server.start(serverListener); server.start(serverListener);
client = newClientTransport(server); client = newClientTransport(server);
@ -852,7 +853,7 @@ public abstract class AbstractTransportTest {
serverStream.flush(); serverStream.flush();
// Block until closedCalled was set. // Block until closedCalled was set.
closedCalled.get(); closedCalled.get(5, TimeUnit.SECONDS);
serverStream.close(Status.OK, new Metadata()); serverStream.close(Status.OK, new Metadata());
} }
@ -1156,6 +1157,11 @@ public abstract class AbstractTransportTest {
this.transport = transport; this.transport = transport;
} }
@Override
public StatsTraceContext methodDetermined(String method, Metadata headers) {
return StatsTraceContext.NOOP;
}
@Override @Override
public ServerStreamListener streamCreated(ServerStream stream, String method, public ServerStreamListener streamCreated(ServerStream stream, String method,
Metadata headers) { Metadata headers) {

View File

@ -0,0 +1,246 @@
/*
* Copyright 2016, Google Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
*
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
package io.grpc.internal.testing;
import static com.google.common.base.Preconditions.checkNotNull;
import com.google.census.CensusContext;
import com.google.census.CensusContextFactory;
import com.google.census.Metric;
import com.google.census.MetricMap;
import com.google.census.MetricName;
import com.google.census.TagKey;
import com.google.census.TagValue;
import com.google.common.collect.ImmutableMap;
import io.grpc.Context;
import java.nio.ByteBuffer;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nullable;
public class CensusTestUtils {
private CensusTestUtils() {
}
public static class MetricsRecord {
public final ImmutableMap<TagKey, TagValue> tags;
public final MetricMap metrics;
private MetricsRecord(ImmutableMap<TagKey, TagValue> tags, MetricMap metrics) {
this.tags = tags;
this.metrics = metrics;
}
/**
* Returns the value of a metric, or {@code null} if not found.
*/
@Nullable
public Double getMetric(MetricName metricName) {
for (Metric m : metrics) {
if (m.getName().equals(metricName)) {
return m.getValue();
}
}
return null;
}
/**
* Returns the value of a metric converted to long, or throw if not found.
*/
public long getMetricAsLongOrFail(MetricName metricName) {
Double doubleValue = getMetric(metricName);
checkNotNull(doubleValue, "Metric not found: %s", metricName.toString());
long longValue = (long) (Math.abs(doubleValue) + 0.0001);
if (doubleValue < 0) {
longValue = -longValue;
}
return longValue;
}
}
public static final TagKey EXTRA_TAG = new TagKey("/rpc/test/extratag");
private static final String EXTRA_TAG_HEADER_VALUE_PREFIX = "extratag:";
private static final String NO_EXTRA_TAG_HEADER_VALUE_PREFIX = "noextratag";
/**
* A factory that makes fake {@link CensusContext}s and saves the created contexts to be
* accessible from {@link #pollContextOrFail}. The contexts it has created would save metrics
* records to be accessible from {@link #pollRecord()} and {@link #pollRecord(long, TimeUnit)},
* until {@link #rolloverRecords} is called.
*/
public static final class FakeCensusContextFactory extends CensusContextFactory {
private BlockingQueue<MetricsRecord> records;
public final BlockingQueue<FakeCensusContext> contexts =
new LinkedBlockingQueue<FakeCensusContext>();
private static final Context.Key<FakeCensusContext> CONTEXT_KEY =
Context.key("fakeCensusContext");
private final FakeCensusContext defaultContext;
/**
* Constructor.
*/
public FakeCensusContextFactory() {
rolloverRecords();
defaultContext = new FakeCensusContext(ImmutableMap.<TagKey, TagValue>of(), this);
// The records on the default context is not visible from pollRecord(), just like it's
// not visible from pollContextOrFail() either.
rolloverRecords();
}
public CensusContext pollContextOrFail() {
CensusContext cc = contexts.poll();
return checkNotNull(cc);
}
public MetricsRecord pollRecord() {
return getCurrentRecordSink().poll();
}
public MetricsRecord pollRecord(long timeout, TimeUnit unit) throws InterruptedException {
return getCurrentRecordSink().poll(timeout, unit);
}
@Override
public CensusContext deserialize(ByteBuffer buffer) {
String serializedString = new String(buffer.array());
if (serializedString.startsWith(EXTRA_TAG_HEADER_VALUE_PREFIX)) {
return getDefault().with(EXTRA_TAG,
new TagValue(serializedString.substring(EXTRA_TAG_HEADER_VALUE_PREFIX.length())));
} else if (serializedString.startsWith(NO_EXTRA_TAG_HEADER_VALUE_PREFIX)) {
return getDefault();
} else {
return null;
}
}
@Override
public FakeCensusContext getDefault() {
return defaultContext;
}
/**
* Disconnect this factory with the contexts it has created so far. The records from those
* contexts will not show up in {@link #pollRecord}. Useful for isolating the records between
* test cases.
*/
// This needs to be synchronized with getCurrentRecordSink() which may run concurrently.
public synchronized void rolloverRecords() {
records = new LinkedBlockingQueue<MetricsRecord>();
}
private synchronized BlockingQueue<MetricsRecord> getCurrentRecordSink() {
return records;
}
}
public static class FakeCensusContext extends CensusContext {
private final ImmutableMap<TagKey, TagValue> tags;
private final FakeCensusContextFactory factory;
private final BlockingQueue<MetricsRecord> recordSink;
private FakeCensusContext(ImmutableMap<TagKey, TagValue> tags,
FakeCensusContextFactory factory) {
this.tags = tags;
this.factory = factory;
this.recordSink = factory.getCurrentRecordSink();
}
@Override
public Builder builder() {
return new FakeCensusContextBuilder(this);
}
@Override
public CensusContext record(MetricMap metrics) {
recordSink.add(new MetricsRecord(tags, metrics));
return this;
}
@Override
public ByteBuffer serialize() {
TagValue extraTagValue = tags.get(EXTRA_TAG);
if (extraTagValue == null) {
return ByteBuffer.wrap(NO_EXTRA_TAG_HEADER_VALUE_PREFIX.getBytes());
} else {
return ByteBuffer.wrap(
(EXTRA_TAG_HEADER_VALUE_PREFIX + extraTagValue.toString()).getBytes());
}
}
@Override
public String toString() {
return "[tags=" + tags + "]";
}
@Override
public boolean equals(Object other) {
if (!(other instanceof FakeCensusContext)) {
return false;
}
FakeCensusContext otherCtx = (FakeCensusContext) other;
return tags.equals(otherCtx.tags);
}
@Override
public int hashCode() {
return tags.hashCode();
}
}
private static class FakeCensusContextBuilder extends CensusContext.Builder {
private final ImmutableMap.Builder<TagKey, TagValue> tagsBuilder = ImmutableMap.builder();
private final FakeCensusContext base;
private FakeCensusContextBuilder(FakeCensusContext base) {
this.base = base;
tagsBuilder.putAll(base.tags);
}
@Override
public CensusContext.Builder set(TagKey key, TagValue value) {
tagsBuilder.put(key, value);
return this;
}
@Override
public CensusContext build() {
FakeCensusContext context = new FakeCensusContext(tagsBuilder.build(), base.factory);
base.factory.contexts.add(context);
return context;
}
}
}