Merge branch 'main' into feature-metacluster

# Conflicts:
#	fdbclient/include/fdbclient/Tenant.h
This commit is contained in:
A.J. Beamon 2022-07-28 16:53:29 -07:00
commit e8e4f3ad3a
67 changed files with 1345 additions and 380 deletions

View File

@ -789,7 +789,7 @@ namespace SummarizeTest
int stderrSeverity = (int)Magnesium.Severity.SevError;
Dictionary<KeyValuePair<string, Magnesium.Severity>, Magnesium.Severity> severityMap = new Dictionary<KeyValuePair<string, Magnesium.Severity>, Magnesium.Severity>();
Dictionary<Tuple<string, string>, bool> codeCoverage = new Dictionary<Tuple<string, string>, bool>();
var codeCoverage = new Dictionary<Tuple<string, string, string>, bool>();
foreach (var traceFileName in traceFiles)
{
@ -902,12 +902,17 @@ namespace SummarizeTest
if (ev.Type == "CodeCoverage" && !willRestart)
{
bool covered = true;
if(ev.DDetails.ContainsKey("Covered"))
if (ev.DDetails.ContainsKey("Covered"))
{
covered = int.Parse(ev.Details.Covered) != 0;
}
var key = new Tuple<string, string>(ev.Details.File, ev.Details.Line);
var comment = "";
if (ev.DDetails.ContainsKey("Comment"))
{
comment = ev.Details.Comment;
}
var key = new Tuple<string, string, string>(ev.Details.File, ev.Details.Line, comment);
if (covered || !codeCoverage.ContainsKey(key))
{
codeCoverage[key] = covered;
@ -961,6 +966,9 @@ namespace SummarizeTest
{
element.Add(new XAttribute("Covered", "0"));
}
if (kv.Key.Item3.Length > 0) {
element.Add(new XAttribute("Comment", kv.Key.Item3));
}
xout.Add(element);
}

View File

@ -1393,7 +1393,7 @@ const EncryptCipherRandomSalt encryptSalt = deterministicRandom()->randomUInt64(
Standalone<StringRef> getBaseCipher() {
Standalone<StringRef> baseCipher = makeString(AES_256_KEY_LENGTH);
generateRandomData(mutateString(baseCipher), baseCipher.size());
deterministicRandom()->randomBytes(mutateString(baseCipher), baseCipher.size());
return baseCipher;
}
@ -1413,7 +1413,7 @@ BlobGranuleCipherKeysCtx getCipherKeysCtx(Arena& arena) {
cipherKeysCtx.headerCipherKey.baseCipher = StringRef(arena, encryptBaseCipher);
cipherKeysCtx.ivRef = makeString(AES_256_IV_LENGTH, arena);
generateRandomData(mutateString(cipherKeysCtx.ivRef), AES_256_IV_LENGTH);
deterministicRandom()->randomBytes(mutateString(cipherKeysCtx.ivRef), AES_256_IV_LENGTH);
return cipherKeysCtx;
}
@ -2222,4 +2222,4 @@ TEST_CASE("/blobgranule/files/granuleReadUnitTest") {
}
return Void();
}
}

View File

@ -242,7 +242,7 @@ void DatabaseContext::getLatestCommitVersions(const Reference<LocationInfo>& loc
return;
}
if (ssVersionVectorCache.getMaxVersion() != invalidVersion && readVersion > ssVersionVectorCache.getMaxVersion()) {
if (readVersion > ssVersionVectorCache.getMaxVersion()) {
if (!CLIENT_KNOBS->FORCE_GRV_CACHE_OFF && !info->options.skipGrvCache && info->options.useGrvCache) {
return;
} else {
@ -255,16 +255,32 @@ void DatabaseContext::getLatestCommitVersions(const Reference<LocationInfo>& loc
std::map<Version, std::set<Tag>> versionMap; // order the versions to be returned
for (int i = 0; i < locationInfo->locations()->size(); i++) {
UID uid = locationInfo->locations()->getId(i);
if (ssidTagMapping.find(uid) != ssidTagMapping.end()) {
Tag tag = ssidTagMapping[uid];
bool updatedVersionMap = false;
Version commitVersion = invalidVersion;
Tag tag = invalidTag;
auto iter = ssidTagMapping.find(locationInfo->locations()->getId(i));
if (iter != ssidTagMapping.end()) {
tag = iter->second;
if (ssVersionVectorCache.hasVersion(tag)) {
Version commitVersion = ssVersionVectorCache.getVersion(tag); // latest commit version
commitVersion = ssVersionVectorCache.getVersion(tag); // latest commit version
if (commitVersion < readVersion) {
updatedVersionMap = true;
versionMap[commitVersion].insert(tag);
}
}
}
// commitVersion == readVersion is common, do not log.
if (!updatedVersionMap && commitVersion != readVersion) {
TraceEvent(SevDebug, "CommitVersionNotFoundForSS")
.detail("InSSIDMap", iter != ssidTagMapping.end() ? 1 : 0)
.detail("Tag", tag)
.detail("CommitVersion", commitVersion)
.detail("ReadVersion", readVersion)
.detail("VersionVector", ssVersionVectorCache.toString())
.setMaxEventLength(11000)
.setMaxFieldLength(10000);
++transactionCommitVersionNotFoundForSS;
}
}
// insert the commit versions in the version vector.
@ -710,6 +726,7 @@ ACTOR static Future<Void> delExcessClntTxnEntriesActor(Transaction* tr, int64_t
tr->clear(KeyRangeRef(txEntries[0].key, strinc(endKey)));
TraceEvent(SevInfo, "DeletingExcessCntTxnEntries").detail("BytesToBeDeleted", numBytesToDel);
int64_t bytesDel = -numBytesToDel;
tr->atomicOp(clientLatencyAtomicCtr, StringRef((uint8_t*)&bytesDel, 8), MutationRef::AddValue);
wait(tr->commit());
}
@ -1466,13 +1483,13 @@ DatabaseContext::DatabaseContext(Reference<AsyncVar<Reference<IClusterConnection
transactionsProcessBehind("ProcessBehind", cc), transactionsThrottled("Throttled", cc),
transactionsExpensiveClearCostEstCount("ExpensiveClearCostEstCount", cc),
transactionGrvFullBatches("NumGrvFullBatches", cc), transactionGrvTimedOutBatches("NumGrvTimedOutBatches", cc),
latencies(1000), readLatencies(1000), commitLatencies(1000), GRVLatencies(1000), mutationsPerCommit(1000),
bytesPerCommit(1000), bgLatencies(1000), bgGranulesPerRequest(1000), outstandingWatches(0), sharedStatePtr(nullptr),
lastGrvTime(0.0), cachedReadVersion(0), lastRkBatchThrottleTime(0.0), lastRkDefaultThrottleTime(0.0),
lastProxyRequestTime(0.0), transactionTracingSample(false), taskID(taskID), clientInfo(clientInfo),
clientInfoMonitor(clientInfoMonitor), coordinator(coordinator), apiVersion(apiVersion), mvCacheInsertLocation(0),
healthMetricsLastUpdated(0), detailedHealthMetricsLastUpdated(0),
smoothMidShardSize(CLIENT_KNOBS->SHARD_STAT_SMOOTH_AMOUNT),
transactionCommitVersionNotFoundForSS("CommitVersionNotFoundForSS", cc), latencies(1000), readLatencies(1000),
commitLatencies(1000), GRVLatencies(1000), mutationsPerCommit(1000), bytesPerCommit(1000), bgLatencies(1000),
bgGranulesPerRequest(1000), outstandingWatches(0), sharedStatePtr(nullptr), lastGrvTime(0.0), cachedReadVersion(0),
lastRkBatchThrottleTime(0.0), lastRkDefaultThrottleTime(0.0), lastProxyRequestTime(0.0),
transactionTracingSample(false), taskID(taskID), clientInfo(clientInfo), clientInfoMonitor(clientInfoMonitor),
coordinator(coordinator), apiVersion(apiVersion), mvCacheInsertLocation(0), healthMetricsLastUpdated(0),
detailedHealthMetricsLastUpdated(0), smoothMidShardSize(CLIENT_KNOBS->SHARD_STAT_SMOOTH_AMOUNT),
specialKeySpace(std::make_unique<SpecialKeySpace>(specialKeys.begin, specialKeys.end, /* test */ false)),
connectToDatabaseEventCacheHolder(format("ConnectToDatabase/%s", dbId.toString().c_str())) {
@ -1765,8 +1782,9 @@ DatabaseContext::DatabaseContext(const Error& err)
transactionsProcessBehind("ProcessBehind", cc), transactionsThrottled("Throttled", cc),
transactionsExpensiveClearCostEstCount("ExpensiveClearCostEstCount", cc),
transactionGrvFullBatches("NumGrvFullBatches", cc), transactionGrvTimedOutBatches("NumGrvTimedOutBatches", cc),
latencies(1000), readLatencies(1000), commitLatencies(1000), GRVLatencies(1000), mutationsPerCommit(1000),
bytesPerCommit(1000), bgLatencies(1000), bgGranulesPerRequest(1000), transactionTracingSample(false),
transactionCommitVersionNotFoundForSS("CommitVersionNotFoundForSS", cc), latencies(1000), readLatencies(1000),
commitLatencies(1000), GRVLatencies(1000), mutationsPerCommit(1000), bytesPerCommit(1000), bgLatencies(1000),
bgGranulesPerRequest(1000), transactionTracingSample(false),
smoothMidShardSize(CLIENT_KNOBS->SHARD_STAT_SMOOTH_AMOUNT),
connectToDatabaseEventCacheHolder(format("ConnectToDatabase/%s", dbId.toString().c_str())) {}
@ -1812,7 +1830,7 @@ DatabaseContext::~DatabaseContext() {
TraceEvent("DatabaseContextDestructed", dbId).backtrace();
}
Optional<KeyRangeLocationInfo> DatabaseContext::getCachedLocation(const Optional<TenantName>& tenantName,
Optional<KeyRangeLocationInfo> DatabaseContext::getCachedLocation(const Optional<TenantNameRef>& tenantName,
const KeyRef& key,
Reverse isBackward) {
TenantMapEntry tenantEntry;
@ -1838,7 +1856,7 @@ Optional<KeyRangeLocationInfo> DatabaseContext::getCachedLocation(const Optional
return Optional<KeyRangeLocationInfo>();
}
bool DatabaseContext::getCachedLocations(const Optional<TenantName>& tenantName,
bool DatabaseContext::getCachedLocations(const Optional<TenantNameRef>& tenantName,
const KeyRangeRef& range,
std::vector<KeyRangeLocationInfo>& result,
int limit,
@ -1895,7 +1913,7 @@ void DatabaseContext::cacheTenant(const TenantName& tenant, const TenantMapEntry
}
}
Reference<LocationInfo> DatabaseContext::setCachedLocation(const Optional<TenantName>& tenant,
Reference<LocationInfo> DatabaseContext::setCachedLocation(const Optional<TenantNameRef>& tenant,
const TenantMapEntry& tenantEntry,
const KeyRangeRef& absoluteKeys,
const std::vector<StorageServerInterface>& servers) {
@ -2836,7 +2854,7 @@ void updateTagMappings(Database cx, const GetKeyServerLocationsReply& reply) {
// If isBackward == true, returns the shard containing the key before 'key' (an infinitely long, inexpressible key).
// Otherwise returns the shard containing key
ACTOR Future<KeyRangeLocationInfo> getKeyLocation_internal(Database cx,
Optional<TenantName> tenant,
TenantInfo tenant,
Key key,
SpanContext spanContext,
Optional<UID> debugID,
@ -2859,26 +2877,20 @@ ACTOR Future<KeyRangeLocationInfo> getKeyLocation_internal(Database cx,
++cx->transactionKeyServerLocationRequests;
choose {
when(wait(cx->onProxiesChanged())) {}
when(GetKeyServerLocationsReply rep =
wait(basicLoadBalance(cx->getCommitProxies(useProvisionalProxies),
&CommitProxyInterface::getKeyServersLocations,
GetKeyServerLocationsRequest(span.context,
tenant.castTo<TenantNameRef>(),
key,
Optional<KeyRef>(),
100,
isBackward,
version,
key.arena()),
TaskPriority::DefaultPromiseEndpoint))) {
when(GetKeyServerLocationsReply rep = wait(basicLoadBalance(
cx->getCommitProxies(useProvisionalProxies),
&CommitProxyInterface::getKeyServersLocations,
GetKeyServerLocationsRequest(
span.context, tenant, key, Optional<KeyRef>(), 100, isBackward, version, key.arena()),
TaskPriority::DefaultPromiseEndpoint))) {
++cx->transactionKeyServerLocationRequestsCompleted;
if (debugID.present())
g_traceBatch.addEvent(
"TransactionDebug", debugID.get().first(), "NativeAPI.getKeyLocation.After");
ASSERT(rep.results.size() == 1);
auto locationInfo =
cx->setCachedLocation(tenant, rep.tenantEntry, rep.results[0].first, rep.results[0].second);
auto locationInfo = cx->setCachedLocation(
tenant.name, rep.tenantEntry, rep.results[0].first, rep.results[0].second);
updateTssMappings(cx, rep);
updateTagMappings(cx, rep);
@ -2891,8 +2903,8 @@ ACTOR Future<KeyRangeLocationInfo> getKeyLocation_internal(Database cx,
}
} catch (Error& e) {
if (e.code() == error_code_tenant_not_found) {
ASSERT(tenant.present());
cx->invalidateCachedTenant(tenant.get());
ASSERT(tenant.name.present());
cx->invalidateCachedTenant(tenant.name.get());
}
throw;
@ -2930,7 +2942,7 @@ bool checkOnlyEndpointFailed(const Database& cx, const Endpoint& endpoint) {
template <class F>
Future<KeyRangeLocationInfo> getKeyLocation(Database const& cx,
Optional<TenantName> const& tenant,
TenantInfo const& tenant,
Key const& key,
F StorageServerInterface::*member,
SpanContext spanContext,
@ -2939,7 +2951,7 @@ Future<KeyRangeLocationInfo> getKeyLocation(Database const& cx,
Reverse isBackward,
Version version) {
// we first check whether this range is cached
Optional<KeyRangeLocationInfo> locationInfo = cx->getCachedLocation(tenant, key, isBackward);
Optional<KeyRangeLocationInfo> locationInfo = cx->getCachedLocation(tenant.name, key, isBackward);
if (!locationInfo.present()) {
return getKeyLocation_internal(
cx, tenant, key, spanContext, debugID, useProvisionalProxies, isBackward, version);
@ -2971,7 +2983,7 @@ Future<KeyRangeLocationInfo> getKeyLocation(Reference<TransactionState> trState,
UseTenant useTenant,
Version version) {
auto f = getKeyLocation(trState->cx,
useTenant ? trState->tenant() : Optional<TenantName>(),
useTenant ? trState->getTenantInfo(AllowInvalidTenantID::True) : TenantInfo(),
key,
member,
trState->spanContext,
@ -2992,7 +3004,7 @@ Future<KeyRangeLocationInfo> getKeyLocation(Reference<TransactionState> trState,
ACTOR Future<std::vector<KeyRangeLocationInfo>> getKeyRangeLocations_internal(
Database cx,
Optional<TenantName> tenant,
TenantInfo tenant,
KeyRange keys,
int limit,
Reverse reverse,
@ -3009,18 +3021,12 @@ ACTOR Future<std::vector<KeyRangeLocationInfo>> getKeyRangeLocations_internal(
++cx->transactionKeyServerLocationRequests;
choose {
when(wait(cx->onProxiesChanged())) {}
when(GetKeyServerLocationsReply _rep =
wait(basicLoadBalance(cx->getCommitProxies(useProvisionalProxies),
&CommitProxyInterface::getKeyServersLocations,
GetKeyServerLocationsRequest(span.context,
tenant.castTo<TenantNameRef>(),
keys.begin,
keys.end,
limit,
reverse,
version,
keys.arena()),
TaskPriority::DefaultPromiseEndpoint))) {
when(GetKeyServerLocationsReply _rep = wait(basicLoadBalance(
cx->getCommitProxies(useProvisionalProxies),
&CommitProxyInterface::getKeyServersLocations,
GetKeyServerLocationsRequest(
span.context, tenant, keys.begin, keys.end, limit, reverse, version, keys.arena()),
TaskPriority::DefaultPromiseEndpoint))) {
++cx->transactionKeyServerLocationRequestsCompleted;
state GetKeyServerLocationsReply rep = _rep;
if (debugID.present())
@ -3037,7 +3043,7 @@ ACTOR Future<std::vector<KeyRangeLocationInfo>> getKeyRangeLocations_internal(
rep.tenantEntry,
(toRelativeRange(rep.results[shard].first, rep.tenantEntry.prefix) & keys),
cx->setCachedLocation(
tenant, rep.tenantEntry, rep.results[shard].first, rep.results[shard].second));
tenant.name, rep.tenantEntry, rep.results[shard].first, rep.results[shard].second));
wait(yield());
}
updateTssMappings(cx, rep);
@ -3049,8 +3055,8 @@ ACTOR Future<std::vector<KeyRangeLocationInfo>> getKeyRangeLocations_internal(
}
} catch (Error& e) {
if (e.code() == error_code_tenant_not_found) {
ASSERT(tenant.present());
cx->invalidateCachedTenant(tenant.get());
ASSERT(tenant.name.present());
cx->invalidateCachedTenant(tenant.name.get());
}
throw;
@ -3065,7 +3071,7 @@ ACTOR Future<std::vector<KeyRangeLocationInfo>> getKeyRangeLocations_internal(
// [([a, b1), locationInfo), ([b1, c), locationInfo), ([c, d1), locationInfo)].
template <class F>
Future<std::vector<KeyRangeLocationInfo>> getKeyRangeLocations(Database const& cx,
Optional<TenantName> tenant,
TenantInfo const& tenant,
KeyRange const& keys,
int limit,
Reverse reverse,
@ -3078,7 +3084,7 @@ Future<std::vector<KeyRangeLocationInfo>> getKeyRangeLocations(Database const& c
ASSERT(!keys.empty());
std::vector<KeyRangeLocationInfo> locations;
if (!cx->getCachedLocations(tenant, keys, locations, limit, reverse)) {
if (!cx->getCachedLocations(tenant.name, keys, locations, limit, reverse)) {
return getKeyRangeLocations_internal(
cx, tenant, keys, limit, reverse, spanContext, debugID, useProvisionalProxies, version);
}
@ -3116,7 +3122,7 @@ Future<std::vector<KeyRangeLocationInfo>> getKeyRangeLocations(Reference<Transac
UseTenant useTenant,
Version version) {
auto f = getKeyRangeLocations(trState->cx,
useTenant ? trState->tenant() : Optional<TenantName>(),
useTenant ? trState->getTenantInfo(AllowInvalidTenantID::True) : TenantInfo(),
keys,
limit,
reverse,
@ -3146,7 +3152,7 @@ ACTOR Future<Void> warmRange_impl(Reference<TransactionState> trState, KeyRange
loop {
std::vector<KeyRangeLocationInfo> locations =
wait(getKeyRangeLocations_internal(trState->cx,
trState->tenant(),
trState->getTenantInfo(),
keys,
CLIENT_KNOBS->WARM_RANGE_SHARD_LIMIT,
Reverse::False,
@ -3196,6 +3202,8 @@ SpanContext generateSpanID(bool transactionTracingSample, SpanContext parentCont
deterministicRandom()->randomUniqueID(), deterministicRandom()->randomUInt64(), TraceFlags::unsampled);
}
FDB_DEFINE_BOOLEAN_PARAM(AllowInvalidTenantID);
TransactionState::TransactionState(Database cx,
Optional<TenantName> tenant,
TaskPriority taskID,
@ -3219,12 +3227,13 @@ Reference<TransactionState> TransactionState::cloneAndReset(Reference<Transactio
newState->startTime = startTime;
newState->committedVersion = committedVersion;
newState->conflictingKeys = conflictingKeys;
newState->authToken = authToken;
newState->tenantSet = tenantSet;
return newState;
}
TenantInfo TransactionState::getTenantInfo() {
TenantInfo TransactionState::getTenantInfo(AllowInvalidTenantID allowInvalidId /* = false */) {
Optional<TenantName> const& t = tenant();
if (options.rawAccess) {
@ -3246,8 +3255,8 @@ TenantInfo TransactionState::getTenantInfo() {
}
}
ASSERT(tenantId != TenantInfo::INVALID_TENANT);
return TenantInfo(t.get(), tenantId);
ASSERT(allowInvalidId || tenantId != TenantInfo::INVALID_TENANT);
return TenantInfo(t, authToken, tenantId);
}
// Returns the tenant used in this transaction. If the tenant is unset and raw access isn't specified, then the default
@ -3590,7 +3599,7 @@ ACTOR Future<Version> watchValue(Database cx, Reference<const WatchParameters> p
loop {
state KeyRangeLocationInfo locationInfo = wait(getKeyLocation(cx,
parameters->tenant.name,
parameters->tenant,
parameters->key,
&StorageServerInterface::watchValue,
parameters->spanContext,
@ -3721,7 +3730,7 @@ ACTOR Future<Void> watchStorageServerResp(int64_t tenantId, Key key, Database cx
}
ACTOR Future<Void> sameVersionDiffValue(Database cx, Reference<WatchParameters> parameters) {
state ReadYourWritesTransaction tr(cx, parameters->tenant.name);
state ReadYourWritesTransaction tr(cx, parameters->tenant.name.castTo<TenantName>());
loop {
try {
if (!parameters->tenant.name.present()) {
@ -5950,8 +5959,12 @@ ACTOR void checkWrites(Reference<TransactionState> trState,
}
}
ACTOR static Future<Void> commitDummyTransaction(Reference<TransactionState> trState, KeyRange range) {
state Transaction tr(trState->cx);
FDB_BOOLEAN_PARAM(TenantPrefixPrepended);
ACTOR static Future<Void> commitDummyTransaction(Reference<TransactionState> trState,
KeyRange range,
TenantPrefixPrepended tenantPrefixPrepended) {
state Transaction tr(trState->cx, trState->tenant());
state int retries = 0;
state Span span("NAPI:dummyTransaction"_loc, trState->spanContext);
tr.span.setParent(span.context);
@ -5960,7 +5973,13 @@ ACTOR static Future<Void> commitDummyTransaction(Reference<TransactionState> trS
TraceEvent("CommitDummyTransaction").detail("Key", range.begin).detail("Retries", retries);
tr.trState->options = trState->options;
tr.trState->taskID = trState->taskID;
tr.setOption(FDBTransactionOptions::ACCESS_SYSTEM_KEYS);
tr.trState->authToken = trState->authToken;
if (!trState->hasTenant()) {
tr.setOption(FDBTransactionOptions::RAW_ACCESS);
} else {
tr.trState->skipApplyTenantPrefix = tenantPrefixPrepended;
CODE_PROBE(true, "Commit of a dummy transaction in tenant keyspace");
}
tr.setOption(FDBTransactionOptions::CAUSAL_WRITE_RISKY);
tr.setOption(FDBTransactionOptions::LOCK_AWARE);
tr.addReadConflictRange(range);
@ -6122,6 +6141,7 @@ ACTOR static Future<Void> tryCommit(Reference<TransactionState> trState,
state double startTime = now();
state Span span("NAPI:tryCommit"_loc, trState->spanContext);
state Optional<UID> debugID = trState->debugID;
state TenantPrefixPrepended tenantPrefixPrepended = TenantPrefixPrepended::False;
if (debugID.present()) {
TraceEvent(interval.begin()).detail("Parent", debugID.get());
}
@ -6146,12 +6166,16 @@ ACTOR static Future<Void> tryCommit(Reference<TransactionState> trState,
Reverse::False,
UseTenant::True,
req.transaction.read_snapshot));
applyTenantPrefix(req, locationInfo.tenantEntry.prefix);
// skipApplyTenantPrefix is set only in the context of a commitDummyTransaction()
// (see member declaration)
if (!trState->skipApplyTenantPrefix) {
applyTenantPrefix(req, locationInfo.tenantEntry.prefix);
tenantPrefixPrepended = TenantPrefixPrepended::True;
}
tenantPrefix = locationInfo.tenantEntry.prefix;
}
CODE_PROBE(trState->skipApplyTenantPrefix, "Tenant prefix prepend skipped for dummy transaction");
req.tenantInfo = trState->getTenantInfo();
startTime = now();
state Optional<UID> commitID = Optional<UID>();
@ -6277,7 +6301,8 @@ ACTOR static Future<Void> tryCommit(Reference<TransactionState> trState,
CODE_PROBE(true, "Waiting for dummy transaction to report commit_unknown_result");
wait(commitDummyTransaction(trState, singleKeyRange(selfConflictingRange.begin)));
wait(
commitDummyTransaction(trState, singleKeyRange(selfConflictingRange.begin), tenantPrefixPrepended));
}
// The user needs to be informed that we aren't sure whether the commit happened. Standard retry loops
@ -6657,6 +6682,13 @@ void Transaction::setOption(FDBTransactionOptions::Option option, Optional<Strin
trState->options.rawAccess = true;
break;
case FDBTransactionOptions::AUTHORIZATION_TOKEN:
if (value.present())
trState->authToken = Standalone<StringRef>(value.get());
else
trState->authToken.reset();
break;
default:
break;
}
@ -7236,7 +7268,7 @@ ACTOR Future<StorageMetrics> doGetStorageMetrics(Database cx, KeyRange keys, Ref
ACTOR Future<StorageMetrics> getStorageMetricsLargeKeyRange(Database cx, KeyRange keys) {
state Span span("NAPI:GetStorageMetricsLargeKeyRange"_loc);
std::vector<KeyRangeLocationInfo> locations = wait(getKeyRangeLocations(cx,
Optional<TenantName>(),
TenantInfo(),
keys,
std::numeric_limits<int>::max(),
Reverse::False,
@ -7338,7 +7370,7 @@ ACTOR Future<Standalone<VectorRef<ReadHotRangeWithMetrics>>> getReadHotRanges(Da
// to find the read-hot sub ranges within a read-hot shard.
std::vector<KeyRangeLocationInfo> locations =
wait(getKeyRangeLocations(cx,
Optional<TenantName>(),
TenantInfo(),
keys,
shardLimit,
Reverse::False,
@ -7409,7 +7441,7 @@ ACTOR Future<std::pair<Optional<StorageMetrics>, int>> waitStorageMetrics(Databa
state Span span("NAPI:WaitStorageMetrics"_loc, generateSpanID(cx->transactionTracingSample));
loop {
std::vector<KeyRangeLocationInfo> locations = wait(getKeyRangeLocations(cx,
Optional<TenantName>(),
TenantInfo(),
keys,
shardLimit,
Reverse::False,
@ -7584,7 +7616,7 @@ ACTOR Future<TenantMapEntry> blobGranuleGetTenantEntry(Transaction* self, Key ra
self->trState->cx->getCachedLocation(self->getTenant().get(), rangeStartKey, Reverse::False);
if (!cachedLocationInfo.present()) {
KeyRangeLocationInfo l = wait(getKeyLocation_internal(self->trState->cx,
self->getTenant().get(),
self->trState->getTenantInfo(AllowInvalidTenantID::True),
rangeStartKey,
self->trState->spanContext,
self->trState->debugID,
@ -8014,7 +8046,7 @@ ACTOR Future<Void> splitStorageMetricsStream(PromiseStream<Key> resultStream,
loop {
state std::vector<KeyRangeLocationInfo> locations =
wait(getKeyRangeLocations(cx,
Optional<TenantName>(),
TenantInfo(),
KeyRangeRef(beginKey, keys.end),
CLIENT_KNOBS->STORAGE_METRICS_SHARD_LIMIT,
Reverse::False,
@ -8114,7 +8146,7 @@ ACTOR Future<Standalone<VectorRef<KeyRef>>> splitStorageMetrics(Database cx,
loop {
state std::vector<KeyRangeLocationInfo> locations =
wait(getKeyRangeLocations(cx,
Optional<TenantName>(),
TenantInfo(),
keys,
CLIENT_KNOBS->STORAGE_METRICS_SHARD_LIMIT,
Reverse::False,
@ -8359,7 +8391,7 @@ ACTOR Future<std::vector<CheckpointMetaData>> getCheckpointMetaData(Database cx,
try {
state std::vector<KeyRangeLocationInfo> locations =
wait(getKeyRangeLocations(cx,
Optional<TenantName>(),
TenantInfo(),
keys,
CLIENT_KNOBS->TOO_MANY,
Reverse::False,
@ -9274,7 +9306,7 @@ ACTOR Future<Void> getChangeFeedStreamActor(Reference<DatabaseContext> db,
keys = fullRange & range;
state std::vector<KeyRangeLocationInfo> locations =
wait(getKeyRangeLocations(cx,
Optional<TenantName>(),
TenantInfo(),
keys,
CLIENT_KNOBS->CHANGE_FEED_LOCATION_LIMIT,
Reverse::False,
@ -9453,7 +9485,7 @@ ACTOR Future<OverlappingChangeFeedsInfo> getOverlappingChangeFeedsActor(Referenc
try {
state std::vector<KeyRangeLocationInfo> locations =
wait(getKeyRangeLocations(cx,
Optional<TenantName>(),
TenantInfo(),
range,
CLIENT_KNOBS->CHANGE_FEED_LOCATION_LIMIT,
Reverse::False,
@ -9555,7 +9587,7 @@ ACTOR Future<Void> popChangeFeedMutationsActor(Reference<DatabaseContext> db, Ke
state std::vector<KeyRangeLocationInfo> locations =
wait(getKeyRangeLocations(cx,
Optional<TenantName>(),
TenantInfo(),
keys,
3,
Reverse::False,

View File

@ -129,7 +129,7 @@ const char* TSS_mismatchTraceName(const GetKeyValuesRequest& req) {
static void traceKeyValuesSummary(TraceEvent& event,
const KeySelectorRef& begin,
const KeySelectorRef& end,
Optional<TenantName> tenant,
Optional<TenantNameRef> tenant,
Version version,
int limit,
int limitBytes,
@ -152,7 +152,7 @@ static void traceKeyValuesSummary(TraceEvent& event,
static void traceKeyValuesDiff(TraceEvent& event,
const KeySelectorRef& begin,
const KeySelectorRef& end,
Optional<TenantName> tenant,
Optional<TenantNameRef> tenant,
Version version,
int limit,
int limitBytes,

View File

@ -179,6 +179,8 @@ struct CommitTransactionRequest : TimedRequest {
CommitTransactionRequest() : CommitTransactionRequest(SpanContext()) {}
CommitTransactionRequest(SpanContext const& context) : spanContext(context), flags(0) {}
bool verify() const { return tenantInfo.isAuthorized(); }
template <class Ar>
void serialize(Ar& ar) {
serializer(
@ -284,6 +286,8 @@ struct GetReadVersionRequest : TimedRequest {
}
}
bool verify() const { return true; }
bool operator<(GetReadVersionRequest const& rhs) const { return priority < rhs.priority; }
template <class Ar>
@ -330,7 +334,7 @@ struct GetKeyServerLocationsRequest {
constexpr static FileIdentifier file_identifier = 9144680;
Arena arena;
SpanContext spanContext;
Optional<TenantNameRef> tenant;
TenantInfo tenant;
KeyRef begin;
Optional<KeyRef> end;
int limit;
@ -345,7 +349,7 @@ struct GetKeyServerLocationsRequest {
GetKeyServerLocationsRequest() : limit(0), reverse(false), minTenantVersion(latestVersion) {}
GetKeyServerLocationsRequest(SpanContext spanContext,
Optional<TenantNameRef> const& tenant,
TenantInfo const& tenant,
KeyRef const& begin,
Optional<KeyRef> const& end,
int limit,
@ -355,6 +359,8 @@ struct GetKeyServerLocationsRequest {
: arena(arena), spanContext(spanContext), tenant(tenant), begin(begin), end(end), limit(limit), reverse(reverse),
minTenantVersion(minTenantVersion) {}
bool verify() const { return tenant.isAuthorized(); }
template <class Ar>
void serialize(Ar& ar) {
serializer(ar, begin, end, limit, reverse, reply, spanContext, tenant, minTenantVersion, arena);

View File

@ -242,6 +242,8 @@ struct GetLeaderRequest {
GetLeaderRequest() {}
explicit GetLeaderRequest(Key key, UID kl) : key(key), knownLeader(kl) {}
bool verify() const { return true; }
template <class Ar>
void serialize(Ar& ar) {
serializer(ar, key, knownLeader, reply);
@ -262,6 +264,8 @@ struct OpenDatabaseCoordRequest {
std::vector<NetworkAddress> coordinators;
ReplyPromise<CachedSerialization<struct ClientDBInfo>> reply;
bool verify() const { return true; }
template <class Ar>
void serialize(Ar& ar) {
serializer(ar,

View File

@ -255,16 +255,16 @@ public:
return cx;
}
Optional<KeyRangeLocationInfo> getCachedLocation(const Optional<TenantName>& tenant,
Optional<KeyRangeLocationInfo> getCachedLocation(const Optional<TenantNameRef>& tenant,
const KeyRef&,
Reverse isBackward = Reverse::False);
bool getCachedLocations(const Optional<TenantName>& tenant,
bool getCachedLocations(const Optional<TenantNameRef>& tenant,
const KeyRangeRef&,
std::vector<KeyRangeLocationInfo>&,
int limit,
Reverse reverse);
void cacheTenant(const TenantName& tenant, const TenantMapEntry& tenantEntry);
Reference<LocationInfo> setCachedLocation(const Optional<TenantName>& tenant,
Reference<LocationInfo> setCachedLocation(const Optional<TenantNameRef>& tenant,
const TenantMapEntry& tenantEntry,
const KeyRangeRef&,
const std::vector<struct StorageServerInterface>&);
@ -527,6 +527,7 @@ public:
Counter transactionsExpensiveClearCostEstCount;
Counter transactionGrvFullBatches;
Counter transactionGrvTimedOutBatches;
Counter transactionCommitVersionNotFoundForSS;
ContinuousSample<double> latencies, readLatencies, commitLatencies, GRVLatencies, mutationsPerCommit,
bytesPerCommit, bgLatencies, bgGranulesPerRequest;

View File

@ -199,8 +199,7 @@ struct MetaclusterOperationContext {
Optional<DataClusterMetadata> dataClusterMetadata;
MetaclusterOperationContext(Reference<DB> managementDb, Optional<ClusterName> clusterName = {})
: managementDb(managementDb), clusterName(clusterName) {
}
: managementDb(managementDb), clusterName(clusterName) {}
// Run a transaction on the management cluster. This verifies that the cluster is a management cluster and matches
// the same metacluster that we've run any previous transactions on. If a clusterName is set, it also verifies that

View File

@ -235,9 +235,12 @@ struct Watch : public ReferenceCounted<Watch>, NonCopyable {
void setWatch(Future<Void> watchFuture);
};
FDB_DECLARE_BOOLEAN_PARAM(AllowInvalidTenantID);
struct TransactionState : ReferenceCounted<TransactionState> {
Database cx;
int64_t tenantId = TenantInfo::INVALID_TENANT;
Optional<Standalone<StringRef>> authToken;
Reference<TransactionLogInfo> trLogInfo;
TransactionOptions options;
@ -247,6 +250,13 @@ struct TransactionState : ReferenceCounted<TransactionState> {
UseProvisionalProxies useProvisionalProxies = UseProvisionalProxies::False;
bool readVersionObtainedFromGrvProxy;
// Special flag to skip prepending tenant prefix to mutations and conflict ranges
// when a dummy, internal transaction gets commited. The sole purpose of commitDummyTransaction() is to
// resolve the state of earlier transaction that returned commit_unknown_result or request_maybe_delivered.
// Therefore, the dummy transaction can simply reuse one conflict range of the earlier commit, if it already has
// been prefixed.
bool skipApplyTenantPrefix = false;
int numErrors = 0;
double startTime = 0;
Promise<Standalone<StringRef>> versionstampPromise;
@ -270,7 +280,7 @@ struct TransactionState : ReferenceCounted<TransactionState> {
Reference<TransactionLogInfo> trLogInfo);
Reference<TransactionState> cloneAndReset(Reference<TransactionLogInfo> newTrLogInfo, bool generateNewSpan) const;
TenantInfo getTenantInfo();
TenantInfo getTenantInfo(AllowInvalidTenantID allowInvalidId = AllowInvalidTenantID::False);
Optional<TenantName> const& tenant();
bool hasTenant() const;

View File

@ -32,6 +32,7 @@
#include "fdbrpc/LoadBalance.actor.h"
#include "fdbrpc/Stats.h"
#include "fdbrpc/TimedRequest.h"
#include "fdbrpc/TenantInfo.h"
#include "fdbrpc/TSSComparison.h"
#include "fdbclient/CommitTransaction.h"
#include "fdbclient/TagThrottle.actor.h"
@ -85,13 +86,13 @@ struct StorageServerInterface {
RequestStream<struct ReadHotSubRangeRequest> getReadHotRanges;
RequestStream<struct SplitRangeRequest> getRangeSplitPoints;
PublicRequestStream<struct GetKeyValuesStreamRequest> getKeyValuesStream;
PublicRequestStream<struct ChangeFeedStreamRequest> changeFeedStream;
PublicRequestStream<struct OverlappingChangeFeedsRequest> overlappingChangeFeeds;
PublicRequestStream<struct ChangeFeedPopRequest> changeFeedPop;
PublicRequestStream<struct ChangeFeedVersionUpdateRequest> changeFeedVersionUpdate;
PublicRequestStream<struct GetCheckpointRequest> checkpoint;
PublicRequestStream<struct FetchCheckpointRequest> fetchCheckpoint;
PublicRequestStream<struct FetchCheckpointKeyValuesRequest> fetchCheckpointKeyValues;
RequestStream<struct ChangeFeedStreamRequest> changeFeedStream;
RequestStream<struct OverlappingChangeFeedsRequest> overlappingChangeFeeds;
RequestStream<struct ChangeFeedPopRequest> changeFeedPop;
RequestStream<struct ChangeFeedVersionUpdateRequest> changeFeedVersionUpdate;
RequestStream<struct GetCheckpointRequest> checkpoint;
RequestStream<struct FetchCheckpointRequest> fetchCheckpoint;
RequestStream<struct FetchCheckpointKeyValuesRequest> fetchCheckpointKeyValues;
private:
bool acceptingRequests;
@ -150,18 +151,17 @@ public:
getMappedKeyValues = PublicRequestStream<struct GetMappedKeyValuesRequest>(
getValue.getEndpoint().getAdjustedEndpoint(14));
changeFeedStream =
PublicRequestStream<struct ChangeFeedStreamRequest>(getValue.getEndpoint().getAdjustedEndpoint(15));
overlappingChangeFeeds = PublicRequestStream<struct OverlappingChangeFeedsRequest>(
getValue.getEndpoint().getAdjustedEndpoint(16));
RequestStream<struct ChangeFeedStreamRequest>(getValue.getEndpoint().getAdjustedEndpoint(15));
overlappingChangeFeeds =
RequestStream<struct OverlappingChangeFeedsRequest>(getValue.getEndpoint().getAdjustedEndpoint(16));
changeFeedPop =
PublicRequestStream<struct ChangeFeedPopRequest>(getValue.getEndpoint().getAdjustedEndpoint(17));
changeFeedVersionUpdate = PublicRequestStream<struct ChangeFeedVersionUpdateRequest>(
RequestStream<struct ChangeFeedPopRequest>(getValue.getEndpoint().getAdjustedEndpoint(17));
changeFeedVersionUpdate = RequestStream<struct ChangeFeedVersionUpdateRequest>(
getValue.getEndpoint().getAdjustedEndpoint(18));
checkpoint =
PublicRequestStream<struct GetCheckpointRequest>(getValue.getEndpoint().getAdjustedEndpoint(19));
checkpoint = RequestStream<struct GetCheckpointRequest>(getValue.getEndpoint().getAdjustedEndpoint(19));
fetchCheckpoint =
PublicRequestStream<struct FetchCheckpointRequest>(getValue.getEndpoint().getAdjustedEndpoint(20));
fetchCheckpointKeyValues = PublicRequestStream<struct FetchCheckpointKeyValuesRequest>(
RequestStream<struct FetchCheckpointRequest>(getValue.getEndpoint().getAdjustedEndpoint(20));
fetchCheckpointKeyValues = RequestStream<struct FetchCheckpointKeyValuesRequest>(
getValue.getEndpoint().getAdjustedEndpoint(21));
}
} else {
@ -242,21 +242,6 @@ struct ServerCacheInfo {
}
};
struct TenantInfo {
static const int64_t INVALID_TENANT = -1;
Optional<TenantName> name;
int64_t tenantId;
TenantInfo() : tenantId(INVALID_TENANT) {}
TenantInfo(TenantName name, int64_t tenantId) : name(name), tenantId(tenantId) {}
template <class Ar>
void serialize(Ar& ar) {
serializer(ar, name, tenantId);
}
};
struct GetValueReply : public LoadBalancedReply {
constexpr static FileIdentifier file_identifier = 1378929;
Optional<Value> value;
@ -284,6 +269,8 @@ struct GetValueRequest : TimedRequest {
// to this client, of all storage replicas that
// serve the given key
bool verify() const { return tenantInfo.isAuthorized(); }
GetValueRequest() {}
GetValueRequest(SpanContext spanContext,
const TenantInfo& tenantInfo,
@ -338,6 +325,8 @@ struct WatchValueRequest {
: spanContext(spanContext), tenantInfo(tenantInfo), key(key), value(value), version(ver), tags(tags),
debugID(debugID) {}
bool verify() const { return tenantInfo.isAuthorized(); }
template <class Ar>
void serialize(Ar& ar) {
serializer(ar, key, value, version, tags, debugID, reply, spanContext, tenantInfo);
@ -381,6 +370,8 @@ struct GetKeyValuesRequest : TimedRequest {
GetKeyValuesRequest() : isFetchKeys(false) {}
bool verify() const { return tenantInfo.isAuthorized(); }
template <class Ar>
void serialize(Ar& ar) {
serializer(ar,
@ -437,6 +428,9 @@ struct GetMappedKeyValuesRequest : TimedRequest {
// serve the given key range
GetMappedKeyValuesRequest() : isFetchKeys(false) {}
bool verify() const { return tenantInfo.isAuthorized(); }
template <class Ar>
void serialize(Ar& ar) {
serializer(ar,
@ -503,6 +497,8 @@ struct GetKeyValuesStreamRequest {
GetKeyValuesStreamRequest() : isFetchKeys(false) {}
bool verify() const { return tenantInfo.isAuthorized(); }
template <class Ar>
void serialize(Ar& ar) {
serializer(ar,
@ -550,6 +546,8 @@ struct GetKeyRequest : TimedRequest {
// to this client, of all storage replicas that
// serve the given key
bool verify() const { return tenantInfo.isAuthorized(); }
GetKeyRequest() {}
GetKeyRequest(SpanContext spanContext,

View File

@ -26,6 +26,7 @@
#include "fdbclient/KeyBackedTypes.h"
#include "fdbclient/VersionedMap.h"
#include "fdbclient/KeyBackedTypes.h"
#include "fdbrpc/TenantInfo.h"
#include "flow/flat_buffers.h"
typedef StringRef TenantNameRef;
@ -174,4 +175,4 @@ public:
typedef VersionedMap<TenantName, TenantMapEntry> TenantMap;
typedef VersionedMap<Key, TenantName> TenantPrefixIndex;
#endif
#endif

View File

@ -302,6 +302,10 @@ description is not currently required but encouraged.
<Option name="skip_grv_cache" code="1102"
description="Specifically instruct this transaction to NOT use cached GRV. Primarily used for the read version cache's background updater to avoid attempting to read a cached entry in specific situations."
hidden="true"/>
<Option name="authorization_token" code="2000"
description="Add a given authorization token to the network thread so that future requests are authorized"
paramType="String" paramDescription="A signed token serialized using flatbuffers"
hidden="true" />
</Scope>
<!-- The enumeration values matter - do not change them without

View File

@ -258,7 +258,7 @@ Optional<Standalone<StringRef>> AsyncFileEncrypted::RandomCache::get(uint32_t bl
TEST_CASE("fdbrpc/AsyncFileEncrypted") {
state const int bytes = FLOW_KNOBS->ENCRYPTION_BLOCK_SIZE * deterministicRandom()->randomInt(0, 1000);
state std::vector<unsigned char> writeBuffer(bytes, 0);
generateRandomData(&writeBuffer.front(), bytes);
deterministicRandom()->randomBytes(&writeBuffer.front(), bytes);
state std::vector<unsigned char> readBuffer(bytes, 0);
ASSERT(g_network->isSimulated());
StreamCipherKey::initializeGlobalRandomTestKey();

View File

@ -117,6 +117,9 @@ Optional<StringRef> decode(Arena& arena, StringRef base64UrlStr) {
}
auto out = new (arena) uint8_t[decodedLen];
auto actualLen = decode(base64UrlStr.begin(), base64UrlStr.size(), out);
if (actualLen == -1) {
return {};
}
ASSERT_EQ(decodedLen, actualLen);
return StringRef(out, decodedLen);
}

View File

@ -28,12 +28,15 @@
#include <memcheck.h>
#endif
#include "fdbrpc/TenantInfo.h"
#include <boost/unordered_map.hpp>
#include "fdbrpc/TokenSign.h"
#include "fdbrpc/fdbrpc.h"
#include "fdbrpc/FailureMonitor.h"
#include "fdbrpc/HealthMonitor.h"
#include "fdbrpc/genericactors.actor.h"
#include "fdbrpc/IPAllowList.h"
#include "fdbrpc/TokenCache.h"
#include "fdbrpc/simulator.h"
#include "flow/ActorCollection.h"
#include "flow/Error.h"
@ -47,8 +50,13 @@
#include "flow/xxhash.h"
#include "flow/actorcompiler.h" // This must be the last #include.
static NetworkAddressList g_currentDeliveryPeerAddress = NetworkAddressList();
static Future<Void> g_currentDeliveryPeerDisconnect;
namespace {
NetworkAddressList g_currentDeliveryPeerAddress = NetworkAddressList();
bool g_currentDeliverPeerAddressTrusted = false;
Future<Void> g_currentDeliveryPeerDisconnect;
} // namespace
constexpr int PACKET_LEN_WIDTH = sizeof(uint32_t);
const uint64_t TOKEN_STREAM_FLAG = 1;
@ -239,31 +247,6 @@ struct PingReceiver final : NetworkMessageReceiver {
bool isPublic() const override { return true; }
};
struct TenantAuthorizer final : NetworkMessageReceiver {
TenantAuthorizer(EndpointMap& endpoints) {
endpoints.insertWellKnown(this, Endpoint::wellKnownToken(WLTOKEN_AUTH_TENANT), TaskPriority::ReadSocket);
}
void receive(ArenaObjectReader& reader) override {
AuthorizationRequest req;
try {
reader.deserialize(req);
// TODO: verify that token is valid
AuthorizedTenants& auth = reader.variable<AuthorizedTenants>("AuthorizedTenants");
for (auto const& t : req.tenants) {
auth.authorizedTenants.insert(TenantInfoRef(auth.arena, t));
}
req.reply.send(Void());
} catch (Error& e) {
if (e.code() == error_code_permission_denied) {
req.reply.sendError(e);
} else {
throw;
}
}
}
bool isPublic() const override { return true; }
};
struct UnauthorizedEndpointReceiver final : NetworkMessageReceiver {
UnauthorizedEndpointReceiver(EndpointMap& endpoints) {
endpoints.insertWellKnown(
@ -339,7 +322,6 @@ public:
EndpointMap endpoints;
EndpointNotFoundReceiver endpointNotFoundReceiver{ endpoints };
PingReceiver pingReceiver{ endpoints };
TenantAuthorizer tenantReceiver{ endpoints };
UnauthorizedEndpointReceiver unauthorizedEndpointReceiver{ endpoints };
Int64MetricHandle bytesSent;
@ -356,10 +338,11 @@ public:
double lastIncompatibleMessage;
uint64_t transportId;
IPAllowList allowList;
std::shared_ptr<ContextVariableMap> localCVM = std::make_shared<ContextVariableMap>(); // for local delivery
Future<Void> multiVersionCleanup;
Future<Void> pingLogger;
std::unordered_map<Standalone<StringRef>, PublicKey> publicKeys;
};
ACTOR Future<Void> pingLatencyLogger(TransportData* self) {
@ -926,10 +909,20 @@ void Peer::prependConnectPacket() {
pkt.protocolVersion.addObjectSerializerFlag();
pkt.connectionId = transport->transportId;
PacketBuffer* pb_first = PacketBuffer::create();
PacketBuffer *pb_first = PacketBuffer::create(), *pb_end = nullptr;
PacketWriter wr(pb_first, nullptr, Unversioned());
pkt.serialize(wr);
unsent.prependWriteBuffer(pb_first, wr.finish());
pb_end = wr.finish();
#if VALGRIND
SendBuffer* checkbuf = pb_first;
while (checkbuf) {
int size = checkbuf->bytes_written;
const uint8_t* data = checkbuf->data();
VALGRIND_CHECK_MEM_IS_DEFINED(data, size);
checkbuf = checkbuf->next;
}
#endif
unsent.prependWriteBuffer(pb_first, pb_end);
}
void Peer::discardUnreliablePackets() {
@ -1013,8 +1006,7 @@ ACTOR static void deliver(TransportData* self,
TaskPriority priority,
ArenaReader reader,
NetworkAddress peerAddress,
Reference<AuthorizedTenants> authorizedTenants,
std::shared_ptr<ContextVariableMap> cvm,
bool isTrustedPeer,
InReadSocket inReadSocket,
Future<Void> disconnect) {
// We want to run the task at the right priority. If the priority is higher than the current priority (which is
@ -1029,22 +1021,26 @@ ACTOR static void deliver(TransportData* self,
}
auto receiver = self->endpoints.get(destination.token);
if (receiver && (authorizedTenants->trusted || receiver->isPublic())) {
if (receiver && (isTrustedPeer || receiver->isPublic())) {
if (!checkCompatible(receiver->peerCompatibilityPolicy(), reader.protocolVersion())) {
return;
}
try {
ASSERT(g_currentDeliveryPeerAddress == NetworkAddressList());
ASSERT(!g_currentDeliverPeerAddressTrusted);
g_currentDeliveryPeerAddress = destination.addresses;
g_currentDeliverPeerAddressTrusted = isTrustedPeer;
g_currentDeliveryPeerDisconnect = disconnect;
StringRef data = reader.arenaReadAll();
ASSERT(data.size() > 8);
ArenaObjectReader objReader(reader.arena(), reader.arenaReadAll(), AssumeVersion(reader.protocolVersion()));
objReader.setContextVariableMap(cvm);
receiver->receive(objReader);
g_currentDeliveryPeerAddress = { NetworkAddress() };
g_currentDeliveryPeerAddress = NetworkAddressList();
g_currentDeliverPeerAddressTrusted = false;
g_currentDeliveryPeerDisconnect = Future<Void>();
} catch (Error& e) {
g_currentDeliveryPeerAddress = { NetworkAddress() };
g_currentDeliveryPeerAddress = NetworkAddressList();
g_currentDeliverPeerAddressTrusted = false;
g_currentDeliveryPeerDisconnect = Future<Void>();
TraceEvent(SevError, "ReceiverError")
.error(e)
@ -1092,8 +1088,7 @@ static void scanPackets(TransportData* transport,
const uint8_t* e,
Arena& arena,
NetworkAddress const& peerAddress,
Reference<AuthorizedTenants> const& authorizedTenants,
std::shared_ptr<ContextVariableMap> cvm,
bool isTrustedPeer,
ProtocolVersion peerProtocolVersion,
Future<Void> disconnect,
IsStableConnection isStableConnection) {
@ -1215,8 +1210,7 @@ static void scanPackets(TransportData* transport,
priority,
std::move(reader),
peerAddress,
authorizedTenants,
cvm,
isTrustedPeer,
InReadSocket::True,
disconnect);
}
@ -1263,14 +1257,9 @@ ACTOR static Future<Void> connectionReader(TransportData* transport,
state bool incompatiblePeerCounted = false;
state NetworkAddress peerAddress;
state ProtocolVersion peerProtocolVersion;
state Reference<AuthorizedTenants> authorizedTenants = makeReference<AuthorizedTenants>();
state std::shared_ptr<ContextVariableMap> cvm = std::make_shared<ContextVariableMap>();
state bool trusted = transport->allowList(conn->getPeerAddress().ip);
peerAddress = conn->getPeerAddress();
authorizedTenants->trusted = transport->allowList(conn->getPeerAddress().ip);
(*cvm)["AuthorizedTenants"] = &authorizedTenants;
(*cvm)["PeerAddress"] = &peerAddress;
authorizedTenants->trusted = transport->allowList(peerAddress.ip);
if (!peer) {
ASSERT(!peerAddress.isPublic());
}
@ -1420,8 +1409,7 @@ ACTOR static Future<Void> connectionReader(TransportData* transport,
unprocessed_end,
arena,
peerAddress,
authorizedTenants,
cvm,
trusted,
peerProtocolVersion,
peer->disconnect.getFuture(),
IsStableConnection(g_network->isSimulated() && conn->isStableConnection()));
@ -1572,6 +1560,11 @@ ACTOR static Future<Void> multiVersionCleanupWorker(TransportData* self) {
FlowTransport::FlowTransport(uint64_t transportId, int maxWellKnownEndpoints, IPAllowList const* allowList)
: self(new TransportData(transportId, maxWellKnownEndpoints, allowList)) {
self->multiVersionCleanup = multiVersionCleanupWorker(self);
if (g_network->isSimulated()) {
for (auto const& p : g_simulator.authKeys) {
self->publicKeys.emplace(p.first, p.second.toPublic());
}
}
}
FlowTransport::~FlowTransport() {
@ -1717,15 +1710,12 @@ static void sendLocal(TransportData* self, ISerializeSource const& what, const E
ASSERT(copy.size() > 0);
TaskPriority priority = self->endpoints.getPriority(destination.token);
if (priority != TaskPriority::UnknownEndpoint || (destination.token.first() & TOKEN_STREAM_FLAG) != 0) {
Reference<AuthorizedTenants> authorizedTenants = makeReference<AuthorizedTenants>();
authorizedTenants->trusted = true;
deliver(self,
destination,
priority,
ArenaReader(copy.arena(), copy, AssumeVersion(currentProtocolVersion)),
NetworkAddress(),
authorizedTenants,
self->localCVM,
true,
InReadSocket::False,
Never());
}
@ -1936,6 +1926,7 @@ void FlowTransport::createInstance(bool isClient,
uint64_t transportId,
int maxWellKnownEndpoints,
IPAllowList const* allowList) {
TokenCache::createInstance();
g_network->setGlobal(INetwork::enFlowTransport,
(flowGlobalType) new FlowTransport(transportId, maxWellKnownEndpoints, allowList));
g_network->setGlobal(INetwork::enNetworkAddressFunc, (flowGlobalType)&FlowTransport::getGlobalLocalAddress);
@ -1947,3 +1938,31 @@ void FlowTransport::createInstance(bool isClient,
HealthMonitor* FlowTransport::healthMonitor() {
return &self->healthMonitor;
}
Optional<PublicKey> FlowTransport::getPublicKeyByName(StringRef name) const {
auto iter = self->publicKeys.find(name);
if (iter != self->publicKeys.end()) {
return iter->second;
}
return {};
}
NetworkAddress FlowTransport::currentDeliveryPeerAddress() const {
return g_currentDeliveryPeerAddress.address;
}
bool FlowTransport::currentDeliveryPeerIsTrusted() const {
return g_currentDeliverPeerAddressTrusted;
}
void FlowTransport::addPublicKey(StringRef name, PublicKey key) {
self->publicKeys[name] = key;
}
void FlowTransport::removePublicKey(StringRef name) {
self->publicKeys.erase(name);
}
void FlowTransport::removeAllPublicKeys() {
self->publicKeys.clear();
}

353
fdbrpc/TokenCache.cpp Normal file
View File

@ -0,0 +1,353 @@
#include "fdbrpc/FlowTransport.h"
#include "fdbrpc/TokenCache.h"
#include "fdbrpc/TokenSign.h"
#include "fdbrpc/TenantInfo.h"
#include "flow/MkCert.h"
#include "flow/ScopeExit.h"
#include "flow/UnitTest.h"
#include "flow/network.h"
#include <boost/unordered_map.hpp>
#include <fmt/format.h>
#include <list>
#include <deque>
template <class Key, class Value>
class LRUCache {
public:
using key_type = Key;
using list_type = std::list<key_type>;
using mapped_type = Value;
using map_type = boost::unordered_map<key_type, std::pair<mapped_type, typename list_type::iterator>>;
using size_type = unsigned;
explicit LRUCache(size_type capacity) : _capacity(capacity) { _map.reserve(capacity); }
size_type size() const { return _map.size(); }
size_type capacity() const { return _capacity; }
bool empty() const { return _map.empty(); }
Optional<mapped_type*> get(key_type const& key) {
auto i = _map.find(key);
if (i == _map.end()) {
return Optional<mapped_type*>();
}
auto j = i->second.second;
if (j != _list.begin()) {
_list.erase(j);
_list.push_front(i->first);
i->second.second = _list.begin();
}
return &i->second.first;
}
template <class K, class V>
mapped_type* insert(K&& key, V&& value) {
auto iter = _map.find(key);
if (iter != _map.end()) {
return &iter->second.first;
}
if (size() == capacity()) {
auto i = --_list.end();
_map.erase(*i);
_list.erase(i);
}
_list.push_front(std::forward<K>(key));
std::tie(iter, std::ignore) =
_map.insert(std::make_pair(*_list.begin(), std::make_pair(std::forward<V>(value), _list.begin())));
return &iter->second.first;
}
private:
const size_type _capacity;
map_type _map;
list_type _list;
};
TEST_CASE("/fdbrpc/authz/LRUCache") {
auto& rng = *deterministicRandom();
{
// test very small LRU cache
LRUCache<int, StringRef> cache(rng.randomInt(2, 10));
for (int i = 0; i < 200; ++i) {
cache.insert(i, "val"_sr);
if (i >= cache.capacity()) {
for (auto j = 0; j <= i - cache.capacity(); j++)
ASSERT(!cache.get(j).present());
// ordering is important so as not to disrupt the LRU order
for (auto j = i - cache.capacity() + 1; j <= i; j++)
ASSERT(cache.get(j).present());
}
}
}
{
// Test larger cache
LRUCache<int, StringRef> cache(1000);
for (auto i = 0; i < 1000; ++i) {
cache.insert(i, "value"_sr);
}
cache.insert(1000, "value"_sr); // should evict 0
ASSERT(!cache.get(0).present());
}
{
// memory test -- this is what the boost implementation didn't do correctly
LRUCache<StringRef, Standalone<StringRef>> cache(10);
std::deque<std::string> cachedStrings;
std::deque<std::string> evictedStrings;
for (int i = 0; i < 10; ++i) {
auto str = rng.randomAlphaNumeric(rng.randomInt(100, 1024));
Standalone<StringRef> sref(str);
cache.insert(sref, sref);
cachedStrings.push_back(str);
}
for (int i = 0; i < 10; ++i) {
Standalone<StringRef> existingStr(cachedStrings.back());
auto cachedStr = cache.get(existingStr);
ASSERT(cachedStr.present());
ASSERT(*cachedStr.get() == existingStr);
if (!evictedStrings.empty()) {
Standalone<StringRef> nonexisting(evictedStrings.at(rng.randomInt(0, evictedStrings.size())));
ASSERT(!cache.get(nonexisting).present());
}
auto str = rng.randomAlphaNumeric(rng.randomInt(100, 1024));
Standalone<StringRef> sref(str);
evictedStrings.push_back(cachedStrings.front());
cachedStrings.pop_front();
cachedStrings.push_back(str);
cache.insert(sref, sref);
}
}
return Void();
}
struct TokenCacheImpl {
struct CacheEntry {
Arena arena;
VectorRef<TenantNameRef> tenants;
double expirationTime = 0.0;
};
LRUCache<StringRef, CacheEntry> cache;
TokenCacheImpl() : cache(FLOW_KNOBS->TOKEN_CACHE_SIZE) {}
bool validate(TenantNameRef tenant, StringRef token);
bool validateAndAdd(double currentTime, StringRef token, NetworkAddress const& peer);
};
TokenCache::TokenCache() : impl(new TokenCacheImpl()) {}
TokenCache::~TokenCache() {
delete impl;
}
void TokenCache::createInstance() {
g_network->setGlobal(INetwork::enTokenCache, new TokenCache());
}
TokenCache& TokenCache::instance() {
return *reinterpret_cast<TokenCache*>(g_network->global(INetwork::enTokenCache));
}
bool TokenCache::validate(TenantNameRef name, StringRef token) {
return impl->validate(name, token);
}
#define TRACE_INVALID_PARSED_TOKEN(reason, token) \
TraceEvent(SevWarn, "InvalidToken") \
.detail("From", peer) \
.detail("Reason", reason) \
.detail("Token", token.toStringRef(arena).toStringView())
bool TokenCacheImpl::validateAndAdd(double currentTime, StringRef token, NetworkAddress const& peer) {
Arena arena;
authz::jwt::TokenRef t;
if (!authz::jwt::parseToken(arena, t, token)) {
CODE_PROBE(true, "Token can't be parsed");
TraceEvent(SevWarn, "InvalidToken")
.detail("From", peer)
.detail("Reason", "ParseError")
.detail("Token", token.toString());
return false;
}
auto key = FlowTransport::transport().getPublicKeyByName(t.keyId);
if (!key.present()) {
CODE_PROBE(true, "Token referencing non-existing key");
TRACE_INVALID_PARSED_TOKEN("UnknownKey", t);
return false;
} else if (!t.expiresAtUnixTime.present()) {
CODE_PROBE(true, "Token has no expiration time");
TRACE_INVALID_PARSED_TOKEN("NoExpirationTime", t);
return false;
} else if (double(t.expiresAtUnixTime.get()) <= currentTime) {
CODE_PROBE(true, "Expired token");
TRACE_INVALID_PARSED_TOKEN("Expired", t);
return false;
} else if (!t.notBeforeUnixTime.present()) {
CODE_PROBE(true, "Token has no not-before field");
TRACE_INVALID_PARSED_TOKEN("NoNotBefore", t);
return false;
} else if (double(t.notBeforeUnixTime.get()) > currentTime) {
CODE_PROBE(true, "Tokens not-before is in the future");
TRACE_INVALID_PARSED_TOKEN("TokenNotYetValid", t);
return false;
} else if (!t.tenants.present()) {
CODE_PROBE(true, "Token with no tenants");
TRACE_INVALID_PARSED_TOKEN("NoTenants", t);
return false;
} else if (!authz::jwt::verifyToken(token, key.get())) {
CODE_PROBE(true, "Token with invalid signature");
TRACE_INVALID_PARSED_TOKEN("InvalidSignature", t);
return false;
} else {
CacheEntry c;
c.expirationTime = double(t.expiresAtUnixTime.get());
c.tenants.reserve(c.arena, t.tenants.get().size());
for (auto tenant : t.tenants.get()) {
c.tenants.push_back_deep(c.arena, tenant);
}
cache.insert(StringRef(c.arena, token), c);
return true;
}
}
bool TokenCacheImpl::validate(TenantNameRef name, StringRef token) {
NetworkAddress peer = FlowTransport::transport().currentDeliveryPeerAddress();
auto cachedEntry = cache.get(token);
double currentTime = g_network->timer();
if (!cachedEntry.present()) {
if (validateAndAdd(currentTime, token, peer)) {
cachedEntry = cache.get(token);
} else {
return false;
}
}
ASSERT(cachedEntry.present());
auto& entry = cachedEntry.get();
if (entry->expirationTime < currentTime) {
CODE_PROBE(true, "Found expired token in cache");
TraceEvent(SevWarn, "InvalidToken").detail("From", peer).detail("Reason", "ExpiredInCache");
return false;
}
bool tenantFound = false;
for (auto const& t : entry->tenants) {
if (t == name) {
tenantFound = true;
break;
}
}
if (!tenantFound) {
CODE_PROBE(true, "Valid token doesn't reference tenant");
TraceEvent(SevWarn, "TenantTokenMismatch").detail("From", peer).detail("Tenant", name.toString());
return false;
}
return true;
}
namespace authz::jwt {
extern TokenRef makeRandomTokenSpec(Arena&, IRandom&, authz::Algorithm);
}
TEST_CASE("/fdbrpc/authz/TokenCache/BadTokens") {
std::pair<void (*)(Arena&, IRandom&, authz::jwt::TokenRef&), char const*> badMutations[]{
{
[](Arena&, IRandom&, authz::jwt::TokenRef&) { FlowTransport::transport().removeAllPublicKeys(); },
"NoKeyWithSuchName",
},
{
[](Arena&, IRandom&, authz::jwt::TokenRef& token) { token.expiresAtUnixTime.reset(); },
"NoExpirationTime",
},
{
[](Arena&, IRandom& rng, authz::jwt::TokenRef& token) {
token.expiresAtUnixTime = uint64_t(g_network->timer() - 10 - rng.random01() * 50);
},
"ExpiredToken",
},
{
[](Arena&, IRandom&, authz::jwt::TokenRef& token) { token.notBeforeUnixTime.reset(); },
"NoNotBefore",
},
{
[](Arena&, IRandom& rng, authz::jwt::TokenRef& token) {
token.notBeforeUnixTime = uint64_t(g_network->timer() + 10 + rng.random01() * 50);
},
"TokenNotYetValid",
},
{
[](Arena& arena, IRandom&, authz::jwt::TokenRef& token) { token.tenants.reset(); },
"NoTenants",
},
};
auto const pubKeyName = "somePublicKey"_sr;
auto privateKey = mkcert::makeEcP256();
auto const numBadMutations = sizeof(badMutations) / sizeof(badMutations[0]);
for (auto repeat = 0; repeat < 50; repeat++) {
auto arena = Arena();
auto& rng = *deterministicRandom();
auto validTokenSpec = authz::jwt::makeRandomTokenSpec(arena, rng, authz::Algorithm::ES256);
validTokenSpec.keyId = pubKeyName;
for (auto i = 0; i < numBadMutations; i++) {
FlowTransport::transport().addPublicKey(pubKeyName, privateKey.toPublic());
auto publicKeyClearGuard =
ScopeExit([pubKeyName]() { FlowTransport::transport().removePublicKey(pubKeyName); });
auto [mutationFn, mutationDesc] = badMutations[i];
auto tmpArena = Arena();
auto mutatedTokenSpec = validTokenSpec;
mutationFn(tmpArena, rng, mutatedTokenSpec);
auto signedToken = authz::jwt::signToken(tmpArena, mutatedTokenSpec, privateKey);
if (TokenCache::instance().validate(validTokenSpec.tenants.get()[0], signedToken)) {
fmt::print("Unexpected successful validation at mutation {}, token spec: {}\n",
mutationDesc,
mutatedTokenSpec.toStringRef(tmpArena).toStringView());
ASSERT(false);
}
}
}
if (TokenCache::instance().validate("TenantNameDontMatterHere"_sr, StringRef())) {
fmt::print("Unexpected successful validation of ill-formed token (no signature part)\n");
ASSERT(false);
}
if (TokenCache::instance().validate("TenantNameDontMatterHere"_sr, "1111.22"_sr)) {
fmt::print("Unexpected successful validation of ill-formed token (no signature part)\n");
ASSERT(false);
}
if (TokenCache::instance().validate("TenantNameDontMatterHere2"_sr, "////.////.////"_sr)) {
fmt::print("Unexpected successful validation of unparseable token\n");
ASSERT(false);
}
fmt::print("TEST OK\n");
return Void();
}
TEST_CASE("/fdbrpc/authz/TokenCache/GoodTokens") {
// Don't repeat because token expiry is at seconds granularity and sleeps are costly in unit tests
auto arena = Arena();
auto privateKey = mkcert::makeEcP256();
auto const pubKeyName = "somePublicKey"_sr;
FlowTransport::transport().addPublicKey(pubKeyName, privateKey.toPublic());
auto publicKeyClearGuard = ScopeExit([pubKeyName]() { FlowTransport::transport().removePublicKey(pubKeyName); });
auto& rng = *deterministicRandom();
auto tokenSpec = authz::jwt::makeRandomTokenSpec(arena, rng, authz::Algorithm::ES256);
tokenSpec.expiresAtUnixTime = static_cast<uint64_t>(g_network->timer() + 2.0);
tokenSpec.keyId = pubKeyName;
auto signedToken = authz::jwt::signToken(arena, tokenSpec, privateKey);
if (!TokenCache::instance().validate(tokenSpec.tenants.get()[0], signedToken)) {
fmt::print("Unexpected failed token validation, token spec: {}, now: {}\n",
tokenSpec.toStringRef(arena).toStringView(),
g_network->timer());
ASSERT(false);
}
threadSleep(3.5);
if (TokenCache::instance().validate(tokenSpec.tenants.get()[0], signedToken)) {
fmt::print(
"Unexpected successful token validation after supposedly expiring in cache, token spec: {}, now: {}\n",
tokenSpec.toStringRef(arena).toStringView(),
g_network->timer());
ASSERT(false);
}
fmt::print("TEST OK\n");
return Void();
}

View File

@ -30,6 +30,7 @@
#include "flow/Trace.h"
#include "flow/UnitTest.h"
#include <fmt/format.h>
#include <iterator>
#include <string_view>
#include <type_traits>
#include <utility>
@ -161,6 +162,43 @@ TokenRef makeRandomTokenSpec(Arena& arena, IRandom& rng) {
namespace authz::jwt {
template <class FieldType, size_t NameLen>
void appendField(fmt::memory_buffer& b, char const (&name)[NameLen], Optional<FieldType> const& field) {
if (!field.present())
return;
auto const& f = field.get();
auto bi = std::back_inserter(b);
if constexpr (std::is_same_v<FieldType, VectorRef<StringRef>>) {
fmt::format_to(bi, " {}=[", name);
for (auto i = 0; i < f.size(); i++) {
if (i)
fmt::format_to(bi, ",");
fmt::format_to(bi, f[i].toStringView());
}
fmt::format_to(bi, "]");
} else if constexpr (std::is_same_v<FieldType, StringRef>) {
fmt::format_to(bi, " {}={}", name, f.toStringView());
} else {
fmt::format_to(bi, " {}={}", name, f);
}
}
StringRef TokenRef::toStringRef(Arena& arena) {
auto buf = fmt::memory_buffer();
fmt::format_to(std::back_inserter(buf), "alg={} kid={}", getAlgorithmName(algorithm), keyId.toStringView());
appendField(buf, "iss", issuer);
appendField(buf, "sub", subject);
appendField(buf, "aud", audience);
appendField(buf, "iat", issuedAtUnixTime);
appendField(buf, "exp", expiresAtUnixTime);
appendField(buf, "nbf", notBeforeUnixTime);
appendField(buf, "jti", tokenId);
appendField(buf, "tenants", tenants);
auto str = new (arena) uint8_t[buf.size()];
memcpy(str, buf.data(), buf.size());
return StringRef(str, buf.size());
}
template <class FieldType, class Writer>
void putField(Optional<FieldType> const& field, Writer& wr, const char* fieldName) {
if (!field.present())
@ -192,9 +230,12 @@ StringRef makeTokenPart(Arena& arena, TokenRef tokenSpec) {
header.StartObject();
header.Key("typ");
header.String("JWT");
header.Key("alg");
auto algo = getAlgorithmName(tokenSpec.algorithm);
header.Key("alg");
header.String(algo.data(), algo.size());
auto kid = tokenSpec.keyId.toStringView();
header.Key("kid");
header.String(kid.data(), kid.size());
header.EndObject();
payload.StartObject();
putField(tokenSpec.issuer, payload, "iss");
@ -203,7 +244,6 @@ StringRef makeTokenPart(Arena& arena, TokenRef tokenSpec) {
putField(tokenSpec.issuedAtUnixTime, payload, "iat");
putField(tokenSpec.expiresAtUnixTime, payload, "exp");
putField(tokenSpec.notBeforeUnixTime, payload, "nbf");
putField(tokenSpec.keyId, payload, "kid");
putField(tokenSpec.tokenId, payload, "jti");
putField(tokenSpec.tenants, payload, "tenants");
payload.EndObject();
@ -240,7 +280,7 @@ StringRef signToken(Arena& arena, TokenRef tokenSpec, PrivateKey privateKey) {
return StringRef(out, totalLen);
}
bool parseHeaderPart(TokenRef& token, StringRef b64urlHeader) {
bool parseHeaderPart(Arena& arena, TokenRef& token, StringRef b64urlHeader) {
auto tmpArena = Arena();
auto optHeader = base64url::decode(tmpArena, b64urlHeader);
if (!optHeader.present())
@ -256,24 +296,30 @@ bool parseHeaderPart(TokenRef& token, StringRef b64urlHeader) {
.detail("Offset", d.GetErrorOffset());
return false;
}
auto algItr = d.FindMember("alg");
if (!d.IsObject())
return false;
auto typItr = d.FindMember("typ");
if (d.IsObject() && algItr != d.MemberEnd() && typItr != d.MemberEnd()) {
auto const& alg = algItr->value;
auto const& typ = typItr->value;
if (alg.IsString() && typ.IsString()) {
auto algValue = StringRef(reinterpret_cast<const uint8_t*>(alg.GetString()), alg.GetStringLength());
auto algType = algorithmFromString(algValue);
if (algType == Algorithm::UNKNOWN)
return false;
token.algorithm = algType;
auto typValue = StringRef(reinterpret_cast<const uint8_t*>(typ.GetString()), typ.GetStringLength());
if (typValue != "JWT"_sr)
return false;
return true;
}
}
return false;
if (typItr == d.MemberEnd() || !typItr->value.IsString())
return false;
auto algItr = d.FindMember("alg");
if (algItr == d.MemberEnd() || !algItr->value.IsString())
return false;
auto kidItr = d.FindMember("kid");
if (kidItr == d.MemberEnd() || !kidItr->value.IsString())
return false;
auto const& typ = typItr->value;
auto const& alg = algItr->value;
auto const& kid = kidItr->value;
auto typValue = StringRef(reinterpret_cast<const uint8_t*>(typ.GetString()), typ.GetStringLength());
if (typValue != "JWT"_sr)
return false;
auto algValue = StringRef(reinterpret_cast<const uint8_t*>(alg.GetString()), alg.GetStringLength());
auto algType = algorithmFromString(algValue);
if (algType == Algorithm::UNKNOWN)
return false;
token.algorithm = algType;
token.keyId = StringRef(arena, reinterpret_cast<const uint8_t*>(kid.GetString()), kid.GetStringLength());
return true;
}
template <class FieldType>
@ -343,8 +389,6 @@ bool parsePayloadPart(Arena& arena, TokenRef& token, StringRef b64urlPayload) {
return false;
if (!parseField(arena, token.notBeforeUnixTime, d, "nbf"))
return false;
if (!parseField(arena, token.keyId, d, "kid"))
return false;
if (!parseField(arena, token.tenants, d, "tenants"))
return false;
return true;
@ -358,13 +402,19 @@ bool parseSignaturePart(Arena& arena, TokenRef& token, StringRef b64urlSignature
return true;
}
StringRef signaturePart(StringRef token) {
token.eat("."_sr);
token.eat("."_sr);
return token;
}
bool parseToken(Arena& arena, TokenRef& token, StringRef signedToken) {
auto b64urlHeader = signedToken.eat("."_sr);
auto b64urlPayload = signedToken.eat("."_sr);
auto b64urlSignature = signedToken;
if (b64urlHeader.empty() || b64urlPayload.empty() || b64urlSignature.empty())
return false;
if (!parseHeaderPart(token, b64urlHeader))
if (!parseHeaderPart(arena, token, b64urlHeader))
return false;
if (!parsePayloadPart(arena, token, b64urlPayload))
return false;
@ -387,7 +437,7 @@ bool verifyToken(StringRef signedToken, PublicKey publicKey) {
return false;
auto sig = optSig.get();
auto parsedToken = TokenRef();
if (!parseHeaderPart(parsedToken, b64urlHeader))
if (!parseHeaderPart(arena, parsedToken, b64urlHeader))
return false;
auto [verifyAlgo, digest] = getMethod(parsedToken.algorithm);
if (!checkVerifyAlgorithm(verifyAlgo, publicKey))
@ -401,6 +451,7 @@ TokenRef makeRandomTokenSpec(Arena& arena, IRandom& rng, Algorithm alg) {
}
auto ret = TokenRef{};
ret.algorithm = alg;
ret.keyId = genRandomAlphanumStringRef(arena, rng, MaxKeyNameLenPlus1);
ret.issuer = genRandomAlphanumStringRef(arena, rng, MaxIssuerNameLenPlus1);
ret.subject = genRandomAlphanumStringRef(arena, rng, MaxIssuerNameLenPlus1);
ret.tokenId = genRandomAlphanumStringRef(arena, rng, 31);
@ -410,9 +461,8 @@ TokenRef makeRandomTokenSpec(Arena& arena, IRandom& rng, Algorithm alg) {
aud[i] = genRandomAlphanumStringRef(arena, rng, MaxTenantNameLenPlus1);
ret.audience = VectorRef<StringRef>(aud, numAudience);
ret.issuedAtUnixTime = timer_int() / 1'000'000'000ul;
ret.notBeforeUnixTime = timer_int() / 1'000'000'000ul;
ret.notBeforeUnixTime = ret.issuedAtUnixTime.get();
ret.expiresAtUnixTime = ret.issuedAtUnixTime.get() + rng.randomInt(360, 1080 + 1);
ret.keyId = genRandomAlphanumStringRef(arena, rng, MaxKeyNameLenPlus1);
auto numTenants = rng.randomInt(1, 3);
auto tenants = new (arena) StringRef[numTenants];
for (auto i = 0; i < numTenants; i++)
@ -491,6 +541,33 @@ TEST_CASE("/fdbrpc/TokenSign/JWT") {
return Void();
}
TEST_CASE("/fdbrpc/TokenSign/JWT/ToStringRef") {
auto t = authz::jwt::TokenRef();
t.algorithm = authz::Algorithm::ES256;
t.issuer = "issuer"_sr;
t.subject = "subject"_sr;
StringRef aud[3]{ "aud1"_sr, "aud2"_sr, "aud3"_sr };
t.audience = VectorRef<StringRef>(aud, 3);
t.issuedAtUnixTime = 123ul;
t.expiresAtUnixTime = 456ul;
t.notBeforeUnixTime = 789ul;
t.keyId = "keyId"_sr;
t.tokenId = "tokenId"_sr;
StringRef tenants[2]{ "tenant1"_sr, "tenant2"_sr };
t.tenants = VectorRef<StringRef>(tenants, 2);
auto arena = Arena();
auto tokenStr = t.toStringRef(arena);
auto tokenStrExpected =
"alg=ES256 kid=keyId iss=issuer sub=subject aud=[aud1,aud2,aud3] iat=123 exp=456 nbf=789 jti=tokenId tenants=[tenant1,tenant2]"_sr;
if (tokenStr != tokenStrExpected) {
fmt::print("Expected: {}\nGot : {}\n", tokenStrExpected.toStringView(), tokenStr.toStringView());
ASSERT(false);
} else {
fmt::print("TEST OK\n");
}
return Void();
}
TEST_CASE("/fdbrpc/TokenSign/bench") {
constexpr auto repeat = 5;
constexpr auto numSamples = 10000;

View File

@ -20,25 +20,21 @@
#ifndef FLOW_TRANSPORT_H
#define FLOW_TRANSPORT_H
#include "flow/Arena.h"
#pragma once
#include <algorithm>
#include "fdbrpc/ContinuousSample.h"
#include "fdbrpc/HealthMonitor.h"
#include "flow/genericactors.actor.h"
#include "flow/network.h"
#include "flow/FileIdentifier.h"
#include "flow/ProtocolVersion.h"
#include "flow/Net2Packet.h"
#include "fdbrpc/ContinuousSample.h"
#include "flow/Arena.h"
#include "flow/PKey.h"
enum {
WLTOKEN_ENDPOINT_NOT_FOUND = 0,
WLTOKEN_PING_PACKET,
WLTOKEN_AUTH_TENANT,
WLTOKEN_UNAUTHORIZED_ENDPOINT,
WLTOKEN_FIRST_AVAILABLE
};
enum { WLTOKEN_ENDPOINT_NOT_FOUND = 0, WLTOKEN_PING_PACKET, WLTOKEN_UNAUTHORIZED_ENDPOINT, WLTOKEN_FIRST_AVAILABLE };
#pragma pack(push, 4)
class Endpoint {
@ -191,7 +187,7 @@ struct Peer : public ReferenceCounted<Peer> {
class IPAllowList;
class FlowTransport {
class FlowTransport : NonCopyable {
public:
FlowTransport(uint64_t transportId, int maxWellKnownEndpoints, IPAllowList const* allowList);
~FlowTransport();
@ -293,6 +289,15 @@ public:
HealthMonitor* healthMonitor();
bool currentDeliveryPeerIsTrusted() const;
NetworkAddress currentDeliveryPeerAddress() const;
Optional<PublicKey> getPublicKeyByName(StringRef name) const;
// Adds or replaces a public key
void addPublicKey(StringRef name, PublicKey key);
void removePublicKey(StringRef name);
void removeAllPublicKeys();
private:
class TransportData* self;
};

View File

@ -21,52 +21,61 @@
#pragma once
#ifndef FDBRPC_TENANTINFO_H_
#define FDBRPC_TENANTINFO_H_
#include "fdbrpc/TenantName.h"
#include "fdbrpc/TokenSign.h"
#include "fdbrpc/TokenCache.h"
#include "fdbrpc/FlowTransport.h"
#include "flow/Arena.h"
#include "fdbrpc/fdbrpc.h"
#include <set>
struct TenantInfoRef {
TenantInfoRef() {}
TenantInfoRef(Arena& p, StringRef toCopy) : tenantName(StringRef(p, toCopy)) {}
TenantInfoRef(Arena& p, TenantInfoRef toCopy)
: tenantName(toCopy.tenantName.present() ? Optional<StringRef>(StringRef(p, toCopy.tenantName.get()))
: Optional<StringRef>()) {}
// Empty tenant name means that the peer is trusted
Optional<StringRef> tenantName;
struct TenantInfo {
static constexpr const int64_t INVALID_TENANT = -1;
bool operator<(TenantInfoRef const& other) const {
if (!other.tenantName.present()) {
return false;
}
if (!tenantName.present()) {
return true;
}
return tenantName.get() < other.tenantName.get();
}
template <class Ar>
void serialize(Ar& ar) {
serializer(ar, tenantName);
}
};
struct AuthorizedTenants : ReferenceCounted<AuthorizedTenants> {
Arena arena;
std::set<TenantInfoRef> authorizedTenants;
Optional<TenantNameRef> name;
Optional<StringRef> token;
int64_t tenantId;
// this field is not serialized and instead set by FlowTransport during
// deserialization. This field indicates whether the client is trusted.
// Untrusted clients are generally expected to set a TenantName
bool trusted = false;
// Is set during deserialization. It will be set to true if the tenant
// name is set and the client is authorized to use this tenant.
bool tenantAuthorized = false;
// Helper function for most endpoints that read/write data. This returns true iff
// the client is either a) a trusted peer or b) is accessing keyspace belonging to a tenant,
// for which it has a valid authorization token.
// NOTE: In a cluster where TenantMode is OPTIONAL or DISABLED, tenant name may be unset.
// In such case, the request containing such TenantInfo is valid iff the requesting peer is trusted.
bool isAuthorized() const { return trusted || tenantAuthorized; }
TenantInfo() : tenantId(INVALID_TENANT) {}
TenantInfo(Optional<TenantName> const& tenantName, Optional<Standalone<StringRef>> const& token, int64_t tenantId)
: tenantId(tenantId) {
if (tenantName.present()) {
arena.dependsOn(tenantName.get().arena());
name = tenantName.get();
}
if (token.present()) {
arena.dependsOn(token.get().arena());
this->token = token.get();
}
}
};
// TODO: receive and validate token instead
struct AuthorizationRequest {
constexpr static FileIdentifier file_identifier = 11499331;
Arena arena;
VectorRef<TenantInfoRef> tenants;
ReplyPromise<Void> reply;
template <class Ar>
void serialize(Ar& ar) {
serializer(ar, tenants, reply, arena);
template <>
struct serializable_traits<TenantInfo> : std::true_type {
template <class Archiver>
static void serialize(Archiver& ar, TenantInfo& v) {
serializer(ar, v.name, v.tenantId, v.token, v.arena);
if constexpr (Archiver::isDeserializing) {
bool tenantAuthorized = false;
if (v.name.present() && v.token.present()) {
tenantAuthorized = TokenCache::instance().validate(v.name.get(), v.token.get());
}
v.trusted = FlowTransport::transport().currentDeliveryPeerIsTrusted();
v.tenantAuthorized = tenantAuthorized;
}
}
};

View File

@ -0,0 +1,27 @@
/*
* TenantName.h
*
* This source file is part of the FoundationDB open source project
*
* Copyright 2013-2022 Apple Inc. and the FoundationDB project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#ifndef FDBRPC_TENANTNAME_H
#define FDBRPC_TENANTNAME_H
#include "flow/Arena.h"
typedef StringRef TenantNameRef;
typedef Standalone<TenantNameRef> TenantName;
#endif // FDBRPC_TENANTNAME_H

View File

@ -0,0 +1,37 @@
/*
* TokenCache.h
*
* This source file is part of the FoundationDB open source project
*
* Copyright 2013-2022 Apple Inc. and the FoundationDB project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef TOKENCACHE_H_
#define TOKENCACHE_H_
#include "fdbrpc/TenantName.h"
#include "flow/Arena.h"
class TokenCache : NonCopyable {
struct TokenCacheImpl* impl;
TokenCache();
public:
~TokenCache();
static void createInstance();
static TokenCache& instance();
bool validate(TenantNameRef tenant, StringRef token);
};
#endif // TOKENCACHE_H_

View File

@ -26,6 +26,7 @@
#include "flow/Arena.h"
#include "flow/FastRef.h"
#include "flow/FileIdentifier.h"
#include "fdbrpc/TenantInfo.h"
#include "flow/PKey.h"
namespace authz {
@ -63,6 +64,8 @@ struct SignedTokenRef {
void serialize(Ar& ar) {
serializer(ar, token, keyName, signature);
}
int expectedSize() const { return token.size() + keyName.size() + signature.size(); }
};
SignedTokenRef signToken(Arena& arena, TokenRef token, StringRef keyName, PrivateKey privateKey);
@ -82,6 +85,7 @@ namespace authz::jwt {
struct TokenRef {
// header part ("typ": "JWT" implicitly enforced)
Algorithm algorithm; // alg
StringRef keyId; // kid
// payload part
Optional<StringRef> issuer; // iss
Optional<StringRef> subject; // sub
@ -89,11 +93,13 @@ struct TokenRef {
Optional<uint64_t> issuedAtUnixTime; // iat
Optional<uint64_t> expiresAtUnixTime; // exp
Optional<uint64_t> notBeforeUnixTime; // nbf
Optional<StringRef> keyId; // kid
Optional<StringRef> tokenId; // jti
Optional<VectorRef<StringRef>> tenants; // tenants
// signature part
StringRef signature;
// print each non-signature field in non-JSON, human-readable format e.g. for trace
StringRef toStringRef(Arena& arena);
};
// Make plain JSON token string with fields (except signature) from passed spec
@ -107,7 +113,7 @@ StringRef signToken(Arena& arena, TokenRef tokenSpec, PrivateKey privateKey);
// Parse passed b64url-encoded header part and materialize its contents into tokenOut,
// using memory allocated from arena
bool parseHeaderPart(TokenRef& tokenOut, StringRef b64urlHeaderIn);
bool parseHeaderPart(Arena& arena, TokenRef& tokenOut, StringRef b64urlHeaderIn);
// Parse passed b64url-encoded payload part and materialize its contents into tokenOut,
// using memory allocated from arena
@ -117,6 +123,9 @@ bool parsePayloadPart(Arena& arena, TokenRef& tokenOut, StringRef b64urlPayloadI
// using memory allocated from arena
bool parseSignaturePart(Arena& arena, TokenRef& tokenOut, StringRef b64urlSignatureIn);
// Returns the base64 encoded signature of the token
StringRef signaturePart(StringRef token);
// Parse passed token string and materialize its contents into tokenOut,
// using memory allocated from arena
// Return whether the signed token string is well-formed

View File

@ -28,28 +28,28 @@
* All well-known endpoints of FDB must be listed here to guarantee their uniqueness
*/
enum WellKnownEndpoints {
WLTOKEN_CLIENTLEADERREG_GETLEADER = WLTOKEN_FIRST_AVAILABLE, // 4
WLTOKEN_CLIENTLEADERREG_OPENDATABASE, // 5
WLTOKEN_LEADERELECTIONREG_CANDIDACY, // 6
WLTOKEN_LEADERELECTIONREG_ELECTIONRESULT, // 7
WLTOKEN_LEADERELECTIONREG_LEADERHEARTBEAT, // 8
WLTOKEN_LEADERELECTIONREG_FORWARD, // 9
WLTOKEN_CLIENTLEADERREG_GETLEADER = WLTOKEN_FIRST_AVAILABLE, // 3
WLTOKEN_CLIENTLEADERREG_OPENDATABASE, // 4
WLTOKEN_LEADERELECTIONREG_CANDIDACY, // 5
WLTOKEN_LEADERELECTIONREG_ELECTIONRESULT, // 6
WLTOKEN_LEADERELECTIONREG_LEADERHEARTBEAT, // 7
WLTOKEN_LEADERELECTIONREG_FORWARD, // 8
WLTOKEN_GENERATIONREG_READ, // 9
WLTOKEN_PROTOCOL_INFO, // 10 : the value of this endpoint should be stable and not change.
WLTOKEN_GENERATIONREG_READ, // 11
WLTOKEN_GENERATIONREG_WRITE, // 12
WLTOKEN_CLIENTLEADERREG_DESCRIPTOR_MUTABLE, // 13
WLTOKEN_CONFIGTXN_GETGENERATION, // 14
WLTOKEN_CONFIGTXN_GET, // 15
WLTOKEN_CONFIGTXN_GETCLASSES, // 16
WLTOKEN_CONFIGTXN_GETKNOBS, // 17
WLTOKEN_CONFIGTXN_COMMIT, // 18
WLTOKEN_CONFIGFOLLOWER_GETSNAPSHOTANDCHANGES, // 19
WLTOKEN_CONFIGFOLLOWER_GETCHANGES, // 20
WLTOKEN_CONFIGFOLLOWER_COMPACT, // 21
WLTOKEN_CONFIGFOLLOWER_ROLLFORWARD, // 22
WLTOKEN_CONFIGFOLLOWER_GETCOMMITTEDVERSION, // 23
WLTOKEN_PROCESS, // 24
WLTOKEN_RESERVED_COUNT // 25
WLTOKEN_GENERATIONREG_WRITE, // 11
WLTOKEN_CLIENTLEADERREG_DESCRIPTOR_MUTABLE, // 12
WLTOKEN_CONFIGTXN_GETGENERATION, // 13
WLTOKEN_CONFIGTXN_GET, // 14
WLTOKEN_CONFIGTXN_GETCLASSES, // 15
WLTOKEN_CONFIGTXN_GETKNOBS, // 16
WLTOKEN_CONFIGTXN_COMMIT, // 17
WLTOKEN_CONFIGFOLLOWER_GETSNAPSHOTANDCHANGES, // 18
WLTOKEN_CONFIGFOLLOWER_GETCHANGES, // 19
WLTOKEN_CONFIGFOLLOWER_COMPACT, // 20
WLTOKEN_CONFIGFOLLOWER_ROLLFORWARD, // 21
WLTOKEN_CONFIGFOLLOWER_GETCOMMITTEDVERSION, // 22
WLTOKEN_PROCESS, // 23
WLTOKEN_RESERVED_COUNT // 24
};
static_assert(WLTOKEN_PROTOCOL_INFO ==

View File

@ -648,10 +648,29 @@ struct serializable_traits<ReplyPromiseStream<T>> : std::true_type {
}
};
template <class T, class = int>
struct HasReply_t : std::false_type {};
template <class T>
struct HasReply_t<T, decltype((void)T::reply, 0)> : std::true_type {};
template <class T>
constexpr bool HasReply = HasReply_t<T>::value;
template <class T, class = int>
struct HasVerify_t : std::false_type {};
template <class T>
struct HasVerify_t<T, decltype(void(std::declval<T>().verify()), 0)> : std::true_type {};
template <class T>
constexpr bool HasVerify = HasVerify_t<T>::value;
template <class T, bool IsPublic>
struct NetNotifiedQueue final : NotifiedQueue<T>, FlowReceiver, FastAllocated<NetNotifiedQueue<T, IsPublic>> {
using FastAllocated<NetNotifiedQueue<T, IsPublic>>::operator new;
using FastAllocated<NetNotifiedQueue<T, IsPublic>>::operator delete;
static_assert(!IsPublic || HasVerify<T>, "Public request stream objects need to implement bool T::verify()");
NetNotifiedQueue(int futures, int promises) : NotifiedQueue<T>(futures, promises) {}
NetNotifiedQueue(int futures, int promises, const Endpoint& remoteEndpoint)
@ -662,7 +681,17 @@ struct NetNotifiedQueue final : NotifiedQueue<T>, FlowReceiver, FastAllocated<Ne
this->addPromiseRef();
T message;
reader.deserialize(message);
this->send(std::move(message));
if constexpr (IsPublic) {
if (!message.verify()) {
if constexpr (HasReply<T>) {
message.reply.sendError(permission_denied());
}
} else {
this->send(std::move(message));
}
} else {
this->send(std::move(message));
}
this->delPromiseRef();
}
bool isStream() const override { return true; }

View File

@ -20,20 +20,23 @@
#ifndef FLOW_SIMULATOR_H
#define FLOW_SIMULATOR_H
#include "flow/ProtocolVersion.h"
#pragma once
#include <algorithm>
#include <string>
#include <random>
#include <limits>
#pragma once
#include "flow/flow.h"
#include "flow/Histogram.h"
#include "flow/ProtocolVersion.h"
#include "fdbrpc/FailureMonitor.h"
#include "fdbrpc/Locality.h"
#include "flow/IAsyncFile.h"
#include "flow/TDMetric.actor.h"
#include <random>
#include "fdbrpc/FailureMonitor.h"
#include "fdbrpc/Locality.h"
#include "fdbrpc/ReplicationPolicy.h"
#include "fdbrpc/TokenSign.h"
enum ClogMode { ClogDefault, ClogAll, ClogSend, ClogReceive };
@ -492,6 +495,8 @@ public:
double injectTargetedSSRestartTime = std::numeric_limits<double>::max();
double injectSSDelayTime = std::numeric_limits<double>::max();
std::unordered_map<Standalone<StringRef>, PrivateKey> authKeys;
flowGlobalType global(int id) const final { return getCurrentProcess()->global(id); };
void setGlobal(size_t id, flowGlobalType v) final { getCurrentProcess()->setGlobal(id, v); };

View File

@ -22,6 +22,7 @@
#include <memory>
#include <string>
#include "flow/MkCert.h"
#include "fmt/format.h"
#include "fdbrpc/simulator.h"
#include "flow/Arena.h"
@ -2178,6 +2179,9 @@ public:
this,
"",
"");
// create a key pair for AuthZ testing
auto key = mkcert::makeEcP256();
authKeys.insert(std::make_pair(Standalone<StringRef>("DefaultKey"_sr), key));
g_network = net2 = newNet2(TLSConfig(), false, true);
g_network->addStopCallback(Net2FileSystem::stop);
Net2FileSystem::newFileSystem();

View File

@ -314,7 +314,7 @@ ACTOR Future<BlobGranuleCipherKeysCtx> getLatestGranuleCipherKeys(Reference<Blob
cipherKeysCtx.headerCipherKey = BlobGranuleCipherKey::fromBlobCipherKey(systemCipherKeys.cipherHeaderKey, *arena);
cipherKeysCtx.ivRef = makeString(AES_256_IV_LENGTH, *arena);
generateRandomData(mutateString(cipherKeysCtx.ivRef), AES_256_IV_LENGTH);
deterministicRandom()->randomBytes(mutateString(cipherKeysCtx.ivRef), AES_256_IV_LENGTH);
if (BG_ENCRYPT_COMPRESS_DEBUG) {
TraceEvent(SevDebug, "GetLatestGranuleCipherKey")

View File

@ -913,8 +913,9 @@ ACTOR Future<Void> getResolution(CommitBatchContext* self) {
};
std::unordered_map<EncryptCipherDomainId, EncryptCipherDomainName> encryptDomains = defaultDomains;
for (int t = 0; t < trs.size(); t++) {
int64_t tenantId = trs[t].tenantInfo.tenantId;
Optional<TenantName> tenantName = trs[t].tenantInfo.name;
TenantInfo const& tenantInfo = trs[t].tenantInfo;
int64_t tenantId = tenantInfo.tenantId;
Optional<TenantNameRef> const& tenantName = tenantInfo.name;
// TODO(yiwu): In raw access mode, use tenant prefix to figure out tenant id for user data
if (tenantId != TenantInfo::INVALID_TENANT) {
ASSERT(tenantName.present());
@ -1845,7 +1846,7 @@ ACTOR static Future<Void> doKeyServerLocationRequest(GetKeyServerLocationsReques
while (tenantEntry.isError()) {
bool finalQuery = commitData->version.get() >= minTenantVersion;
ErrorOr<Optional<TenantMapEntry>> _tenantEntry =
getTenantEntry(commitData, req.tenant, Optional<int64_t>(), finalQuery);
getTenantEntry(commitData, req.tenant.name, Optional<int64_t>(), finalQuery);
tenantEntry = _tenantEntry;
if (tenantEntry.isError()) {

View File

@ -1081,6 +1081,9 @@ ShardsAffectedByTeamFailure::getTeamsFor(KeyRangeRef keys) {
}
void ShardsAffectedByTeamFailure::erase(Team team, KeyRange const& range) {
DisabledTraceEvent(SevDebug, "ShardsAffectedByTeamFailureErase")
.detail("Range", range)
.detail("Team", team.toString());
if (team_shards.erase(std::pair<Team, KeyRange>(team, range)) > 0) {
for (auto uid = team.servers.begin(); uid != team.servers.end(); ++uid) {
// Safeguard against going negative after eraseServer() sets value to 0
@ -1092,6 +1095,9 @@ void ShardsAffectedByTeamFailure::erase(Team team, KeyRange const& range) {
}
void ShardsAffectedByTeamFailure::insert(Team team, KeyRange const& range) {
DisabledTraceEvent(SevDebug, "ShardsAffectedByTeamFailureInsert")
.detail("Range", range)
.detail("Team", team.toString());
if (team_shards.insert(std::pair<Team, KeyRange>(team, range)).second) {
for (auto uid = team.servers.begin(); uid != team.servers.end(); ++uid)
storageServerShards[*uid]++;

View File

@ -303,6 +303,9 @@ rocksdb::Options getOptions() {
// TODO: enable rocksdb metrics.
options.db_log_dir = SERVER_KNOBS->LOG_DIRECTORY;
if (g_network->isSimulated()) {
options.OptimizeForSmallDb();
}
return options;
}

View File

@ -1712,15 +1712,19 @@ ACTOR static Future<Void> finishMoveShards(Database occ,
Void(),
TaskPriority::MoveKeys));
int count = 0;
std::vector<UID> readyServers;
for (int s = 0; s < serverReady.size(); ++s) {
count += serverReady[s].isReady() && !serverReady[s].isError();
if (serverReady[s].isReady() && !serverReady[s].isError()) {
readyServers.push_back(storageServerInterfaces[s].uniqueID);
}
}
TraceEvent(SevVerbose, "FinishMoveShardsWaitedServers", relocationIntervalId)
.detail("ReadyServers", count);
.detail("DataMoveID", dataMoveId)
.detail("ReadyServers", describe(readyServers));
if (readyServers.size() == newDestinations.size()) {
if (count == newDestinations.size()) {
std::vector<Future<Void>> actors;
actors.push_back(krmSetRangeCoalescing(
&tr, keyServersPrefix, range, allKeys, keyServersValue(destServers, {}, dataMoveId, UID())));

View File

@ -880,7 +880,7 @@ std::shared_ptr<platform::TmpFile> prepareTokenFile(const uint8_t* buff, const i
std::shared_ptr<platform::TmpFile> prepareTokenFile(const int tokenLen) {
Standalone<StringRef> buff = makeString(tokenLen);
generateRandomData(mutateString(buff), tokenLen);
deterministicRandom()->randomBytes(mutateString(buff), tokenLen);
return prepareTokenFile(buff.begin(), tokenLen);
}
@ -941,7 +941,7 @@ ACTOR Future<Void> testValidationFileTokenPayloadTooLarge(Reference<RESTKmsConne
SERVER_KNOBS->REST_KMS_CONNECTOR_VALIDATION_TOKEN_MAX_SIZE +
2;
Standalone<StringRef> buff = makeString(tokenLen);
generateRandomData(mutateString(buff), tokenLen);
deterministicRandom()->randomBytes(mutateString(buff), tokenLen);
std::string details;
state std::vector<std::shared_ptr<platform::TmpFile>> tokenfiles;
@ -972,7 +972,7 @@ ACTOR Future<Void> testMultiValidationFileTokenFiles(Reference<RESTKmsConnectorC
state std::unordered_map<std::string, std::string> tokenNameValueMap;
state std::string tokenDetailsStr;
generateRandomData(mutateString(buff), tokenLen);
deterministicRandom()->randomBytes(mutateString(buff), tokenLen);
for (int i = 1; i <= numFiles; i++) {
std::string tokenName = std::to_string(i);
@ -1350,7 +1350,7 @@ TEST_CASE("/KmsConnector/REST/ParseKmsDiscoveryUrls") {
state Arena arena;
// initialize cipher key used for testing
generateRandomData(&BASE_CIPHER_KEY_TEST[0], 32);
deterministicRandom()->randomBytes(&BASE_CIPHER_KEY_TEST[0], 32);
wait(testParseDiscoverKmsUrlFileNotFound(ctx));
wait(testParseDiscoverKmsUrlFile(ctx));
@ -1363,7 +1363,7 @@ TEST_CASE("/KmsConnector/REST/ParseValidationTokenFile") {
state Arena arena;
// initialize cipher key used for testing
generateRandomData(&BASE_CIPHER_KEY_TEST[0], 32);
deterministicRandom()->randomBytes(&BASE_CIPHER_KEY_TEST[0], 32);
wait(testEmptyValidationFileDetails(ctx));
wait(testMalformedFileValidationTokenDetails(ctx));
@ -1380,7 +1380,7 @@ TEST_CASE("/KmsConnector/REST/ParseKmsResponse") {
state Arena arena;
// initialize cipher key used for testing
generateRandomData(&BASE_CIPHER_KEY_TEST[0], 32);
deterministicRandom()->randomBytes(&BASE_CIPHER_KEY_TEST[0], 32);
testMissingCipherDetailsTag(ctx);
testMalformedCipherDetails(ctx);
@ -1394,7 +1394,7 @@ TEST_CASE("/KmsConnector/REST/GetEncryptionKeyOps") {
state Arena arena;
// initialize cipher key used for testing
generateRandomData(&BASE_CIPHER_KEY_TEST[0], 32);
deterministicRandom()->randomBytes(&BASE_CIPHER_KEY_TEST[0], 32);
// Prepare KmsConnector context details
wait(testParseDiscoverKmsUrlFile(ctx));

View File

@ -41,6 +41,7 @@
#include "fdbclient/NativeAPI.actor.h"
#include "fdbclient/BackupAgent.actor.h"
#include "fdbclient/versions.h"
#include "flow/MkCert.h"
#include "fdbrpc/WellKnownEndpoints.h"
#include "flow/ProtocolVersion.h"
#include "flow/network.h"
@ -614,6 +615,9 @@ ACTOR Future<ISimulator::KillType> simulatedFDBDRebooter(Reference<IClusterConne
1,
WLTOKEN_RESERVED_COUNT,
&allowList);
for (const auto& p : g_simulator.authKeys) {
FlowTransport::transport().addPublicKey(p.first, p.second.toPublic());
}
Sim2FileSystem::newFileSystem();
std::vector<Future<Void>> futures;
@ -1464,8 +1468,6 @@ void SimulationConfig::setStorageEngine(const TestConfig& testConfig) {
TraceEvent(SevWarnAlways, "RocksDBNonDeterminism")
.detail("Explanation", "The Sharded RocksDB storage engine is threaded and non-deterministic");
noUnseed = true;
auto& g_knobs = IKnobCollection::getMutableGlobalKnobCollection();
g_knobs.setKnob("shard_encode_location_metadata", KnobValueRef::create(bool{ true }));
break;
}
default:
@ -2422,6 +2424,10 @@ ACTOR void setupAndRun(std::string dataFolder,
state bool allowDisablingTenants = testConfig.allowDisablingTenants;
state bool allowCreatingTenants = testConfig.allowCreatingTenants;
if (!CLIENT_KNOBS->SHARD_ENCODE_LOCATION_METADATA) {
testConfig.storageEngineExcludeTypes.push_back(5);
}
// The RocksDB storage engine does not support the restarting tests because you cannot consistently get a clean
// snapshot of the storage engine without a snapshotting file system.
// https://github.com/apple/foundationdb/issues/5155

View File

@ -117,7 +117,7 @@ void TenantCache::insert(TenantName& tenantName, TenantMapEntry& tenant) {
KeyRef tenantPrefix(tenant.prefix.begin(), tenant.prefix.size());
ASSERT(tenantCache.find(tenantPrefix) == tenantCache.end());
TenantInfo tenantInfo(tenantName, tenant.id);
TenantInfo tenantInfo(tenantName, Optional<Standalone<StringRef>>(), tenant.id);
tenantCache[tenantPrefix] = makeReference<TCTenantInfo>(tenantInfo, tenant.prefix);
tenantCache[tenantPrefix]->updateCacheGeneration(generation);
}

View File

@ -10273,7 +10273,7 @@ TEST_CASE(":/redwood/performance/extentQueue") {
state int v;
state ExtentQueueEntry<16> e;
generateRandomData(e.entry, 16);
deterministicRandom()->randomBytes(e.entry, 16);
state int sinceYield = 0;
for (v = 1; v <= numEntries; ++v) {
// Sometimes do a commit

View File

@ -1857,8 +1857,10 @@ int main(int argc, char* argv[]) {
auto opts = CLIOptions::parseArgs(argc, argv);
const auto role = opts.role;
if (role == ServerRole::Simulation)
if (role == ServerRole::Simulation) {
printf("Random seed is %u...\n", opts.randomSeed);
bindDeterministicRandomToOpenssl();
}
if (opts.zoneId.present())
printf("ZoneId set to %s, dcId to %s\n", printable(opts.zoneId).c_str(), printable(opts.dcId).c_str());

View File

@ -66,7 +66,7 @@ struct EncryptedMutationMessage {
ASSERT(textCipherItr != cipherKeys.end() && textCipherItr->second.isValid());
ASSERT(headerCipherItr != cipherKeys.end() && headerCipherItr->second.isValid());
uint8_t iv[AES_256_IV_LENGTH];
generateRandomData(iv, AES_256_IV_LENGTH);
deterministicRandom()->randomBytes(iv, AES_256_IV_LENGTH);
BinaryWriter bw(AssumeVersion(g_network->protocolVersion()));
bw << mutation;
EncryptedMutationMessage encrypted_mutation;
@ -116,4 +116,4 @@ struct EncryptedMutationMessage {
return mutation;
}
};
#endif
#endif

View File

@ -39,10 +39,29 @@
#include "fdbrpc/simulator.h"
#include "flow/actorcompiler.h" // This must be the last #include.
template <class T>
struct sfinae_true : std::true_type {};
template <class T>
auto testAuthToken(int) -> sfinae_true<decltype(std::declval<T>().getAuthToken())>;
template <class>
auto testAuthToken(long) -> std::false_type;
template <class T>
struct hasAuthToken : decltype(testAuthToken<T>(0)) {};
template <class T>
void setAuthToken(T const& self, Transaction& tr) {
if constexpr (hasAuthToken<T>::value) {
tr.setOption(FDBTransactionOptions::AUTHORIZATION_TOKEN, self.getAuthToken());
}
}
ACTOR template <class T>
Future<bool> checkRangeSimpleValueSize(Database cx, T* workload, uint64_t begin, uint64_t end) {
loop {
state Transaction tr(cx);
setAuthToken(*workload, tr);
try {
state Standalone<KeyValueRef> firstKV = (*workload)(begin);
state Standalone<KeyValueRef> lastKV = (*workload)(end - 1);
@ -63,6 +82,7 @@ Future<uint64_t> setupRange(Database cx, T* workload, uint64_t begin, uint64_t e
state uint64_t bytesInserted = 0;
loop {
state Transaction tr(cx);
setAuthToken(*workload, tr);
try {
// if( deterministicRandom()->random01() < 0.001 )
// tr.debugTransaction( deterministicRandom()->randomUniqueID() );
@ -128,6 +148,7 @@ Future<uint64_t> setupRangeWorker(Database cx,
if (keysLoaded - lastStoredKeysLoaded >= keySaveIncrement || jobs->size() == 0) {
state Transaction tr(cx);
setAuthToken(*workload, tr);
try {
std::string countKey = format("keycount|%d|%d", workload->clientId, actorId);
std::string bytesKey = format("bytesstored|%d|%d", workload->clientId, actorId);

View File

@ -3222,7 +3222,7 @@ ACTOR Future<GetValueReqAndResultRef> quickGetValue(StorageServer* data,
++data->counters.quickGetValueMiss;
if (SERVER_KNOBS->QUICK_GET_VALUE_FALLBACK) {
state Transaction tr(data->cx, pOriginalReq->tenantInfo.name);
state Transaction tr(data->cx, pOriginalReq->tenantInfo.name.castTo<TenantName>());
tr.setVersion(version);
// TODO: is DefaultPromiseEndpoint the best priority for this?
tr.trState->taskID = TaskPriority::DefaultPromiseEndpoint;
@ -3857,7 +3857,7 @@ ACTOR Future<GetRangeReqAndResultRef> quickGetKeyValues(
++data->counters.quickGetKeyValuesMiss;
if (SERVER_KNOBS->QUICK_GET_KEY_VALUES_FALLBACK) {
state Transaction tr(data->cx, pOriginalReq->tenantInfo.name);
state Transaction tr(data->cx, pOriginalReq->tenantInfo.name.castTo<TenantName>());
tr.setVersion(version);
// TODO: is DefaultPromiseEndpoint the best priority for this?
tr.trState->taskID = TaskPriority::DefaultPromiseEndpoint;
@ -7084,6 +7084,10 @@ void changeServerKeysWithPhysicalShards(StorageServer* data,
for (int i = 0; i < ranges.size(); i++) {
const Reference<ShardInfo> currentShard = ranges[i].value;
const KeyRangeRef currentRange = static_cast<KeyRangeRef>(ranges[i]);
if (currentShard.isValid()) {
TraceEvent(SevVerbose, "OverlappingPhysicalShard", data->thisServerID)
.detail("PhysicalShard", currentShard->toStorageServerShard().toString());
}
if (!currentShard.isValid()) {
ASSERT(currentRange == keys); // there shouldn't be any nulls except for the range being inserted
} else if (currentShard->notAssigned()) {
@ -7105,7 +7109,7 @@ void changeServerKeysWithPhysicalShards(StorageServer* data,
.detail("NowAssigned", nowAssigned)
.detail("Version", cVer)
.detail("ResultingShard", newShard.toString());
} else if (ranges[i].value->adding) {
} else if (currentShard->adding) {
ASSERT(!nowAssigned);
StorageServerShard newShard = currentShard->toStorageServerShard();
newShard.range = currentRange;

View File

@ -710,6 +710,7 @@ ACTOR Future<Void> testerServerWorkload(WorkloadRequest work,
endRole(Role::TESTER, workIface.id(), "Complete");
} catch (Error& e) {
TraceEvent(SevDebug, "TesterWorkloadFailed").errorUnsuppressed(e);
if (!replied) {
if (e.code() == error_code_test_specification_invalid)
work.reply.sendError(e);

View File

@ -377,7 +377,7 @@ struct AsyncFileCorrectnessWorkload : public AsyncFileWorkload {
}
} else if (info.operation == WRITE) {
info.data = self->allocateBuffer(info.length);
generateRandomData(reinterpret_cast<uint8_t*>(info.data->buffer), info.length);
deterministicRandom()->randomBytes(reinterpret_cast<uint8_t*>(info.data->buffer), info.length);
memcpy(&self->memoryFile->buffer[info.offset], info.data->buffer, info.length);
memset(&self->fileValidityMask[info.offset], 0xFF, info.length);

View File

@ -119,6 +119,7 @@ public:
struct WorkloadProcess {
WorkloadProcessState* processState;
WorkloadContext childWorkloadContext;
UID id = deterministicRandom()->randomUniqueID();
Database cx;
Future<Void> databaseOpened;
@ -166,36 +167,56 @@ struct WorkloadProcess {
WorkloadProcess(ClientWorkload::CreateWorkload const& childCreator, WorkloadContext const& wcx)
: processState(WorkloadProcessState::instance(wcx.clientId)) {
TraceEvent("StartingClinetWorkload", id).detail("OnClientProcess", processState->id);
databaseOpened = openDatabase(this, childCreator, wcx);
childWorkloadContext.clientCount = wcx.clientCount;
childWorkloadContext.clientId = wcx.clientId;
childWorkloadContext.ccr = wcx.ccr;
childWorkloadContext.options = wcx.options;
childWorkloadContext.sharedRandomNumber = wcx.sharedRandomNumber;
databaseOpened = openDatabase(this, childCreator, childWorkloadContext);
}
ACTOR static void destroy(WorkloadProcess* self) {
state ISimulator::ProcessInfo* parent = g_simulator.getCurrentProcess();
wait(g_simulator.onProcess(self->childProcess(), TaskPriority::DefaultYield));
TraceEvent("DeleteWorkloadProcess").backtrace();
delete self;
wait(g_simulator.onProcess(parent, TaskPriority::DefaultYield));
}
std::string description() { return desc; }
// This actor will keep a reference to a future alive, switch to another process and then return. If the future
// count of `f` is 1, this will cause the future to be destroyed in the process `process`
ACTOR template <class T>
static void cancelChild(ISimulator::ProcessInfo* process, Future<T> f) {
wait(g_simulator.onProcess(process, TaskPriority::DefaultYield));
}
ACTOR template <class Ret, class Fun>
Future<Ret> runActor(WorkloadProcess* self, Optional<TenantName> defaultTenant, Fun f) {
state Optional<Error> err;
state Ret res;
state Future<Ret> fut;
state ISimulator::ProcessInfo* parent = g_simulator.getCurrentProcess();
wait(self->databaseOpened);
wait(g_simulator.onProcess(self->childProcess(), TaskPriority::DefaultYield));
self->cx->defaultTenant = defaultTenant;
try {
Ret r = wait(f(self->cx));
fut = f(self->cx);
Ret r = wait(fut);
res = r;
} catch (Error& e) {
// if we're getting cancelled, we could run in the scope of the parent process, but we're not allowed to
// cancel `fut` in any other process than the child process. So we're going to pass the future to an
// uncancellable actor (it has to be uncancellable because if we got cancelled here we can't wait on
// anything) which will then destroy the future on the child process.
cancelChild(self->childProcess(), fut);
if (e.code() == error_code_actor_cancelled) {
ASSERT(g_simulator.getCurrentProcess() == parent);
throw;
throw e;
}
err = e;
}
fut = Future<Ret>();
wait(g_simulator.onProcess(parent, TaskPriority::DefaultYield));
if (err.present()) {
throw err.get();
@ -208,6 +229,7 @@ ClientWorkload::ClientWorkload(CreateWorkload const& childCreator, WorkloadConte
: TestWorkload(wcx), impl(new WorkloadProcess(childCreator, wcx)) {}
ClientWorkload::~ClientWorkload() {
TraceEvent(SevDebug, "DestroyClientWorkload").backtrace();
WorkloadProcess::destroy(impl);
}

View File

@ -885,16 +885,11 @@ struct ConsistencyCheckWorkload : TestWorkload {
for (int i = 0; i < commitProxyInfo->size(); i++)
keyServerLocationFutures.push_back(
commitProxyInfo->get(i, &CommitProxyInterface::getKeyServersLocations)
.getReplyUnlessFailedFor(GetKeyServerLocationsRequest(span.context,
Optional<TenantNameRef>(),
begin,
end,
limitKeyServers,
false,
latestVersion,
Arena()),
2,
0));
.getReplyUnlessFailedFor(
GetKeyServerLocationsRequest(
span.context, TenantInfo(), begin, end, limitKeyServers, false, latestVersion, Arena()),
2,
0));
state bool keyServersInsertedForThisIteration = false;
choose {

View File

@ -0,0 +1,62 @@
/*
* CreateTenant.actor.cpp
*
* This source file is part of the FoundationDB open source project
*
* Copyright 2013-2022 Apple Inc. and the FoundationDB project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cstdint>
#include "fdbclient/TenantManagement.actor.h"
#include "fdbserver/workloads/workloads.actor.h"
#include "flow/actorcompiler.h" // This must be the last #include.
struct CreateTenantWorkload : TestWorkload {
TenantName tenant;
CreateTenantWorkload(WorkloadContext const& wcx) : TestWorkload(wcx) {
tenant = getOption(options, "name"_sr, "DefaultTenant"_sr);
}
std::string description() const override { return "CreateTenant"; }
Future<Void> setup(Database const& cx) override {
if (clientId == 0) {
return _setup(this, cx);
}
return Void();
}
Future<Void> start(Database const& cx) override { return Void(); }
Future<bool> check(Database const& cx) override { return true; }
virtual void getMetrics(std::vector<PerfMetric>& m) override {}
ACTOR static Future<Void> _setup(CreateTenantWorkload* self, Database db) {
try {
Optional<TenantMapEntry> entry = wait(TenantAPI::createTenant(db.getReference(), self->tenant));
ASSERT(entry.present());
} catch (Error& e) {
TraceEvent(SevError, "TenantCreationFailed").error(e);
if (e.code() == error_code_actor_cancelled) {
throw;
}
ASSERT(false);
}
return Void();
}
};
WorkloadFactory<CreateTenantWorkload> CreateTenantWorkload("CreateTenant");

View File

@ -20,19 +20,33 @@
#include <cstring>
#include "flow/Arena.h"
#include "flow/IRandom.h"
#include "flow/Trace.h"
#include "flow/serialize.h"
#include "fdbrpc/simulator.h"
#include "fdbrpc/TokenSign.h"
#include "fdbclient/FDBOptions.g.h"
#include "fdbclient/NativeAPI.actor.h"
#include "fdbserver/TesterInterface.actor.h"
#include "fdbserver/workloads/workloads.actor.h"
#include "fdbserver/workloads/BulkSetup.actor.h"
#include "flow/Arena.h"
#include "flow/IRandom.h"
#include "flow/Trace.h"
#include "flow/serialize.h"
#include "flow/actorcompiler.h" // This must be the last #include.
struct CycleWorkload : TestWorkload {
template <bool MultiTenancy>
struct CycleMembers {};
template <>
struct CycleMembers<true> {
Arena arena;
TenantName tenant;
authz::jwt::TokenRef token;
StringRef signedToken;
};
template <bool MultiTenancy>
struct CycleWorkload : TestWorkload, CycleMembers<MultiTenancy> {
int actorCount, nodeCount;
double testDuration, transactionsPerSecond, minExpectedTransactionsPerSecond, traceParentProbability;
Key keyPrefix;
@ -51,17 +65,58 @@ struct CycleWorkload : TestWorkload {
keyPrefix = unprintable(getOption(options, "keyPrefix"_sr, LiteralStringRef("")).toString());
traceParentProbability = getOption(options, "traceParentProbability"_sr, 0.01);
minExpectedTransactionsPerSecond = transactionsPerSecond * getOption(options, "expectedRate"_sr, 0.7);
if constexpr (MultiTenancy) {
ASSERT(g_network->isSimulated());
auto k = g_simulator.authKeys.begin();
this->tenant = getOption(options, "tenant"_sr, "CycleTenant"_sr);
// make it comfortably longer than the timeout of the workload
auto currentTime = uint64_t(lround(g_network->timer()));
this->token.algorithm = authz::Algorithm::ES256;
this->token.issuedAtUnixTime = currentTime;
this->token.expiresAtUnixTime =
currentTime + uint64_t(std::lround(getCheckTimeout())) + uint64_t(std::lround(testDuration)) + 100;
this->token.keyId = k->first;
this->token.notBeforeUnixTime = currentTime - 10;
VectorRef<StringRef> tenants;
tenants.push_back_deep(this->arena, this->tenant);
this->token.tenants = tenants;
// we currently don't support this workload to be run outside of simulation
this->signedToken = authz::jwt::signToken(this->arena, this->token, k->second);
}
}
std::string description() const override { return "CycleWorkload"; }
Future<Void> setup(Database const& cx) override { return bulkSetup(cx, this, nodeCount, Promise<double>()); }
template <bool MT = MultiTenancy>
std::enable_if_t<MT, StringRef> getAuthToken() const {
return this->signedToken;
}
std::string description() const override {
if constexpr (MultiTenancy) {
return "TenantCycleWorkload";
} else {
return "CycleWorkload";
}
}
Future<Void> setup(Database const& cx) override {
if constexpr (MultiTenancy) {
cx->defaultTenant = this->tenant;
}
return bulkSetup(cx, this, nodeCount, Promise<double>());
}
Future<Void> start(Database const& cx) override {
if constexpr (MultiTenancy) {
cx->defaultTenant = this->tenant;
}
for (int c = 0; c < actorCount; c++)
clients.push_back(
timeout(cycleClient(cx->clone(), this, actorCount / transactionsPerSecond), testDuration, Void()));
return delay(testDuration);
}
Future<bool> check(Database const& cx) override {
if constexpr (MultiTenancy) {
cx->defaultTenant = this->tenant;
}
int errors = 0;
for (int c = 0; c < clients.size(); c++)
errors += clients[c].isError();
@ -95,6 +150,14 @@ struct CycleWorkload : TestWorkload {
.detailf("From", "%016llx", debug_lastLoadBalanceResultEndpointToken);
}
template <bool B = MultiTenancy>
std::enable_if_t<B> setAuthToken(Transaction& tr) {
tr.setOption(FDBTransactionOptions::AUTHORIZATION_TOKEN, this->signedToken);
}
template <bool B = MultiTenancy>
std::enable_if_t<!B> setAuthToken(Transaction& tr) {}
ACTOR Future<Void> cycleClient(Database cx, CycleWorkload* self, double delay) {
state double lastTime = now();
try {
@ -104,6 +167,7 @@ struct CycleWorkload : TestWorkload {
state double tstart = now();
state int r = deterministicRandom()->randomInt(0, self->nodeCount);
state Transaction tr(cx);
self->setAuthToken(tr);
if (deterministicRandom()->random01() <= self->traceParentProbability) {
state Span span("CycleClient"_loc);
TraceEvent("CycleTracingTransaction", span.context.traceID).log();
@ -231,6 +295,7 @@ struct CycleWorkload : TestWorkload {
}
return true;
}
ACTOR Future<bool> cycleCheck(Database cx, CycleWorkload* self, bool ok) {
if (self->transactions.getMetric().value() < self->testDuration * self->minExpectedTransactionsPerSecond) {
TraceEvent(SevWarnAlways, "TestFailure")
@ -249,6 +314,7 @@ struct CycleWorkload : TestWorkload {
// One client checks the validity of the cycle
state Transaction tr(cx);
state int retryCount = 0;
self->setAuthToken(tr);
loop {
try {
state Version v = wait(tr.getReadVersion());
@ -273,4 +339,5 @@ struct CycleWorkload : TestWorkload {
}
};
WorkloadFactory<CycleWorkload> CycleWorkloadFactory("Cycle", true);
WorkloadFactory<CycleWorkload<false>> CycleWorkloadFactory("Cycle", false);
WorkloadFactory<CycleWorkload<true>> TenantCycleWorkloadFactory("TenantCycle", true);

View File

@ -159,7 +159,7 @@ struct EncryptionOpsWorkload : TestWorkload {
void generateRandomBaseCipher(const int maxLen, uint8_t* buff, int* retLen) {
memset(buff, 0, maxLen);
*retLen = deterministicRandom()->randomInt(maxLen / 2, maxLen);
generateRandomData(buff, *retLen);
deterministicRandom()->randomBytes(buff, *retLen);
}
void setupCipherEssentials() {
@ -247,7 +247,7 @@ struct EncryptionOpsWorkload : TestWorkload {
const EncryptAuthTokenMode authMode,
BlobCipherEncryptHeader* header) {
uint8_t iv[AES_256_IV_LENGTH];
generateRandomData(&iv[0], AES_256_IV_LENGTH);
deterministicRandom()->randomBytes(&iv[0], AES_256_IV_LENGTH);
EncryptBlobCipherAes265Ctr encryptor(textCipherKey, headerCipherKey, &iv[0], AES_256_IV_LENGTH, authMode);
auto start = std::chrono::high_resolution_clock::now();
@ -341,7 +341,7 @@ struct EncryptionOpsWorkload : TestWorkload {
}
int dataLen = isFixedSizePayload() ? pageSize : deterministicRandom()->randomInt(100, maxBufSize);
generateRandomData(buff.get(), dataLen);
deterministicRandom()->randomBytes(buff.get(), dataLen);
// Encrypt the payload - generates BlobCipherEncryptHeader to assist decryption later
BlobCipherEncryptHeader header;

View File

@ -75,4 +75,4 @@ struct StorageQuotaWorkload : TestWorkload {
}
};
WorkloadFactory<StorageQuotaWorkload> StorageQuotaWorkloadFactory("StorageQuota", true);
WorkloadFactory<StorageQuotaWorkload> StorageQuotaWorkloadFactory("StorageQuota");

View File

@ -389,7 +389,7 @@ EncryptBlobCipherAes265Ctr::EncryptBlobCipherAes265Ctr(Reference<BlobCipherKey>
const EncryptAuthTokenMode mode)
: ctx(EVP_CIPHER_CTX_new()), textCipherKey(tCipherKey), headerCipherKey(hCipherKey), authTokenMode(mode) {
ASSERT(isEncryptHeaderAuthTokenModeValid(mode));
generateRandomData(iv, AES_256_IV_LENGTH);
deterministicRandom()->randomBytes(iv, AES_256_IV_LENGTH);
init();
}
@ -796,7 +796,7 @@ TEST_CASE("flow/BlobCipher") {
BaseCipher(const EncryptCipherDomainId& dId, const EncryptCipherBaseKeyId& kId)
: domainId(dId), len(deterministicRandom()->randomInt(AES_256_KEY_LENGTH / 2, AES_256_KEY_LENGTH + 1)),
keyId(kId), key(std::make_unique<uint8_t[]>(len)) {
generateRandomData(key.get(), len);
deterministicRandom()->randomBytes(key.get(), len);
}
};
@ -899,11 +899,11 @@ TEST_CASE("flow/BlobCipher") {
Reference<BlobCipherKey> headerCipherKey = cipherKeyCache->getLatestCipherKey(ENCRYPT_HEADER_DOMAIN_ID);
const int bufLen = deterministicRandom()->randomInt(786, 2127) + 512;
uint8_t orgData[bufLen];
generateRandomData(&orgData[0], bufLen);
deterministicRandom()->randomBytes(&orgData[0], bufLen);
Arena arena;
uint8_t iv[AES_256_IV_LENGTH];
generateRandomData(&iv[0], AES_256_IV_LENGTH);
deterministicRandom()->randomBytes(&iv[0], AES_256_IV_LENGTH);
BlobCipherEncryptHeader headerCopy;
// validate basic encrypt followed by decrypt operation for AUTH_MODE_NONE

View File

@ -98,7 +98,7 @@ TEST_CASE("/CompressionUtils/noCompression") {
Arena arena;
const int size = deterministicRandom()->randomInt(512, 1024);
Standalone<StringRef> uncompressed = makeString(size);
generateRandomData(mutateString(uncompressed), size);
deterministicRandom()->randomBytes(mutateString(uncompressed), size);
Standalone<StringRef> compressed = CompressionUtils::compress(CompressionFilter::NONE, uncompressed, arena);
ASSERT_EQ(compressed.compare(uncompressed), 0);
@ -116,7 +116,7 @@ TEST_CASE("/CompressionUtils/gzipCompression") {
Arena arena;
const int size = deterministicRandom()->randomInt(512, 1024);
Standalone<StringRef> uncompressed = makeString(size);
generateRandomData(mutateString(uncompressed), size);
deterministicRandom()->randomBytes(mutateString(uncompressed), size);
Standalone<StringRef> compressed = CompressionUtils::compress(CompressionFilter::GZIP, uncompressed, arena);
ASSERT_NE(compressed.compare(uncompressed), 0);

View File

@ -19,6 +19,7 @@
*/
#include "fmt/format.h"
#include "flow/Arena.h"
#include "flow/DeterministicRandom.h"
#include <cstring>
@ -124,6 +125,23 @@ std::string DeterministicRandom::randomAlphaNumeric(int length) {
return s;
}
void DeterministicRandom::randomBytes(uint8_t* buf, int length) {
constexpr const int unitLen = sizeof(decltype(gen64()));
for (int i = 0; i < length; i += unitLen) {
auto val = gen64();
memcpy(buf + i, &val, std::min(unitLen, length - i));
}
if (randLog && useRandLog) {
constexpr const int cutOff = 32;
bool tooLong = length > cutOff;
fmt::print(randLog,
"Rbytes[{}] {}{}\n",
length,
StringRef(buf, std::min(cutOff, length)).printable(),
tooLong ? "..." : "");
}
}
uint64_t DeterministicRandom::peek() const {
return next;
}
@ -134,10 +152,3 @@ void DeterministicRandom::addref() {
void DeterministicRandom::delref() {
ReferenceCounted<DeterministicRandom>::delref();
}
void generateRandomData(uint8_t* buffer, int length) {
for (int i = 0; i < length; i += sizeof(uint32_t)) {
uint32_t val = deterministicRandom()->randomUInt32();
memcpy(&buffer[i], &val, std::min(length - i, (int)sizeof(uint32_t)));
}
}

View File

@ -126,6 +126,7 @@ void FlowKnobs::initialize(Randomize randomize, IsSimulated isSimulated) {
init( NETWORK_TEST_REQUEST_COUNT, 0 ); // 0 -> run forever
init( NETWORK_TEST_REQUEST_SIZE, 1 );
init( NETWORK_TEST_SCRIPT_MODE, false );
init( MAX_CACHED_EXPIRED_TOKENS, 1024 );
//AsyncFileCached
init( PAGE_CACHE_4K, 2LL<<30 );
@ -286,6 +287,7 @@ void FlowKnobs::initialize(Randomize randomize, IsSimulated isSimulated) {
if ( randomize && BUGGIFY) { ENCRYPT_CIPHER_KEY_CACHE_TTL = deterministicRandom()->randomInt(50, 100); }
init( ENCRYPT_KEY_REFRESH_INTERVAL, isSimulated ? 60 : 8 * 60 );
if ( randomize && BUGGIFY) { ENCRYPT_KEY_REFRESH_INTERVAL = deterministicRandom()->randomInt(2, 10); }
init( TOKEN_CACHE_SIZE, 100 );
// REST Client
init( RESTCLIENT_MAX_CONNECTIONPOOL_SIZE, 10 );

View File

@ -199,11 +199,11 @@ TEST_CASE("flow/StreamCipher") {
StreamCipherKey const* key = StreamCipherKey::getGlobalCipherKey();
StreamCipher::IV iv;
generateRandomData(iv.data(), iv.size());
deterministicRandom()->randomBytes(iv.data(), iv.size());
Arena arena;
std::vector<unsigned char> plaintext(deterministicRandom()->randomInt(0, 10001));
generateRandomData(&plaintext.front(), plaintext.size());
deterministicRandom()->randomBytes(&plaintext.front(), plaintext.size());
std::vector<unsigned char> ciphertext(plaintext.size() + AES_BLOCK_SIZE);
std::vector<unsigned char> decryptedtext(plaintext.size() + AES_BLOCK_SIZE);

View File

@ -20,6 +20,7 @@
#include "flow/flow.h"
#include "flow/DeterministicRandom.h"
#include "flow/Error.h"
#include "flow/UnitTest.h"
#include "flow/rte_memcpy.h"
#ifdef WITH_FOLLY_MEMCPY
@ -27,6 +28,8 @@
#endif
#include <stdarg.h>
#include <cinttypes>
#include <openssl/err.h>
#include <openssl/rand.h>
std::atomic<bool> startSampling = false;
LineageReference rootLineage;
@ -374,6 +377,77 @@ void enableBuggify(bool enabled, BuggifyType type) {
buggifyActivated[int(type)] = enabled;
}
// Make OpenSSL use DeterministicRandom as RNG source such that simulation runs stay deterministic w/ e.g. signature ops
void bindDeterministicRandomToOpenssl() {
// TODO: implement ifdef branch for 3.x using provider API
#ifndef OPENSSL_IS_BORINGSSL
static const RAND_METHOD method = {
// replacement for RAND_seed(), which reseeds OpenSSL RNG
[](const void*, int) -> int { return 1; },
// replacement for RAND_bytes(), which fills given buffer with random byte sequence
[](unsigned char* buf, int length) -> int {
if (g_network)
ASSERT_ABORT(g_network->isSimulated());
deterministicRandom()->randomBytes(buf, length);
return 1;
},
// replacement for RAND_cleanup(), a no-op for simulation
[]() -> void {},
// replacement for RAND_add(), which reseeds OpenSSL RNG with randomness hint
[](const void*, int, double) -> int { return 1; },
// replacement for default pseudobytes getter (same as RAND_bytes by default)
[](unsigned char* buf, int length) -> int {
if (g_network)
ASSERT_ABORT(g_network->isSimulated());
deterministicRandom()->randomBytes(buf, length);
return 1;
},
// status function for PRNG readiness check
[]() -> int { return 1; },
};
if (1 != ::RAND_set_rand_method(&method)) {
auto ec = ::ERR_get_error();
char msg[256]{
0,
};
if (ec) {
::ERR_error_string_n(ec, msg, sizeof(msg));
}
fprintf(stderr,
"ERROR: Failed to bind DeterministicRandom to OpenSSL RNG\n"
" OpenSSL error message: '%s'\n",
msg);
throw internal_error();
} else {
printf("DeterministicRandom successfully bound to OpenSSL RNG\n");
}
#else // OPENSSL_IS_BORINGSSL
static const RAND_METHOD method = {
[](const void*, int) -> void {},
[](unsigned char* buf, unsigned long length) -> int {
if (g_network)
ASSERT_ABORT(g_network->isSimulated());
ASSERT(length <= std::numeric_limits<int>::max());
deterministicRandom()->randomBytes(buf, length);
return 1;
},
[]() -> void {},
[](const void*, int, double) -> void {},
[](unsigned char* buf, unsigned long length) -> int {
if (g_network)
ASSERT_ABORT(g_network->isSimulated());
ASSERT(length <= std::numeric_limits<int>::max());
deterministicRandom()->randomBytes(buf, length);
return 1;
},
[]() -> int { return 1; },
};
::RAND_set_rand_method(&method);
printf("DeterministicRandom successfully bound to OpenSSL RNG\n");
#endif // OPENSSL_IS_BORINGSSL
}
namespace {
// Simple message for flatbuffers unittests
struct Int {

View File

@ -34,6 +34,7 @@
#include <algorithm>
#include <boost/functional/hash.hpp>
#include <stdint.h>
#include <string_view>
#include <string>
#include <cstring>
#include <limits>
@ -532,7 +533,9 @@ public:
return substr(0, size() - s.size());
}
std::string toString() const { return std::string((const char*)data, length); }
std::string toString() const { return std::string(reinterpret_cast<const char*>(data), length); }
std::string_view toStringView() const { return std::string_view(reinterpret_cast<const char*>(data), length); }
static bool isPrintable(char c) { return c > 32 && c < 127; }
inline std::string printable() const;

View File

@ -228,7 +228,7 @@ struct CodeProbeImpl : ICodeProbe {
evt.detail("File", filename())
.detail("Line", Line)
.detail("Condition", Condition::value())
.detail("ProbeHit", condition)
.detail("Covered", condition)
.detail("Comment", Comment::value());
annotations.trace(this, evt, condition);
}

View File

@ -49,6 +49,7 @@ public:
UID randomUniqueID() override;
char randomAlphaNumeric() override;
std::string randomAlphaNumeric(int length) override;
void randomBytes(uint8_t* buf, int length) override;
uint64_t peek() const override;
void addref() override;
void delref() override;

View File

@ -143,6 +143,7 @@ public:
virtual UID randomUniqueID() = 0;
virtual char randomAlphaNumeric() = 0;
virtual std::string randomAlphaNumeric(int length) = 0;
virtual void randomBytes(uint8_t* buf, int length) = 0;
virtual uint32_t randomSkewedUInt32(uint32_t min, uint32_t maxPlusOne) = 0;
virtual uint64_t peek() const = 0; // returns something that is probably different for different random states.
// Deterministic (and idempotent) for a deterministic generator.
@ -209,7 +210,4 @@ Reference<IRandom> nondeterministicRandom();
// WARNING: This is not thread safe and must not be called from any other thread than the network thread!
Reference<IRandom> debugRandom();
// Populates a buffer with a random sequence of bytes
void generateRandomData(uint8_t* buffer, int length);
#endif

View File

@ -195,6 +195,8 @@ public:
int NETWORK_TEST_REQUEST_SIZE;
bool NETWORK_TEST_SCRIPT_MODE;
int MAX_CACHED_EXPIRED_TOKENS;
// AsyncFileCached
int64_t PAGE_CACHE_4K;
int64_t PAGE_CACHE_64K;
@ -354,6 +356,9 @@ public:
int64_t ENCRYPT_CIPHER_KEY_CACHE_TTL;
int64_t ENCRYPT_KEY_REFRESH_INTERVAL;
// Authorization
int TOKEN_CACHE_SIZE;
// RESTClient
int RESTCLIENT_MAX_CONNECTIONPOOL_SIZE;
int RESTCLIENT_CONNECT_TRIES;

View File

@ -25,8 +25,15 @@
#include "flow/ProtocolVersion.h"
#include <unordered_map>
#include <any>
using ContextVariableMap = std::unordered_map<std::string_view, void*>;
using ContextVariableMap = std::unordered_map<std::string_view, std::any>;
template <class T>
struct HasVariableMap_t : std::false_type {};
template <class T>
constexpr bool HasVariableMap = HasVariableMap_t<T>::value;
template <class Ar>
struct LoadContext {
@ -53,6 +60,11 @@ struct LoadContext {
void addArena(Arena& arena) { arena = ar->arena(); }
LoadContext& context() { return *this; }
template <class Archiver = Ar>
std::enable_if_t<HasVariableMap<Archiver>, std::any&> variable(std::string_view name) {
return ar->variable(name);
}
};
template <class Ar, class Allocator>
@ -110,23 +122,9 @@ public:
deserialize(FileIdentifierFor<Item>::value, item);
}
template <class T>
bool variable(std::string_view name, T* val) {
auto p = variables->insert(std::make_pair(name, val));
return p.second;
}
std::any& variable(std::string_view name) { return variables->at(name); }
template <class T>
T& variable(std::string_view name) {
auto res = variables->at(name);
return *reinterpret_cast<T*>(res);
}
template <class T>
T const& variable(std::string_view name) const {
auto res = variables->at(name);
return *reinterpret_cast<T*>(res);
}
std::any const& variable(std::string_view name) const { return variables->at(name); }
};
class ObjectReader : public _ObjectReader<ObjectReader> {
@ -267,6 +265,11 @@ private:
int size = 0;
};
template <>
struct HasVariableMap_t<ObjectReader> : std::true_type {};
template <>
struct HasVariableMap_t<ArenaObjectReader> : std::true_type {};
// this special case is needed - the code expects
// Standalone<T> and T to be equivalent for serialization
namespace detail {

View File

@ -53,7 +53,7 @@ public:
int size() const { return keySize; }
uint8_t* data() const { return arr.get(); }
void initializeKey(uint8_t* data, int len);
void initializeRandomTestKey() { generateRandomData(arr.get(), keySize); }
void initializeRandomTestKey() { deterministicRandom()->randomBytes(arr.get(), keySize); }
void reset() { memset(arr.get(), 0, keySize); }
static bool isGlobalKeyPresent();

View File

@ -1365,5 +1365,7 @@ inline bool check_yield(TaskPriority taskID = TaskPriority::DefaultYield) {
return g_network->check_yield(taskID);
}
void bindDeterministicRandomToOpenssl();
#include "flow/genericactors.actor.h"
#endif

View File

@ -522,6 +522,7 @@ public:
enDiskFailureInjector = 16,
enBitFlipper = 17,
enHistogram = 18,
enTokenCache = 19,
COUNT // Add new fields before this enumerator
};

View File

@ -854,10 +854,6 @@ struct PacketWriter {
}
ProtocolVersion protocolVersion() const { return m_protocolVersion; }
void setProtocolVersion(ProtocolVersion pv) { m_protocolVersion = pv; }
private:
void serializeBytesAcrossBoundary(const void* data, int bytes);
void nextBuffer(size_t size = 0 /* downstream it will default to at least 4k minus some padding */);
uint8_t* writeBytes(size_t size) {
if (size > buffer->bytes_unwritten()) {
nextBuffer(size);
@ -868,6 +864,9 @@ private:
return result;
}
private:
void serializeBytesAcrossBoundary(const void* data, int bytes);
void nextBuffer(size_t size = 0 /* downstream it will default to at least 4k minus some padding */);
template <class, class>
friend class MakeSerializeSource;

View File

@ -25,7 +25,7 @@
static StreamCipher::IV getRandomIV() {
StreamCipher::IV iv;
generateRandomData(iv.data(), iv.size());
deterministicRandom()->randomBytes(iv.data(), iv.size());
return iv;
}

View File

@ -28,7 +28,7 @@ static inline void initGlobalData() {
if (!globalData) {
globalData = static_cast<uint8_t*>(allocateFast(globalDataSize));
}
generateRandomData(globalData, globalDataSize);
deterministicRandom()->randomBytes(globalData, globalDataSize);
}
KeyValueRef getKV(size_t keySize, size_t valueSize) {

View File

@ -178,6 +178,7 @@ if(WITH_PYTHON)
add_fdb_test(TEST_FILES fast/SwizzledRollbackSideband.toml)
add_fdb_test(TEST_FILES fast/SystemRebootTestCycle.toml)
add_fdb_test(TEST_FILES fast/TaskBucketCorrectness.toml)
add_fdb_test(TEST_FILES fast/TenantCycle.toml)
add_fdb_test(TEST_FILES fast/TimeKeeperCorrectness.toml)
add_fdb_test(TEST_FILES fast/TxnStateStoreCycleTest.toml)
add_fdb_test(TEST_FILES fast/UDP.toml)

View File

@ -0,0 +1,31 @@
[configuration]
allowDefaultTenant = false
allowDisablingTenants = false
[[test]]
testTitle = 'TenantCreation'
[[test.workload]]
testName = 'CreateTenant'
name = 'First'
[[test.workload]]
testName = 'CreateTenant'
name = 'Second'
[[test]]
testTitle = 'Cycle'
[[test.workload]]
testName = 'TenantCycle'
tenant = 'First'
transactionsPerSecond = 250.0
testDuration = 10.0
expectedRate = 0.80
[[test.workload]]
testName = 'TenantCycle'
tenant = 'Second'
transactionsPerSecond = 2500.0
testDuration = 10.0
expectedRate = 0.80