binder: Helper class to allow in process servers to use peer uids in test (#11014)

This commit is contained in:
abtom 2024-03-22 18:57:29 -04:00 committed by GitHub
parent 10cb4a3bed
commit 537dbe826a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 186 additions and 0 deletions

View File

@ -0,0 +1,123 @@
package io.grpc.binder;
import static com.google.common.truth.Truth.assertThat;
import static java.nio.charset.StandardCharsets.UTF_8;
import com.google.common.collect.ImmutableList;
import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ClientInterceptors;
import io.grpc.ManagedChannel;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.ServerInterceptors;
import io.grpc.ServerServiceDefinition;
import io.grpc.inprocess.InProcessChannelBuilder;
import io.grpc.inprocess.InProcessServerBuilder;
import io.grpc.stub.ClientCalls;
import io.grpc.stub.MetadataUtils;
import io.grpc.stub.ServerCalls;
import io.grpc.testing.GrpcCleanupRule;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.concurrent.atomic.AtomicReference;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@RunWith(JUnit4.class)
public final class PeerUidTestHelperTest {
@Rule
public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
private static final int FAKE_UID = 12345;
private final AtomicReference<PeerUid> clientUidCapture = new AtomicReference<>();
@Test
public void keyPopulatedWithInterceptorAndHeader() throws Exception {
makeServiceCall(/* includeInterceptor= */ true, /* includeUidInHeader= */ true, FAKE_UID);
assertThat(clientUidCapture.get()).isEqualTo(new PeerUid(FAKE_UID));
}
@Test
public void keyNotPopulatedWithInterceptorAndNoHeader() throws Exception {
makeServiceCall(/* includeInterceptor= */ true, /* includeUidInHeader= */ false, /* uid= */ -1);
assertThat(clientUidCapture.get()).isNull();
}
@Test
public void keyNotPopulatedWithoutInterceptorAndWithHeader() throws Exception {
makeServiceCall(
/* includeInterceptor= */ false, /* includeUidInHeader= */ true, /* uid= */ FAKE_UID);
assertThat(clientUidCapture.get()).isNull();
}
private final MethodDescriptor<String, String> method =
MethodDescriptor.newBuilder(StringMarshaller.INSTANCE, StringMarshaller.INSTANCE)
.setFullMethodName("test/method")
.setType(MethodDescriptor.MethodType.UNARY)
.build();
private void makeServiceCall(boolean includeInterceptor, boolean includeUidInHeader, int uid)
throws Exception {
ServerCallHandler<String, String> callHandler =
ServerCalls.asyncUnaryCall(
(req, respObserver) -> {
clientUidCapture.set(PeerUids.REMOTE_PEER.get());
respObserver.onNext(req);
respObserver.onCompleted();
});
ImmutableList<ServerInterceptor> interceptors;
if (includeInterceptor) {
interceptors = ImmutableList.of(PeerUidTestHelper.newTestPeerIdentifyingServerInterceptor());
} else {
interceptors = ImmutableList.of();
}
ServerServiceDefinition serviceDef =
ServerInterceptors.intercept(
ServerServiceDefinition.builder("test").addMethod(method, callHandler).build(),
interceptors);
InProcessServerBuilder server =
InProcessServerBuilder.forName("test").directExecutor().addService(serviceDef);
grpcCleanup.register(server.build().start());
Channel channel = InProcessChannelBuilder.forName("test").directExecutor().build();
grpcCleanup.register((ManagedChannel) channel);
if (includeUidInHeader) {
Metadata header = new Metadata();
header.put(PeerUidTestHelper.UID_KEY, uid);
channel =
ClientInterceptors.intercept(channel, MetadataUtils.newAttachHeadersInterceptor(header));
}
ClientCalls.blockingUnaryCall(channel, method, CallOptions.DEFAULT, "hello");
}
private static class StringMarshaller implements MethodDescriptor.Marshaller<String> {
public static final StringMarshaller INSTANCE = new StringMarshaller();
@Override
public InputStream stream(String value) {
return new ByteArrayInputStream(value.getBytes(UTF_8));
}
@Override
public String parse(InputStream stream) {
try {
return new String(stream.readAllBytes(), UTF_8);
} catch (IOException ex) {
throw new RuntimeException(ex);
}
}
}
}

View File

@ -0,0 +1,63 @@
package io.grpc.binder;
import io.grpc.Context;
import io.grpc.Contexts;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
/**
* Class which helps set up {@link PeerUids} to be used in tests.
*/
public final class PeerUidTestHelper {
/**
* The UID of the calling package is set with the value of this key.
*/
public static final Metadata.Key<Integer> UID_KEY =
Metadata.Key.of("binder-remote-uid-for-unit-testing", PeerUidTestMarshaller.INSTANCE);
/**
* Creates an interceptor that associates the {@link PeerUids#REMOTE_PEER} key in the request
* {@link Context} with a UID provided by the client in the {@link #UID_KEY} request header, if
* present.
*
* <p>The returned interceptor works with any gRPC transport but is meant for in-process unit
* testing of gRPC/binder services that depend on {@link PeerUids}.
*/
public static ServerInterceptor newTestPeerIdentifyingServerInterceptor() {
return new ServerInterceptor() {
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next) {
if (headers.containsKey(UID_KEY)) {
Context context =
Context.current().withValue(PeerUids.REMOTE_PEER, new PeerUid(headers.get(UID_KEY)));
return Contexts.interceptCall(context, call, headers, next);
}
return next.startCall(call, headers);
}
};
}
private PeerUidTestHelper() {
}
private static class PeerUidTestMarshaller implements Metadata.AsciiMarshaller<Integer> {
public static final PeerUidTestMarshaller INSTANCE = new PeerUidTestMarshaller();
@Override
public String toAsciiString(Integer value) {
return value.toString();
}
@Override
public Integer parseAsciiString(String serialized) {
return Integer.parseInt(serialized);
}
}
;
}