Update EncryptBaseCipher cache to be index using {baseCipherId, domainId} (#7183)

Description

Major changes proposed in the patch includes:
1. Update EncryptKeyProxy EncyrptBaseCipherKeyId cache to be indexed
   using {encryptDomainId, baseCipherId} instead of only 'baseCipherId'
2. Enhance RESTKmsConnector 'error' tag to encapsulte: errorMessage
   and errorCode information

Testing

1. Updated EncyrptKeyProxy test
2. Updated RESTKmsConnector unit test
This commit is contained in:
Ata E Husain Bohra 2022-05-18 06:16:40 -07:00 committed by GitHub
parent 5205b565ab
commit 728869466d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 56 additions and 23 deletions

View File

@ -80,7 +80,12 @@ struct EncryptBaseCipherKey {
};
using EncryptBaseDomainIdCache = std::unordered_map<EncryptCipherDomainId, EncryptBaseCipherKey>;
using EncryptBaseCipherKeyIdCache = std::unordered_map<EncryptCipherBaseKeyId, EncryptBaseCipherKey>;
using EncryptBaseCipherDomainIdKeyIdCacheKey = std::pair<EncryptCipherDomainId, EncryptCipherBaseKeyId>;
using EncryptBaseCipherDomainIdKeyIdCacheKeyHash = boost::hash<EncryptBaseCipherDomainIdKeyIdCacheKey>;
using EncryptBaseCipherDomainIdKeyIdCache = std::unordered_map<EncryptBaseCipherDomainIdKeyIdCacheKey,
EncryptBaseCipherKey,
EncryptBaseCipherDomainIdKeyIdCacheKeyHash>;
struct EncryptKeyProxyData : NonCopyable, ReferenceCounted<EncryptKeyProxyData> {
public:
@ -89,7 +94,7 @@ public:
Future<Void> encryptionKeyRefresher;
EncryptBaseDomainIdCache baseCipherDomainIdCache;
EncryptBaseCipherKeyIdCache baseCipherKeyIdCache;
EncryptBaseCipherDomainIdKeyIdCache baseCipherDomainIdKeyIdCache;
std::unique_ptr<KmsConnector> kmsConnector;
@ -113,6 +118,12 @@ public:
numResponseWithErrors("EKPNumResponseWithErrors", ekpCacheMetrics),
numEncryptionKeyRefreshErrors("EKPNumEncryptionKeyRefreshErrors", ekpCacheMetrics) {}
EncryptBaseCipherDomainIdKeyIdCacheKey getBaseCipherDomainIdKeyIdCacheKey(
const EncryptCipherDomainId domainId,
const EncryptCipherBaseKeyId baseCipherId) {
return std::make_pair(domainId, baseCipherId);
}
void insertIntoBaseDomainIdCache(const EncryptCipherDomainId domainId,
const EncryptCipherBaseKeyId baseCipherId,
const StringRef baseCipherKey) {
@ -131,7 +142,8 @@ public:
// Given an cipherKey is immutable, it is OK to NOT expire cached information.
// TODO: Update cache to support LRU eviction policy to limit the total cache size.
baseCipherKeyIdCache[baseCipherId] = EncryptBaseCipherKey(domainId, baseCipherId, baseCipherKey, true);
EncryptBaseCipherDomainIdKeyIdCacheKey cacheKey = getBaseCipherDomainIdKeyIdCacheKey(domainId, baseCipherId);
baseCipherDomainIdKeyIdCache[cacheKey] = EncryptBaseCipherKey(domainId, baseCipherId, baseCipherKey, true);
}
template <class Reply>
@ -193,8 +205,10 @@ ACTOR Future<Void> getCipherKeysByBaseCipherKeyIds(Reference<EncryptKeyProxyData
}
for (const auto& item : dedupedCipherIds) {
const auto itr = ekpProxyData->baseCipherKeyIdCache.find(item.first);
if (itr != ekpProxyData->baseCipherKeyIdCache.end()) {
const EncryptBaseCipherDomainIdKeyIdCacheKey cacheKey =
ekpProxyData->getBaseCipherDomainIdKeyIdCacheKey(item.second, item.first);
const auto itr = ekpProxyData->baseCipherDomainIdKeyIdCache.find(cacheKey);
if (itr != ekpProxyData->baseCipherDomainIdKeyIdCache.end()) {
ASSERT(itr->second.isValid());
cachedCipherDetails.emplace_back(
itr->second.domainId, itr->second.baseCipherId, itr->second.baseCipherKey, keyIdsReply.arena);

View File

@ -54,7 +54,8 @@ const char* BASE_CIPHER_TAG = "baseCipher";
const char* CIPHER_KEY_DETAILS_TAG = "cipher_key_details";
const char* ENCRYPT_DOMAIN_ID_TAG = "encrypt_domain_id";
const char* ERROR_TAG = "error";
const char* ERROR_DETAIL_TAG = "details";
const char* ERROR_MSG_TAG = "errMsg";
const char* ERROR_CODE_TAG = "errCode";
const char* KMS_URLS_TAG = "kms_urls";
const char* QUERY_MODE_TAG = "query_mode";
const char* REFRESH_KMS_URLS_TAG = "refresh_kms_urls";
@ -282,7 +283,8 @@ void parseKmsResponse(Reference<RESTKmsConnectorCtx> ctx,
// "url1", "url2", ...
// ],
// "error" : { // Optional, populated by the KMS, if present, rest of payload is ignored.
// "details": <details>
// "errMsg" : <message>
// "errCode": <code>
// }
// }
@ -296,12 +298,26 @@ void parseKmsResponse(Reference<RESTKmsConnectorCtx> ctx,
// Check if response has error
if (doc.HasMember(ERROR_TAG)) {
if (doc[ERROR_TAG].HasMember(ERROR_DETAIL_TAG) && doc[ERROR_TAG][ERROR_DETAIL_TAG].IsString()) {
Standalone<StringRef> errRef = makeString(doc[ERROR_TAG][ERROR_DETAIL_TAG].GetStringLength());
memcpy(mutateString(errRef),
doc[ERROR_TAG][ERROR_DETAIL_TAG].GetString(),
doc[ERROR_TAG][ERROR_DETAIL_TAG].GetStringLength());
TraceEvent("KMSErrorResponse", ctx->uid).detail("ErrorDetails", errRef.toString());
Standalone<StringRef> errMsgRef;
Standalone<StringRef> errCodeRef;
if (doc[ERROR_TAG].HasMember(ERROR_MSG_TAG) && doc[ERROR_TAG][ERROR_MSG_TAG].IsString()) {
errMsgRef = makeString(doc[ERROR_TAG][ERROR_MSG_TAG].GetStringLength());
memcpy(mutateString(errMsgRef),
doc[ERROR_TAG][ERROR_MSG_TAG].GetString(),
doc[ERROR_TAG][ERROR_MSG_TAG].GetStringLength());
}
if (doc[ERROR_TAG].HasMember(ERROR_CODE_TAG) && doc[ERROR_TAG][ERROR_CODE_TAG].IsString()) {
errMsgRef = makeString(doc[ERROR_TAG][ERROR_CODE_TAG].GetStringLength());
memcpy(mutateString(errMsgRef),
doc[ERROR_TAG][ERROR_CODE_TAG].GetString(),
doc[ERROR_TAG][ERROR_CODE_TAG].GetStringLength());
}
if (!errCodeRef.empty() || !errMsgRef.empty()) {
TraceEvent("KMSErrorResponse", ctx->uid)
.detail("ErrorMsg", errMsgRef.empty() ? "" : errMsgRef.toString())
.detail("ErrorCode", errCodeRef.empty() ? "" : errCodeRef.toString());
} else {
TraceEvent("KMSErrorResponse_EmptyDetails", ctx->uid).log();
}
@ -1194,7 +1210,7 @@ void testKMSErrorResponse(Reference<RESTKmsConnectorCtx> ctx) {
rapidjson::Value errorTag(rapidjson::kObjectType);
// Add 'error_detail'
rapidjson::Value eKey(ERROR_DETAIL_TAG, doc.GetAllocator());
rapidjson::Value eKey(ERROR_MSG_TAG, doc.GetAllocator());
rapidjson::Value detailInfo;
detailInfo.SetString("Foo is always bad", doc.GetAllocator());
errorTag.AddMember(eKey, detailInfo, doc.GetAllocator());

View File

@ -43,8 +43,9 @@ struct EncryptKeyProxyTestWorkload : TestWorkload {
Arena arena;
uint64_t minDomainId;
uint64_t maxDomainId;
std::unordered_map<uint64_t, StringRef> cipherIdMap;
std::vector<uint64_t> cipherIds;
using CacheKey = std::pair<int64_t, uint64_t>;
std::unordered_map<CacheKey, StringRef, boost::hash<CacheKey>> cipherIdMap;
std::vector<CacheKey> cipherIds;
int numDomains;
std::vector<uint64_t> domainIds;
static std::atomic<int> seed;
@ -207,8 +208,9 @@ struct EncryptKeyProxyTestWorkload : TestWorkload {
self->cipherIdMap.clear();
self->cipherIds.clear();
for (auto& item : rep.baseCipherDetails) {
self->cipherIdMap.emplace(item.baseCipherId, StringRef(self->arena, item.baseCipherKey));
self->cipherIds.emplace_back(item.baseCipherId);
CacheKey cacheKey = std::make_pair(item.encryptDomainId, item.baseCipherId);
self->cipherIdMap.emplace(cacheKey, StringRef(self->arena, item.baseCipherKey));
self->cipherIds.emplace_back(cacheKey);
}
state int numIterations = deterministicRandom()->randomInt(512, 786);
@ -221,7 +223,7 @@ struct EncryptKeyProxyTestWorkload : TestWorkload {
req.debugId = deterministicRandom()->randomUniqueID();
}
for (int i = idx; i < nIds && i < self->cipherIds.size(); i++) {
req.baseCipherIds.emplace_back(std::make_pair(self->cipherIds[i], 1));
req.baseCipherIds.emplace_back(std::make_pair(self->cipherIds[i].second, self->cipherIds[i].first));
}
if (req.baseCipherIds.empty()) {
// No keys to query; continue
@ -238,9 +240,10 @@ struct EncryptKeyProxyTestWorkload : TestWorkload {
ASSERT_EQ(rep.numHits, expectedHits);
// Valdiate the 'cipherKey' content against the one read while querying by domainIds
for (auto& item : rep.baseCipherDetails) {
const auto itr = self->cipherIdMap.find(item.baseCipherId);
CacheKey cacheKey = std::make_pair(item.encryptDomainId, item.baseCipherId);
const auto itr = self->cipherIdMap.find(cacheKey);
ASSERT(itr != self->cipherIdMap.end());
Standalone<StringRef> toCompare = self->cipherIdMap[item.baseCipherId];
Standalone<StringRef> toCompare = self->cipherIdMap[cacheKey];
if (toCompare.compare(item.baseCipherKey) != 0) {
TraceEvent("Mismatch")
.detail("Id", item.baseCipherId)
@ -264,8 +267,8 @@ struct EncryptKeyProxyTestWorkload : TestWorkload {
// Prepare a lookup with valid and invalid keyIds - SimEncryptKmsProxy should throw encrypt_key_not_found()
std::vector<std::pair<uint64_t, int64_t>> baseCipherIds;
for (auto id : self->cipherIds) {
baseCipherIds.emplace_back(std::make_pair(id, 1));
for (auto item : self->cipherIds) {
baseCipherIds.emplace_back(std::make_pair(item.second, item.first));
}
baseCipherIds.emplace_back(std::make_pair(SERVER_KNOBS->SIM_KMS_MAX_KEYS + 10, 1));
EKPGetBaseCipherKeysByIdsRequest req(deterministicRandom()->randomUniqueID(), baseCipherIds);