xds: implement MeshCACertificateProvider (#7274)

This commit is contained in:
sanjaypujare 2020-08-07 16:16:22 -07:00 committed by GitHub
parent 0a99a20b70
commit 65e7ffc788
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 1107 additions and 62 deletions

View File

@ -177,6 +177,8 @@ subprojects {
conscrypt: 'org.conscrypt:conscrypt-openjdk-uber:2.2.1',
re2j: 'com.google.re2j:re2j:1.2',
bouncycastle: 'org.bouncycastle:bcpkix-jdk15on:1.61',
// Test dependencies.
junit: 'junit:junit:4.12',
mockito: 'org.mockito:mockito-core:3.3.3',

View File

@ -26,7 +26,8 @@ dependencies {
project(':grpc-auth'),
project(path: ':grpc-alts', configuration: 'shadow'),
libraries.gson,
libraries.re2j
libraries.re2j,
libraries.bouncycastle
def nettyDependency = implementation project(':grpc-netty')
implementation (libraries.opencensus_proto) {

View File

@ -49,42 +49,42 @@ public abstract class CertificateProvider implements Closeable {
@VisibleForTesting
static final class DistributorWatcher implements Watcher {
private PrivateKey lastKey;
private List<X509Certificate> lastCertChain;
private List<X509Certificate> lastTrustedRoots;
private PrivateKey privateKey;
private List<X509Certificate> certChain;
private List<X509Certificate> trustedRoots;
@VisibleForTesting
final Set<Watcher> downsstreamWatchers = new HashSet<>();
final Set<Watcher> downstreamWatchers = new HashSet<>();
synchronized void addWatcher(Watcher watcher) {
downsstreamWatchers.add(watcher);
if (lastKey != null && lastCertChain != null) {
downstreamWatchers.add(watcher);
if (privateKey != null && certChain != null) {
sendLastCertificateUpdate(watcher);
}
if (lastTrustedRoots != null) {
if (trustedRoots != null) {
sendLastTrustedRootsUpdate(watcher);
}
}
synchronized void removeWatcher(Watcher watcher) {
downsstreamWatchers.remove(watcher);
downstreamWatchers.remove(watcher);
}
private void sendLastCertificateUpdate(Watcher watcher) {
watcher.updateCertificate(lastKey, lastCertChain);
watcher.updateCertificate(privateKey, certChain);
}
private void sendLastTrustedRootsUpdate(Watcher watcher) {
watcher.updateTrustedRoots(lastTrustedRoots);
watcher.updateTrustedRoots(trustedRoots);
}
@Override
public synchronized void updateCertificate(PrivateKey key, List<X509Certificate> certChain) {
checkNotNull(key, "key");
checkNotNull(certChain, "certChain");
lastKey = key;
lastCertChain = certChain;
for (Watcher watcher : downsstreamWatchers) {
privateKey = key;
this.certChain = certChain;
for (Watcher watcher : downstreamWatchers) {
sendLastCertificateUpdate(watcher);
}
}
@ -92,18 +92,36 @@ public abstract class CertificateProvider implements Closeable {
@Override
public synchronized void updateTrustedRoots(List<X509Certificate> trustedRoots) {
checkNotNull(trustedRoots, "trustedRoots");
lastTrustedRoots = trustedRoots;
for (Watcher watcher : downsstreamWatchers) {
this.trustedRoots = trustedRoots;
for (Watcher watcher : downstreamWatchers) {
sendLastTrustedRootsUpdate(watcher);
}
}
@Override
public synchronized void onError(Status errorStatus) {
for (Watcher watcher : downsstreamWatchers) {
for (Watcher watcher : downstreamWatchers) {
watcher.onError(errorStatus);
}
}
X509Certificate getLastIdentityCert() {
if (certChain != null && !certChain.isEmpty()) {
return certChain.get(0);
}
return null;
}
void close() {
downstreamWatchers.clear();
clearValues();
}
void clearValues() {
privateKey = null;
certChain = null;
trustedRoots = null;
}
}
/**

View File

@ -17,35 +17,336 @@
package io.grpc.xds.internal.certprovider;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static io.grpc.Status.Code.ABORTED;
import static io.grpc.Status.Code.CANCELLED;
import static io.grpc.Status.Code.DEADLINE_EXCEEDED;
import static io.grpc.Status.Code.INTERNAL;
import static io.grpc.Status.Code.RESOURCE_EXHAUSTED;
import static io.grpc.Status.Code.UNAVAILABLE;
import static io.grpc.Status.Code.UNKNOWN;
import static java.nio.charset.StandardCharsets.UTF_8;
import com.google.auth.oauth2.GoogleCredentials;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.protobuf.Duration;
import google.security.meshca.v1.MeshCertificateServiceGrpc;
import google.security.meshca.v1.Meshca;
import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ClientCall;
import io.grpc.ClientInterceptor;
import io.grpc.ForwardingClientCall;
import io.grpc.InternalLogId;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.Status;
import io.grpc.SynchronizationContext;
import io.grpc.auth.MoreCallCredentials;
import io.grpc.internal.BackoffPolicy;
import io.grpc.internal.TimeProvider;
import io.grpc.xds.internal.sds.trust.CertificateUtils;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.StringWriter;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.NoSuchAlgorithmException;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.EnumSet;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.security.auth.x500.X500Principal;
import org.bouncycastle.openssl.jcajce.JcaPEMWriter;
import org.bouncycastle.operator.ContentSigner;
import org.bouncycastle.operator.OperatorCreationException;
import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder;
import org.bouncycastle.pkcs.PKCS10CertificationRequest;
import org.bouncycastle.pkcs.PKCS10CertificationRequestBuilder;
import org.bouncycastle.pkcs.jcajce.JcaPKCS10CertificationRequestBuilder;
import org.bouncycastle.util.io.pem.PemObject;
/** Implementation of {@link CertificateProvider} for the Google Mesh CA. */
final class MeshCaCertificateProvider extends CertificateProvider {
private static final Logger logger = Logger.getLogger(MeshCaCertificateProvider.class.getName());
MeshCaCertificateProvider(DistributorWatcher watcher, boolean notifyCertUpdates,
String meshCaUrl, String zone, long validitySeconds,
int keySize, String alg, String signatureAlg, MeshCaChannelFactory meshCaChannelFactory,
BackoffPolicy.Provider backoffPolicyProvider, long renewalGracePeriodSeconds,
int maxRetryAttempts, GoogleCredentials oauth2Creds) {
MeshCaCertificateProvider(
DistributorWatcher watcher,
boolean notifyCertUpdates,
String meshCaUrl,
String zone,
long validitySeconds,
int keySize,
String alg,
String signatureAlg, MeshCaChannelFactory meshCaChannelFactory,
BackoffPolicy.Provider backoffPolicyProvider,
long renewalGracePeriodSeconds,
int maxRetryAttempts,
GoogleCredentials oauth2Creds,
ScheduledExecutorService scheduledExecutorService,
TimeProvider timeProvider,
long rpcTimeoutMillis) {
super(watcher, notifyCertUpdates);
this.meshCaUrl = checkNotNull(meshCaUrl, "meshCaUrl");
checkArgument(
validitySeconds > INITIAL_DELAY_SECONDS,
"validitySeconds must be greater than " + INITIAL_DELAY_SECONDS);
this.validitySeconds = validitySeconds;
this.keySize = keySize;
this.alg = checkNotNull(alg, "alg");
this.signatureAlg = checkNotNull(signatureAlg, "signatureAlg");
this.meshCaChannelFactory = checkNotNull(meshCaChannelFactory, "meshCaChannelFactory");
this.backoffPolicyProvider = checkNotNull(backoffPolicyProvider, "backoffPolicyProvider");
checkArgument(
renewalGracePeriodSeconds > 0L && renewalGracePeriodSeconds < validitySeconds,
"renewalGracePeriodSeconds should be between 0 and " + validitySeconds);
this.renewalGracePeriodSeconds = renewalGracePeriodSeconds;
checkArgument(maxRetryAttempts >= 0, "maxRetryAttempts must be >= 0");
this.maxRetryAttempts = maxRetryAttempts;
this.oauth2Creds = checkNotNull(oauth2Creds, "oauth2Creds");
this.scheduledExecutorService =
checkNotNull(scheduledExecutorService, "scheduledExecutorService");
this.timeProvider = checkNotNull(timeProvider, "timeProvider");
this.headerInterceptor = new ZoneInfoClientInterceptor(checkNotNull(zone, "zone"));
this.syncContext = createSynchronizationContext(meshCaUrl);
this.rpcTimeoutMillis = rpcTimeoutMillis;
}
private SynchronizationContext createSynchronizationContext(String details) {
final InternalLogId logId = InternalLogId.allocate("MeshCaCertificateProvider", details);
return new SynchronizationContext(
new Thread.UncaughtExceptionHandler() {
private boolean panicMode;
@Override
public void uncaughtException(Thread t, Throwable e) {
logger.log(
Level.SEVERE,
"[" + logId + "] Uncaught exception in the SynchronizationContext. Panic!",
e);
panic(e);
}
void panic(final Throwable t) {
if (panicMode) {
// Preserve the first panic information
return;
}
panicMode = true;
close();
}
});
}
@Override
public void start() {
// TODO implement
scheduleNextRefreshCertificate(INITIAL_DELAY_SECONDS);
}
@Override
public void close() {
// TODO implement
if (scheduledHandle != null) {
scheduledHandle.cancel();
scheduledHandle = null;
}
getWatcher().close();
}
private void scheduleNextRefreshCertificate(long delayInSeconds) {
if (scheduledHandle != null && scheduledHandle.isPending()) {
logger.log(Level.SEVERE, "Pending task found: inconsistent state in scheduledHandle!");
scheduledHandle.cancel();
}
RefreshCertificateTask runnable = new RefreshCertificateTask();
scheduledHandle = syncContext.schedule(
runnable, delayInSeconds, TimeUnit.SECONDS, scheduledExecutorService);
}
@VisibleForTesting
void refreshCertificate()
throws NoSuchAlgorithmException, IOException, OperatorCreationException {
long refreshDelaySeconds = computeRefreshSecondsFromCurrentCertExpiry();
ManagedChannel channel = meshCaChannelFactory.createChannel(meshCaUrl);
try {
String uniqueReqIdForAllRetries = UUID.randomUUID().toString();
Duration duration = Duration.newBuilder().setSeconds(validitySeconds).build();
KeyPair keyPair = generateKeyPair();
String csr = generateCsr(keyPair);
MeshCertificateServiceGrpc.MeshCertificateServiceBlockingStub stub =
createStubToMeshCa(channel);
List<X509Certificate> x509Chain = makeRequestWithRetries(stub, uniqueReqIdForAllRetries,
duration, csr);
if (x509Chain != null) {
refreshDelaySeconds =
computeDelaySecondsToCertExpiry(x509Chain.get(0)) - renewalGracePeriodSeconds;
getWatcher().updateCertificate(keyPair.getPrivate(), x509Chain);
getWatcher().updateTrustedRoots(ImmutableList.of(x509Chain.get(x509Chain.size() - 1)));
}
} finally {
shutdownChannel(channel);
scheduleNextRefreshCertificate(refreshDelaySeconds);
}
}
private MeshCertificateServiceGrpc.MeshCertificateServiceBlockingStub createStubToMeshCa(
ManagedChannel channel) {
return MeshCertificateServiceGrpc
.newBlockingStub(channel)
.withCallCredentials(MoreCallCredentials.from(oauth2Creds))
.withInterceptors(headerInterceptor);
}
private List<X509Certificate> makeRequestWithRetries(
MeshCertificateServiceGrpc.MeshCertificateServiceBlockingStub stub,
String reqId,
Duration duration,
String csr) {
Meshca.MeshCertificateRequest request =
Meshca.MeshCertificateRequest.newBuilder()
.setValidity(duration)
.setCsr(csr)
.setRequestId(reqId)
.build();
BackoffPolicy backoffPolicy = backoffPolicyProvider.get();
Throwable lastException = null;
for (int i = 0; i <= maxRetryAttempts; i++) {
try {
Meshca.MeshCertificateResponse response =
stub.withDeadlineAfter(rpcTimeoutMillis, TimeUnit.MILLISECONDS)
.createCertificate(request);
return getX509CertificatesFromResponse(response);
} catch (Throwable t) {
if (!retriable(t)) {
generateErrorIfCurrentCertExpired(t);
return null;
}
lastException = t;
sleepForNanos(backoffPolicy.nextBackoffNanos());
}
}
generateErrorIfCurrentCertExpired(lastException);
return null;
}
private void sleepForNanos(long nanos) {
ScheduledFuture<?> future = scheduledExecutorService.schedule(new Runnable() {
@Override
public void run() {
// do nothing
}
}, nanos, TimeUnit.NANOSECONDS);
try {
future.get(nanos, TimeUnit.NANOSECONDS);
} catch (InterruptedException ie) {
logger.log(Level.SEVERE, "Inside sleep", ie);
Thread.currentThread().interrupt();
} catch (ExecutionException | TimeoutException ex) {
logger.log(Level.SEVERE, "Inside sleep", ex);
}
}
private static boolean retriable(Throwable t) {
return RETRIABLE_CODES.contains(Status.fromThrowable(t).getCode());
}
private void generateErrorIfCurrentCertExpired(Throwable t) {
X509Certificate currentCert = getWatcher().getLastIdentityCert();
if (currentCert != null) {
long delaySeconds = computeDelaySecondsToCertExpiry(currentCert);
if (delaySeconds > INITIAL_DELAY_SECONDS) {
return;
}
getWatcher().clearValues();
}
getWatcher().onError(Status.fromThrowable(t));
}
private KeyPair generateKeyPair() throws NoSuchAlgorithmException {
KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance(alg);
keyPairGenerator.initialize(keySize);
return keyPairGenerator.generateKeyPair();
}
private String generateCsr(KeyPair pair) throws IOException, OperatorCreationException {
PKCS10CertificationRequestBuilder p10Builder =
new JcaPKCS10CertificationRequestBuilder(
new X500Principal("CN=EXAMPLE.COM"), pair.getPublic());
JcaContentSignerBuilder csBuilder = new JcaContentSignerBuilder(signatureAlg);
ContentSigner signer = csBuilder.build(pair.getPrivate());
PKCS10CertificationRequest csr = p10Builder.build(signer);
PemObject pemObject = new PemObject("NEW CERTIFICATE REQUEST", csr.getEncoded());
try (StringWriter str = new StringWriter()) {
try (JcaPEMWriter pemWriter = new JcaPEMWriter(str)) {
pemWriter.writeObject(pemObject);
}
return str.toString();
}
}
/** Compute refresh interval as half of interval to current cert expiry. */
private long computeRefreshSecondsFromCurrentCertExpiry() {
X509Certificate lastCert = getWatcher().getLastIdentityCert();
if (lastCert == null) {
return INITIAL_DELAY_SECONDS;
}
long delayToCertExpirySeconds = computeDelaySecondsToCertExpiry(lastCert) / 2;
return Math.max(delayToCertExpirySeconds, INITIAL_DELAY_SECONDS);
}
@SuppressWarnings("JdkObsolete")
private long computeDelaySecondsToCertExpiry(X509Certificate lastCert) {
checkNotNull(lastCert, "lastCert");
return TimeUnit.NANOSECONDS.toSeconds(
TimeUnit.MILLISECONDS.toNanos(lastCert.getNotAfter().getTime()) - timeProvider
.currentTimeNanos());
}
private static void shutdownChannel(ManagedChannel channel) {
channel.shutdown();
try {
channel.awaitTermination(10, TimeUnit.SECONDS);
} catch (InterruptedException ex) {
logger.log(Level.SEVERE, "awaiting channel Termination", ex);
channel.shutdownNow();
Thread.currentThread().interrupt();
}
}
private List<X509Certificate> getX509CertificatesFromResponse(
Meshca.MeshCertificateResponse response) throws CertificateException, IOException {
List<String> certChain = response.getCertChainList();
List<X509Certificate> x509Chain = new ArrayList<>(certChain.size());
for (String certString : certChain) {
try (ByteArrayInputStream bais = new ByteArrayInputStream(certString.getBytes(UTF_8))) {
x509Chain.add(CertificateUtils.toX509Certificate(bais));
}
}
return x509Chain;
}
@VisibleForTesting
class RefreshCertificateTask implements Runnable {
@Override
public void run() {
try {
refreshCertificate();
} catch (NoSuchAlgorithmException | OperatorCreationException | IOException ex) {
logger.log(Level.SEVERE, "refreshing certificate", ex);
}
}
}
/** Factory for creating channels to MeshCA sever. */
@ -94,7 +395,10 @@ final class MeshCaCertificateProvider extends CertificateProvider {
BackoffPolicy.Provider backoffPolicyProvider,
long renewalGracePeriodSeconds,
int maxRetryAttempts,
GoogleCredentials oauth2Creds) {
GoogleCredentials oauth2Creds,
ScheduledExecutorService scheduledExecutorService,
TimeProvider timeProvider,
long rpcTimeoutMillis) {
return new MeshCaCertificateProvider(
watcher,
notifyCertUpdates,
@ -108,7 +412,10 @@ final class MeshCaCertificateProvider extends CertificateProvider {
backoffPolicyProvider,
renewalGracePeriodSeconds,
maxRetryAttempts,
oauth2Creds);
oauth2Creds,
scheduledExecutorService,
timeProvider,
rpcTimeoutMillis);
}
};
@ -129,6 +436,64 @@ final class MeshCaCertificateProvider extends CertificateProvider {
BackoffPolicy.Provider backoffPolicyProvider,
long renewalGracePeriodSeconds,
int maxRetryAttempts,
GoogleCredentials oauth2Creds);
GoogleCredentials oauth2Creds,
ScheduledExecutorService scheduledExecutorService,
TimeProvider timeProvider,
long rpcTimeoutMillis);
}
private class ZoneInfoClientInterceptor implements ClientInterceptor {
private final String zone;
ZoneInfoClientInterceptor(String zone) {
this.zone = zone;
}
@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
return new ForwardingClientCall.SimpleForwardingClientCall<ReqT, RespT>(
next.newCall(method, callOptions)) {
@Override
public void start(Listener<RespT> responseListener, Metadata headers) {
headers.put(KEY_FOR_ZONE_INFO, zone);
super.start(responseListener, headers);
}
};
}
}
@VisibleForTesting
static final Metadata.Key<String> KEY_FOR_ZONE_INFO =
Metadata.Key.of("x-goog-request-params", Metadata.ASCII_STRING_MARSHALLER);
@VisibleForTesting
static final long INITIAL_DELAY_SECONDS = 4L;
private static final EnumSet<Status.Code> RETRIABLE_CODES =
EnumSet.of(
CANCELLED,
UNKNOWN,
DEADLINE_EXCEEDED,
RESOURCE_EXHAUSTED,
ABORTED,
INTERNAL,
UNAVAILABLE);
private final SynchronizationContext syncContext;
private final ScheduledExecutorService scheduledExecutorService;
private final int maxRetryAttempts;
private final ZoneInfoClientInterceptor headerInterceptor;
private final BackoffPolicy.Provider backoffPolicyProvider;
private final String meshCaUrl;
private final long validitySeconds;
private final long renewalGracePeriodSeconds;
private final int keySize;
private final String alg;
private final String signatureAlg;
private final GoogleCredentials oauth2Creds;
private final TimeProvider timeProvider;
private final MeshCaChannelFactory meshCaChannelFactory;
@VisibleForTesting SynchronizationContext.ScheduledHandle scheduledHandle;
private final long rpcTimeoutMillis;
}

View File

@ -21,10 +21,15 @@ import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import io.grpc.internal.BackoffPolicy;
import io.grpc.internal.ExponentialBackoffPolicy;
import io.grpc.internal.TimeProvider;
import io.grpc.xds.internal.sts.StsCredentials;
import java.util.Map;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
@ -46,16 +51,20 @@ final class MeshCaCertificateProviderProvider implements CertificateProviderProv
private static final String STS_URL_KEY = "stsUrl";
private static final String GKE_SA_JWT_LOCATION_KEY = "gkeSaJwtLocation";
static final String MESHCA_URL_DEFAULT = "meshca.googleapis.com";
static final long RPC_TIMEOUT_SECONDS_DEFAULT = 5L;
static final long CERT_VALIDITY_SECONDS_DEFAULT = 9L * 3600L; // 9 hours
static final long RENEWAL_GRACE_PERIOD_SECONDS_DEFAULT = 1L * 3600L; // 1 hour
static final String KEY_ALGO_DEFAULT = "RSA"; // aka keyType
static final int KEY_SIZE_DEFAULT = 2048;
static final String SIGNATURE_ALGO_DEFAULT = "SHA256withRSA";
static final int MAX_RETRY_ATTEMPTS_DEFAULT = 3;
@VisibleForTesting static final String MESHCA_URL_DEFAULT = "meshca.googleapis.com";
@VisibleForTesting static final long RPC_TIMEOUT_SECONDS_DEFAULT = 5L;
@VisibleForTesting static final long CERT_VALIDITY_SECONDS_DEFAULT = 9L * 3600L;
@VisibleForTesting static final long RENEWAL_GRACE_PERIOD_SECONDS_DEFAULT = 1L * 3600L;
@VisibleForTesting static final String KEY_ALGO_DEFAULT = "RSA"; // aka keyType
@VisibleForTesting static final int KEY_SIZE_DEFAULT = 2048;
@VisibleForTesting static final String SIGNATURE_ALGO_DEFAULT = "SHA256withRSA";
@VisibleForTesting static final int MAX_RETRY_ATTEMPTS_DEFAULT = 3;
@VisibleForTesting
static final String STS_URL_DEFAULT = "https://securetoken.googleapis.com/v1/identitybindingtoken";
@VisibleForTesting
static final long RPC_TIMEOUT_SECONDS = 10L;
private static final Pattern CLUSTER_URL_PATTERN = Pattern
.compile(".*/projects/(.*)/locations/(.*)/clusters/.*");
@ -70,23 +79,32 @@ final class MeshCaCertificateProviderProvider implements CertificateProviderProv
StsCredentials.Factory.getInstance(),
MeshCaCertificateProvider.MeshCaChannelFactory.getInstance(),
new ExponentialBackoffPolicy.Provider(),
MeshCaCertificateProvider.Factory.getInstance()));
MeshCaCertificateProvider.Factory.getInstance(),
ScheduledExecutorServiceFactory.DEFAULT_INSTANCE,
TimeProvider.SYSTEM_TIME_PROVIDER));
}
final StsCredentials.Factory stsCredentialsFactory;
final MeshCaCertificateProvider.MeshCaChannelFactory meshCaChannelFactory;
final BackoffPolicy.Provider backoffPolicyProvider;
final MeshCaCertificateProvider.Factory meshCaCertificateProviderFactory;
final ScheduledExecutorServiceFactory scheduledExecutorServiceFactory;
final TimeProvider timeProvider;
@VisibleForTesting
MeshCaCertificateProviderProvider(StsCredentials.Factory stsCredentialsFactory,
MeshCaCertificateProviderProvider(
StsCredentials.Factory stsCredentialsFactory,
MeshCaCertificateProvider.MeshCaChannelFactory meshCaChannelFactory,
BackoffPolicy.Provider backoffPolicyProvider,
MeshCaCertificateProvider.Factory meshCaCertificateProviderFactory) {
MeshCaCertificateProvider.Factory meshCaCertificateProviderFactory,
ScheduledExecutorServiceFactory scheduledExecutorServiceFactory,
TimeProvider timeProvider) {
this.stsCredentialsFactory = stsCredentialsFactory;
this.meshCaChannelFactory = meshCaChannelFactory;
this.backoffPolicyProvider = backoffPolicyProvider;
this.meshCaCertificateProviderFactory = meshCaCertificateProviderFactory;
this.scheduledExecutorServiceFactory = scheduledExecutorServiceFactory;
this.timeProvider = timeProvider;
}
@Override
@ -106,12 +124,23 @@ final class MeshCaCertificateProviderProvider implements CertificateProviderProv
StsCredentials stsCredentials = stsCredentialsFactory
.create(configObj.stsUrl, audience, configObj.gkeSaJwtLocation);
return meshCaCertificateProviderFactory.create(watcher, notifyCertUpdates, configObj.meshCaUrl,
return meshCaCertificateProviderFactory.create(
watcher,
notifyCertUpdates,
configObj.meshCaUrl,
configObj.zone,
configObj.certValiditySeconds, configObj.keySize, configObj.keyAlgo,
configObj.certValiditySeconds,
configObj.keySize,
configObj.keyAlgo,
configObj.signatureAlgo,
meshCaChannelFactory, backoffPolicyProvider,
configObj.renewalGracePeriodSeconds, configObj.maxRetryAttempts, stsCredentials);
meshCaChannelFactory,
backoffPolicyProvider,
configObj.renewalGracePeriodSeconds,
configObj.maxRetryAttempts,
stsCredentials,
scheduledExecutorServiceFactory.create(configObj.meshCaUrl),
timeProvider,
TimeUnit.SECONDS.toMillis(RPC_TIMEOUT_SECONDS));
}
private static Config validateAndTranslateConfig(Object config) {
@ -177,6 +206,28 @@ final class MeshCaCertificateProviderProvider implements CertificateProviderProv
configObj.zone = matcher.group(2);
}
abstract static class ScheduledExecutorServiceFactory {
private static final ScheduledExecutorServiceFactory DEFAULT_INSTANCE =
new ScheduledExecutorServiceFactory() {
@Override
ScheduledExecutorService create(String serverUri) {
return Executors.newSingleThreadScheduledExecutor(
new ThreadFactoryBuilder()
.setNameFormat("meshca-" + serverUri + "-%d")
.setDaemon(true)
.build());
}
};
static ScheduledExecutorServiceFactory getInstance() {
return DEFAULT_INSTANCE;
}
abstract ScheduledExecutorService create(String serverUri);
}
/** POJO class for storing various config values. */
@VisibleForTesting
static class Config {

View File

@ -30,7 +30,7 @@ import java.util.Collection;
/**
* Contains certificate utility method(s).
*/
final class CertificateUtils {
public final class CertificateUtils {
private static CertificateFactory factory;
@ -46,19 +46,26 @@ final class CertificateUtils {
* @param file a {@link File} containing the cert data
*/
static X509Certificate[] toX509Certificates(File file) throws CertificateException, IOException {
FileInputStream fis = new FileInputStream(file);
return toX509Certificates(new BufferedInputStream(fis));
try (FileInputStream fis = new FileInputStream(file);
BufferedInputStream bis = new BufferedInputStream(fis)) {
return toX509Certificates(bis);
}
}
static synchronized X509Certificate[] toX509Certificates(InputStream inputStream)
throws CertificateException, IOException {
initInstance();
try {
Collection<? extends Certificate> certs = factory.generateCertificates(inputStream);
return certs.toArray(new X509Certificate[0]);
} finally {
inputStream.close();
}
Collection<? extends Certificate> certs = factory.generateCertificates(inputStream);
return certs.toArray(new X509Certificate[0]);
}
/** See {@link CertificateFactory#generateCertificate(InputStream)}. */
public static synchronized X509Certificate toX509Certificate(InputStream inputStream)
throws CertificateException, IOException {
initInstance();
Certificate cert = factory.generateCertificate(inputStream);
return (X509Certificate) cert;
}
private CertificateUtils() {}

View File

@ -27,6 +27,7 @@ import io.grpc.xds.internal.sds.TlsContextManagerImpl;
import io.netty.handler.ssl.util.SimpleTrustManagerFactory;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.security.KeyStore;
import java.security.KeyStoreException;
import java.security.NoSuchAlgorithmException;
@ -69,8 +70,10 @@ public final class SdsTrustManagerFactory extends SimpleTrustManagerFactory {
"trustedCa.file-name in certificateValidationContext cannot be empty");
return CertificateUtils.toX509Certificates(new File(certsFile));
} else if (specifierCase == SpecifierCase.INLINE_BYTES) {
return CertificateUtils.toX509Certificates(
certificateValidationContext.getTrustedCa().getInlineBytes().newInput());
try (InputStream is =
certificateValidationContext.getTrustedCa().getInlineBytes().newInput()) {
return CertificateUtils.toX509Certificates(is);
}
} else {
throw new IllegalArgumentException("Not supported: " + specifierCase);
}

View File

@ -169,7 +169,7 @@ public class CertificateProviderStoreTest {
(TestCertificateProvider) handle1.certProvider;
assertThat(testCertificateProvider.startCalled).isEqualTo(1);
CertificateProvider.DistributorWatcher distWatcher = testCertificateProvider.getWatcher();
assertThat(distWatcher.downsstreamWatchers).hasSize(2);
assertThat(distWatcher.downstreamWatchers).hasSize(2);
PrivateKey testKey = mock(PrivateKey.class);
X509Certificate cert = mock(X509Certificate.class);
List<X509Certificate> testList = ImmutableList.of(cert);
@ -185,7 +185,7 @@ public class CertificateProviderStoreTest {
reset(mockWatcher2);
handle1.close();
assertThat(testCertificateProvider.closeCalled).isEqualTo(0);
assertThat(distWatcher.downsstreamWatchers).hasSize(1);
assertThat(distWatcher.downstreamWatchers).hasSize(1);
testCertificateProvider.getWatcher().updateCertificate(testKey, testList);
verify(mockWatcher1, never())
.updateCertificate(any(PrivateKey.class), anyListOf(X509Certificate.class));
@ -221,7 +221,7 @@ public class CertificateProviderStoreTest {
CertificateProvider.Watcher mockWatcher2 = mock(CertificateProvider.Watcher.class);
CertificateProviderStore.Handle unused = certificateProviderStore.createOrGetProvider(
"cert-name1", "plugin1", "config", mockWatcher2, true);
assertThat(distWatcher.downsstreamWatchers).hasSize(2);
assertThat(distWatcher.downstreamWatchers).hasSize(2);
// updates sent to the second watcher
verify(mockWatcher2, times(1)).updateCertificate(eq(testKey), eq(testList));
verify(mockWatcher2, times(1)).updateTrustedRoots(eq(testList));
@ -323,9 +323,9 @@ public class CertificateProviderStoreTest {
assertThat(testCertificateProvider2.certProviderProvider)
.isSameInstanceAs(certProviderProvider2);
CertificateProvider.DistributorWatcher distWatcher1 = testCertificateProvider1.getWatcher();
assertThat(distWatcher1.downsstreamWatchers).hasSize(1);
assertThat(distWatcher1.downstreamWatchers).hasSize(1);
CertificateProvider.DistributorWatcher distWatcher2 = testCertificateProvider2.getWatcher();
assertThat(distWatcher2.downsstreamWatchers).hasSize(1);
assertThat(distWatcher2.downstreamWatchers).hasSize(1);
PrivateKey testKey1 = mock(PrivateKey.class);
X509Certificate cert1 = mock(X509Certificate.class);
List<X509Certificate> testList1 = ImmutableList.of(cert1);

View File

@ -17,19 +17,25 @@
package io.grpc.xds.internal.certprovider;
import static com.google.common.truth.Truth.assertThat;
import static io.grpc.xds.internal.certprovider.MeshCaCertificateProviderProvider.RPC_TIMEOUT_SECONDS;
import static org.junit.Assert.fail;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isNull;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.google.auth.oauth2.GoogleCredentials;
import io.grpc.internal.BackoffPolicy;
import io.grpc.internal.ExponentialBackoffPolicy;
import io.grpc.internal.TimeProvider;
import io.grpc.xds.internal.sts.StsCredentials;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
@ -59,6 +65,14 @@ public class MeshCaCertificateProviderProviderTest {
@Mock
MeshCaCertificateProvider.Factory meshCaCertificateProviderFactory;
@Mock
private MeshCaCertificateProviderProvider.ScheduledExecutorServiceFactory
scheduledExecutorServiceFactory;
@Mock
private TimeProvider timeProvider;
private MeshCaCertificateProviderProvider provider;
@Before
@ -69,7 +83,9 @@ public class MeshCaCertificateProviderProviderTest {
stsCredentialsFactory,
meshCaChannelFactory,
backoffPolicyProvider,
meshCaCertificateProviderFactory);
meshCaCertificateProviderFactory,
scheduledExecutorServiceFactory,
timeProvider);
}
@Test
@ -94,6 +110,10 @@ public class MeshCaCertificateProviderProviderTest {
CertificateProvider.DistributorWatcher distWatcher =
new CertificateProvider.DistributorWatcher();
Map<String, String> map = buildMinimalMap();
ScheduledExecutorService mockService = mock(ScheduledExecutorService.class);
when(scheduledExecutorServiceFactory.create(
eq(MeshCaCertificateProviderProvider.MESHCA_URL_DEFAULT)))
.thenReturn(mockService);
provider.createCertificateProvider(map, distWatcher, true);
verify(stsCredentialsFactory, times(1))
.create(
@ -114,7 +134,10 @@ public class MeshCaCertificateProviderProviderTest {
eq(backoffPolicyProvider),
eq(MeshCaCertificateProviderProvider.RENEWAL_GRACE_PERIOD_SECONDS_DEFAULT),
eq(MeshCaCertificateProviderProvider.MAX_RETRY_ATTEMPTS_DEFAULT),
(GoogleCredentials) isNull());
(GoogleCredentials) isNull(),
eq(mockService),
eq(timeProvider),
eq(TimeUnit.SECONDS.toMillis(RPC_TIMEOUT_SECONDS)));
}
@Test
@ -150,7 +173,9 @@ public class MeshCaCertificateProviderProviderTest {
CertificateProvider.DistributorWatcher distWatcher =
new CertificateProvider.DistributorWatcher();
Map<String, String> map = buildMinimalMap();
map.put("gkeClusterUrl", "https://container.googleapis.com/v1/project/test-project1/locations/test-zone2/clusters/test-cluster3");
map.put(
"gkeClusterUrl",
"https://container.googleapis.com/v1/project/test-project1/locations/test-zone2/clusters/test-cluster3");
try {
provider.createCertificateProvider(map, distWatcher, true);
fail("exception expected");
@ -164,6 +189,9 @@ public class MeshCaCertificateProviderProviderTest {
CertificateProvider.DistributorWatcher distWatcher =
new CertificateProvider.DistributorWatcher();
Map<String, String> map = buildFullMap();
ScheduledExecutorService mockService = mock(ScheduledExecutorService.class);
when(scheduledExecutorServiceFactory.create(eq(NON_DEFAULT_MESH_CA_URL)))
.thenReturn(mockService);
provider.createCertificateProvider(map, distWatcher, true);
verify(stsCredentialsFactory, times(1))
.create(
@ -184,7 +212,10 @@ public class MeshCaCertificateProviderProviderTest {
eq(backoffPolicyProvider),
eq(4321L),
eq(9),
(GoogleCredentials) isNull());
(GoogleCredentials) isNull(),
eq(mockService),
eq(timeProvider),
eq(TimeUnit.SECONDS.toMillis(RPC_TIMEOUT_SECONDS)));
}
private Map<String, String> buildFullMap() {

View File

@ -0,0 +1,538 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.internal.certprovider;
import static com.google.common.truth.Truth.assertThat;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_0_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE;
import static org.mockito.AdditionalAnswers.delegatesTo;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.google.auth.http.AuthHttpConstants;
import com.google.auth.oauth2.AccessToken;
import com.google.auth.oauth2.GoogleCredentials;
import com.google.common.collect.ImmutableList;
import com.google.common.util.concurrent.MoreExecutors;
import google.security.meshca.v1.MeshCertificateServiceGrpc;
import google.security.meshca.v1.Meshca;
import io.grpc.Context;
import io.grpc.ManagedChannel;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import io.grpc.SynchronizationContext;
import io.grpc.inprocess.InProcessChannelBuilder;
import io.grpc.inprocess.InProcessServerBuilder;
import io.grpc.internal.BackoffPolicy;
import io.grpc.internal.TimeProvider;
import io.grpc.testing.GrpcCleanupRule;
import io.grpc.xds.internal.certprovider.CertificateProvider.DistributorWatcher;
import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil;
import java.io.IOException;
import java.security.NoSuchAlgorithmException;
import java.security.PrivateKey;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.ArrayDeque;
import java.util.Date;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import org.bouncycastle.operator.OperatorCreationException;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatchers;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.mockito.Spy;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
/** Unit tests for {@link MeshCaCertificateProvider}. */
@RunWith(JUnit4.class)
public class MeshCaCertificateProviderTest {
private static final String TEST_STS_TOKEN = "test-stsToken";
private static final long RENEWAL_GRACE_PERIOD_SECONDS = TimeUnit.HOURS.toSeconds(1L);
private static final Metadata.Key<String> KEY_FOR_AUTHORIZATION =
Metadata.Key.of(AuthHttpConstants.AUTHORIZATION, Metadata.ASCII_STRING_MARSHALLER);
private static final String ZONE = "us-west2-a";
private static final long START_DELAY = 200_000_000L; // 0.2 seconds
private static final long[] DELAY_VALUES = {START_DELAY, START_DELAY * 2, START_DELAY * 4};
private static final long RPC_TIMEOUT_MILLIS = 100L;
/**
* Expire time of cert SERVER_0_PEM_FILE.
*/
private static final long CERT0_EXPIRY_TIME_MILLIS = 1899853658000L;
/**
* Cert validity of 12 hours for the above cert.
*/
private static final long CERT0_VALIDITY_MILLIS = TimeUnit.MILLISECONDS
.convert(12, TimeUnit.HOURS);
/**
* Compute current time based on cert expiry and cert validity.
*/
private static final long CURRENT_TIME_NANOS =
TimeUnit.MILLISECONDS.toNanos(CERT0_EXPIRY_TIME_MILLIS - CERT0_VALIDITY_MILLIS);
@Rule
public final GrpcCleanupRule cleanupRule = new GrpcCleanupRule();
private static class ResponseToSend {
Throwable getThrowable() {
throw new UnsupportedOperationException("Called on " + getClass().getCanonicalName());
}
List<String> getList() {
throw new UnsupportedOperationException("Called on " + getClass().getCanonicalName());
}
}
private static class ResponseThrowable extends ResponseToSend {
final Throwable throwableToSend;
ResponseThrowable(Throwable throwable) {
throwableToSend = throwable;
}
@Override
Throwable getThrowable() {
return throwableToSend;
}
}
private static class ResponseList extends ResponseToSend {
final List<String> listToSend;
ResponseList(List<String> list) {
listToSend = list;
}
@Override
List<String> getList() {
return listToSend;
}
}
private final Queue<Meshca.MeshCertificateRequest> receivedRequests = new ArrayDeque<>();
private final Queue<String> receivedStsCreds = new ArrayDeque<>();
private final Queue<String> receivedZoneValues = new ArrayDeque<>();
private final Queue<ResponseToSend> responsesToSend = new ArrayDeque<>();
private final Queue<String> oauth2Tokens = new ArrayDeque<>();
private final AtomicBoolean callEnded = new AtomicBoolean(true);
@Mock private MeshCertificateServiceGrpc.MeshCertificateServiceImplBase mockedMeshCaService;
@Mock private CertificateProvider.Watcher mockWatcher;
@Mock private BackoffPolicy.Provider backoffPolicyProvider;
@Mock private BackoffPolicy backoffPolicy;
@Spy private GoogleCredentials oauth2Creds;
@Mock private ScheduledExecutorService timeService;
@Mock private TimeProvider timeProvider;
private ManagedChannel channel;
private MeshCaCertificateProvider provider;
@Before
public void setUp() throws IOException {
MockitoAnnotations.initMocks(this);
when(backoffPolicyProvider.get()).thenReturn(backoffPolicy);
when(backoffPolicy.nextBackoffNanos())
.thenReturn(DELAY_VALUES[0], DELAY_VALUES[1], DELAY_VALUES[2]);
doAnswer(
new Answer<AccessToken>() {
@Override
public AccessToken answer(InvocationOnMock invocation) throws Throwable {
return new AccessToken(
oauth2Tokens.poll(), new Date(System.currentTimeMillis() + 1000L));
}
})
.when(oauth2Creds)
.refreshAccessToken();
final String meshCaUri = InProcessServerBuilder.generateName();
MeshCertificateServiceGrpc.MeshCertificateServiceImplBase meshCaServiceImpl =
new MeshCertificateServiceGrpc.MeshCertificateServiceImplBase() {
@Override
public void createCertificate(
google.security.meshca.v1.Meshca.MeshCertificateRequest request,
io.grpc.stub.StreamObserver<google.security.meshca.v1.Meshca.MeshCertificateResponse>
responseObserver) {
assertThat(callEnded.get()).isTrue(); // ensure previous call was ended
callEnded.set(false);
Context.current()
.addListener(
new Context.CancellationListener() {
@Override
public void cancelled(Context context) {
callEnded.set(true);
}
},
MoreExecutors.directExecutor());
receivedRequests.offer(request);
ResponseToSend response = responsesToSend.poll();
if (response instanceof ResponseThrowable) {
responseObserver.onError(response.getThrowable());
} else if (response instanceof ResponseList) {
List<String> certChainInResponse = response.getList();
Meshca.MeshCertificateResponse responseToSend =
Meshca.MeshCertificateResponse.newBuilder()
.addAllCertChain(certChainInResponse)
.build();
responseObserver.onNext(responseToSend);
responseObserver.onCompleted();
} else {
callEnded.set(true);
}
}
};
mockedMeshCaService =
mock(
MeshCertificateServiceGrpc.MeshCertificateServiceImplBase.class,
delegatesTo(meshCaServiceImpl));
ServerInterceptor interceptor =
new ServerInterceptor() {
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next) {
receivedStsCreds.offer(headers.get(KEY_FOR_AUTHORIZATION));
receivedZoneValues.offer(headers.get(MeshCaCertificateProvider.KEY_FOR_ZONE_INFO));
return next.startCall(call, headers);
}
};
cleanupRule.register(
InProcessServerBuilder.forName(meshCaUri)
.addService(mockedMeshCaService)
.intercept(interceptor)
.directExecutor()
.build()
.start());
channel =
cleanupRule.register(InProcessChannelBuilder.forName(meshCaUri).directExecutor().build());
MeshCaCertificateProvider.MeshCaChannelFactory channelFactory =
new MeshCaCertificateProvider.MeshCaChannelFactory() {
@Override
ManagedChannel createChannel(String serverUri) {
assertThat(serverUri).isEqualTo(meshCaUri);
return channel;
}
};
CertificateProvider.DistributorWatcher watcher = new CertificateProvider.DistributorWatcher();
watcher.addWatcher(mockWatcher); //
provider =
new MeshCaCertificateProvider(
watcher,
true,
meshCaUri,
ZONE,
TimeUnit.HOURS.toSeconds(9L),
2048,
"RSA",
"SHA256withRSA",
channelFactory,
backoffPolicyProvider,
RENEWAL_GRACE_PERIOD_SECONDS,
MeshCaCertificateProviderProvider.MAX_RETRY_ATTEMPTS_DEFAULT,
oauth2Creds,
timeService,
timeProvider,
RPC_TIMEOUT_MILLIS);
}
@Test
public void startAndClose() {
ScheduledFuture<?> scheduledFuture = mock(ScheduledFuture.class);
doReturn(scheduledFuture)
.when(timeService)
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
provider.start();
SynchronizationContext.ScheduledHandle savedScheduledHandle = provider.scheduledHandle;
assertThat(savedScheduledHandle).isNotNull();
assertThat(savedScheduledHandle.isPending()).isTrue();
verify(timeService, times(1))
.schedule(
any(Runnable.class),
eq(MeshCaCertificateProvider.INITIAL_DELAY_SECONDS),
eq(TimeUnit.SECONDS));
DistributorWatcher distWatcher = provider.getWatcher();
assertThat(distWatcher.downstreamWatchers).hasSize(1);
PrivateKey mockKey = mock(PrivateKey.class);
X509Certificate mockCert = mock(X509Certificate.class);
distWatcher.updateCertificate(mockKey, ImmutableList.of(mockCert));
distWatcher.updateTrustedRoots(ImmutableList.of(mockCert));
provider.close();
assertThat(provider.scheduledHandle).isNull();
assertThat(savedScheduledHandle.isPending()).isFalse();
assertThat(distWatcher.downstreamWatchers).isEmpty();
assertThat(distWatcher.getLastIdentityCert()).isNull();
}
@Test
public void startTwice_noException() {
ScheduledFuture<?> scheduledFuture = mock(ScheduledFuture.class);
doReturn(scheduledFuture)
.when(timeService)
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
provider.start();
SynchronizationContext.ScheduledHandle savedScheduledHandle1 = provider.scheduledHandle;
provider.start();
SynchronizationContext.ScheduledHandle savedScheduledHandle2 = provider.scheduledHandle;
assertThat(savedScheduledHandle2).isNotSameInstanceAs(savedScheduledHandle1);
assertThat(savedScheduledHandle2.isPending()).isTrue();
}
@Test
public void getCertificate()
throws IOException, CertificateException, OperatorCreationException,
NoSuchAlgorithmException {
oauth2Tokens.offer(TEST_STS_TOKEN + "0");
responsesToSend.offer(
new ResponseList(ImmutableList.of(
CommonTlsContextTestsUtil.getResourceContents(SERVER_0_PEM_FILE),
CommonTlsContextTestsUtil.getResourceContents(SERVER_1_PEM_FILE),
CommonTlsContextTestsUtil.getResourceContents(CA_PEM_FILE))));
when(timeProvider.currentTimeNanos()).thenReturn(CURRENT_TIME_NANOS);
ScheduledFuture<?> scheduledFuture = mock(ScheduledFuture.class);
doReturn(scheduledFuture)
.when(timeService)
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
provider.refreshCertificate();
Meshca.MeshCertificateRequest receivedReq = receivedRequests.poll();
assertThat(receivedReq.getValidity().getSeconds()).isEqualTo(TimeUnit.HOURS.toSeconds(9L));
// cannot decode CSR: just check the PEM format delimiters
String csr = receivedReq.getCsr();
assertThat(csr).startsWith("-----BEGIN NEW CERTIFICATE REQUEST-----");
verifyReceivedMetadataValues(1);
verify(timeService, times(1))
.schedule(
any(Runnable.class),
eq(
TimeUnit.MILLISECONDS.toSeconds(
CERT0_VALIDITY_MILLIS
- TimeUnit.SECONDS.toMillis(RENEWAL_GRACE_PERIOD_SECONDS))),
eq(TimeUnit.SECONDS));
verifyMockWatcher();
}
@Test
public void getCertificate_withError()
throws IOException, OperatorCreationException, NoSuchAlgorithmException {
oauth2Tokens.offer(TEST_STS_TOKEN + "0");
responsesToSend
.offer(new ResponseThrowable(new StatusRuntimeException(Status.FAILED_PRECONDITION)));
ScheduledFuture<?> scheduledFuture = mock(ScheduledFuture.class);
doReturn(scheduledFuture).when(timeService)
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
provider.refreshCertificate();
verify(mockWatcher, never())
.updateCertificate(any(PrivateKey.class), ArgumentMatchers.<X509Certificate>anyList());
verify(mockWatcher, never()).updateTrustedRoots(ArgumentMatchers.<X509Certificate>anyList());
verify(mockWatcher, times(1)).onError(Status.FAILED_PRECONDITION);
verify(timeService, times(1)).schedule(any(Runnable.class),
eq(MeshCaCertificateProvider.INITIAL_DELAY_SECONDS),
eq(TimeUnit.SECONDS));
verifyReceivedMetadataValues(1);
}
@Test
public void getCertificate_withError_withExistingCert()
throws IOException, OperatorCreationException, NoSuchAlgorithmException {
PrivateKey mockKey = mock(PrivateKey.class);
X509Certificate mockCert = mock(X509Certificate.class);
// have current cert expire in 3 hours from current time
long threeHoursFromNowMillis = TimeUnit.NANOSECONDS
.toMillis(CURRENT_TIME_NANOS + TimeUnit.HOURS.toNanos(3));
when(mockCert.getNotAfter()).thenReturn(new Date(threeHoursFromNowMillis));
provider.getWatcher().updateCertificate(mockKey, ImmutableList.of(mockCert));
reset(mockWatcher);
oauth2Tokens.offer(TEST_STS_TOKEN + "0");
responsesToSend
.offer(new ResponseThrowable(new StatusRuntimeException(Status.FAILED_PRECONDITION)));
when(timeProvider.currentTimeNanos()).thenReturn(CURRENT_TIME_NANOS);
ScheduledFuture<?> scheduledFuture = mock(ScheduledFuture.class);
doReturn(scheduledFuture).when(timeService)
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
provider.refreshCertificate();
verify(mockWatcher, never())
.updateCertificate(any(PrivateKey.class), ArgumentMatchers.<X509Certificate>anyList());
verify(mockWatcher, never()).updateTrustedRoots(ArgumentMatchers.<X509Certificate>anyList());
verify(mockWatcher, never()).onError(any(Status.class));
verify(timeService, times(1)).schedule(any(Runnable.class),
eq(5400L),
eq(TimeUnit.SECONDS));
assertThat(provider.getWatcher().getLastIdentityCert()).isNotNull();
verifyReceivedMetadataValues(1);
}
@Test
public void getCertificate_withError_withExistingExpiredCert()
throws IOException, OperatorCreationException, NoSuchAlgorithmException {
PrivateKey mockKey = mock(PrivateKey.class);
X509Certificate mockCert = mock(X509Certificate.class);
// have current cert expire in 3 seconds from current time
long threeSecondsFromNowMillis = TimeUnit.NANOSECONDS
.toMillis(CURRENT_TIME_NANOS + TimeUnit.SECONDS.toNanos(3));
when(mockCert.getNotAfter()).thenReturn(new Date(threeSecondsFromNowMillis));
provider.getWatcher().updateCertificate(mockKey, ImmutableList.of(mockCert));
reset(mockWatcher);
oauth2Tokens.offer(TEST_STS_TOKEN + "0");
responsesToSend
.offer(new ResponseThrowable(new StatusRuntimeException(Status.FAILED_PRECONDITION)));
when(timeProvider.currentTimeNanos()).thenReturn(CURRENT_TIME_NANOS);
ScheduledFuture<?> scheduledFuture = mock(ScheduledFuture.class);
doReturn(scheduledFuture).when(timeService)
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
provider.refreshCertificate();
verify(mockWatcher, never())
.updateCertificate(any(PrivateKey.class), ArgumentMatchers.<X509Certificate>anyList());
verify(mockWatcher, never()).updateTrustedRoots(ArgumentMatchers.<X509Certificate>anyList());
verify(mockWatcher, times(1)).onError(Status.FAILED_PRECONDITION);
verify(timeService, times(1)).schedule(any(Runnable.class),
eq(MeshCaCertificateProvider.INITIAL_DELAY_SECONDS),
eq(TimeUnit.SECONDS));
assertThat(provider.getWatcher().getLastIdentityCert()).isNull();
verifyReceivedMetadataValues(1);
}
@Test
public void getCertificate_retriesWithErrors()
throws IOException, CertificateException, OperatorCreationException,
NoSuchAlgorithmException, InterruptedException, ExecutionException, TimeoutException {
oauth2Tokens.offer(TEST_STS_TOKEN + "0");
oauth2Tokens.offer(TEST_STS_TOKEN + "1");
oauth2Tokens.offer(TEST_STS_TOKEN + "2");
responsesToSend.offer(new ResponseThrowable(new StatusRuntimeException(Status.UNKNOWN)));
responsesToSend.offer(
new ResponseThrowable(
new Exception(new StatusRuntimeException(Status.RESOURCE_EXHAUSTED))));
responsesToSend.offer(new ResponseList(ImmutableList.of(
CommonTlsContextTestsUtil.getResourceContents(SERVER_0_PEM_FILE),
CommonTlsContextTestsUtil.getResourceContents(SERVER_1_PEM_FILE),
CommonTlsContextTestsUtil.getResourceContents(CA_PEM_FILE))));
when(timeProvider.currentTimeNanos()).thenReturn(CURRENT_TIME_NANOS);
ScheduledFuture<?> scheduledFuture = mock(ScheduledFuture.class);
doReturn(scheduledFuture).when(timeService)
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
ScheduledFuture<?> scheduledFutureSleep = mock(ScheduledFuture.class);
doReturn(scheduledFutureSleep).when(timeService)
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.NANOSECONDS));
provider.refreshCertificate();
assertThat(receivedRequests.size()).isEqualTo(3);
verify(timeService, times(1)).schedule(any(Runnable.class),
eq(TimeUnit.MILLISECONDS.toSeconds(
CERT0_VALIDITY_MILLIS - TimeUnit.SECONDS.toMillis(RENEWAL_GRACE_PERIOD_SECONDS))),
eq(TimeUnit.SECONDS));
verifyRetriesWithBackoff(scheduledFutureSleep, 2);
verifyMockWatcher();
verifyReceivedMetadataValues(3);
}
@Test
public void getCertificate_retriesWithTimeouts()
throws IOException, CertificateException, OperatorCreationException,
NoSuchAlgorithmException, InterruptedException, ExecutionException, TimeoutException {
oauth2Tokens.offer(TEST_STS_TOKEN + "0");
oauth2Tokens.offer(TEST_STS_TOKEN + "1");
oauth2Tokens.offer(TEST_STS_TOKEN + "2");
oauth2Tokens.offer(TEST_STS_TOKEN + "3");
responsesToSend.offer(new ResponseToSend());
responsesToSend.offer(new ResponseToSend());
responsesToSend.offer(new ResponseToSend());
responsesToSend.offer(new ResponseList(ImmutableList.of(
CommonTlsContextTestsUtil.getResourceContents(SERVER_0_PEM_FILE),
CommonTlsContextTestsUtil.getResourceContents(SERVER_1_PEM_FILE),
CommonTlsContextTestsUtil.getResourceContents(CA_PEM_FILE))));
when(timeProvider.currentTimeNanos()).thenReturn(CURRENT_TIME_NANOS);
ScheduledFuture<?> scheduledFuture = mock(ScheduledFuture.class);
doReturn(scheduledFuture).when(timeService)
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
ScheduledFuture<?> scheduledFutureSleep = mock(ScheduledFuture.class);
doReturn(scheduledFutureSleep).when(timeService)
.schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.NANOSECONDS));
provider.refreshCertificate();
assertThat(receivedRequests.size()).isEqualTo(4);
verify(timeService, times(1)).schedule(any(Runnable.class),
eq(TimeUnit.MILLISECONDS.toSeconds(
CERT0_VALIDITY_MILLIS - TimeUnit.SECONDS.toMillis(RENEWAL_GRACE_PERIOD_SECONDS))),
eq(TimeUnit.SECONDS));
verifyRetriesWithBackoff(scheduledFutureSleep, 3);
verifyMockWatcher();
verifyReceivedMetadataValues(4);
}
private void verifyRetriesWithBackoff(ScheduledFuture<?> scheduledFutureSleep, int numOfRetries)
throws InterruptedException, ExecutionException, TimeoutException {
for (int i = 0; i < numOfRetries; i++) {
long delayValue = DELAY_VALUES[i];
verify(timeService, times(1)).schedule(any(Runnable.class),
eq(delayValue),
eq(TimeUnit.NANOSECONDS));
verify(scheduledFutureSleep, times(1)).get(eq(delayValue), eq(TimeUnit.NANOSECONDS));
}
}
private void verifyMockWatcher() throws IOException, CertificateException {
ArgumentCaptor<List<X509Certificate>> certChainCaptor = ArgumentCaptor.forClass(null);
verify(mockWatcher, times(1))
.updateCertificate(any(PrivateKey.class), certChainCaptor.capture());
List<X509Certificate> certChain = certChainCaptor.getValue();
assertThat(certChain).hasSize(3);
assertThat(certChain.get(0))
.isEqualTo(CommonTlsContextTestsUtil.getCertFromResourceName(SERVER_0_PEM_FILE));
assertThat(certChain.get(1))
.isEqualTo(CommonTlsContextTestsUtil.getCertFromResourceName(SERVER_1_PEM_FILE));
assertThat(certChain.get(2))
.isEqualTo(CommonTlsContextTestsUtil.getCertFromResourceName(CA_PEM_FILE));
ArgumentCaptor<List<X509Certificate>> rootsCaptor = ArgumentCaptor.forClass(null);
verify(mockWatcher, times(1)).updateTrustedRoots(rootsCaptor.capture());
List<X509Certificate> roots = rootsCaptor.getValue();
assertThat(roots).hasSize(1);
assertThat(roots.get(0))
.isEqualTo(CommonTlsContextTestsUtil.getCertFromResourceName(CA_PEM_FILE));
verify(mockWatcher, never()).onError(any(Status.class));
}
private void verifyReceivedMetadataValues(int count) {
assertThat(receivedStsCreds).hasSize(count);
assertThat(receivedZoneValues).hasSize(count);
for (int i = 0; i < count; i++) {
assertThat(receivedStsCreds.poll()).isEqualTo("Bearer " + TEST_STS_TOKEN + i);
assertThat(receivedZoneValues.poll()).isEqualTo("us-west2-a");
}
}
}

View File

@ -16,7 +16,10 @@
package io.grpc.xds.internal.sds;
import static java.nio.charset.StandardCharsets.UTF_8;
import com.google.common.base.Strings;
import com.google.common.io.CharStreams;
import com.google.protobuf.BoolValue;
import com.google.protobuf.Struct;
import com.google.protobuf.Value;
@ -36,7 +39,14 @@ import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContex
import io.envoyproxy.envoy.type.matcher.v3.StringMatcher;
import io.grpc.internal.testing.TestUtils;
import io.grpc.xds.EnvoyServerProtoData;
import io.grpc.xds.internal.sds.trust.CertificateUtils;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.Reader;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.Arrays;
import javax.annotation.Nullable;
@ -432,4 +442,23 @@ public class CommonTlsContextTestsUtil {
return EnvoyServerProtoData.UpstreamTlsContext.fromEnvoyProtoUpstreamTlsContext(
upstreamTlsContext);
}
/** Gets a cert from contents of a resource. */
public static X509Certificate getCertFromResourceName(String resourceName)
throws IOException, CertificateException {
try (ByteArrayInputStream bais =
new ByteArrayInputStream(getResourceContents(resourceName).getBytes(UTF_8))) {
return CertificateUtils.toX509Certificate(bais);
}
}
/** Gets contents of a resource from TestUtils.class loader. */
public static String getResourceContents(String resourceName) throws IOException {
InputStream inputStream = TestUtils.class.getResourceAsStream("/certs/" + resourceName);
String text = null;
try (Reader reader = new InputStreamReader(inputStream, UTF_8)) {
text = CharStreams.toString(reader);
}
return text;
}
}