mirror of https://github.com/grpc/grpc-java.git
xds: implement MeshCACertificateProvider (#7274)
This commit is contained in:
parent
0a99a20b70
commit
65e7ffc788
|
@ -177,6 +177,8 @@ subprojects {
|
|||
conscrypt: 'org.conscrypt:conscrypt-openjdk-uber:2.2.1',
|
||||
re2j: 'com.google.re2j:re2j:1.2',
|
||||
|
||||
bouncycastle: 'org.bouncycastle:bcpkix-jdk15on:1.61',
|
||||
|
||||
// Test dependencies.
|
||||
junit: 'junit:junit:4.12',
|
||||
mockito: 'org.mockito:mockito-core:3.3.3',
|
||||
|
|
|
@ -26,7 +26,8 @@ dependencies {
|
|||
project(':grpc-auth'),
|
||||
project(path: ':grpc-alts', configuration: 'shadow'),
|
||||
libraries.gson,
|
||||
libraries.re2j
|
||||
libraries.re2j,
|
||||
libraries.bouncycastle
|
||||
def nettyDependency = implementation project(':grpc-netty')
|
||||
|
||||
implementation (libraries.opencensus_proto) {
|
||||
|
|
|
@ -49,42 +49,42 @@ public abstract class CertificateProvider implements Closeable {
|
|||
|
||||
@VisibleForTesting
|
||||
static final class DistributorWatcher implements Watcher {
|
||||
private PrivateKey lastKey;
|
||||
private List<X509Certificate> lastCertChain;
|
||||
private List<X509Certificate> lastTrustedRoots;
|
||||
private PrivateKey privateKey;
|
||||
private List<X509Certificate> certChain;
|
||||
private List<X509Certificate> trustedRoots;
|
||||
|
||||
@VisibleForTesting
|
||||
final Set<Watcher> downsstreamWatchers = new HashSet<>();
|
||||
final Set<Watcher> downstreamWatchers = new HashSet<>();
|
||||
|
||||
synchronized void addWatcher(Watcher watcher) {
|
||||
downsstreamWatchers.add(watcher);
|
||||
if (lastKey != null && lastCertChain != null) {
|
||||
downstreamWatchers.add(watcher);
|
||||
if (privateKey != null && certChain != null) {
|
||||
sendLastCertificateUpdate(watcher);
|
||||
}
|
||||
if (lastTrustedRoots != null) {
|
||||
if (trustedRoots != null) {
|
||||
sendLastTrustedRootsUpdate(watcher);
|
||||
}
|
||||
}
|
||||
|
||||
synchronized void removeWatcher(Watcher watcher) {
|
||||
downsstreamWatchers.remove(watcher);
|
||||
downstreamWatchers.remove(watcher);
|
||||
}
|
||||
|
||||
private void sendLastCertificateUpdate(Watcher watcher) {
|
||||
watcher.updateCertificate(lastKey, lastCertChain);
|
||||
watcher.updateCertificate(privateKey, certChain);
|
||||
}
|
||||
|
||||
private void sendLastTrustedRootsUpdate(Watcher watcher) {
|
||||
watcher.updateTrustedRoots(lastTrustedRoots);
|
||||
watcher.updateTrustedRoots(trustedRoots);
|
||||
}
|
||||
|
||||
@Override
|
||||
public synchronized void updateCertificate(PrivateKey key, List<X509Certificate> certChain) {
|
||||
checkNotNull(key, "key");
|
||||
checkNotNull(certChain, "certChain");
|
||||
lastKey = key;
|
||||
lastCertChain = certChain;
|
||||
for (Watcher watcher : downsstreamWatchers) {
|
||||
privateKey = key;
|
||||
this.certChain = certChain;
|
||||
for (Watcher watcher : downstreamWatchers) {
|
||||
sendLastCertificateUpdate(watcher);
|
||||
}
|
||||
}
|
||||
|
@ -92,18 +92,36 @@ public abstract class CertificateProvider implements Closeable {
|
|||
@Override
|
||||
public synchronized void updateTrustedRoots(List<X509Certificate> trustedRoots) {
|
||||
checkNotNull(trustedRoots, "trustedRoots");
|
||||
lastTrustedRoots = trustedRoots;
|
||||
for (Watcher watcher : downsstreamWatchers) {
|
||||
this.trustedRoots = trustedRoots;
|
||||
for (Watcher watcher : downstreamWatchers) {
|
||||
sendLastTrustedRootsUpdate(watcher);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public synchronized void onError(Status errorStatus) {
|
||||
for (Watcher watcher : downsstreamWatchers) {
|
||||
for (Watcher watcher : downstreamWatchers) {
|
||||
watcher.onError(errorStatus);
|
||||
}
|
||||
}
|
||||
|
||||
X509Certificate getLastIdentityCert() {
|
||||
if (certChain != null && !certChain.isEmpty()) {
|
||||
return certChain.get(0);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
void close() {
|
||||
downstreamWatchers.clear();
|
||||
clearValues();
|
||||
}
|
||||
|
||||
void clearValues() {
|
||||
privateKey = null;
|
||||
certChain = null;
|
||||
trustedRoots = null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -17,35 +17,336 @@
|
|||
package io.grpc.xds.internal.certprovider;
|
||||
|
||||
import static com.google.common.base.Preconditions.checkArgument;
|
||||
import static com.google.common.base.Preconditions.checkNotNull;
|
||||
import static io.grpc.Status.Code.ABORTED;
|
||||
import static io.grpc.Status.Code.CANCELLED;
|
||||
import static io.grpc.Status.Code.DEADLINE_EXCEEDED;
|
||||
import static io.grpc.Status.Code.INTERNAL;
|
||||
import static io.grpc.Status.Code.RESOURCE_EXHAUSTED;
|
||||
import static io.grpc.Status.Code.UNAVAILABLE;
|
||||
import static io.grpc.Status.Code.UNKNOWN;
|
||||
import static java.nio.charset.StandardCharsets.UTF_8;
|
||||
|
||||
import com.google.auth.oauth2.GoogleCredentials;
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.protobuf.Duration;
|
||||
import google.security.meshca.v1.MeshCertificateServiceGrpc;
|
||||
import google.security.meshca.v1.Meshca;
|
||||
import io.grpc.CallOptions;
|
||||
import io.grpc.Channel;
|
||||
import io.grpc.ClientCall;
|
||||
import io.grpc.ClientInterceptor;
|
||||
import io.grpc.ForwardingClientCall;
|
||||
import io.grpc.InternalLogId;
|
||||
import io.grpc.ManagedChannel;
|
||||
import io.grpc.ManagedChannelBuilder;
|
||||
import io.grpc.Metadata;
|
||||
import io.grpc.MethodDescriptor;
|
||||
import io.grpc.Status;
|
||||
import io.grpc.SynchronizationContext;
|
||||
import io.grpc.auth.MoreCallCredentials;
|
||||
import io.grpc.internal.BackoffPolicy;
|
||||
import io.grpc.internal.TimeProvider;
|
||||
import io.grpc.xds.internal.sds.trust.CertificateUtils;
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.StringWriter;
|
||||
import java.security.KeyPair;
|
||||
import java.security.KeyPairGenerator;
|
||||
import java.security.NoSuchAlgorithmException;
|
||||
import java.security.cert.CertificateException;
|
||||
import java.security.cert.X509Certificate;
|
||||
import java.util.ArrayList;
|
||||
import java.util.EnumSet;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
import java.util.concurrent.ScheduledExecutorService;
|
||||
import java.util.concurrent.ScheduledFuture;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.TimeoutException;
|
||||
import java.util.logging.Level;
|
||||
import java.util.logging.Logger;
|
||||
import javax.security.auth.x500.X500Principal;
|
||||
import org.bouncycastle.openssl.jcajce.JcaPEMWriter;
|
||||
import org.bouncycastle.operator.ContentSigner;
|
||||
import org.bouncycastle.operator.OperatorCreationException;
|
||||
import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder;
|
||||
import org.bouncycastle.pkcs.PKCS10CertificationRequest;
|
||||
import org.bouncycastle.pkcs.PKCS10CertificationRequestBuilder;
|
||||
import org.bouncycastle.pkcs.jcajce.JcaPKCS10CertificationRequestBuilder;
|
||||
import org.bouncycastle.util.io.pem.PemObject;
|
||||
|
||||
/** Implementation of {@link CertificateProvider} for the Google Mesh CA. */
|
||||
final class MeshCaCertificateProvider extends CertificateProvider {
|
||||
private static final Logger logger = Logger.getLogger(MeshCaCertificateProvider.class.getName());
|
||||
|
||||
MeshCaCertificateProvider(DistributorWatcher watcher, boolean notifyCertUpdates,
|
||||
String meshCaUrl, String zone, long validitySeconds,
|
||||
int keySize, String alg, String signatureAlg, MeshCaChannelFactory meshCaChannelFactory,
|
||||
BackoffPolicy.Provider backoffPolicyProvider, long renewalGracePeriodSeconds,
|
||||
int maxRetryAttempts, GoogleCredentials oauth2Creds) {
|
||||
MeshCaCertificateProvider(
|
||||
DistributorWatcher watcher,
|
||||
boolean notifyCertUpdates,
|
||||
String meshCaUrl,
|
||||
String zone,
|
||||
long validitySeconds,
|
||||
int keySize,
|
||||
String alg,
|
||||
String signatureAlg, MeshCaChannelFactory meshCaChannelFactory,
|
||||
BackoffPolicy.Provider backoffPolicyProvider,
|
||||
long renewalGracePeriodSeconds,
|
||||
int maxRetryAttempts,
|
||||
GoogleCredentials oauth2Creds,
|
||||
ScheduledExecutorService scheduledExecutorService,
|
||||
TimeProvider timeProvider,
|
||||
long rpcTimeoutMillis) {
|
||||
super(watcher, notifyCertUpdates);
|
||||
this.meshCaUrl = checkNotNull(meshCaUrl, "meshCaUrl");
|
||||
checkArgument(
|
||||
validitySeconds > INITIAL_DELAY_SECONDS,
|
||||
"validitySeconds must be greater than " + INITIAL_DELAY_SECONDS);
|
||||
this.validitySeconds = validitySeconds;
|
||||
this.keySize = keySize;
|
||||
this.alg = checkNotNull(alg, "alg");
|
||||
this.signatureAlg = checkNotNull(signatureAlg, "signatureAlg");
|
||||
this.meshCaChannelFactory = checkNotNull(meshCaChannelFactory, "meshCaChannelFactory");
|
||||
this.backoffPolicyProvider = checkNotNull(backoffPolicyProvider, "backoffPolicyProvider");
|
||||
checkArgument(
|
||||
renewalGracePeriodSeconds > 0L && renewalGracePeriodSeconds < validitySeconds,
|
||||
"renewalGracePeriodSeconds should be between 0 and " + validitySeconds);
|
||||
this.renewalGracePeriodSeconds = renewalGracePeriodSeconds;
|
||||
checkArgument(maxRetryAttempts >= 0, "maxRetryAttempts must be >= 0");
|
||||
this.maxRetryAttempts = maxRetryAttempts;
|
||||
this.oauth2Creds = checkNotNull(oauth2Creds, "oauth2Creds");
|
||||
this.scheduledExecutorService =
|
||||
checkNotNull(scheduledExecutorService, "scheduledExecutorService");
|
||||
this.timeProvider = checkNotNull(timeProvider, "timeProvider");
|
||||
this.headerInterceptor = new ZoneInfoClientInterceptor(checkNotNull(zone, "zone"));
|
||||
this.syncContext = createSynchronizationContext(meshCaUrl);
|
||||
this.rpcTimeoutMillis = rpcTimeoutMillis;
|
||||
}
|
||||
|
||||
private SynchronizationContext createSynchronizationContext(String details) {
|
||||
final InternalLogId logId = InternalLogId.allocate("MeshCaCertificateProvider", details);
|
||||
return new SynchronizationContext(
|
||||
new Thread.UncaughtExceptionHandler() {
|
||||
private boolean panicMode;
|
||||
|
||||
@Override
|
||||
public void uncaughtException(Thread t, Throwable e) {
|
||||
logger.log(
|
||||
Level.SEVERE,
|
||||
"[" + logId + "] Uncaught exception in the SynchronizationContext. Panic!",
|
||||
e);
|
||||
panic(e);
|
||||
}
|
||||
|
||||
void panic(final Throwable t) {
|
||||
if (panicMode) {
|
||||
// Preserve the first panic information
|
||||
return;
|
||||
}
|
||||
panicMode = true;
|
||||
close();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public void start() {
|
||||
// TODO implement
|
||||
scheduleNextRefreshCertificate(INITIAL_DELAY_SECONDS);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
// TODO implement
|
||||
if (scheduledHandle != null) {
|
||||
scheduledHandle.cancel();
|
||||
scheduledHandle = null;
|
||||
}
|
||||
getWatcher().close();
|
||||
}
|
||||
|
||||
private void scheduleNextRefreshCertificate(long delayInSeconds) {
|
||||
if (scheduledHandle != null && scheduledHandle.isPending()) {
|
||||
logger.log(Level.SEVERE, "Pending task found: inconsistent state in scheduledHandle!");
|
||||
scheduledHandle.cancel();
|
||||
}
|
||||
RefreshCertificateTask runnable = new RefreshCertificateTask();
|
||||
scheduledHandle = syncContext.schedule(
|
||||
runnable, delayInSeconds, TimeUnit.SECONDS, scheduledExecutorService);
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
void refreshCertificate()
|
||||
throws NoSuchAlgorithmException, IOException, OperatorCreationException {
|
||||
long refreshDelaySeconds = computeRefreshSecondsFromCurrentCertExpiry();
|
||||
ManagedChannel channel = meshCaChannelFactory.createChannel(meshCaUrl);
|
||||
try {
|
||||
String uniqueReqIdForAllRetries = UUID.randomUUID().toString();
|
||||
Duration duration = Duration.newBuilder().setSeconds(validitySeconds).build();
|
||||
KeyPair keyPair = generateKeyPair();
|
||||
String csr = generateCsr(keyPair);
|
||||
MeshCertificateServiceGrpc.MeshCertificateServiceBlockingStub stub =
|
||||
createStubToMeshCa(channel);
|
||||
List<X509Certificate> x509Chain = makeRequestWithRetries(stub, uniqueReqIdForAllRetries,
|
||||
duration, csr);
|
||||
if (x509Chain != null) {
|
||||
refreshDelaySeconds =
|
||||
computeDelaySecondsToCertExpiry(x509Chain.get(0)) - renewalGracePeriodSeconds;
|
||||
getWatcher().updateCertificate(keyPair.getPrivate(), x509Chain);
|
||||
getWatcher().updateTrustedRoots(ImmutableList.of(x509Chain.get(x509Chain.size() - 1)));
|
||||
}
|
||||
} finally {
|
||||
shutdownChannel(channel);
|
||||
scheduleNextRefreshCertificate(refreshDelaySeconds);
|
||||
}
|
||||
}
|
||||
|
||||
private MeshCertificateServiceGrpc.MeshCertificateServiceBlockingStub createStubToMeshCa(
|
||||
ManagedChannel channel) {
|
||||
return MeshCertificateServiceGrpc
|
||||
.newBlockingStub(channel)
|
||||
.withCallCredentials(MoreCallCredentials.from(oauth2Creds))
|
||||
.withInterceptors(headerInterceptor);
|
||||
}
|
||||
|
||||
private List<X509Certificate> makeRequestWithRetries(
|
||||
MeshCertificateServiceGrpc.MeshCertificateServiceBlockingStub stub,
|
||||
String reqId,
|
||||
Duration duration,
|
||||
String csr) {
|
||||
Meshca.MeshCertificateRequest request =
|
||||
Meshca.MeshCertificateRequest.newBuilder()
|
||||
.setValidity(duration)
|
||||
.setCsr(csr)
|
||||
.setRequestId(reqId)
|
||||
.build();
|
||||
|
||||
BackoffPolicy backoffPolicy = backoffPolicyProvider.get();
|
||||
Throwable lastException = null;
|
||||
for (int i = 0; i <= maxRetryAttempts; i++) {
|
||||
try {
|
||||
Meshca.MeshCertificateResponse response =
|
||||
stub.withDeadlineAfter(rpcTimeoutMillis, TimeUnit.MILLISECONDS)
|
||||
.createCertificate(request);
|
||||
return getX509CertificatesFromResponse(response);
|
||||
} catch (Throwable t) {
|
||||
if (!retriable(t)) {
|
||||
generateErrorIfCurrentCertExpired(t);
|
||||
return null;
|
||||
}
|
||||
lastException = t;
|
||||
sleepForNanos(backoffPolicy.nextBackoffNanos());
|
||||
}
|
||||
}
|
||||
generateErrorIfCurrentCertExpired(lastException);
|
||||
return null;
|
||||
}
|
||||
|
||||
private void sleepForNanos(long nanos) {
|
||||
ScheduledFuture<?> future = scheduledExecutorService.schedule(new Runnable() {
|
||||
@Override
|
||||
public void run() {
|
||||
// do nothing
|
||||
}
|
||||
}, nanos, TimeUnit.NANOSECONDS);
|
||||
try {
|
||||
future.get(nanos, TimeUnit.NANOSECONDS);
|
||||
} catch (InterruptedException ie) {
|
||||
logger.log(Level.SEVERE, "Inside sleep", ie);
|
||||
Thread.currentThread().interrupt();
|
||||
} catch (ExecutionException | TimeoutException ex) {
|
||||
logger.log(Level.SEVERE, "Inside sleep", ex);
|
||||
}
|
||||
}
|
||||
|
||||
private static boolean retriable(Throwable t) {
|
||||
return RETRIABLE_CODES.contains(Status.fromThrowable(t).getCode());
|
||||
}
|
||||
|
||||
private void generateErrorIfCurrentCertExpired(Throwable t) {
|
||||
X509Certificate currentCert = getWatcher().getLastIdentityCert();
|
||||
if (currentCert != null) {
|
||||
long delaySeconds = computeDelaySecondsToCertExpiry(currentCert);
|
||||
if (delaySeconds > INITIAL_DELAY_SECONDS) {
|
||||
return;
|
||||
}
|
||||
getWatcher().clearValues();
|
||||
}
|
||||
getWatcher().onError(Status.fromThrowable(t));
|
||||
}
|
||||
|
||||
private KeyPair generateKeyPair() throws NoSuchAlgorithmException {
|
||||
KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance(alg);
|
||||
keyPairGenerator.initialize(keySize);
|
||||
return keyPairGenerator.generateKeyPair();
|
||||
}
|
||||
|
||||
private String generateCsr(KeyPair pair) throws IOException, OperatorCreationException {
|
||||
PKCS10CertificationRequestBuilder p10Builder =
|
||||
new JcaPKCS10CertificationRequestBuilder(
|
||||
new X500Principal("CN=EXAMPLE.COM"), pair.getPublic());
|
||||
JcaContentSignerBuilder csBuilder = new JcaContentSignerBuilder(signatureAlg);
|
||||
ContentSigner signer = csBuilder.build(pair.getPrivate());
|
||||
PKCS10CertificationRequest csr = p10Builder.build(signer);
|
||||
PemObject pemObject = new PemObject("NEW CERTIFICATE REQUEST", csr.getEncoded());
|
||||
try (StringWriter str = new StringWriter()) {
|
||||
try (JcaPEMWriter pemWriter = new JcaPEMWriter(str)) {
|
||||
pemWriter.writeObject(pemObject);
|
||||
}
|
||||
return str.toString();
|
||||
}
|
||||
}
|
||||
|
||||
/** Compute refresh interval as half of interval to current cert expiry. */
|
||||
private long computeRefreshSecondsFromCurrentCertExpiry() {
|
||||
X509Certificate lastCert = getWatcher().getLastIdentityCert();
|
||||
if (lastCert == null) {
|
||||
return INITIAL_DELAY_SECONDS;
|
||||
}
|
||||
long delayToCertExpirySeconds = computeDelaySecondsToCertExpiry(lastCert) / 2;
|
||||
return Math.max(delayToCertExpirySeconds, INITIAL_DELAY_SECONDS);
|
||||
}
|
||||
|
||||
@SuppressWarnings("JdkObsolete")
|
||||
private long computeDelaySecondsToCertExpiry(X509Certificate lastCert) {
|
||||
checkNotNull(lastCert, "lastCert");
|
||||
return TimeUnit.NANOSECONDS.toSeconds(
|
||||
TimeUnit.MILLISECONDS.toNanos(lastCert.getNotAfter().getTime()) - timeProvider
|
||||
.currentTimeNanos());
|
||||
}
|
||||
|
||||
private static void shutdownChannel(ManagedChannel channel) {
|
||||
channel.shutdown();
|
||||
try {
|
||||
channel.awaitTermination(10, TimeUnit.SECONDS);
|
||||
} catch (InterruptedException ex) {
|
||||
logger.log(Level.SEVERE, "awaiting channel Termination", ex);
|
||||
channel.shutdownNow();
|
||||
Thread.currentThread().interrupt();
|
||||
}
|
||||
}
|
||||
|
||||
private List<X509Certificate> getX509CertificatesFromResponse(
|
||||
Meshca.MeshCertificateResponse response) throws CertificateException, IOException {
|
||||
List<String> certChain = response.getCertChainList();
|
||||
List<X509Certificate> x509Chain = new ArrayList<>(certChain.size());
|
||||
for (String certString : certChain) {
|
||||
try (ByteArrayInputStream bais = new ByteArrayInputStream(certString.getBytes(UTF_8))) {
|
||||
x509Chain.add(CertificateUtils.toX509Certificate(bais));
|
||||
}
|
||||
}
|
||||
return x509Chain;
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
class RefreshCertificateTask implements Runnable {
|
||||
@Override
|
||||
public void run() {
|
||||
try {
|
||||
refreshCertificate();
|
||||
} catch (NoSuchAlgorithmException | OperatorCreationException | IOException ex) {
|
||||
logger.log(Level.SEVERE, "refreshing certificate", ex);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Factory for creating channels to MeshCA sever. */
|
||||
|
@ -94,7 +395,10 @@ final class MeshCaCertificateProvider extends CertificateProvider {
|
|||
BackoffPolicy.Provider backoffPolicyProvider,
|
||||
long renewalGracePeriodSeconds,
|
||||
int maxRetryAttempts,
|
||||
GoogleCredentials oauth2Creds) {
|
||||
GoogleCredentials oauth2Creds,
|
||||
ScheduledExecutorService scheduledExecutorService,
|
||||
TimeProvider timeProvider,
|
||||
long rpcTimeoutMillis) {
|
||||
return new MeshCaCertificateProvider(
|
||||
watcher,
|
||||
notifyCertUpdates,
|
||||
|
@ -108,7 +412,10 @@ final class MeshCaCertificateProvider extends CertificateProvider {
|
|||
backoffPolicyProvider,
|
||||
renewalGracePeriodSeconds,
|
||||
maxRetryAttempts,
|
||||
oauth2Creds);
|
||||
oauth2Creds,
|
||||
scheduledExecutorService,
|
||||
timeProvider,
|
||||
rpcTimeoutMillis);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -129,6 +436,64 @@ final class MeshCaCertificateProvider extends CertificateProvider {
|
|||
BackoffPolicy.Provider backoffPolicyProvider,
|
||||
long renewalGracePeriodSeconds,
|
||||
int maxRetryAttempts,
|
||||
GoogleCredentials oauth2Creds);
|
||||
GoogleCredentials oauth2Creds,
|
||||
ScheduledExecutorService scheduledExecutorService,
|
||||
TimeProvider timeProvider,
|
||||
long rpcTimeoutMillis);
|
||||
}
|
||||
|
||||
private class ZoneInfoClientInterceptor implements ClientInterceptor {
|
||||
private final String zone;
|
||||
|
||||
ZoneInfoClientInterceptor(String zone) {
|
||||
this.zone = zone;
|
||||
}
|
||||
|
||||
@Override
|
||||
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
|
||||
MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
|
||||
return new ForwardingClientCall.SimpleForwardingClientCall<ReqT, RespT>(
|
||||
next.newCall(method, callOptions)) {
|
||||
|
||||
@Override
|
||||
public void start(Listener<RespT> responseListener, Metadata headers) {
|
||||
headers.put(KEY_FOR_ZONE_INFO, zone);
|
||||
super.start(responseListener, headers);
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
static final Metadata.Key<String> KEY_FOR_ZONE_INFO =
|
||||
Metadata.Key.of("x-goog-request-params", Metadata.ASCII_STRING_MARSHALLER);
|
||||
@VisibleForTesting
|
||||
static final long INITIAL_DELAY_SECONDS = 4L;
|
||||
|
||||
private static final EnumSet<Status.Code> RETRIABLE_CODES =
|
||||
EnumSet.of(
|
||||
CANCELLED,
|
||||
UNKNOWN,
|
||||
DEADLINE_EXCEEDED,
|
||||
RESOURCE_EXHAUSTED,
|
||||
ABORTED,
|
||||
INTERNAL,
|
||||
UNAVAILABLE);
|
||||
|
||||
private final SynchronizationContext syncContext;
|
||||
private final ScheduledExecutorService scheduledExecutorService;
|
||||
private final int maxRetryAttempts;
|
||||
private final ZoneInfoClientInterceptor headerInterceptor;
|
||||
private final BackoffPolicy.Provider backoffPolicyProvider;
|
||||
private final String meshCaUrl;
|
||||
private final long validitySeconds;
|
||||
private final long renewalGracePeriodSeconds;
|
||||
private final int keySize;
|
||||
private final String alg;
|
||||
private final String signatureAlg;
|
||||
private final GoogleCredentials oauth2Creds;
|
||||
private final TimeProvider timeProvider;
|
||||
private final MeshCaChannelFactory meshCaChannelFactory;
|
||||
@VisibleForTesting SynchronizationContext.ScheduledHandle scheduledHandle;
|
||||
private final long rpcTimeoutMillis;
|
||||
}
|
||||
|
|
|
@ -21,10 +21,15 @@ import static com.google.common.base.Preconditions.checkNotNull;
|
|||
import static com.google.common.base.Preconditions.checkState;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import com.google.common.util.concurrent.ThreadFactoryBuilder;
|
||||
import io.grpc.internal.BackoffPolicy;
|
||||
import io.grpc.internal.ExponentialBackoffPolicy;
|
||||
import io.grpc.internal.TimeProvider;
|
||||
import io.grpc.xds.internal.sts.StsCredentials;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.ScheduledExecutorService;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
|
||||
|
@ -46,16 +51,20 @@ final class MeshCaCertificateProviderProvider implements CertificateProviderProv
|
|||
private static final String STS_URL_KEY = "stsUrl";
|
||||
private static final String GKE_SA_JWT_LOCATION_KEY = "gkeSaJwtLocation";
|
||||
|
||||
static final String MESHCA_URL_DEFAULT = "meshca.googleapis.com";
|
||||
static final long RPC_TIMEOUT_SECONDS_DEFAULT = 5L;
|
||||
static final long CERT_VALIDITY_SECONDS_DEFAULT = 9L * 3600L; // 9 hours
|
||||
static final long RENEWAL_GRACE_PERIOD_SECONDS_DEFAULT = 1L * 3600L; // 1 hour
|
||||
static final String KEY_ALGO_DEFAULT = "RSA"; // aka keyType
|
||||
static final int KEY_SIZE_DEFAULT = 2048;
|
||||
static final String SIGNATURE_ALGO_DEFAULT = "SHA256withRSA";
|
||||
static final int MAX_RETRY_ATTEMPTS_DEFAULT = 3;
|
||||
@VisibleForTesting static final String MESHCA_URL_DEFAULT = "meshca.googleapis.com";
|
||||
@VisibleForTesting static final long RPC_TIMEOUT_SECONDS_DEFAULT = 5L;
|
||||
@VisibleForTesting static final long CERT_VALIDITY_SECONDS_DEFAULT = 9L * 3600L;
|
||||
@VisibleForTesting static final long RENEWAL_GRACE_PERIOD_SECONDS_DEFAULT = 1L * 3600L;
|
||||
@VisibleForTesting static final String KEY_ALGO_DEFAULT = "RSA"; // aka keyType
|
||||
@VisibleForTesting static final int KEY_SIZE_DEFAULT = 2048;
|
||||
@VisibleForTesting static final String SIGNATURE_ALGO_DEFAULT = "SHA256withRSA";
|
||||
@VisibleForTesting static final int MAX_RETRY_ATTEMPTS_DEFAULT = 3;
|
||||
@VisibleForTesting
|
||||
static final String STS_URL_DEFAULT = "https://securetoken.googleapis.com/v1/identitybindingtoken";
|
||||
|
||||
@VisibleForTesting
|
||||
static final long RPC_TIMEOUT_SECONDS = 10L;
|
||||
|
||||
private static final Pattern CLUSTER_URL_PATTERN = Pattern
|
||||
.compile(".*/projects/(.*)/locations/(.*)/clusters/.*");
|
||||
|
||||
|
@ -70,23 +79,32 @@ final class MeshCaCertificateProviderProvider implements CertificateProviderProv
|
|||
StsCredentials.Factory.getInstance(),
|
||||
MeshCaCertificateProvider.MeshCaChannelFactory.getInstance(),
|
||||
new ExponentialBackoffPolicy.Provider(),
|
||||
MeshCaCertificateProvider.Factory.getInstance()));
|
||||
MeshCaCertificateProvider.Factory.getInstance(),
|
||||
ScheduledExecutorServiceFactory.DEFAULT_INSTANCE,
|
||||
TimeProvider.SYSTEM_TIME_PROVIDER));
|
||||
}
|
||||
|
||||
final StsCredentials.Factory stsCredentialsFactory;
|
||||
final MeshCaCertificateProvider.MeshCaChannelFactory meshCaChannelFactory;
|
||||
final BackoffPolicy.Provider backoffPolicyProvider;
|
||||
final MeshCaCertificateProvider.Factory meshCaCertificateProviderFactory;
|
||||
final ScheduledExecutorServiceFactory scheduledExecutorServiceFactory;
|
||||
final TimeProvider timeProvider;
|
||||
|
||||
@VisibleForTesting
|
||||
MeshCaCertificateProviderProvider(StsCredentials.Factory stsCredentialsFactory,
|
||||
MeshCaCertificateProviderProvider(
|
||||
StsCredentials.Factory stsCredentialsFactory,
|
||||
MeshCaCertificateProvider.MeshCaChannelFactory meshCaChannelFactory,
|
||||
BackoffPolicy.Provider backoffPolicyProvider,
|
||||
MeshCaCertificateProvider.Factory meshCaCertificateProviderFactory) {
|
||||
MeshCaCertificateProvider.Factory meshCaCertificateProviderFactory,
|
||||
ScheduledExecutorServiceFactory scheduledExecutorServiceFactory,
|
||||
TimeProvider timeProvider) {
|
||||
this.stsCredentialsFactory = stsCredentialsFactory;
|
||||
this.meshCaChannelFactory = meshCaChannelFactory;
|
||||
this.backoffPolicyProvider = backoffPolicyProvider;
|
||||
this.meshCaCertificateProviderFactory = meshCaCertificateProviderFactory;
|
||||
this.scheduledExecutorServiceFactory = scheduledExecutorServiceFactory;
|
||||
this.timeProvider = timeProvider;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -106,12 +124,23 @@ final class MeshCaCertificateProviderProvider implements CertificateProviderProv
|
|||
StsCredentials stsCredentials = stsCredentialsFactory
|
||||
.create(configObj.stsUrl, audience, configObj.gkeSaJwtLocation);
|
||||
|
||||
return meshCaCertificateProviderFactory.create(watcher, notifyCertUpdates, configObj.meshCaUrl,
|
||||
return meshCaCertificateProviderFactory.create(
|
||||
watcher,
|
||||
notifyCertUpdates,
|
||||
configObj.meshCaUrl,
|
||||
configObj.zone,
|
||||
configObj.certValiditySeconds, configObj.keySize, configObj.keyAlgo,
|
||||
configObj.certValiditySeconds,
|
||||
configObj.keySize,
|
||||
configObj.keyAlgo,
|
||||
configObj.signatureAlgo,
|
||||
meshCaChannelFactory, backoffPolicyProvider,
|
||||
configObj.renewalGracePeriodSeconds, configObj.maxRetryAttempts, stsCredentials);
|
||||
meshCaChannelFactory,
|
||||
backoffPolicyProvider,
|
||||
configObj.renewalGracePeriodSeconds,
|
||||
configObj.maxRetryAttempts,
|
||||
stsCredentials,
|
||||
scheduledExecutorServiceFactory.create(configObj.meshCaUrl),
|
||||
timeProvider,
|
||||
TimeUnit.SECONDS.toMillis(RPC_TIMEOUT_SECONDS));
|
||||
}
|
||||
|
||||
private static Config validateAndTranslateConfig(Object config) {
|
||||
|
@ -177,6 +206,28 @@ final class MeshCaCertificateProviderProvider implements CertificateProviderProv
|
|||
configObj.zone = matcher.group(2);
|
||||
}
|
||||
|
||||
abstract static class ScheduledExecutorServiceFactory {
|
||||
|
||||
private static final ScheduledExecutorServiceFactory DEFAULT_INSTANCE =
|
||||
new ScheduledExecutorServiceFactory() {
|
||||
|
||||
@Override
|
||||
ScheduledExecutorService create(String serverUri) {
|
||||
return Executors.newSingleThreadScheduledExecutor(
|
||||
new ThreadFactoryBuilder()
|
||||
.setNameFormat("meshca-" + serverUri + "-%d")
|
||||
.setDaemon(true)
|
||||
.build());
|
||||
}
|
||||
};
|
||||
|
||||
static ScheduledExecutorServiceFactory getInstance() {
|
||||
return DEFAULT_INSTANCE;
|
||||
}
|
||||
|
||||
abstract ScheduledExecutorService create(String serverUri);
|
||||
}
|
||||
|
||||
/** POJO class for storing various config values. */
|
||||
@VisibleForTesting
|
||||
static class Config {
|
||||
|
|
|
@ -30,7 +30,7 @@ import java.util.Collection;
|
|||
/**
|
||||
* Contains certificate utility method(s).
|
||||
*/
|
||||
final class CertificateUtils {
|
||||
public final class CertificateUtils {
|
||||
|
||||
private static CertificateFactory factory;
|
||||
|
||||
|
@ -46,19 +46,26 @@ final class CertificateUtils {
|
|||
* @param file a {@link File} containing the cert data
|
||||
*/
|
||||
static X509Certificate[] toX509Certificates(File file) throws CertificateException, IOException {
|
||||
FileInputStream fis = new FileInputStream(file);
|
||||
return toX509Certificates(new BufferedInputStream(fis));
|
||||
try (FileInputStream fis = new FileInputStream(file);
|
||||
BufferedInputStream bis = new BufferedInputStream(fis)) {
|
||||
return toX509Certificates(bis);
|
||||
}
|
||||
}
|
||||
|
||||
static synchronized X509Certificate[] toX509Certificates(InputStream inputStream)
|
||||
throws CertificateException, IOException {
|
||||
initInstance();
|
||||
try {
|
||||
Collection<? extends Certificate> certs = factory.generateCertificates(inputStream);
|
||||
return certs.toArray(new X509Certificate[0]);
|
||||
} finally {
|
||||
inputStream.close();
|
||||
}
|
||||
Collection<? extends Certificate> certs = factory.generateCertificates(inputStream);
|
||||
return certs.toArray(new X509Certificate[0]);
|
||||
|
||||
}
|
||||
|
||||
/** See {@link CertificateFactory#generateCertificate(InputStream)}. */
|
||||
public static synchronized X509Certificate toX509Certificate(InputStream inputStream)
|
||||
throws CertificateException, IOException {
|
||||
initInstance();
|
||||
Certificate cert = factory.generateCertificate(inputStream);
|
||||
return (X509Certificate) cert;
|
||||
}
|
||||
|
||||
private CertificateUtils() {}
|
||||
|
|
|
@ -27,6 +27,7 @@ import io.grpc.xds.internal.sds.TlsContextManagerImpl;
|
|||
import io.netty.handler.ssl.util.SimpleTrustManagerFactory;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.security.KeyStore;
|
||||
import java.security.KeyStoreException;
|
||||
import java.security.NoSuchAlgorithmException;
|
||||
|
@ -69,8 +70,10 @@ public final class SdsTrustManagerFactory extends SimpleTrustManagerFactory {
|
|||
"trustedCa.file-name in certificateValidationContext cannot be empty");
|
||||
return CertificateUtils.toX509Certificates(new File(certsFile));
|
||||
} else if (specifierCase == SpecifierCase.INLINE_BYTES) {
|
||||
return CertificateUtils.toX509Certificates(
|
||||
certificateValidationContext.getTrustedCa().getInlineBytes().newInput());
|
||||
try (InputStream is =
|
||||
certificateValidationContext.getTrustedCa().getInlineBytes().newInput()) {
|
||||
return CertificateUtils.toX509Certificates(is);
|
||||
}
|
||||
} else {
|
||||
throw new IllegalArgumentException("Not supported: " + specifierCase);
|
||||
}
|
||||
|
|
|
@ -169,7 +169,7 @@ public class CertificateProviderStoreTest {
|
|||
(TestCertificateProvider) handle1.certProvider;
|
||||
assertThat(testCertificateProvider.startCalled).isEqualTo(1);
|
||||
CertificateProvider.DistributorWatcher distWatcher = testCertificateProvider.getWatcher();
|
||||
assertThat(distWatcher.downsstreamWatchers).hasSize(2);
|
||||
assertThat(distWatcher.downstreamWatchers).hasSize(2);
|
||||
PrivateKey testKey = mock(PrivateKey.class);
|
||||
X509Certificate cert = mock(X509Certificate.class);
|
||||
List<X509Certificate> testList = ImmutableList.of(cert);
|
||||
|
@ -185,7 +185,7 @@ public class CertificateProviderStoreTest {
|
|||
reset(mockWatcher2);
|
||||
handle1.close();
|
||||
assertThat(testCertificateProvider.closeCalled).isEqualTo(0);
|
||||
assertThat(distWatcher.downsstreamWatchers).hasSize(1);
|
||||
assertThat(distWatcher.downstreamWatchers).hasSize(1);
|
||||
testCertificateProvider.getWatcher().updateCertificate(testKey, testList);
|
||||
verify(mockWatcher1, never())
|
||||
.updateCertificate(any(PrivateKey.class), anyListOf(X509Certificate.class));
|
||||
|
@ -221,7 +221,7 @@ public class CertificateProviderStoreTest {
|
|||
CertificateProvider.Watcher mockWatcher2 = mock(CertificateProvider.Watcher.class);
|
||||
CertificateProviderStore.Handle unused = certificateProviderStore.createOrGetProvider(
|
||||
"cert-name1", "plugin1", "config", mockWatcher2, true);
|
||||
assertThat(distWatcher.downsstreamWatchers).hasSize(2);
|
||||
assertThat(distWatcher.downstreamWatchers).hasSize(2);
|
||||
// updates sent to the second watcher
|
||||
verify(mockWatcher2, times(1)).updateCertificate(eq(testKey), eq(testList));
|
||||
verify(mockWatcher2, times(1)).updateTrustedRoots(eq(testList));
|
||||
|
@ -323,9 +323,9 @@ public class CertificateProviderStoreTest {
|
|||
assertThat(testCertificateProvider2.certProviderProvider)
|
||||
.isSameInstanceAs(certProviderProvider2);
|
||||
CertificateProvider.DistributorWatcher distWatcher1 = testCertificateProvider1.getWatcher();
|
||||
assertThat(distWatcher1.downsstreamWatchers).hasSize(1);
|
||||
assertThat(distWatcher1.downstreamWatchers).hasSize(1);
|
||||
CertificateProvider.DistributorWatcher distWatcher2 = testCertificateProvider2.getWatcher();
|
||||
assertThat(distWatcher2.downsstreamWatchers).hasSize(1);
|
||||
assertThat(distWatcher2.downstreamWatchers).hasSize(1);
|
||||
PrivateKey testKey1 = mock(PrivateKey.class);
|
||||
X509Certificate cert1 = mock(X509Certificate.class);
|
||||
List<X509Certificate> testList1 = ImmutableList.of(cert1);
|
||||
|
|
|
@ -17,19 +17,25 @@
|
|||
package io.grpc.xds.internal.certprovider;
|
||||
|
||||
import static com.google.common.truth.Truth.assertThat;
|
||||
import static io.grpc.xds.internal.certprovider.MeshCaCertificateProviderProvider.RPC_TIMEOUT_SECONDS;
|
||||
import static org.junit.Assert.fail;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.ArgumentMatchers.isNull;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.times;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import com.google.auth.oauth2.GoogleCredentials;
|
||||
import io.grpc.internal.BackoffPolicy;
|
||||
import io.grpc.internal.ExponentialBackoffPolicy;
|
||||
import io.grpc.internal.TimeProvider;
|
||||
import io.grpc.xds.internal.sts.StsCredentials;
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ScheduledExecutorService;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
|
@ -59,6 +65,14 @@ public class MeshCaCertificateProviderProviderTest {
|
|||
|
||||
@Mock
|
||||
MeshCaCertificateProvider.Factory meshCaCertificateProviderFactory;
|
||||
|
||||
@Mock
|
||||
private MeshCaCertificateProviderProvider.ScheduledExecutorServiceFactory
|
||||
scheduledExecutorServiceFactory;
|
||||
|
||||
@Mock
|
||||
private TimeProvider timeProvider;
|
||||
|
||||
private MeshCaCertificateProviderProvider provider;
|
||||
|
||||
@Before
|
||||
|
@ -69,7 +83,9 @@ public class MeshCaCertificateProviderProviderTest {
|
|||
stsCredentialsFactory,
|
||||
meshCaChannelFactory,
|
||||
backoffPolicyProvider,
|
||||
meshCaCertificateProviderFactory);
|
||||
meshCaCertificateProviderFactory,
|
||||
scheduledExecutorServiceFactory,
|
||||
timeProvider);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -94,6 +110,10 @@ public class MeshCaCertificateProviderProviderTest {
|
|||
CertificateProvider.DistributorWatcher distWatcher =
|
||||
new CertificateProvider.DistributorWatcher();
|
||||
Map<String, String> map = buildMinimalMap();
|
||||
ScheduledExecutorService mockService = mock(ScheduledExecutorService.class);
|
||||
when(scheduledExecutorServiceFactory.create(
|
||||
eq(MeshCaCertificateProviderProvider.MESHCA_URL_DEFAULT)))
|
||||
.thenReturn(mockService);
|
||||
provider.createCertificateProvider(map, distWatcher, true);
|
||||
verify(stsCredentialsFactory, times(1))
|
||||
.create(
|
||||
|
@ -114,7 +134,10 @@ public class MeshCaCertificateProviderProviderTest {
|
|||
eq(backoffPolicyProvider),
|
||||
eq(MeshCaCertificateProviderProvider.RENEWAL_GRACE_PERIOD_SECONDS_DEFAULT),
|
||||
eq(MeshCaCertificateProviderProvider.MAX_RETRY_ATTEMPTS_DEFAULT),
|
||||
(GoogleCredentials) isNull());
|
||||
(GoogleCredentials) isNull(),
|
||||
eq(mockService),
|
||||
eq(timeProvider),
|
||||
eq(TimeUnit.SECONDS.toMillis(RPC_TIMEOUT_SECONDS)));
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -150,7 +173,9 @@ public class MeshCaCertificateProviderProviderTest {
|
|||
CertificateProvider.DistributorWatcher distWatcher =
|
||||
new CertificateProvider.DistributorWatcher();
|
||||
Map<String, String> map = buildMinimalMap();
|
||||
map.put("gkeClusterUrl", "https://container.googleapis.com/v1/project/test-project1/locations/test-zone2/clusters/test-cluster3");
|
||||
map.put(
|
||||
"gkeClusterUrl",
|
||||
"https://container.googleapis.com/v1/project/test-project1/locations/test-zone2/clusters/test-cluster3");
|
||||
try {
|
||||
provider.createCertificateProvider(map, distWatcher, true);
|
||||
fail("exception expected");
|
||||
|
@ -164,6 +189,9 @@ public class MeshCaCertificateProviderProviderTest {
|
|||
CertificateProvider.DistributorWatcher distWatcher =
|
||||
new CertificateProvider.DistributorWatcher();
|
||||
Map<String, String> map = buildFullMap();
|
||||
ScheduledExecutorService mockService = mock(ScheduledExecutorService.class);
|
||||
when(scheduledExecutorServiceFactory.create(eq(NON_DEFAULT_MESH_CA_URL)))
|
||||
.thenReturn(mockService);
|
||||
provider.createCertificateProvider(map, distWatcher, true);
|
||||
verify(stsCredentialsFactory, times(1))
|
||||
.create(
|
||||
|
@ -184,7 +212,10 @@ public class MeshCaCertificateProviderProviderTest {
|
|||
eq(backoffPolicyProvider),
|
||||
eq(4321L),
|
||||
eq(9),
|
||||
(GoogleCredentials) isNull());
|
||||
(GoogleCredentials) isNull(),
|
||||
eq(mockService),
|
||||
eq(timeProvider),
|
||||
eq(TimeUnit.SECONDS.toMillis(RPC_TIMEOUT_SECONDS)));
|
||||
}
|
||||
|
||||
private Map<String, String> buildFullMap() {
|
||||
|
|
|
@ -0,0 +1,538 @@
|
|||
/*
|
||||
* Copyright 2020 The gRPC Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.xds.internal.certprovider;
|
||||
|
||||
import static com.google.common.truth.Truth.assertThat;
|
||||
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE;
|
||||
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_0_PEM_FILE;
|
||||
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE;
|
||||
import static org.mockito.AdditionalAnswers.delegatesTo;
|
||||
import static org.mockito.Mockito.any;
|
||||
import static org.mockito.Mockito.doAnswer;
|
||||
import static org.mockito.Mockito.doReturn;
|
||||
import static org.mockito.Mockito.eq;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.never;
|
||||
import static org.mockito.Mockito.reset;
|
||||
import static org.mockito.Mockito.times;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import com.google.auth.http.AuthHttpConstants;
|
||||
import com.google.auth.oauth2.AccessToken;
|
||||
import com.google.auth.oauth2.GoogleCredentials;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.util.concurrent.MoreExecutors;
|
||||
import google.security.meshca.v1.MeshCertificateServiceGrpc;
|
||||
import google.security.meshca.v1.Meshca;
|
||||
import io.grpc.Context;
|
||||
import io.grpc.ManagedChannel;
|
||||
import io.grpc.Metadata;
|
||||
import io.grpc.ServerCall;
|
||||
import io.grpc.ServerCallHandler;
|
||||
import io.grpc.ServerInterceptor;
|
||||
import io.grpc.Status;
|
||||
import io.grpc.StatusRuntimeException;
|
||||
import io.grpc.SynchronizationContext;
|
||||
import io.grpc.inprocess.InProcessChannelBuilder;
|
||||
import io.grpc.inprocess.InProcessServerBuilder;
|
||||
import io.grpc.internal.BackoffPolicy;
|
||||
import io.grpc.internal.TimeProvider;
|
||||
import io.grpc.testing.GrpcCleanupRule;
|
||||
import io.grpc.xds.internal.certprovider.CertificateProvider.DistributorWatcher;
|
||||
import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil;
|
||||
import java.io.IOException;
|
||||
import java.security.NoSuchAlgorithmException;
|
||||
import java.security.PrivateKey;
|
||||
import java.security.cert.CertificateException;
|
||||
import java.security.cert.X509Certificate;
|
||||
import java.util.ArrayDeque;
|
||||
import java.util.Date;
|
||||
import java.util.List;
|
||||
import java.util.Queue;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
import java.util.concurrent.ScheduledExecutorService;
|
||||
import java.util.concurrent.ScheduledFuture;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.TimeoutException;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import org.bouncycastle.operator.OperatorCreationException;
|
||||
import org.junit.Before;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.JUnit4;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.mockito.ArgumentMatchers;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.MockitoAnnotations;
|
||||
import org.mockito.Spy;
|
||||
import org.mockito.invocation.InvocationOnMock;
|
||||
import org.mockito.stubbing.Answer;
|
||||
|
||||
/** Unit tests for {@link MeshCaCertificateProvider}. */
|
||||
@RunWith(JUnit4.class)
|
||||
public class MeshCaCertificateProviderTest {
|
||||
|
||||
private static final String TEST_STS_TOKEN = "test-stsToken";
|
||||
private static final long RENEWAL_GRACE_PERIOD_SECONDS = TimeUnit.HOURS.toSeconds(1L);
|
||||
private static final Metadata.Key<String> KEY_FOR_AUTHORIZATION =
|
||||
Metadata.Key.of(AuthHttpConstants.AUTHORIZATION, Metadata.ASCII_STRING_MARSHALLER);
|
||||
private static final String ZONE = "us-west2-a";
|
||||
private static final long START_DELAY = 200_000_000L; // 0.2 seconds
|
||||
private static final long[] DELAY_VALUES = {START_DELAY, START_DELAY * 2, START_DELAY * 4};
|
||||
private static final long RPC_TIMEOUT_MILLIS = 100L;
|
||||
/**
|
||||
* Expire time of cert SERVER_0_PEM_FILE.
|
||||
*/
|
||||
private static final long CERT0_EXPIRY_TIME_MILLIS = 1899853658000L;
|
||||
/**
|
||||
* Cert validity of 12 hours for the above cert.
|
||||
*/
|
||||
private static final long CERT0_VALIDITY_MILLIS = TimeUnit.MILLISECONDS
|
||||
.convert(12, TimeUnit.HOURS);
|
||||
/**
|
||||
* Compute current time based on cert expiry and cert validity.
|
||||
*/
|
||||
private static final long CURRENT_TIME_NANOS =
|
||||
TimeUnit.MILLISECONDS.toNanos(CERT0_EXPIRY_TIME_MILLIS - CERT0_VALIDITY_MILLIS);
|
||||
@Rule
|
||||
public final GrpcCleanupRule cleanupRule = new GrpcCleanupRule();
|
||||
|
||||
private static class ResponseToSend {
|
||||
Throwable getThrowable() {
|
||||
throw new UnsupportedOperationException("Called on " + getClass().getCanonicalName());
|
||||
}
|
||||
|
||||
List<String> getList() {
|
||||
throw new UnsupportedOperationException("Called on " + getClass().getCanonicalName());
|
||||
}
|
||||
}
|
||||
|
||||
private static class ResponseThrowable extends ResponseToSend {
|
||||
final Throwable throwableToSend;
|
||||
|
||||
ResponseThrowable(Throwable throwable) {
|
||||
throwableToSend = throwable;
|
||||
}
|
||||
|
||||
@Override
|
||||
Throwable getThrowable() {
|
||||
return throwableToSend;
|
||||
}
|
||||
}
|
||||
|
||||
private static class ResponseList extends ResponseToSend {
|
||||
final List<String> listToSend;
|
||||
|
||||
ResponseList(List<String> list) {
|
||||
listToSend = list;
|
||||
}
|
||||
|
||||
@Override
|
||||
List<String> getList() {
|
||||
return listToSend;
|
||||
}
|
||||
}
|
||||
|
||||
private final Queue<Meshca.MeshCertificateRequest> receivedRequests = new ArrayDeque<>();
|
||||
private final Queue<String> receivedStsCreds = new ArrayDeque<>();
|
||||
private final Queue<String> receivedZoneValues = new ArrayDeque<>();
|
||||
private final Queue<ResponseToSend> responsesToSend = new ArrayDeque<>();
|
||||
private final Queue<String> oauth2Tokens = new ArrayDeque<>();
|
||||
private final AtomicBoolean callEnded = new AtomicBoolean(true);
|
||||
|
||||
@Mock private MeshCertificateServiceGrpc.MeshCertificateServiceImplBase mockedMeshCaService;
|
||||
@Mock private CertificateProvider.Watcher mockWatcher;
|
||||
@Mock private BackoffPolicy.Provider backoffPolicyProvider;
|
||||
@Mock private BackoffPolicy backoffPolicy;
|
||||
@Spy private GoogleCredentials oauth2Creds;
|
||||
@Mock private ScheduledExecutorService timeService;
|
||||
@Mock private TimeProvider timeProvider;
|
||||
|
||||
private ManagedChannel channel;
|
||||
private MeshCaCertificateProvider provider;
|
||||
|
||||
@Before
|
||||
public void setUp() throws IOException {
|
||||
MockitoAnnotations.initMocks(this);
|
||||
when(backoffPolicyProvider.get()).thenReturn(backoffPolicy);
|
||||
when(backoffPolicy.nextBackoffNanos())
|
||||
.thenReturn(DELAY_VALUES[0], DELAY_VALUES[1], DELAY_VALUES[2]);
|
||||
doAnswer(
|
||||
new Answer<AccessToken>() {
|
||||
@Override
|
||||
public AccessToken answer(InvocationOnMock invocation) throws Throwable {
|
||||
return new AccessToken(
|
||||
oauth2Tokens.poll(), new Date(System.currentTimeMillis() + 1000L));
|
||||
}
|
||||
})
|
||||
.when(oauth2Creds)
|
||||
.refreshAccessToken();
|
||||
final String meshCaUri = InProcessServerBuilder.generateName();
|
||||
MeshCertificateServiceGrpc.MeshCertificateServiceImplBase meshCaServiceImpl =
|
||||
new MeshCertificateServiceGrpc.MeshCertificateServiceImplBase() {
|
||||
|
||||
@Override
|
||||
public void createCertificate(
|
||||
google.security.meshca.v1.Meshca.MeshCertificateRequest request,
|
||||
io.grpc.stub.StreamObserver<google.security.meshca.v1.Meshca.MeshCertificateResponse>
|
||||
responseObserver) {
|
||||
assertThat(callEnded.get()).isTrue(); // ensure previous call was ended
|
||||
callEnded.set(false);
|
||||
Context.current()
|
||||
.addListener(
|
||||
new Context.CancellationListener() {
|
||||
@Override
|
||||
public void cancelled(Context context) {
|
||||
callEnded.set(true);
|
||||
}
|
||||
},
|
||||
MoreExecutors.directExecutor());
|
||||
receivedRequests.offer(request);
|
||||
ResponseToSend response = responsesToSend.poll();
|
||||
if (response instanceof ResponseThrowable) {
|
||||
responseObserver.onError(response.getThrowable());
|
||||
} else if (response instanceof ResponseList) {
|
||||
List<String> certChainInResponse = response.getList();
|
||||
Meshca.MeshCertificateResponse responseToSend =
|
||||
Meshca.MeshCertificateResponse.newBuilder()
|
||||
.addAllCertChain(certChainInResponse)
|
||||
.build();
|
||||
responseObserver.onNext(responseToSend);
|
||||
responseObserver.onCompleted();
|
||||
} else {
|
||||
callEnded.set(true);
|
||||
}
|
||||
}
|
||||
};
|
||||
mockedMeshCaService =
|
||||
mock(
|
||||
MeshCertificateServiceGrpc.MeshCertificateServiceImplBase.class,
|
||||
delegatesTo(meshCaServiceImpl));
|
||||
ServerInterceptor interceptor =
|
||||
new ServerInterceptor() {
|
||||
@Override
|
||||
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
|
||||
ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next) {
|
||||
receivedStsCreds.offer(headers.get(KEY_FOR_AUTHORIZATION));
|
||||
receivedZoneValues.offer(headers.get(MeshCaCertificateProvider.KEY_FOR_ZONE_INFO));
|
||||
return next.startCall(call, headers);
|
||||
}
|
||||
};
|
||||
cleanupRule.register(
|
||||
InProcessServerBuilder.forName(meshCaUri)
|
||||
.addService(mockedMeshCaService)
|
||||
.intercept(interceptor)
|
||||
.directExecutor()
|
||||
.build()
|
||||
.start());
|
||||
channel =
|
||||
cleanupRule.register(InProcessChannelBuilder.forName(meshCaUri).directExecutor().build());
|
||||
MeshCaCertificateProvider.MeshCaChannelFactory channelFactory =
|
||||
new MeshCaCertificateProvider.MeshCaChannelFactory() {
|
||||
@Override
|
||||
ManagedChannel createChannel(String serverUri) {
|
||||
assertThat(serverUri).isEqualTo(meshCaUri);
|
||||
return channel;
|
||||
}
|
||||
};
|
||||
CertificateProvider.DistributorWatcher watcher = new CertificateProvider.DistributorWatcher();
|
||||
watcher.addWatcher(mockWatcher); //
|
||||
provider =
|
||||
new MeshCaCertificateProvider(
|
||||
watcher,
|
||||
true,
|
||||
meshCaUri,
|
||||
ZONE,
|
||||
TimeUnit.HOURS.toSeconds(9L),
|
||||
2048,
|
||||
"RSA",
|
||||
"SHA256withRSA",
|
||||
channelFactory,
|
||||
backoffPolicyProvider,
|
||||
RENEWAL_GRACE_PERIOD_SECONDS,
|
||||
MeshCaCertificateProviderProvider.MAX_RETRY_ATTEMPTS_DEFAULT,
|
||||
oauth2Creds,
|
||||
timeService,
|
||||
timeProvider,
|
||||
RPC_TIMEOUT_MILLIS);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void startAndClose() {
|
||||
ScheduledFuture<?> scheduledFuture = mock(ScheduledFuture.class);
|
||||
doReturn(scheduledFuture)
|
||||
.when(timeService)
|
||||
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
|
||||
provider.start();
|
||||
SynchronizationContext.ScheduledHandle savedScheduledHandle = provider.scheduledHandle;
|
||||
assertThat(savedScheduledHandle).isNotNull();
|
||||
assertThat(savedScheduledHandle.isPending()).isTrue();
|
||||
verify(timeService, times(1))
|
||||
.schedule(
|
||||
any(Runnable.class),
|
||||
eq(MeshCaCertificateProvider.INITIAL_DELAY_SECONDS),
|
||||
eq(TimeUnit.SECONDS));
|
||||
DistributorWatcher distWatcher = provider.getWatcher();
|
||||
assertThat(distWatcher.downstreamWatchers).hasSize(1);
|
||||
PrivateKey mockKey = mock(PrivateKey.class);
|
||||
X509Certificate mockCert = mock(X509Certificate.class);
|
||||
distWatcher.updateCertificate(mockKey, ImmutableList.of(mockCert));
|
||||
distWatcher.updateTrustedRoots(ImmutableList.of(mockCert));
|
||||
provider.close();
|
||||
assertThat(provider.scheduledHandle).isNull();
|
||||
assertThat(savedScheduledHandle.isPending()).isFalse();
|
||||
assertThat(distWatcher.downstreamWatchers).isEmpty();
|
||||
assertThat(distWatcher.getLastIdentityCert()).isNull();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void startTwice_noException() {
|
||||
ScheduledFuture<?> scheduledFuture = mock(ScheduledFuture.class);
|
||||
doReturn(scheduledFuture)
|
||||
.when(timeService)
|
||||
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
|
||||
provider.start();
|
||||
SynchronizationContext.ScheduledHandle savedScheduledHandle1 = provider.scheduledHandle;
|
||||
provider.start();
|
||||
SynchronizationContext.ScheduledHandle savedScheduledHandle2 = provider.scheduledHandle;
|
||||
assertThat(savedScheduledHandle2).isNotSameInstanceAs(savedScheduledHandle1);
|
||||
assertThat(savedScheduledHandle2.isPending()).isTrue();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void getCertificate()
|
||||
throws IOException, CertificateException, OperatorCreationException,
|
||||
NoSuchAlgorithmException {
|
||||
oauth2Tokens.offer(TEST_STS_TOKEN + "0");
|
||||
responsesToSend.offer(
|
||||
new ResponseList(ImmutableList.of(
|
||||
CommonTlsContextTestsUtil.getResourceContents(SERVER_0_PEM_FILE),
|
||||
CommonTlsContextTestsUtil.getResourceContents(SERVER_1_PEM_FILE),
|
||||
CommonTlsContextTestsUtil.getResourceContents(CA_PEM_FILE))));
|
||||
when(timeProvider.currentTimeNanos()).thenReturn(CURRENT_TIME_NANOS);
|
||||
ScheduledFuture<?> scheduledFuture = mock(ScheduledFuture.class);
|
||||
doReturn(scheduledFuture)
|
||||
.when(timeService)
|
||||
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
|
||||
provider.refreshCertificate();
|
||||
Meshca.MeshCertificateRequest receivedReq = receivedRequests.poll();
|
||||
assertThat(receivedReq.getValidity().getSeconds()).isEqualTo(TimeUnit.HOURS.toSeconds(9L));
|
||||
// cannot decode CSR: just check the PEM format delimiters
|
||||
String csr = receivedReq.getCsr();
|
||||
assertThat(csr).startsWith("-----BEGIN NEW CERTIFICATE REQUEST-----");
|
||||
verifyReceivedMetadataValues(1);
|
||||
verify(timeService, times(1))
|
||||
.schedule(
|
||||
any(Runnable.class),
|
||||
eq(
|
||||
TimeUnit.MILLISECONDS.toSeconds(
|
||||
CERT0_VALIDITY_MILLIS
|
||||
- TimeUnit.SECONDS.toMillis(RENEWAL_GRACE_PERIOD_SECONDS))),
|
||||
eq(TimeUnit.SECONDS));
|
||||
verifyMockWatcher();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void getCertificate_withError()
|
||||
throws IOException, OperatorCreationException, NoSuchAlgorithmException {
|
||||
oauth2Tokens.offer(TEST_STS_TOKEN + "0");
|
||||
responsesToSend
|
||||
.offer(new ResponseThrowable(new StatusRuntimeException(Status.FAILED_PRECONDITION)));
|
||||
ScheduledFuture<?> scheduledFuture = mock(ScheduledFuture.class);
|
||||
doReturn(scheduledFuture).when(timeService)
|
||||
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
|
||||
provider.refreshCertificate();
|
||||
verify(mockWatcher, never())
|
||||
.updateCertificate(any(PrivateKey.class), ArgumentMatchers.<X509Certificate>anyList());
|
||||
verify(mockWatcher, never()).updateTrustedRoots(ArgumentMatchers.<X509Certificate>anyList());
|
||||
verify(mockWatcher, times(1)).onError(Status.FAILED_PRECONDITION);
|
||||
verify(timeService, times(1)).schedule(any(Runnable.class),
|
||||
eq(MeshCaCertificateProvider.INITIAL_DELAY_SECONDS),
|
||||
eq(TimeUnit.SECONDS));
|
||||
verifyReceivedMetadataValues(1);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void getCertificate_withError_withExistingCert()
|
||||
throws IOException, OperatorCreationException, NoSuchAlgorithmException {
|
||||
PrivateKey mockKey = mock(PrivateKey.class);
|
||||
X509Certificate mockCert = mock(X509Certificate.class);
|
||||
// have current cert expire in 3 hours from current time
|
||||
long threeHoursFromNowMillis = TimeUnit.NANOSECONDS
|
||||
.toMillis(CURRENT_TIME_NANOS + TimeUnit.HOURS.toNanos(3));
|
||||
when(mockCert.getNotAfter()).thenReturn(new Date(threeHoursFromNowMillis));
|
||||
provider.getWatcher().updateCertificate(mockKey, ImmutableList.of(mockCert));
|
||||
reset(mockWatcher);
|
||||
oauth2Tokens.offer(TEST_STS_TOKEN + "0");
|
||||
responsesToSend
|
||||
.offer(new ResponseThrowable(new StatusRuntimeException(Status.FAILED_PRECONDITION)));
|
||||
when(timeProvider.currentTimeNanos()).thenReturn(CURRENT_TIME_NANOS);
|
||||
ScheduledFuture<?> scheduledFuture = mock(ScheduledFuture.class);
|
||||
doReturn(scheduledFuture).when(timeService)
|
||||
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
|
||||
provider.refreshCertificate();
|
||||
verify(mockWatcher, never())
|
||||
.updateCertificate(any(PrivateKey.class), ArgumentMatchers.<X509Certificate>anyList());
|
||||
verify(mockWatcher, never()).updateTrustedRoots(ArgumentMatchers.<X509Certificate>anyList());
|
||||
verify(mockWatcher, never()).onError(any(Status.class));
|
||||
verify(timeService, times(1)).schedule(any(Runnable.class),
|
||||
eq(5400L),
|
||||
eq(TimeUnit.SECONDS));
|
||||
assertThat(provider.getWatcher().getLastIdentityCert()).isNotNull();
|
||||
verifyReceivedMetadataValues(1);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void getCertificate_withError_withExistingExpiredCert()
|
||||
throws IOException, OperatorCreationException, NoSuchAlgorithmException {
|
||||
PrivateKey mockKey = mock(PrivateKey.class);
|
||||
X509Certificate mockCert = mock(X509Certificate.class);
|
||||
// have current cert expire in 3 seconds from current time
|
||||
long threeSecondsFromNowMillis = TimeUnit.NANOSECONDS
|
||||
.toMillis(CURRENT_TIME_NANOS + TimeUnit.SECONDS.toNanos(3));
|
||||
when(mockCert.getNotAfter()).thenReturn(new Date(threeSecondsFromNowMillis));
|
||||
provider.getWatcher().updateCertificate(mockKey, ImmutableList.of(mockCert));
|
||||
reset(mockWatcher);
|
||||
oauth2Tokens.offer(TEST_STS_TOKEN + "0");
|
||||
responsesToSend
|
||||
.offer(new ResponseThrowable(new StatusRuntimeException(Status.FAILED_PRECONDITION)));
|
||||
when(timeProvider.currentTimeNanos()).thenReturn(CURRENT_TIME_NANOS);
|
||||
ScheduledFuture<?> scheduledFuture = mock(ScheduledFuture.class);
|
||||
doReturn(scheduledFuture).when(timeService)
|
||||
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
|
||||
provider.refreshCertificate();
|
||||
verify(mockWatcher, never())
|
||||
.updateCertificate(any(PrivateKey.class), ArgumentMatchers.<X509Certificate>anyList());
|
||||
verify(mockWatcher, never()).updateTrustedRoots(ArgumentMatchers.<X509Certificate>anyList());
|
||||
verify(mockWatcher, times(1)).onError(Status.FAILED_PRECONDITION);
|
||||
verify(timeService, times(1)).schedule(any(Runnable.class),
|
||||
eq(MeshCaCertificateProvider.INITIAL_DELAY_SECONDS),
|
||||
eq(TimeUnit.SECONDS));
|
||||
assertThat(provider.getWatcher().getLastIdentityCert()).isNull();
|
||||
verifyReceivedMetadataValues(1);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void getCertificate_retriesWithErrors()
|
||||
throws IOException, CertificateException, OperatorCreationException,
|
||||
NoSuchAlgorithmException, InterruptedException, ExecutionException, TimeoutException {
|
||||
oauth2Tokens.offer(TEST_STS_TOKEN + "0");
|
||||
oauth2Tokens.offer(TEST_STS_TOKEN + "1");
|
||||
oauth2Tokens.offer(TEST_STS_TOKEN + "2");
|
||||
responsesToSend.offer(new ResponseThrowable(new StatusRuntimeException(Status.UNKNOWN)));
|
||||
responsesToSend.offer(
|
||||
new ResponseThrowable(
|
||||
new Exception(new StatusRuntimeException(Status.RESOURCE_EXHAUSTED))));
|
||||
responsesToSend.offer(new ResponseList(ImmutableList.of(
|
||||
CommonTlsContextTestsUtil.getResourceContents(SERVER_0_PEM_FILE),
|
||||
CommonTlsContextTestsUtil.getResourceContents(SERVER_1_PEM_FILE),
|
||||
CommonTlsContextTestsUtil.getResourceContents(CA_PEM_FILE))));
|
||||
when(timeProvider.currentTimeNanos()).thenReturn(CURRENT_TIME_NANOS);
|
||||
ScheduledFuture<?> scheduledFuture = mock(ScheduledFuture.class);
|
||||
doReturn(scheduledFuture).when(timeService)
|
||||
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
|
||||
ScheduledFuture<?> scheduledFutureSleep = mock(ScheduledFuture.class);
|
||||
doReturn(scheduledFutureSleep).when(timeService)
|
||||
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.NANOSECONDS));
|
||||
provider.refreshCertificate();
|
||||
assertThat(receivedRequests.size()).isEqualTo(3);
|
||||
verify(timeService, times(1)).schedule(any(Runnable.class),
|
||||
eq(TimeUnit.MILLISECONDS.toSeconds(
|
||||
CERT0_VALIDITY_MILLIS - TimeUnit.SECONDS.toMillis(RENEWAL_GRACE_PERIOD_SECONDS))),
|
||||
eq(TimeUnit.SECONDS));
|
||||
verifyRetriesWithBackoff(scheduledFutureSleep, 2);
|
||||
verifyMockWatcher();
|
||||
verifyReceivedMetadataValues(3);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void getCertificate_retriesWithTimeouts()
|
||||
throws IOException, CertificateException, OperatorCreationException,
|
||||
NoSuchAlgorithmException, InterruptedException, ExecutionException, TimeoutException {
|
||||
oauth2Tokens.offer(TEST_STS_TOKEN + "0");
|
||||
oauth2Tokens.offer(TEST_STS_TOKEN + "1");
|
||||
oauth2Tokens.offer(TEST_STS_TOKEN + "2");
|
||||
oauth2Tokens.offer(TEST_STS_TOKEN + "3");
|
||||
responsesToSend.offer(new ResponseToSend());
|
||||
responsesToSend.offer(new ResponseToSend());
|
||||
responsesToSend.offer(new ResponseToSend());
|
||||
responsesToSend.offer(new ResponseList(ImmutableList.of(
|
||||
CommonTlsContextTestsUtil.getResourceContents(SERVER_0_PEM_FILE),
|
||||
CommonTlsContextTestsUtil.getResourceContents(SERVER_1_PEM_FILE),
|
||||
CommonTlsContextTestsUtil.getResourceContents(CA_PEM_FILE))));
|
||||
when(timeProvider.currentTimeNanos()).thenReturn(CURRENT_TIME_NANOS);
|
||||
ScheduledFuture<?> scheduledFuture = mock(ScheduledFuture.class);
|
||||
doReturn(scheduledFuture).when(timeService)
|
||||
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
|
||||
ScheduledFuture<?> scheduledFutureSleep = mock(ScheduledFuture.class);
|
||||
doReturn(scheduledFutureSleep).when(timeService)
|
||||
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.NANOSECONDS));
|
||||
provider.refreshCertificate();
|
||||
assertThat(receivedRequests.size()).isEqualTo(4);
|
||||
verify(timeService, times(1)).schedule(any(Runnable.class),
|
||||
eq(TimeUnit.MILLISECONDS.toSeconds(
|
||||
CERT0_VALIDITY_MILLIS - TimeUnit.SECONDS.toMillis(RENEWAL_GRACE_PERIOD_SECONDS))),
|
||||
eq(TimeUnit.SECONDS));
|
||||
verifyRetriesWithBackoff(scheduledFutureSleep, 3);
|
||||
verifyMockWatcher();
|
||||
verifyReceivedMetadataValues(4);
|
||||
}
|
||||
|
||||
private void verifyRetriesWithBackoff(ScheduledFuture<?> scheduledFutureSleep, int numOfRetries)
|
||||
throws InterruptedException, ExecutionException, TimeoutException {
|
||||
for (int i = 0; i < numOfRetries; i++) {
|
||||
long delayValue = DELAY_VALUES[i];
|
||||
verify(timeService, times(1)).schedule(any(Runnable.class),
|
||||
eq(delayValue),
|
||||
eq(TimeUnit.NANOSECONDS));
|
||||
verify(scheduledFutureSleep, times(1)).get(eq(delayValue), eq(TimeUnit.NANOSECONDS));
|
||||
}
|
||||
}
|
||||
|
||||
private void verifyMockWatcher() throws IOException, CertificateException {
|
||||
ArgumentCaptor<List<X509Certificate>> certChainCaptor = ArgumentCaptor.forClass(null);
|
||||
verify(mockWatcher, times(1))
|
||||
.updateCertificate(any(PrivateKey.class), certChainCaptor.capture());
|
||||
List<X509Certificate> certChain = certChainCaptor.getValue();
|
||||
assertThat(certChain).hasSize(3);
|
||||
assertThat(certChain.get(0))
|
||||
.isEqualTo(CommonTlsContextTestsUtil.getCertFromResourceName(SERVER_0_PEM_FILE));
|
||||
assertThat(certChain.get(1))
|
||||
.isEqualTo(CommonTlsContextTestsUtil.getCertFromResourceName(SERVER_1_PEM_FILE));
|
||||
assertThat(certChain.get(2))
|
||||
.isEqualTo(CommonTlsContextTestsUtil.getCertFromResourceName(CA_PEM_FILE));
|
||||
|
||||
ArgumentCaptor<List<X509Certificate>> rootsCaptor = ArgumentCaptor.forClass(null);
|
||||
verify(mockWatcher, times(1)).updateTrustedRoots(rootsCaptor.capture());
|
||||
List<X509Certificate> roots = rootsCaptor.getValue();
|
||||
assertThat(roots).hasSize(1);
|
||||
assertThat(roots.get(0))
|
||||
.isEqualTo(CommonTlsContextTestsUtil.getCertFromResourceName(CA_PEM_FILE));
|
||||
verify(mockWatcher, never()).onError(any(Status.class));
|
||||
}
|
||||
|
||||
private void verifyReceivedMetadataValues(int count) {
|
||||
assertThat(receivedStsCreds).hasSize(count);
|
||||
assertThat(receivedZoneValues).hasSize(count);
|
||||
for (int i = 0; i < count; i++) {
|
||||
assertThat(receivedStsCreds.poll()).isEqualTo("Bearer " + TEST_STS_TOKEN + i);
|
||||
assertThat(receivedZoneValues.poll()).isEqualTo("us-west2-a");
|
||||
}
|
||||
}
|
||||
}
|
|
@ -16,7 +16,10 @@
|
|||
|
||||
package io.grpc.xds.internal.sds;
|
||||
|
||||
import static java.nio.charset.StandardCharsets.UTF_8;
|
||||
|
||||
import com.google.common.base.Strings;
|
||||
import com.google.common.io.CharStreams;
|
||||
import com.google.protobuf.BoolValue;
|
||||
import com.google.protobuf.Struct;
|
||||
import com.google.protobuf.Value;
|
||||
|
@ -36,7 +39,14 @@ import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContex
|
|||
import io.envoyproxy.envoy.type.matcher.v3.StringMatcher;
|
||||
import io.grpc.internal.testing.TestUtils;
|
||||
import io.grpc.xds.EnvoyServerProtoData;
|
||||
import io.grpc.xds.internal.sds.trust.CertificateUtils;
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.InputStreamReader;
|
||||
import java.io.Reader;
|
||||
import java.security.cert.CertificateException;
|
||||
import java.security.cert.X509Certificate;
|
||||
import java.util.Arrays;
|
||||
import javax.annotation.Nullable;
|
||||
|
||||
|
@ -432,4 +442,23 @@ public class CommonTlsContextTestsUtil {
|
|||
return EnvoyServerProtoData.UpstreamTlsContext.fromEnvoyProtoUpstreamTlsContext(
|
||||
upstreamTlsContext);
|
||||
}
|
||||
|
||||
/** Gets a cert from contents of a resource. */
|
||||
public static X509Certificate getCertFromResourceName(String resourceName)
|
||||
throws IOException, CertificateException {
|
||||
try (ByteArrayInputStream bais =
|
||||
new ByteArrayInputStream(getResourceContents(resourceName).getBytes(UTF_8))) {
|
||||
return CertificateUtils.toX509Certificate(bais);
|
||||
}
|
||||
}
|
||||
|
||||
/** Gets contents of a resource from TestUtils.class loader. */
|
||||
public static String getResourceContents(String resourceName) throws IOException {
|
||||
InputStream inputStream = TestUtils.class.getResourceAsStream("/certs/" + resourceName);
|
||||
String text = null;
|
||||
try (Reader reader = new InputStreamReader(inputStream, UTF_8)) {
|
||||
text = CharStreams.toString(reader);
|
||||
}
|
||||
return text;
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue