core: use new OpenCensus stats/tagging API. (#3647)

This commit updates gRPC core to use io.opencensus:opencensus-api and
io.opencensus:opencensus-contrib-grpc-metrics instead of
com.google.instrumentation:instrumentation-api for stats and tagging. The gRPC
Monitoring Service continues to use instrumentation-api.

The main changes affecting gRPC:

- The StatsContextFactory is replaced by three objects, StatsRecorder, Tagger,
  and TagContextBinarySerializer.

- The StatsRecorder, Tagger, and TagContextBinarySerializer are never null,
  but the objects are no-ops when the OpenCensus implementation is not
  available.

This commit includes changes written by @songy23 and @sebright.
This commit is contained in:
sebright 2017-11-07 12:25:03 -08:00 committed by Kun Zhang
parent a7300150de
commit ef2ec94911
18 changed files with 603 additions and 452 deletions

View File

@ -192,6 +192,7 @@ subprojects {
okhttp: 'com.squareup.okhttp:okhttp:2.5.0', okhttp: 'com.squareup.okhttp:okhttp:2.5.0',
okio: 'com.squareup.okio:okio:1.6.0', okio: 'com.squareup.okio:okio:1.6.0',
opencensus_api: 'io.opencensus:opencensus-api:0.8.0', opencensus_api: 'io.opencensus:opencensus-api:0.8.0',
opencensus_contrib_grpc_metrics: 'io.opencensus:opencensus-contrib-grpc-metrics:0.8.0',
opencensus_impl: 'io.opencensus:opencensus-impl:0.8.0', opencensus_impl: 'io.opencensus:opencensus-impl:0.8.0',
instrumentation_api: 'com.google.instrumentation:instrumentation-api:0.4.3', instrumentation_api: 'com.google.instrumentation:instrumentation-api:0.4.3',
protobuf: "com.google.protobuf:protobuf-java:${protobufVersion}", protobuf: "com.google.protobuf:protobuf-java:${protobufVersion}",

View File

@ -19,6 +19,12 @@ dependencies {
// we'll always be more up-to-date // we'll always be more up-to-date
exclude group: 'io.grpc', module: 'grpc-context' exclude group: 'io.grpc', module: 'grpc-context'
} }
compile (libraries.opencensus_contrib_grpc_metrics) {
// prefer 3.0.0 from libraries instead of 3.0.1
exclude group: 'com.google.code.findbugs', module: 'jsr305'
// we'll always be more up-to-date
exclude group: 'io.grpc', module: 'grpc-context'
}
testCompile project(':grpc-testing') testCompile project(':grpc-testing')

View File

@ -21,8 +21,6 @@ import static com.google.common.base.Preconditions.checkArgument;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import com.google.common.util.concurrent.MoreExecutors; import com.google.common.util.concurrent.MoreExecutors;
import com.google.instrumentation.stats.Stats;
import com.google.instrumentation.stats.StatsContextFactory;
import io.grpc.Attributes; import io.grpc.Attributes;
import io.grpc.ClientInterceptor; import io.grpc.ClientInterceptor;
import io.grpc.CompressorRegistry; import io.grpc.CompressorRegistry;
@ -150,7 +148,7 @@ public abstract class AbstractManagedChannelImplBuilder
private boolean tracingEnabled = true; private boolean tracingEnabled = true;
@Nullable @Nullable
private StatsContextFactory statsFactory; private CensusStatsModule censusStatsOverride;
protected AbstractManagedChannelImplBuilder(String target) { protected AbstractManagedChannelImplBuilder(String target) {
this.target = Preconditions.checkNotNull(target, "target"); this.target = Preconditions.checkNotNull(target, "target");
@ -285,8 +283,8 @@ public abstract class AbstractManagedChannelImplBuilder
* Override the default stats implementation. * Override the default stats implementation.
*/ */
@VisibleForTesting @VisibleForTesting
protected final T statsContextFactory(StatsContextFactory statsFactory) { protected final T overrideCensusStatsModule(CensusStatsModule censusStats) {
this.statsFactory = statsFactory; this.censusStatsOverride = censusStats;
return thisT(); return thisT();
} }
@ -344,15 +342,13 @@ public abstract class AbstractManagedChannelImplBuilder
List<ClientInterceptor> effectiveInterceptors = List<ClientInterceptor> effectiveInterceptors =
new ArrayList<ClientInterceptor>(this.interceptors); new ArrayList<ClientInterceptor>(this.interceptors);
if (statsEnabled) { if (statsEnabled) {
StatsContextFactory statsCtxFactory = CensusStatsModule censusStats = this.censusStatsOverride;
this.statsFactory != null ? this.statsFactory : Stats.getStatsContextFactory(); if (censusStats == null) {
if (statsCtxFactory != null) { censusStats = new CensusStatsModule(GrpcUtil.STOPWATCH_SUPPLIER, true);
CensusStatsModule censusStats = }
new CensusStatsModule(statsCtxFactory, GrpcUtil.STOPWATCH_SUPPLIER, true, recordStats);
// First interceptor runs last (see ClientInterceptors.intercept()), so that no // First interceptor runs last (see ClientInterceptors.intercept()), so that no
// other interceptor can override the tracer factory we set in CallOptions. // other interceptor can override the tracer factory we set in CallOptions.
effectiveInterceptors.add(0, censusStats.getClientInterceptor()); effectiveInterceptors.add(0, censusStats.getClientInterceptor(recordStats));
}
} }
if (tracingEnabled) { if (tracingEnabled) {
CensusTracingModule censusTracing = CensusTracingModule censusTracing =

View File

@ -20,8 +20,6 @@ import static com.google.common.base.Preconditions.checkNotNull;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import com.google.common.util.concurrent.MoreExecutors; import com.google.common.util.concurrent.MoreExecutors;
import com.google.instrumentation.stats.Stats;
import com.google.instrumentation.stats.StatsContextFactory;
import io.grpc.BindableService; import io.grpc.BindableService;
import io.grpc.CompressorRegistry; import io.grpc.CompressorRegistry;
import io.grpc.Context; import io.grpc.Context;
@ -97,7 +95,7 @@ public abstract class AbstractServerImplBuilder<T extends AbstractServerImplBuil
CompressorRegistry compressorRegistry = DEFAULT_COMPRESSOR_REGISTRY; CompressorRegistry compressorRegistry = DEFAULT_COMPRESSOR_REGISTRY;
@Nullable @Nullable
private StatsContextFactory statsFactory; private CensusStatsModule censusStatsOverride;
private boolean statsEnabled = true; private boolean statsEnabled = true;
private boolean recordStats = true; private boolean recordStats = true;
@ -184,8 +182,8 @@ public abstract class AbstractServerImplBuilder<T extends AbstractServerImplBuil
* Override the default stats implementation. * Override the default stats implementation.
*/ */
@VisibleForTesting @VisibleForTesting
protected T statsContextFactory(StatsContextFactory statsFactory) { protected T overrideCensusStatsModule(CensusStatsModule censusStats) {
this.statsFactory = statsFactory; this.censusStatsOverride = censusStats;
return thisT(); return thisT();
} }
@ -228,13 +226,11 @@ public abstract class AbstractServerImplBuilder<T extends AbstractServerImplBuil
ArrayList<ServerStreamTracer.Factory> tracerFactories = ArrayList<ServerStreamTracer.Factory> tracerFactories =
new ArrayList<ServerStreamTracer.Factory>(); new ArrayList<ServerStreamTracer.Factory>();
if (statsEnabled) { if (statsEnabled) {
StatsContextFactory statsFactory = CensusStatsModule censusStats = this.censusStatsOverride;
this.statsFactory != null ? this.statsFactory : Stats.getStatsContextFactory(); if (censusStats == null) {
if (statsFactory != null) { censusStats = new CensusStatsModule(GrpcUtil.STOPWATCH_SUPPLIER, true);
CensusStatsModule censusStats =
new CensusStatsModule(statsFactory, GrpcUtil.STOPWATCH_SUPPLIER, true, recordStats);
tracerFactories.add(censusStats.getServerTracerFactory());
} }
tracerFactories.add(censusStats.getServerTracerFactory(recordStats));
} }
if (tracingEnabled) { if (tracingEnabled) {
CensusTracingModule censusTracing = CensusTracingModule censusTracing =

View File

@ -19,16 +19,11 @@ package io.grpc.internal;
import static com.google.common.base.MoreObjects.firstNonNull; import static com.google.common.base.MoreObjects.firstNonNull;
import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Preconditions.checkState;
import static com.google.instrumentation.stats.ContextUtils.STATS_CONTEXT_KEY; import static io.opencensus.tags.unsafe.ContextUtils.TAG_CONTEXT_KEY;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Stopwatch; import com.google.common.base.Stopwatch;
import com.google.common.base.Supplier; import com.google.common.base.Supplier;
import com.google.instrumentation.stats.MeasurementMap;
import com.google.instrumentation.stats.RpcConstants;
import com.google.instrumentation.stats.StatsContext;
import com.google.instrumentation.stats.StatsContextFactory;
import com.google.instrumentation.stats.TagValue;
import io.grpc.CallOptions; import io.grpc.CallOptions;
import io.grpc.Channel; import io.grpc.Channel;
import io.grpc.ClientCall; import io.grpc.ClientCall;
@ -42,9 +37,16 @@ import io.grpc.MethodDescriptor;
import io.grpc.ServerStreamTracer; import io.grpc.ServerStreamTracer;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.StreamTracer; import io.grpc.StreamTracer;
import java.io.ByteArrayInputStream; import io.opencensus.contrib.grpc.metrics.RpcMeasureConstants;
import java.io.ByteArrayOutputStream; import io.opencensus.stats.MeasureMap;
import java.io.IOException; import io.opencensus.stats.Stats;
import io.opencensus.stats.StatsRecorder;
import io.opencensus.tags.TagContext;
import io.opencensus.tags.TagValue;
import io.opencensus.tags.Tagger;
import io.opencensus.tags.Tags;
import io.opencensus.tags.propagation.TagContextBinarySerializer;
import io.opencensus.tags.propagation.TagContextSerializationException;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
import java.util.concurrent.atomic.AtomicLongFieldUpdater; import java.util.concurrent.atomic.AtomicLongFieldUpdater;
@ -64,49 +66,63 @@ import javax.annotation.Nullable;
* starts earlier than the ServerCall. Therefore, only one tracer is created per stream/call and * starts earlier than the ServerCall. Therefore, only one tracer is created per stream/call and
* it's the tracer that reports the summary to Census. * it's the tracer that reports the summary to Census.
*/ */
final class CensusStatsModule { public final class CensusStatsModule {
private static final Logger logger = Logger.getLogger(CensusStatsModule.class.getName()); private static final Logger logger = Logger.getLogger(CensusStatsModule.class.getName());
private static final double NANOS_PER_MILLI = TimeUnit.MILLISECONDS.toNanos(1); private static final double NANOS_PER_MILLI = TimeUnit.MILLISECONDS.toNanos(1);
private static final ClientTracer BLANK_CLIENT_TRACER = new ClientTracer(); private static final ClientTracer BLANK_CLIENT_TRACER = new ClientTracer();
private final StatsContextFactory statsCtxFactory; private final Tagger tagger;
private final StatsRecorder statsRecorder;
private final Supplier<Stopwatch> stopwatchSupplier; private final Supplier<Stopwatch> stopwatchSupplier;
@VisibleForTesting @VisibleForTesting
final Metadata.Key<StatsContext> statsHeader; final Metadata.Key<TagContext> statsHeader;
private final StatsClientInterceptor clientInterceptor = new StatsClientInterceptor();
private final ServerTracerFactory serverTracerFactory = new ServerTracerFactory();
private final boolean propagateTags; private final boolean propagateTags;
private final boolean recordStats;
CensusStatsModule( /**
final StatsContextFactory statsCtxFactory, Supplier<Stopwatch> stopwatchSupplier, * Creates a {@link CensusStatsModule} with the default OpenCensus implementation.
boolean propagateTags, boolean recordStats) { */
this.statsCtxFactory = checkNotNull(statsCtxFactory, "statsCtxFactory"); CensusStatsModule(Supplier<Stopwatch> stopwatchSupplier, boolean propagateTags) {
this(
Tags.getTagger(),
Tags.getTagPropagationComponent().getBinarySerializer(),
Stats.getStatsRecorder(),
stopwatchSupplier,
propagateTags);
}
/**
* Creates a {@link CensusStatsModule} with the given OpenCensus implementation.
*/
public CensusStatsModule(
final Tagger tagger,
final TagContextBinarySerializer tagCtxSerializer,
StatsRecorder statsRecorder, Supplier<Stopwatch> stopwatchSupplier,
boolean propagateTags) {
this.tagger = checkNotNull(tagger, "tagger");
this.statsRecorder = checkNotNull(statsRecorder, "statsRecorder");
checkNotNull(tagCtxSerializer, "tagCtxSerializer");
this.stopwatchSupplier = checkNotNull(stopwatchSupplier, "stopwatchSupplier"); this.stopwatchSupplier = checkNotNull(stopwatchSupplier, "stopwatchSupplier");
this.propagateTags = propagateTags; this.propagateTags = propagateTags;
this.recordStats = recordStats;
this.statsHeader = this.statsHeader =
Metadata.Key.of("grpc-tags-bin", new Metadata.BinaryMarshaller<StatsContext>() { Metadata.Key.of("grpc-tags-bin", new Metadata.BinaryMarshaller<TagContext>() {
@Override @Override
public byte[] toBytes(StatsContext context) { public byte[] toBytes(TagContext context) {
// TODO(carl-mastrangelo): currently we only make sure the correctness. We may need to // TODO(carl-mastrangelo): currently we only make sure the correctness. We may need to
// optimize out the allocation and copy in the future. // optimize out the allocation and copy in the future.
ByteArrayOutputStream buffer = new ByteArrayOutputStream();
try { try {
context.serialize(buffer); return tagCtxSerializer.toByteArray(context);
} catch (IOException e) { } catch (TagContextSerializationException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
return buffer.toByteArray();
} }
@Override @Override
public StatsContext parseBytes(byte[] serialized) { public TagContext parseBytes(byte[] serialized) {
try { try {
return statsCtxFactory.deserialize(new ByteArrayInputStream(serialized)); return tagCtxSerializer.fromByteArray(serialized);
} catch (Exception e) { } catch (Exception e) {
logger.log(Level.FINE, "Failed to parse stats header", e); logger.log(Level.FINE, "Failed to parse stats header", e);
return statsCtxFactory.getDefault(); return tagger.empty();
} }
} }
}); });
@ -116,22 +132,23 @@ final class CensusStatsModule {
* Creates a {@link ClientCallTracer} for a new call. * Creates a {@link ClientCallTracer} for a new call.
*/ */
@VisibleForTesting @VisibleForTesting
ClientCallTracer newClientCallTracer(StatsContext parentCtx, String fullMethodName) { ClientCallTracer newClientCallTracer(
return new ClientCallTracer(this, parentCtx, fullMethodName); TagContext parentCtx, String fullMethodName, boolean recordStats) {
return new ClientCallTracer(this, parentCtx, fullMethodName, recordStats);
} }
/** /**
* Returns the server tracer factory. * Returns the server tracer factory.
*/ */
ServerStreamTracer.Factory getServerTracerFactory() { ServerStreamTracer.Factory getServerTracerFactory(boolean recordStats) {
return serverTracerFactory; return new ServerTracerFactory(recordStats);
} }
/** /**
* Returns the client interceptor that facilitates Census-based stats reporting. * Returns the client interceptor that facilitates Census-based stats reporting.
*/ */
ClientInterceptor getClientInterceptor() { ClientInterceptor getClientInterceptor(boolean recordStats) {
return clientInterceptor; return new StatsClientInterceptor(recordStats);
} }
private static final class ClientTracer extends ClientStreamTracer { private static final class ClientTracer extends ClientStreamTracer {
@ -203,13 +220,19 @@ final class CensusStatsModule {
private final Stopwatch stopwatch; private final Stopwatch stopwatch;
private volatile ClientTracer streamTracer; private volatile ClientTracer streamTracer;
private volatile int callEnded; private volatile int callEnded;
private final StatsContext parentCtx; private final TagContext parentCtx;
private final boolean recordStats;
ClientCallTracer(CensusStatsModule module, StatsContext parentCtx, String fullMethodName) { ClientCallTracer(
CensusStatsModule module,
TagContext parentCtx,
String fullMethodName,
boolean recordStats) {
this.module = module; this.module = module;
this.parentCtx = checkNotNull(parentCtx, "parentCtx"); this.parentCtx = checkNotNull(parentCtx, "parentCtx");
this.fullMethodName = checkNotNull(fullMethodName, "fullMethodName"); this.fullMethodName = checkNotNull(fullMethodName, "fullMethodName");
this.stopwatch = module.stopwatchSupplier.get().start(); this.stopwatch = module.stopwatchSupplier.get().start();
this.recordStats = recordStats;
} }
@Override @Override
@ -222,7 +245,7 @@ final class CensusStatsModule {
"Are you creating multiple streams per call? This class doesn't yet support this case."); "Are you creating multiple streams per call? This class doesn't yet support this case.");
if (module.propagateTags) { if (module.propagateTags) {
headers.discardAll(module.statsHeader); headers.discardAll(module.statsHeader);
if (parentCtx != module.statsCtxFactory.getDefault()) { if (!module.tagger.empty().equals(parentCtx)) {
headers.put(module.statsHeader, parentCtx); headers.put(module.statsHeader, parentCtx);
} }
} }
@ -239,7 +262,7 @@ final class CensusStatsModule {
if (callEndedUpdater.getAndSet(this, 1) != 0) { if (callEndedUpdater.getAndSet(this, 1) != 0) {
return; return;
} }
if (!module.recordStats) { if (!recordStats) {
return; return;
} }
stopwatch.stop(); stopwatch.stop();
@ -248,27 +271,29 @@ final class CensusStatsModule {
if (tracer == null) { if (tracer == null) {
tracer = BLANK_CLIENT_TRACER; tracer = BLANK_CLIENT_TRACER;
} }
MeasurementMap.Builder builder = MeasurementMap.builder() MeasureMap measureMap = module.statsRecorder.newMeasureMap()
// The metrics are in double // The metrics are in double
.put(RpcConstants.RPC_CLIENT_ROUNDTRIP_LATENCY, roundtripNanos / NANOS_PER_MILLI) .put(RpcMeasureConstants.RPC_CLIENT_ROUNDTRIP_LATENCY, roundtripNanos / NANOS_PER_MILLI)
.put(RpcConstants.RPC_CLIENT_REQUEST_COUNT, tracer.outboundMessageCount) .put(RpcMeasureConstants.RPC_CLIENT_REQUEST_COUNT, tracer.outboundMessageCount)
.put(RpcConstants.RPC_CLIENT_RESPONSE_COUNT, tracer.inboundMessageCount) .put(RpcMeasureConstants.RPC_CLIENT_RESPONSE_COUNT, tracer.inboundMessageCount)
.put(RpcConstants.RPC_CLIENT_REQUEST_BYTES, tracer.outboundWireSize) .put(RpcMeasureConstants.RPC_CLIENT_REQUEST_BYTES, tracer.outboundWireSize)
.put(RpcConstants.RPC_CLIENT_RESPONSE_BYTES, tracer.inboundWireSize) .put(RpcMeasureConstants.RPC_CLIENT_RESPONSE_BYTES, tracer.inboundWireSize)
.put( .put(
RpcConstants.RPC_CLIENT_UNCOMPRESSED_REQUEST_BYTES, RpcMeasureConstants.RPC_CLIENT_UNCOMPRESSED_REQUEST_BYTES,
tracer.outboundUncompressedSize) tracer.outboundUncompressedSize)
.put( .put(
RpcConstants.RPC_CLIENT_UNCOMPRESSED_RESPONSE_BYTES, RpcMeasureConstants.RPC_CLIENT_UNCOMPRESSED_RESPONSE_BYTES,
tracer.inboundUncompressedSize); tracer.inboundUncompressedSize);
if (!status.isOk()) { if (!status.isOk()) {
builder.put(RpcConstants.RPC_CLIENT_ERROR_COUNT, 1.0); measureMap.put(RpcMeasureConstants.RPC_CLIENT_ERROR_COUNT, 1);
} }
parentCtx measureMap.record(
.with( module
RpcConstants.RPC_CLIENT_METHOD, TagValue.create(fullMethodName), .tagger
RpcConstants.RPC_STATUS, TagValue.create(status.getCode().toString())) .toBuilder(parentCtx)
.record(builder.build()); .put(RpcMeasureConstants.RPC_METHOD, TagValue.create(fullMethodName))
.put(RpcMeasureConstants.RPC_STATUS, TagValue.create(status.getCode().toString()))
.build());
} }
} }
@ -291,10 +316,11 @@ final class CensusStatsModule {
private final CensusStatsModule module; private final CensusStatsModule module;
private final String fullMethodName; private final String fullMethodName;
@Nullable @Nullable
private final StatsContext parentCtx; private final TagContext parentCtx;
private volatile int streamClosed; private volatile int streamClosed;
private final Stopwatch stopwatch; private final Stopwatch stopwatch;
private final StatsContextFactory statsCtxFactory; private final Tagger tagger;
private final boolean recordStats;
private volatile long outboundMessageCount; private volatile long outboundMessageCount;
private volatile long inboundMessageCount; private volatile long inboundMessageCount;
private volatile long outboundWireSize; private volatile long outboundWireSize;
@ -305,14 +331,16 @@ final class CensusStatsModule {
ServerTracer( ServerTracer(
CensusStatsModule module, CensusStatsModule module,
String fullMethodName, String fullMethodName,
StatsContext parentCtx, TagContext parentCtx,
Supplier<Stopwatch> stopwatchSupplier, Supplier<Stopwatch> stopwatchSupplier,
StatsContextFactory statsCtxFactory) { Tagger tagger,
boolean recordStats) {
this.module = module; this.module = module;
this.fullMethodName = checkNotNull(fullMethodName, "fullMethodName"); this.fullMethodName = checkNotNull(fullMethodName, "fullMethodName");
this.parentCtx = checkNotNull(parentCtx, "parentCtx"); this.parentCtx = checkNotNull(parentCtx, "parentCtx");
this.stopwatch = stopwatchSupplier.get().start(); this.stopwatch = stopwatchSupplier.get().start();
this.statsCtxFactory = statsCtxFactory; this.tagger = tagger;
this.recordStats = recordStats;
} }
@Override @Override
@ -356,33 +384,36 @@ final class CensusStatsModule {
if (streamClosedUpdater.getAndSet(this, 1) != 0) { if (streamClosedUpdater.getAndSet(this, 1) != 0) {
return; return;
} }
if (!module.recordStats) { if (!recordStats) {
return; return;
} }
stopwatch.stop(); stopwatch.stop();
long elapsedTimeNanos = stopwatch.elapsed(TimeUnit.NANOSECONDS); long elapsedTimeNanos = stopwatch.elapsed(TimeUnit.NANOSECONDS);
MeasurementMap.Builder builder = MeasurementMap.builder() MeasureMap measureMap = module.statsRecorder.newMeasureMap()
// The metrics are in double // The metrics are in double
.put(RpcConstants.RPC_SERVER_SERVER_LATENCY, elapsedTimeNanos / NANOS_PER_MILLI) .put(RpcMeasureConstants.RPC_SERVER_SERVER_LATENCY, elapsedTimeNanos / NANOS_PER_MILLI)
.put(RpcConstants.RPC_SERVER_RESPONSE_COUNT, outboundMessageCount) .put(RpcMeasureConstants.RPC_SERVER_RESPONSE_COUNT, outboundMessageCount)
.put(RpcConstants.RPC_SERVER_REQUEST_COUNT, inboundMessageCount) .put(RpcMeasureConstants.RPC_SERVER_REQUEST_COUNT, inboundMessageCount)
.put(RpcConstants.RPC_SERVER_RESPONSE_BYTES, outboundWireSize) .put(RpcMeasureConstants.RPC_SERVER_RESPONSE_BYTES, outboundWireSize)
.put(RpcConstants.RPC_SERVER_REQUEST_BYTES, inboundWireSize) .put(RpcMeasureConstants.RPC_SERVER_REQUEST_BYTES, inboundWireSize)
.put(RpcConstants.RPC_SERVER_UNCOMPRESSED_RESPONSE_BYTES, outboundUncompressedSize) .put(RpcMeasureConstants.RPC_SERVER_UNCOMPRESSED_RESPONSE_BYTES, outboundUncompressedSize)
.put(RpcConstants.RPC_SERVER_UNCOMPRESSED_REQUEST_BYTES, inboundUncompressedSize); .put(RpcMeasureConstants.RPC_SERVER_UNCOMPRESSED_REQUEST_BYTES, inboundUncompressedSize);
if (!status.isOk()) { if (!status.isOk()) {
builder.put(RpcConstants.RPC_SERVER_ERROR_COUNT, 1.0); measureMap.put(RpcMeasureConstants.RPC_SERVER_ERROR_COUNT, 1);
} }
StatsContext ctx = firstNonNull(parentCtx, statsCtxFactory.getDefault()); TagContext ctx = firstNonNull(parentCtx, tagger.empty());
ctx measureMap.record(
.with(RpcConstants.RPC_STATUS, TagValue.create(status.getCode().toString())) module
.record(builder.build()); .tagger
.toBuilder(ctx)
.put(RpcMeasureConstants.RPC_STATUS, TagValue.create(status.getCode().toString()))
.build());
} }
@Override @Override
public Context filterContext(Context context) { public Context filterContext(Context context) {
if (parentCtx != statsCtxFactory.getDefault()) { if (!tagger.empty().equals(parentCtx)) {
return context.withValue(STATS_CONTEXT_KEY, parentCtx); return context.withValue(TAG_CONTEXT_KEY, parentCtx);
} }
return context; return context;
} }
@ -390,27 +421,48 @@ final class CensusStatsModule {
@VisibleForTesting @VisibleForTesting
final class ServerTracerFactory extends ServerStreamTracer.Factory { final class ServerTracerFactory extends ServerStreamTracer.Factory {
private final boolean recordStats;
ServerTracerFactory(boolean recordStats) {
this.recordStats = recordStats;
}
@Override @Override
public ServerStreamTracer newServerStreamTracer(String fullMethodName, Metadata headers) { public ServerStreamTracer newServerStreamTracer(String fullMethodName, Metadata headers) {
StatsContext parentCtx = headers.get(statsHeader); TagContext parentCtx = headers.get(statsHeader);
if (parentCtx == null) { if (parentCtx == null) {
parentCtx = statsCtxFactory.getDefault(); parentCtx = tagger.empty();
} }
parentCtx = parentCtx.with(RpcConstants.RPC_SERVER_METHOD, TagValue.create(fullMethodName)); parentCtx =
tagger
.toBuilder(parentCtx)
.put(RpcMeasureConstants.RPC_METHOD, TagValue.create(fullMethodName))
.build();
return new ServerTracer( return new ServerTracer(
CensusStatsModule.this, fullMethodName, parentCtx, stopwatchSupplier, statsCtxFactory); CensusStatsModule.this,
fullMethodName,
parentCtx,
stopwatchSupplier,
tagger,
recordStats);
} }
} }
@VisibleForTesting @VisibleForTesting
final class StatsClientInterceptor implements ClientInterceptor { final class StatsClientInterceptor implements ClientInterceptor {
private final boolean recordStats;
StatsClientInterceptor(boolean recordStats) {
this.recordStats = recordStats;
}
@Override @Override
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall( public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) { MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
// New RPCs on client-side inherit the stats context from the current Context. // New RPCs on client-side inherit the tag context from the current Context.
StatsContext parentCtx = statsCtxFactory.getCurrentStatsContext(); TagContext parentCtx = tagger.getCurrentTagContext();
final ClientCallTracer tracerFactory = final ClientCallTracer tracerFactory =
newClientCallTracer(parentCtx, method.getFullMethodName()); newClientCallTracer(parentCtx, method.getFullMethodName(), recordStats);
ClientCall<ReqT, RespT> call = ClientCall<ReqT, RespT> call =
next.newCall(method, callOptions.withStreamTracerFactory(tracerFactory)); next.newCall(method, callOptions.withStreamTracerFactory(tracerFactory));
return new SimpleForwardingClientCall<ReqT, RespT>(call) { return new SimpleForwardingClientCall<ReqT, RespT>(call) {

View File

@ -27,8 +27,6 @@ import static org.junit.Assert.fail;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import com.google.common.util.concurrent.MoreExecutors; import com.google.common.util.concurrent.MoreExecutors;
import com.google.instrumentation.stats.StatsContext;
import com.google.instrumentation.stats.StatsContextFactory;
import io.grpc.CallOptions; import io.grpc.CallOptions;
import io.grpc.Channel; import io.grpc.Channel;
import io.grpc.ClientCall; import io.grpc.ClientCall;
@ -38,7 +36,9 @@ import io.grpc.DecompressorRegistry;
import io.grpc.LoadBalancer; import io.grpc.LoadBalancer;
import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor;
import io.grpc.NameResolver; import io.grpc.NameResolver;
import java.io.InputStream; import io.grpc.internal.testing.StatsTestUtils.FakeStatsRecorder;
import io.grpc.internal.testing.StatsTestUtils.FakeTagContextBinarySerializer;
import io.grpc.internal.testing.StatsTestUtils.FakeTagger;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.net.SocketAddress; import java.net.SocketAddress;
import java.net.URI; import java.net.URI;
@ -61,19 +61,6 @@ public class AbstractManagedChannelImplBuilderTest {
} }
}; };
private static final StatsContextFactory DUMMY_STATS_FACTORY =
new StatsContextFactory() {
@Override
public StatsContext deserialize(InputStream input) {
throw new UnsupportedOperationException();
}
@Override
public StatsContext getDefault() {
throw new UnsupportedOperationException();
}
};
private Builder builder = new Builder("fake"); private Builder builder = new Builder("fake");
private Builder directAddressBuilder = new Builder(new SocketAddress(){}, "fake"); private Builder directAddressBuilder = new Builder(new SocketAddress(){}, "fake");
@ -346,12 +333,24 @@ public class AbstractManagedChannelImplBuilderTest {
static class Builder extends AbstractManagedChannelImplBuilder<Builder> { static class Builder extends AbstractManagedChannelImplBuilder<Builder> {
Builder(String target) { Builder(String target) {
super(target); super(target);
statsContextFactory(DUMMY_STATS_FACTORY); overrideCensusStatsModule(
new CensusStatsModule(
new FakeTagger(),
new FakeTagContextBinarySerializer(),
new FakeStatsRecorder(),
GrpcUtil.STOPWATCH_SUPPLIER,
true));
} }
Builder(SocketAddress directServerAddress, String authority) { Builder(SocketAddress directServerAddress, String authority) {
super(directServerAddress, authority); super(directServerAddress, authority);
statsContextFactory(DUMMY_STATS_FACTORY); overrideCensusStatsModule(
new CensusStatsModule(
new FakeTagger(),
new FakeTagContextBinarySerializer(),
new FakeStatsRecorder(),
GrpcUtil.STOPWATCH_SUPPLIER,
true));
} }
@Override @Override

View File

@ -19,12 +19,12 @@ package io.grpc.internal;
import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import com.google.instrumentation.stats.StatsContext;
import com.google.instrumentation.stats.StatsContextFactory;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.ServerStreamTracer; import io.grpc.ServerStreamTracer;
import io.grpc.internal.testing.StatsTestUtils.FakeStatsRecorder;
import io.grpc.internal.testing.StatsTestUtils.FakeTagContextBinarySerializer;
import io.grpc.internal.testing.StatsTestUtils.FakeTagger;
import java.io.File; import java.io.File;
import java.io.InputStream;
import java.util.List; import java.util.List;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
@ -33,18 +33,6 @@ import org.junit.runners.JUnit4;
/** Unit tests for {@link AbstractServerImplBuilder}. */ /** Unit tests for {@link AbstractServerImplBuilder}. */
@RunWith(JUnit4.class) @RunWith(JUnit4.class)
public class AbstractServerImplBuilderTest { public class AbstractServerImplBuilderTest {
private static final StatsContextFactory DUMMY_STATS_FACTORY =
new StatsContextFactory() {
@Override
public StatsContext deserialize(InputStream input) {
throw new UnsupportedOperationException();
}
@Override
public StatsContext getDefault() {
throw new UnsupportedOperationException();
}
};
private static final ServerStreamTracer.Factory DUMMY_USER_TRACER = private static final ServerStreamTracer.Factory DUMMY_USER_TRACER =
new ServerStreamTracer.Factory() { new ServerStreamTracer.Factory() {
@ -97,7 +85,13 @@ public class AbstractServerImplBuilderTest {
static class Builder extends AbstractServerImplBuilder<Builder> { static class Builder extends AbstractServerImplBuilder<Builder> {
Builder() { Builder() {
statsContextFactory(DUMMY_STATS_FACTORY); overrideCensusStatsModule(
new CensusStatsModule(
new FakeTagger(),
new FakeTagContextBinarySerializer(),
new FakeStatsRecorder(),
GrpcUtil.STOPWATCH_SUPPLIER,
true));
} }
@Override @Override

View File

@ -16,7 +16,7 @@
package io.grpc.internal; package io.grpc.internal;
import static com.google.instrumentation.stats.ContextUtils.STATS_CONTEXT_KEY; import static io.opencensus.tags.unsafe.ContextUtils.TAG_CONTEXT_KEY;
import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertFalse;
@ -40,9 +40,6 @@ import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.verifyZeroInteractions; import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import com.google.instrumentation.stats.RpcConstants;
import com.google.instrumentation.stats.StatsContext;
import com.google.instrumentation.stats.TagValue;
import io.grpc.CallOptions; import io.grpc.CallOptions;
import io.grpc.Channel; import io.grpc.Channel;
import io.grpc.ClientCall; import io.grpc.ClientCall;
@ -58,9 +55,15 @@ import io.grpc.ServerServiceDefinition;
import io.grpc.ServerStreamTracer; import io.grpc.ServerStreamTracer;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.internal.testing.StatsTestUtils; import io.grpc.internal.testing.StatsTestUtils;
import io.grpc.internal.testing.StatsTestUtils.FakeStatsContextFactory; import io.grpc.internal.testing.StatsTestUtils.FakeStatsRecorder;
import io.grpc.internal.testing.StatsTestUtils.FakeTagContextBinarySerializer;
import io.grpc.internal.testing.StatsTestUtils.FakeTagger;
import io.grpc.internal.testing.StatsTestUtils.MockableSpan; import io.grpc.internal.testing.StatsTestUtils.MockableSpan;
import io.grpc.testing.GrpcServerRule; import io.grpc.testing.GrpcServerRule;
import io.opencensus.contrib.grpc.metrics.RpcMeasureConstants;
import io.opencensus.tags.TagContext;
import io.opencensus.tags.TagValue;
import io.opencensus.tags.Tags;
import io.opencensus.trace.EndSpanOptions; import io.opencensus.trace.EndSpanOptions;
import io.opencensus.trace.NetworkEvent; import io.opencensus.trace.NetworkEvent;
import io.opencensus.trace.NetworkEvent.Type; import io.opencensus.trace.NetworkEvent.Type;
@ -71,7 +74,6 @@ import io.opencensus.trace.Tracer;
import io.opencensus.trace.propagation.BinaryFormat; import io.opencensus.trace.propagation.BinaryFormat;
import io.opencensus.trace.propagation.SpanContextParseException; import io.opencensus.trace.propagation.SpanContextParseException;
import io.opencensus.trace.unsafe.ContextUtils; import io.opencensus.trace.unsafe.ContextUtils;
import java.io.ByteArrayInputStream;
import java.io.InputStream; import java.io.InputStream;
import java.util.List; import java.util.List;
import java.util.Random; import java.util.Random;
@ -137,7 +139,10 @@ public class CensusModulesTest {
method.toBuilder().setSampledToLocalTracing(true).build(); method.toBuilder().setSampledToLocalTracing(true).build();
private final FakeClock fakeClock = new FakeClock(); private final FakeClock fakeClock = new FakeClock();
private final FakeStatsContextFactory statsCtxFactory = new FakeStatsContextFactory(); private final FakeTagger tagger = new FakeTagger();
private final FakeTagContextBinarySerializer tagCtxSerializer =
new FakeTagContextBinarySerializer();
private final FakeStatsRecorder statsRecorder = new FakeStatsRecorder();
private final Random random = new Random(1234); private final Random random = new Random(1234);
private final Span fakeClientParentSpan = MockableSpan.generateRandomSpan(random); private final Span fakeClientParentSpan = MockableSpan.generateRandomSpan(random);
private final Span spyClientSpan = spy(MockableSpan.generateRandomSpan(random)); private final Span spyClientSpan = spy(MockableSpan.generateRandomSpan(random));
@ -185,13 +190,14 @@ public class CensusModulesTest {
when(mockTracingPropagationHandler.fromByteArray(any(byte[].class))) when(mockTracingPropagationHandler.fromByteArray(any(byte[].class)))
.thenReturn(fakeClientSpanContext); .thenReturn(fakeClientSpanContext);
censusStats = censusStats =
new CensusStatsModule(statsCtxFactory, fakeClock.getStopwatchSupplier(), true, true); new CensusStatsModule(
tagger, tagCtxSerializer, statsRecorder, fakeClock.getStopwatchSupplier(), true);
censusTracing = new CensusTracingModule(tracer, mockTracingPropagationHandler); censusTracing = new CensusTracingModule(tracer, mockTracingPropagationHandler);
} }
@After @After
public void wrapUp() { public void wrapUp() {
assertNull(statsCtxFactory.pollRecord()); assertNull(statsRecorder.pollRecord());
} }
@Test @Test
@ -204,7 +210,7 @@ public class CensusModulesTest {
testClientInterceptors(true); testClientInterceptors(true);
} }
// Test that Census ClientInterceptors uses the StatsContext and Span out of the current Context // Test that Census ClientInterceptors uses the TagContext and Span out of the current Context
// to create the ClientCallTracer, and that it intercepts ClientCall.Listener.onClose() to call // to create the ClientCallTracer, and that it intercepts ClientCall.Listener.onClose() to call
// ClientCallTracer.callEnded(). // ClientCallTracer.callEnded().
private void testClientInterceptors(boolean nonDefaultContext) { private void testClientInterceptors(boolean nonDefaultContext) {
@ -234,14 +240,14 @@ public class CensusModulesTest {
Channel interceptedChannel = Channel interceptedChannel =
ClientInterceptors.intercept( ClientInterceptors.intercept(
grpcServerRule.getChannel(), callOptionsCaptureInterceptor, grpcServerRule.getChannel(), callOptionsCaptureInterceptor,
censusStats.getClientInterceptor(), censusTracing.getClientInterceptor()); censusStats.getClientInterceptor(true), censusTracing.getClientInterceptor());
ClientCall<String, String> call; ClientCall<String, String> call;
if (nonDefaultContext) { if (nonDefaultContext) {
Context ctx = Context ctx =
Context.ROOT.withValues( Context.ROOT.withValues(
STATS_CONTEXT_KEY, TAG_CONTEXT_KEY,
statsCtxFactory.getDefault().with( tagger.emptyBuilder().put(
StatsTestUtils.EXTRA_TAG, TagValue.create("extra value")), StatsTestUtils.EXTRA_TAG, TagValue.create("extra value")).build(),
ContextUtils.CONTEXT_SPAN_KEY, ContextUtils.CONTEXT_SPAN_KEY,
fakeClientParentSpan); fakeClientParentSpan);
Context origCtx = ctx.attach(); Context origCtx = ctx.attach();
@ -251,7 +257,7 @@ public class CensusModulesTest {
ctx.detach(origCtx); ctx.detach(origCtx);
} }
} else { } else {
assertNull(STATS_CONTEXT_KEY.get()); assertEquals(Tags.getTagger().empty(), TAG_CONTEXT_KEY.get());
assertNull(ContextUtils.CONTEXT_SPAN_KEY.get()); assertNull(ContextUtils.CONTEXT_SPAN_KEY.get());
call = interceptedChannel.newCall(method, CALL_OPTIONS); call = interceptedChannel.newCall(method, CALL_OPTIONS);
} }
@ -269,7 +275,7 @@ public class CensusModulesTest {
// Make the call // Make the call
Metadata headers = new Metadata(); Metadata headers = new Metadata();
call.start(mockClientCallListener, headers); call.start(mockClientCallListener, headers);
assertNull(statsCtxFactory.pollRecord()); assertNull(statsRecorder.pollRecord());
if (nonDefaultContext) { if (nonDefaultContext) {
verify(tracer).spanBuilderWithExplicitParent( verify(tracer).spanBuilderWithExplicitParent(
eq("Sent.package1.service2.method3"), same(fakeClientParentSpan)); eq("Sent.package1.service2.method3"), same(fakeClientParentSpan));
@ -291,15 +297,15 @@ public class CensusModulesTest {
assertEquals("No you don't", status.getDescription()); assertEquals("No you don't", status.getDescription());
// The intercepting listener calls callEnded() on ClientCallTracer, which records to Census. // The intercepting listener calls callEnded() on ClientCallTracer, which records to Census.
StatsTestUtils.MetricsRecord record = statsCtxFactory.pollRecord(); StatsTestUtils.MetricsRecord record = statsRecorder.pollRecord();
assertNotNull(record); assertNotNull(record);
TagValue methodTag = record.tags.get(RpcConstants.RPC_CLIENT_METHOD); TagValue methodTag = record.tags.get(RpcMeasureConstants.RPC_METHOD);
assertEquals(method.getFullMethodName(), methodTag.toString()); assertEquals(method.getFullMethodName(), methodTag.asString());
TagValue statusTag = record.tags.get(RpcConstants.RPC_STATUS); TagValue statusTag = record.tags.get(RpcMeasureConstants.RPC_STATUS);
assertEquals(Status.Code.PERMISSION_DENIED.toString(), statusTag.toString()); assertEquals(Status.Code.PERMISSION_DENIED.toString(), statusTag.asString());
if (nonDefaultContext) { if (nonDefaultContext) {
TagValue extraTag = record.tags.get(StatsTestUtils.EXTRA_TAG); TagValue extraTag = record.tags.get(StatsTestUtils.EXTRA_TAG);
assertEquals("extra value", extraTag.toString()); assertEquals("extra value", extraTag.asString());
} else { } else {
assertNull(record.tags.get(StatsTestUtils.EXTRA_TAG)); assertNull(record.tags.get(StatsTestUtils.EXTRA_TAG));
} }
@ -316,7 +322,7 @@ public class CensusModulesTest {
@Test @Test
public void clientBasicStatsDefaultContext() { public void clientBasicStatsDefaultContext() {
CensusStatsModule.ClientCallTracer callTracer = CensusStatsModule.ClientCallTracer callTracer =
censusStats.newClientCallTracer(statsCtxFactory.getDefault(), method.getFullMethodName()); censusStats.newClientCallTracer(tagger.empty(), method.getFullMethodName(), true);
Metadata headers = new Metadata(); Metadata headers = new Metadata();
ClientStreamTracer tracer = callTracer.newClientStreamTracer(CallOptions.DEFAULT, headers); ClientStreamTracer tracer = callTracer.newClientStreamTracer(CallOptions.DEFAULT, headers);
@ -343,24 +349,27 @@ public class CensusModulesTest {
tracer.streamClosed(Status.OK); tracer.streamClosed(Status.OK);
callTracer.callEnded(Status.OK); callTracer.callEnded(Status.OK);
StatsTestUtils.MetricsRecord record = statsCtxFactory.pollRecord(); StatsTestUtils.MetricsRecord record = statsRecorder.pollRecord();
assertNotNull(record); assertNotNull(record);
assertNoServerContent(record); assertNoServerContent(record);
TagValue methodTag = record.tags.get(RpcConstants.RPC_CLIENT_METHOD); TagValue methodTag = record.tags.get(RpcMeasureConstants.RPC_METHOD);
assertEquals(method.getFullMethodName(), methodTag.toString()); assertEquals(method.getFullMethodName(), methodTag.asString());
TagValue statusTag = record.tags.get(RpcConstants.RPC_STATUS); TagValue statusTag = record.tags.get(RpcMeasureConstants.RPC_STATUS);
assertEquals(Status.Code.OK.toString(), statusTag.toString()); assertEquals(Status.Code.OK.toString(), statusTag.asString());
assertNull(record.getMetric(RpcConstants.RPC_CLIENT_ERROR_COUNT)); assertNull(record.getMetric(RpcMeasureConstants.RPC_CLIENT_ERROR_COUNT));
assertEquals(2, record.getMetricAsLongOrFail(RpcConstants.RPC_CLIENT_REQUEST_COUNT)); assertEquals(2, record.getMetricAsLongOrFail(RpcMeasureConstants.RPC_CLIENT_REQUEST_COUNT));
assertEquals(1028 + 99, record.getMetricAsLongOrFail(RpcConstants.RPC_CLIENT_REQUEST_BYTES)); assertEquals(
assertEquals(1128 + 865, 1028 + 99, record.getMetricAsLongOrFail(RpcMeasureConstants.RPC_CLIENT_REQUEST_BYTES));
record.getMetricAsLongOrFail(RpcConstants.RPC_CLIENT_UNCOMPRESSED_REQUEST_BYTES)); assertEquals(
assertEquals(2, record.getMetricAsLongOrFail(RpcConstants.RPC_CLIENT_RESPONSE_COUNT)); 1128 + 865,
assertEquals(33 + 154, record.getMetricAsLongOrFail(RpcConstants.RPC_CLIENT_RESPONSE_BYTES)); record.getMetricAsLongOrFail(RpcMeasureConstants.RPC_CLIENT_UNCOMPRESSED_REQUEST_BYTES));
assertEquals(2, record.getMetricAsLongOrFail(RpcMeasureConstants.RPC_CLIENT_RESPONSE_COUNT));
assertEquals(
33 + 154, record.getMetricAsLongOrFail(RpcMeasureConstants.RPC_CLIENT_RESPONSE_BYTES));
assertEquals(67 + 552, assertEquals(67 + 552,
record.getMetricAsLongOrFail(RpcConstants.RPC_CLIENT_UNCOMPRESSED_RESPONSE_BYTES)); record.getMetricAsLongOrFail(RpcMeasureConstants.RPC_CLIENT_UNCOMPRESSED_RESPONSE_BYTES));
assertEquals(30 + 100 + 16 + 24, assertEquals(30 + 100 + 16 + 24,
record.getMetricAsLongOrFail(RpcConstants.RPC_CLIENT_ROUNDTRIP_LATENCY)); record.getMetricAsLongOrFail(RpcMeasureConstants.RPC_CLIENT_ROUNDTRIP_LATENCY));
} }
@Test @Test
@ -424,29 +433,30 @@ public class CensusModulesTest {
public void clientStreamNeverCreatedStillRecordStats() { public void clientStreamNeverCreatedStillRecordStats() {
CensusStatsModule.ClientCallTracer callTracer = CensusStatsModule.ClientCallTracer callTracer =
censusStats.newClientCallTracer( censusStats.newClientCallTracer(
statsCtxFactory.getDefault(), method.getFullMethodName()); tagger.empty(), method.getFullMethodName(), true);
fakeClock.forwardTime(3000, MILLISECONDS); fakeClock.forwardTime(3000, MILLISECONDS);
callTracer.callEnded(Status.DEADLINE_EXCEEDED.withDescription("3 seconds")); callTracer.callEnded(Status.DEADLINE_EXCEEDED.withDescription("3 seconds"));
StatsTestUtils.MetricsRecord record = statsCtxFactory.pollRecord(); StatsTestUtils.MetricsRecord record = statsRecorder.pollRecord();
assertNotNull(record); assertNotNull(record);
assertNoServerContent(record); assertNoServerContent(record);
TagValue methodTag = record.tags.get(RpcConstants.RPC_CLIENT_METHOD); TagValue methodTag = record.tags.get(RpcMeasureConstants.RPC_METHOD);
assertEquals(method.getFullMethodName(), methodTag.toString()); assertEquals(method.getFullMethodName(), methodTag.asString());
TagValue statusTag = record.tags.get(RpcConstants.RPC_STATUS); TagValue statusTag = record.tags.get(RpcMeasureConstants.RPC_STATUS);
assertEquals(Status.Code.DEADLINE_EXCEEDED.toString(), statusTag.toString()); assertEquals(Status.Code.DEADLINE_EXCEEDED.toString(), statusTag.asString());
assertEquals(1, record.getMetricAsLongOrFail(RpcConstants.RPC_CLIENT_ERROR_COUNT)); assertEquals(1, record.getMetricAsLongOrFail(RpcMeasureConstants.RPC_CLIENT_ERROR_COUNT));
assertEquals(0, record.getMetricAsLongOrFail(RpcConstants.RPC_CLIENT_REQUEST_COUNT)); assertEquals(0, record.getMetricAsLongOrFail(RpcMeasureConstants.RPC_CLIENT_REQUEST_COUNT));
assertEquals(0, record.getMetricAsLongOrFail(RpcConstants.RPC_CLIENT_REQUEST_BYTES)); assertEquals(0, record.getMetricAsLongOrFail(RpcMeasureConstants.RPC_CLIENT_REQUEST_BYTES));
assertEquals(0, assertEquals(0,
record.getMetricAsLongOrFail(RpcConstants.RPC_CLIENT_UNCOMPRESSED_REQUEST_BYTES)); record.getMetricAsLongOrFail(RpcMeasureConstants.RPC_CLIENT_UNCOMPRESSED_REQUEST_BYTES));
assertEquals(0, record.getMetricAsLongOrFail(RpcConstants.RPC_CLIENT_RESPONSE_COUNT)); assertEquals(0, record.getMetricAsLongOrFail(RpcMeasureConstants.RPC_CLIENT_RESPONSE_COUNT));
assertEquals(0, record.getMetricAsLongOrFail(RpcConstants.RPC_CLIENT_RESPONSE_BYTES)); assertEquals(0, record.getMetricAsLongOrFail(RpcMeasureConstants.RPC_CLIENT_RESPONSE_BYTES));
assertEquals(0, assertEquals(0,
record.getMetricAsLongOrFail(RpcConstants.RPC_CLIENT_UNCOMPRESSED_RESPONSE_BYTES)); record.getMetricAsLongOrFail(RpcMeasureConstants.RPC_CLIENT_UNCOMPRESSED_RESPONSE_BYTES));
assertEquals(3000, record.getMetricAsLongOrFail(RpcConstants.RPC_CLIENT_ROUNDTRIP_LATENCY)); assertEquals(
assertNull(record.getMetric(RpcConstants.RPC_CLIENT_SERVER_ELAPSED_TIME)); 3000, record.getMetricAsLongOrFail(RpcMeasureConstants.RPC_CLIENT_ROUNDTRIP_LATENCY));
assertNull(record.getMetric(RpcMeasureConstants.RPC_CLIENT_SERVER_ELAPSED_TIME));
} }
@Test @Test
@ -492,13 +502,18 @@ public class CensusModulesTest {
// EXTRA_TAG is propagated by the FakeStatsContextFactory. Note that not all tags are // EXTRA_TAG is propagated by the FakeStatsContextFactory. Note that not all tags are
// propagated. The StatsContextFactory decides which tags are to propagated. gRPC facilitates // propagated. The StatsContextFactory decides which tags are to propagated. gRPC facilitates
// the propagation by putting them in the headers. // the propagation by putting them in the headers.
StatsContext clientCtx = statsCtxFactory.getDefault().with( TagContext clientCtx = tagger.emptyBuilder().put(
StatsTestUtils.EXTRA_TAG, TagValue.create("extra-tag-value-897")); StatsTestUtils.EXTRA_TAG, TagValue.create("extra-tag-value-897")).build();
CensusStatsModule census = new CensusStatsModule( CensusStatsModule census =
statsCtxFactory, fakeClock.getStopwatchSupplier(), propagate, recordStats); new CensusStatsModule(
tagger,
tagCtxSerializer,
statsRecorder,
fakeClock.getStopwatchSupplier(),
propagate);
Metadata headers = new Metadata(); Metadata headers = new Metadata();
CensusStatsModule.ClientCallTracer callTracer = CensusStatsModule.ClientCallTracer callTracer =
census.newClientCallTracer(clientCtx, method.getFullMethodName()); census.newClientCallTracer(clientCtx, method.getFullMethodName(), recordStats);
// This propagates clientCtx to headers if propagates==true // This propagates clientCtx to headers if propagates==true
callTracer.newClientStreamTracer(CallOptions.DEFAULT, headers); callTracer.newClientStreamTracer(CallOptions.DEFAULT, headers);
if (propagate) { if (propagate) {
@ -509,30 +524,32 @@ public class CensusModulesTest {
} }
ServerStreamTracer serverTracer = ServerStreamTracer serverTracer =
census.getServerTracerFactory().newServerStreamTracer( census.getServerTracerFactory(recordStats).newServerStreamTracer(
method.getFullMethodName(), headers); method.getFullMethodName(), headers);
// Server tracer deserializes clientCtx from the headers, so that it records stats with the // Server tracer deserializes clientCtx from the headers, so that it records stats with the
// propagated tags. // propagated tags.
Context serverContext = serverTracer.filterContext(Context.ROOT); Context serverContext = serverTracer.filterContext(Context.ROOT);
// It also put clientCtx in the Context seen by the call handler // It also put clientCtx in the Context seen by the call handler
assertEquals( assertEquals(
clientCtx.with(RpcConstants.RPC_SERVER_METHOD, TagValue.create(method.getFullMethodName())), tagger.toBuilder(clientCtx).put(
STATS_CONTEXT_KEY.get(serverContext)); RpcMeasureConstants.RPC_METHOD,
TagValue.create(method.getFullMethodName())).build(),
TAG_CONTEXT_KEY.get(serverContext));
// Verifies that the server tracer records the status with the propagated tag // Verifies that the server tracer records the status with the propagated tag
serverTracer.streamClosed(Status.OK); serverTracer.streamClosed(Status.OK);
if (recordStats) { if (recordStats) {
StatsTestUtils.MetricsRecord serverRecord = statsCtxFactory.pollRecord(); StatsTestUtils.MetricsRecord serverRecord = statsRecorder.pollRecord();
assertNotNull(serverRecord); assertNotNull(serverRecord);
assertNoClientContent(serverRecord); assertNoClientContent(serverRecord);
TagValue serverMethodTag = serverRecord.tags.get(RpcConstants.RPC_SERVER_METHOD); TagValue serverMethodTag = serverRecord.tags.get(RpcMeasureConstants.RPC_METHOD);
assertEquals(method.getFullMethodName(), serverMethodTag.toString()); assertEquals(method.getFullMethodName(), serverMethodTag.asString());
TagValue serverStatusTag = serverRecord.tags.get(RpcConstants.RPC_STATUS); TagValue serverStatusTag = serverRecord.tags.get(RpcMeasureConstants.RPC_STATUS);
assertEquals(Status.Code.OK.toString(), serverStatusTag.toString()); assertEquals(Status.Code.OK.toString(), serverStatusTag.asString());
assertNull(serverRecord.getMetric(RpcConstants.RPC_SERVER_ERROR_COUNT)); assertNull(serverRecord.getMetric(RpcMeasureConstants.RPC_SERVER_ERROR_COUNT));
TagValue serverPropagatedTag = serverRecord.tags.get(StatsTestUtils.EXTRA_TAG); TagValue serverPropagatedTag = serverRecord.tags.get(StatsTestUtils.EXTRA_TAG);
assertEquals("extra-tag-value-897", serverPropagatedTag.toString()); assertEquals("extra-tag-value-897", serverPropagatedTag.asString());
} }
// Verifies that the client tracer factory uses clientCtx, which includes the custom tags, to // Verifies that the client tracer factory uses clientCtx, which includes the custom tags, to
@ -540,27 +557,27 @@ public class CensusModulesTest {
callTracer.callEnded(Status.OK); callTracer.callEnded(Status.OK);
if (recordStats) { if (recordStats) {
StatsTestUtils.MetricsRecord clientRecord = statsCtxFactory.pollRecord(); StatsTestUtils.MetricsRecord clientRecord = statsRecorder.pollRecord();
assertNotNull(clientRecord); assertNotNull(clientRecord);
assertNoServerContent(clientRecord); assertNoServerContent(clientRecord);
TagValue clientMethodTag = clientRecord.tags.get(RpcConstants.RPC_CLIENT_METHOD); TagValue clientMethodTag = clientRecord.tags.get(RpcMeasureConstants.RPC_METHOD);
assertEquals(method.getFullMethodName(), clientMethodTag.toString()); assertEquals(method.getFullMethodName(), clientMethodTag.asString());
TagValue clientStatusTag = clientRecord.tags.get(RpcConstants.RPC_STATUS); TagValue clientStatusTag = clientRecord.tags.get(RpcMeasureConstants.RPC_STATUS);
assertEquals(Status.Code.OK.toString(), clientStatusTag.toString()); assertEquals(Status.Code.OK.toString(), clientStatusTag.asString());
assertNull(clientRecord.getMetric(RpcConstants.RPC_CLIENT_ERROR_COUNT)); assertNull(clientRecord.getMetric(RpcMeasureConstants.RPC_CLIENT_ERROR_COUNT));
TagValue clientPropagatedTag = clientRecord.tags.get(StatsTestUtils.EXTRA_TAG); TagValue clientPropagatedTag = clientRecord.tags.get(StatsTestUtils.EXTRA_TAG);
assertEquals("extra-tag-value-897", clientPropagatedTag.toString()); assertEquals("extra-tag-value-897", clientPropagatedTag.asString());
} }
if (!recordStats) { if (!recordStats) {
assertNull(statsCtxFactory.pollRecord()); assertNull(statsRecorder.pollRecord());
} }
} }
@Test @Test
public void statsHeadersNotPropagateDefaultContext() { public void statsHeadersNotPropagateDefaultContext() {
CensusStatsModule.ClientCallTracer callTracer = CensusStatsModule.ClientCallTracer callTracer =
censusStats.newClientCallTracer(statsCtxFactory.getDefault(), method.getFullMethodName()); censusStats.newClientCallTracer(tagger.empty(), method.getFullMethodName(), true);
Metadata headers = new Metadata(); Metadata headers = new Metadata();
callTracer.newClientStreamTracer(CallOptions.DEFAULT, headers); callTracer.newClientStreamTracer(CallOptions.DEFAULT, headers);
assertFalse(headers.containsKey(censusStats.statsHeader)); assertFalse(headers.containsKey(censusStats.statsHeader));
@ -573,7 +590,7 @@ public class CensusModulesTest {
Metadata.Key<byte[]> arbitraryStatsHeader = Metadata.Key<byte[]> arbitraryStatsHeader =
Metadata.Key.of("grpc-tags-bin", Metadata.BINARY_BYTE_MARSHALLER); Metadata.Key.of("grpc-tags-bin", Metadata.BINARY_BYTE_MARSHALLER);
try { try {
statsCtxFactory.deserialize(new ByteArrayInputStream(statsHeaderValue)); tagCtxSerializer.fromByteArray(statsHeaderValue);
fail("Should have thrown"); fail("Should have thrown");
} catch (Exception e) { } catch (Exception e) {
// Expected // Expected
@ -583,7 +600,7 @@ public class CensusModulesTest {
Metadata headers = new Metadata(); Metadata headers = new Metadata();
assertNull(headers.get(censusStats.statsHeader)); assertNull(headers.get(censusStats.statsHeader));
headers.put(arbitraryStatsHeader, statsHeaderValue); headers.put(arbitraryStatsHeader, statsHeaderValue);
assertSame(statsCtxFactory.getDefault(), headers.get(censusStats.statsHeader)); assertSame(tagger.empty(), headers.get(censusStats.statsHeader));
} }
@Test @Test
@ -641,15 +658,19 @@ public class CensusModulesTest {
@Test @Test
public void serverBasicStatsNoHeaders() { public void serverBasicStatsNoHeaders() {
ServerStreamTracer.Factory tracerFactory = censusStats.getServerTracerFactory(); ServerStreamTracer.Factory tracerFactory = censusStats.getServerTracerFactory(true);
ServerStreamTracer tracer = ServerStreamTracer tracer =
tracerFactory.newServerStreamTracer(method.getFullMethodName(), new Metadata()); tracerFactory.newServerStreamTracer(method.getFullMethodName(), new Metadata());
Context filteredContext = tracer.filterContext(Context.ROOT); Context filteredContext = tracer.filterContext(Context.ROOT);
StatsContext statsCtx = STATS_CONTEXT_KEY.get(filteredContext); TagContext statsCtx = TAG_CONTEXT_KEY.get(filteredContext);
assertEquals( assertEquals(
statsCtxFactory.getDefault() tagger
.with(RpcConstants.RPC_SERVER_METHOD, TagValue.create(method.getFullMethodName())), .emptyBuilder()
.put(
RpcMeasureConstants.RPC_METHOD,
TagValue.create(method.getFullMethodName()))
.build(),
statsCtx); statsCtx);
tracer.inboundMessage(0); tracer.inboundMessage(0);
@ -673,24 +694,27 @@ public class CensusModulesTest {
tracer.streamClosed(Status.CANCELLED); tracer.streamClosed(Status.CANCELLED);
StatsTestUtils.MetricsRecord record = statsCtxFactory.pollRecord(); StatsTestUtils.MetricsRecord record = statsRecorder.pollRecord();
assertNotNull(record); assertNotNull(record);
assertNoClientContent(record); assertNoClientContent(record);
TagValue methodTag = record.tags.get(RpcConstants.RPC_SERVER_METHOD); TagValue methodTag = record.tags.get(RpcMeasureConstants.RPC_METHOD);
assertEquals(method.getFullMethodName(), methodTag.toString()); assertEquals(method.getFullMethodName(), methodTag.asString());
TagValue statusTag = record.tags.get(RpcConstants.RPC_STATUS); TagValue statusTag = record.tags.get(RpcMeasureConstants.RPC_STATUS);
assertEquals(Status.Code.CANCELLED.toString(), statusTag.toString()); assertEquals(Status.Code.CANCELLED.toString(), statusTag.asString());
assertEquals(1, record.getMetricAsLongOrFail(RpcConstants.RPC_SERVER_ERROR_COUNT)); assertEquals(1, record.getMetricAsLongOrFail(RpcMeasureConstants.RPC_SERVER_ERROR_COUNT));
assertEquals(2, record.getMetricAsLongOrFail(RpcConstants.RPC_SERVER_RESPONSE_COUNT)); assertEquals(2, record.getMetricAsLongOrFail(RpcMeasureConstants.RPC_SERVER_RESPONSE_COUNT));
assertEquals(1028 + 99, record.getMetricAsLongOrFail(RpcConstants.RPC_SERVER_RESPONSE_BYTES)); assertEquals(
assertEquals(1128 + 865, 1028 + 99, record.getMetricAsLongOrFail(RpcMeasureConstants.RPC_SERVER_RESPONSE_BYTES));
record.getMetricAsLongOrFail(RpcConstants.RPC_SERVER_UNCOMPRESSED_RESPONSE_BYTES)); assertEquals(
assertEquals(2, record.getMetricAsLongOrFail(RpcConstants.RPC_SERVER_REQUEST_COUNT)); 1128 + 865,
assertEquals(34 + 154, record.getMetricAsLongOrFail(RpcConstants.RPC_SERVER_REQUEST_BYTES)); record.getMetricAsLongOrFail(RpcMeasureConstants.RPC_SERVER_UNCOMPRESSED_RESPONSE_BYTES));
assertEquals(2, record.getMetricAsLongOrFail(RpcMeasureConstants.RPC_SERVER_REQUEST_COUNT));
assertEquals(
34 + 154, record.getMetricAsLongOrFail(RpcMeasureConstants.RPC_SERVER_REQUEST_BYTES));
assertEquals(67 + 552, assertEquals(67 + 552,
record.getMetricAsLongOrFail(RpcConstants.RPC_SERVER_UNCOMPRESSED_REQUEST_BYTES)); record.getMetricAsLongOrFail(RpcMeasureConstants.RPC_SERVER_UNCOMPRESSED_REQUEST_BYTES));
assertEquals(100 + 16 + 24, assertEquals(100 + 16 + 24,
record.getMetricAsLongOrFail(RpcConstants.RPC_SERVER_SERVER_LATENCY)); record.getMetricAsLongOrFail(RpcMeasureConstants.RPC_SERVER_SERVER_LATENCY));
} }
@Test @Test
@ -803,27 +827,27 @@ public class CensusModulesTest {
} }
private static void assertNoServerContent(StatsTestUtils.MetricsRecord record) { private static void assertNoServerContent(StatsTestUtils.MetricsRecord record) {
assertNull(record.getMetric(RpcConstants.RPC_SERVER_ERROR_COUNT)); assertNull(record.getMetric(RpcMeasureConstants.RPC_SERVER_ERROR_COUNT));
assertNull(record.getMetric(RpcConstants.RPC_SERVER_REQUEST_COUNT)); assertNull(record.getMetric(RpcMeasureConstants.RPC_SERVER_REQUEST_COUNT));
assertNull(record.getMetric(RpcConstants.RPC_SERVER_RESPONSE_COUNT)); assertNull(record.getMetric(RpcMeasureConstants.RPC_SERVER_RESPONSE_COUNT));
assertNull(record.getMetric(RpcConstants.RPC_SERVER_REQUEST_BYTES)); assertNull(record.getMetric(RpcMeasureConstants.RPC_SERVER_REQUEST_BYTES));
assertNull(record.getMetric(RpcConstants.RPC_SERVER_RESPONSE_BYTES)); assertNull(record.getMetric(RpcMeasureConstants.RPC_SERVER_RESPONSE_BYTES));
assertNull(record.getMetric(RpcConstants.RPC_SERVER_SERVER_ELAPSED_TIME)); assertNull(record.getMetric(RpcMeasureConstants.RPC_SERVER_SERVER_ELAPSED_TIME));
assertNull(record.getMetric(RpcConstants.RPC_SERVER_SERVER_LATENCY)); assertNull(record.getMetric(RpcMeasureConstants.RPC_SERVER_SERVER_LATENCY));
assertNull(record.getMetric(RpcConstants.RPC_SERVER_UNCOMPRESSED_REQUEST_BYTES)); assertNull(record.getMetric(RpcMeasureConstants.RPC_SERVER_UNCOMPRESSED_REQUEST_BYTES));
assertNull(record.getMetric(RpcConstants.RPC_SERVER_UNCOMPRESSED_RESPONSE_BYTES)); assertNull(record.getMetric(RpcMeasureConstants.RPC_SERVER_UNCOMPRESSED_RESPONSE_BYTES));
} }
private static void assertNoClientContent(StatsTestUtils.MetricsRecord record) { private static void assertNoClientContent(StatsTestUtils.MetricsRecord record) {
assertNull(record.getMetric(RpcConstants.RPC_CLIENT_ERROR_COUNT)); assertNull(record.getMetric(RpcMeasureConstants.RPC_CLIENT_ERROR_COUNT));
assertNull(record.getMetric(RpcConstants.RPC_CLIENT_REQUEST_COUNT)); assertNull(record.getMetric(RpcMeasureConstants.RPC_CLIENT_REQUEST_COUNT));
assertNull(record.getMetric(RpcConstants.RPC_CLIENT_RESPONSE_COUNT)); assertNull(record.getMetric(RpcMeasureConstants.RPC_CLIENT_RESPONSE_COUNT));
assertNull(record.getMetric(RpcConstants.RPC_CLIENT_REQUEST_BYTES)); assertNull(record.getMetric(RpcMeasureConstants.RPC_CLIENT_REQUEST_BYTES));
assertNull(record.getMetric(RpcConstants.RPC_CLIENT_RESPONSE_BYTES)); assertNull(record.getMetric(RpcMeasureConstants.RPC_CLIENT_RESPONSE_BYTES));
assertNull(record.getMetric(RpcConstants.RPC_CLIENT_ROUNDTRIP_LATENCY)); assertNull(record.getMetric(RpcMeasureConstants.RPC_CLIENT_ROUNDTRIP_LATENCY));
assertNull(record.getMetric(RpcConstants.RPC_CLIENT_SERVER_ELAPSED_TIME)); assertNull(record.getMetric(RpcMeasureConstants.RPC_CLIENT_SERVER_ELAPSED_TIME));
assertNull(record.getMetric(RpcConstants.RPC_CLIENT_UNCOMPRESSED_REQUEST_BYTES)); assertNull(record.getMetric(RpcMeasureConstants.RPC_CLIENT_UNCOMPRESSED_REQUEST_BYTES));
assertNull(record.getMetric(RpcConstants.RPC_CLIENT_UNCOMPRESSED_RESPONSE_BYTES)); assertNull(record.getMetric(RpcMeasureConstants.RPC_CLIENT_UNCOMPRESSED_RESPONSE_BYTES));
} }
private static class FakeServerCall<ReqT, RespT> extends ServerCall<ReqT, RespT> { private static class FakeServerCall<ReqT, RespT> extends ServerCall<ReqT, RespT> {

View File

@ -17,9 +17,9 @@
package io.grpc.testing.integration; package io.grpc.testing.integration;
import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertThat;
import static com.google.instrumentation.stats.ContextUtils.STATS_CONTEXT_KEY;
import static io.grpc.stub.ClientCalls.blockingServerStreamingCall; import static io.grpc.stub.ClientCalls.blockingServerStreamingCall;
import static io.grpc.testing.integration.Messages.PayloadType.COMPRESSABLE; import static io.grpc.testing.integration.Messages.PayloadType.COMPRESSABLE;
import static io.opencensus.tags.unsafe.ContextUtils.TAG_CONTEXT_KEY;
import static io.opencensus.trace.unsafe.ContextUtils.CONTEXT_SPAN_KEY; import static io.opencensus.trace.unsafe.ContextUtils.CONTEXT_SPAN_KEY;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertFalse;
@ -42,10 +42,6 @@ import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.google.common.io.ByteStreams; import com.google.common.io.ByteStreams;
import com.google.common.util.concurrent.SettableFuture; import com.google.common.util.concurrent.SettableFuture;
import com.google.instrumentation.stats.RpcConstants;
import com.google.instrumentation.stats.StatsContextFactory;
import com.google.instrumentation.stats.TagKey;
import com.google.instrumentation.stats.TagValue;
import com.google.protobuf.BoolValue; import com.google.protobuf.BoolValue;
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import com.google.protobuf.EmptyProtos.Empty; import com.google.protobuf.EmptyProtos.Empty;
@ -71,12 +67,15 @@ import io.grpc.Status;
import io.grpc.StatusRuntimeException; import io.grpc.StatusRuntimeException;
import io.grpc.auth.MoreCallCredentials; import io.grpc.auth.MoreCallCredentials;
import io.grpc.internal.AbstractServerImplBuilder; import io.grpc.internal.AbstractServerImplBuilder;
import io.grpc.internal.CensusStatsModule;
import io.grpc.internal.GrpcUtil; import io.grpc.internal.GrpcUtil;
import io.grpc.internal.testing.StatsTestUtils.FakeStatsContext; import io.grpc.internal.testing.StatsTestUtils;
import io.grpc.internal.testing.StatsTestUtils.FakeStatsContextFactory; import io.grpc.internal.testing.StatsTestUtils.FakeStatsRecorder;
import io.grpc.internal.testing.StatsTestUtils.FakeTagContext;
import io.grpc.internal.testing.StatsTestUtils.FakeTagContextBinarySerializer;
import io.grpc.internal.testing.StatsTestUtils.FakeTagger;
import io.grpc.internal.testing.StatsTestUtils.MetricsRecord; import io.grpc.internal.testing.StatsTestUtils.MetricsRecord;
import io.grpc.internal.testing.StatsTestUtils.MockableSpan; import io.grpc.internal.testing.StatsTestUtils.MockableSpan;
import io.grpc.internal.testing.StatsTestUtils;
import io.grpc.internal.testing.StreamRecorder; import io.grpc.internal.testing.StreamRecorder;
import io.grpc.internal.testing.TestClientStreamTracer; import io.grpc.internal.testing.TestClientStreamTracer;
import io.grpc.internal.testing.TestServerStreamTracer; import io.grpc.internal.testing.TestServerStreamTracer;
@ -97,6 +96,9 @@ import io.grpc.testing.integration.Messages.StreamingInputCallRequest;
import io.grpc.testing.integration.Messages.StreamingInputCallResponse; import io.grpc.testing.integration.Messages.StreamingInputCallResponse;
import io.grpc.testing.integration.Messages.StreamingOutputCallRequest; import io.grpc.testing.integration.Messages.StreamingOutputCallRequest;
import io.grpc.testing.integration.Messages.StreamingOutputCallResponse; import io.grpc.testing.integration.Messages.StreamingOutputCallResponse;
import io.opencensus.contrib.grpc.metrics.RpcMeasureConstants;
import io.opencensus.tags.TagKey;
import io.opencensus.tags.TagValue;
import io.opencensus.trace.Span; import io.opencensus.trace.Span;
import io.opencensus.trace.unsafe.ContextUtils; import io.opencensus.trace.unsafe.ContextUtils;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
@ -153,10 +155,11 @@ public abstract class AbstractInteropTest {
new AtomicReference<Context>(); new AtomicReference<Context>();
private static ScheduledExecutorService testServiceExecutor; private static ScheduledExecutorService testServiceExecutor;
private static Server server; private static Server server;
private static final FakeStatsContextFactory clientStatsCtxFactory = private static final FakeTagger tagger = new FakeTagger();
new FakeStatsContextFactory(); private static final FakeTagContextBinarySerializer tagContextBinarySerializer =
private static final FakeStatsContextFactory serverStatsCtxFactory = new FakeTagContextBinarySerializer();
new FakeStatsContextFactory(); private static final FakeStatsRecorder clientStatsRecorder = new FakeStatsRecorder();
private static final FakeStatsRecorder serverStatsRecorder = new FakeStatsRecorder();
private static final LinkedBlockingQueue<ServerStreamTracerInfo> serverStreamTracers = private static final LinkedBlockingQueue<ServerStreamTracerInfo> serverStreamTracers =
new LinkedBlockingQueue<ServerStreamTracerInfo>(); new LinkedBlockingQueue<ServerStreamTracerInfo>();
@ -212,7 +215,14 @@ public abstract class AbstractInteropTest {
new TestServiceImpl(testServiceExecutor), new TestServiceImpl(testServiceExecutor),
allInterceptors)) allInterceptors))
.addStreamTracerFactory(serverStreamTracerFactory); .addStreamTracerFactory(serverStreamTracerFactory);
io.grpc.internal.TestingAccessor.setStatsContextFactory(builder, serverStatsCtxFactory); io.grpc.internal.TestingAccessor.setStatsImplementation(
builder,
new CensusStatsModule(
tagger,
tagContextBinarySerializer,
serverStatsRecorder,
GrpcUtil.STOPWATCH_SUPPLIER,
true));
try { try {
server = builder.build().start(); server = builder.build().start();
} catch (IOException ex) { } catch (IOException ex) {
@ -266,8 +276,8 @@ public abstract class AbstractInteropTest {
TestServiceGrpc.newBlockingStub(channel).withInterceptors(tracerSetupInterceptor); TestServiceGrpc.newBlockingStub(channel).withInterceptors(tracerSetupInterceptor);
asyncStub = TestServiceGrpc.newStub(channel).withInterceptors(tracerSetupInterceptor); asyncStub = TestServiceGrpc.newStub(channel).withInterceptors(tracerSetupInterceptor);
requestHeadersCapture.set(null); requestHeadersCapture.set(null);
clientStatsCtxFactory.rolloverRecords(); clientStatsRecorder.rolloverRecords();
serverStatsCtxFactory.rolloverRecords(); serverStatsRecorder.rolloverRecords();
serverStreamTracers.clear(); serverStreamTracers.clear();
} }
@ -281,8 +291,9 @@ public abstract class AbstractInteropTest {
protected abstract ManagedChannel createChannel(); protected abstract ManagedChannel createChannel();
protected final StatsContextFactory getClientStatsFactory() { protected final CensusStatsModule createClientCensusStatsModule() {
return clientStatsCtxFactory; return new CensusStatsModule(
tagger, tagContextBinarySerializer, clientStatsRecorder, GrpcUtil.STOPWATCH_SUPPLIER, true);
} }
/** /**
@ -711,10 +722,9 @@ public abstract class AbstractInteropTest {
// CensusStreamTracerModule record final status in the interceptor, thus is guaranteed to be // CensusStreamTracerModule record final status in the interceptor, thus is guaranteed to be
// recorded. The tracer stats rely on the stream being created, which is not always the case // recorded. The tracer stats rely on the stream being created, which is not always the case
// in this test. Therefore we don't check the tracer stats. // in this test. Therefore we don't check the tracer stats.
MetricsRecord clientRecord = clientStatsCtxFactory.pollRecord(5, TimeUnit.SECONDS); MetricsRecord clientRecord = clientStatsRecorder.pollRecord(5, TimeUnit.SECONDS);
checkTags( checkTags(
clientRecord, false, "grpc.testing.TestService/StreamingInputCall", clientRecord, "grpc.testing.TestService/StreamingInputCall", Status.CANCELLED.getCode());
Status.CANCELLED.getCode());
// Do not check server-side metrics, because the status on the server side is undetermined. // Do not check server-side metrics, because the status on the server side is undetermined.
} }
} }
@ -1034,9 +1044,10 @@ public abstract class AbstractInteropTest {
if (metricsExpected()) { if (metricsExpected()) {
// Stream may not have been created before deadline is exceeded, thus we don't test the tracer // Stream may not have been created before deadline is exceeded, thus we don't test the tracer
// stats. // stats.
MetricsRecord clientRecord = clientStatsCtxFactory.pollRecord(5, TimeUnit.SECONDS); MetricsRecord clientRecord = clientStatsRecorder.pollRecord(5, TimeUnit.SECONDS);
checkTags( checkTags(
clientRecord, false, "grpc.testing.TestService/StreamingOutputCall", clientRecord,
"grpc.testing.TestService/StreamingOutputCall",
Status.Code.DEADLINE_EXCEEDED); Status.Code.DEADLINE_EXCEEDED);
// Do not check server-side metrics, because the status on the server side is undetermined. // Do not check server-side metrics, because the status on the server side is undetermined.
} }
@ -1067,9 +1078,10 @@ public abstract class AbstractInteropTest {
if (metricsExpected()) { if (metricsExpected()) {
// Stream may not have been created when deadline is exceeded, thus we don't check tracer // Stream may not have been created when deadline is exceeded, thus we don't check tracer
// stats. // stats.
MetricsRecord clientRecord = clientStatsCtxFactory.pollRecord(5, TimeUnit.SECONDS); MetricsRecord clientRecord = clientStatsRecorder.pollRecord(5, TimeUnit.SECONDS);
checkTags( checkTags(
clientRecord, false, "grpc.testing.TestService/StreamingOutputCall", clientRecord,
"grpc.testing.TestService/StreamingOutputCall",
Status.Code.DEADLINE_EXCEEDED); Status.Code.DEADLINE_EXCEEDED);
// Do not check server-side metrics, because the status on the server side is undetermined. // Do not check server-side metrics, because the status on the server side is undetermined.
} }
@ -1091,10 +1103,9 @@ public abstract class AbstractInteropTest {
// recorded. The tracer stats rely on the stream being created, which is not the case if // recorded. The tracer stats rely on the stream being created, which is not the case if
// deadline is exceeded before the call is created. Therefore we don't check the tracer stats. // deadline is exceeded before the call is created. Therefore we don't check the tracer stats.
if (metricsExpected()) { if (metricsExpected()) {
MetricsRecord clientRecord = clientStatsCtxFactory.pollRecord(5, TimeUnit.SECONDS); MetricsRecord clientRecord = clientStatsRecorder.pollRecord(5, TimeUnit.SECONDS);
checkTags( checkTags(
clientRecord, false, "grpc.testing.TestService/EmptyCall", clientRecord, "grpc.testing.TestService/EmptyCall", Status.DEADLINE_EXCEEDED.getCode());
Status.DEADLINE_EXCEEDED.getCode());
} }
// warm up the channel // warm up the channel
@ -1109,10 +1120,9 @@ public abstract class AbstractInteropTest {
} }
assertStatsTrace("grpc.testing.TestService/EmptyCall", Status.Code.OK); assertStatsTrace("grpc.testing.TestService/EmptyCall", Status.Code.OK);
if (metricsExpected()) { if (metricsExpected()) {
MetricsRecord clientRecord = clientStatsCtxFactory.pollRecord(5, TimeUnit.SECONDS); MetricsRecord clientRecord = clientStatsRecorder.pollRecord(5, TimeUnit.SECONDS);
checkTags( checkTags(
clientRecord, false, "grpc.testing.TestService/EmptyCall", clientRecord, "grpc.testing.TestService/EmptyCall", Status.DEADLINE_EXCEEDED.getCode());
Status.DEADLINE_EXCEEDED.getCode());
} }
} }
@ -1397,9 +1407,9 @@ public abstract class AbstractInteropTest {
Span clientParentSpan = MockableSpan.generateRandomSpan(new Random()); Span clientParentSpan = MockableSpan.generateRandomSpan(new Random());
Context ctx = Context ctx =
Context.ROOT.withValues( Context.ROOT.withValues(
STATS_CONTEXT_KEY, TAG_CONTEXT_KEY,
clientStatsCtxFactory.getDefault().with( tagger.emptyBuilder().put(
StatsTestUtils.EXTRA_TAG, TagValue.create("extra value")), StatsTestUtils.EXTRA_TAG, TagValue.create("extra value")).build(),
ContextUtils.CONTEXT_SPAN_KEY, ContextUtils.CONTEXT_SPAN_KEY,
clientParentSpan); clientParentSpan);
Context origCtx = ctx.attach(); Context origCtx = ctx.attach();
@ -1408,7 +1418,7 @@ public abstract class AbstractInteropTest {
Context serverCtx = contextCapture.get(); Context serverCtx = contextCapture.get();
assertNotNull(serverCtx); assertNotNull(serverCtx);
FakeStatsContext statsCtx = (FakeStatsContext) STATS_CONTEXT_KEY.get(serverCtx); FakeTagContext statsCtx = (FakeTagContext) TAG_CONTEXT_KEY.get(serverCtx);
assertNotNull(statsCtx); assertNotNull(statsCtx);
Map<TagKey, TagValue> tags = statsCtx.getTags(); Map<TagKey, TagValue> tags = statsCtx.getTags();
boolean tagFound = false; boolean tagFound = false;
@ -1530,9 +1540,10 @@ public abstract class AbstractInteropTest {
// CensusStreamTracerModule record final status in the interceptor, thus is guaranteed to be // CensusStreamTracerModule record final status in the interceptor, thus is guaranteed to be
// recorded. The tracer stats rely on the stream being created, which is not always the case // recorded. The tracer stats rely on the stream being created, which is not always the case
// in this test, thus we will not check that. // in this test, thus we will not check that.
MetricsRecord clientRecord = clientStatsCtxFactory.pollRecord(5, TimeUnit.SECONDS); MetricsRecord clientRecord = clientStatsRecorder.pollRecord(5, TimeUnit.SECONDS);
checkTags( checkTags(
clientRecord, false, "grpc.testing.TestService/FullDuplexCall", clientRecord,
"grpc.testing.TestService/FullDuplexCall",
Status.DEADLINE_EXCEEDED.getCode()); Status.DEADLINE_EXCEEDED.getCode());
} }
} }
@ -1792,8 +1803,8 @@ public abstract class AbstractInteropTest {
if (metricsExpected()) { if (metricsExpected()) {
// CensusStreamTracerModule records final status in interceptor, which is guaranteed to be // CensusStreamTracerModule records final status in interceptor, which is guaranteed to be
// done before application receives status. // done before application receives status.
MetricsRecord clientRecord = clientStatsCtxFactory.pollRecord(); MetricsRecord clientRecord = clientStatsRecorder.pollRecord();
checkTags(clientRecord, false, method, code); checkTags(clientRecord, method, code);
if (requests != null && responses != null) { if (requests != null && responses != null) {
checkCensus(clientRecord, false, requests, responses); checkCensus(clientRecord, false, requests, responses);
@ -1824,7 +1835,7 @@ public abstract class AbstractInteropTest {
try { try {
// On the server, the stats is finalized in ServerStreamListener.closed(), which can be // On the server, the stats is finalized in ServerStreamListener.closed(), which can be
// run after the client receives the final status. So we use a timeout. // run after the client receives the final status. So we use a timeout.
serverRecord = serverStatsCtxFactory.pollRecord(5, TimeUnit.SECONDS); serverRecord = serverStatsRecorder.pollRecord(5, TimeUnit.SECONDS);
} catch (InterruptedException e) { } catch (InterruptedException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
@ -1832,7 +1843,7 @@ public abstract class AbstractInteropTest {
break; break;
} }
try { try {
checkTags(serverRecord, true, method, code); checkTags(serverRecord, method, code);
if (requests != null && responses != null) { if (requests != null && responses != null) {
checkCensus(serverRecord, true, requests, responses); checkCensus(serverRecord, true, requests, responses);
} }
@ -1889,15 +1900,14 @@ public abstract class AbstractInteropTest {
} }
private static void checkTags( private static void checkTags(
MetricsRecord record, boolean server, String methodName, Status.Code status) { MetricsRecord record, String methodName, Status.Code status) {
assertNotNull("record is not null", record); assertNotNull("record is not null", record);
TagValue methodNameTag = record.tags.get( TagValue methodNameTag = record.tags.get(RpcMeasureConstants.RPC_METHOD);
server ? RpcConstants.RPC_SERVER_METHOD : RpcConstants.RPC_CLIENT_METHOD);
assertNotNull("method name tagged", methodNameTag); assertNotNull("method name tagged", methodNameTag);
assertEquals("method names match", methodName, methodNameTag.toString()); assertEquals("method names match", methodName, methodNameTag.asString());
TagValue statusTag = record.tags.get(RpcConstants.RPC_STATUS); TagValue statusTag = record.tags.get(RpcMeasureConstants.RPC_STATUS);
assertNotNull("status tagged", statusTag); assertNotNull("status tagged", statusTag);
assertEquals(status.toString(), statusTag.toString()); assertEquals(status.toString(), statusTag.asString());
} }
/** /**
@ -1950,33 +1960,41 @@ public abstract class AbstractInteropTest {
} }
if (server && serverInProcess()) { if (server && serverInProcess()) {
assertEquals( assertEquals(
requests.size(), record.getMetricAsLongOrFail(RpcConstants.RPC_SERVER_REQUEST_COUNT)); requests.size(),
record.getMetricAsLongOrFail(RpcMeasureConstants.RPC_SERVER_REQUEST_COUNT));
assertEquals( assertEquals(
responses.size(), record.getMetricAsLongOrFail(RpcConstants.RPC_SERVER_RESPONSE_COUNT)); responses.size(),
assertEquals(uncompressedRequestsSize, record.getMetricAsLongOrFail(RpcMeasureConstants.RPC_SERVER_RESPONSE_COUNT));
record.getMetricAsLongOrFail(RpcConstants.RPC_SERVER_UNCOMPRESSED_REQUEST_BYTES)); assertEquals(
assertEquals(uncompressedResponsesSize, uncompressedRequestsSize,
record.getMetricAsLongOrFail(RpcConstants.RPC_SERVER_UNCOMPRESSED_RESPONSE_BYTES)); record.getMetricAsLongOrFail(RpcMeasureConstants.RPC_SERVER_UNCOMPRESSED_REQUEST_BYTES));
assertNotNull(record.getMetric(RpcConstants.RPC_SERVER_SERVER_LATENCY)); assertEquals(
uncompressedResponsesSize,
record.getMetricAsLongOrFail(RpcMeasureConstants.RPC_SERVER_UNCOMPRESSED_RESPONSE_BYTES));
assertNotNull(record.getMetric(RpcMeasureConstants.RPC_SERVER_SERVER_LATENCY));
// It's impossible to get the expected wire sizes because it may be compressed, so we just // It's impossible to get the expected wire sizes because it may be compressed, so we just
// check if they are recorded. // check if they are recorded.
assertNotNull(record.getMetric(RpcConstants.RPC_SERVER_REQUEST_BYTES)); assertNotNull(record.getMetric(RpcMeasureConstants.RPC_SERVER_REQUEST_BYTES));
assertNotNull(record.getMetric(RpcConstants.RPC_SERVER_RESPONSE_BYTES)); assertNotNull(record.getMetric(RpcMeasureConstants.RPC_SERVER_RESPONSE_BYTES));
} }
if (!server) { if (!server) {
assertEquals( assertEquals(
requests.size(), record.getMetricAsLongOrFail(RpcConstants.RPC_CLIENT_REQUEST_COUNT)); requests.size(),
record.getMetricAsLongOrFail(RpcMeasureConstants.RPC_CLIENT_REQUEST_COUNT));
assertEquals( assertEquals(
responses.size(), record.getMetricAsLongOrFail(RpcConstants.RPC_CLIENT_RESPONSE_COUNT)); responses.size(),
assertEquals(uncompressedRequestsSize, record.getMetricAsLongOrFail(RpcMeasureConstants.RPC_CLIENT_RESPONSE_COUNT));
record.getMetricAsLongOrFail(RpcConstants.RPC_CLIENT_UNCOMPRESSED_REQUEST_BYTES)); assertEquals(
assertEquals(uncompressedResponsesSize, uncompressedRequestsSize,
record.getMetricAsLongOrFail(RpcConstants.RPC_CLIENT_UNCOMPRESSED_RESPONSE_BYTES)); record.getMetricAsLongOrFail(RpcMeasureConstants.RPC_CLIENT_UNCOMPRESSED_REQUEST_BYTES));
assertNotNull(record.getMetric(RpcConstants.RPC_CLIENT_ROUNDTRIP_LATENCY)); assertEquals(
uncompressedResponsesSize,
record.getMetricAsLongOrFail(RpcMeasureConstants.RPC_CLIENT_UNCOMPRESSED_RESPONSE_BYTES));
assertNotNull(record.getMetric(RpcMeasureConstants.RPC_CLIENT_ROUNDTRIP_LATENCY));
// It's impossible to get the expected wire sizes because it may be compressed, so we just // It's impossible to get the expected wire sizes because it may be compressed, so we just
// check if they are recorded. // check if they are recorded.
assertNotNull(record.getMetric(RpcConstants.RPC_CLIENT_REQUEST_BYTES)); assertNotNull(record.getMetric(RpcMeasureConstants.RPC_CLIENT_REQUEST_BYTES));
assertNotNull(record.getMetric(RpcConstants.RPC_CLIENT_RESPONSE_BYTES)); assertNotNull(record.getMetric(RpcMeasureConstants.RPC_CLIENT_RESPONSE_BYTES));
} }
} }

View File

@ -365,7 +365,8 @@ public class TestServiceClient {
} }
builder = okBuilder; builder = okBuilder;
} }
io.grpc.internal.TestingAccessor.setStatsContextFactory(builder, getClientStatsFactory()); io.grpc.internal.TestingAccessor.setStatsImplementation(
builder, createClientCensusStatsModule());
return builder.build(); return builder.build();
} }

View File

@ -48,7 +48,8 @@ public class AutoWindowSizingOnTest extends AbstractInteropTest {
NettyChannelBuilder builder = NettyChannelBuilder.forAddress("localhost", getPort()) NettyChannelBuilder builder = NettyChannelBuilder.forAddress("localhost", getPort())
.negotiationType(NegotiationType.PLAINTEXT) .negotiationType(NegotiationType.PLAINTEXT)
.maxInboundMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE); .maxInboundMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE);
io.grpc.internal.TestingAccessor.setStatsContextFactory(builder, getClientStatsFactory()); io.grpc.internal.TestingAccessor.setStatsImplementation(
builder, createClientCensusStatsModule());
return builder.build(); return builder.build();
} }
} }

View File

@ -59,7 +59,8 @@ public class Http2NettyLocalChannelTest extends AbstractInteropTest {
.channelType(LocalChannel.class) .channelType(LocalChannel.class)
.flowControlWindow(65 * 1024) .flowControlWindow(65 * 1024)
.maxInboundMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE); .maxInboundMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE);
io.grpc.internal.TestingAccessor.setStatsContextFactory(builder, getClientStatsFactory()); io.grpc.internal.TestingAccessor.setStatsImplementation(
builder, createClientCensusStatsModule());
return builder.build(); return builder.build();
} }
} }

View File

@ -80,7 +80,8 @@ public class Http2NettyTest extends AbstractInteropTest {
.ciphers(TestUtils.preferredTestCiphers(), SupportedCipherSuiteFilter.INSTANCE) .ciphers(TestUtils.preferredTestCiphers(), SupportedCipherSuiteFilter.INSTANCE)
.sslProvider(SslProvider.OPENSSL) .sslProvider(SslProvider.OPENSSL)
.build()); .build());
io.grpc.internal.TestingAccessor.setStatsContextFactory(builder, getClientStatsFactory()); io.grpc.internal.TestingAccessor.setStatsImplementation(
builder, createClientCensusStatsModule());
return builder.build(); return builder.build();
} catch (Exception ex) { } catch (Exception ex) {
throw new RuntimeException(ex); throw new RuntimeException(ex);

View File

@ -105,7 +105,8 @@ public class Http2OkHttpTest extends AbstractInteropTest {
.build()) .build())
.overrideAuthority(GrpcUtil.authorityFromHostAndPort( .overrideAuthority(GrpcUtil.authorityFromHostAndPort(
TestUtils.TEST_SERVER_HOST, getPort())); TestUtils.TEST_SERVER_HOST, getPort()));
io.grpc.internal.TestingAccessor.setStatsContextFactory(builder, getClientStatsFactory()); io.grpc.internal.TestingAccessor.setStatsImplementation(
builder, createClientCensusStatsModule());
try { try {
builder.sslSocketFactory(TestUtils.newSslSocketFactoryForCa(Platform.get().getProvider(), builder.sslSocketFactory(TestUtils.newSslSocketFactoryForCa(Platform.get().getProvider(),
TestUtils.loadCert("ca.pem"))); TestUtils.loadCert("ca.pem")));

View File

@ -44,7 +44,8 @@ public class InProcessTest extends AbstractInteropTest {
@Override @Override
protected ManagedChannel createChannel() { protected ManagedChannel createChannel() {
InProcessChannelBuilder builder = InProcessChannelBuilder.forName(SERVER_NAME); InProcessChannelBuilder builder = InProcessChannelBuilder.forName(SERVER_NAME);
io.grpc.internal.TestingAccessor.setStatsContextFactory(builder, getClientStatsFactory()); io.grpc.internal.TestingAccessor.setStatsImplementation(
builder, createClientCensusStatsModule());
return builder.build(); return builder.build();
} }

View File

@ -172,7 +172,8 @@ public class TransportCompressionTest extends AbstractInteropTest {
} }
}) })
.usePlaintext(true); .usePlaintext(true);
io.grpc.internal.TestingAccessor.setStatsContextFactory(builder, getClientStatsFactory()); io.grpc.internal.TestingAccessor.setStatsImplementation(
builder, createClientCensusStatsModule());
return builder.build(); return builder.build();
} }

View File

@ -16,26 +16,24 @@
package io.grpc.internal; package io.grpc.internal;
import com.google.instrumentation.stats.StatsContextFactory;
/** /**
* Test helper that allows accessing package-private stuff. * Test helper that allows accessing package-private stuff.
*/ */
public final class TestingAccessor { public final class TestingAccessor {
/** /**
* Sets a custom {@link StatsContextFactory} for tests. * Sets a custom stats implementation for tests.
*/ */
public static void setStatsContextFactory( public static void setStatsImplementation(
AbstractManagedChannelImplBuilder<?> builder, StatsContextFactory factory) { AbstractManagedChannelImplBuilder<?> builder, CensusStatsModule censusStats) {
builder.statsContextFactory(factory); builder.overrideCensusStatsModule(censusStats);
} }
/** /**
* Sets a custom {@link StatsContextFactory} for tests. * Sets a custom stats implementation for tests.
*/ */
public static void setStatsContextFactory( public static void setStatsImplementation(
AbstractServerImplBuilder<?> builder, StatsContextFactory factory) { AbstractServerImplBuilder<?> builder, CensusStatsModule censusStats) {
builder.statsContextFactory(factory); builder.overrideCensusStatsModule(censusStats);
} }
private TestingAccessor() { private TestingAccessor() {

View File

@ -19,15 +19,23 @@ package io.grpc.internal.testing;
import static com.google.common.base.Charsets.UTF_8; import static com.google.common.base.Charsets.UTF_8;
import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkNotNull;
import com.google.common.base.Function;
import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMap;
import com.google.instrumentation.stats.MeasurementDescriptor; import com.google.common.collect.Iterators;
import com.google.instrumentation.stats.MeasurementMap; import com.google.common.collect.Maps;
import com.google.instrumentation.stats.MeasurementValue; import io.opencensus.common.Scope;
import com.google.instrumentation.stats.StatsContext; import io.opencensus.stats.Measure;
import com.google.instrumentation.stats.StatsContextFactory; import io.opencensus.stats.MeasureMap;
import com.google.instrumentation.stats.TagKey; import io.opencensus.stats.StatsRecorder;
import com.google.instrumentation.stats.TagValue; import io.opencensus.tags.Tag;
import io.grpc.internal.IoUtils; import io.opencensus.tags.TagContext;
import io.opencensus.tags.TagContextBuilder;
import io.opencensus.tags.TagKey;
import io.opencensus.tags.TagValue;
import io.opencensus.tags.Tagger;
import io.opencensus.tags.propagation.TagContextBinarySerializer;
import io.opencensus.tags.propagation.TagContextDeserializationException;
import io.opencensus.tags.unsafe.ContextUtils;
import io.opencensus.trace.Annotation; import io.opencensus.trace.Annotation;
import io.opencensus.trace.AttributeValue; import io.opencensus.trace.AttributeValue;
import io.opencensus.trace.EndSpanOptions; import io.opencensus.trace.EndSpanOptions;
@ -40,10 +48,8 @@ import io.opencensus.trace.SpanContext;
import io.opencensus.trace.SpanId; import io.opencensus.trace.SpanId;
import io.opencensus.trace.TraceId; import io.opencensus.trace.TraceId;
import io.opencensus.trace.TraceOptions; import io.opencensus.trace.TraceOptions;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.EnumSet; import java.util.EnumSet;
import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Random; import java.util.Random;
@ -57,10 +63,12 @@ public class StatsTestUtils {
} }
public static class MetricsRecord { public static class MetricsRecord {
public final ImmutableMap<TagKey, TagValue> tags;
public final MeasurementMap metrics;
private MetricsRecord(ImmutableMap<TagKey, TagValue> tags, MeasurementMap metrics) { public final ImmutableMap<TagKey, TagValue> tags;
public final ImmutableMap<Measure, Number> metrics;
private MetricsRecord(
ImmutableMap<TagKey, TagValue> tags, ImmutableMap<Measure, Number> metrics) {
this.tags = tags; this.tags = tags;
this.metrics = metrics; this.metrics = metrics;
} }
@ -69,10 +77,16 @@ public class StatsTestUtils {
* Returns the value of a metric, or {@code null} if not found. * Returns the value of a metric, or {@code null} if not found.
*/ */
@Nullable @Nullable
public Double getMetric(MeasurementDescriptor metricName) { public Double getMetric(Measure measure) {
for (MeasurementValue m : metrics) { for (Map.Entry<Measure, Number> m : metrics.entrySet()) {
if (m.getMeasurement().equals(metricName)) { if (m.getKey().equals(measure)) {
return m.getValue(); Number value = m.getValue();
if (value instanceof Double) {
return (Double) value;
} else if (value instanceof Long) {
return (double) (Long) value;
}
throw new AssertionError("Unexpected measure value type: " + value.getClass().getName());
} }
} }
return null; return null;
@ -81,9 +95,9 @@ public class StatsTestUtils {
/** /**
* Returns the value of a metric converted to long, or throw if not found. * Returns the value of a metric converted to long, or throw if not found.
*/ */
public long getMetricAsLongOrFail(MeasurementDescriptor metricName) { public long getMetricAsLongOrFail(Measure measure) {
Double doubleValue = getMetric(metricName); Double doubleValue = getMetric(measure);
checkNotNull(doubleValue, "Metric not found: %s", metricName.toString()); checkNotNull(doubleValue, "Measure not found: %s", measure.getName());
long longValue = (long) (Math.abs(doubleValue) + 0.0001); long longValue = (long) (Math.abs(doubleValue) + 0.0001);
if (doubleValue < 0) { if (doubleValue < 0) {
longValue = -longValue; longValue = -longValue;
@ -93,39 +107,28 @@ public class StatsTestUtils {
} }
/** /**
* This tag will be propagated by {@link FakeStatsContextFactory} on the wire. * This tag will be propagated by {@link FakeTagger} on the wire.
*/ */
public static final TagKey EXTRA_TAG = TagKey.create("/rpc/test/extratag"); public static final TagKey EXTRA_TAG = TagKey.create("/rpc/test/extratag");
private static final String EXTRA_TAG_HEADER_VALUE_PREFIX = "extratag:"; private static final String EXTRA_TAG_HEADER_VALUE_PREFIX = "extratag:";
private static final String NO_EXTRA_TAG_HEADER_VALUE_PREFIX = "noextratag";
/** /**
* A factory that makes fake {@link StatsContext}s and saves the created contexts to be * A {@link Tagger} implementation that saves metrics records to be accessible from {@link
* accessible from {@link #pollContextOrFail}. The contexts it has created would save metrics * #pollRecord()} and {@link #pollRecord(long, TimeUnit)}, until {@link #rolloverRecords} is
* records to be accessible from {@link #pollRecord()} and {@link #pollRecord(long, TimeUnit)}, * called.
* until {@link #rolloverRecords} is called.
*/ */
public static final class FakeStatsContextFactory extends StatsContextFactory { public static final class FakeStatsRecorder extends StatsRecorder {
private BlockingQueue<MetricsRecord> records; private BlockingQueue<MetricsRecord> records;
public final BlockingQueue<FakeStatsContext> contexts =
new LinkedBlockingQueue<FakeStatsContext>();
private final FakeStatsContext defaultContext;
/** public FakeStatsRecorder() {
* Constructor.
*/
public FakeStatsContextFactory() {
rolloverRecords();
defaultContext = new FakeStatsContext(ImmutableMap.<TagKey, TagValue>of(), this);
// The records on the default context is not visible from pollRecord(), just like it's
// not visible from pollContextOrFail() either.
rolloverRecords(); rolloverRecords();
} }
public StatsContext pollContextOrFail() { @Override
StatsContext cc = contexts.poll(); public MeasureMap newMeasureMap() {
return checkNotNull(cc); return new FakeStatsRecord(this);
} }
public MetricsRecord pollRecord() { public MetricsRecord pollRecord() {
@ -136,31 +139,8 @@ public class StatsTestUtils {
return getCurrentRecordSink().poll(timeout, unit); return getCurrentRecordSink().poll(timeout, unit);
} }
@Override
public StatsContext deserialize(InputStream buffer) throws IOException {
String serializedString;
try {
serializedString = new String(IoUtils.toByteArray(buffer), UTF_8);
} catch (IOException e) {
throw new RuntimeException(e);
}
if (serializedString.startsWith(EXTRA_TAG_HEADER_VALUE_PREFIX)) {
return getDefault().with(EXTRA_TAG,
TagValue.create(serializedString.substring(EXTRA_TAG_HEADER_VALUE_PREFIX.length())));
} else if (serializedString.startsWith(NO_EXTRA_TAG_HEADER_VALUE_PREFIX)) {
return getDefault();
} else {
throw new IOException("Malformed value");
}
}
@Override
public FakeStatsContext getDefault() {
return defaultContext;
}
/** /**
* Disconnect this factory with the contexts it has created so far. The records from those * Disconnect this tagger with the contexts it has created so far. The records from those
* contexts will not show up in {@link #pollRecord}. Useful for isolating the records between * contexts will not show up in {@link #pollRecord}. Useful for isolating the records between
* test cases. * test cases.
*/ */
@ -174,45 +154,111 @@ public class StatsTestUtils {
} }
} }
public static final class FakeStatsContext extends StatsContext { public static final class FakeTagger extends Tagger {
private final ImmutableMap<TagKey, TagValue> tags;
private final FakeStatsContextFactory factory; @Override
public FakeTagContext empty() {
return FakeTagContext.EMPTY;
}
@Override
public TagContext getCurrentTagContext() {
return ContextUtils.TAG_CONTEXT_KEY.get();
}
@Override
public TagContextBuilder emptyBuilder() {
return new FakeTagContextBuilder(ImmutableMap.<TagKey, TagValue>of());
}
@Override
public FakeTagContextBuilder toBuilder(TagContext tags) {
return new FakeTagContextBuilder(getTags(tags));
}
@Override
public TagContextBuilder currentBuilder() {
throw new UnsupportedOperationException();
}
@Override
public Scope withTagContext(TagContext tags) {
throw new UnsupportedOperationException();
}
}
public static final class FakeTagContextBinarySerializer extends TagContextBinarySerializer {
private final FakeTagger tagger = new FakeTagger();
@Override
public TagContext fromByteArray(byte[] bytes) throws TagContextDeserializationException {
String serializedString = new String(bytes, UTF_8);
if (serializedString.startsWith(EXTRA_TAG_HEADER_VALUE_PREFIX)) {
return tagger.emptyBuilder()
.put(EXTRA_TAG,
TagValue.create(serializedString.substring(EXTRA_TAG_HEADER_VALUE_PREFIX.length())))
.build();
} else {
throw new TagContextDeserializationException("Malformed value");
}
}
@Override
public byte[] toByteArray(TagContext tags) {
TagValue extraTagValue = getTags(tags).get(EXTRA_TAG);
if (extraTagValue == null) {
throw new UnsupportedOperationException("TagContext must contain EXTRA_TAG");
}
return (EXTRA_TAG_HEADER_VALUE_PREFIX + extraTagValue.asString()).getBytes(UTF_8);
}
}
public static final class FakeStatsRecord extends MeasureMap {
private final BlockingQueue<MetricsRecord> recordSink; private final BlockingQueue<MetricsRecord> recordSink;
public final Map<Measure, Number> metrics = Maps.newHashMap();
private FakeStatsContext(ImmutableMap<TagKey, TagValue> tags, private FakeStatsRecord(FakeStatsRecorder statsRecorder) {
FakeStatsContextFactory factory) { this.recordSink = statsRecorder.getCurrentRecordSink();
this.tags = tags;
this.factory = factory;
this.recordSink = factory.getCurrentRecordSink();
}
public Map<TagKey, TagValue> getTags() {
return tags;
} }
@Override @Override
public Builder builder() { public MeasureMap put(Measure.MeasureDouble measure, double value) {
return new FakeStatsContextBuilder(this); metrics.put(measure, value);
}
@Override
public StatsContext record(MeasurementMap metrics) {
recordSink.add(new MetricsRecord(tags, metrics));
return this; return this;
} }
@Override @Override
public void serialize(OutputStream os) { public MeasureMap put(Measure.MeasureLong measure, long value) {
TagValue extraTagValue = tags.get(EXTRA_TAG); metrics.put(measure, value);
try { return this;
if (extraTagValue == null) {
os.write(NO_EXTRA_TAG_HEADER_VALUE_PREFIX.getBytes(UTF_8));
} else {
os.write((EXTRA_TAG_HEADER_VALUE_PREFIX + extraTagValue.toString()).getBytes(UTF_8));
} }
} catch (IOException e) {
throw new RuntimeException(e); @Override
public void record(TagContext tags) {
recordSink.add(new MetricsRecord(getTags(tags), ImmutableMap.copyOf(metrics)));
} }
@Override
public void record() {
throw new UnsupportedOperationException();
}
}
public static final class FakeTagContext extends TagContext {
private static final FakeTagContext EMPTY =
new FakeTagContext(ImmutableMap.<TagKey, TagValue>of());
private final ImmutableMap<TagKey, TagValue> tags;
private FakeTagContext(ImmutableMap<TagKey, TagValue> tags) {
this.tags = tags;
}
public ImmutableMap<TagKey, TagValue> getTags() {
return tags;
} }
@Override @Override
@ -221,41 +267,55 @@ public class StatsTestUtils {
} }
@Override @Override
public boolean equals(Object other) { protected Iterator<Tag> getIterator() {
if (!(other instanceof FakeStatsContext)) { return Iterators.transform(
return false; tags.entrySet().iterator(),
new Function<Map.Entry<TagKey, TagValue>, Tag>() {
@Override
public Tag apply(@Nullable Map.Entry<TagKey, TagValue> entry) {
return Tag.create(entry.getKey(), entry.getValue());
} }
FakeStatsContext otherCtx = (FakeStatsContext) other; });
return tags.equals(otherCtx.tags); }
}
public static class FakeTagContextBuilder extends TagContextBuilder {
private final Map<TagKey, TagValue> tagsBuilder = Maps.newHashMap();
private FakeTagContextBuilder(Map<TagKey, TagValue> tags) {
tagsBuilder.putAll(tags);
} }
@Override @Override
public int hashCode() { public TagContextBuilder put(TagKey key, TagValue value) {
return tags.hashCode();
}
}
private static class FakeStatsContextBuilder extends StatsContext.Builder {
private final ImmutableMap.Builder<TagKey, TagValue> tagsBuilder = ImmutableMap.builder();
private final FakeStatsContext base;
private FakeStatsContextBuilder(FakeStatsContext base) {
this.base = base;
tagsBuilder.putAll(base.tags);
}
@Override
public StatsContext.Builder set(TagKey key, TagValue value) {
tagsBuilder.put(key, value); tagsBuilder.put(key, value);
return this; return this;
} }
@Override @Override
public StatsContext build() { public TagContextBuilder remove(TagKey key) {
FakeStatsContext context = new FakeStatsContext(tagsBuilder.build(), base.factory); tagsBuilder.remove(key);
base.factory.contexts.add(context); return this;
}
@Override
public TagContext build() {
FakeTagContext context = new FakeTagContext(ImmutableMap.copyOf(tagsBuilder));
return context; return context;
} }
@Override
public Scope buildScoped() {
throw new UnsupportedOperationException();
}
}
// This method handles the default TagContext, which isn't an instance of FakeTagContext.
private static ImmutableMap<TagKey, TagValue> getTags(TagContext tags) {
return tags instanceof FakeTagContext
? ((FakeTagContext) tags).getTags()
: ImmutableMap.<TagKey, TagValue>of();
} }
// TODO(bdrutu): Remove this class after OpenCensus releases support for this class. // TODO(bdrutu): Remove this class after OpenCensus releases support for this class.