Merge branch 'main' of https://github.com/apple/foundationdb into feature/dd-refactor-simple

This commit is contained in:
Xiaoxi Wang 2022-09-01 09:29:30 -07:00
commit b18561bc31
95 changed files with 3992 additions and 1629 deletions

View File

@ -345,7 +345,6 @@ if(NOT WIN32)
)
set_tests_properties("fdb_c_upgrade_to_future_version" PROPERTIES ENVIRONMENT "${SANITIZER_OPTIONS}")
if (0) # reenable after stabilizing the test
add_test(NAME fdb_c_upgrade_to_future_version_blob_granules
COMMAND ${CMAKE_SOURCE_DIR}/tests/TestRunner/upgrade_test.py
--build-dir ${CMAKE_BINARY_DIR}
@ -354,7 +353,6 @@ if (0) # reenable after stabilizing the test
--blob-granules-enabled
--process-number 3
)
endif()
if(CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" AND NOT USE_SANITIZER)
add_test(NAME fdb_c_upgrade_single_threaded_630api
@ -489,7 +487,7 @@ elseif(NOT WIN32 AND NOT APPLE AND NOT USE_SANITIZER) # Linux Only, non-santizer
DEPENDS ${IMPLIBSO_SRC}
COMMENT "Generating source code for C shim library")
add_library(fdb_c_shim SHARED ${SHIM_LIB_GEN_SRC} foundationdb/fdb_c_shim.h fdb_c_shim.cpp)
add_library(fdb_c_shim STATIC ${SHIM_LIB_GEN_SRC} foundationdb/fdb_c_shim.h fdb_c_shim.cpp)
target_link_options(fdb_c_shim PRIVATE "LINKER:--version-script=${CMAKE_CURRENT_SOURCE_DIR}/fdb_c.map,-z,nodelete,-z,noexecstack")
target_link_libraries(fdb_c_shim PUBLIC dl)
target_include_directories(fdb_c_shim PUBLIC

View File

@ -943,6 +943,57 @@ extern "C" DLLEXPORT FDBResult* fdb_transaction_read_blob_granules(FDBTransactio
return (FDBResult*)(TXN(tr)->readBlobGranules(range, beginVersion, rv, context).extractPtr()););
}
extern "C" DLLEXPORT FDBFuture* fdb_transaction_read_blob_granules_start(FDBTransaction* tr,
uint8_t const* begin_key_name,
int begin_key_name_length,
uint8_t const* end_key_name,
int end_key_name_length,
int64_t beginVersion,
int64_t readVersion,
int64_t* readVersionOut) {
Optional<Version> rv;
if (readVersion != latestVersion) {
rv = readVersion;
}
return (FDBFuture*)(TXN(tr)
->readBlobGranulesStart(KeyRangeRef(KeyRef(begin_key_name, begin_key_name_length),
KeyRef(end_key_name, end_key_name_length)),
beginVersion,
rv,
readVersionOut)
.extractPtr());
}
extern "C" DLLEXPORT FDBResult* fdb_transaction_read_blob_granules_finish(FDBTransaction* tr,
FDBFuture* f,
uint8_t const* begin_key_name,
int begin_key_name_length,
uint8_t const* end_key_name,
int end_key_name_length,
int64_t beginVersion,
int64_t readVersion,
FDBReadBlobGranuleContext* granule_context) {
// FIXME: better way to convert?
ReadBlobGranuleContext context;
context.userContext = granule_context->userContext;
context.start_load_f = granule_context->start_load_f;
context.get_load_f = granule_context->get_load_f;
context.free_load_f = granule_context->free_load_f;
context.debugNoMaterialize = granule_context->debugNoMaterialize;
context.granuleParallelism = granule_context->granuleParallelism;
ThreadFuture<Standalone<VectorRef<BlobGranuleChunkRef>>> startFuture(
TSAV(Standalone<VectorRef<BlobGranuleChunkRef>>, f));
return (FDBResult*)(TXN(tr)
->readBlobGranulesFinish(startFuture,
KeyRangeRef(KeyRef(begin_key_name, begin_key_name_length),
KeyRef(end_key_name, end_key_name_length)),
beginVersion,
readVersion,
context)
.extractPtr());
}
#include "fdb_c_function_pointers.g.h"
#define FDB_API_CHANGED(func, ver) \

View File

@ -51,6 +51,27 @@ DLLEXPORT WARN_UNUSED_RESULT fdb_error_t fdb_create_database_from_connection_str
DLLEXPORT void fdb_use_future_protocol_version();
// the logical read_blob_granules is broken out (at different points depending on the client type) into the asynchronous
// start() that happens on the fdb network thread, and synchronous finish() that happens off it
DLLEXPORT FDBFuture* fdb_transaction_read_blob_granules_start(FDBTransaction* tr,
uint8_t const* begin_key_name,
int begin_key_name_length,
uint8_t const* end_key_name,
int end_key_name_length,
int64_t beginVersion,
int64_t readVersion,
int64_t* readVersionOut);
DLLEXPORT FDBResult* fdb_transaction_read_blob_granules_finish(FDBTransaction* tr,
FDBFuture* f,
uint8_t const* begin_key_name,
int begin_key_name_length,
uint8_t const* end_key_name,
int end_key_name_length,
int64_t beginVersion,
int64_t readVersion,
FDBReadBlobGranuleContext* granuleContext);
#ifdef __cplusplus
}
#endif

View File

@ -622,6 +622,13 @@ func (o TransactionOptions) SetUseGrvCache() error {
return o.setOpt(1101, nil)
}
// Attach given authorization token to the transaction such that subsequent tenant-aware requests are authorized
//
// Parameter: A JSON Web Token authorized to access data belonging to one or more tenants, indicated by 'tenants' claim of the token's payload.
func (o TransactionOptions) SetAuthorizationToken(param string) error {
return o.setOpt(2000, []byte(param))
}
type StreamingMode int
const (

View File

@ -242,7 +242,7 @@ public interface Database extends AutoCloseable, TransactionContext {
}
/**
* Sets a range to be unblobbified in the database.
* Unsets a blobbified range in the database. The range must be aligned to known blob ranges.
*
* @param beginKey start of the key range
* @param endKey end of the key range
@ -260,7 +260,7 @@ public interface Database extends AutoCloseable, TransactionContext {
* @param rangeLimit batch size
* @param e the {@link Executor} to use for asynchronous callbacks
* @return a future with the list of blobbified ranges.
* @return a future with the list of blobbified ranges: [lastLessThan(beginKey), firstGreaterThanOrEqual(endKey)]
*/
default CompletableFuture<KeyRangeArrayResult> listBlobbifiedRanges(byte[] beginKey, byte[] endKey, int rangeLimit) {
return listBlobbifiedRanges(beginKey, endKey, rangeLimit, getExecutor());
@ -274,7 +274,7 @@ public interface Database extends AutoCloseable, TransactionContext {
* @param rangeLimit batch size
* @param e the {@link Executor} to use for asynchronous callbacks
* @return a future with the list of blobbified ranges.
* @return a future with the list of blobbified ranges: [lastLessThan(beginKey), firstGreaterThanOrEqual(endKey)]
*/
CompletableFuture<KeyRangeArrayResult> listBlobbifiedRanges(byte[] beginKey, byte[] endKey, int rangeLimit, Executor e);

View File

@ -175,7 +175,7 @@ class Config:
self.cov_include_files_args = {'help': 'Only consider coverage traces that originated in files matching regex'}
self.cov_exclude_files: str = r'.^'
self.cov_exclude_files_args = {'help': 'Ignore coverage traces that originated in files matching regex'}
self.max_stderr_bytes: int = 1000
self.max_stderr_bytes: int = 10000
self.write_stats: bool = True
self.read_stats: bool = True
self.reproduce_prefix: str | None = None
@ -234,7 +234,10 @@ class Config:
assert type(None) != attr_type
e = os.getenv(env_name)
if e is not None:
self.__setattr__(attr, attr_type(e))
# Use the env var to supply the default value, so that if the
# environment variable is set and the corresponding command line
# flag is not, the environment variable has an effect.
self._config_map[attr].kwargs['default'] = attr_type(e)
def build_arguments(self, parser: argparse.ArgumentParser):
for val in self._config_map.values():

View File

@ -421,11 +421,17 @@ class Summary:
child.attributes['Severity'] = '40'
self.out.append(child)
if self.error_out is not None and len(self.error_out) > 0:
if self.stderr_severity == '40':
self.error = True
lines = self.error_out.split('\n')
lines = self.error_out.splitlines()
stderr_bytes = 0
for line in lines:
if line.endswith("WARNING: ASan doesn't fully support makecontext/swapcontext functions and may produce false positives in some cases!"):
# When running ASAN we expect to see this message. Boost coroutine should be using the correct asan annotations so that it shouldn't produce any false positives.
continue
if line.endswith("Warning: unimplemented fcntl command: 1036"):
# Valgrind produces this warning when F_SET_RW_HINT is used
continue
if self.stderr_severity == '40':
self.error = True
remaining_bytes = config.max_stderr_bytes - stderr_bytes
if remaining_bytes > 0:
out_err = line[0:remaining_bytes] + ('...' if len(line) > remaining_bytes else '')
@ -437,7 +443,7 @@ class Summary:
if stderr_bytes > config.max_stderr_bytes:
child = SummaryTree('StdErrOutputTruncated')
child.attributes['Severity'] = self.stderr_severity
child.attributes['BytesRemaining'] = stderr_bytes - config.max_stderr_bytes
child.attributes['BytesRemaining'] = str(stderr_bytes - config.max_stderr_bytes)
self.out.append(child)
self.out.attributes['Ok'] = '1' if self.ok() else '0'

View File

@ -125,20 +125,3 @@ In each test, the `GlobalTagThrottlerTesting::monitor` function is used to perio
On the ratekeeper, every `SERVER_KNOBS->TAG_THROTTLE_PUSH_INTERVAL` seconds, the ratekeeper will call `GlobalTagThrottler::getClientRates`. At the end of the rate calculation for each tag, a trace event of type `GlobalTagThrottler_GotClientRate` is produced. This trace event reports the relevant inputs that went in to the rate calculation, and can be used for debugging.
On storage servers, every `SERVER_KNOBS->TAG_MEASUREMENT_INTERVAL` seconds, there are `BusyReadTag` events for every tag that has sufficient read cost to be reported to the ratekeeper. Both cost and fractional busyness are reported.
### Status
For each storage server, the busiest read tag is reported in the full status output, along with its cost and fractional busyness.
At path `.cluster.qos.global_tag_throttler`, throttling limitations for each tag are reported:
```
{
"<tagName>": {
"desired_tps": <number>,
"reserved_tps": <number>,
"limiting_tps": [<number>|"unset"],
"target_tps": <number>
},
...
}
```

View File

@ -2,6 +2,17 @@
Release Notes
#############
7.1.21
======
* Same as 7.1.20 release with AVX enabled.
7.1.20
======
* Released with AVX disabled.
* Fixed missing localities for fdbserver that can cause cross DC calls among storage servers. `(PR #7995) <https://github.com/apple/foundationdb/pull/7995>`_
* Removed extremely spammy trace event in FetchKeys and fixed transaction_profiling_analyzer.py. `(PR #7934) <https://github.com/apple/foundationdb/pull/7934>`_
* Fixed bugs when GRV proxy returns an error. `(PR #7860) <https://github.com/apple/foundationdb/pull/7860>`_
7.1.19
======
* Same as 7.1.18 release with AVX enabled.

View File

@ -186,16 +186,12 @@ TEST_CASE("/fdbserver/blobgranule/isRangeCoveredByBlob") {
ASSERT(range.end == "key_b5"_sr);
}
// check unsorted chunks
{
Standalone<VectorRef<BlobGranuleChunkRef>> unsortedChunks(chunks);
testAddChunkRange("key_0"_sr, "key_a"_sr, unsortedChunks);
ASSERT(isRangeFullyCovered(KeyRangeRef("key_00"_sr, "key_01"_sr), unsortedChunks));
}
// check continued chunks
{
Standalone<VectorRef<BlobGranuleChunkRef>> continuedChunks(chunks);
Standalone<VectorRef<BlobGranuleChunkRef>> continuedChunks;
testAddChunkRange("key_a1"_sr, "key_a9"_sr, continuedChunks);
testAddChunkRange("key_a9"_sr, "key_b1"_sr, continuedChunks);
testAddChunkRange("key_b1"_sr, "key_b9"_sr, continuedChunks);
ASSERT(isRangeFullyCovered(KeyRangeRef("key_a1"_sr, "key_b9"_sr), continuedChunks) == false);
}
return Void();

View File

@ -23,6 +23,7 @@
#include "fdbclient/CommitTransaction.h"
#include "fdbclient/FDBTypes.h"
#include "fdbclient/ReadYourWrites.h"
#include "flow/UnitTest.h"
#include "flow/actorcompiler.h" // has to be last include
void KeyRangeActorMap::getRangesAffectedByInsertion(const KeyRangeRef& keys, std::vector<KeyRange>& affectedRanges) {
@ -35,32 +36,54 @@ void KeyRangeActorMap::getRangesAffectedByInsertion(const KeyRangeRef& keys, std
affectedRanges.push_back(KeyRangeRef(keys.end, e.end()));
}
RangeResult krmDecodeRanges(KeyRef mapPrefix, KeyRange keys, RangeResult kv) {
RangeResult krmDecodeRanges(KeyRef mapPrefix, KeyRange keys, RangeResult kv, bool align) {
ASSERT(!kv.more || kv.size() > 1);
KeyRange withPrefix =
KeyRangeRef(mapPrefix.toString() + keys.begin.toString(), mapPrefix.toString() + keys.end.toString());
ValueRef beginValue, endValue;
if (kv.size() && kv[0].key.startsWith(mapPrefix))
beginValue = kv[0].value;
if (kv.size() && kv.end()[-1].key.startsWith(mapPrefix))
endValue = kv.end()[-1].value;
RangeResult result;
result.arena().dependsOn(kv.arena());
result.arena().dependsOn(keys.arena());
result.push_back(result.arena(), KeyValueRef(keys.begin, beginValue));
// Always push a kv pair <= keys.begin.
KeyRef beginKey = keys.begin;
if (!align && !kv.empty() && kv.front().key.startsWith(mapPrefix) && kv.front().key < withPrefix.begin) {
beginKey = kv[0].key.removePrefix(mapPrefix);
}
ValueRef beginValue;
if (!kv.empty() && kv.front().key.startsWith(mapPrefix) && kv.front().key <= withPrefix.begin) {
beginValue = kv.front().value;
}
result.push_back(result.arena(), KeyValueRef(beginKey, beginValue));
for (int i = 0; i < kv.size(); i++) {
if (kv[i].key > withPrefix.begin && kv[i].key < withPrefix.end) {
KeyRef k = kv[i].key.removePrefix(mapPrefix);
result.push_back(result.arena(), KeyValueRef(k, kv[i].value));
} else if (kv[i].key >= withPrefix.end)
} else if (kv[i].key >= withPrefix.end) {
kv.more = false;
// There should be at most 1 value past mapPrefix + keys.end.
ASSERT(i == kv.size() - 1);
break;
}
}
if (!kv.more)
result.push_back(result.arena(), KeyValueRef(keys.end, endValue));
if (!kv.more) {
KeyRef endKey = keys.end;
if (!align && !kv.empty() && kv.back().key.startsWith(mapPrefix) && kv.back().key >= withPrefix.end) {
endKey = kv.back().key.removePrefix(mapPrefix);
}
ValueRef endValue;
if (!kv.empty()) {
// In the aligned case, carry the last value to be the end value.
if (align && kv.back().key.startsWith(mapPrefix) && kv.back().key > withPrefix.end) {
endValue = result.back().value;
} else {
endValue = kv.back().value;
}
}
result.push_back(result.arena(), KeyValueRef(endKey, endValue));
}
result.more = kv.more;
return result;
@ -93,6 +116,37 @@ ACTOR Future<RangeResult> krmGetRanges(Reference<ReadYourWritesTransaction> tr,
return krmDecodeRanges(mapPrefix, keys, kv);
}
// Returns keys.begin, all transitional points in keys, and keys.end, and their values
ACTOR Future<RangeResult> krmGetRangesUnaligned(Transaction* tr,
Key mapPrefix,
KeyRange keys,
int limit,
int limitBytes) {
KeyRange withPrefix =
KeyRangeRef(mapPrefix.toString() + keys.begin.toString(), mapPrefix.toString() + keys.end.toString());
state GetRangeLimits limits(limit, limitBytes);
limits.minRows = 2;
RangeResult kv = wait(tr->getRange(lastLessOrEqual(withPrefix.begin), firstGreaterThan(withPrefix.end), limits));
return krmDecodeRanges(mapPrefix, keys, kv, false);
}
ACTOR Future<RangeResult> krmGetRangesUnaligned(Reference<ReadYourWritesTransaction> tr,
Key mapPrefix,
KeyRange keys,
int limit,
int limitBytes) {
KeyRange withPrefix =
KeyRangeRef(mapPrefix.toString() + keys.begin.toString(), mapPrefix.toString() + keys.end.toString());
state GetRangeLimits limits(limit, limitBytes);
limits.minRows = 2;
RangeResult kv = wait(tr->getRange(lastLessOrEqual(withPrefix.begin), firstGreaterThan(withPrefix.end), limits));
return krmDecodeRanges(mapPrefix, keys, kv, false);
}
void krmSetPreviouslyEmptyRange(Transaction* tr,
const KeyRef& mapPrefix,
const KeyRangeRef& keys,
@ -254,3 +308,87 @@ Future<Void> krmSetRangeCoalescing(Reference<ReadYourWritesTransaction> const& t
Value const& value) {
return holdWhile(tr, krmSetRangeCoalescing_(tr.getPtr(), mapPrefix, range, maxRange, value));
}
TEST_CASE("/keyrangemap/decoderange/aligned") {
Arena arena;
Key prefix = LiteralStringRef("/prefix/");
StringRef fullKeyA = StringRef(arena, LiteralStringRef("/prefix/a"));
StringRef fullKeyB = StringRef(arena, LiteralStringRef("/prefix/b"));
StringRef fullKeyC = StringRef(arena, LiteralStringRef("/prefix/c"));
StringRef fullKeyD = StringRef(arena, LiteralStringRef("/prefix/d"));
StringRef keyA = StringRef(arena, LiteralStringRef("a"));
StringRef keyB = StringRef(arena, LiteralStringRef("b"));
StringRef keyC = StringRef(arena, LiteralStringRef("c"));
StringRef keyD = StringRef(arena, LiteralStringRef("d"));
StringRef keyE = StringRef(arena, LiteralStringRef("e"));
StringRef keyAB = StringRef(arena, LiteralStringRef("ab"));
StringRef keyCD = StringRef(arena, LiteralStringRef("cd"));
// Fake getRange() call.
RangeResult kv;
kv.push_back(arena, KeyValueRef(fullKeyA, keyA));
kv.push_back(arena, KeyValueRef(fullKeyB, keyB));
kv.push_back(arena, KeyValueRef(fullKeyC, keyC));
kv.push_back(arena, KeyValueRef(fullKeyD, keyD));
// [A, AB(start), B, C, CD(end), D]
RangeResult decodedRanges = krmDecodeRanges(prefix, KeyRangeRef(keyAB, keyCD), kv);
ASSERT(decodedRanges.size() == 4);
ASSERT(decodedRanges.front().key == keyAB);
ASSERT(decodedRanges.front().value == keyA);
ASSERT(decodedRanges.back().key == keyCD);
ASSERT(decodedRanges.back().value == keyC);
// [""(start), A, B, C, D, E(end)]
decodedRanges = krmDecodeRanges(prefix, KeyRangeRef(StringRef(), keyE), kv);
ASSERT(decodedRanges.size() == 6);
ASSERT(decodedRanges.front().key == StringRef());
ASSERT(decodedRanges.front().value == StringRef());
ASSERT(decodedRanges.back().key == keyE);
ASSERT(decodedRanges.back().value == keyD);
return Void();
}
TEST_CASE("/keyrangemap/decoderange/unaligned") {
Arena arena;
Key prefix = LiteralStringRef("/prefix/");
StringRef fullKeyA = StringRef(arena, LiteralStringRef("/prefix/a"));
StringRef fullKeyB = StringRef(arena, LiteralStringRef("/prefix/b"));
StringRef fullKeyC = StringRef(arena, LiteralStringRef("/prefix/c"));
StringRef fullKeyD = StringRef(arena, LiteralStringRef("/prefix/d"));
StringRef keyA = StringRef(arena, LiteralStringRef("a"));
StringRef keyB = StringRef(arena, LiteralStringRef("b"));
StringRef keyC = StringRef(arena, LiteralStringRef("c"));
StringRef keyD = StringRef(arena, LiteralStringRef("d"));
StringRef keyE = StringRef(arena, LiteralStringRef("e"));
StringRef keyAB = StringRef(arena, LiteralStringRef("ab"));
StringRef keyCD = StringRef(arena, LiteralStringRef("cd"));
// Fake getRange() call.
RangeResult kv;
kv.push_back(arena, KeyValueRef(fullKeyA, keyA));
kv.push_back(arena, KeyValueRef(fullKeyB, keyB));
kv.push_back(arena, KeyValueRef(fullKeyC, keyC));
kv.push_back(arena, KeyValueRef(fullKeyD, keyD));
// [A, AB(start), B, C, CD(end), D]
RangeResult decodedRanges = krmDecodeRanges(prefix, KeyRangeRef(keyAB, keyCD), kv, false);
ASSERT(decodedRanges.size() == 4);
ASSERT(decodedRanges.front().key == keyA);
ASSERT(decodedRanges.front().value == keyA);
ASSERT(decodedRanges.back().key == keyD);
ASSERT(decodedRanges.back().value == keyD);
// [""(start), A, B, C, D, E(end)]
decodedRanges = krmDecodeRanges(prefix, KeyRangeRef(StringRef(), keyE), kv, false);
ASSERT(decodedRanges.size() == 6);
ASSERT(decodedRanges.front().key == StringRef());
ASSERT(decodedRanges.front().value == StringRef());
ASSERT(decodedRanges.back().key == keyE);
ASSERT(decodedRanges.back().value == keyD);
return Void();
}

View File

@ -280,10 +280,46 @@ ThreadResult<RangeResult> DLTransaction::readBlobGranules(const KeyRangeRef& key
Version beginVersion,
Optional<Version> readVersion,
ReadBlobGranuleContext granuleContext) {
if (!api->transactionReadBlobGranules) {
return unsupported_operation();
}
ThreadFuture<Standalone<VectorRef<BlobGranuleChunkRef>>> DLTransaction::readBlobGranulesStart(
const KeyRangeRef& keyRange,
Version beginVersion,
Optional<Version> readVersion,
Version* readVersionOut) {
if (!api->transactionReadBlobGranulesStart) {
return unsupported_operation();
}
int64_t rv = readVersion.present() ? readVersion.get() : latestVersion;
FdbCApi::FDBFuture* f = api->transactionReadBlobGranulesStart(tr,
keyRange.begin.begin(),
keyRange.begin.size(),
keyRange.end.begin(),
keyRange.end.size(),
beginVersion,
rv,
readVersionOut);
return ThreadFuture<Standalone<VectorRef<BlobGranuleChunkRef>>>(
(ThreadSingleAssignmentVar<Standalone<VectorRef<BlobGranuleChunkRef>>>*)(f));
};
ThreadResult<RangeResult> DLTransaction::readBlobGranulesFinish(
ThreadFuture<Standalone<VectorRef<BlobGranuleChunkRef>>> startFuture,
const KeyRangeRef& keyRange,
Version beginVersion,
Version readVersion,
ReadBlobGranuleContext granuleContext) {
if (!api->transactionReadBlobGranulesFinish) {
return unsupported_operation();
}
// convert back to fdb future for API
FdbCApi::FDBFuture* f = (FdbCApi::FDBFuture*)(startFuture.extractPtr());
// FIXME: better way to convert here?
FdbCApi::FDBReadBlobGranuleContext context;
context.userContext = granuleContext.userContext;
@ -293,18 +329,18 @@ ThreadResult<RangeResult> DLTransaction::readBlobGranules(const KeyRangeRef& key
context.debugNoMaterialize = granuleContext.debugNoMaterialize;
context.granuleParallelism = granuleContext.granuleParallelism;
int64_t rv = readVersion.present() ? readVersion.get() : latestVersion;
FdbCApi::FDBResult* r = api->transactionReadBlobGranulesFinish(tr,
f,
keyRange.begin.begin(),
keyRange.begin.size(),
keyRange.end.begin(),
keyRange.end.size(),
beginVersion,
readVersion,
&context);
FdbCApi::FDBResult* r = api->transactionReadBlobGranules(tr,
keyRange.begin.begin(),
keyRange.begin.size(),
keyRange.end.begin(),
keyRange.end.size(),
beginVersion,
rv,
context);
return ThreadResult<RangeResult>((ThreadSingleAssignmentVar<RangeResult>*)(r));
}
};
void DLTransaction::addReadConflictRange(const KeyRangeRef& keys) {
throwIfError(api->transactionAddConflictRange(
@ -812,6 +848,16 @@ void DLApi::init() {
headerVersion >= 710);
loadClientFunction(
&api->transactionReadBlobGranules, lib, fdbCPath, "fdb_transaction_read_blob_granules", headerVersion >= 710);
loadClientFunction(&api->transactionReadBlobGranulesStart,
lib,
fdbCPath,
"fdb_transaction_read_blob_granules_start",
headerVersion >= 720);
loadClientFunction(&api->transactionReadBlobGranulesFinish,
lib,
fdbCPath,
"fdb_transaction_read_blob_granules_finish",
headerVersion >= 720);
loadClientFunction(&api->futureGetInt64,
lib,
fdbCPath,
@ -1165,14 +1211,45 @@ ThreadResult<RangeResult> MultiVersionTransaction::readBlobGranules(const KeyRan
Version beginVersion,
Optional<Version> readVersion,
ReadBlobGranuleContext granuleContext) {
// FIXME: prevent from calling this from another main thread?
auto tr = getTransaction();
if (tr.transaction) {
return tr.transaction->readBlobGranules(keyRange, beginVersion, readVersion, granuleContext);
Version readVersionOut;
auto f = tr.transaction->readBlobGranulesStart(keyRange, beginVersion, readVersion, &readVersionOut);
auto abortableF = abortableFuture(f, tr.onChange);
abortableF.blockUntilReadyCheckOnMainThread();
if (abortableF.isError()) {
return ThreadResult<RangeResult>(abortableF.getError());
}
if (granuleContext.debugNoMaterialize) {
return ThreadResult<RangeResult>(blob_granule_not_materialized());
}
return tr.transaction->readBlobGranulesFinish(
abortableF, keyRange, beginVersion, readVersionOut, granuleContext);
} else {
return abortableTimeoutResult<RangeResult>(tr.onChange);
}
}
ThreadFuture<Standalone<VectorRef<BlobGranuleChunkRef>>> MultiVersionTransaction::readBlobGranulesStart(
const KeyRangeRef& keyRange,
Version beginVersion,
Optional<Version> readVersion,
Version* readVersionOut) {
// can't call this directly
return ThreadFuture<Standalone<VectorRef<BlobGranuleChunkRef>>>(unsupported_operation());
}
ThreadResult<RangeResult> MultiVersionTransaction::readBlobGranulesFinish(
ThreadFuture<Standalone<VectorRef<BlobGranuleChunkRef>>> startFuture,
const KeyRangeRef& keyRange,
Version beginVersion,
Version readVersion,
ReadBlobGranuleContext granuleContext) {
// can't call this directly
return ThreadResult<RangeResult>(unsupported_operation());
}
void MultiVersionTransaction::atomicOp(const KeyRef& key, const ValueRef& value, uint32_t operationType) {
auto tr = getTransaction();
if (tr.transaction) {

View File

@ -103,6 +103,8 @@
#endif
#include "flow/actorcompiler.h" // This must be the last #include.
FDB_DEFINE_BOOLEAN_PARAM(CacheResult);
extern const char* getSourceVersion();
namespace {
@ -231,8 +233,9 @@ void DatabaseContext::getLatestCommitVersions(const Reference<LocationInfo>& loc
VersionVector& latestCommitVersions) {
latestCommitVersions.clear();
if (info->debugID.present()) {
g_traceBatch.addEvent("TransactionDebug", info->debugID.get().first(), "NativeAPI.getLatestCommitVersions");
if (info->readOptions.present() && info->readOptions.get().debugID.present()) {
g_traceBatch.addEvent(
"TransactionDebug", info->readOptions.get().debugID.get().first(), "NativeAPI.getLatestCommitVersions");
}
if (!info->readVersionObtainedFromGrvProxy) {
@ -2968,7 +2971,7 @@ Future<KeyRangeLocationInfo> getKeyLocation(Reference<TransactionState> trState,
key,
member,
trState->spanContext,
trState->debugID,
trState->readOptions.present() ? trState->readOptions.get().debugID : Optional<UID>(),
trState->useProvisionalProxies,
isBackward,
version);
@ -3109,7 +3112,7 @@ Future<std::vector<KeyRangeLocationInfo>> getKeyRangeLocations(Reference<Transac
reverse,
member,
trState->spanContext,
trState->debugID,
trState->readOptions.present() ? trState->readOptions.get().debugID : Optional<UID>(),
trState->useProvisionalProxies,
version);
@ -3131,16 +3134,16 @@ ACTOR Future<Void> warmRange_impl(Reference<TransactionState> trState, KeyRange
state Version version = wait(fVersion);
loop {
std::vector<KeyRangeLocationInfo> locations =
wait(getKeyRangeLocations_internal(trState->cx,
trState->getTenantInfo(),
keys,
CLIENT_KNOBS->WARM_RANGE_SHARD_LIMIT,
Reverse::False,
trState->spanContext,
trState->debugID,
trState->useProvisionalProxies,
version));
std::vector<KeyRangeLocationInfo> locations = wait(getKeyRangeLocations_internal(
trState->cx,
trState->getTenantInfo(),
keys,
CLIENT_KNOBS->WARM_RANGE_SHARD_LIMIT,
Reverse::False,
trState->spanContext,
trState->readOptions.present() ? trState->readOptions.get().debugID : Optional<UID>(),
trState->useProvisionalProxies,
version));
totalRanges += CLIENT_KNOBS->WARM_RANGE_SHARD_LIMIT;
totalRequests++;
if (locations.size() == 0 || totalRanges >= trState->cx->locationCacheSize ||
@ -3298,12 +3301,16 @@ ACTOR Future<Optional<Value>> getValue(Reference<TransactionState> trState,
state uint64_t startTime;
state double startTimeD;
state VersionVector ssLatestCommitVersions;
state Optional<ReadOptions> readOptions = trState->readOptions;
trState->cx->getLatestCommitVersions(locationInfo.locations, ver, trState, ssLatestCommitVersions);
try {
if (trState->debugID.present()) {
if (trState->readOptions.present() && trState->readOptions.get().debugID.present()) {
getValueID = nondeterministicRandom()->randomUniqueID();
readOptions.get().debugID = getValueID;
g_traceBatch.addAttach("GetValueAttachID", trState->debugID.get().first(), getValueID.get().first());
g_traceBatch.addAttach(
"GetValueAttachID", trState->readOptions.get().debugID.get().first(), getValueID.get().first());
g_traceBatch.addEvent("GetValueDebug",
getValueID.get().first(),
"NativeAPI.getValue.Before"); //.detail("TaskID", g_network->getCurrentTask());
@ -3336,7 +3343,7 @@ ACTOR Future<Optional<Value>> getValue(Reference<TransactionState> trState,
ver,
trState->cx->sampleReadTags() ? trState->options.readTags
: Optional<TagSet>(),
getValueID,
readOptions,
ssLatestCommitVersions),
TaskPriority::DefaultPromiseEndpoint,
AtMostOnce::False,
@ -3411,12 +3418,16 @@ ACTOR Future<Key> getKey(Reference<TransactionState> trState,
UseTenant useTenant = UseTenant::True) {
wait(success(version));
state Optional<UID> getKeyID = Optional<UID>();
state Span span("NAPI:getKey"_loc, trState->spanContext);
if (trState->debugID.present()) {
getKeyID = nondeterministicRandom()->randomUniqueID();
state Optional<UID> getKeyID;
state Optional<ReadOptions> readOptions = trState->readOptions;
g_traceBatch.addAttach("GetKeyAttachID", trState->debugID.get().first(), getKeyID.get().first());
state Span span("NAPI:getKey"_loc, trState->spanContext);
if (trState->readOptions.present() && trState->readOptions.get().debugID.present()) {
getKeyID = nondeterministicRandom()->randomUniqueID();
readOptions.get().debugID = getKeyID;
g_traceBatch.addAttach(
"GetKeyAttachID", trState->readOptions.get().debugID.get().first(), getKeyID.get().first());
g_traceBatch.addEvent(
"GetKeyDebug",
getKeyID.get().first(),
@ -3459,7 +3470,7 @@ ACTOR Future<Key> getKey(Reference<TransactionState> trState,
k,
version.get(),
trState->cx->sampleReadTags() ? trState->options.readTags : Optional<TagSet>(),
getKeyID,
readOptions,
ssLatestCommitVersions);
req.arena.dependsOn(k.arena());
@ -3924,13 +3935,15 @@ Future<RangeResultFamily> getExactRange(Reference<TransactionState> trState,
// FIXME: buggify byte limits on internal functions that use them, instead of globally
req.tags = trState->cx->sampleReadTags() ? trState->options.readTags : Optional<TagSet>();
req.debugID = trState->debugID;
req.options = trState->readOptions;
try {
if (trState->debugID.present()) {
g_traceBatch.addEvent(
"TransactionDebug", trState->debugID.get().first(), "NativeAPI.getExactRange.Before");
/*TraceEvent("TransactionDebugGetExactRangeInfo", trState->debugID.get())
if (trState->readOptions.present() && trState->readOptions.get().debugID.present()) {
g_traceBatch.addEvent("TransactionDebug",
trState->readOptions.get().debugID.get().first(),
"NativeAPI.getExactRange.Before");
/*TraceEvent("TransactionDebugGetExactRangeInfo", trState->readOptions.debugID.get())
.detail("ReqBeginKey", req.begin.getKey())
.detail("ReqEndKey", req.end.getKey())
.detail("ReqLimit", req.limit)
@ -3960,9 +3973,10 @@ Future<RangeResultFamily> getExactRange(Reference<TransactionState> trState,
++trState->cx->transactionPhysicalReadsCompleted;
throw;
}
if (trState->debugID.present())
g_traceBatch.addEvent(
"TransactionDebug", trState->debugID.get().first(), "NativeAPI.getExactRange.After");
if (trState->readOptions.present() && trState->readOptions.get().debugID.present())
g_traceBatch.addEvent("TransactionDebug",
trState->readOptions.get().debugID.get().first(),
"NativeAPI.getExactRange.After");
output.arena().dependsOn(rep.arena);
output.append(output.arena(), rep.data.begin(), rep.data.size());
@ -4290,7 +4304,7 @@ Future<RangeResultFamily> getRange(Reference<TransactionState> trState,
req.arena.dependsOn(mapper.arena());
setMatchIndex<GetKeyValuesFamilyRequest>(req, matchIndex);
req.tenantInfo = useTenant ? trState->getTenantInfo() : TenantInfo();
req.isFetchKeys = (trState->taskID == TaskPriority::FetchKeys);
req.options = trState->readOptions;
req.version = readVersion;
trState->cx->getLatestCommitVersions(
@ -4328,13 +4342,13 @@ Future<RangeResultFamily> getRange(Reference<TransactionState> trState,
ASSERT(req.limitBytes > 0 && req.limit != 0 && req.limit < 0 == reverse);
req.tags = trState->cx->sampleReadTags() ? trState->options.readTags : Optional<TagSet>();
req.debugID = trState->debugID;
req.spanContext = span.context;
try {
if (trState->debugID.present()) {
g_traceBatch.addEvent(
"TransactionDebug", trState->debugID.get().first(), "NativeAPI.getRange.Before");
/*TraceEvent("TransactionDebugGetRangeInfo", trState->debugID.get())
if (trState->readOptions.present() && trState->readOptions.get().debugID.present()) {
g_traceBatch.addEvent("TransactionDebug",
trState->readOptions.get().debugID.get().first(),
"NativeAPI.getRange.Before");
/*TraceEvent("TransactionDebugGetRangeInfo", trState->readOptions.debugID.get())
.detail("ReqBeginKey", req.begin.getKey())
.detail("ReqEndKey", req.end.getKey())
.detail("OriginalBegin", originalBegin.toString())
@ -4373,11 +4387,11 @@ Future<RangeResultFamily> getRange(Reference<TransactionState> trState,
throw;
}
if (trState->debugID.present()) {
if (trState->readOptions.present() && trState->readOptions.get().debugID.present()) {
g_traceBatch.addEvent("TransactionDebug",
trState->debugID.get().first(),
trState->readOptions.get().debugID.get().first(),
"NativeAPI.getRange.After"); //.detail("SizeOf", rep.data.size());
/*TraceEvent("TransactionDebugGetRangeDone", trState->debugID.get())
/*TraceEvent("TransactionDebugGetRangeDone", trState->readOptions.debugID.get())
.detail("ReqBeginKey", req.begin.getKey())
.detail("ReqEndKey", req.end.getKey())
.detail("RepIsMore", rep.more)
@ -4489,10 +4503,11 @@ Future<RangeResultFamily> getRange(Reference<TransactionState> trState,
}
} catch (Error& e) {
if (trState->debugID.present()) {
g_traceBatch.addEvent(
"TransactionDebug", trState->debugID.get().first(), "NativeAPI.getRange.Error");
TraceEvent("TransactionDebugError", trState->debugID.get()).error(e);
if (trState->readOptions.present() && trState->readOptions.get().debugID.present()) {
g_traceBatch.addEvent("TransactionDebug",
trState->readOptions.get().debugID.get().first(),
"NativeAPI.getRange.Error");
TraceEvent("TransactionDebugError", trState->readOptions.get().debugID.get()).error(e);
}
if (e.code() == error_code_wrong_shard_server || e.code() == error_code_all_alternatives_failed ||
(e.code() == error_code_transaction_too_old && readVersion == latestVersion)) {
@ -4744,9 +4759,8 @@ ACTOR Future<Void> getRangeStreamFragment(Reference<TransactionState> trState,
req.spanContext = spanContext;
req.limit = reverse ? -CLIENT_KNOBS->REPLY_BYTE_LIMIT : CLIENT_KNOBS->REPLY_BYTE_LIMIT;
req.limitBytes = std::numeric_limits<int>::max();
// leaving the flag off for now to prevent data fetches stall under heavy load
// it is used to inform the storage that the rangeRead is for Fetch
// req.isFetchKeys = (trState->taskID == TaskPriority::FetchKeys);
req.options = trState->readOptions;
trState->cx->getLatestCommitVersions(
locations[shard].locations, req.version, trState, req.ssLatestCommitVersions);
@ -4757,12 +4771,12 @@ ACTOR Future<Void> getRangeStreamFragment(Reference<TransactionState> trState,
// FIXME: buggify byte limits on internal functions that use them, instead of globally
req.tags = trState->cx->sampleReadTags() ? trState->options.readTags : Optional<TagSet>();
req.debugID = trState->debugID;
try {
if (trState->debugID.present()) {
g_traceBatch.addEvent(
"TransactionDebug", trState->debugID.get().first(), "NativeAPI.RangeStream.Before");
if (trState->readOptions.present() && trState->readOptions.get().debugID.present()) {
g_traceBatch.addEvent("TransactionDebug",
trState->readOptions.get().debugID.get().first(),
"NativeAPI.RangeStream.Before");
}
++trState->cx->transactionPhysicalReads;
state GetKeyValuesStreamReply rep;
@ -4856,9 +4870,10 @@ ACTOR Future<Void> getRangeStreamFragment(Reference<TransactionState> trState,
}
rep = GetKeyValuesStreamReply();
}
if (trState->debugID.present())
g_traceBatch.addEvent(
"TransactionDebug", trState->debugID.get().first(), "NativeAPI.getExactRange.After");
if (trState->readOptions.present() && trState->readOptions.get().debugID.present())
g_traceBatch.addEvent("TransactionDebug",
trState->readOptions.get().debugID.get().first(),
"NativeAPI.getExactRange.After");
RangeResult output(RangeResultRef(rep.data, rep.more), rep.arena);
if (tssDuplicateStream.present() && !tssDuplicateStream.get().done()) {
@ -5337,7 +5352,7 @@ Future<Void> Transaction::watch(Reference<Watch> watch) {
trState->options.readTags,
trState->spanContext,
trState->taskID,
trState->debugID,
trState->readOptions.present() ? trState->readOptions.get().debugID : Optional<UID>(),
trState->useProvisionalProxies);
}
@ -6005,16 +6020,17 @@ void Transaction::setupWatches() {
Future<Version> watchVersion = getCommittedVersion() > 0 ? getCommittedVersion() : getReadVersion();
for (int i = 0; i < watches.size(); ++i)
watches[i]->setWatch(watchValueMap(watchVersion,
trState->getTenantInfo(),
watches[i]->key,
watches[i]->value,
trState->cx,
trState->options.readTags,
trState->spanContext,
trState->taskID,
trState->debugID,
trState->useProvisionalProxies));
watches[i]->setWatch(
watchValueMap(watchVersion,
trState->getTenantInfo(),
watches[i]->key,
watches[i]->value,
trState->cx,
trState->options.readTags,
trState->spanContext,
trState->taskID,
trState->readOptions.present() ? trState->readOptions.get().debugID : Optional<UID>(),
trState->useProvisionalProxies));
watches.clear();
} catch (Error&) {
@ -6135,7 +6151,7 @@ ACTOR static Future<Void> tryCommit(Reference<TransactionState> trState,
state TraceInterval interval("TransactionCommit");
state double startTime = now();
state Span span("NAPI:tryCommit"_loc, trState->spanContext);
state Optional<UID> debugID = trState->debugID;
state Optional<UID> debugID = trState->readOptions.present() ? trState->readOptions.get().debugID : Optional<UID>();
state TenantPrefixPrepended tenantPrefixPrepended = TenantPrefixPrepended::False;
if (debugID.present()) {
TraceEvent(interval.begin()).detail("Parent", debugID.get());
@ -6539,10 +6555,10 @@ void Transaction::setOption(FDBTransactionOptions::Option option, Optional<Strin
makeReference<TransactionLogInfo>(value.get().printable(), TransactionLogInfo::DONT_LOG);
trState->trLogInfo->maxFieldLength = trState->options.maxTransactionLoggingFieldLength;
}
if (trState->debugID.present()) {
if (trState->readOptions.present() && trState->readOptions.get().debugID.present()) {
TraceEvent(SevInfo, "TransactionBeingTraced")
.detail("DebugTransactionID", trState->trLogInfo->identifier)
.detail("ServerTraceID", trState->debugID.get());
.detail("ServerTraceID", trState->readOptions.get().debugID.get());
}
break;
@ -6574,10 +6590,11 @@ void Transaction::setOption(FDBTransactionOptions::Option option, Optional<Strin
case FDBTransactionOptions::SERVER_REQUEST_TRACING:
validateOptionValueNotPresent(value);
debugTransaction(deterministicRandom()->randomUniqueID());
if (trState->trLogInfo && !trState->trLogInfo->identifier.empty()) {
if (trState->trLogInfo && !trState->trLogInfo->identifier.empty() && trState->readOptions.present() &&
trState->readOptions.get().debugID.present()) {
TraceEvent(SevInfo, "TransactionBeingTraced")
.detail("DebugTransactionID", trState->trLogInfo->identifier)
.detail("ServerTraceID", trState->debugID.get());
.detail("ServerTraceID", trState->readOptions.get().debugID.get());
}
break;
@ -7048,7 +7065,9 @@ Future<Version> Transaction::getReadVersion(uint32_t flags) {
Location location = "NAPI:getReadVersion"_loc;
SpanContext spanContext = generateSpanID(trState->cx->transactionTracingSample, trState->spanContext);
auto const req = DatabaseContext::VersionRequest(spanContext, trState->options.tags, trState->debugID);
Optional<UID> versionDebugID =
trState->readOptions.present() ? trState->readOptions.get().debugID : Optional<UID>();
auto const req = DatabaseContext::VersionRequest(spanContext, trState->options.tags, versionDebugID);
batcher.stream.send(req);
trState->startTime = now();
readVersion = extractReadVersion(trState, location, spanContext, req.reply.getFuture(), metadataVersion);
@ -7208,7 +7227,8 @@ Future<Void> Transaction::onError(Error const& e) {
if (e.code() == error_code_not_committed || e.code() == error_code_commit_unknown_result ||
e.code() == error_code_database_locked || e.code() == error_code_commit_proxy_memory_limit_exceeded ||
e.code() == error_code_grv_proxy_memory_limit_exceeded || e.code() == error_code_process_behind ||
e.code() == error_code_batch_transaction_throttled || e.code() == error_code_tag_throttled) {
e.code() == error_code_batch_transaction_throttled || e.code() == error_code_tag_throttled ||
e.code() == error_code_blob_granule_request_failed) {
if (e.code() == error_code_not_committed)
++trState->cx->transactionsNotCommitted;
else if (e.code() == error_code_commit_unknown_result)
@ -7616,14 +7636,15 @@ ACTOR Future<TenantMapEntry> blobGranuleGetTenantEntry(Transaction* self, Key ra
Optional<KeyRangeLocationInfo> cachedLocationInfo =
self->trState->cx->getCachedLocation(self->getTenant().get(), rangeStartKey, Reverse::False);
if (!cachedLocationInfo.present()) {
KeyRangeLocationInfo l = wait(getKeyLocation_internal(self->trState->cx,
self->trState->getTenantInfo(AllowInvalidTenantID::True),
rangeStartKey,
self->trState->spanContext,
self->trState->debugID,
self->trState->useProvisionalProxies,
Reverse::False,
latestVersion));
KeyRangeLocationInfo l = wait(getKeyLocation_internal(
self->trState->cx,
self->trState->getTenantInfo(AllowInvalidTenantID::True),
rangeStartKey,
self->trState->spanContext,
self->trState->readOptions.present() ? self->trState->readOptions.get().debugID : Optional<UID>(),
self->trState->useProvisionalProxies,
Reverse::False,
latestVersion));
self->trState->trySetTenantId(l.tenantEntry.id);
return l.tenantEntry;
} else {
@ -7839,11 +7860,9 @@ ACTOR Future<Standalone<VectorRef<BlobGranuleChunkRef>>> readBlobGranulesActor(
getValue(self->trState, blobWorkerListKeyFor(workerId), self->getReadVersion(), UseTenant::False)));
// from the time the mapping was read from the db, the associated blob worker
// could have died and so its interface wouldn't be present as part of the blobWorkerList
// we persist in the db. So throw wrong_shard_server to get the new mapping
// we persist in the db. So throw blob_granule_request_failed to get the new mapping
if (!workerInterface.present()) {
// need to re-read mapping, throw transaction_too_old so client retries. TODO better error?
// throw wrong_shard_server();
throw transaction_too_old();
throw blob_granule_request_failed();
}
// FIXME: maybe just want to insert here if there are racing queries for the same worker or something?
self->trState->cx->blobWorker_interf[workerId] = decodeBlobWorkerListValue(workerInterface.get());
@ -7978,10 +7997,8 @@ ACTOR Future<Standalone<VectorRef<BlobGranuleChunkRef>>> readBlobGranulesActor(
e.name());
}
// worker is up but didn't actually have granule, or connection failed
if (e.code() == error_code_wrong_shard_server || e.code() == error_code_connection_failed ||
e.code() == error_code_unknown_tenant) {
// need to re-read mapping, throw transaction_too_old so client retries. TODO better error?
throw transaction_too_old();
if (e.code() == error_code_wrong_shard_server || e.code() == error_code_connection_failed) {
throw blob_granule_request_failed();
}
throw e;
}
@ -9587,7 +9604,7 @@ ACTOR Future<Void> getChangeFeedStreamActor(Reference<DatabaseContext> db,
if (e.code() == error_code_wrong_shard_server || e.code() == error_code_all_alternatives_failed ||
e.code() == error_code_connection_failed || e.code() == error_code_unknown_change_feed ||
e.code() == error_code_broken_promise) {
e.code() == error_code_broken_promise || e.code() == error_code_future_version) {
db->changeFeedCache.erase(rangeID);
cx->invalidateCache(Key(), keys);
if (begin == lastBeginVersion) {
@ -9709,7 +9726,8 @@ ACTOR Future<OverlappingChangeFeedsInfo> getOverlappingChangeFeedsActor(Referenc
}
return result;
} catch (Error& e) {
if (e.code() == error_code_wrong_shard_server || e.code() == error_code_all_alternatives_failed) {
if (e.code() == error_code_wrong_shard_server || e.code() == error_code_all_alternatives_failed ||
e.code() == error_code_future_version) {
cx->invalidateCache(Key(), range);
wait(delay(CLIENT_KNOBS->WRONG_SHARD_SERVER_DELAY));
} else {
@ -9926,6 +9944,39 @@ Future<Void> DatabaseContext::waitPurgeGranulesComplete(Key purgeKey) {
return waitPurgeGranulesCompleteActor(Reference<DatabaseContext>::addRef(this), purgeKey);
}
ACTOR Future<Standalone<VectorRef<KeyRangeRef>>> getBlobRanges(Reference<ReadYourWritesTransaction> tr,
KeyRange range,
int batchLimit) {
state Standalone<VectorRef<KeyRangeRef>> blobRanges;
state Key beginKey = range.begin;
loop {
try {
tr->setOption(FDBTransactionOptions::ACCESS_SYSTEM_KEYS);
state RangeResult results = wait(
krmGetRangesUnaligned(tr, blobRangeKeys.begin, KeyRangeRef(beginKey, range.end), 2 * batchLimit + 2));
blobRanges.arena().dependsOn(results.arena());
for (int i = 0; i < results.size() - 1; i++) {
if (results[i].value == blobRangeActive) {
blobRanges.push_back(blobRanges.arena(), KeyRangeRef(results[i].key, results[i + 1].key));
}
if (blobRanges.size() == batchLimit) {
return blobRanges;
}
}
if (!results.more) {
return blobRanges;
}
beginKey = results.back().key;
} catch (Error& e) {
wait(tr->onError(e));
}
}
}
ACTOR Future<bool> setBlobRangeActor(Reference<DatabaseContext> cx, KeyRange range, bool active) {
state Database db(cx);
state Reference<ReadYourWritesTransaction> tr = makeReference<ReadYourWritesTransaction>(db);
@ -9937,18 +9988,26 @@ ACTOR Future<bool> setBlobRangeActor(Reference<DatabaseContext> cx, KeyRange ran
tr->setOption(FDBTransactionOptions::ACCESS_SYSTEM_KEYS);
tr->setOption(FDBTransactionOptions::PRIORITY_SYSTEM_IMMEDIATE);
state Standalone<VectorRef<KeyRangeRef>> startBlobRanges = wait(getBlobRanges(tr, range, 10));
state Standalone<VectorRef<KeyRangeRef>> endBlobRanges =
wait(getBlobRanges(tr, KeyRangeRef(range.end, keyAfter(range.end)), 10));
if (active) {
state RangeResult results = wait(krmGetRanges(tr, blobRangeKeys.begin, range));
ASSERT(results.size() >= 2);
if (results[0].key == range.begin && results[1].key == range.end &&
results[0].value == blobRangeActive) {
// Idempotent request.
if (!startBlobRanges.empty() && !endBlobRanges.empty()) {
return startBlobRanges.front().begin == range.begin && endBlobRanges.front().end == range.end;
}
} else {
// An unblobbify request must be aligned to boundaries.
// It is okay to unblobbify multiple regions all at once.
if (startBlobRanges.empty() && endBlobRanges.empty()) {
return true;
} else {
for (int i = 0; i < results.size(); i++) {
if (results[i].value == blobRangeActive) {
return false;
}
}
}
// If there is a blob at the beginning of the range and it isn't aligned,
// or there is a blob range that begins before the end of the range, then fail.
if ((!startBlobRanges.empty() && startBlobRanges.front().begin != range.begin) ||
(!endBlobRanges.empty() && endBlobRanges.front().begin < range.end)) {
return false;
}
}
@ -9980,29 +10039,10 @@ ACTOR Future<Standalone<VectorRef<KeyRangeRef>>> listBlobbifiedRangesActor(Refer
int rangeLimit) {
state Database db(cx);
state Reference<ReadYourWritesTransaction> tr = makeReference<ReadYourWritesTransaction>(db);
state Standalone<VectorRef<KeyRangeRef>> blobRanges;
loop {
try {
tr->setOption(FDBTransactionOptions::ACCESS_SYSTEM_KEYS);
state Standalone<VectorRef<KeyRangeRef>> blobRanges = wait(getBlobRanges(tr, range, rangeLimit));
state RangeResult results = wait(krmGetRanges(tr, blobRangeKeys.begin, range, 2 * rangeLimit + 2));
blobRanges.arena().dependsOn(results.arena());
for (int i = 0; i < results.size() - 1; i++) {
if (results[i].value == LiteralStringRef("1")) {
blobRanges.push_back(blobRanges.arena(), KeyRangeRef(results[i].key, results[i + 1].key));
}
if (blobRanges.size() == rangeLimit) {
return blobRanges;
}
}
return blobRanges;
} catch (Error& e) {
wait(tr->onError(e));
}
}
return blobRanges;
}
Future<Standalone<VectorRef<KeyRangeRef>>> DatabaseContext::listBlobbifiedRanges(KeyRange range, int rowLimit) {

View File

@ -681,7 +681,8 @@ public:
break;
if (it.is_unknown_range()) {
if (limits.hasByteLimit() && result.size() && itemsPastEnd >= 1 - end.offset) {
if (limits.hasByteLimit() && limits.hasSatisfiedMinRows() && result.size() &&
itemsPastEnd >= 1 - end.offset) {
result.more = true;
break;
}

View File

@ -418,6 +418,7 @@ void ServerKnobs::initialize(Randomize randomize, ClientKnobs* clientKnobs, IsSi
init( ROCKSDB_BLOCK_SIZE, 32768 ); // 32 KB, size of the block in rocksdb cache.
init( ENABLE_SHARDED_ROCKSDB, false );
init( ROCKSDB_WRITE_BUFFER_SIZE, 1 << 30 ); // 1G
init( ROCKSDB_CF_WRITE_BUFFER_SIZE, 64 << 20 ); // 64M, RocksDB default.
init( ROCKSDB_MAX_TOTAL_WAL_SIZE, 0 ); // RocksDB default.
init( ROCKSDB_MAX_BACKGROUND_JOBS, 2 ); // RocksDB default.
init( ROCKSDB_DELETE_OBSOLETE_FILE_PERIOD, 21600 ); // 6h, RocksDB default.
@ -485,7 +486,7 @@ void ServerKnobs::initialize(Randomize randomize, ClientKnobs* clientKnobs, IsSi
init( REPORT_TRANSACTION_COST_ESTIMATION_DELAY, 0.1 );
init( PROXY_REJECT_BATCH_QUEUED_TOO_LONG, true );
bool buggfyUseResolverPrivateMutations = randomize && BUGGIFY && !ENABLE_VERSION_VECTOR_TLOG_UNICAST;
bool buggfyUseResolverPrivateMutations = randomize && BUGGIFY && !ENABLE_VERSION_VECTOR_TLOG_UNICAST;
init( PROXY_USE_RESOLVER_PRIVATE_MUTATIONS, false ); if( buggfyUseResolverPrivateMutations ) PROXY_USE_RESOLVER_PRIVATE_MUTATIONS = deterministicRandom()->coinflip();
init( RESET_MASTER_BATCHES, 200 );
@ -636,6 +637,8 @@ void ServerKnobs::initialize(Randomize randomize, ClientKnobs* clientKnobs, IsSi
init( SPRING_BYTES_STORAGE_SERVER_BATCH, 100e6 ); if( smallStorageTarget ) SPRING_BYTES_STORAGE_SERVER_BATCH = 150e3;
init( STORAGE_HARD_LIMIT_BYTES, 1500e6 ); if( smallStorageTarget ) STORAGE_HARD_LIMIT_BYTES = 4500e3;
init( STORAGE_HARD_LIMIT_BYTES_OVERAGE, 5000e3 ); if( smallStorageTarget ) STORAGE_HARD_LIMIT_BYTES_OVERAGE = 100e3; // byte+version overage ensures storage server makes enough progress on freeing up storage queue memory at hard limit by ensuring it advances desiredOldestVersion enough per commit cycle.
init( STORAGE_HARD_LIMIT_BYTES_SPEED_UP_SIM, STORAGE_HARD_LIMIT_BYTES ); if( smallStorageTarget ) STORAGE_HARD_LIMIT_BYTES_SPEED_UP_SIM *= 10;
init( STORAGE_HARD_LIMIT_BYTES_OVERAGE_SPEED_UP_SIM, STORAGE_HARD_LIMIT_BYTES_OVERAGE ); if( smallStorageTarget ) STORAGE_HARD_LIMIT_BYTES_OVERAGE_SPEED_UP_SIM *= 10;
init( STORAGE_HARD_LIMIT_VERSION_OVERAGE, VERSIONS_PER_SECOND / 4.0 );
init( STORAGE_DURABILITY_LAG_HARD_MAX, 2000e6 ); if( smallStorageTarget ) STORAGE_DURABILITY_LAG_HARD_MAX = 100e6;
init( STORAGE_DURABILITY_LAG_SOFT_MAX, 250e6 ); if( smallStorageTarget ) STORAGE_DURABILITY_LAG_SOFT_MAX = 10e6;
@ -689,6 +692,7 @@ void ServerKnobs::initialize(Randomize randomize, ClientKnobs* clientKnobs, IsSi
init( BW_FETCH_WORKERS_INTERVAL, 5.0 );
init( BW_RW_LOGGING_INTERVAL, 5.0 );
init( BW_MAX_BLOCKED_INTERVAL, 10.0 ); if(buggifySmallBWLag) BW_MAX_BLOCKED_INTERVAL = 2.0;
init( BW_RK_SIM_QUIESCE_DELAY, 150.0 );
init( MAX_AUTO_THROTTLED_TRANSACTION_TAGS, 5 ); if(randomize && BUGGIFY) MAX_AUTO_THROTTLED_TRANSACTION_TAGS = 1;
init( MAX_MANUAL_THROTTLED_TRANSACTION_TAGS, 40 ); if(randomize && BUGGIFY) MAX_MANUAL_THROTTLED_TRANSACTION_TAGS = 1;
@ -903,18 +907,18 @@ void ServerKnobs::initialize(Randomize randomize, ClientKnobs* clientKnobs, IsSi
init ( CLUSTER_RECOVERY_EVENT_NAME_PREFIX, "Master" );
// Encryption
init( ENABLE_ENCRYPTION, false ); if ( randomize && BUGGIFY ) { ENABLE_ENCRYPTION = deterministicRandom()->coinflip(); }
init( ENABLE_ENCRYPTION, false ); if ( randomize && BUGGIFY ) ENABLE_ENCRYPTION = !ENABLE_ENCRYPTION;
init( ENCRYPTION_MODE, "AES-256-CTR" );
init( SIM_KMS_MAX_KEYS, 4096 );
init( ENCRYPT_PROXY_MAX_DBG_TRACE_LENGTH, 100000 );
init( ENABLE_TLOG_ENCRYPTION, ENABLE_ENCRYPTION ); if ( randomize && BUGGIFY ) { ENABLE_TLOG_ENCRYPTION = (ENABLE_ENCRYPTION && !PROXY_USE_RESOLVER_PRIVATE_MUTATIONS && deterministicRandom()->coinflip()); }
init( ENABLE_BLOB_GRANULE_ENCRYPTION, ENABLE_ENCRYPTION ); if ( randomize && BUGGIFY ) { ENABLE_BLOB_GRANULE_ENCRYPTION = (ENABLE_ENCRYPTION && deterministicRandom()->coinflip()); }
init( ENABLE_TLOG_ENCRYPTION, ENABLE_ENCRYPTION ); if ( randomize && BUGGIFY && ENABLE_ENCRYPTION && !PROXY_USE_RESOLVER_PRIVATE_MUTATIONS ) ENABLE_TLOG_ENCRYPTION = true;
init( ENABLE_STORAGE_SERVER_ENCRYPTION, ENABLE_ENCRYPTION ); if ( randomize && BUGGIFY) ENABLE_STORAGE_SERVER_ENCRYPTION = !ENABLE_STORAGE_SERVER_ENCRYPTION;
init( ENABLE_BLOB_GRANULE_ENCRYPTION, ENABLE_ENCRYPTION ); if ( randomize && BUGGIFY) ENABLE_BLOB_GRANULE_ENCRYPTION = !ENABLE_BLOB_GRANULE_ENCRYPTION;
// encrypt key proxy
init( ENABLE_BLOB_GRANULE_COMPRESSION, false ); if ( randomize && BUGGIFY ) { ENABLE_BLOB_GRANULE_COMPRESSION = deterministicRandom()->coinflip(); }
init( BLOB_GRANULE_COMPRESSION_FILTER, "GZIP" ); if ( randomize && BUGGIFY ) { BLOB_GRANULE_COMPRESSION_FILTER = "NONE"; }
// KMS connector type
init( KMS_CONNECTOR_TYPE, "RESTKmsConnector" );

View File

@ -1332,7 +1332,7 @@ int64_t decodeBlobManagerEpochValue(ValueRef const& value) {
// blob granule data
const KeyRef blobRangeActive = LiteralStringRef("1");
const KeyRef blobRangeInactive = LiteralStringRef("0");
const KeyRef blobRangeInactive = StringRef();
const KeyRangeRef blobGranuleFileKeys(LiteralStringRef("\xff\x02/bgf/"), LiteralStringRef("\xff\x02/bgf0"));
const KeyRangeRef blobGranuleMappingKeys(LiteralStringRef("\xff\x02/bgm/"), LiteralStringRef("\xff\x02/bgm0"));

View File

@ -26,11 +26,11 @@
Key TenantMapEntry::idToPrefix(int64_t id) {
int64_t swapped = bigEndian64(id);
return StringRef(reinterpret_cast<const uint8_t*>(&swapped), 8);
return StringRef(reinterpret_cast<const uint8_t*>(&swapped), TENANT_PREFIX_SIZE);
}
int64_t TenantMapEntry::prefixToId(KeyRef prefix) {
ASSERT(prefix.size() == 8);
ASSERT(prefix.size() == TENANT_PREFIX_SIZE);
int64_t id = *reinterpret_cast<const int64_t*>(prefix.begin());
id = bigEndian64(id);
ASSERT(id >= 0);

View File

@ -400,34 +400,33 @@ ThreadResult<RangeResult> ThreadSafeTransaction::readBlobGranules(const KeyRange
Version beginVersion,
Optional<Version> readVersion,
ReadBlobGranuleContext granule_context) {
// FIXME: prevent from calling this from another main thread!
// This should not be called directly, bypassMultiversionApi should not be set
return ThreadResult<RangeResult>(unsupported_operation());
}
ThreadFuture<Standalone<VectorRef<BlobGranuleChunkRef>>> ThreadSafeTransaction::readBlobGranulesStart(
const KeyRangeRef& keyRange,
Version beginVersion,
Optional<Version> readVersion,
Version* readVersionOut) {
ISingleThreadTransaction* tr = this->tr;
KeyRange r = keyRange;
int64_t readVersionOut;
ThreadFuture<Standalone<VectorRef<BlobGranuleChunkRef>>> getFilesFuture = onMainThread(
[tr, r, beginVersion, readVersion, &readVersionOut]() -> Future<Standalone<VectorRef<BlobGranuleChunkRef>>> {
return onMainThread(
[tr, r, beginVersion, readVersion, readVersionOut]() -> Future<Standalone<VectorRef<BlobGranuleChunkRef>>> {
tr->checkDeferredError();
return tr->readBlobGranules(r, beginVersion, readVersion, &readVersionOut);
return tr->readBlobGranules(r, beginVersion, readVersion, readVersionOut);
});
// FIXME: can this safely avoid another main thread jump?
getFilesFuture.blockUntilReadyCheckOnMainThread();
// propagate error to client
if (getFilesFuture.isError()) {
return ThreadResult<RangeResult>(getFilesFuture.getError());
}
Standalone<VectorRef<BlobGranuleChunkRef>> files = getFilesFuture.get();
}
ThreadResult<RangeResult> ThreadSafeTransaction::readBlobGranulesFinish(
ThreadFuture<Standalone<VectorRef<BlobGranuleChunkRef>>> startFuture,
const KeyRangeRef& keyRange,
Version beginVersion,
Version readVersion,
ReadBlobGranuleContext granuleContext) {
// do this work off of fdb network threads for performance!
if (granule_context.debugNoMaterialize) {
return ThreadResult<RangeResult>(blob_granule_not_materialized());
} else {
return loadAndMaterializeBlobGranules(files, keyRange, beginVersion, readVersionOut, granule_context);
}
Standalone<VectorRef<BlobGranuleChunkRef>> files = startFuture.get();
return loadAndMaterializeBlobGranules(files, keyRange, beginVersion, readVersion, granuleContext);
}
void ThreadSafeTransaction::addReadConflictRange(const KeyRangeRef& keys) {

View File

@ -208,7 +208,7 @@ Tuple& Tuple::append(double value) {
return *this;
}
Tuple& Tuple::append(nullptr_t) {
Tuple& Tuple::append(std::nullptr_t) {
offsets.push_back(data.size());
data.push_back(data.arena(), (uint8_t)'\x00');
return *this;

View File

@ -1531,6 +1531,42 @@ struct StorageWiggleValue {
}
};
enum class ReadType {
EAGER,
FETCH,
LOW,
NORMAL,
HIGH,
};
FDB_DECLARE_BOOLEAN_PARAM(CacheResult);
// store options for storage engine read
// ReadType describes the usage and priority of the read
// cacheResult determines whether the storage engine cache for this read
// consistencyCheckStartVersion indicates the consistency check which began at this version
// debugID helps to trace the path of the read
struct ReadOptions {
ReadType type;
// Once CacheResult is serializable, change type from bool to CacheResult
bool cacheResult;
Optional<UID> debugID;
Optional<Version> consistencyCheckStartVersion;
ReadOptions() : type(ReadType::NORMAL), cacheResult(CacheResult::True){};
ReadOptions(Optional<UID> debugID,
ReadType type = ReadType::NORMAL,
CacheResult cache = CacheResult::False,
Optional<Version> version = Optional<Version>())
: type(type), cacheResult(cache), debugID(debugID), consistencyCheckStartVersion(version){};
template <class Ar>
void serialize(Ar& ar) {
serializer(ar, type, cacheResult, debugID, consistencyCheckStartVersion);
}
};
// Can be used to identify types (e.g. IDatabase) that can be used to create transactions with a `createTransaction`
// function
template <typename, typename = void>

View File

@ -250,27 +250,28 @@ struct TextAndHeaderCipherKeys {
Reference<BlobCipherKey> cipherHeaderKey;
};
// Helper method to get latest cipher text key and cipher header key for system domain,
// used for encrypting system data.
ACTOR template <class T>
Future<TextAndHeaderCipherKeys> getLatestSystemEncryptCipherKeys(Reference<AsyncVar<T> const> db) {
static std::unordered_map<EncryptCipherDomainId, EncryptCipherDomainName> domains = {
{ SYSTEM_KEYSPACE_ENCRYPT_DOMAIN_ID, FDB_DEFAULT_ENCRYPT_DOMAIN_NAME },
{ ENCRYPT_HEADER_DOMAIN_ID, FDB_DEFAULT_ENCRYPT_DOMAIN_NAME }
};
Future<TextAndHeaderCipherKeys> getLatestEncryptCipherKeysForDomain(Reference<AsyncVar<T> const> db,
EncryptCipherDomainId domainId,
EncryptCipherDomainName domainName) {
std::unordered_map<EncryptCipherDomainId, EncryptCipherDomainName> domains;
domains[domainId] = domainName;
domains[ENCRYPT_HEADER_DOMAIN_ID] = FDB_DEFAULT_ENCRYPT_DOMAIN_NAME;
std::unordered_map<EncryptCipherDomainId, Reference<BlobCipherKey>> cipherKeys =
wait(getLatestEncryptCipherKeys(db, domains));
ASSERT(cipherKeys.count(SYSTEM_KEYSPACE_ENCRYPT_DOMAIN_ID) > 0);
ASSERT(cipherKeys.count(domainId) > 0);
ASSERT(cipherKeys.count(ENCRYPT_HEADER_DOMAIN_ID) > 0);
TextAndHeaderCipherKeys result{ cipherKeys.at(SYSTEM_KEYSPACE_ENCRYPT_DOMAIN_ID),
cipherKeys.at(ENCRYPT_HEADER_DOMAIN_ID) };
TextAndHeaderCipherKeys result{ cipherKeys.at(domainId), cipherKeys.at(ENCRYPT_HEADER_DOMAIN_ID) };
ASSERT(result.cipherTextKey.isValid());
ASSERT(result.cipherHeaderKey.isValid());
return result;
}
// Helper method to get both text cipher key and header cipher key for the given encryption header,
// used for decrypting given encrypted data with encryption header.
template <class T>
Future<TextAndHeaderCipherKeys> getLatestSystemEncryptCipherKeys(const Reference<AsyncVar<T> const>& db) {
return getLatestEncryptCipherKeysForDomain(db, SYSTEM_KEYSPACE_ENCRYPT_DOMAIN_ID, FDB_DEFAULT_ENCRYPT_DOMAIN_NAME);
}
ACTOR template <class T>
Future<TextAndHeaderCipherKeys> getEncryptCipherKeys(Reference<AsyncVar<T> const> db, BlobCipherEncryptHeader header) {
std::unordered_set<BlobCipherDetails> cipherDetails{ header.cipherTextDetails, header.cipherHeaderDetails };

View File

@ -22,6 +22,7 @@
#define FDBCLIENT_ICLIENTAPI_H
#pragma once
#include "fdbclient/BlobGranuleCommon.h"
#include "fdbclient/FDBOptions.g.h"
#include "fdbclient/FDBTypes.h"
#include "fdbclient/Tenant.h"
@ -86,6 +87,19 @@ public:
Optional<Version> readVersion,
ReadBlobGranuleContext granuleContext) = 0;
virtual ThreadFuture<Standalone<VectorRef<BlobGranuleChunkRef>>> readBlobGranulesStart(
const KeyRangeRef& keyRange,
Version beginVersion,
Optional<Version> readVersion,
Version* readVersionOut) = 0;
virtual ThreadResult<RangeResult> readBlobGranulesFinish(
ThreadFuture<Standalone<VectorRef<BlobGranuleChunkRef>>> startFuture,
const KeyRangeRef& keyRange,
Version beginVersion,
Version readVersion,
ReadBlobGranuleContext granuleContext) = 0;
virtual void atomicOp(const KeyRef& key, const ValueRef& value, uint32_t operationType) = 0;
virtual void set(const KeyRef& key, const ValueRef& value) = 0;
virtual void clear(const KeyRef& begin, const KeyRef& end) = 0;

View File

@ -136,6 +136,16 @@ Future<RangeResult> krmGetRanges(Reference<ReadYourWritesTransaction> const& tr,
KeyRange const& keys,
int const& limit = CLIENT_KNOBS->KRM_GET_RANGE_LIMIT,
int const& limitBytes = CLIENT_KNOBS->KRM_GET_RANGE_LIMIT_BYTES);
Future<RangeResult> krmGetRangesUnaligned(Transaction* const& tr,
Key const& mapPrefix,
KeyRange const& keys,
int const& limit = CLIENT_KNOBS->KRM_GET_RANGE_LIMIT,
int const& limitBytes = CLIENT_KNOBS->KRM_GET_RANGE_LIMIT_BYTES);
Future<RangeResult> krmGetRangesUnaligned(Reference<ReadYourWritesTransaction> const& tr,
Key const& mapPrefix,
KeyRange const& keys,
int const& limit = CLIENT_KNOBS->KRM_GET_RANGE_LIMIT,
int const& limitBytes = CLIENT_KNOBS->KRM_GET_RANGE_LIMIT_BYTES);
void krmSetPreviouslyEmptyRange(Transaction* tr,
const KeyRef& mapPrefix,
const KeyRangeRef& keys,
@ -162,7 +172,7 @@ Future<Void> krmSetRangeCoalescing(Reference<ReadYourWritesTransaction> const& t
KeyRange const& range,
KeyRange const& maxRange,
Value const& value);
RangeResult krmDecodeRanges(KeyRef mapPrefix, KeyRange keys, RangeResult kv);
RangeResult krmDecodeRanges(KeyRef mapPrefix, KeyRange keys, RangeResult kv, bool align = true);
template <class Val, class Metric, class MetricFunc>
std::vector<KeyRangeWith<Val>> KeyRangeMap<Val, Metric, MetricFunc>::getAffectedRangesAfterInsertion(

View File

@ -298,21 +298,39 @@ struct FdbCApi : public ThreadSafeReferenceCounted<FdbCApi> {
int end_key_name_length,
int64_t chunkSize);
FDBFuture* (*transactionGetBlobGranuleRanges)(FDBTransaction* db,
FDBFuture* (*transactionGetBlobGranuleRanges)(FDBTransaction* tr,
uint8_t const* begin_key_name,
int begin_key_name_length,
uint8_t const* end_key_name,
int end_key_name_length,
int rangeLimit);
FDBResult* (*transactionReadBlobGranules)(FDBTransaction* db,
FDBResult* (*transactionReadBlobGranules)(FDBTransaction* tr,
uint8_t const* begin_key_name,
int begin_key_name_length,
uint8_t const* end_key_name,
int end_key_name_length,
int64_t beginVersion,
int64_t readVersion,
FDBReadBlobGranuleContext granule_context);
int64_t readVersion);
FDBFuture* (*transactionReadBlobGranulesStart)(FDBTransaction* tr,
uint8_t const* begin_key_name,
int begin_key_name_length,
uint8_t const* end_key_name,
int end_key_name_length,
int64_t beginVersion,
int64_t readVersion,
int64_t* readVersionOut);
FDBResult* (*transactionReadBlobGranulesFinish)(FDBTransaction* tr,
FDBFuture* startFuture,
uint8_t const* begin_key_name,
int begin_key_name_length,
uint8_t const* end_key_name,
int end_key_name_length,
int64_t beginVersion,
int64_t readVersion,
FDBReadBlobGranuleContext* granule_context);
FDBFuture* (*transactionCommit)(FDBTransaction* tr);
fdb_error_t (*transactionGetCommittedVersion)(FDBTransaction* tr, int64_t* outVersion);
@ -411,6 +429,18 @@ public:
Optional<Version> readVersion,
ReadBlobGranuleContext granule_context) override;
ThreadFuture<Standalone<VectorRef<BlobGranuleChunkRef>>> readBlobGranulesStart(const KeyRangeRef& keyRange,
Version beginVersion,
Optional<Version> readVersion,
Version* readVersionOut) override;
ThreadResult<RangeResult> readBlobGranulesFinish(
ThreadFuture<Standalone<VectorRef<BlobGranuleChunkRef>>> startFuture,
const KeyRangeRef& keyRange,
Version beginVersion,
Version readVersion,
ReadBlobGranuleContext granuleContext) override;
void addReadConflictRange(const KeyRangeRef& keys) override;
void atomicOp(const KeyRef& key, const ValueRef& value, uint32_t operationType) override;
@ -616,6 +646,18 @@ public:
Optional<Version> readVersion,
ReadBlobGranuleContext granule_context) override;
ThreadFuture<Standalone<VectorRef<BlobGranuleChunkRef>>> readBlobGranulesStart(const KeyRangeRef& keyRange,
Version beginVersion,
Optional<Version> readVersion,
Version* readVersionOut) override;
ThreadResult<RangeResult> readBlobGranulesFinish(
ThreadFuture<Standalone<VectorRef<BlobGranuleChunkRef>>> startFuture,
const KeyRangeRef& keyRange,
Version beginVersion,
Version readVersion,
ReadBlobGranuleContext granuleContext) override;
void atomicOp(const KeyRef& key, const ValueRef& value, uint32_t operationType) override;
void set(const KeyRef& key, const ValueRef& value) override;
void clear(const KeyRef& begin, const KeyRef& end) override;
@ -681,6 +723,9 @@ private:
template <class T>
ThreadResult<T> abortableTimeoutResult(ThreadFuture<Void> abortSignal);
template <class T>
ThreadResult<T> abortableResult(ThreadResult<T> result, ThreadFuture<Void> abortSignal);
TransactionInfo transaction;
TransactionInfo getTransaction();

View File

@ -242,8 +242,8 @@ struct TransactionState : ReferenceCounted<TransactionState> {
Optional<Standalone<StringRef>> authToken;
Reference<TransactionLogInfo> trLogInfo;
TransactionOptions options;
Optional<ReadOptions> readOptions;
Optional<UID> debugID;
TaskPriority taskID;
SpanContext spanContext;
UseProvisionalProxies useProvisionalProxies = UseProvisionalProxies::False;
@ -459,7 +459,13 @@ public:
void fullReset();
double getBackoff(int errCode);
void debugTransaction(UID dID) { trState->debugID = dID; }
void debugTransaction(UID dID) {
if (trState->readOptions.present()) {
trState->readOptions.get().debugID = dID;
} else {
trState->readOptions = ReadOptions(dID);
}
}
VersionVector getVersionVector() const;
SpanContext getSpanContext() const { return trState->spanContext; }

View File

@ -340,6 +340,7 @@ public:
int64_t ROCKSDB_BLOCK_SIZE;
bool ENABLE_SHARDED_ROCKSDB;
int64_t ROCKSDB_WRITE_BUFFER_SIZE;
int64_t ROCKSDB_CF_WRITE_BUFFER_SIZE;
int64_t ROCKSDB_MAX_TOTAL_WAL_SIZE;
int64_t ROCKSDB_MAX_BACKGROUND_JOBS;
int64_t ROCKSDB_DELETE_OBSOLETE_FILE_PERIOD;
@ -567,6 +568,8 @@ public:
int64_t SPRING_BYTES_STORAGE_SERVER_BATCH;
int64_t STORAGE_HARD_LIMIT_BYTES;
int64_t STORAGE_HARD_LIMIT_BYTES_OVERAGE;
int64_t STORAGE_HARD_LIMIT_BYTES_SPEED_UP_SIM;
int64_t STORAGE_HARD_LIMIT_BYTES_OVERAGE_SPEED_UP_SIM;
int64_t STORAGE_HARD_LIMIT_VERSION_OVERAGE;
int64_t STORAGE_DURABILITY_LAG_HARD_MAX;
int64_t STORAGE_DURABILITY_LAG_SOFT_MAX;
@ -643,6 +646,7 @@ public:
double BW_FETCH_WORKERS_INTERVAL;
double BW_RW_LOGGING_INTERVAL;
double BW_MAX_BLOCKED_INTERVAL;
double BW_RK_SIM_QUIESCE_DELAY;
// disk snapshot
int64_t MAX_FORKED_PROCESS_OUTPUT;
@ -881,6 +885,7 @@ public:
int SIM_KMS_MAX_KEYS;
int ENCRYPT_PROXY_MAX_DBG_TRACE_LENGTH;
bool ENABLE_TLOG_ENCRYPTION;
bool ENABLE_STORAGE_SERVER_ENCRYPTION; // Currently only Redwood engine supports encryption
bool ENABLE_BLOB_GRANULE_ENCRYPTION;
// Compression

View File

@ -295,28 +295,28 @@ struct GetValueRequest : TimedRequest {
Key key;
Version version;
Optional<TagSet> tags;
Optional<UID> debugID;
ReplyPromise<GetValueReply> reply;
Optional<ReadOptions> options;
VersionVector ssLatestCommitVersions; // includes the latest commit versions, as known
// to this client, of all storage replicas that
// serve the given key
GetValueRequest() {}
bool verify() const { return tenantInfo.isAuthorized(); }
GetValueRequest() {}
GetValueRequest(SpanContext spanContext,
const TenantInfo& tenantInfo,
const Key& key,
Version ver,
Optional<TagSet> tags,
Optional<UID> debugID,
Optional<ReadOptions> options,
VersionVector latestCommitVersions)
: spanContext(spanContext), tenantInfo(tenantInfo), key(key), version(ver), tags(tags), debugID(debugID),
: spanContext(spanContext), tenantInfo(tenantInfo), key(key), version(ver), tags(tags), options(options),
ssLatestCommitVersions(latestCommitVersions) {}
template <class Ar>
void serialize(Ar& ar) {
serializer(ar, key, version, tags, debugID, reply, spanContext, tenantInfo, ssLatestCommitVersions);
serializer(ar, key, version, tags, reply, spanContext, tenantInfo, options, ssLatestCommitVersions);
}
};
@ -392,15 +392,14 @@ struct GetKeyValuesRequest : TimedRequest {
KeyRef mapper = KeyRef();
Version version; // or latestVersion
int limit, limitBytes;
bool isFetchKeys;
Optional<TagSet> tags;
Optional<UID> debugID;
Optional<ReadOptions> options;
ReplyPromise<GetKeyValuesReply> reply;
VersionVector ssLatestCommitVersions; // includes the latest commit versions, as known
// to this client, of all storage replicas that
// serve the given key
GetKeyValuesRequest() : isFetchKeys(false) {}
GetKeyValuesRequest() {}
bool verify() const { return tenantInfo.isAuthorized(); }
@ -412,12 +411,11 @@ struct GetKeyValuesRequest : TimedRequest {
version,
limit,
limitBytes,
isFetchKeys,
tags,
debugID,
reply,
spanContext,
tenantInfo,
options,
arena,
ssLatestCommitVersions);
}
@ -451,15 +449,14 @@ struct GetMappedKeyValuesRequest : TimedRequest {
Version version; // or latestVersion
int limit, limitBytes;
int matchIndex;
bool isFetchKeys;
Optional<TagSet> tags;
Optional<UID> debugID;
Optional<ReadOptions> options;
ReplyPromise<GetMappedKeyValuesReply> reply;
VersionVector ssLatestCommitVersions; // includes the latest commit versions, as known
// to this client, of all storage replicas that
// serve the given key range
GetMappedKeyValuesRequest() : isFetchKeys(false) {}
GetMappedKeyValuesRequest() {}
bool verify() const { return tenantInfo.isAuthorized(); }
@ -472,12 +469,11 @@ struct GetMappedKeyValuesRequest : TimedRequest {
version,
limit,
limitBytes,
isFetchKeys,
tags,
debugID,
reply,
spanContext,
tenantInfo,
options,
arena,
ssLatestCommitVersions,
matchIndex);
@ -519,15 +515,14 @@ struct GetKeyValuesStreamRequest {
KeySelectorRef begin, end;
Version version; // or latestVersion
int limit, limitBytes;
bool isFetchKeys;
Optional<TagSet> tags;
Optional<UID> debugID;
Optional<ReadOptions> options;
ReplyPromiseStream<GetKeyValuesStreamReply> reply;
VersionVector ssLatestCommitVersions; // includes the latest commit versions, as known
// to this client, of all storage replicas that
// serve the given key range
GetKeyValuesStreamRequest() : isFetchKeys(false) {}
GetKeyValuesStreamRequest() {}
bool verify() const { return tenantInfo.isAuthorized(); }
@ -539,12 +534,11 @@ struct GetKeyValuesStreamRequest {
version,
limit,
limitBytes,
isFetchKeys,
tags,
debugID,
reply,
spanContext,
tenantInfo,
options,
arena,
ssLatestCommitVersions);
}
@ -572,29 +566,29 @@ struct GetKeyRequest : TimedRequest {
KeySelectorRef sel;
Version version; // or latestVersion
Optional<TagSet> tags;
Optional<UID> debugID;
ReplyPromise<GetKeyReply> reply;
Optional<ReadOptions> options;
VersionVector ssLatestCommitVersions; // includes the latest commit versions, as known
// to this client, of all storage replicas that
// serve the given key
bool verify() const { return tenantInfo.isAuthorized(); }
GetKeyRequest() {}
bool verify() const { return tenantInfo.isAuthorized(); }
GetKeyRequest(SpanContext spanContext,
TenantInfo tenantInfo,
KeySelectorRef const& sel,
Version version,
Optional<TagSet> tags,
Optional<UID> debugID,
Optional<ReadOptions> options,
VersionVector latestCommitVersions)
: spanContext(spanContext), tenantInfo(tenantInfo), sel(sel), version(version), debugID(debugID),
: spanContext(spanContext), tenantInfo(tenantInfo), sel(sel), version(version), tags(tags), options(options),
ssLatestCommitVersions(latestCommitVersions) {}
template <class Ar>
void serialize(Ar& ar) {
serializer(ar, sel, version, tags, debugID, reply, spanContext, tenantInfo, arena, ssLatestCommitVersions);
serializer(ar, sel, version, tags, reply, spanContext, tenantInfo, options, arena, ssLatestCommitVersions);
}
};

View File

@ -62,6 +62,8 @@ enum class TenantState { REGISTERING, READY, REMOVING, UPDATING_CONFIGURATION, R
// Can be used in conjunction with the other tenant states above.
enum class TenantLockState { UNLOCKED, READ_ONLY, LOCKED };
constexpr int TENANT_PREFIX_SIZE = sizeof(int64_t);
struct TenantMapEntry {
constexpr static FileIdentifier file_identifier = 12247338;
@ -201,6 +203,6 @@ struct TenantMetadata {
};
typedef VersionedMap<TenantName, TenantMapEntry> TenantMap;
typedef VersionedMap<Key, TenantName> TenantPrefixIndex;
class TenantPrefixIndex : public VersionedMap<Key, TenantName>, public ReferenceCounted<TenantPrefixIndex> {};
#endif

View File

@ -164,6 +164,18 @@ public:
Optional<Version> readVersion,
ReadBlobGranuleContext granuleContext) override;
ThreadFuture<Standalone<VectorRef<BlobGranuleChunkRef>>> readBlobGranulesStart(const KeyRangeRef& keyRange,
Version beginVersion,
Optional<Version> readVersion,
Version* readVersionOut) override;
ThreadResult<RangeResult> readBlobGranulesFinish(
ThreadFuture<Standalone<VectorRef<BlobGranuleChunkRef>>> startFuture,
const KeyRangeRef& keyRange,
Version beginVersion,
Version readVersion,
ReadBlobGranuleContext granuleContext) override;
void addReadConflictRange(const KeyRangeRef& keys) override;
void makeSelfConflicting();

View File

@ -306,9 +306,8 @@ description is not currently required but encouraged.
description="Specifically instruct this transaction to NOT use cached GRV. Primarily used for the read version cache's background updater to avoid attempting to read a cached entry in specific situations."
hidden="true"/>
<Option name="authorization_token" code="2000"
description="Add a given authorization token to the network thread so that future requests are authorized"
paramType="String" paramDescription="A signed token serialized using flatbuffers"
hidden="true" />
description="Attach given authorization token to the transaction such that subsequent tenant-aware requests are authorized"
paramType="String" paramDescription="A JSON Web Token authorized to access data belonging to one or more tenants, indicated by 'tenants' claim of the token's payload."/>
</Scope>
<!-- The enumeration values matter - do not change them without

View File

@ -80,3 +80,5 @@ target_compile_definitions(fdbrpc_sampling PRIVATE -DENABLE_SAMPLING)
if(WIN32)
add_dependencies(fdbrpc_sampling_actors fdbrpc_actors)
endif()
add_subdirectory(tests)

View File

@ -34,6 +34,7 @@
#include "fdbrpc/fdbrpc.h"
#include "fdbrpc/FailureMonitor.h"
#include "fdbrpc/HealthMonitor.h"
#include "fdbrpc/JsonWebKeySet.h"
#include "fdbrpc/genericactors.actor.h"
#include "fdbrpc/IPAllowList.h"
#include "fdbrpc/TokenCache.h"
@ -44,8 +45,10 @@
#include "flow/Net2Packet.h"
#include "flow/TDMetric.actor.h"
#include "flow/ObjectSerializer.h"
#include "flow/Platform.h"
#include "flow/ProtocolVersion.h"
#include "flow/UnitTest.h"
#include "flow/WatchFile.actor.h"
#define XXH_INLINE_ALL
#include "flow/xxhash.h"
#include "flow/actorcompiler.h" // This must be the last #include.
@ -309,6 +312,7 @@ public:
// Returns true if given network address 'address' is one of the address we are listening on.
bool isLocalAddress(const NetworkAddress& address) const;
void applyPublicKeySet(StringRef jwkSetString);
NetworkAddressCachedString localAddresses;
std::vector<Future<Void>> listeners;
@ -341,6 +345,7 @@ public:
Future<Void> multiVersionCleanup;
Future<Void> pingLogger;
Future<Void> publicKeyFileWatch;
std::unordered_map<Standalone<StringRef>, PublicKey> publicKeys;
};
@ -958,7 +963,7 @@ void Peer::onIncomingConnection(Reference<Peer> self, Reference<IConnection> con
.detail("FromAddr", conn->getPeerAddress())
.detail("CanonicalAddr", destination)
.detail("IsPublic", destination.isPublic())
.detail("Trusted", self->transport->allowList(conn->getPeerAddress().ip));
.detail("Trusted", self->transport->allowList(conn->getPeerAddress().ip) && conn->hasTrustedPeer());
connect.cancel();
prependConnectPacket();
@ -1257,7 +1262,7 @@ ACTOR static Future<Void> connectionReader(TransportData* transport,
state bool incompatiblePeerCounted = false;
state NetworkAddress peerAddress;
state ProtocolVersion peerProtocolVersion;
state bool trusted = transport->allowList(conn->getPeerAddress().ip);
state bool trusted = transport->allowList(conn->getPeerAddress().ip) && conn->hasTrustedPeer();
peerAddress = conn->getPeerAddress();
if (!peer) {
@ -1529,6 +1534,27 @@ bool TransportData::isLocalAddress(const NetworkAddress& address) const {
address == localAddresses.getAddressList().secondaryAddress.get());
}
void TransportData::applyPublicKeySet(StringRef jwkSetString) {
auto jwks = JsonWebKeySet::parse(jwkSetString, {});
if (!jwks.present())
throw pkey_decode_error();
const auto& keySet = jwks.get().keys;
publicKeys.clear();
int numPrivateKeys = 0;
for (auto [keyName, key] : keySet) {
// ignore private keys
if (key.isPublic()) {
publicKeys[keyName] = key.getPublic();
} else {
numPrivateKeys++;
}
}
TraceEvent(SevInfo, "AuthzPublicKeySetApply").detail("NumPublicKeys", publicKeys.size());
if (numPrivateKeys > 0) {
TraceEvent(SevWarnAlways, "AuthzPublicKeySetContainsPrivateKeys").detail("NumPrivateKeys", numPrivateKeys);
}
}
ACTOR static Future<Void> multiVersionCleanupWorker(TransportData* self) {
loop {
wait(delay(FLOW_KNOBS->CONNECTION_CLEANUP_DELAY));
@ -1967,3 +1993,62 @@ void FlowTransport::removePublicKey(StringRef name) {
void FlowTransport::removeAllPublicKeys() {
self->publicKeys.clear();
}
void FlowTransport::loadPublicKeyFile(const std::string& filePath) {
if (!fileExists(filePath)) {
throw file_not_found();
}
int64_t const len = fileSize(filePath);
if (len <= 0) {
TraceEvent(SevWarn, "AuthzPublicKeySetEmpty").detail("Path", filePath);
} else if (len > FLOW_KNOBS->PUBLIC_KEY_FILE_MAX_SIZE) {
throw file_too_large();
} else {
auto json = readFileBytes(filePath, len);
self->applyPublicKeySet(StringRef(json));
}
}
ACTOR static Future<Void> watchPublicKeyJwksFile(std::string filePath, TransportData* self) {
state AsyncTrigger fileChanged;
state Future<Void> fileWatch;
state unsigned errorCount = 0; // error since watch start or last successful refresh
// Make sure this watch setup does not break due to async file system initialization not having been called
loop {
if (IAsyncFileSystem::filesystem())
break;
wait(delay(1.0));
}
const int& intervalSeconds = FLOW_KNOBS->PUBLIC_KEY_FILE_REFRESH_INTERVAL_SECONDS;
fileWatch = watchFileForChanges(filePath, &fileChanged, &intervalSeconds, "AuthzPublicKeySetRefreshStatError");
loop {
try {
wait(fileChanged.onTrigger());
state Reference<IAsyncFile> file = wait(IAsyncFileSystem::filesystem()->open(
filePath, IAsyncFile::OPEN_READONLY | IAsyncFile::OPEN_UNCACHED, 0));
state int64_t filesize = wait(file->size());
state std::string json(filesize, '\0');
if (filesize > FLOW_KNOBS->PUBLIC_KEY_FILE_MAX_SIZE)
throw file_too_large();
if (filesize <= 0) {
TraceEvent(SevWarn, "AuthzPublicKeySetEmpty").suppressFor(60);
continue;
}
wait(success(file->read(&json[0], filesize, 0)));
self->applyPublicKeySet(StringRef(json));
errorCount = 0;
} catch (Error& e) {
if (e.code() == error_code_actor_cancelled) {
throw;
}
// parse/read error
errorCount++;
TraceEvent(SevWarn, "AuthzPublicKeySetRefreshError").error(e).detail("ErrorCount", errorCount);
}
}
}
void FlowTransport::watchPublicKeyFile(const std::string& publicKeyFilePath) {
self->publicKeyFileWatch = watchPublicKeyJwksFile(publicKeyFilePath, self);
}

View File

@ -830,11 +830,18 @@ TEST_CASE("/fdbrpc/JsonWebKeySet/EC/PrivateKey") {
}
TEST_CASE("/fdbrpc/JsonWebKeySet/RSA/PublicKey") {
testPublicKey(&mkcert::makeRsa2048Bit);
testPublicKey(&mkcert::makeRsa4096Bit);
return Void();
}
TEST_CASE("/fdbrpc/JsonWebKeySet/RSA/PrivateKey") {
testPrivateKey(&mkcert::makeRsa2048Bit);
testPrivateKey(&mkcert::makeRsa4096Bit);
return Void();
}
TEST_CASE("/fdbrpc/JsonWebKeySet/Empty") {
auto keyset = JsonWebKeySet::parse("{\"keys\":[]}"_sr, {});
ASSERT(keyset.present());
ASSERT(keyset.get().keys.empty());
return Void();
}

View File

@ -125,6 +125,10 @@ NetworkAddress SimExternalConnection::getPeerAddress() const {
}
}
bool SimExternalConnection::hasTrustedPeer() const {
return true;
}
UID SimExternalConnection::getDebugID() const {
return dbgid;
}

View File

@ -177,6 +177,10 @@ bool TokenCacheImpl::validateAndAdd(double currentTime, StringRef token, Network
CODE_PROBE(true, "Token referencing non-existing key");
TRACE_INVALID_PARSED_TOKEN("UnknownKey", t);
return false;
} else if (!t.issuedAtUnixTime.present()) {
CODE_PROBE(true, "Token has no issued-at field");
TRACE_INVALID_PARSED_TOKEN("NoIssuedAt", t);
return false;
} else if (!t.expiresAtUnixTime.present()) {
CODE_PROBE(true, "Token has no expiration time");
TRACE_INVALID_PARSED_TOKEN("NoExpirationTime", t);
@ -203,7 +207,7 @@ bool TokenCacheImpl::validateAndAdd(double currentTime, StringRef token, Network
return false;
} else {
CacheEntry c;
c.expirationTime = double(t.expiresAtUnixTime.get());
c.expirationTime = t.expiresAtUnixTime.get();
c.tenants.reserve(c.arena, t.tenants.get().size());
for (auto tenant : t.tenants.get()) {
c.tenants.push_back_deep(c.arena, tenant);
@ -265,7 +269,7 @@ TEST_CASE("/fdbrpc/authz/TokenCache/BadTokens") {
},
{
[](Arena&, IRandom& rng, authz::jwt::TokenRef& token) {
token.expiresAtUnixTime = uint64_t(std::max<double>(g_network->timer() - 10 - rng.random01() * 50, 0));
token.expiresAtUnixTime = std::max<double>(g_network->timer() - 10 - rng.random01() * 50, 0);
},
"ExpiredToken",
},
@ -275,10 +279,15 @@ TEST_CASE("/fdbrpc/authz/TokenCache/BadTokens") {
},
{
[](Arena&, IRandom& rng, authz::jwt::TokenRef& token) {
token.notBeforeUnixTime = uint64_t(g_network->timer() + 10 + rng.random01() * 50);
token.notBeforeUnixTime = g_network->timer() + 10 + rng.random01() * 50;
},
"TokenNotYetValid",
},
{
[](Arena&, IRandom&, authz::jwt::TokenRef& token) { token.issuedAtUnixTime.reset(); },
"NoIssuedAt",
},
{
[](Arena& arena, IRandom&, authz::jwt::TokenRef& token) { token.tenants.reset(); },
"NoTenants",
@ -336,7 +345,7 @@ TEST_CASE("/fdbrpc/authz/TokenCache/GoodTokens") {
authz::jwt::makeRandomTokenSpec(arena, *deterministicRandom(), authz::Algorithm::ES256);
state StringRef signedToken;
FlowTransport::transport().addPublicKey(pubKeyName, privateKey.toPublic());
tokenSpec.expiresAtUnixTime = static_cast<uint64_t>(g_network->timer() + 2.0);
tokenSpec.expiresAtUnixTime = g_network->timer() + 2.0;
tokenSpec.keyId = pubKeyName;
signedToken = authz::jwt::signToken(arena, tokenSpec, privateKey);
if (!TokenCache::instance().validate(tokenSpec.tenants.get()[0], signedToken)) {

View File

@ -22,6 +22,7 @@
#include "flow/network.h"
#include "flow/serialize.h"
#include "flow/Arena.h"
#include "flow/AutoCPointer.h"
#include "flow/Error.h"
#include "flow/IRandom.h"
#include "flow/MkCert.h"
@ -30,6 +31,7 @@
#include "flow/Trace.h"
#include "flow/UnitTest.h"
#include <fmt/format.h>
#include <cmath>
#include <iterator>
#include <string_view>
#include <type_traits>
@ -87,6 +89,51 @@ bool checkSignAlgorithm(PKeyAlgorithm algo, PrivateKey key) {
}
}
Optional<StringRef> convertEs256P1363ToDer(Arena& arena, StringRef p1363) {
const int SIGLEN = p1363.size();
const int HALF_SIGLEN = SIGLEN / 2;
auto r = AutoCPointer(BN_bin2bn(p1363.begin(), HALF_SIGLEN, nullptr), &::BN_free);
auto s = AutoCPointer(BN_bin2bn(p1363.begin() + HALF_SIGLEN, HALF_SIGLEN, nullptr), &::BN_free);
if (!r || !s)
return {};
auto sig = AutoCPointer(::ECDSA_SIG_new(), &ECDSA_SIG_free);
if (!sig)
return {};
::ECDSA_SIG_set0(sig, r.release(), s.release());
auto const derLen = ::i2d_ECDSA_SIG(sig, nullptr);
if (derLen < 0)
return {};
auto buf = new (arena) uint8_t[derLen];
auto bufPtr = buf;
::i2d_ECDSA_SIG(sig, &bufPtr);
return StringRef(buf, derLen);
}
Optional<StringRef> convertEs256DerToP1363(Arena& arena, StringRef der) {
uint8_t const* derPtr = der.begin();
auto sig = AutoCPointer(::d2i_ECDSA_SIG(nullptr, &derPtr, der.size()), &::ECDSA_SIG_free);
if (!sig) {
return {};
}
// ES256-specific constant. Adapt as needed
constexpr const int SIGLEN = 64;
constexpr const int HALF_SIGLEN = SIGLEN / 2;
auto buf = new (arena) uint8_t[SIGLEN];
::memset(buf, 0, SIGLEN);
auto bufr = buf;
auto bufs = bufr + HALF_SIGLEN;
auto r = std::add_pointer_t<BIGNUM const>();
auto s = std::add_pointer_t<BIGNUM const>();
ECDSA_SIG_get0(sig, &r, &s);
auto const lenr = BN_num_bytes(r);
auto const lens = BN_num_bytes(s);
if (lenr > HALF_SIGLEN || lens > HALF_SIGLEN)
return {};
BN_bn2bin(r, bufr + (HALF_SIGLEN - lenr));
BN_bn2bin(s, bufs + (HALF_SIGLEN - lens));
return StringRef(buf, SIGLEN);
}
} // namespace
namespace authz {
@ -130,11 +177,7 @@ SignedTokenRef signToken(Arena& arena, TokenRef token, StringRef keyName, Privat
auto writer = ObjectWriter([&arena](size_t len) { return new (arena) uint8_t[len]; }, IncludeVersion());
writer.serialize(token);
auto tokenStr = writer.toStringRef();
auto [signAlgo, digest] = getMethod(Algorithm::ES256);
if (!checkSignAlgorithm(signAlgo, privateKey)) {
throw digital_signature_ops_error();
}
auto sig = privateKey.sign(arena, tokenStr, *digest);
auto sig = privateKey.sign(arena, tokenStr, *::EVP_sha256());
ret.token = tokenStr;
ret.signature = sig;
ret.keyName = StringRef(arena, keyName);
@ -142,10 +185,7 @@ SignedTokenRef signToken(Arena& arena, TokenRef token, StringRef keyName, Privat
}
bool verifyToken(SignedTokenRef signedToken, PublicKey publicKey) {
auto [keyAlg, digest] = getMethod(Algorithm::ES256);
if (!checkVerifyAlgorithm(keyAlg, publicKey))
return false;
return publicKey.verify(signedToken.token, signedToken.signature, *digest);
return publicKey.verify(signedToken.token, signedToken.signature, *::EVP_sha256());
}
TokenRef makeRandomTokenSpec(Arena& arena, IRandom& rng) {
@ -268,6 +308,17 @@ StringRef signToken(Arena& arena, TokenRef tokenSpec, PrivateKey privateKey) {
throw digital_signature_ops_error();
}
auto plainSig = privateKey.sign(tmpArena, tokenPart, *digest);
if (tokenSpec.algorithm == Algorithm::ES256) {
// Need to convert ASN.1/DER signature to IEEE-P1363
auto convertedSig = convertEs256DerToP1363(tmpArena, plainSig);
if (!convertedSig.present()) {
auto tmpArena = Arena();
TraceEvent(SevWarn, "TokenSigConversionFailure")
.detail("TokenSpec", tokenSpec.toStringRef(tmpArena).toString());
throw digital_signature_ops_error();
}
plainSig = convertedSig.get();
}
auto const sigPartLen = base64url::encodedLength(plainSig.size());
auto const totalLen = tokenPart.size() + 1 + sigPartLen;
auto out = new (arena) uint8_t[totalLen];
@ -335,9 +386,9 @@ bool parseField(Arena& arena, Optional<FieldType>& out, const rapidjson::Documen
return false;
out = StringRef(arena, reinterpret_cast<const uint8_t*>(field.GetString()), field.GetStringLength());
} else if constexpr (std::is_same_v<FieldType, uint64_t>) {
if (!field.IsUint64())
if (!field.IsNumber())
return false;
out = field.GetUint64();
out = static_cast<uint64_t>(field.GetDouble());
} else {
if (!field.IsArray())
return false;
@ -442,13 +493,17 @@ bool verifyToken(StringRef signedToken, PublicKey publicKey) {
auto [verifyAlgo, digest] = getMethod(parsedToken.algorithm);
if (!checkVerifyAlgorithm(verifyAlgo, publicKey))
return false;
if (parsedToken.algorithm == Algorithm::ES256) {
// Need to convert IEEE-P1363 signature to ASN.1/DER
auto convertedSig = convertEs256P1363ToDer(arena, sig);
if (!convertedSig.present())
return false;
sig = convertedSig.get();
}
return publicKey.verify(b64urlTokenPart, sig, *digest);
}
TokenRef makeRandomTokenSpec(Arena& arena, IRandom& rng, Algorithm alg) {
if (alg != Algorithm::ES256) {
throw unsupported_operation();
}
auto ret = TokenRef{};
ret.algorithm = alg;
ret.keyId = genRandomAlphanumStringRef(arena, rng, MaxKeyNameLenPlus1);
@ -460,7 +515,7 @@ TokenRef makeRandomTokenSpec(Arena& arena, IRandom& rng, Algorithm alg) {
for (auto i = 0; i < numAudience; i++)
aud[i] = genRandomAlphanumStringRef(arena, rng, MaxTenantNameLenPlus1);
ret.audience = VectorRef<StringRef>(aud, numAudience);
ret.issuedAtUnixTime = uint64_t(std::floor(g_network->timer()));
ret.issuedAtUnixTime = g_network->timer();
ret.notBeforeUnixTime = ret.issuedAtUnixTime.get();
ret.expiresAtUnixTime = ret.issuedAtUnixTime.get() + rng.randomInt(360, 1080 + 1);
auto numTenants = rng.randomInt(1, 3);
@ -569,51 +624,68 @@ TEST_CASE("/fdbrpc/TokenSign/JWT/ToStringRef") {
}
TEST_CASE("/fdbrpc/TokenSign/bench") {
constexpr auto repeat = 5;
constexpr auto numSamples = 10000;
auto keys = std::vector<PrivateKey>(numSamples);
auto pubKeys = std::vector<PublicKey>(numSamples);
for (auto i = 0; i < numSamples; i++) {
keys[i] = mkcert::makeEcP256();
pubKeys[i] = keys[i].toPublic();
}
fmt::print("{} keys generated\n", numSamples);
auto& rng = *deterministicRandom();
auto arena = Arena();
auto jwts = new (arena) StringRef[numSamples];
auto fbs = new (arena) StringRef[numSamples];
{
auto tmpArena = Arena();
auto keyTypes = std::array<StringRef, 2>{ "EC"_sr, "RSA"_sr };
for (auto kty : keyTypes) {
constexpr auto repeat = 5;
constexpr auto numSamples = 10000;
fmt::print("=== {} keys case\n", kty.toString());
auto key = kty == "EC"_sr ? mkcert::makeEcP256() : mkcert::makeRsa4096Bit();
auto pubKey = key.toPublic();
auto& rng = *deterministicRandom();
auto arena = Arena();
auto jwtSpecs = new (arena) authz::jwt::TokenRef[numSamples];
auto fbSpecs = new (arena) authz::flatbuffers::TokenRef[numSamples];
auto jwts = new (arena) StringRef[numSamples];
auto fbs = new (arena) StringRef[numSamples];
for (auto i = 0; i < numSamples; i++) {
auto jwtSpec = authz::jwt::makeRandomTokenSpec(tmpArena, rng, authz::Algorithm::ES256);
jwts[i] = authz::jwt::signToken(arena, jwtSpec, keys[i]);
auto fbSpec = authz::flatbuffers::makeRandomTokenSpec(tmpArena, rng);
auto fbToken = authz::flatbuffers::signToken(tmpArena, fbSpec, "defaultKey"_sr, keys[i]);
auto wr = ObjectWriter([&arena](size_t len) { return new (arena) uint8_t[len]; }, Unversioned());
wr.serialize(fbToken);
fbs[i] = wr.toStringRef();
jwtSpecs[i] = authz::jwt::makeRandomTokenSpec(
arena, rng, kty == "EC"_sr ? authz::Algorithm::ES256 : authz::Algorithm::RS256);
fbSpecs[i] = authz::flatbuffers::makeRandomTokenSpec(arena, rng);
}
{
auto const jwtSignBegin = timer_monotonic();
for (auto i = 0; i < numSamples; i++) {
jwts[i] = authz::jwt::signToken(arena, jwtSpecs[i], key);
}
auto const jwtSignEnd = timer_monotonic();
fmt::print("JWT Sign : {:.2f} OPS\n", numSamples / (jwtSignEnd - jwtSignBegin));
}
{
auto const jwtVerifyBegin = timer_monotonic();
for (auto rep = 0; rep < repeat; rep++) {
for (auto i = 0; i < numSamples; i++) {
auto verifyOk = authz::jwt::verifyToken(jwts[i], pubKey);
ASSERT(verifyOk);
}
}
auto const jwtVerifyEnd = timer_monotonic();
fmt::print("JWT Verify : {:.2f} OPS\n", repeat * numSamples / (jwtVerifyEnd - jwtVerifyBegin));
}
{
auto tmpArena = Arena();
auto const fbSignBegin = timer_monotonic();
for (auto i = 0; i < numSamples; i++) {
auto fbToken = authz::flatbuffers::signToken(tmpArena, fbSpecs[i], "defaultKey"_sr, key);
auto wr = ObjectWriter([&arena](size_t len) { return new (arena) uint8_t[len]; }, Unversioned());
wr.serialize(fbToken);
fbs[i] = wr.toStringRef();
}
auto const fbSignEnd = timer_monotonic();
fmt::print("FlatBuffers Sign : {:.2f} OPS\n", numSamples / (fbSignEnd - fbSignBegin));
}
{
auto const fbVerifyBegin = timer_monotonic();
for (auto rep = 0; rep < repeat; rep++) {
for (auto i = 0; i < numSamples; i++) {
auto signedToken = ObjectReader::fromStringRef<Standalone<authz::flatbuffers::SignedTokenRef>>(
fbs[i], Unversioned());
auto verifyOk = authz::flatbuffers::verifyToken(signedToken, pubKey);
ASSERT(verifyOk);
}
}
auto const fbVerifyEnd = timer_monotonic();
fmt::print("FlatBuffers Verify : {:.2f} OPS\n", repeat * numSamples / (fbVerifyEnd - fbVerifyBegin));
}
}
fmt::print("{} FB/JWT tokens generated\n", numSamples);
auto jwtBegin = timer_monotonic();
for (auto rep = 0; rep < repeat; rep++) {
for (auto i = 0; i < numSamples; i++) {
auto verifyOk = authz::jwt::verifyToken(jwts[i], pubKeys[i]);
ASSERT(verifyOk);
}
}
auto jwtEnd = timer_monotonic();
fmt::print("JWT: {:.2f} OPS\n", repeat * numSamples / (jwtEnd - jwtBegin));
auto fbBegin = timer_monotonic();
for (auto rep = 0; rep < repeat; rep++) {
for (auto i = 0; i < numSamples; i++) {
auto signedToken =
ObjectReader::fromStringRef<Standalone<authz::flatbuffers::SignedTokenRef>>(fbs[i], Unversioned());
auto verifyOk = authz::flatbuffers::verifyToken(signedToken, pubKeys[i]);
ASSERT(verifyOk);
}
}
auto fbEnd = timer_monotonic();
fmt::print("FlatBuffers: {:.2f} OPS\n", repeat * numSamples / (fbEnd - fbBegin));
return Void();
}

View File

@ -298,6 +298,12 @@ public:
void removePublicKey(StringRef name);
void removeAllPublicKeys();
// Synchronously load and apply JWKS (RFC 7517) public key file with which to verify authorization tokens.
void loadPublicKeyFile(const std::string& publicKeyFilePath);
// Periodically read JWKS (RFC 7517) public key file to refresh public key set.
void watchPublicKeyFile(const std::string& publicKeyFilePath);
private:
class TransportData* self;
};

View File

@ -47,6 +47,7 @@ public:
int read(uint8_t* begin, uint8_t* end) override;
int write(SendBuffer const* buffer, int limit) override;
NetworkAddress getPeerAddress() const override;
bool hasTrustedPeer() const override;
UID getDebugID() const override;
boost::asio::ip::tcp::socket& getSocket() override { return socket; }
static Future<std::vector<NetworkAddress>> resolveTCPEndpoint(const std::string& host,

View File

@ -208,7 +208,7 @@ SimClogging g_clogging;
struct Sim2Conn final : IConnection, ReferenceCounted<Sim2Conn> {
Sim2Conn(ISimulator::ProcessInfo* process)
: opened(false), closedByCaller(false), stableConnection(false), process(process),
: opened(false), closedByCaller(false), stableConnection(false), trustedPeer(true), process(process),
dbgid(deterministicRandom()->randomUniqueID()), stopReceive(Never()) {
pipes = sender(this) && receiver(this);
}
@ -259,6 +259,8 @@ struct Sim2Conn final : IConnection, ReferenceCounted<Sim2Conn> {
bool isPeerGone() const { return !peer || peerProcess->failed; }
bool hasTrustedPeer() const override { return trustedPeer; }
bool isStableConnection() const override { return stableConnection; }
void peerClosed() {
@ -327,7 +329,7 @@ struct Sim2Conn final : IConnection, ReferenceCounted<Sim2Conn> {
boost::asio::ip::tcp::socket& getSocket() override { throw operation_failed(); }
bool opened, closedByCaller, stableConnection;
bool opened, closedByCaller, stableConnection, trustedPeer;
private:
ISimulator::ProcessInfo *process, *peerProcess;

View File

@ -0,0 +1,357 @@
/*
* AuthzTlsTest.cpp
*
* This source file is part of the FoundationDB open source project
*
* Copyright 2013-2022 Apple Inc. and the FoundationDB project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef _WIN32
#include <algorithm>
#include <cstring>
#include <fmt/format.h>
#include <unistd.h>
#include <string_view>
#include <signal.h>
#include <sys/wait.h>
#include "flow/Arena.h"
#include "flow/MkCert.h"
#include "flow/ScopeExit.h"
#include "flow/TLSConfig.actor.h"
#include "fdbrpc/fdbrpc.h"
#include "fdbrpc/FlowTransport.h"
#include "flow/actorcompiler.h" // This must be the last #include.
std::FILE* outp = stdout;
template <class... Args>
void log(Args&&... args) {
auto buf = fmt::memory_buffer{};
fmt::format_to(std::back_inserter(buf), std::forward<Args>(args)...);
fmt::print(outp, "{}\n", std::string_view(buf.data(), buf.size()));
}
template <class... Args>
void logc(Args&&... args) {
auto buf = fmt::memory_buffer{};
fmt::format_to(std::back_inserter(buf), "[CLIENT] ");
fmt::format_to(std::back_inserter(buf), std::forward<Args>(args)...);
fmt::print(outp, "{}\n", std::string_view(buf.data(), buf.size()));
}
template <class... Args>
void logs(Args&&... args) {
auto buf = fmt::memory_buffer{};
fmt::format_to(std::back_inserter(buf), "[SERVER] ");
fmt::format_to(std::back_inserter(buf), std::forward<Args>(args)...);
fmt::print(outp, "{}\n", std::string_view(buf.data(), buf.size()));
}
template <class... Args>
void logm(Args&&... args) {
auto buf = fmt::memory_buffer{};
fmt::format_to(std::back_inserter(buf), "[ MAIN ] ");
fmt::format_to(std::back_inserter(buf), std::forward<Args>(args)...);
fmt::print(outp, "{}\n", std::string_view(buf.data(), buf.size()));
}
struct TLSCreds {
std::string certBytes;
std::string keyBytes;
std::string caBytes;
};
TLSCreds makeCreds(int chainLen, mkcert::ESide side) {
if (chainLen == 0)
return {};
auto arena = Arena();
auto ret = TLSCreds{};
auto specs = mkcert::makeCertChainSpec(arena, std::labs(chainLen), side);
if (chainLen < 0) {
specs[0].offsetNotBefore = -60l * 60 * 24 * 365;
specs[0].offsetNotAfter = -10l; // cert that expired 10 seconds ago
}
auto chain = mkcert::makeCertChain(arena, specs, {} /* create root CA cert from spec*/);
if (chain.size() == 1) {
ret.certBytes = concatCertChain(arena, chain).toString();
} else {
auto nonRootChain = chain;
nonRootChain.pop_back();
ret.certBytes = concatCertChain(arena, nonRootChain).toString();
}
ret.caBytes = chain.back().certPem.toString();
ret.keyBytes = chain.front().privateKeyPem.toString();
return ret;
}
enum class Result : int {
TRUSTED = 0,
UNTRUSTED,
ERROR,
};
template <>
struct fmt::formatter<Result> {
constexpr auto parse(format_parse_context& ctx) -> decltype(ctx.begin()) { return ctx.begin(); }
template <class FormatContext>
auto format(const Result& r, FormatContext& ctx) -> decltype(ctx.out()) {
if (r == Result::TRUSTED)
return fmt::format_to(ctx.out(), "TRUSTED");
else if (r == Result::UNTRUSTED)
return fmt::format_to(ctx.out(), "UNTRUSTED");
else
return fmt::format_to(ctx.out(), "ERROR");
}
};
ACTOR template <class T>
Future<T> stopNetworkAfter(Future<T> what) {
T t = wait(what);
g_network->stop();
return t;
}
// Reflective struct containing information about the requester from a server PoV
struct SessionInfo {
constexpr static FileIdentifier file_identifier = 1578312;
bool isPeerTrusted = false;
NetworkAddress peerAddress;
template <class Ar>
void serialize(Ar& ar) {
serializer(ar, isPeerTrusted, peerAddress);
}
};
struct SessionProbeRequest {
constexpr static FileIdentifier file_identifier = 1559713;
ReplyPromise<SessionInfo> reply{ PeerCompatibilityPolicy{ RequirePeer::AtLeast,
ProtocolVersion::withStableInterfaces() } };
bool verify() const { return true; }
template <class Ar>
void serialize(Ar& ar) {
serializer(ar, reply);
}
};
struct SessionProbeReceiver final : NetworkMessageReceiver {
SessionProbeReceiver() {}
void receive(ArenaObjectReader& reader) override {
SessionProbeRequest req;
reader.deserialize(req);
SessionInfo res;
res.isPeerTrusted = FlowTransport::transport().currentDeliveryPeerIsTrusted();
res.peerAddress = FlowTransport::transport().currentDeliveryPeerAddress();
req.reply.send(res);
}
PeerCompatibilityPolicy peerCompatibilityPolicy() const override {
return PeerCompatibilityPolicy{ RequirePeer::AtLeast, ProtocolVersion::withStableInterfaces() };
}
bool isPublic() const override { return true; }
};
Future<Void> runServer(Future<Void> listenFuture, const Endpoint& endpoint, int addrPipe, int completionPipe) {
auto realAddr = FlowTransport::transport().getLocalAddresses().address;
logs("Listening at {}", realAddr.toString());
logs("Endpoint token is {}", endpoint.token.toString());
// below writes/reads would block, but this is good enough for a test.
if (sizeof(realAddr) != ::write(addrPipe, &realAddr, sizeof(realAddr))) {
logs("Failed to write server addr to pipe: {}", strerror(errno));
return Void();
}
if (sizeof(endpoint.token) != ::write(addrPipe, &endpoint.token, sizeof(endpoint.token))) {
logs("Failed to write server endpoint to pipe: {}", strerror(errno));
return Void();
}
auto done = false;
if (sizeof(done) != ::read(completionPipe, &done, sizeof(done))) {
logs("Failed to read completion flag from pipe: {}", strerror(errno));
return Void();
}
return Void();
}
ACTOR Future<Void> waitAndPrintResponse(Future<SessionInfo> response, Result* rc) {
try {
SessionInfo info = wait(response);
logc("Probe response: trusted={} peerAddress={}", info.isPeerTrusted, info.peerAddress.toString());
*rc = info.isPeerTrusted ? Result::TRUSTED : Result::UNTRUSTED;
} catch (Error& err) {
logc("Error: {}", err.what());
*rc = Result::ERROR;
}
return Void();
}
template <bool IsServer>
int runHost(TLSCreds creds, int addrPipe, int completionPipe, Result expect) {
auto tlsConfig = TLSConfig(IsServer ? TLSEndpointType::SERVER : TLSEndpointType::CLIENT);
tlsConfig.setCertificateBytes(creds.certBytes);
tlsConfig.setCABytes(creds.caBytes);
tlsConfig.setKeyBytes(creds.keyBytes);
g_network = newNet2(tlsConfig);
openTraceFile(NetworkAddress(),
10 << 20,
10 << 20,
".",
IsServer ? "authz_tls_unittest_server" : "authz_tls_unittest_client");
FlowTransport::createInstance(!IsServer, 1, WLTOKEN_RESERVED_COUNT);
auto& transport = FlowTransport::transport();
if constexpr (IsServer) {
auto addr = NetworkAddress::parse("127.0.0.1:0:tls");
auto thread = std::thread([]() {
g_network->run();
flushTraceFileVoid();
});
auto endpoint = Endpoint();
auto receiver = SessionProbeReceiver();
transport.addEndpoint(endpoint, &receiver, TaskPriority::ReadSocket);
runServer(transport.bind(addr, addr), endpoint, addrPipe, completionPipe);
auto cleanupGuard = ScopeExit([&thread]() {
g_network->stop();
thread.join();
});
} else {
auto dest = Endpoint();
auto& serverAddr = dest.addresses.address;
if (sizeof(serverAddr) != ::read(addrPipe, &serverAddr, sizeof(serverAddr))) {
logc("Failed to read server addr from pipe: {}", strerror(errno));
return 1;
}
auto& token = dest.token;
if (sizeof(token) != ::read(addrPipe, &token, sizeof(token))) {
logc("Failed to read server endpoint token from pipe: {}", strerror(errno));
return 2;
}
logc("Server address is {}", serverAddr.toString());
logc("Server endpoint token is {}", token.toString());
auto sessionProbeReq = SessionProbeRequest{};
transport.sendUnreliable(SerializeSource(sessionProbeReq), dest, true /*openConnection*/);
logc("Request is sent");
auto probeResponse = sessionProbeReq.reply.getFuture();
auto result = Result::TRUSTED;
auto timeout = delay(5);
auto complete = waitAndPrintResponse(probeResponse, &result);
auto f = stopNetworkAfter(complete || timeout);
auto rc = 0;
g_network->run();
if (!complete.isReady()) {
logc("Error: Probe request timed out");
rc = 3;
}
auto done = true;
if (sizeof(done) != ::write(completionPipe, &done, sizeof(done))) {
logc("Failed to signal server to terminate: {}", strerror(errno));
rc = 4;
}
if (rc == 0) {
if (expect != result) {
logc("Test failed: expected {}, got {}", expect, result);
rc = 5;
} else {
logc("Response OK: got {} as expected", result);
}
}
return rc;
}
return 0;
}
int runTlsTest(int serverChainLen, int clientChainLen) {
log("==== BEGIN TESTCASE ====");
auto expect = Result::ERROR;
if (serverChainLen > 0) {
if (clientChainLen > 0)
expect = Result::TRUSTED;
else if (clientChainLen == 0)
expect = Result::UNTRUSTED;
}
log("Cert chain length: server={} client={}", serverChainLen, clientChainLen);
auto arena = Arena();
auto serverCreds = makeCreds(serverChainLen, mkcert::ESide::Server);
auto clientCreds = makeCreds(clientChainLen, mkcert::ESide::Client);
// make server and client trust each other
std::swap(serverCreds.caBytes, clientCreds.caBytes);
auto clientPid = pid_t{};
auto serverPid = pid_t{};
int addrPipe[2];
int completionPipe[2];
if (::pipe(addrPipe) || ::pipe(completionPipe)) {
logm("Pipe open failed: {}", strerror(errno));
return 1;
}
auto pipeCleanup = ScopeExit([&addrPipe, &completionPipe]() {
::close(addrPipe[0]);
::close(addrPipe[1]);
::close(completionPipe[0]);
::close(completionPipe[1]);
});
serverPid = fork();
if (serverPid == 0) {
_exit(runHost<true>(std::move(serverCreds), addrPipe[1], completionPipe[0], expect));
}
clientPid = fork();
if (clientPid == 0) {
_exit(runHost<false>(std::move(clientCreds), addrPipe[0], completionPipe[1], expect));
}
auto pid = pid_t{};
auto status = int{};
pid = waitpid(clientPid, &status, 0);
auto ok = true;
if (pid < 0) {
logm("waitpid() for client failed with {}", strerror(errno));
ok = false;
} else {
if (status != 0) {
logm("Client error: rc={}", status);
ok = false;
} else {
logm("Client OK");
}
}
pid = waitpid(serverPid, &status, 0);
if (pid < 0) {
logm("waitpid() for server failed with {}", strerror(errno));
ok = false;
} else {
if (status != 0) {
logm("Server error: rc={}", status);
ok = false;
} else {
logm("Server OK");
}
}
log(ok ? "OK" : "FAILED");
return 0;
}
int main() {
std::pair<int, int> inputs[] = { { 3, 2 }, { 4, 0 }, { 1, 3 }, { 1, 0 }, { 2, 0 }, { 3, 3 }, { 3, 0 } };
for (auto input : inputs) {
auto [serverChainLen, clientChainLen] = input;
if (auto rc = runTlsTest(serverChainLen, clientChainLen))
return rc;
}
return 0;
}
#else // _WIN32
int main() {
return 0;
}
#endif // _WIN32

View File

@ -0,0 +1,6 @@
if(NOT WIN32)
add_flow_target(EXECUTABLE NAME authz_tls_unittest SRCS AuthzTlsTest.actor.cpp)
target_link_libraries(authz_tls_unittest PRIVATE flow fdbrpc fmt::fmt)
add_test(NAME authorization_tls_unittest
COMMAND $<TARGET_FILE:authz_tls_unittest>)
endif()

View File

@ -3304,7 +3304,8 @@ ACTOR Future<Void> loadForcePurgedRanges(Reference<BlobManagerData> bmData) {
beginKey = results.back().key;
} catch (Error& e) {
if (BM_DEBUG) {
fmt::print("BM {0} got error reading granule mapping during recovery: {1}\n", bmData->epoch, e.name());
fmt::print(
"BM {0} got error reading force purge ranges during recovery: {1}\n", bmData->epoch, e.name());
}
wait(tr->onError(e));
}

View File

@ -1185,16 +1185,16 @@ ACTOR Future<BlobFileIndex> compactFromBlob(Reference<BlobWorkerData> bwData,
}
ASSERT(lastDeltaVersion >= version);
chunk.includedVersion = version;
if (BW_DEBUG) {
fmt::print("Re-snapshotting [{0} - {1}) @ {2} from blob\n",
metadata->keyRange.begin.printable(),
metadata->keyRange.end.printable(),
version);
}
chunksToRead.push_back(readBlobGranule(chunk, metadata->keyRange, 0, version, bstore, &bwData->stats));
}
if (BW_DEBUG) {
fmt::print("Re-snapshotting [{0} - {1}) @ {2} from blob\n",
metadata->keyRange.begin.printable(),
metadata->keyRange.end.printable(),
version);
}
try {
state PromiseStream<RangeResult> rowsStream;
state Future<BlobFileIndex> snapshotWriter = writeSnapshot(bwData,
@ -1839,6 +1839,20 @@ ACTOR Future<Void> waitVersionCommitted(Reference<BlobWorkerData> bwData,
return Void();
}
ACTOR Future<bool> checkFileNotFoundForcePurgeRace(Reference<BlobWorkerData> bwData, KeyRange range) {
state Transaction tr(bwData->db);
loop {
try {
tr.setOption(FDBTransactionOptions::PRIORITY_SYSTEM_IMMEDIATE);
tr.setOption(FDBTransactionOptions::LOCK_AWARE);
ForcedPurgeState purgeState = wait(getForcePurgedState(&tr, range));
return purgeState != ForcedPurgeState::NonePurged;
} catch (Error& e) {
wait(tr.onError(e));
}
}
}
// updater for a single granule
// TODO: this is getting kind of large. Should try to split out this actor if it continues to grow?
ACTOR Future<Void> blobGranuleUpdateFiles(Reference<BlobWorkerData> bwData,
@ -2637,17 +2651,31 @@ ACTOR Future<Void> blobGranuleUpdateFiles(Reference<BlobWorkerData> bwData,
throw e;
}
state Error e2 = e;
if (e.code() == error_code_file_not_found) {
// FIXME: better way to fix this?
bool isForcePurging = wait(checkFileNotFoundForcePurgeRace(bwData, metadata->keyRange));
if (isForcePurging) {
CODE_PROBE(true, "Granule got file not found from force purge");
TraceEvent("GranuleFileUpdaterFileNotFoundForcePurge", bwData->id)
.error(e2)
.detail("KeyRange", metadata->keyRange)
.detail("GranuleID", startState.granuleID);
return Void();
}
}
TraceEvent(SevError, "GranuleFileUpdaterUnexpectedError", bwData->id)
.error(e)
.error(e2)
.detail("Granule", metadata->keyRange)
.detail("GranuleID", startState.granuleID);
ASSERT_WE_THINK(false);
// if not simulation, kill the BW
if (bwData->fatalError.canBeSet()) {
bwData->fatalError.sendError(e);
bwData->fatalError.sendError(e2);
}
throw e;
throw e2;
}
}
@ -3195,7 +3223,7 @@ bool canReplyWith(Error e) {
switch (e.code()) {
case error_code_blob_granule_transaction_too_old:
case error_code_transaction_too_old:
case error_code_future_version: // not thrown yet
case error_code_future_version:
case error_code_wrong_shard_server:
case error_code_process_behind: // not thrown yet
case error_code_blob_worker_full:
@ -3234,6 +3262,7 @@ ACTOR Future<Void> waitForVersion(Reference<GranuleMetadata> metadata, Version v
// wait for change feed version to catch up to ensure we have all data
if (metadata->activeCFData.get()->getVersion() < v) {
// FIXME: add future version timeout and throw here, same as SS
wait(metadata->activeCFData.get()->whenAtLeast(v));
ASSERT(metadata->activeCFData.get()->getVersion() >= v);
}
@ -4918,4 +4947,4 @@ ACTOR Future<Void> blobWorker(BlobWorkerInterface bwInterf,
return Void();
}
// TODO add unit tests for assign/revoke range, especially version ordering
// TODO add unit tests for assign/revoke range, especially version ordering

View File

@ -103,7 +103,7 @@ struct RatekeeperSingleton : Singleton<RatekeeperInterface> {
}
}
void halt(ClusterControllerData* cc, Optional<Standalone<StringRef>> pid) const {
if (interface.present()) {
if (interface.present() && cc->id_worker.count(pid)) {
cc->id_worker[pid].haltRatekeeper =
brokenPromiseToNever(interface.get().haltRatekeeper.getReply(HaltRatekeeperRequest(cc->id)));
}
@ -128,7 +128,7 @@ struct DataDistributorSingleton : Singleton<DataDistributorInterface> {
}
}
void halt(ClusterControllerData* cc, Optional<Standalone<StringRef>> pid) const {
if (interface.present()) {
if (interface.present() && cc->id_worker.count(pid)) {
cc->id_worker[pid].haltDistributor =
brokenPromiseToNever(interface.get().haltDataDistributor.getReply(HaltDataDistributorRequest(cc->id)));
}
@ -153,7 +153,7 @@ struct BlobManagerSingleton : Singleton<BlobManagerInterface> {
}
}
void halt(ClusterControllerData* cc, Optional<Standalone<StringRef>> pid) const {
if (interface.present()) {
if (interface.present() && cc->id_worker.count(pid)) {
cc->id_worker[pid].haltBlobManager =
brokenPromiseToNever(interface.get().haltBlobManager.getReply(HaltBlobManagerRequest(cc->id)));
}
@ -185,7 +185,7 @@ struct EncryptKeyProxySingleton : Singleton<EncryptKeyProxyInterface> {
}
}
void halt(ClusterControllerData* cc, Optional<Standalone<StringRef>> pid) const {
if (interface.present()) {
if (interface.present() && cc->id_worker.count(pid)) {
cc->id_worker[pid].haltEncryptKeyProxy =
brokenPromiseToNever(interface.get().haltEncryptKeyProxy.getReply(HaltEncryptKeyProxyRequest(cc->id)));
}
@ -2058,8 +2058,9 @@ ACTOR Future<Void> monitorDataDistributor(ClusterControllerData* self) {
choose {
when(wait(waitFailureClient(self->db.serverInfo->get().distributor.get().waitFailure,
SERVER_KNOBS->DD_FAILURE_TIME))) {
TraceEvent("CCDataDistributorDied", self->id)
.detail("DDID", self->db.serverInfo->get().distributor.get().id());
const auto& distributor = self->db.serverInfo->get().distributor;
TraceEvent("CCDataDistributorDied", self->id).detail("DDID", distributor.get().id());
DataDistributorSingleton(distributor).halt(self, distributor.get().locality.processId());
self->db.clearInterf(ProcessClass::DataDistributorClass);
}
when(wait(self->recruitDistributor.onChange())) {}
@ -2149,8 +2150,9 @@ ACTOR Future<Void> monitorRatekeeper(ClusterControllerData* self) {
choose {
when(wait(waitFailureClient(self->db.serverInfo->get().ratekeeper.get().waitFailure,
SERVER_KNOBS->RATEKEEPER_FAILURE_TIME))) {
TraceEvent("CCRatekeeperDied", self->id)
.detail("RKID", self->db.serverInfo->get().ratekeeper.get().id());
const auto& ratekeeper = self->db.serverInfo->get().ratekeeper;
TraceEvent("CCRatekeeperDied", self->id).detail("RKID", ratekeeper.get().id());
RatekeeperSingleton(ratekeeper).halt(self, ratekeeper.get().locality.processId());
self->db.clearInterf(ProcessClass::RatekeeperClass);
}
when(wait(self->recruitRatekeeper.onChange())) {}
@ -2245,6 +2247,8 @@ ACTOR Future<Void> monitorEncryptKeyProxy(ClusterControllerData* self) {
when(wait(waitFailureClient(self->db.serverInfo->get().encryptKeyProxy.get().waitFailure,
SERVER_KNOBS->ENCRYPT_KEY_PROXY_FAILURE_TIME))) {
TraceEvent("CCEKP_Died", self->id);
const auto& encryptKeyProxy = self->db.serverInfo->get().encryptKeyProxy;
EncryptKeyProxySingleton(encryptKeyProxy).halt(self, encryptKeyProxy.get().locality.processId());
self->db.clearInterf(ProcessClass::EncryptKeyProxyClass);
}
when(wait(self->recruitEncryptKeyProxy.onChange())) {}
@ -2389,8 +2393,9 @@ ACTOR Future<Void> monitorBlobManager(ClusterControllerData* self) {
loop {
choose {
when(wait(wfClient)) {
TraceEvent("CCBlobManagerDied", self->id)
.detail("BMID", self->db.serverInfo->get().blobManager.get().id());
const auto& blobManager = self->db.serverInfo->get().blobManager;
TraceEvent("CCBlobManagerDied", self->id).detail("BMID", blobManager.get().id());
BlobManagerSingleton(blobManager).halt(self, blobManager.get().locality.processId());
self->db.clearInterf(ProcessClass::BlobManagerClass);
break;
}

View File

@ -1744,6 +1744,12 @@ ACTOR Future<Void> clusterRecoveryCore(Reference<ClusterRecoveryData> self) {
.detail("RecoveryDuration", recoveryDuration)
.trackLatest(self->clusterRecoveryStateEventHolder->trackingKey);
TraceEvent(getRecoveryEventName(ClusterRecoveryEventType::CLUSTER_RECOVERY_AVAILABLE_EVENT_NAME).c_str(),
self->dbgid)
.detail("NumOfOldGensOfLogs", self->cstate.myDBState.oldTLogData.size())
.detail("AvailableAtVersion", self->recoveryTransactionVersion)
.trackLatest(self->clusterRecoveryAvailableEventHolder->trackingKey);
self->addActor.send(changeCoordinators(self));
Database cx = openDBOnServer(self->dbInfo, TaskPriority::DefaultEndpoint, LockAware::True);
self->addActor.send(configurationMonitor(self, cx));

View File

@ -168,7 +168,6 @@ class GlobalTagThrottlerImpl {
std::unordered_map<UID, Optional<double>> throttlingRatios;
std::unordered_map<TransactionTag, PerTagStatistics> tagStatistics;
std::unordered_map<UID, std::unordered_map<TransactionTag, ThroughputCounters>> throughput;
GlobalTagThrottlerStatusReply statusReply;
// Returns the cost rate for the given tag on the given storage server
Optional<double> getCurrentCost(UID storageServerId, TransactionTag tag, OpType opType) const {
@ -422,12 +421,6 @@ class GlobalTagThrottlerImpl {
.detail("DesiredTps", desiredTps)
.detail("NumStorageServers", throughput.size());
auto& tagStats = statusReply.status[tag];
tagStats.desiredTps = desiredTps.get();
tagStats.limitingTps = limitingTps;
tagStats.targetTps = targetTps.get();
tagStats.reservedTps = reservedTps.get();
return targetTps;
}
@ -440,7 +433,6 @@ public:
PrioritizedTransactionTagMap<double> getProxyRates(int numProxies) {
PrioritizedTransactionTagMap<double> result;
lastBusyReadTagCount = lastBusyWriteTagCount = 0;
statusReply = {};
for (auto& [tag, stats] : tagStatistics) {
// Currently there is no differentiation between batch priority and default priority transactions
@ -468,7 +460,6 @@ public:
PrioritizedTransactionTagMap<ClientTagThrottleLimits> getClientRates() {
PrioritizedTransactionTagMap<ClientTagThrottleLimits> result;
lastBusyReadTagCount = lastBusyWriteTagCount = 0;
statusReply = {};
for (auto& [tag, stats] : tagStatistics) {
// Currently there is no differentiation between batch priority and default priority transactions
@ -524,8 +515,6 @@ public:
}
void removeQuota(TransactionTagRef tag) { tagStatistics[tag].clearQuota(); }
GlobalTagThrottlerStatusReply getStatus() const { return statusReply; }
};
GlobalTagThrottler::GlobalTagThrottler(Database db, UID id) : impl(PImpl<GlobalTagThrottlerImpl>::create(db, id)) {}
@ -574,10 +563,6 @@ void GlobalTagThrottler::removeQuota(TransactionTagRef tag) {
return impl->removeQuota(tag);
}
GlobalTagThrottlerStatusReply GlobalTagThrottler::getGlobalTagThrottlerStatusReply() const {
return impl->getStatus();
}
namespace GlobalTagThrottlerTesting {
enum class LimitType { RESERVED, TOTAL };
@ -727,11 +712,6 @@ bool isNear(Optional<double> a, Optional<double> b) {
}
}
bool isNear(GlobalTagThrottlerStatusReply::TagStats const& a, GlobalTagThrottlerStatusReply::TagStats const& b) {
return isNear(a.desiredTps, b.desiredTps) && isNear(a.targetTps, b.targetTps) &&
isNear(a.reservedTps, b.reservedTps) && isNear(a.limitingTps, b.limitingTps);
}
bool targetRateIsNear(GlobalTagThrottler& globalTagThrottler, TransactionTag tag, Optional<double> expected) {
Optional<double> rate;
auto targetRates = globalTagThrottler.getProxyRates(1);
@ -766,22 +746,6 @@ bool clientRateIsNear(GlobalTagThrottler& globalTagThrottler, TransactionTag tag
return isNear(rate, expected);
}
bool statusIsNear(GlobalTagThrottler const& globalTagThrottler,
TransactionTag tag,
GlobalTagThrottlerStatusReply::TagStats expectedStats) {
auto const stats = globalTagThrottler.getGlobalTagThrottlerStatusReply().status[tag];
TraceEvent("GlobalTagThrottling_StatusMonitor")
.detail("DesiredTps", stats.desiredTps)
.detail("ExpectedDesiredTps", expectedStats.desiredTps)
.detail("LimitingTps", stats.limitingTps)
.detail("ExpectedLimitingTps", expectedStats.limitingTps)
.detail("TargetTps", stats.targetTps)
.detail("ExpectedTargetTps", expectedStats.targetTps)
.detail("ReservedTps", stats.reservedTps)
.detail("ExpectedReservedTps", expectedStats.reservedTps);
return isNear(stats, expectedStats);
}
ACTOR Future<Void> updateGlobalTagThrottler(GlobalTagThrottler* globalTagThrottler,
StorageServerCollection const* storageServers) {
loop {
@ -809,7 +773,7 @@ TEST_CASE("/GlobalTagThrottler/Simple") {
});
state Future<Void> updater =
GlobalTagThrottlerTesting::updateGlobalTagThrottler(&globalTagThrottler, &storageServers);
wait(timeoutError(monitor || client || updater, 300.0));
wait(timeoutError(monitor || client || updater, 600.0));
return Void();
}
@ -827,7 +791,7 @@ TEST_CASE("/GlobalTagThrottler/WriteThrottling") {
});
state Future<Void> updater =
GlobalTagThrottlerTesting::updateGlobalTagThrottler(&globalTagThrottler, &storageServers);
wait(timeoutError(monitor || client || updater, 300.0));
wait(timeoutError(monitor || client || updater, 600.0));
return Void();
}
@ -852,7 +816,7 @@ TEST_CASE("/GlobalTagThrottler/MultiTagThrottling") {
return GlobalTagThrottlerTesting::targetRateIsNear(gtt, testTag1, 100.0 / 6.0) &&
GlobalTagThrottlerTesting::targetRateIsNear(gtt, testTag2, 100.0 / 6.0);
});
wait(timeoutError(waitForAny(futures) || monitor, 300.0));
wait(timeoutError(waitForAny(futures) || monitor, 600.0));
return Void();
}
@ -870,7 +834,7 @@ TEST_CASE("/GlobalTagThrottler/AttemptWorkloadAboveQuota") {
});
state Future<Void> updater =
GlobalTagThrottlerTesting::updateGlobalTagThrottler(&globalTagThrottler, &storageServers);
wait(timeoutError(monitor || client || updater, 300.0));
wait(timeoutError(monitor || client || updater, 600.0));
return Void();
}
@ -891,7 +855,7 @@ TEST_CASE("/GlobalTagThrottler/MultiClientThrottling") {
});
state Future<Void> updater =
GlobalTagThrottlerTesting::updateGlobalTagThrottler(&globalTagThrottler, &storageServers);
wait(timeoutError(monitor || client || client2 || updater, 300.0));
wait(timeoutError(monitor || client || client2 || updater, 600.0));
return Void();
}
@ -912,7 +876,7 @@ TEST_CASE("/GlobalTagThrottler/MultiClientThrottling2") {
});
state Future<Void> updater =
GlobalTagThrottlerTesting::updateGlobalTagThrottler(&globalTagThrottler, &storageServers);
wait(timeoutError(monitor || client || updater, 300.0));
wait(timeoutError(monitor || client || updater, 600.0));
return Void();
}
@ -934,7 +898,7 @@ TEST_CASE("/GlobalTagThrottler/SkewedMultiClientThrottling") {
});
state Future<Void> updater =
GlobalTagThrottlerTesting::updateGlobalTagThrottler(&globalTagThrottler, &storageServers);
wait(timeoutError(monitor || client || updater, 300.0));
wait(timeoutError(monitor || client || updater, 600.0));
return Void();
}
@ -953,13 +917,13 @@ TEST_CASE("/GlobalTagThrottler/UpdateQuota") {
});
state Future<Void> updater =
GlobalTagThrottlerTesting::updateGlobalTagThrottler(&globalTagThrottler, &storageServers);
wait(timeoutError(monitor || client || updater, 300.0));
wait(timeoutError(monitor || client || updater, 600.0));
tagQuotaValue.totalReadQuota = 50.0;
globalTagThrottler.setQuota(testTag, tagQuotaValue);
monitor = GlobalTagThrottlerTesting::monitor(&globalTagThrottler, [](auto& gtt) {
return GlobalTagThrottlerTesting::targetRateIsNear(gtt, "sampleTag1"_sr, 50.0 / 6.0);
});
wait(timeoutError(monitor || client || updater, 300.0));
wait(timeoutError(monitor || client || updater, 600.0));
return Void();
}
@ -977,12 +941,12 @@ TEST_CASE("/GlobalTagThrottler/RemoveQuota") {
});
state Future<Void> updater =
GlobalTagThrottlerTesting::updateGlobalTagThrottler(&globalTagThrottler, &storageServers);
wait(timeoutError(monitor || client || updater, 300.0));
wait(timeoutError(monitor || client || updater, 600.0));
globalTagThrottler.removeQuota(testTag);
monitor = GlobalTagThrottlerTesting::monitor(&globalTagThrottler, [](auto& gtt) {
return GlobalTagThrottlerTesting::targetRateIsNear(gtt, "sampleTag1"_sr, {});
});
wait(timeoutError(monitor || client || updater, 300.0));
wait(timeoutError(monitor || client || updater, 600.0));
return Void();
}
@ -1000,7 +964,7 @@ TEST_CASE("/GlobalTagThrottler/ActiveThrottling") {
});
state Future<Void> updater =
GlobalTagThrottlerTesting::updateGlobalTagThrottler(&globalTagThrottler, &storageServers);
wait(timeoutError(monitor || client || updater, 300.0));
wait(timeoutError(monitor || client || updater, 600.0));
return Void();
}
@ -1027,7 +991,7 @@ TEST_CASE("/GlobalTagThrottler/MultiTagActiveThrottling") {
gtt.busyReadTagCount() == 2;
});
futures.push_back(GlobalTagThrottlerTesting::updateGlobalTagThrottler(&globalTagThrottler, &storageServers));
wait(timeoutError(waitForAny(futures) || monitor, 300.0));
wait(timeoutError(waitForAny(futures) || monitor, 600.0));
return Void();
}
@ -1053,7 +1017,7 @@ TEST_CASE("/GlobalTagThrottler/MultiTagActiveThrottling2") {
GlobalTagThrottlerTesting::targetRateIsNear(gtt, testTag2, 50 / 6.0) && gtt.busyReadTagCount() == 2;
});
futures.push_back(GlobalTagThrottlerTesting::updateGlobalTagThrottler(&globalTagThrottler, &storageServers));
wait(timeoutError(waitForAny(futures) || monitor, 300.0));
wait(timeoutError(waitForAny(futures) || monitor, 600.0));
return Void();
}
@ -1079,7 +1043,7 @@ TEST_CASE("/GlobalTagThrottler/MultiTagActiveThrottling3") {
GlobalTagThrottlerTesting::targetRateIsNear(gtt, testTag2, 100 / 6.0) && gtt.busyReadTagCount() == 1;
});
futures.push_back(GlobalTagThrottlerTesting::updateGlobalTagThrottler(&globalTagThrottler, &storageServers));
wait(timeoutError(waitForAny(futures) || monitor, 300.0));
wait(timeoutError(waitForAny(futures) || monitor, 600.0));
return Void();
}
@ -1098,7 +1062,7 @@ TEST_CASE("/GlobalTagThrottler/ReservedReadQuota") {
});
state Future<Void> updater =
GlobalTagThrottlerTesting::updateGlobalTagThrottler(&globalTagThrottler, &storageServers);
wait(timeoutError(monitor || client || updater, 300.0));
wait(timeoutError(monitor || client || updater, 600.0));
return Void();
}
@ -1117,30 +1081,6 @@ TEST_CASE("/GlobalTagThrottler/ReservedWriteQuota") {
});
state Future<Void> updater =
GlobalTagThrottlerTesting::updateGlobalTagThrottler(&globalTagThrottler, &storageServers);
wait(timeoutError(monitor || client || updater, 300.0));
return Void();
}
TEST_CASE("/GlobalTagThrottler/Status") {
state GlobalTagThrottler globalTagThrottler(Database{}, UID{});
state GlobalTagThrottlerTesting::StorageServerCollection storageServers(10, 100);
GlobalTagThrottlerStatusReply::TagStats expectedStats;
ThrottleApi::TagQuotaValue tagQuotaValue;
TransactionTag testTag = "sampleTag1"_sr;
tagQuotaValue.totalReadQuota = tagQuotaValue.totalWriteQuota = 100.0;
globalTagThrottler.setQuota(testTag, tagQuotaValue);
expectedStats.desiredTps = 100.0 / 6.0;
expectedStats.limitingTps = {};
expectedStats.targetTps = 100.0 / 6.0;
expectedStats.reservedTps = 0.0;
state Future<Void> client = GlobalTagThrottlerTesting::runClient(
&globalTagThrottler, &storageServers, testTag, 5.0, 6.0, GlobalTagThrottlerTesting::OpType::READ);
state Future<Void> monitor =
GlobalTagThrottlerTesting::monitor(&globalTagThrottler, [testTag, expectedStats](auto& gtt) {
return GlobalTagThrottlerTesting::statusIsNear(gtt, testTag, expectedStats);
});
state Future<Void> updater =
GlobalTagThrottlerTesting::updateGlobalTagThrottler(&globalTagThrottler, &storageServers);
wait(timeoutError(monitor || client || updater, 300.0));
wait(timeoutError(monitor || client || updater, 600.0));
return Void();
}

View File

@ -56,30 +56,30 @@ struct KeyValueStoreCompressTestData final : IKeyValueStore {
void clear(KeyRangeRef range, const Arena* arena = nullptr) override { store->clear(range, arena); }
Future<Void> commit(bool sequential = false) override { return store->commit(sequential); }
Future<Optional<Value>> readValue(KeyRef key, IKeyValueStore::ReadType, Optional<UID> debugID) override {
return doReadValue(store, key, debugID);
Future<Optional<Value>> readValue(KeyRef key, Optional<ReadOptions> options) override {
return doReadValue(store, key, options);
}
// Note that readValuePrefix doesn't do anything in this implementation of IKeyValueStore, so the "atomic bomb"
// problem is still present if you are using this storage interface, but this storage interface is not used by
// customers ever. However, if you want to try to test malicious atomic op workloads with compressed values for some
// reason, you will need to fix this.
Future<Optional<Value>> readValuePrefix(KeyRef key,
int maxLength,
IKeyValueStore::ReadType,
Optional<UID> debugID) override {
return doReadValuePrefix(store, key, maxLength, debugID);
Future<Optional<Value>> readValuePrefix(KeyRef key, int maxLength, Optional<ReadOptions> options) override {
return doReadValuePrefix(store, key, maxLength, options);
}
// If rowLimit>=0, reads first rows sorted ascending, otherwise reads last rows sorted descending
// The total size of the returned value (less the last entry) will be less than byteLimit
Future<RangeResult> readRange(KeyRangeRef keys, int rowLimit, int byteLimit, IKeyValueStore::ReadType) override {
return doReadRange(store, keys, rowLimit, byteLimit);
Future<RangeResult> readRange(KeyRangeRef keys,
int rowLimit,
int byteLimit,
Optional<ReadOptions> options = Optional<ReadOptions>()) override {
return doReadRange(store, keys, rowLimit, byteLimit, options);
}
private:
ACTOR static Future<Optional<Value>> doReadValue(IKeyValueStore* store, Key key, Optional<UID> debugID) {
Optional<Value> v = wait(store->readValue(key, IKeyValueStore::ReadType::NORMAL, debugID));
ACTOR static Future<Optional<Value>> doReadValue(IKeyValueStore* store, Key key, Optional<ReadOptions> options) {
Optional<Value> v = wait(store->readValue(key, options));
if (!v.present())
return v;
return unpack(v.get());
@ -88,8 +88,8 @@ private:
ACTOR static Future<Optional<Value>> doReadValuePrefix(IKeyValueStore* store,
Key key,
int maxLength,
Optional<UID> debugID) {
Optional<Value> v = wait(doReadValue(store, key, debugID));
Optional<ReadOptions> options) {
Optional<Value> v = wait(doReadValue(store, key, options));
if (!v.present())
return v;
if (maxLength < v.get().size()) {
@ -98,8 +98,12 @@ private:
return v;
}
}
ACTOR Future<RangeResult> doReadRange(IKeyValueStore* store, KeyRangeRef keys, int rowLimit, int byteLimit) {
RangeResult _vs = wait(store->readRange(keys, rowLimit, byteLimit));
ACTOR Future<RangeResult> doReadRange(IKeyValueStore* store,
KeyRangeRef keys,
int rowLimit,
int byteLimit,
Optional<ReadOptions> options) {
RangeResult _vs = wait(store->readRange(keys, rowLimit, byteLimit, options));
RangeResult vs = _vs; // Get rid of implicit const& from wait statement
Arena& a = vs.arena();
for (int i = 0; i < vs.size(); i++)

View File

@ -198,11 +198,11 @@ public:
return c;
}
Future<Optional<Value>> readValue(KeyRef key, IKeyValueStore::ReadType, Optional<UID> debugID) override {
Future<Optional<Value>> readValue(KeyRef key, Optional<ReadOptions> options) override {
if (recovering.isError())
throw recovering.getError();
if (!recovering.isReady())
return waitAndReadValue(this, key);
return waitAndReadValue(this, key, options);
auto it = data.find(key);
if (it == data.end())
@ -210,14 +210,11 @@ public:
return Optional<Value>(it.getValue());
}
Future<Optional<Value>> readValuePrefix(KeyRef key,
int maxLength,
IKeyValueStore::ReadType,
Optional<UID> debugID) override {
Future<Optional<Value>> readValuePrefix(KeyRef key, int maxLength, Optional<ReadOptions> options) override {
if (recovering.isError())
throw recovering.getError();
if (!recovering.isReady())
return waitAndReadValuePrefix(this, key, maxLength);
return waitAndReadValuePrefix(this, key, maxLength, options);
auto it = data.find(key);
if (it == data.end())
@ -232,11 +229,14 @@ public:
// If rowLimit>=0, reads first rows sorted ascending, otherwise reads last rows sorted descending
// The total size of the returned value (less the last entry) will be less than byteLimit
Future<RangeResult> readRange(KeyRangeRef keys, int rowLimit, int byteLimit, IKeyValueStore::ReadType) override {
Future<RangeResult> readRange(KeyRangeRef keys,
int rowLimit,
int byteLimit,
Optional<ReadOptions> options) override {
if (recovering.isError())
throw recovering.getError();
if (!recovering.isReady())
return waitAndReadRange(this, keys, rowLimit, byteLimit);
return waitAndReadRange(this, keys, rowLimit, byteLimit, options);
RangeResult result;
if (rowLimit == 0) {
@ -926,20 +926,26 @@ private:
}
}
ACTOR static Future<Optional<Value>> waitAndReadValue(KeyValueStoreMemory* self, Key key) {
ACTOR static Future<Optional<Value>> waitAndReadValue(KeyValueStoreMemory* self,
Key key,
Optional<ReadOptions> options) {
wait(self->recovering);
return static_cast<IKeyValueStore*>(self)->readValue(key).get();
return static_cast<IKeyValueStore*>(self)->readValue(key, options).get();
}
ACTOR static Future<Optional<Value>> waitAndReadValuePrefix(KeyValueStoreMemory* self, Key key, int maxLength) {
ACTOR static Future<Optional<Value>> waitAndReadValuePrefix(KeyValueStoreMemory* self,
Key key,
int maxLength,
Optional<ReadOptions> options) {
wait(self->recovering);
return static_cast<IKeyValueStore*>(self)->readValuePrefix(key, maxLength).get();
return static_cast<IKeyValueStore*>(self)->readValuePrefix(key, maxLength, options).get();
}
ACTOR static Future<RangeResult> waitAndReadRange(KeyValueStoreMemory* self,
KeyRange keys,
int rowLimit,
int byteLimit) {
int byteLimit,
Optional<ReadOptions> options) {
wait(self->recovering);
return static_cast<IKeyValueStore*>(self)->readRange(keys, rowLimit, byteLimit).get();
return static_cast<IKeyValueStore*>(self)->readRange(keys, rowLimit, byteLimit, options).get();
}
ACTOR static Future<Void> waitAndCommit(KeyValueStoreMemory* self, bool sequential) {
wait(self->recovering);

View File

@ -1858,8 +1858,8 @@ struct RocksDBKeyValueStore : IKeyValueStore {
// We don't throttle eager reads and reads to the FF keyspace because FDB struggles when those reads fail.
// Thus far, they have been low enough volume to not cause an issue.
static bool shouldThrottle(IKeyValueStore::ReadType type, KeyRef key) {
return type != IKeyValueStore::ReadType::EAGER && !(key.startsWith(systemKeys.begin));
static bool shouldThrottle(ReadType type, KeyRef key) {
return type != ReadType::EAGER && !(key.startsWith(systemKeys.begin));
}
ACTOR template <class Action>
@ -1880,7 +1880,15 @@ struct RocksDBKeyValueStore : IKeyValueStore {
return result;
}
Future<Optional<Value>> readValue(KeyRef key, IKeyValueStore::ReadType type, Optional<UID> debugID) override {
Future<Optional<Value>> readValue(KeyRef key, Optional<ReadOptions> options) override {
ReadType type = ReadType::NORMAL;
Optional<UID> debugID;
if (options.present()) {
type = options.get().type;
debugID = options.get().debugID;
}
if (!shouldThrottle(type, key)) {
auto a = new Reader::ReadValueAction(key, debugID);
auto res = a->result.getFuture();
@ -1888,18 +1896,23 @@ struct RocksDBKeyValueStore : IKeyValueStore {
return res;
}
auto& semaphore = (type == IKeyValueStore::ReadType::FETCH) ? fetchSemaphore : readSemaphore;
int maxWaiters = (type == IKeyValueStore::ReadType::FETCH) ? numFetchWaiters : numReadWaiters;
auto& semaphore = (type == ReadType::FETCH) ? fetchSemaphore : readSemaphore;
int maxWaiters = (type == ReadType::FETCH) ? numFetchWaiters : numReadWaiters;
checkWaiters(semaphore, maxWaiters);
auto a = std::make_unique<Reader::ReadValueAction>(key, debugID);
return read(a.release(), &semaphore, readThreads.getPtr(), &counters.failedToAcquire);
}
Future<Optional<Value>> readValuePrefix(KeyRef key,
int maxLength,
IKeyValueStore::ReadType type,
Optional<UID> debugID) override {
Future<Optional<Value>> readValuePrefix(KeyRef key, int maxLength, Optional<ReadOptions> options) override {
ReadType type = ReadType::NORMAL;
Optional<UID> debugID;
if (options.present()) {
type = options.get().type;
debugID = options.get().debugID;
}
if (!shouldThrottle(type, key)) {
auto a = new Reader::ReadValuePrefixAction(key, maxLength, debugID);
auto res = a->result.getFuture();
@ -1907,8 +1920,8 @@ struct RocksDBKeyValueStore : IKeyValueStore {
return res;
}
auto& semaphore = (type == IKeyValueStore::ReadType::FETCH) ? fetchSemaphore : readSemaphore;
int maxWaiters = (type == IKeyValueStore::ReadType::FETCH) ? numFetchWaiters : numReadWaiters;
auto& semaphore = (type == ReadType::FETCH) ? fetchSemaphore : readSemaphore;
int maxWaiters = (type == ReadType::FETCH) ? numFetchWaiters : numReadWaiters;
checkWaiters(semaphore, maxWaiters);
auto a = std::make_unique<Reader::ReadValuePrefixAction>(key, maxLength, debugID);
@ -1938,7 +1951,13 @@ struct RocksDBKeyValueStore : IKeyValueStore {
Future<RangeResult> readRange(KeyRangeRef keys,
int rowLimit,
int byteLimit,
IKeyValueStore::ReadType type) override {
Optional<ReadOptions> options) override {
ReadType type = ReadType::NORMAL;
if (options.present()) {
type = options.get().type;
}
if (!shouldThrottle(type, keys.begin)) {
auto a = new Reader::ReadRangeAction(keys, rowLimit, byteLimit);
auto res = a->result.getFuture();
@ -1946,8 +1965,8 @@ struct RocksDBKeyValueStore : IKeyValueStore {
return res;
}
auto& semaphore = (type == IKeyValueStore::ReadType::FETCH) ? fetchSemaphore : readSemaphore;
int maxWaiters = (type == IKeyValueStore::ReadType::FETCH) ? numFetchWaiters : numReadWaiters;
auto& semaphore = (type == ReadType::FETCH) ? fetchSemaphore : readSemaphore;
int maxWaiters = (type == ReadType::FETCH) ? numFetchWaiters : numReadWaiters;
checkWaiters(semaphore, maxWaiters);
auto a = std::make_unique<Reader::ReadRangeAction>(keys, rowLimit, byteLimit);

View File

@ -1589,12 +1589,12 @@ public:
void clear(KeyRangeRef range, const Arena* arena = nullptr) override;
Future<Void> commit(bool sequential = false) override;
Future<Optional<Value>> readValue(KeyRef key, IKeyValueStore::ReadType, Optional<UID> debugID) override;
Future<Optional<Value>> readValuePrefix(KeyRef key,
int maxLength,
IKeyValueStore::ReadType,
Optional<UID> debugID) override;
Future<RangeResult> readRange(KeyRangeRef keys, int rowLimit, int byteLimit, IKeyValueStore::ReadType) override;
Future<Optional<Value>> readValue(KeyRef key, Optional<ReadOptions> optionss) override;
Future<Optional<Value>> readValuePrefix(KeyRef key, int maxLength, Optional<ReadOptions> options) override;
Future<RangeResult> readRange(KeyRangeRef keys,
int rowLimit,
int byteLimit,
Optional<ReadOptions> options) override;
KeyValueStoreSQLite(std::string const& filename,
UID logID,
@ -2216,18 +2216,23 @@ Future<Void> KeyValueStoreSQLite::commit(bool sequential) {
writeThread->post(p);
return f;
}
Future<Optional<Value>> KeyValueStoreSQLite::readValue(KeyRef key, IKeyValueStore::ReadType, Optional<UID> debugID) {
Future<Optional<Value>> KeyValueStoreSQLite::readValue(KeyRef key, Optional<ReadOptions> options) {
++readsRequested;
Optional<UID> debugID;
if (options.present()) {
debugID = options.get().debugID;
}
auto p = new Reader::ReadValueAction(key, debugID);
auto f = p->result.getFuture();
readThreads->post(p);
return f;
}
Future<Optional<Value>> KeyValueStoreSQLite::readValuePrefix(KeyRef key,
int maxLength,
IKeyValueStore::ReadType,
Optional<UID> debugID) {
Future<Optional<Value>> KeyValueStoreSQLite::readValuePrefix(KeyRef key, int maxLength, Optional<ReadOptions> options) {
++readsRequested;
Optional<UID> debugID;
if (options.present()) {
debugID = options.get().debugID;
}
auto p = new Reader::ReadValuePrefixAction(key, maxLength, debugID);
auto f = p->result.getFuture();
readThreads->post(p);
@ -2236,7 +2241,7 @@ Future<Optional<Value>> KeyValueStoreSQLite::readValuePrefix(KeyRef key,
Future<RangeResult> KeyValueStoreSQLite::readRange(KeyRangeRef keys,
int rowLimit,
int byteLimit,
IKeyValueStore::ReadType) {
Optional<ReadOptions> options) {
++readsRequested;
auto p = new Reader::ReadRangeAction(keys, rowLimit, byteLimit);
auto f = p->result.getFuture();

View File

@ -288,6 +288,7 @@ rocksdb::Options getOptions() {
options.max_background_jobs = SERVER_KNOBS->ROCKSDB_MAX_BACKGROUND_JOBS;
options.db_write_buffer_size = SERVER_KNOBS->ROCKSDB_WRITE_BUFFER_SIZE;
options.write_buffer_size = SERVER_KNOBS->ROCKSDB_CF_WRITE_BUFFER_SIZE;
options.statistics = rocksdb::CreateDBStatistics();
options.statistics->set_stats_level(rocksdb::kExceptHistogramOrTimers);
options.db_log_dir = SERVER_KNOBS->LOG_DIRECTORY;
@ -2309,8 +2310,8 @@ struct ShardedRocksDBKeyValueStore : IKeyValueStore {
// We don't throttle eager reads and reads to the FF keyspace because FDB struggles when those reads fail.
// Thus far, they have been low enough volume to not cause an issue.
static bool shouldThrottle(IKeyValueStore::ReadType type, KeyRef key) {
return type != IKeyValueStore::ReadType::EAGER && !(key.startsWith(systemKeys.begin));
static bool shouldThrottle(ReadType type, KeyRef key) {
return type != ReadType::EAGER && !(key.startsWith(systemKeys.begin));
}
ACTOR template <class Action>
@ -2331,7 +2332,7 @@ struct ShardedRocksDBKeyValueStore : IKeyValueStore {
return result;
}
Future<Optional<Value>> readValue(KeyRef key, IKeyValueStore::ReadType type, Optional<UID> debugID) override {
Future<Optional<Value>> readValue(KeyRef key, Optional<ReadOptions> options) override {
auto* shard = shardManager.getDataShard(key);
if (shard == nullptr || !shard->physicalShard->initialized()) {
// TODO: read non-exist system key range should not cause an error.
@ -2341,6 +2342,14 @@ struct ShardedRocksDBKeyValueStore : IKeyValueStore {
return Optional<Value>();
}
ReadType type = ReadType::NORMAL;
Optional<UID> debugID;
if (options.present()) {
type = options.get().type;
debugID = options.get().debugID;
}
if (!shouldThrottle(type, key)) {
auto a = new Reader::ReadValueAction(key, shard->physicalShard, debugID);
auto res = a->result.getFuture();
@ -2348,18 +2357,15 @@ struct ShardedRocksDBKeyValueStore : IKeyValueStore {
return res;
}
auto& semaphore = (type == IKeyValueStore::ReadType::FETCH) ? fetchSemaphore : readSemaphore;
int maxWaiters = (type == IKeyValueStore::ReadType::FETCH) ? numFetchWaiters : numReadWaiters;
auto& semaphore = (type == ReadType::FETCH) ? fetchSemaphore : readSemaphore;
int maxWaiters = (type == ReadType::FETCH) ? numFetchWaiters : numReadWaiters;
checkWaiters(semaphore, maxWaiters);
auto a = std::make_unique<Reader::ReadValueAction>(key, shard->physicalShard, debugID);
return read(a.release(), &semaphore, readThreads.getPtr(), &counters.failedToAcquire);
}
Future<Optional<Value>> readValuePrefix(KeyRef key,
int maxLength,
IKeyValueStore::ReadType type,
Optional<UID> debugID) override {
Future<Optional<Value>> readValuePrefix(KeyRef key, int maxLength, Optional<ReadOptions> options) override {
auto* shard = shardManager.getDataShard(key);
if (shard == nullptr || !shard->physicalShard->initialized()) {
// TODO: read non-exist system key range should not cause an error.
@ -2369,6 +2375,14 @@ struct ShardedRocksDBKeyValueStore : IKeyValueStore {
return Optional<Value>();
}
ReadType type = ReadType::NORMAL;
Optional<UID> debugID;
if (options.present()) {
type = options.get().type;
debugID = options.get().debugID;
}
if (!shouldThrottle(type, key)) {
auto a = new Reader::ReadValuePrefixAction(key, maxLength, shard->physicalShard, debugID);
auto res = a->result.getFuture();
@ -2376,8 +2390,8 @@ struct ShardedRocksDBKeyValueStore : IKeyValueStore {
return res;
}
auto& semaphore = (type == IKeyValueStore::ReadType::FETCH) ? fetchSemaphore : readSemaphore;
int maxWaiters = (type == IKeyValueStore::ReadType::FETCH) ? numFetchWaiters : numReadWaiters;
auto& semaphore = (type == ReadType::FETCH) ? fetchSemaphore : readSemaphore;
int maxWaiters = (type == ReadType::FETCH) ? numFetchWaiters : numReadWaiters;
checkWaiters(semaphore, maxWaiters);
auto a = std::make_unique<Reader::ReadValuePrefixAction>(key, maxLength, shard->physicalShard, debugID);
@ -2407,10 +2421,15 @@ struct ShardedRocksDBKeyValueStore : IKeyValueStore {
Future<RangeResult> readRange(KeyRangeRef keys,
int rowLimit,
int byteLimit,
IKeyValueStore::ReadType type) override {
Optional<ReadOptions> options = Optional<ReadOptions>()) override {
TraceEvent(SevVerbose, "ShardedRocksReadRangeBegin", this->id).detail("Range", keys);
auto shards = shardManager.getDataShardsByRange(keys);
ReadType type = ReadType::NORMAL;
if (options.present()) {
type = options.get().type;
}
if (!shouldThrottle(type, keys.begin)) {
auto a = new Reader::ReadRangeAction(keys, shards, rowLimit, byteLimit);
auto res = a->result.getFuture();
@ -2418,8 +2437,8 @@ struct ShardedRocksDBKeyValueStore : IKeyValueStore {
return res;
}
auto& semaphore = (type == IKeyValueStore::ReadType::FETCH) ? fetchSemaphore : readSemaphore;
int maxWaiters = (type == IKeyValueStore::ReadType::FETCH) ? numFetchWaiters : numReadWaiters;
auto& semaphore = (type == ReadType::FETCH) ? fetchSemaphore : readSemaphore;
int maxWaiters = (type == ReadType::FETCH) ? numFetchWaiters : numReadWaiters;
checkWaiters(semaphore, maxWaiters);
auto a = std::make_unique<Reader::ReadRangeAction>(keys, shards, rowLimit, byteLimit);
@ -2608,24 +2627,21 @@ TEST_CASE("noSim/ShardedRocksDB/RangeOps") {
// Range read
// Read forward full range.
RangeResult result =
wait(kvStore->readRange(KeyRangeRef("0"_sr, ":"_sr), 1000, 10000, IKeyValueStore::ReadType::NORMAL));
RangeResult result = wait(kvStore->readRange(KeyRangeRef("0"_sr, ":"_sr), 1000, 10000));
ASSERT_EQ(result.size(), expectedRows.size());
for (int i = 0; i < expectedRows.size(); ++i) {
ASSERT(result[i] == expectedRows[i]);
}
// Read backward full range.
RangeResult result =
wait(kvStore->readRange(KeyRangeRef("0"_sr, ":"_sr), -1000, 10000, IKeyValueStore::ReadType::NORMAL));
RangeResult result = wait(kvStore->readRange(KeyRangeRef("0"_sr, ":"_sr), -1000, 10000));
ASSERT_EQ(result.size(), expectedRows.size());
for (int i = 0; i < expectedRows.size(); ++i) {
ASSERT(result[i] == expectedRows[59 - i]);
}
// Forward with row limit.
RangeResult result =
wait(kvStore->readRange(KeyRangeRef("2"_sr, "6"_sr), 10, 10000, IKeyValueStore::ReadType::NORMAL));
RangeResult result = wait(kvStore->readRange(KeyRangeRef("2"_sr, "6"_sr), 10, 10000));
ASSERT_EQ(result.size(), 10);
for (int i = 0; i < 10; ++i) {
ASSERT(result[i] == expectedRows[20 + i]);
@ -2651,16 +2667,14 @@ TEST_CASE("noSim/ShardedRocksDB/RangeOps") {
wait(kvStore->init());
// Read all values.
RangeResult result =
wait(kvStore->readRange(KeyRangeRef("0"_sr, ":"_sr), 1000, 10000, IKeyValueStore::ReadType::NORMAL));
RangeResult result = wait(kvStore->readRange(KeyRangeRef("0"_sr, ":"_sr), 1000, 10000));
ASSERT_EQ(result.size(), expectedRows.size());
for (int i = 0; i < expectedRows.size(); ++i) {
ASSERT(result[i] == expectedRows[i]);
}
// Read partial range with row limit
RangeResult result =
wait(kvStore->readRange(KeyRangeRef("5"_sr, ":"_sr), 35, 10000, IKeyValueStore::ReadType::NORMAL));
RangeResult result = wait(kvStore->readRange(KeyRangeRef("5"_sr, ":"_sr), 35, 10000));
ASSERT_EQ(result.size(), 35);
for (int i = 0; i < result.size(); ++i) {
ASSERT(result[i] == expectedRows[40 + i]);
@ -2670,8 +2684,7 @@ TEST_CASE("noSim/ShardedRocksDB/RangeOps") {
kvStore->clear(KeyRangeRef("40"_sr, "45"_sr));
wait(kvStore->commit(false));
RangeResult result =
wait(kvStore->readRange(KeyRangeRef("4"_sr, "5"_sr), 20, 10000, IKeyValueStore::ReadType::NORMAL));
RangeResult result = wait(kvStore->readRange(KeyRangeRef("4"_sr, "5"_sr), 20, 10000));
ASSERT_EQ(result.size(), 5);
// Clear a single value.
@ -2691,12 +2704,10 @@ TEST_CASE("noSim/ShardedRocksDB/RangeOps") {
kvStore = new ShardedRocksDBKeyValueStore(rocksDBTestDir, deterministicRandom()->randomUniqueID());
wait(kvStore->init());
RangeResult result =
wait(kvStore->readRange(KeyRangeRef("1"_sr, "8"_sr), 1000, 10000, IKeyValueStore::ReadType::NORMAL));
RangeResult result = wait(kvStore->readRange(KeyRangeRef("1"_sr, "8"_sr), 1000, 10000));
ASSERT_EQ(result.size(), 0);
RangeResult result =
wait(kvStore->readRange(KeyRangeRef("0"_sr, ":"_sr), 1000, 10000, IKeyValueStore::ReadType::NORMAL));
RangeResult result = wait(kvStore->readRange(KeyRangeRef("0"_sr, ":"_sr), 1000, 10000));
ASSERT_EQ(result.size(), 19);
Future<Void> closed = kvStore->onClosed();

View File

@ -520,10 +520,6 @@ public:
TraceEvent("RatekeeperHalted", rkInterf.id()).detail("ReqID", req.requesterID);
break;
}
when(GlobalTagThrottlerStatusRequest req = waitNext(rkInterf.getGlobalTagThrottlerStatus.getFuture())) {
req.reply.send(self.tagThrottler->getGlobalTagThrottlerStatusReply());
break;
}
when(ReportCommitCostEstimationRequest req =
waitNext(rkInterf.reportCommitCostEstimation.getFuture())) {
self.updateCommitCostEstimation(req.ssTrTagCommitCost);
@ -930,7 +926,7 @@ void Ratekeeper::updateRate(RatekeeperLimits* limits) {
if (blobWorkerLag > 3 * limits->bwLagTarget) {
targetRateRatio = 0;
ASSERT(!g_network->isSimulated() || limits->bwLagTarget != SERVER_KNOBS->TARGET_BW_LAG ||
now() < FLOW_KNOBS->SIM_SPEEDUP_AFTER_SECONDS + 50);
now() < FLOW_KNOBS->SIM_SPEEDUP_AFTER_SECONDS + SERVER_KNOBS->BW_RK_SIM_QUIESCE_DELAY);
} else if (blobWorkerLag > limits->bwLagTarget) {
targetRateRatio = SERVER_KNOBS->BW_LAG_DECREASE_AMOUNT;
} else {
@ -988,7 +984,7 @@ void Ratekeeper::updateRate(RatekeeperLimits* limits) {
}
limitReason = limitReason_t::blob_worker_missing;
ASSERT(!g_network->isSimulated() || limits->bwLagTarget != SERVER_KNOBS->TARGET_BW_LAG ||
now() < FLOW_KNOBS->SIM_SPEEDUP_AFTER_SECONDS + 50);
now() < FLOW_KNOBS->SIM_SPEEDUP_AFTER_SECONDS + SERVER_KNOBS->BW_RK_SIM_QUIESCE_DELAY);
} else if (bwTPS < limits->tpsLimit) {
if (printRateKeepLimitReasonDetails) {
TraceEvent("RatekeeperLimitReasonDetails")
@ -1017,7 +1013,7 @@ void Ratekeeper::updateRate(RatekeeperLimits* limits) {
}
limitReason = limitReason_t::blob_worker_missing;
ASSERT(!g_network->isSimulated() || limits->bwLagTarget != SERVER_KNOBS->TARGET_BW_LAG ||
now() < FLOW_KNOBS->SIM_SPEEDUP_AFTER_SECONDS + 50);
now() < FLOW_KNOBS->SIM_SPEEDUP_AFTER_SECONDS + SERVER_KNOBS->BW_RK_SIM_QUIESCE_DELAY);
}
} else if (blobWorkerLag > 3 * limits->bwLagTarget) {
limits->tpsLimit = 0.0;
@ -1030,7 +1026,7 @@ void Ratekeeper::updateRate(RatekeeperLimits* limits) {
}
limitReason = limitReason_t::blob_worker_missing;
ASSERT(!g_network->isSimulated() || limits->bwLagTarget != SERVER_KNOBS->TARGET_BW_LAG ||
now() < FLOW_KNOBS->SIM_SPEEDUP_AFTER_SECONDS + 50);
now() < FLOW_KNOBS->SIM_SPEEDUP_AFTER_SECONDS + SERVER_KNOBS->BW_RK_SIM_QUIESCE_DELAY);
}
} else {
blobWorkerTime = now();

View File

@ -18,6 +18,7 @@
* limitations under the License.
*/
#include "fdbserver/IKeyValueStore.h"
#include "flow/ActorCollection.h"
#include "flow/Error.h"
#include "flow/Platform.h"
@ -99,8 +100,7 @@ ACTOR Future<Void> runIKVS(OpenKVStoreRequest openReq, IKVSInterface ikvsInterfa
try {
choose {
when(IKVSGetValueRequest getReq = waitNext(ikvsInterface.getValue.getFuture())) {
actors.add(cancellableForwardPromise(getReq.reply,
kvStore->readValue(getReq.key, getReq.type, getReq.debugID)));
actors.add(cancellableForwardPromise(getReq.reply, kvStore->readValue(getReq.key, getReq.options)));
}
when(IKVSSetRequest req = waitNext(ikvsInterface.set.getFuture())) { kvStore->set(req.keyValue); }
when(IKVSClearRequest req = waitNext(ikvsInterface.clear.getFuture())) { kvStore->clear(req.range); }
@ -110,16 +110,16 @@ ACTOR Future<Void> runIKVS(OpenKVStoreRequest openReq, IKVSInterface ikvsInterfa
when(IKVSReadValuePrefixRequest readPrefixReq = waitNext(ikvsInterface.readValuePrefix.getFuture())) {
actors.add(cancellableForwardPromise(
readPrefixReq.reply,
kvStore->readValuePrefix(
readPrefixReq.key, readPrefixReq.maxLength, readPrefixReq.type, readPrefixReq.debugID)));
kvStore->readValuePrefix(readPrefixReq.key, readPrefixReq.maxLength, readPrefixReq.options)));
}
when(IKVSReadRangeRequest readRangeReq = waitNext(ikvsInterface.readRange.getFuture())) {
actors.add(cancellableForwardPromise(
readRangeReq.reply,
fmap(
[](const RangeResult& result) { return IKVSReadRangeReply(result); },
kvStore->readRange(
readRangeReq.keys, readRangeReq.rowLimit, readRangeReq.byteLimit, readRangeReq.type))));
fmap([](const RangeResult& result) { return IKVSReadRangeReply(result); },
kvStore->readRange(readRangeReq.keys,
readRangeReq.rowLimit,
readRangeReq.byteLimit,
readRangeReq.options))));
}
when(IKVSGetStorageByteRequest req = waitNext(ikvsInterface.getStorageBytes.getFuture())) {
StorageBytes storageBytes = kvStore->getStorageBytes();

View File

@ -1147,6 +1147,8 @@ ACTOR Future<Void> restartSimulatedSystem(std::vector<Future<Void>>* systemActor
if (testConfig.disableEncryption) {
g_knobs.setKnob("enable_encryption", KnobValueRef::create(bool{ false }));
g_knobs.setKnob("enable_tlog_encryption", KnobValueRef::create(bool{ false }));
g_knobs.setKnob("enable_storage_server_encryption", KnobValueRef::create(bool{ false }));
g_knobs.setKnob("enable_blob_granule_encryption", KnobValueRef::create(bool{ false }));
TraceEvent(SevDebug, "DisableEncryption");
}
*pConnString = conn;
@ -1930,6 +1932,8 @@ void setupSimulatedSystem(std::vector<Future<Void>>* systemActors,
if (testConfig.disableEncryption) {
g_knobs.setKnob("enable_encryption", KnobValueRef::create(bool{ false }));
g_knobs.setKnob("enable_tlog_encryption", KnobValueRef::create(bool{ false }));
g_knobs.setKnob("enable_storage_server_encryption", KnobValueRef::create(bool{ false }));
g_knobs.setKnob("enable_blob_granule_encryption", KnobValueRef::create(bool{ false }));
TraceEvent(SevDebug, "DisableEncryption");
}
auto configDBType = testConfig.getConfigDBType();

View File

@ -2159,16 +2159,9 @@ ACTOR static Future<JsonBuilderObject> workloadStatusFetcher(
timeoutError(rkWorker.interf.eventLogRequest.getReply(EventLogRequest(LiteralStringRef("RkUpdate"))), 1.0);
state Future<TraceEventFields> f2 = timeoutError(
rkWorker.interf.eventLogRequest.getReply(EventLogRequest(LiteralStringRef("RkUpdateBatch"))), 1.0);
state Future<GlobalTagThrottlerStatusReply> f3 =
SERVER_KNOBS->GLOBAL_TAG_THROTTLING
? timeoutError(db->get().ratekeeper.get().getGlobalTagThrottlerStatus.getReply(
GlobalTagThrottlerStatusRequest{}),
1.0)
: Future<GlobalTagThrottlerStatusReply>(GlobalTagThrottlerStatusReply{});
wait(success(f1) && success(f2) && success(f3));
wait(success(f1) && success(f2));
TraceEventFields ratekeeper = f1.get();
TraceEventFields batchRatekeeper = f2.get();
auto const globalTagThrottlerStatus = f3.get();
bool autoThrottlingEnabled = ratekeeper.getInt("AutoThrottlingEnabled");
double tpsLimit = ratekeeper.getDouble("TPSLimit");
@ -2230,23 +2223,6 @@ ACTOR static Future<JsonBuilderObject> workloadStatusFetcher(
(*qos)["throttled_tags"] = throttledTagsObj;
if (SERVER_KNOBS->GLOBAL_TAG_THROTTLING) {
JsonBuilderObject globalTagThrottlerObj;
for (const auto& [tag, tagStats] : globalTagThrottlerStatus.status) {
JsonBuilderObject tagStatsObj;
tagStatsObj["desired_tps"] = tagStats.desiredTps;
tagStatsObj["reserved_tps"] = tagStats.reservedTps;
if (tagStats.limitingTps.present()) {
tagStatsObj["limiting_tps"] = tagStats.limitingTps.get();
} else {
tagStatsObj["limiting_tps"] = "<unset>"_sr;
}
tagStatsObj["target_tps"] = tagStats.targetTps;
globalTagThrottlerObj[printable(tag)] = tagStatsObj;
}
(*qos)["global_tag_throttler"] = globalTagThrottlerObj;
}
JsonBuilderObject perfLimit = getPerfLimit(ratekeeper, transPerSec, tpsLimit);
if (!perfLimit.empty()) {
(*qos)["performance_limited_by"] = perfLimit;

View File

@ -18,6 +18,7 @@
* limitations under the License.
*/
#include "fdbclient/FDBTypes.h"
#include "fdbserver/OTELSpanContextMessage.h"
#include "flow/Arena.h"
#include "fdbclient/FDBOptions.g.h"
@ -480,18 +481,18 @@ ACTOR Future<Void> getValueQ(StorageCacheData* data, GetValueRequest req) {
// TODO what's this?
wait(delay(0, TaskPriority::DefaultEndpoint));
if (req.debugID.present()) {
if (req.options.present() && req.options.get().debugID.present()) {
g_traceBatch.addEvent("GetValueDebug",
req.debugID.get().first(),
req.options.get().debugID.get().first(),
"getValueQ.DoRead"); //.detail("TaskID", g_network->getCurrentTask());
// FIXME
}
state Optional<Value> v;
state Version version = wait(waitForVersion(data, req.version));
if (req.debugID.present())
if (req.options.present() && req.options.get().debugID.present())
g_traceBatch.addEvent("GetValueDebug",
req.debugID.get().first(),
req.options.get().debugID.get().first(),
"getValueQ.AfterVersion"); //.detail("TaskID", g_network->getCurrentTask());
state uint64_t changeCounter = data->cacheRangeChangeCounter;
@ -526,9 +527,9 @@ ACTOR Future<Void> getValueQ(StorageCacheData* data, GetValueRequest req) {
//TraceEvent(SevDebug, "SCGetValueQPresent", data->thisServerID).detail("ResultSize",resultSize).detail("Version", version).detail("ReqKey",req.key).detail("Value",v);
}
if (req.debugID.present())
if (req.options.present() && req.options.get().debugID.present())
g_traceBatch.addEvent("GetValueDebug",
req.debugID.get().first(),
req.options.get().debugID.get().first(),
"getValueQ.AfterRead"); //.detail("TaskID", g_network->getCurrentTask());
GetValueReply reply(v, true);
@ -731,7 +732,7 @@ ACTOR Future<Void> getKeyValues(StorageCacheData* data, GetKeyValuesRequest req)
// Active load balancing runs at a very high priority (to obtain accurate queue lengths)
// so we need to downgrade here
TaskPriority taskType = TaskPriority::DefaultEndpoint;
if (SERVER_KNOBS->FETCH_KEYS_LOWER_PRIORITY && req.isFetchKeys) {
if (SERVER_KNOBS->FETCH_KEYS_LOWER_PRIORITY && req.options.present() && req.options.get().type == ReadType::FETCH) {
taskType = TaskPriority::FetchKeys;
// } else if (false) {
// // Placeholder for up-prioritizing fetches for important requests
@ -740,17 +741,18 @@ ACTOR Future<Void> getKeyValues(StorageCacheData* data, GetKeyValuesRequest req)
wait(delay(0, taskType));
try {
if (req.debugID.present())
g_traceBatch.addEvent("TransactionDebug", req.debugID.get().first(), "storagecache.getKeyValues.Before");
if (req.options.present() && req.options.get().debugID.present())
g_traceBatch.addEvent(
"TransactionDebug", req.options.get().debugID.get().first(), "storagecache.getKeyValues.Before");
state Version version = wait(waitForVersion(data, req.version));
state uint64_t changeCounter = data->cacheRangeChangeCounter;
state KeyRange cachedKeyRange = getCachedKeyRange(data, req.begin);
if (req.debugID.present())
if (req.options.present() && req.options.get().debugID.present())
g_traceBatch.addEvent(
"TransactionDebug", req.debugID.get().first(), "storagecache.getKeyValues.AfterVersion");
"TransactionDebug", req.options.get().debugID.get().first(), "storagecache.getKeyValues.AfterVersion");
//.detail("CacheRangeBegin", cachedKeyRange.begin).detail("CacheRangeEnd", cachedKeyRange.end);
if (!selectorInRange(req.end, cachedKeyRange) &&
@ -768,8 +770,9 @@ ACTOR Future<Void> getKeyValues(StorageCacheData* data, GetKeyValuesRequest req)
: findKey(data, req.begin, version, cachedKeyRange, &offset1);
state Key end = req.end.isFirstGreaterOrEqual() ? req.end.getKey()
: findKey(data, req.end, version, cachedKeyRange, &offset2);
if (req.debugID.present())
g_traceBatch.addEvent("TransactionDebug", req.debugID.get().first(), "storagecache.getKeyValues.AfterKeys");
if (req.options.present() && req.options.get().debugID.present())
g_traceBatch.addEvent(
"TransactionDebug", req.options.get().debugID.get().first(), "storagecache.getKeyValues.AfterKeys");
//.detail("Off1",offset1).detail("Off2",offset2).detail("ReqBegin",req.begin.getKey()).detail("ReqEnd",req.end.getKey());
// Offsets of zero indicate begin/end keys in this cachedKeyRange, which obviously means we can answer the query
@ -794,8 +797,9 @@ ACTOR Future<Void> getKeyValues(StorageCacheData* data, GetKeyValuesRequest req)
// offset1).detail("EndOffset", offset2);
if (begin >= end) {
if (req.debugID.present())
g_traceBatch.addEvent("TransactionDebug", req.debugID.get().first(), "storagecache.getKeyValues.Send");
if (req.options.present() && req.options.get().debugID.present())
g_traceBatch.addEvent(
"TransactionDebug", req.options.get().debugID.get().first(), "storagecache.getKeyValues.Send");
//.detail("Begin",begin).detail("End",end);
GetKeyValuesReply none;
@ -811,9 +815,10 @@ ACTOR Future<Void> getKeyValues(StorageCacheData* data, GetKeyValuesRequest req)
GetKeyValuesReply _r = readRange(data, version, KeyRangeRef(begin, end), req.limit, &remainingLimitBytes);
GetKeyValuesReply r = _r;
if (req.debugID.present())
g_traceBatch.addEvent(
"TransactionDebug", req.debugID.get().first(), "storagecache.getKeyValues.AfterReadRange");
if (req.options.present() && req.options.get().debugID.present())
g_traceBatch.addEvent("TransactionDebug",
req.options.get().debugID.get().first(),
"storagecache.getKeyValues.AfterReadRange");
data->checkChangeCounter(
changeCounter,
KeyRangeRef(std::min<KeyRef>(begin, std::min<KeyRef>(req.begin.getKey(), req.end.getKey())),
@ -1182,6 +1187,7 @@ ACTOR Future<RangeResult> tryFetchRange(Database cx,
state RangeResult output;
state KeySelectorRef begin = firstGreaterOrEqual(keys.begin);
state KeySelectorRef end = firstGreaterOrEqual(keys.end);
state ReadOptions options = ReadOptions(Optional<UID>(), ReadType::FETCH);
if (*isTooOld)
throw transaction_too_old();
@ -1189,6 +1195,7 @@ ACTOR Future<RangeResult> tryFetchRange(Database cx,
ASSERT(!cx->switchable);
tr.setVersion(version);
tr.trState->taskID = TaskPriority::FetchKeys;
tr.trState->readOptions = options;
limits.minRows = 0;
try {

View File

@ -22,6 +22,537 @@
#include "fdbserver/StorageMetrics.h"
#include "flow/actorcompiler.h" // This must be the last #include.
int64_t StorageMetricSample::getEstimate(KeyRangeRef keys) const {
return sample.sumRange(keys.begin, keys.end);
}
KeyRef StorageMetricSample::splitEstimate(KeyRangeRef range, int64_t offset, bool front) const {
auto fwd_split = sample.index(front ? sample.sumTo(sample.lower_bound(range.begin)) + offset
: sample.sumTo(sample.lower_bound(range.end)) - offset);
if (fwd_split == sample.end() || *fwd_split >= range.end)
return range.end;
if (!front && *fwd_split <= range.begin)
return range.begin;
auto bck_split = fwd_split;
// Butterfly search - start at midpoint then go in both directions.
while ((fwd_split != sample.end() && *fwd_split < range.end) ||
(bck_split != sample.begin() && *bck_split > range.begin)) {
if (bck_split != sample.begin() && *bck_split > range.begin) {
auto it = bck_split;
bck_split.decrementNonEnd();
KeyRef split = keyBetween(KeyRangeRef(
bck_split != sample.begin() ? std::max<KeyRef>(*bck_split, range.begin) : range.begin, *it));
if (!front || (getEstimate(KeyRangeRef(range.begin, split)) > 0 &&
split.size() <= CLIENT_KNOBS->SPLIT_KEY_SIZE_LIMIT))
return split;
}
if (fwd_split != sample.end() && *fwd_split < range.end) {
auto it = fwd_split;
++it;
KeyRef split =
keyBetween(KeyRangeRef(*fwd_split, it != sample.end() ? std::min<KeyRef>(*it, range.end) : range.end));
if (front ||
(getEstimate(KeyRangeRef(split, range.end)) > 0 && split.size() <= CLIENT_KNOBS->SPLIT_KEY_SIZE_LIMIT))
return split;
fwd_split = it;
}
}
// If we didn't return above, we didn't find anything.
TraceEvent(SevWarn, "CannotSplitLastSampleKey").detail("Range", range).detail("Offset", offset);
return front ? range.end : range.begin;
}
// Get the current estimated metrics for the given keys
StorageMetrics StorageServerMetrics::getMetrics(KeyRangeRef const& keys) const {
StorageMetrics result;
result.bytes = byteSample.getEstimate(keys);
result.bytesPerKSecond =
bandwidthSample.getEstimate(keys) * SERVER_KNOBS->STORAGE_METRICS_AVERAGE_INTERVAL_PER_KSECONDS;
result.iosPerKSecond = iopsSample.getEstimate(keys) * SERVER_KNOBS->STORAGE_METRICS_AVERAGE_INTERVAL_PER_KSECONDS;
result.bytesReadPerKSecond =
bytesReadSample.getEstimate(keys) * SERVER_KNOBS->STORAGE_METRICS_AVERAGE_INTERVAL_PER_KSECONDS;
return result;
}
// Called when metrics should change (IO for a given key)
// Notifies waiting WaitMetricsRequests through waitMetricsMap, and updates metricsAverageQueue and metricsSampleMap
void StorageServerMetrics::notify(KeyRef key, StorageMetrics& metrics) {
ASSERT(metrics.bytes == 0); // ShardNotifyMetrics
if (g_network->isSimulated()) {
CODE_PROBE(metrics.bytesPerKSecond != 0, "ShardNotifyMetrics bytes");
CODE_PROBE(metrics.iosPerKSecond != 0, "ShardNotifyMetrics ios");
CODE_PROBE(metrics.bytesReadPerKSecond != 0, "ShardNotifyMetrics bytesRead");
}
double expire = now() + SERVER_KNOBS->STORAGE_METRICS_AVERAGE_INTERVAL;
StorageMetrics notifyMetrics;
if (metrics.bytesPerKSecond)
notifyMetrics.bytesPerKSecond = bandwidthSample.addAndExpire(key, metrics.bytesPerKSecond, expire) *
SERVER_KNOBS->STORAGE_METRICS_AVERAGE_INTERVAL_PER_KSECONDS;
if (metrics.iosPerKSecond)
notifyMetrics.iosPerKSecond = iopsSample.addAndExpire(key, metrics.iosPerKSecond, expire) *
SERVER_KNOBS->STORAGE_METRICS_AVERAGE_INTERVAL_PER_KSECONDS;
if (metrics.bytesReadPerKSecond)
notifyMetrics.bytesReadPerKSecond = bytesReadSample.addAndExpire(key, metrics.bytesReadPerKSecond, expire) *
SERVER_KNOBS->STORAGE_METRICS_AVERAGE_INTERVAL_PER_KSECONDS;
if (!notifyMetrics.allZero()) {
auto& v = waitMetricsMap[key];
for (int i = 0; i < v.size(); i++) {
if (g_network->isSimulated()) {
CODE_PROBE(true, "shard notify metrics");
}
// ShardNotifyMetrics
v[i].send(notifyMetrics);
}
}
}
// Due to the fact that read sampling will be called on all reads, use this specialized function to avoid overhead
// around branch misses and unnecessary stack allocation which eventually addes up under heavy load.
void StorageServerMetrics::notifyBytesReadPerKSecond(KeyRef key, int64_t in) {
double expire = now() + SERVER_KNOBS->STORAGE_METRICS_AVERAGE_INTERVAL;
int64_t bytesReadPerKSecond =
bytesReadSample.addAndExpire(key, in, expire) * SERVER_KNOBS->STORAGE_METRICS_AVERAGE_INTERVAL_PER_KSECONDS;
if (bytesReadPerKSecond > 0) {
StorageMetrics notifyMetrics;
notifyMetrics.bytesReadPerKSecond = bytesReadPerKSecond;
auto& v = waitMetricsMap[key];
for (int i = 0; i < v.size(); i++) {
CODE_PROBE(true, "ShardNotifyMetrics");
v[i].send(notifyMetrics);
}
}
}
// Called by StorageServerDisk when the size of a key in byteSample changes, to notify WaitMetricsRequest
// Should not be called for keys past allKeys.end
void StorageServerMetrics::notifyBytes(
RangeMap<Key, std::vector<PromiseStream<StorageMetrics>>, KeyRangeRef>::iterator shard,
int64_t bytes) {
ASSERT(shard.end() <= allKeys.end);
StorageMetrics notifyMetrics;
notifyMetrics.bytes = bytes;
for (int i = 0; i < shard.value().size(); i++) {
CODE_PROBE(true, "notifyBytes");
shard.value()[i].send(notifyMetrics);
}
}
// Called by StorageServerDisk when the size of a key in byteSample changes, to notify WaitMetricsRequest
void StorageServerMetrics::notifyBytes(KeyRef key, int64_t bytes) {
if (key >= allKeys.end) // Do not notify on changes to internal storage server state
return;
notifyBytes(waitMetricsMap.rangeContaining(key), bytes);
}
// Called when a range of keys becomes unassigned (and therefore not readable), to notify waiting
// WaitMetricsRequests (also other types of wait
// requests in the future?)
void StorageServerMetrics::notifyNotReadable(KeyRangeRef keys) {
auto rs = waitMetricsMap.intersectingRanges(keys);
for (auto r = rs.begin(); r != rs.end(); ++r) {
auto& v = r->value();
CODE_PROBE(v.size(), "notifyNotReadable() sending errors to intersecting ranges");
for (int n = 0; n < v.size(); n++)
v[n].sendError(wrong_shard_server());
}
}
// Called periodically (~1 sec intervals) to remove older IOs from the averages
// Removes old entries from metricsAverageQueue, updates metricsSampleMap accordingly, and notifies
// WaitMetricsRequests through waitMetricsMap.
void StorageServerMetrics::poll() {
{
StorageMetrics m;
m.bytesPerKSecond = SERVER_KNOBS->STORAGE_METRICS_AVERAGE_INTERVAL_PER_KSECONDS;
bandwidthSample.poll(waitMetricsMap, m);
}
{
StorageMetrics m;
m.iosPerKSecond = SERVER_KNOBS->STORAGE_METRICS_AVERAGE_INTERVAL_PER_KSECONDS;
iopsSample.poll(waitMetricsMap, m);
}
{
StorageMetrics m;
m.bytesReadPerKSecond = SERVER_KNOBS->STORAGE_METRICS_AVERAGE_INTERVAL_PER_KSECONDS;
bytesReadSample.poll(waitMetricsMap, m);
}
// bytesSample doesn't need polling because we never call addExpire() on it
}
// This function can run on untrusted user data. We must validate all divisions carefully.
KeyRef StorageServerMetrics::getSplitKey(int64_t remaining,
int64_t estimated,
int64_t limits,
int64_t used,
int64_t infinity,
bool isLastShard,
const StorageMetricSample& sample,
double divisor,
KeyRef const& lastKey,
KeyRef const& key,
bool hasUsed) const {
ASSERT(remaining >= 0);
ASSERT(limits > 0);
ASSERT(divisor > 0);
if (limits < infinity / 2) {
int64_t expectedSize;
if (isLastShard || remaining > estimated) {
double remaining_divisor = (double(remaining) / limits) + 0.5;
expectedSize = remaining / remaining_divisor;
} else {
// If we are here, then estimated >= remaining >= 0
double estimated_divisor = (double(estimated) / limits) + 0.5;
expectedSize = remaining / estimated_divisor;
}
if (remaining > expectedSize) {
// This does the conversion from native units to bytes using the divisor.
double offset = (expectedSize - used) / divisor;
if (offset <= 0)
return hasUsed ? lastKey : key;
return sample.splitEstimate(
KeyRangeRef(lastKey, key),
offset * ((1.0 - SERVER_KNOBS->SPLIT_JITTER_AMOUNT) +
2 * deterministicRandom()->random01() * SERVER_KNOBS->SPLIT_JITTER_AMOUNT));
}
}
return key;
}
void StorageServerMetrics::splitMetrics(SplitMetricsRequest req) const {
int minSplitBytes = req.minSplitBytes.present() ? req.minSplitBytes.get() : SERVER_KNOBS->MIN_SHARD_BYTES;
try {
SplitMetricsReply reply;
KeyRef lastKey = req.keys.begin;
StorageMetrics used = req.used;
StorageMetrics estimated = req.estimated;
StorageMetrics remaining = getMetrics(req.keys) + used;
//TraceEvent("SplitMetrics").detail("Begin", req.keys.begin).detail("End", req.keys.end).detail("Remaining", remaining.bytes).detail("Used", used.bytes).detail("MinSplitBytes", minSplitBytes);
while (true) {
if (remaining.bytes < 2 * minSplitBytes)
break;
KeyRef key = req.keys.end;
bool hasUsed = used.bytes != 0 || used.bytesPerKSecond != 0 || used.iosPerKSecond != 0;
key = getSplitKey(remaining.bytes,
estimated.bytes,
req.limits.bytes,
used.bytes,
req.limits.infinity,
req.isLastShard,
byteSample,
1,
lastKey,
key,
hasUsed);
if (used.bytes < minSplitBytes)
key = std::max(
key, byteSample.splitEstimate(KeyRangeRef(lastKey, req.keys.end), minSplitBytes - used.bytes));
key = getSplitKey(remaining.iosPerKSecond,
estimated.iosPerKSecond,
req.limits.iosPerKSecond,
used.iosPerKSecond,
req.limits.infinity,
req.isLastShard,
iopsSample,
SERVER_KNOBS->STORAGE_METRICS_AVERAGE_INTERVAL_PER_KSECONDS,
lastKey,
key,
hasUsed);
key = getSplitKey(remaining.bytesPerKSecond,
estimated.bytesPerKSecond,
req.limits.bytesPerKSecond,
used.bytesPerKSecond,
req.limits.infinity,
req.isLastShard,
bandwidthSample,
SERVER_KNOBS->STORAGE_METRICS_AVERAGE_INTERVAL_PER_KSECONDS,
lastKey,
key,
hasUsed);
ASSERT(key != lastKey || hasUsed);
if (key == req.keys.end)
break;
reply.splits.push_back_deep(reply.splits.arena(), key);
StorageMetrics diff = (getMetrics(KeyRangeRef(lastKey, key)) + used);
remaining -= diff;
estimated -= diff;
used = StorageMetrics();
lastKey = key;
}
reply.used = getMetrics(KeyRangeRef(lastKey, req.keys.end)) + used;
req.reply.send(reply);
} catch (Error& e) {
req.reply.sendError(e);
}
}
void StorageServerMetrics::getStorageMetrics(GetStorageMetricsRequest req,
StorageBytes sb,
double bytesInputRate,
int64_t versionLag,
double lastUpdate) const {
GetStorageMetricsReply rep;
// SOMEDAY: make bytes dynamic with hard disk space
rep.load = getMetrics(allKeys);
if (sb.free < 1e9) {
TraceEvent(SevWarn, "PhysicalDiskMetrics")
.suppressFor(60.0)
.detail("Free", sb.free)
.detail("Total", sb.total)
.detail("Available", sb.available)
.detail("Load", rep.load.bytes);
}
rep.available.bytes = sb.available;
rep.available.iosPerKSecond = 10e6;
rep.available.bytesPerKSecond = 100e9;
rep.available.bytesReadPerKSecond = 100e9;
rep.capacity.bytes = sb.total;
rep.capacity.iosPerKSecond = 10e6;
rep.capacity.bytesPerKSecond = 100e9;
rep.capacity.bytesReadPerKSecond = 100e9;
rep.bytesInputRate = bytesInputRate;
rep.versionLag = versionLag;
rep.lastUpdate = lastUpdate;
req.reply.send(rep);
}
// Given a read hot shard, this function will divide the shard into chunks and find those chunks whose
// readBytes/sizeBytes exceeds the `readDensityRatio`. Please make sure to run unit tests
// `StorageMetricsSampleTests.txt` after change made.
std::vector<ReadHotRangeWithMetrics> StorageServerMetrics::getReadHotRanges(
KeyRangeRef shard,
double readDensityRatio,
int64_t baseChunkSize,
int64_t minShardReadBandwidthPerKSeconds) const {
std::vector<ReadHotRangeWithMetrics> toReturn;
double shardSize = (double)byteSample.getEstimate(shard);
int64_t shardReadBandwidth = bytesReadSample.getEstimate(shard);
if (shardReadBandwidth * SERVER_KNOBS->STORAGE_METRICS_AVERAGE_INTERVAL_PER_KSECONDS <=
minShardReadBandwidthPerKSeconds) {
return toReturn;
}
if (shardSize <= baseChunkSize) {
// Shard is small, use it as is
if (bytesReadSample.getEstimate(shard) > (readDensityRatio * shardSize)) {
toReturn.emplace_back(shard,
bytesReadSample.getEstimate(shard) / shardSize,
bytesReadSample.getEstimate(shard) / SERVER_KNOBS->STORAGE_METRICS_AVERAGE_INTERVAL);
}
return toReturn;
}
KeyRef beginKey = shard.begin;
auto endKey =
byteSample.sample.index(byteSample.sample.sumTo(byteSample.sample.lower_bound(beginKey)) + baseChunkSize);
while (endKey != byteSample.sample.end()) {
if (*endKey > shard.end) {
endKey = byteSample.sample.lower_bound(shard.end);
if (*endKey == beginKey) {
// No need to increment endKey since otherwise it would stuck here forever.
break;
}
}
if (*endKey == beginKey) {
++endKey;
continue;
}
if (bytesReadSample.getEstimate(KeyRangeRef(beginKey, *endKey)) >
(readDensityRatio * std::max(baseChunkSize, byteSample.getEstimate(KeyRangeRef(beginKey, *endKey))))) {
auto range = KeyRangeRef(beginKey, *endKey);
if (!toReturn.empty() && toReturn.back().keys.end == range.begin) {
// in case two consecutive chunks both are over the ratio, merge them.
range = KeyRangeRef(toReturn.back().keys.begin, *endKey);
toReturn.pop_back();
}
toReturn.emplace_back(range,
(double)bytesReadSample.getEstimate(range) /
std::max(baseChunkSize, byteSample.getEstimate(range)),
bytesReadSample.getEstimate(range) / SERVER_KNOBS->STORAGE_METRICS_AVERAGE_INTERVAL);
}
beginKey = *endKey;
endKey =
byteSample.sample.index(byteSample.sample.sumTo(byteSample.sample.lower_bound(beginKey)) + baseChunkSize);
}
return toReturn;
}
void StorageServerMetrics::getReadHotRanges(ReadHotSubRangeRequest req) const {
ReadHotSubRangeReply reply;
auto _ranges = getReadHotRanges(req.keys,
SERVER_KNOBS->SHARD_MAX_READ_DENSITY_RATIO,
SERVER_KNOBS->READ_HOT_SUB_RANGE_CHUNK_SIZE,
SERVER_KNOBS->SHARD_READ_HOT_BANDWIDTH_MIN_PER_KSECONDS);
reply.readHotRanges = VectorRef(_ranges.data(), _ranges.size());
req.reply.send(reply);
}
void StorageServerMetrics::getSplitPoints(SplitRangeRequest req, Optional<Key> prefix) const {
SplitRangeReply reply;
KeyRangeRef range = req.keys;
if (prefix.present()) {
range = range.withPrefix(prefix.get(), req.arena);
}
std::vector<KeyRef> points = getSplitPoints(range, req.chunkSize, prefix);
reply.splitPoints.append_deep(reply.splitPoints.arena(), points.data(), points.size());
req.reply.send(reply);
}
std::vector<KeyRef> StorageServerMetrics::getSplitPoints(KeyRangeRef range,
int64_t chunkSize,
Optional<Key> prefixToRemove) const {
std::vector<KeyRef> toReturn;
KeyRef beginKey = range.begin;
IndexedSet<Key, int64_t>::const_iterator endKey =
byteSample.sample.index(byteSample.sample.sumTo(byteSample.sample.lower_bound(beginKey)) + chunkSize);
while (endKey != byteSample.sample.end()) {
if (*endKey > range.end) {
break;
}
if (*endKey == beginKey) {
++endKey;
continue;
}
KeyRef splitPoint = *endKey;
if (prefixToRemove.present()) {
splitPoint = splitPoint.removePrefix(prefixToRemove.get());
}
toReturn.push_back(splitPoint);
beginKey = *endKey;
endKey = byteSample.sample.index(byteSample.sample.sumTo(byteSample.sample.lower_bound(beginKey)) + chunkSize);
}
return toReturn;
}
void StorageServerMetrics::collapse(KeyRangeMap<int>& map, KeyRef const& key) {
auto range = map.rangeContaining(key);
if (range == map.ranges().begin() || range == map.ranges().end())
return;
int value = range->value();
auto prev = range;
--prev;
if (prev->value() != value)
return;
KeyRange keys = KeyRangeRef(prev->begin(), range->end());
map.insert(keys, value);
}
void StorageServerMetrics::add(KeyRangeMap<int>& map, KeyRangeRef const& keys, int delta) {
auto rs = map.modify(keys);
for (auto r = rs.begin(); r != rs.end(); ++r)
r->value() += delta;
collapse(map, keys.begin);
collapse(map, keys.end);
}
// Returns the sampled metric value (possibly 0, possibly increased by the sampling factor)
int64_t TransientStorageMetricSample::addAndExpire(KeyRef key, int64_t metric, double expiration) {
int64_t x = add(key, metric);
if (x)
queue.emplace_back(expiration, std::make_pair(*sample.find(key), -x));
return x;
}
// FIXME: both versions of erase are broken, because they do not remove items in the queue with will subtract a
// metric from the value sometime in the future
int64_t TransientStorageMetricSample::erase(KeyRef key) {
auto it = sample.find(key);
if (it == sample.end())
return 0;
int64_t x = sample.getMetric(it);
sample.erase(it);
return x;
}
void TransientStorageMetricSample::erase(KeyRangeRef keys) {
sample.erase(keys.begin, keys.end);
}
bool TransientStorageMetricSample::roll(KeyRef key, int64_t metric) const {
return deterministicRandom()->random01() < (double)metric / metricUnitsPerSample; //< SOMEDAY: Better randomInt64?
}
void TransientStorageMetricSample::poll(KeyRangeMap<std::vector<PromiseStream<StorageMetrics>>>& waitMap,
StorageMetrics m) {
double now = ::now();
while (queue.size() && queue.front().first <= now) {
KeyRef key = queue.front().second.first;
int64_t delta = queue.front().second.second;
ASSERT(delta != 0);
if (sample.addMetric(key, delta) == 0)
sample.erase(key);
StorageMetrics deltaM = m * delta;
auto v = waitMap[key];
for (int i = 0; i < v.size(); i++) {
CODE_PROBE(true, "TransientStorageMetricSample poll update");
v[i].send(deltaM);
}
queue.pop_front();
}
}
void TransientStorageMetricSample::poll() {
double now = ::now();
while (queue.size() && queue.front().first <= now) {
KeyRef key = queue.front().second.first;
int64_t delta = queue.front().second.second;
ASSERT(delta != 0);
if (sample.addMetric(key, delta) == 0)
sample.erase(key);
queue.pop_front();
}
}
int64_t TransientStorageMetricSample::add(KeyRef key, int64_t metric) {
if (!metric)
return 0;
int64_t mag = metric < 0 ? -metric : metric;
if (mag < metricUnitsPerSample) {
if (!roll(key, mag))
return 0;
metric = metric < 0 ? -metricUnitsPerSample : metricUnitsPerSample;
}
if (sample.addMetric(key, metric) == 0)
sample.erase(key);
return metric;
}
TEST_CASE("/fdbserver/StorageMetricSample/simple") {
StorageMetricSample s(1000);
s.sample.insert(LiteralStringRef("Apple"), 1000);

View File

@ -2192,15 +2192,15 @@ public:
int64_t remapCleanupWindowBytes,
int concurrentExtentReads,
bool memoryOnly,
std::shared_ptr<IEncryptionKeyProvider> keyProvider,
Reference<IEncryptionKeyProvider> keyProvider,
Promise<Void> errorPromise = {})
: keyProvider(keyProvider), ioLock(FLOW_KNOBS->MAX_OUTSTANDING, ioMaxPriority, FLOW_KNOBS->MAX_OUTSTANDING / 2),
pageCacheBytes(pageCacheSizeBytes), desiredPageSize(desiredPageSize), desiredExtentSize(desiredExtentSize),
filename(filename), memoryOnly(memoryOnly), errorPromise(errorPromise),
remapCleanupWindowBytes(remapCleanupWindowBytes), concurrentExtentReads(new FlowLock(concurrentExtentReads)) {
if (keyProvider == nullptr) {
keyProvider = std::make_shared<NullKeyProvider>();
if (!keyProvider) {
keyProvider = makeReference<NullKeyProvider>();
}
// This sets the page cache size for all PageCacheT instances using the same evictor
@ -2963,11 +2963,8 @@ public:
page->rawData(),
header);
int readBytes = wait(readPhysicalBlock(self,
page->rawData(),
page->rawSize(),
(int64_t)pageID * page->rawSize(),
std::min(priority, ioMaxPriority)));
int readBytes = wait(
readPhysicalBlock(self, page->rawData(), page->rawSize(), (int64_t)pageID * page->rawSize(), priority));
debug_printf("DWALPager(%s) op=readPhysicalDiskReadComplete %s ptr=%p bytes=%d\n",
self->filename.c_str(),
toString(pageID).c_str(),
@ -3958,7 +3955,7 @@ private:
int physicalExtentSize;
int pagesPerExtent;
std::shared_ptr<IEncryptionKeyProvider> keyProvider;
Reference<IEncryptionKeyProvider> keyProvider;
PriorityMultiLock ioLock;
@ -5039,7 +5036,7 @@ public:
VersionedBTree(IPager2* pager,
std::string name,
EncodingType defaultEncodingType,
std::shared_ptr<IEncryptionKeyProvider> keyProvider)
Reference<IEncryptionKeyProvider> keyProvider)
: m_pager(pager), m_encodingType(defaultEncodingType), m_enforceEncodingType(false), m_keyProvider(keyProvider),
m_pBuffer(nullptr), m_mutationCount(0), m_name(name) {
@ -5047,13 +5044,13 @@ public:
// This prevents an attack where an encrypted page is replaced by an attacker with an unencrypted page
// or an encrypted page fabricated using a compromised scheme.
if (ArenaPage::isEncodingTypeEncrypted(m_encodingType)) {
ASSERT(keyProvider != nullptr);
ASSERT(keyProvider.isValid());
m_enforceEncodingType = true;
}
// If key provider isn't given, instantiate the null provider
if (m_keyProvider == nullptr) {
m_keyProvider = std::make_shared<NullKeyProvider>();
if (!m_keyProvider) {
m_keyProvider = makeReference<NullKeyProvider>();
}
m_pBoundaryVerifier = DecodeBoundaryVerifier::getVerifier(name);
@ -5239,6 +5236,17 @@ public:
self->m_lazyClearQueue.recover(self->m_pager, self->m_header.lazyDeleteQueue, "LazyClearQueueRecovered");
debug_printf("BTree recovered.\n");
if (ArenaPage::isEncodingTypeEncrypted(self->m_header.encodingType) &&
self->m_encodingType == EncodingType::XXHash64) {
// On restart the encryption config of the cluster could be unknown. In that case if we find the Redwood
// instance is encrypted, we should use the same encryption encoding.
self->m_encodingType = self->m_header.encodingType;
self->m_enforceEncodingType = true;
TraceEvent("RedwoodBTreeNodeForceEncryption")
.detail("InstanceName", self->m_pager->getName())
.detail("EncodingFound", self->m_header.encodingType)
.detail("EncodingDesired", self->m_encodingType);
}
if (self->m_header.encodingType != self->m_encodingType) {
TraceEvent(SevWarn, "RedwoodBTreeNodeEncodingMismatch")
.detail("InstanceName", self->m_pager->getName())
@ -5535,7 +5543,7 @@ private:
IPager2* m_pager;
EncodingType m_encodingType;
bool m_enforceEncodingType;
std::shared_ptr<IEncryptionKeyProvider> m_keyProvider;
Reference<IEncryptionKeyProvider> m_keyProvider;
// Counter to update with DecodeCache memory usage
int64_t* m_pDecodeCacheMemory = nullptr;
@ -7355,6 +7363,7 @@ public:
private:
PagerEventReasons reason;
Optional<ReadOptions> options;
VersionedBTree* btree;
Reference<IPagerSnapshot> pager;
bool valid;
@ -7420,7 +7429,7 @@ public:
link.get().getChildPage(),
ioMaxPriority,
false,
true),
!options.present() || options.get().cacheResult || path.back().btPage()->height != 2),
[=](Reference<const ArenaPage> p) {
#if REDWOOD_DEBUG
path.push_back({ p, btree->getCursor(p.getPtr(), link), link.get().getChildPage() });
@ -7454,10 +7463,12 @@ public:
// Initialize or reinitialize cursor
Future<Void> init(VersionedBTree* btree_in,
PagerEventReasons reason_in,
Optional<ReadOptions> options_in,
Reference<IPagerSnapshot> pager_in,
BTreeNodeLink root) {
btree = btree_in;
reason = reason_in;
options = options_in;
pager = pager_in;
path.clear();
path.reserve(6);
@ -7652,7 +7663,10 @@ public:
Future<Void> movePrev() { return path.empty() ? Void() : move_impl(this, false); }
};
Future<Void> initBTreeCursor(BTreeCursor* cursor, Version snapshotVersion, PagerEventReasons reason) {
Future<Void> initBTreeCursor(BTreeCursor* cursor,
Version snapshotVersion,
PagerEventReasons reason,
Optional<ReadOptions> options = Optional<ReadOptions>()) {
Reference<IPagerSnapshot> snapshot = m_pager->getReadSnapshot(snapshotVersion);
BTreeNodeLinkRef root;
@ -7669,7 +7683,7 @@ public:
root = *snapshot->extra.getPtr<BTreeNodeLink>();
}
return cursor->init(this, reason, snapshot, root);
return cursor->init(this, reason, options, snapshot, root);
}
};
@ -7680,7 +7694,7 @@ RedwoodRecordRef VersionedBTree::dbEnd(LiteralStringRef("\xff\xff\xff\xff\xff"))
class KeyValueStoreRedwood : public IKeyValueStore {
public:
KeyValueStoreRedwood(std::string filename, UID logID)
KeyValueStoreRedwood(std::string filename, UID logID, Reference<IEncryptionKeyProvider> encryptionKeyProvider)
: m_filename(filename), m_concurrentReads(SERVER_KNOBS->REDWOOD_KVSTORE_CONCURRENT_READS, 0),
prefetch(SERVER_KNOBS->REDWOOD_KVSTORE_RANGE_PREFETCH) {
@ -7703,10 +7717,15 @@ public:
EncodingType encodingType = EncodingType::XXHash64;
// Deterministically enable encryption based on uid
if (g_network->isSimulated() && logID.hash() % 2 == 0) {
encodingType = EncodingType::XOREncryption;
m_keyProvider = std::make_shared<XOREncryptionKeyProvider>(filename);
// When reopening Redwood on restart, the cluser encryption config could be unknown at this point,
// for which shouldEnableEncryption will return false. In that case, if the Redwood instance was encrypted
// before, the encoding type in the header page will be used instead.
//
// TODO(yiwu): When the cluster encryption config is available later, fail if the cluster is configured to
// enable encryption, but the Redwood instance is unencrypted.
if (encryptionKeyProvider && encryptionKeyProvider->shouldEnableEncryption()) {
encodingType = EncodingType::AESEncryptionV1;
m_keyProvider = encryptionKeyProvider;
}
IPager2* pager = new DWALPager(pageSize,
@ -7798,18 +7817,26 @@ public:
m_tree->set(keyValue);
}
Future<RangeResult> readRange(KeyRangeRef keys, int rowLimit, int byteLimit, IKeyValueStore::ReadType) override {
Future<RangeResult> readRange(KeyRangeRef keys,
int rowLimit,
int byteLimit,
Optional<ReadOptions> options) override {
debug_printf("READRANGE %s\n", printable(keys).c_str());
return catchError(readRange_impl(this, keys, rowLimit, byteLimit));
return catchError(readRange_impl(this, keys, rowLimit, byteLimit, options));
}
ACTOR static Future<RangeResult> readRange_impl(KeyValueStoreRedwood* self,
KeyRange keys,
int rowLimit,
int byteLimit) {
int byteLimit,
Optional<ReadOptions> options) {
state PagerEventReasons reason = PagerEventReasons::RangeRead;
state VersionedBTree::BTreeCursor cur;
wait(
self->m_tree->initBTreeCursor(&cur, self->m_tree->getLastCommittedVersion(), PagerEventReasons::RangeRead));
if (options.present() && options.get().type == ReadType::FETCH) {
reason = PagerEventReasons::FetchRange;
}
wait(self->m_tree->initBTreeCursor(&cur, self->m_tree->getLastCommittedVersion(), reason, options));
state PriorityMultiLock::Lock lock;
state Future<Void> f;
++g_redwoodMetrics.metric.opGetRange;
@ -7945,10 +7972,12 @@ public:
return result;
}
ACTOR static Future<Optional<Value>> readValue_impl(KeyValueStoreRedwood* self, Key key, Optional<UID> debugID) {
ACTOR static Future<Optional<Value>> readValue_impl(KeyValueStoreRedwood* self,
Key key,
Optional<ReadOptions> options) {
state VersionedBTree::BTreeCursor cur;
wait(
self->m_tree->initBTreeCursor(&cur, self->m_tree->getLastCommittedVersion(), PagerEventReasons::PointRead));
wait(self->m_tree->initBTreeCursor(
&cur, self->m_tree->getLastCommittedVersion(), PagerEventReasons::PointRead, options));
// Not locking for point reads, instead relying on IO priority lock
// state PriorityMultiLock::Lock lock = wait(self->m_concurrentReads.lock());
@ -7967,15 +7996,12 @@ public:
return Optional<Value>();
}
Future<Optional<Value>> readValue(KeyRef key, IKeyValueStore::ReadType, Optional<UID> debugID) override {
return catchError(readValue_impl(this, key, debugID));
Future<Optional<Value>> readValue(KeyRef key, Optional<ReadOptions> options) override {
return catchError(readValue_impl(this, key, options));
}
Future<Optional<Value>> readValuePrefix(KeyRef key,
int maxLength,
IKeyValueStore::ReadType,
Optional<UID> debugID) override {
return catchError(map(readValue_impl(this, key, debugID), [maxLength](Optional<Value> v) {
Future<Optional<Value>> readValuePrefix(KeyRef key, int maxLength, Optional<ReadOptions> options) override {
return catchError(map(readValue_impl(this, key, options), [maxLength](Optional<Value> v) {
if (v.present() && v.get().size() > maxLength) {
v.get().contents() = v.get().substr(0, maxLength);
}
@ -7994,7 +8020,7 @@ private:
PriorityMultiLock m_concurrentReads;
bool prefetch;
Version m_nextCommitVersion;
std::shared_ptr<IEncryptionKeyProvider> m_keyProvider;
Reference<IEncryptionKeyProvider> m_keyProvider;
Future<Void> m_lastCommit = Void();
template <typename T>
@ -8003,8 +8029,10 @@ private:
}
};
IKeyValueStore* keyValueStoreRedwoodV1(std::string const& filename, UID logID) {
return new KeyValueStoreRedwood(filename, logID);
IKeyValueStore* keyValueStoreRedwoodV1(std::string const& filename,
UID logID,
Reference<IEncryptionKeyProvider> encryptionKeyProvider) {
return new KeyValueStoreRedwood(filename, logID, encryptionKeyProvider);
}
int randomSize(int max) {
@ -9735,7 +9763,7 @@ TEST_CASE("Lredwood/correctness/btree") {
state bool shortTest = params.getInt("shortTest").orDefault(deterministicRandom()->random01() < 0.25);
state int pageSize =
shortTest ? 200 : (deterministicRandom()->coinflip() ? 4096 : deterministicRandom()->randomInt(200, 400));
shortTest ? 250 : (deterministicRandom()->coinflip() ? 4096 : deterministicRandom()->randomInt(250, 400));
state int extentSize =
params.getInt("extentSize")
.orDefault(deterministicRandom()->coinflip() ? SERVER_KNOBS->REDWOOD_DEFAULT_EXTENT_SIZE
@ -9784,12 +9812,13 @@ TEST_CASE("Lredwood/correctness/btree") {
// Max number of records in the BTree or the versioned written map to visit
state int64_t maxRecordsRead = params.getInt("maxRecordsRead").orDefault(300e6);
state EncodingType encodingType = EncodingType::XXHash64;
state std::shared_ptr<IEncryptionKeyProvider> keyProvider;
if (deterministicRandom()->coinflip()) {
encodingType = EncodingType::XOREncryption;
keyProvider = std::make_shared<XOREncryptionKeyProvider>(file);
state EncodingType encodingType =
static_cast<EncodingType>(deterministicRandom()->randomInt(0, EncodingType::MAX_ENCODING_TYPE));
state Reference<IEncryptionKeyProvider> keyProvider;
if (encodingType == EncodingType::AESEncryptionV1) {
keyProvider = makeReference<RandomEncryptionKeyProvider>();
} else if (encodingType == EncodingType::XOREncryption_TestOnly) {
keyProvider = makeReference<XOREncryptionKeyProvider_TestOnly>(file);
}
printf("\n");
@ -10271,7 +10300,7 @@ TEST_CASE(":/redwood/performance/extentQueue") {
remapCleanupWindowBytes,
concurrentExtentReads,
false,
nullptr);
Reference<IEncryptionKeyProvider>());
wait(success(pager->init()));
@ -10322,8 +10351,14 @@ TEST_CASE(":/redwood/performance/extentQueue") {
}
printf("Reopening pager file from disk.\n");
pager = new DWALPager(
pageSize, extentSize, fileName, cacheSizeBytes, remapCleanupWindowBytes, concurrentExtentReads, false, nullptr);
pager = new DWALPager(pageSize,
extentSize,
fileName,
cacheSizeBytes,
remapCleanupWindowBytes,
concurrentExtentReads,
false,
Reference<IEncryptionKeyProvider>());
wait(success(pager->init()));
printf("Starting ExtentQueue FastPath Recovery from Disk.\n");
@ -10468,8 +10503,9 @@ TEST_CASE(":/redwood/performance/set") {
remapCleanupWindowBytes,
concurrentExtentReads,
pagerMemoryOnly,
nullptr);
state VersionedBTree* btree = new VersionedBTree(pager, file, EncodingType::XXHash64, nullptr);
Reference<IEncryptionKeyProvider>());
state VersionedBTree* btree =
new VersionedBTree(pager, file, EncodingType::XXHash64, Reference<IEncryptionKeyProvider>());
wait(btree->init());
printf("Initialized. StorageBytes=%s\n", btree->getStorageBytes().toString().c_str());
@ -10997,7 +11033,9 @@ ACTOR Future<Void> randomRangeScans(IKeyValueStore* kvs,
int valueSize,
int recordCountTarget,
bool singlePrefix,
int rowLimit) {
int rowLimit,
int byteLimit,
Optional<ReadOptions> options = Optional<ReadOptions>()) {
fmt::print("\nstoreType: {}\n", static_cast<int>(kvs->getType()));
fmt::print("prefixSource: {}\n", source.toString());
fmt::print("suffixSize: {}\n", suffixSize);
@ -11030,7 +11068,7 @@ ACTOR Future<Void> randomRangeScans(IKeyValueStore* kvs,
KeyRangeRef range = source.getKeyRangeRef(singlePrefix, suffixSize);
int rowLim = (deterministicRandom()->randomInt(0, 2) != 0) ? rowLimit : -rowLimit;
RangeResult result = wait(kvs->readRange(range, rowLim));
RangeResult result = wait(kvs->readRange(range, rowLim, byteLimit, options));
recordsRead += result.size();
bytesRead += result.size() * recordSize;
@ -11053,6 +11091,7 @@ TEST_CASE(":/redwood/performance/randomRangeScans") {
state int prefixLen = 30;
state int suffixSize = 12;
state int valueSize = 100;
state int maxByteLimit = std::numeric_limits<int>::max();
// TODO change to 100e8 after figuring out no-disk redwood mode
state int writeRecordCountTarget = 1e6;
@ -11068,11 +11107,11 @@ TEST_CASE(":/redwood/performance/randomRangeScans") {
redwood, suffixSize, valueSize, source, writeRecordCountTarget, writePrefixesInOrder, false));
// divide targets for tiny queries by 10 because they are much slower
wait(randomRangeScans(redwood, suffixSize, source, valueSize, queryRecordTarget / 10, true, 10));
wait(randomRangeScans(redwood, suffixSize, source, valueSize, queryRecordTarget, true, 1000));
wait(randomRangeScans(redwood, suffixSize, source, valueSize, queryRecordTarget / 10, false, 100));
wait(randomRangeScans(redwood, suffixSize, source, valueSize, queryRecordTarget, false, 10000));
wait(randomRangeScans(redwood, suffixSize, source, valueSize, queryRecordTarget, false, 1000000));
wait(randomRangeScans(redwood, suffixSize, source, valueSize, queryRecordTarget / 10, true, 10, maxByteLimit));
wait(randomRangeScans(redwood, suffixSize, source, valueSize, queryRecordTarget, true, 1000, maxByteLimit));
wait(randomRangeScans(redwood, suffixSize, source, valueSize, queryRecordTarget / 10, false, 100, maxByteLimit));
wait(randomRangeScans(redwood, suffixSize, source, valueSize, queryRecordTarget, false, 10000, maxByteLimit));
wait(randomRangeScans(redwood, suffixSize, source, valueSize, queryRecordTarget, false, 1000000, maxByteLimit));
wait(closeKVS(redwood));
printf("\n");
return Void();

View File

@ -113,7 +113,7 @@ enum {
OPT_METRICSPREFIX, OPT_LOGGROUP, OPT_LOCALITY, OPT_IO_TRUST_SECONDS, OPT_IO_TRUST_WARN_ONLY, OPT_FILESYSTEM, OPT_PROFILER_RSS_SIZE, OPT_KVFILE,
OPT_TRACE_FORMAT, OPT_WHITELIST_BINPATH, OPT_BLOB_CREDENTIAL_FILE, OPT_CONFIG_PATH, OPT_USE_TEST_CONFIG_DB, OPT_FAULT_INJECTION, OPT_PROFILER, OPT_PRINT_SIMTIME,
OPT_FLOW_PROCESS_NAME, OPT_FLOW_PROCESS_ENDPOINT, OPT_IP_TRUSTED_MASK, OPT_KMS_CONN_DISCOVERY_URL_FILE, OPT_KMS_CONNECTOR_TYPE, OPT_KMS_CONN_VALIDATION_TOKEN_DETAILS,
OPT_KMS_CONN_GET_ENCRYPTION_KEYS_ENDPOINT, OPT_NEW_CLUSTER_KEY, OPT_USE_FUTURE_PROTOCOL_VERSION
OPT_KMS_CONN_GET_ENCRYPTION_KEYS_ENDPOINT, OPT_NEW_CLUSTER_KEY, OPT_AUTHZ_PUBLIC_KEY_FILE, OPT_USE_FUTURE_PROTOCOL_VERSION
};
CSimpleOpt::SOption g_rgOptions[] = {
@ -128,8 +128,8 @@ CSimpleOpt::SOption g_rgOptions[] = {
{ OPT_LISTEN, "-l", SO_REQ_SEP },
{ OPT_LISTEN, "--listen-address", SO_REQ_SEP },
#ifdef __linux__
{ OPT_FILESYSTEM, "--data-filesystem", SO_REQ_SEP },
{ OPT_PROFILER_RSS_SIZE, "--rsssize", SO_REQ_SEP },
{ OPT_FILESYSTEM, "--data-filesystem", SO_REQ_SEP },
{ OPT_PROFILER_RSS_SIZE, "--rsssize", SO_REQ_SEP },
#endif
{ OPT_DATAFOLDER, "-d", SO_REQ_SEP },
{ OPT_DATAFOLDER, "--datadir", SO_REQ_SEP },
@ -208,6 +208,7 @@ CSimpleOpt::SOption g_rgOptions[] = {
{ OPT_FLOW_PROCESS_ENDPOINT, "--process-endpoint", SO_REQ_SEP },
{ OPT_IP_TRUSTED_MASK, "--trusted-subnet-", SO_REQ_SEP },
{ OPT_NEW_CLUSTER_KEY, "--new-cluster-key", SO_REQ_SEP },
{ OPT_AUTHZ_PUBLIC_KEY_FILE, "--authorization-public-key-file", SO_REQ_SEP },
{ OPT_KMS_CONN_DISCOVERY_URL_FILE, "--discover-kms-conn-url-file", SO_REQ_SEP },
{ OPT_KMS_CONNECTOR_TYPE, "--kms-connector-type", SO_REQ_SEP },
{ OPT_KMS_CONN_VALIDATION_TOKEN_DETAILS, "--kms-conn-validation-token-details", SO_REQ_SEP },
@ -1022,8 +1023,8 @@ enum class ServerRole {
};
struct CLIOptions {
std::string commandLine;
std::string fileSystemPath, dataFolder, connFile, seedConnFile, seedConnString, logFolder = ".", metricsConnFile,
metricsPrefix, newClusterKey;
std::string fileSystemPath, dataFolder, connFile, seedConnFile, seedConnString,
logFolder = ".", metricsConnFile, metricsPrefix, newClusterKey, authzPublicKeyFile;
std::string logGroup = "default";
uint64_t rollsize = TRACE_DEFAULT_ROLL_SIZE;
uint64_t maxLogsSize = TRACE_DEFAULT_MAX_LOGS_SIZE;
@ -1713,6 +1714,10 @@ private:
}
break;
}
case OPT_AUTHZ_PUBLIC_KEY_FILE: {
authzPublicKeyFile = args.OptionArg();
break;
}
case OPT_USE_FUTURE_PROTOCOL_VERSION: {
if (!strcmp(args.OptionArg(), "true")) {
::useFutureProtocolVersion();
@ -2029,6 +2034,16 @@ int main(int argc, char* argv[]) {
openTraceFile(
opts.publicAddresses.address, opts.rollsize, opts.maxLogsSize, opts.logFolder, "trace", opts.logGroup);
g_network->initTLS();
if (!opts.authzPublicKeyFile.empty()) {
try {
FlowTransport::transport().loadPublicKeyFile(opts.authzPublicKeyFile);
} catch (Error& e) {
TraceEvent("AuthzPublicKeySetLoadError").error(e);
}
FlowTransport::transport().watchPublicKeyFile(opts.authzPublicKeyFile);
} else {
TraceEvent(SevInfo, "AuthzPublicKeyFileNotSet");
}
if (expectsPublicAddress) {
for (int ii = 0; ii < (opts.publicAddresses.secondaryAddress.present() ? 2 : 1); ++ii) {
@ -2238,6 +2253,8 @@ int main(int argc, char* argv[]) {
KnobValue::create(ini.GetBoolValue("META", "enableEncryption", false)));
g_knobs.setKnob("enable_tlog_encryption",
KnobValue::create(ini.GetBoolValue("META", "enableTLogEncryption", false)));
g_knobs.setKnob("enable_storage_server_encryption",
KnobValue::create(ini.GetBoolValue("META", "enableStorageServerEncryption", false)));
g_knobs.setKnob("enable_blob_granule_encryption",
KnobValue::create(ini.GetBoolValue("META", "enableBlobGranuleEncryption", false)));
g_knobs.setKnob("enable_blob_granule_compression",

View File

@ -352,7 +352,7 @@ FDB_DECLARE_BOOLEAN_PARAM(MoveKeyRangeOutPhysicalShard);
class PhysicalShardCollection : public ReferenceCounted<PhysicalShardCollection> {
public:
PhysicalShardCollection() : lastTransitionStartTime(now()), requireTransition(false) {}
PhysicalShardCollection() : requireTransition(false), lastTransitionStartTime(now()) {}
enum class PhysicalShardCreationTime { DDInit, DDRelocator };

View File

@ -27,19 +27,20 @@
typedef enum { TLOG_ENCRYPTION = 0, STORAGE_SERVER_ENCRYPTION = 1, BLOB_GRANULE_ENCRYPTION = 2 } EncryptOperationType;
inline bool isEncryptionOpSupported(EncryptOperationType operation_type, ClientDBInfo dbInfo) {
inline bool isEncryptionOpSupported(EncryptOperationType operation_type, const ClientDBInfo& dbInfo) {
if (!dbInfo.isEncryptionEnabled) {
return false;
}
if (operation_type == TLOG_ENCRYPTION) {
return SERVER_KNOBS->ENABLE_TLOG_ENCRYPTION;
} else if (operation_type == STORAGE_SERVER_ENCRYPTION) {
return SERVER_KNOBS->ENABLE_STORAGE_SERVER_ENCRYPTION;
} else if (operation_type == BLOB_GRANULE_ENCRYPTION) {
bool supported = SERVER_KNOBS->ENABLE_BLOB_GRANULE_ENCRYPTION && SERVER_KNOBS->BG_METADATA_SOURCE == "tenant";
ASSERT((supported && SERVER_KNOBS->ENABLE_ENCRYPTION) || !supported);
return supported;
} else {
// TODO (Nim): Add once storage server encryption knob is created
return false;
}
}

View File

@ -0,0 +1,284 @@
/*
* IEncryptionKeyProvider.actor.h
*
* This source file is part of the FoundationDB open source project
*
* Copyright 2013-2022 Apple Inc. and the FoundationDB project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#if defined(NO_INTELLISENSE) && !defined(FDBSERVER_IENCRYPTIONKEYPROVIDER_ACTOR_G_H)
#define FDBSERVER_IENCRYPTIONKEYPROVIDER_ACTOR_G_H
#include "fdbserver/IEncryptionKeyProvider.actor.g.h"
#elif !defined(FDBSERVER_IENCRYPTIONKEYPROVIDER_ACTOR_H)
#define FDBSERVER_IENCRYPTIONKEYPROVIDER_ACTOR_H
#include "fdbclient/GetEncryptCipherKeys.actor.h"
#include "fdbclient/Tenant.h"
#include "fdbserver/EncryptionOpsUtils.h"
#include "fdbserver/ServerDBInfo.h"
#include "flow/Arena.h"
#define XXH_INLINE_ALL
#include "flow/xxhash.h"
#include "flow/actorcompiler.h" // This must be the last #include.
typedef uint64_t XOREncryptionKeyID;
// EncryptionKeyRef is somewhat multi-variant, it will contain members representing the union
// of all fields relevant to any implemented encryption scheme. They are generally of
// the form
// Page Fields - fields which come from or are stored in the Page
// Secret Fields - fields which are only known by the Key Provider
// but it is up to each encoding and provider which fields are which and which ones are used
//
// TODO(yiwu): Rename and/or refactor this struct. It doesn't sound like an encryption key should
// contain page fields like encryption header.
struct EncryptionKeyRef {
EncryptionKeyRef(){};
EncryptionKeyRef(Arena& arena, const EncryptionKeyRef& toCopy)
: cipherKeys(toCopy.cipherKeys), secret(arena, toCopy.secret), id(toCopy.id) {}
int expectedSize() const { return secret.size(); }
// Fields for AESEncryptionV1
TextAndHeaderCipherKeys cipherKeys;
Optional<BlobCipherEncryptHeader> cipherHeader;
// Fields for XOREncryption_TestOnly
StringRef secret;
Optional<XOREncryptionKeyID> id;
};
typedef Standalone<EncryptionKeyRef> EncryptionKey;
// Interface used by pager to get encryption keys reading pages from disk
// and by the BTree to get encryption keys to use for new pages
class IEncryptionKeyProvider : public ReferenceCounted<IEncryptionKeyProvider> {
public:
virtual ~IEncryptionKeyProvider() {}
// Get an EncryptionKey with Secret Fields populated based on the given Page Fields.
// It is up to the implementation which fields those are.
// The output Page Fields must match the input Page Fields.
virtual Future<EncryptionKey> getSecrets(const EncryptionKeyRef& key) = 0;
// Get encryption key that should be used for a given user Key-Value range
virtual Future<EncryptionKey> getByRange(const KeyRef& begin, const KeyRef& end) = 0;
// Setting tenant prefix to tenant name map.
virtual void setTenantPrefixIndex(Reference<TenantPrefixIndex> tenantPrefixIndex) {}
virtual bool shouldEnableEncryption() const = 0;
};
// The null key provider is useful to simplify page decoding.
// It throws an error for any key info requested.
class NullKeyProvider : public IEncryptionKeyProvider {
public:
virtual ~NullKeyProvider() {}
bool shouldEnableEncryption() const override { return true; }
Future<EncryptionKey> getSecrets(const EncryptionKeyRef& key) override { throw encryption_key_not_found(); }
Future<EncryptionKey> getByRange(const KeyRef& begin, const KeyRef& end) override {
throw encryption_key_not_found();
}
};
// Key provider for dummy XOR encryption scheme
class XOREncryptionKeyProvider_TestOnly : public IEncryptionKeyProvider {
public:
XOREncryptionKeyProvider_TestOnly(std::string filename) {
ASSERT(g_network->isSimulated());
// Choose a deterministic random filename (without path) byte for secret generation
// Remove any leading directory names
size_t lastSlash = filename.find_last_of("\\/");
if (lastSlash != filename.npos) {
filename.erase(0, lastSlash);
}
xorWith = filename.empty() ? 0x5e
: (uint8_t)filename[XXH3_64bits(filename.data(), filename.size()) % filename.size()];
}
virtual ~XOREncryptionKeyProvider_TestOnly() {}
bool shouldEnableEncryption() const override { return true; }
Future<EncryptionKey> getSecrets(const EncryptionKeyRef& key) override {
if (!key.id.present()) {
throw encryption_key_not_found();
}
EncryptionKey s = key;
uint8_t secret = ~(uint8_t)key.id.get() ^ xorWith;
s.secret = StringRef(s.arena(), &secret, 1);
return s;
}
Future<EncryptionKey> getByRange(const KeyRef& begin, const KeyRef& end) override {
EncryptionKeyRef k;
k.id = end.empty() ? 0 : *(end.end() - 1);
return getSecrets(k);
}
uint8_t xorWith;
};
// Key provider to provider cipher keys randomly from a pre-generated pool. Use for testing.
class RandomEncryptionKeyProvider : public IEncryptionKeyProvider {
public:
RandomEncryptionKeyProvider() {
for (unsigned i = 0; i < NUM_CIPHER; i++) {
BlobCipherDetails cipherDetails;
cipherDetails.encryptDomainId = i;
cipherDetails.baseCipherId = deterministicRandom()->randomUInt64();
cipherDetails.salt = deterministicRandom()->randomUInt64();
cipherKeys[i] = generateCipherKey(cipherDetails);
}
}
virtual ~RandomEncryptionKeyProvider() = default;
bool shouldEnableEncryption() const override { return true; }
Future<EncryptionKey> getSecrets(const EncryptionKeyRef& key) override {
ASSERT(key.cipherHeader.present());
EncryptionKey s = key;
s.cipherKeys.cipherTextKey = cipherKeys[key.cipherHeader.get().cipherTextDetails.encryptDomainId];
s.cipherKeys.cipherHeaderKey = cipherKeys[key.cipherHeader.get().cipherHeaderDetails.encryptDomainId];
return s;
}
Future<EncryptionKey> getByRange(const KeyRef& /*begin*/, const KeyRef& /*end*/) override {
EncryptionKey s;
s.cipherKeys.cipherTextKey = getRandomCipherKey();
s.cipherKeys.cipherHeaderKey = getRandomCipherKey();
return s;
}
private:
Reference<BlobCipherKey> generateCipherKey(const BlobCipherDetails& cipherDetails) {
static unsigned char SHA_KEY[] = "3ab9570b44b8315fdb261da6b1b6c13b";
Arena arena;
StringRef digest = computeAuthToken(reinterpret_cast<const unsigned char*>(&cipherDetails.baseCipherId),
sizeof(EncryptCipherBaseKeyId),
SHA_KEY,
AES_256_KEY_LENGTH,
arena);
return makeReference<BlobCipherKey>(cipherDetails.encryptDomainId,
cipherDetails.baseCipherId,
digest.begin(),
AES_256_KEY_LENGTH,
cipherDetails.salt,
std::numeric_limits<int64_t>::max() /* refreshAt */,
std::numeric_limits<int64_t>::max() /* expireAt */);
}
Reference<BlobCipherKey> getRandomCipherKey() {
return cipherKeys[deterministicRandom()->randomInt(0, NUM_CIPHER)];
}
static constexpr int NUM_CIPHER = 1000;
Reference<BlobCipherKey> cipherKeys[NUM_CIPHER];
};
// Key provider which extract tenant id from range key prefixes, and fetch tenant specific encryption keys from
// EncryptKeyProxy.
class TenantAwareEncryptionKeyProvider : public IEncryptionKeyProvider {
public:
TenantAwareEncryptionKeyProvider(Reference<AsyncVar<ServerDBInfo> const> db) : db(db) {}
virtual ~TenantAwareEncryptionKeyProvider() = default;
bool shouldEnableEncryption() const override {
return isEncryptionOpSupported(EncryptOperationType::STORAGE_SERVER_ENCRYPTION, db->get().client);
}
ACTOR static Future<EncryptionKey> getSecrets(TenantAwareEncryptionKeyProvider* self, EncryptionKeyRef key) {
if (!key.cipherHeader.present()) {
TraceEvent("TenantAwareEncryptionKeyProvider_CipherHeaderMissing");
throw encrypt_ops_error();
}
TextAndHeaderCipherKeys cipherKeys = wait(getEncryptCipherKeys(self->db, key.cipherHeader.get()));
EncryptionKey s = key;
s.cipherKeys = cipherKeys;
return s;
}
Future<EncryptionKey> getSecrets(const EncryptionKeyRef& key) override { return getSecrets(this, key); }
ACTOR static Future<EncryptionKey> getByRange(TenantAwareEncryptionKeyProvider* self, KeyRef begin, KeyRef end) {
EncryptCipherDomainName domainName;
EncryptCipherDomainId domainId = self->getEncryptionDomainId(begin, end, &domainName);
TextAndHeaderCipherKeys cipherKeys = wait(getLatestEncryptCipherKeysForDomain(self->db, domainId, domainName));
EncryptionKey s;
s.cipherKeys = cipherKeys;
return s;
}
Future<EncryptionKey> getByRange(const KeyRef& begin, const KeyRef& end) override {
return getByRange(this, begin, end);
}
void setTenantPrefixIndex(Reference<TenantPrefixIndex> tenantPrefixIndex) override {
ASSERT(tenantPrefixIndex.isValid());
this->tenantPrefixIndex = tenantPrefixIndex;
}
private:
EncryptCipherDomainId getEncryptionDomainId(const KeyRef& begin,
const KeyRef& end,
EncryptCipherDomainName* domainName) {
int64_t domainId = SYSTEM_KEYSPACE_ENCRYPT_DOMAIN_ID;
int64_t beginTenantId = getTenant(begin, true /*inclusive*/);
int64_t endTenantId = getTenant(end, false /*inclusive*/);
if (beginTenantId == endTenantId && beginTenantId != SYSTEM_KEYSPACE_ENCRYPT_DOMAIN_ID) {
ASSERT(tenantPrefixIndex.isValid());
Key tenantPrefix = TenantMapEntry::idToPrefix(beginTenantId);
auto view = tenantPrefixIndex->atLatest();
auto itr = view.find(tenantPrefix);
if (itr != view.end()) {
*domainName = *itr;
domainId = beginTenantId;
} else {
// No tenant with the same tenant id. We could be in optional or disabled tenant mode.
}
}
if (domainId == SYSTEM_KEYSPACE_ENCRYPT_DOMAIN_ID) {
*domainName = FDB_DEFAULT_ENCRYPT_DOMAIN_NAME;
}
return domainId;
}
int64_t getTenant(const KeyRef& key, bool inclusive) {
// A valid tenant id is always a valid encrypt domain id.
static_assert(ENCRYPT_INVALID_DOMAIN_ID < 0);
if (key.size() < TENANT_PREFIX_SIZE || key >= systemKeys.begin) {
return SYSTEM_KEYSPACE_ENCRYPT_DOMAIN_ID;
}
// TODO(yiwu): Use TenantMapEntry::prefixToId() instead.
int64_t tenantId = bigEndian64(*reinterpret_cast<const int64_t*>(key.begin()));
if (tenantId < 0) {
return SYSTEM_KEYSPACE_ENCRYPT_DOMAIN_ID;
}
if (!inclusive && key.size() == TENANT_PREFIX_SIZE) {
tenantId = tenantId - 1;
}
ASSERT(tenantId >= 0);
return tenantId;
}
Reference<AsyncVar<ServerDBInfo> const> db;
Reference<TenantPrefixIndex> tenantPrefixIndex;
};
#include "flow/unactorcompiler.h"
#endif

View File

@ -23,9 +23,11 @@
#pragma once
#include "fdbclient/FDBTypes.h"
#include "fdbserver/Knobs.h"
#include "fdbserver/ServerDBInfo.h"
#include "fdbclient/StorageCheckpoint.h"
#include "fdbclient/Tenant.h"
#include "fdbserver/Knobs.h"
#include "fdbserver/IEncryptionKeyProvider.actor.h"
#include "fdbserver/ServerDBInfo.h"
struct CheckpointRequest {
const Version version; // The FDB version at which the checkpoint is created.
@ -70,30 +72,19 @@ public:
virtual Future<Void> commit(
bool sequential = false) = 0; // returns when prior sets and clears are (atomically) durable
enum class ReadType {
EAGER,
FETCH,
LOW,
NORMAL,
HIGH,
};
virtual Future<Optional<Value>> readValue(KeyRef key,
ReadType type = ReadType::NORMAL,
Optional<UID> debugID = Optional<UID>()) = 0;
virtual Future<Optional<Value>> readValue(KeyRef key, Optional<ReadOptions> options = Optional<ReadOptions>()) = 0;
// Like readValue(), but returns only the first maxLength bytes of the value if it is longer
virtual Future<Optional<Value>> readValuePrefix(KeyRef key,
int maxLength,
ReadType type = ReadType::NORMAL,
Optional<UID> debugID = Optional<UID>()) = 0;
Optional<ReadOptions> options = Optional<ReadOptions>()) = 0;
// If rowLimit>=0, reads first rows sorted ascending, otherwise reads last rows sorted descending
// The total size of the returned value (less the last entry) will be less than byteLimit
virtual Future<RangeResult> readRange(KeyRangeRef keys,
int rowLimit = 1 << 30,
int byteLimit = 1 << 30,
ReadType type = ReadType::NORMAL) = 0;
Optional<ReadOptions> options = Optional<ReadOptions>()) = 0;
// Shard management APIs.
// Adds key range to a physical shard.
@ -158,7 +149,9 @@ extern IKeyValueStore* keyValueStoreSQLite(std::string const& filename,
KeyValueStoreType storeType,
bool checkChecksums = false,
bool checkIntegrity = false);
extern IKeyValueStore* keyValueStoreRedwoodV1(std::string const& filename, UID logID);
extern IKeyValueStore* keyValueStoreRedwoodV1(std::string const& filename,
UID logID,
Reference<IEncryptionKeyProvider> encryptionKeyProvider = {});
extern IKeyValueStore* keyValueStoreRocksDB(std::string const& path,
UID logID,
KeyValueStoreType storeType,
@ -196,7 +189,8 @@ inline IKeyValueStore* openKVStore(KeyValueStoreType storeType,
int64_t memoryLimit,
bool checkChecksums = false,
bool checkIntegrity = false,
bool openRemotely = false) {
bool openRemotely = false,
Reference<IEncryptionKeyProvider> encryptionKeyProvider = {}) {
if (openRemotely) {
return openRemoteKVStore(storeType, filename, logID, memoryLimit, checkChecksums, checkIntegrity);
}
@ -208,7 +202,7 @@ inline IKeyValueStore* openKVStore(KeyValueStoreType storeType,
case KeyValueStoreType::MEMORY:
return keyValueStoreMemory(filename, logID, memoryLimit);
case KeyValueStoreType::SSD_REDWOOD_V1:
return keyValueStoreRedwoodV1(filename, logID);
return keyValueStoreRedwoodV1(filename, logID, encryptionKeyProvider);
case KeyValueStoreType::SSD_ROCKSDB_V1:
return keyValueStoreRocksDB(filename, logID, storeType);
case KeyValueStoreType::SSD_SHARDED_ROCKSDB:

View File

@ -17,20 +17,22 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#ifndef FDBSERVER_IPAGER_H
#define FDBSERVER_IPAGER_H
#include "flow/Error.h"
#include "flow/FastAlloc.h"
#include "flow/ProtocolVersion.h"
#include <cstddef>
#include <stdint.h>
#pragma once
#include "fdbserver/IKeyValueStore.h"
#include "flow/flow.h"
#include "fdbclient/FDBTypes.h"
#include "fdbclient/Tenant.h"
#include "fdbserver/IEncryptionKeyProvider.actor.h"
#include "fdbserver/IKeyValueStore.h"
#include "flow/BlobCipher.h"
#include "flow/Error.h"
#include "flow/FastAlloc.h"
#include "flow/flow.h"
#include "flow/ProtocolVersion.h"
#define XXH_INLINE_ALL
#include "flow/xxhash.h"
@ -46,10 +48,18 @@ typedef uint32_t QueueID;
enum class PagerEvents { CacheLookup = 0, CacheHit, CacheMiss, PageWrite, MAXEVENTS };
static const char* const PagerEventsStrings[] = { "Lookup", "Hit", "Miss", "Write", "Unknown" };
// Reasons for page level events.
enum class PagerEventReasons { PointRead = 0, RangeRead, RangePrefetch, Commit, LazyClear, MetaData, MAXEVENTREASONS };
static const char* const PagerEventReasonsStrings[] = {
"Get", "GetR", "GetRPF", "Commit", "LazyClr", "Meta", "Unknown"
enum class PagerEventReasons {
PointRead = 0,
FetchRange,
RangeRead,
RangePrefetch,
Commit,
LazyClear,
MetaData,
MAXEVENTREASONS
};
static const char* const PagerEventReasonsStrings[] = { "Get", "FetchR", "GetR", "GetRPF",
"Commit", "LazyClr", "Meta", "Unknown" };
static const unsigned int nonBtreeLevel = 0;
static const std::vector<std::pair<PagerEvents, PagerEventReasons>> possibleEventReasonPairs = {
@ -57,14 +67,17 @@ static const std::vector<std::pair<PagerEvents, PagerEventReasons>> possibleEven
{ PagerEvents::CacheLookup, PagerEventReasons::LazyClear },
{ PagerEvents::CacheLookup, PagerEventReasons::PointRead },
{ PagerEvents::CacheLookup, PagerEventReasons::RangeRead },
{ PagerEvents::CacheLookup, PagerEventReasons::FetchRange },
{ PagerEvents::CacheHit, PagerEventReasons::Commit },
{ PagerEvents::CacheHit, PagerEventReasons::LazyClear },
{ PagerEvents::CacheHit, PagerEventReasons::PointRead },
{ PagerEvents::CacheHit, PagerEventReasons::RangeRead },
{ PagerEvents::CacheHit, PagerEventReasons::FetchRange },
{ PagerEvents::CacheMiss, PagerEventReasons::Commit },
{ PagerEvents::CacheMiss, PagerEventReasons::LazyClear },
{ PagerEvents::CacheMiss, PagerEventReasons::PointRead },
{ PagerEvents::CacheMiss, PagerEventReasons::RangeRead },
{ PagerEvents::CacheMiss, PagerEventReasons::FetchRange },
{ PagerEvents::PageWrite, PagerEventReasons::Commit },
{ PagerEvents::PageWrite, PagerEventReasons::LazyClear },
};
@ -78,11 +91,7 @@ static const std::vector<std::pair<PagerEvents, PagerEventReasons>> L0PossibleEv
{ PagerEvents::PageWrite, PagerEventReasons::MetaData },
};
enum EncodingType : uint8_t {
XXHash64 = 0,
// For testing purposes
XOREncryption = 1
};
enum EncodingType : uint8_t { XXHash64 = 0, XOREncryption_TestOnly = 1, AESEncryptionV1 = 2, MAX_ENCODING_TYPE = 3 };
enum PageType : uint8_t {
HeaderPage = 0,
@ -93,41 +102,6 @@ enum PageType : uint8_t {
QueuePageInExtent = 5
};
// Encryption key ID
typedef uint64_t KeyID;
// EncryptionKeyRef is somewhat multi-variant, it will contain members representing the union
// of all fields relevant to any implemented encryption scheme. They are generally of
// the form
// Page Fields - fields which come from or are stored in the Page
// Secret Fields - fields which are only known by the Key Provider
// but it is up to each encoding and provider which fields are which and which ones are used
struct EncryptionKeyRef {
EncryptionKeyRef(){};
EncryptionKeyRef(Arena& arena, const EncryptionKeyRef& toCopy) : secret(arena, toCopy.secret), id(toCopy.id) {}
int expectedSize() const { return secret.size(); }
StringRef secret;
Optional<KeyID> id;
};
typedef Standalone<EncryptionKeyRef> EncryptionKey;
// Interface used by pager to get encryption keys by ID when reading pages from disk
// and by the BTree to get encryption keys to use for new pages
class IEncryptionKeyProvider {
public:
virtual ~IEncryptionKeyProvider() {}
// Get an EncryptionKey with Secret Fields populated based on the given Page Fields.
// It is up to the implementation which fields those are.
// The output Page Fields must match the input Page Fields.
virtual Future<EncryptionKey> getSecrets(const EncryptionKeyRef& key) = 0;
// Get encryption key that should be used for a given user Key-Value range
virtual Future<EncryptionKey> getByRange(const KeyRef& begin, const KeyRef& end) = 0;
};
// This is a hacky way to attach an additional object of an arbitrary type at runtime to another object.
// It stores an arbitrary void pointer and a void pointer function to call when the ArbitraryObject
// is destroyed.
@ -328,7 +302,7 @@ public:
};
// An encoding that validates the payload with an XXHash checksum
struct XXHashEncodingHeader {
struct XXHashEncoder {
XXH64_hash_t checksum;
void encode(uint8_t* payload, int len, PhysicalPageID seed) {
checksum = XXH3_64bits_withSeed(payload, len, seed);
@ -342,7 +316,7 @@ public:
// A dummy "encrypting" encoding which uses XOR with a 1 byte secret key on
// the payload to obfuscate it and protects the payload with an XXHash checksum.
struct XOREncryptionEncodingHeader {
struct XOREncryptionEncoder {
// Checksum is on unencrypted payload
XXH64_hash_t checksum;
uint8_t keyID;
@ -362,6 +336,27 @@ public:
}
}
};
struct AESEncryptionV1Encoder {
BlobCipherEncryptHeader header;
void encode(const TextAndHeaderCipherKeys& cipherKeys, uint8_t* payload, int len) {
EncryptBlobCipherAes265Ctr cipher(
cipherKeys.cipherTextKey, cipherKeys.cipherHeaderKey, ENCRYPT_HEADER_AUTH_TOKEN_MODE_SINGLE);
Arena arena;
StringRef ciphertext = cipher.encrypt(payload, len, &header, arena)->toStringRef();
ASSERT_EQ(len, ciphertext.size());
memcpy(payload, ciphertext.begin(), len);
}
void decode(const TextAndHeaderCipherKeys& cipherKeys, uint8_t* payload, int len) {
DecryptBlobCipherAes256Ctr cipher(cipherKeys.cipherTextKey, cipherKeys.cipherHeaderKey, header.iv);
Arena arena;
StringRef plaintext = cipher.decrypt(payload, len, header, arena)->toStringRef();
ASSERT_EQ(len, plaintext.size());
memcpy(payload, plaintext.begin(), len);
}
};
#pragma pack(pop)
// Get the size of the encoding header based on type
@ -369,9 +364,11 @@ public:
// existing pages, the payload offset is stored in the page.
static int encodingHeaderSize(EncodingType t) {
if (t == EncodingType::XXHash64) {
return sizeof(XXHashEncodingHeader);
} else if (t == EncodingType::XOREncryption) {
return sizeof(XOREncryptionEncodingHeader);
return sizeof(XXHashEncoder);
} else if (t == EncodingType::XOREncryption_TestOnly) {
return sizeof(XOREncryptionEncoder);
} else if (t == EncodingType::AESEncryptionV1) {
return sizeof(AESEncryptionV1Encoder);
} else {
throw page_encoding_not_supported();
}
@ -475,12 +472,15 @@ public:
ASSERT(VALGRIND_CHECK_MEM_IS_DEFINED(pPayload, payloadSize) == 0);
if (page->encodingType == EncodingType::XXHash64) {
page->getEncodingHeader<XXHashEncodingHeader>()->encode(pPayload, payloadSize, pageID);
} else if (page->encodingType == EncodingType::XOREncryption) {
page->getEncodingHeader<XXHashEncoder>()->encode(pPayload, payloadSize, pageID);
} else if (page->encodingType == EncodingType::XOREncryption_TestOnly) {
ASSERT(encryptionKey.secret.size() == 1);
XOREncryptionEncodingHeader* xh = page->getEncodingHeader<XOREncryptionEncodingHeader>();
XOREncryptionEncoder* xh = page->getEncodingHeader<XOREncryptionEncoder>();
xh->keyID = encryptionKey.id.orDefault(0);
xh->encode(encryptionKey.secret[0], pPayload, payloadSize, pageID);
} else if (page->encodingType == EncodingType::AESEncryptionV1) {
AESEncryptionV1Encoder* eh = page->getEncodingHeader<AESEncryptionV1Encoder>();
eh->encode(encryptionKey.cipherKeys, pPayload, payloadSize);
} else {
throw page_encoding_not_supported();
}
@ -504,8 +504,11 @@ public:
payloadSize = logicalSize - (pPayload - buffer);
// Populate encryption key with relevant fields from page
if (page->encodingType == EncodingType::XOREncryption) {
encryptionKey.id = page->getEncodingHeader<XOREncryptionEncodingHeader>()->keyID;
if (page->encodingType == EncodingType::XOREncryption_TestOnly) {
encryptionKey.id = page->getEncodingHeader<XOREncryptionEncoder>()->keyID;
} else if (page->encodingType == EncodingType::AESEncryptionV1) {
AESEncryptionV1Encoder* eh = page->getEncodingHeader<AESEncryptionV1Encoder>();
encryptionKey.cipherHeader = eh->header;
}
if (page->headerVersion == 1) {
@ -525,11 +528,13 @@ public:
// Post: Payload has been verified and decrypted if necessary
void postReadPayload(PhysicalPageID pageID) {
if (page->encodingType == EncodingType::XXHash64) {
page->getEncodingHeader<XXHashEncodingHeader>()->decode(pPayload, payloadSize, pageID);
} else if (page->encodingType == EncodingType::XOREncryption) {
page->getEncodingHeader<XXHashEncoder>()->decode(pPayload, payloadSize, pageID);
} else if (page->encodingType == EncodingType::XOREncryption_TestOnly) {
ASSERT(encryptionKey.secret.size() == 1);
page->getEncodingHeader<XOREncryptionEncodingHeader>()->decode(
page->getEncodingHeader<XOREncryptionEncoder>()->decode(
encryptionKey.secret[0], pPayload, payloadSize, pageID);
} else if (page->encodingType == EncodingType::AESEncryptionV1) {
page->getEncodingHeader<AESEncryptionV1Encoder>()->decode(encryptionKey.cipherKeys, pPayload, payloadSize);
} else {
throw page_encoding_not_supported();
}
@ -537,7 +542,9 @@ public:
const Arena& getArena() const { return arena; }
static bool isEncodingTypeEncrypted(EncodingType t) { return t == EncodingType::XOREncryption; }
static bool isEncodingTypeEncrypted(EncodingType t) {
return t == EncodingType::AESEncryptionV1 || t == EncodingType::XOREncryption_TestOnly;
}
// Returns true if the page's encoding type employs encryption
bool isEncrypted() const { return isEncodingTypeEncrypted(getEncodingType()); }
@ -739,52 +746,4 @@ protected:
~IPager2() {} // Destruction should be done using close()/dispose() from the IClosable interface
};
// The null key provider is useful to simplify page decoding.
// It throws an error for any key info requested.
class NullKeyProvider : public IEncryptionKeyProvider {
public:
virtual ~NullKeyProvider() {}
Future<EncryptionKey> getSecrets(const EncryptionKeyRef& key) override { throw encryption_key_not_found(); }
Future<EncryptionKey> getByRange(const KeyRef& begin, const KeyRef& end) override {
throw encryption_key_not_found();
}
};
// Key provider for dummy XOR encryption scheme
class XOREncryptionKeyProvider : public IEncryptionKeyProvider {
public:
XOREncryptionKeyProvider(std::string filename) {
ASSERT(g_network->isSimulated());
// Choose a deterministic random filename (without path) byte for secret generation
// Remove any leading directory names
size_t lastSlash = filename.find_last_of("\\/");
if (lastSlash != filename.npos) {
filename.erase(0, lastSlash);
}
xorWith = filename.empty() ? 0x5e
: (uint8_t)filename[XXH3_64bits(filename.data(), filename.size()) % filename.size()];
}
virtual ~XOREncryptionKeyProvider() {}
virtual Future<EncryptionKey> getSecrets(const EncryptionKeyRef& key) override {
if (!key.id.present()) {
throw encryption_key_not_found();
}
EncryptionKey s = key;
uint8_t secret = ~(uint8_t)key.id.get() ^ xorWith;
s.secret = StringRef(s.arena(), &secret, 1);
return s;
}
virtual Future<EncryptionKey> getByRange(const KeyRef& begin, const KeyRef& end) override {
EncryptionKeyRef k;
k.id = end.empty() ? 0 : *(end.end() - 1);
return getSecrets(k);
}
uint8_t xorWith;
};
#endif

View File

@ -34,7 +34,6 @@ struct RatekeeperInterface {
RequestStream<struct ReportCommitCostEstimationRequest> reportCommitCostEstimation;
struct LocalityData locality;
UID myId;
RequestStream<struct GlobalTagThrottlerStatusRequest> getGlobalTagThrottlerStatus;
RatekeeperInterface() {}
explicit RatekeeperInterface(const struct LocalityData& l, UID id) : locality(l), myId(id) {}
@ -47,14 +46,7 @@ struct RatekeeperInterface {
template <class Archive>
void serialize(Archive& ar) {
serializer(ar,
waitFailure,
getRateInfo,
haltRatekeeper,
reportCommitCostEstimation,
locality,
myId,
getGlobalTagThrottlerStatus);
serializer(ar, waitFailure, getRateInfo, haltRatekeeper, reportCommitCostEstimation, locality, myId);
}
};
@ -167,39 +159,4 @@ struct ReportCommitCostEstimationRequest {
}
};
struct GlobalTagThrottlerStatusReply {
constexpr static FileIdentifier file_identifier = 9510482;
struct TagStats {
constexpr static FileIdentifier file_identifier = 6018293;
double desiredTps;
Optional<double> limitingTps;
double targetTps;
double reservedTps;
template <class Ar>
void serialize(Ar& ar) {
serializer(ar, desiredTps, limitingTps, targetTps, reservedTps);
}
};
std::unordered_map<TransactionTag, TagStats> status;
template <class Ar>
void serialize(Ar& ar) {
serializer(ar, status);
}
};
struct GlobalTagThrottlerStatusRequest {
constexpr static FileIdentifier file_identifier = 5620934;
ReplyPromise<struct GlobalTagThrottlerStatusReply> reply;
template <class Ar>
void serialize(Ar& ar) {
serializer(ar, reply);
}
};
#endif // FDBSERVER_RATEKEEPERINTERFACE_H

View File

@ -155,13 +155,12 @@ struct OpenKVStoreRequest {
struct IKVSGetValueRequest {
constexpr static FileIdentifier file_identifier = 1029439;
KeyRef key;
IKeyValueStore::ReadType type;
Optional<UID> debugID = Optional<UID>();
Optional<ReadOptions> options;
ReplyPromise<Optional<Value>> reply;
template <class Ar>
void serialize(Ar& ar) {
serializer(ar, key, type, debugID, reply);
serializer(ar, key, options, reply);
}
};
@ -202,13 +201,12 @@ struct IKVSReadValuePrefixRequest {
constexpr static FileIdentifier file_identifier = 1928374;
KeyRef key;
int maxLength;
IKeyValueStore::ReadType type;
Optional<UID> debugID = Optional<UID>();
Optional<ReadOptions> options;
ReplyPromise<Optional<Value>> reply;
template <class Ar>
void serialize(Ar& ar) {
serializer(ar, key, maxLength, type, debugID, reply);
serializer(ar, key, maxLength, options, reply);
}
};
@ -246,12 +244,12 @@ struct IKVSReadRangeRequest {
KeyRangeRef keys;
int rowLimit;
int byteLimit;
IKeyValueStore::ReadType type;
Optional<ReadOptions> options;
ReplyPromise<IKVSReadRangeReply> reply;
template <class Ar>
void serialize(Ar& ar) {
serializer(ar, keys, rowLimit, byteLimit, type, reply);
serializer(ar, keys, rowLimit, byteLimit, options, reply);
}
};
@ -402,25 +400,22 @@ struct RemoteIKeyValueStore : public IKeyValueStore {
return commitAndGetStorageBytes(this, commitReply);
}
Future<Optional<Value>> readValue(KeyRef key,
ReadType type = ReadType::NORMAL,
Optional<UID> debugID = Optional<UID>()) override {
return readValueImpl(this, IKVSGetValueRequest{ key, type, debugID, ReplyPromise<Optional<Value>>() });
Future<Optional<Value>> readValue(KeyRef key, Optional<ReadOptions> options = Optional<ReadOptions>()) override {
return readValueImpl(this, IKVSGetValueRequest{ key, options, ReplyPromise<Optional<Value>>() });
}
Future<Optional<Value>> readValuePrefix(KeyRef key,
int maxLength,
ReadType type = ReadType::NORMAL,
Optional<UID> debugID = Optional<UID>()) override {
Optional<ReadOptions> options = Optional<ReadOptions>()) override {
return interf.readValuePrefix.getReply(
IKVSReadValuePrefixRequest{ key, maxLength, type, debugID, ReplyPromise<Optional<Value>>() });
IKVSReadValuePrefixRequest{ key, maxLength, options, ReplyPromise<Optional<Value>>() });
}
Future<RangeResult> readRange(KeyRangeRef keys,
int rowLimit = 1 << 30,
int byteLimit = 1 << 30,
ReadType type = ReadType::NORMAL) override {
IKVSReadRangeRequest req{ keys, rowLimit, byteLimit, type, ReplyPromise<IKVSReadRangeReply>() };
Optional<ReadOptions> options = Optional<ReadOptions>()) override {
IKVSReadRangeRequest req{ keys, rowLimit, byteLimit, options, ReplyPromise<IKVSReadRangeReply>() };
return fmap([](const IKVSReadRangeReply& reply) { return reply.toRangeResult(); },
interf.readRange.getReply(req));
}

View File

@ -46,51 +46,8 @@ struct StorageMetricSample {
explicit StorageMetricSample(int64_t metricUnitsPerSample) : metricUnitsPerSample(metricUnitsPerSample) {}
int64_t getEstimate(KeyRangeRef keys) const { return sample.sumRange(keys.begin, keys.end); }
KeyRef splitEstimate(KeyRangeRef range, int64_t offset, bool front = true) const {
auto fwd_split = sample.index(front ? sample.sumTo(sample.lower_bound(range.begin)) + offset
: sample.sumTo(sample.lower_bound(range.end)) - offset);
if (fwd_split == sample.end() || *fwd_split >= range.end)
return range.end;
if (!front && *fwd_split <= range.begin)
return range.begin;
auto bck_split = fwd_split;
// Butterfly search - start at midpoint then go in both directions.
while ((fwd_split != sample.end() && *fwd_split < range.end) ||
(bck_split != sample.begin() && *bck_split > range.begin)) {
if (bck_split != sample.begin() && *bck_split > range.begin) {
auto it = bck_split;
bck_split.decrementNonEnd();
KeyRef split = keyBetween(KeyRangeRef(
bck_split != sample.begin() ? std::max<KeyRef>(*bck_split, range.begin) : range.begin, *it));
if (!front || (getEstimate(KeyRangeRef(range.begin, split)) > 0 &&
split.size() <= CLIENT_KNOBS->SPLIT_KEY_SIZE_LIMIT))
return split;
}
if (fwd_split != sample.end() && *fwd_split < range.end) {
auto it = fwd_split;
++it;
KeyRef split = keyBetween(
KeyRangeRef(*fwd_split, it != sample.end() ? std::min<KeyRef>(*it, range.end) : range.end));
if (front || (getEstimate(KeyRangeRef(split, range.end)) > 0 &&
split.size() <= CLIENT_KNOBS->SPLIT_KEY_SIZE_LIMIT))
return split;
fwd_split = it;
}
}
// If we didn't return above, we didn't find anything.
TraceEvent(SevWarn, "CannotSplitLastSampleKey").detail("Range", range).detail("Offset", offset);
return front ? range.end : range.begin;
}
int64_t getEstimate(KeyRangeRef keys) const;
KeyRef splitEstimate(KeyRangeRef range, int64_t offset, bool front = true) const;
};
struct TransientStorageMetricSample : StorageMetricSample {
@ -98,83 +55,18 @@ struct TransientStorageMetricSample : StorageMetricSample {
explicit TransientStorageMetricSample(int64_t metricUnitsPerSample) : StorageMetricSample(metricUnitsPerSample) {}
// Returns the sampled metric value (possibly 0, possibly increased by the sampling factor)
int64_t addAndExpire(KeyRef key, int64_t metric, double expiration) {
int64_t x = add(key, metric);
if (x)
queue.emplace_back(expiration, std::make_pair(*sample.find(key), -x));
return x;
}
int64_t addAndExpire(KeyRef key, int64_t metric, double expiration);
// FIXME: both versions of erase are broken, because they do not remove items in the queue with will subtract a
// metric from the value sometime in the future
int64_t erase(KeyRef key) {
auto it = sample.find(key);
if (it == sample.end())
return 0;
int64_t x = sample.getMetric(it);
sample.erase(it);
return x;
}
void erase(KeyRangeRef keys) { sample.erase(keys.begin, keys.end); }
int64_t erase(KeyRef key);
void erase(KeyRangeRef keys);
void poll(KeyRangeMap<std::vector<PromiseStream<StorageMetrics>>>& waitMap, StorageMetrics m) {
double now = ::now();
while (queue.size() && queue.front().first <= now) {
KeyRef key = queue.front().second.first;
int64_t delta = queue.front().second.second;
ASSERT(delta != 0);
void poll(KeyRangeMap<std::vector<PromiseStream<StorageMetrics>>>& waitMap, StorageMetrics m);
if (sample.addMetric(key, delta) == 0)
sample.erase(key);
StorageMetrics deltaM = m * delta;
auto v = waitMap[key];
for (int i = 0; i < v.size(); i++) {
CODE_PROBE(true, "TransientStorageMetricSample poll update");
v[i].send(deltaM);
}
queue.pop_front();
}
}
void poll() {
double now = ::now();
while (queue.size() && queue.front().first <= now) {
KeyRef key = queue.front().second.first;
int64_t delta = queue.front().second.second;
ASSERT(delta != 0);
if (sample.addMetric(key, delta) == 0)
sample.erase(key);
queue.pop_front();
}
}
void poll();
private:
bool roll(KeyRef key, int64_t metric) const {
return deterministicRandom()->random01() <
(double)metric / metricUnitsPerSample; //< SOMEDAY: Better randomInt64?
}
int64_t add(KeyRef key, int64_t metric) {
if (!metric)
return 0;
int64_t mag = metric < 0 ? -metric : metric;
if (mag < metricUnitsPerSample) {
if (!roll(key, mag))
return 0;
metric = metric < 0 ? -metricUnitsPerSample : metricUnitsPerSample;
}
if (sample.addMetric(key, metric) == 0)
sample.erase(key);
return metric;
}
bool roll(KeyRef key, int64_t metric) const;
int64_t add(KeyRef key, int64_t metric);
};
struct StorageServerMetrics {
@ -190,131 +82,23 @@ struct StorageServerMetrics {
bandwidthSample(SERVER_KNOBS->BANDWIDTH_UNITS_PER_SAMPLE),
bytesReadSample(SERVER_KNOBS->BYTES_READ_UNITS_PER_SAMPLE) {}
// Get the current estimated metrics for the given keys
StorageMetrics getMetrics(KeyRangeRef const& keys) const {
StorageMetrics result;
result.bytes = byteSample.getEstimate(keys);
result.bytesPerKSecond =
bandwidthSample.getEstimate(keys) * SERVER_KNOBS->STORAGE_METRICS_AVERAGE_INTERVAL_PER_KSECONDS;
result.iosPerKSecond =
iopsSample.getEstimate(keys) * SERVER_KNOBS->STORAGE_METRICS_AVERAGE_INTERVAL_PER_KSECONDS;
result.bytesReadPerKSecond =
bytesReadSample.getEstimate(keys) * SERVER_KNOBS->STORAGE_METRICS_AVERAGE_INTERVAL_PER_KSECONDS;
return result;
}
StorageMetrics getMetrics(KeyRangeRef const& keys) const;
// Called when metrics should change (IO for a given key)
// Notifies waiting WaitMetricsRequests through waitMetricsMap, and updates metricsAverageQueue and metricsSampleMap
void notify(KeyRef key, StorageMetrics& metrics) {
ASSERT(metrics.bytes == 0); // ShardNotifyMetrics
if (g_network->isSimulated()) {
CODE_PROBE(metrics.bytesPerKSecond != 0, "ShardNotifyMetrics bytes");
CODE_PROBE(metrics.iosPerKSecond != 0, "ShardNotifyMetrics ios");
CODE_PROBE(metrics.bytesReadPerKSecond != 0, "ShardNotifyMetrics bytesRead");
}
void notify(KeyRef key, StorageMetrics& metrics);
double expire = now() + SERVER_KNOBS->STORAGE_METRICS_AVERAGE_INTERVAL;
void notifyBytesReadPerKSecond(KeyRef key, int64_t in);
StorageMetrics notifyMetrics;
if (metrics.bytesPerKSecond)
notifyMetrics.bytesPerKSecond = bandwidthSample.addAndExpire(key, metrics.bytesPerKSecond, expire) *
SERVER_KNOBS->STORAGE_METRICS_AVERAGE_INTERVAL_PER_KSECONDS;
if (metrics.iosPerKSecond)
notifyMetrics.iosPerKSecond = iopsSample.addAndExpire(key, metrics.iosPerKSecond, expire) *
SERVER_KNOBS->STORAGE_METRICS_AVERAGE_INTERVAL_PER_KSECONDS;
if (metrics.bytesReadPerKSecond)
notifyMetrics.bytesReadPerKSecond = bytesReadSample.addAndExpire(key, metrics.bytesReadPerKSecond, expire) *
SERVER_KNOBS->STORAGE_METRICS_AVERAGE_INTERVAL_PER_KSECONDS;
if (!notifyMetrics.allZero()) {
auto& v = waitMetricsMap[key];
for (int i = 0; i < v.size(); i++) {
if (g_network->isSimulated()) {
CODE_PROBE(true, "shard notify metrics");
}
// ShardNotifyMetrics
v[i].send(notifyMetrics);
}
}
}
// Due to the fact that read sampling will be called on all reads, use this specialized function to avoid overhead
// around branch misses and unnecessary stack allocation which eventually addes up under heavy load.
void notifyBytesReadPerKSecond(KeyRef key, int64_t in) {
double expire = now() + SERVER_KNOBS->STORAGE_METRICS_AVERAGE_INTERVAL;
int64_t bytesReadPerKSecond =
bytesReadSample.addAndExpire(key, in, expire) * SERVER_KNOBS->STORAGE_METRICS_AVERAGE_INTERVAL_PER_KSECONDS;
if (bytesReadPerKSecond > 0) {
StorageMetrics notifyMetrics;
notifyMetrics.bytesReadPerKSecond = bytesReadPerKSecond;
auto& v = waitMetricsMap[key];
for (int i = 0; i < v.size(); i++) {
CODE_PROBE(true, "ShardNotifyMetrics");
v[i].send(notifyMetrics);
}
}
}
// Called by StorageServerDisk when the size of a key in byteSample changes, to notify WaitMetricsRequest
// Should not be called for keys past allKeys.end
void notifyBytes(RangeMap<Key, std::vector<PromiseStream<StorageMetrics>>, KeyRangeRef>::iterator shard,
int64_t bytes) {
ASSERT(shard.end() <= allKeys.end);
int64_t bytes);
StorageMetrics notifyMetrics;
notifyMetrics.bytes = bytes;
for (int i = 0; i < shard.value().size(); i++) {
CODE_PROBE(true, "notifyBytes");
shard.value()[i].send(notifyMetrics);
}
}
void notifyBytes(KeyRef key, int64_t bytes);
// Called by StorageServerDisk when the size of a key in byteSample changes, to notify WaitMetricsRequest
void notifyBytes(KeyRef key, int64_t bytes) {
if (key >= allKeys.end) // Do not notify on changes to internal storage server state
return;
void notifyNotReadable(KeyRangeRef keys);
notifyBytes(waitMetricsMap.rangeContaining(key), bytes);
}
// Called when a range of keys becomes unassigned (and therefore not readable), to notify waiting
// WaitMetricsRequests (also other types of wait
// requests in the future?)
void notifyNotReadable(KeyRangeRef keys) {
auto rs = waitMetricsMap.intersectingRanges(keys);
for (auto r = rs.begin(); r != rs.end(); ++r) {
auto& v = r->value();
CODE_PROBE(v.size(), "notifyNotReadable() sending errors to intersecting ranges");
for (int n = 0; n < v.size(); n++)
v[n].sendError(wrong_shard_server());
}
}
// Called periodically (~1 sec intervals) to remove older IOs from the averages
// Removes old entries from metricsAverageQueue, updates metricsSampleMap accordingly, and notifies
// WaitMetricsRequests through waitMetricsMap.
void poll() {
{
StorageMetrics m;
m.bytesPerKSecond = SERVER_KNOBS->STORAGE_METRICS_AVERAGE_INTERVAL_PER_KSECONDS;
bandwidthSample.poll(waitMetricsMap, m);
}
{
StorageMetrics m;
m.iosPerKSecond = SERVER_KNOBS->STORAGE_METRICS_AVERAGE_INTERVAL_PER_KSECONDS;
iopsSample.poll(waitMetricsMap, m);
}
{
StorageMetrics m;
m.bytesReadPerKSecond = SERVER_KNOBS->STORAGE_METRICS_AVERAGE_INTERVAL_PER_KSECONDS;
bytesReadSample.poll(waitMetricsMap, m);
}
// bytesSample doesn't need polling because we never call addExpire() on it
}
void poll();
// static void waitMetrics( StorageServerMetrics* const& self, WaitMetricsRequest const& req );
// This function can run on untrusted user data. We must validate all divisions carefully.
KeyRef getSplitKey(int64_t remaining,
int64_t estimated,
int64_t limits,
@ -325,276 +109,32 @@ struct StorageServerMetrics {
double divisor,
KeyRef const& lastKey,
KeyRef const& key,
bool hasUsed) const {
ASSERT(remaining >= 0);
ASSERT(limits > 0);
ASSERT(divisor > 0);
bool hasUsed) const;
if (limits < infinity / 2) {
int64_t expectedSize;
if (isLastShard || remaining > estimated) {
double remaining_divisor = (double(remaining) / limits) + 0.5;
expectedSize = remaining / remaining_divisor;
} else {
// If we are here, then estimated >= remaining >= 0
double estimated_divisor = (double(estimated) / limits) + 0.5;
expectedSize = remaining / estimated_divisor;
}
if (remaining > expectedSize) {
// This does the conversion from native units to bytes using the divisor.
double offset = (expectedSize - used) / divisor;
if (offset <= 0)
return hasUsed ? lastKey : key;
return sample.splitEstimate(
KeyRangeRef(lastKey, key),
offset * ((1.0 - SERVER_KNOBS->SPLIT_JITTER_AMOUNT) +
2 * deterministicRandom()->random01() * SERVER_KNOBS->SPLIT_JITTER_AMOUNT));
}
}
return key;
}
void splitMetrics(SplitMetricsRequest req) const {
int minSplitBytes = req.minSplitBytes.present() ? req.minSplitBytes.get() : SERVER_KNOBS->MIN_SHARD_BYTES;
try {
SplitMetricsReply reply;
KeyRef lastKey = req.keys.begin;
StorageMetrics used = req.used;
StorageMetrics estimated = req.estimated;
StorageMetrics remaining = getMetrics(req.keys) + used;
//TraceEvent("SplitMetrics").detail("Begin", req.keys.begin).detail("End", req.keys.end).detail("Remaining", remaining.bytes).detail("Used", used.bytes).detail("MinSplitBytes", minSplitBytes);
while (true) {
if (remaining.bytes < 2 * minSplitBytes)
break;
KeyRef key = req.keys.end;
bool hasUsed = used.bytes != 0 || used.bytesPerKSecond != 0 || used.iosPerKSecond != 0;
key = getSplitKey(remaining.bytes,
estimated.bytes,
req.limits.bytes,
used.bytes,
req.limits.infinity,
req.isLastShard,
byteSample,
1,
lastKey,
key,
hasUsed);
if (used.bytes < minSplitBytes)
key = std::max(
key, byteSample.splitEstimate(KeyRangeRef(lastKey, req.keys.end), minSplitBytes - used.bytes));
key = getSplitKey(remaining.iosPerKSecond,
estimated.iosPerKSecond,
req.limits.iosPerKSecond,
used.iosPerKSecond,
req.limits.infinity,
req.isLastShard,
iopsSample,
SERVER_KNOBS->STORAGE_METRICS_AVERAGE_INTERVAL_PER_KSECONDS,
lastKey,
key,
hasUsed);
key = getSplitKey(remaining.bytesPerKSecond,
estimated.bytesPerKSecond,
req.limits.bytesPerKSecond,
used.bytesPerKSecond,
req.limits.infinity,
req.isLastShard,
bandwidthSample,
SERVER_KNOBS->STORAGE_METRICS_AVERAGE_INTERVAL_PER_KSECONDS,
lastKey,
key,
hasUsed);
ASSERT(key != lastKey || hasUsed);
if (key == req.keys.end)
break;
reply.splits.push_back_deep(reply.splits.arena(), key);
StorageMetrics diff = (getMetrics(KeyRangeRef(lastKey, key)) + used);
remaining -= diff;
estimated -= diff;
used = StorageMetrics();
lastKey = key;
}
reply.used = getMetrics(KeyRangeRef(lastKey, req.keys.end)) + used;
req.reply.send(reply);
} catch (Error& e) {
req.reply.sendError(e);
}
}
void splitMetrics(SplitMetricsRequest req) const;
void getStorageMetrics(GetStorageMetricsRequest req,
StorageBytes sb,
double bytesInputRate,
int64_t versionLag,
double lastUpdate) const {
GetStorageMetricsReply rep;
// SOMEDAY: make bytes dynamic with hard disk space
rep.load = getMetrics(allKeys);
if (sb.free < 1e9) {
TraceEvent(SevWarn, "PhysicalDiskMetrics")
.suppressFor(60.0)
.detail("Free", sb.free)
.detail("Total", sb.total)
.detail("Available", sb.available)
.detail("Load", rep.load.bytes);
}
rep.available.bytes = sb.available;
rep.available.iosPerKSecond = 10e6;
rep.available.bytesPerKSecond = 100e9;
rep.available.bytesReadPerKSecond = 100e9;
rep.capacity.bytes = sb.total;
rep.capacity.iosPerKSecond = 10e6;
rep.capacity.bytesPerKSecond = 100e9;
rep.capacity.bytesReadPerKSecond = 100e9;
rep.bytesInputRate = bytesInputRate;
rep.versionLag = versionLag;
rep.lastUpdate = lastUpdate;
req.reply.send(rep);
}
double lastUpdate) const;
Future<Void> waitMetrics(WaitMetricsRequest req, Future<Void> delay);
// Given a read hot shard, this function will divide the shard into chunks and find those chunks whose
// readBytes/sizeBytes exceeds the `readDensityRatio`. Please make sure to run unit tests
// `StorageMetricsSampleTests.txt` after change made.
std::vector<ReadHotRangeWithMetrics> getReadHotRanges(KeyRangeRef shard,
double readDensityRatio,
int64_t baseChunkSize,
int64_t minShardReadBandwidthPerKSeconds) const {
std::vector<ReadHotRangeWithMetrics> toReturn;
int64_t minShardReadBandwidthPerKSeconds) const;
double shardSize = (double)byteSample.getEstimate(shard);
int64_t shardReadBandwidth = bytesReadSample.getEstimate(shard);
if (shardReadBandwidth * SERVER_KNOBS->STORAGE_METRICS_AVERAGE_INTERVAL_PER_KSECONDS <=
minShardReadBandwidthPerKSeconds) {
return toReturn;
}
if (shardSize <= baseChunkSize) {
// Shard is small, use it as is
if (bytesReadSample.getEstimate(shard) > (readDensityRatio * shardSize)) {
toReturn.emplace_back(shard,
bytesReadSample.getEstimate(shard) / shardSize,
bytesReadSample.getEstimate(shard) /
SERVER_KNOBS->STORAGE_METRICS_AVERAGE_INTERVAL);
}
return toReturn;
}
KeyRef beginKey = shard.begin;
auto endKey =
byteSample.sample.index(byteSample.sample.sumTo(byteSample.sample.lower_bound(beginKey)) + baseChunkSize);
while (endKey != byteSample.sample.end()) {
if (*endKey > shard.end) {
endKey = byteSample.sample.lower_bound(shard.end);
if (*endKey == beginKey) {
// No need to increment endKey since otherwise it would stuck here forever.
break;
}
}
if (*endKey == beginKey) {
++endKey;
continue;
}
if (bytesReadSample.getEstimate(KeyRangeRef(beginKey, *endKey)) >
(readDensityRatio * std::max(baseChunkSize, byteSample.getEstimate(KeyRangeRef(beginKey, *endKey))))) {
auto range = KeyRangeRef(beginKey, *endKey);
if (!toReturn.empty() && toReturn.back().keys.end == range.begin) {
// in case two consecutive chunks both are over the ratio, merge them.
range = KeyRangeRef(toReturn.back().keys.begin, *endKey);
toReturn.pop_back();
}
toReturn.emplace_back(
range,
(double)bytesReadSample.getEstimate(range) / std::max(baseChunkSize, byteSample.getEstimate(range)),
bytesReadSample.getEstimate(range) / SERVER_KNOBS->STORAGE_METRICS_AVERAGE_INTERVAL);
}
beginKey = *endKey;
endKey = byteSample.sample.index(byteSample.sample.sumTo(byteSample.sample.lower_bound(beginKey)) +
baseChunkSize);
}
return toReturn;
}
void getReadHotRanges(ReadHotSubRangeRequest req) const;
void getReadHotRanges(ReadHotSubRangeRequest req) const {
ReadHotSubRangeReply reply;
auto _ranges = getReadHotRanges(req.keys,
SERVER_KNOBS->SHARD_MAX_READ_DENSITY_RATIO,
SERVER_KNOBS->READ_HOT_SUB_RANGE_CHUNK_SIZE,
SERVER_KNOBS->SHARD_READ_HOT_BANDWIDTH_MIN_PER_KSECONDS);
reply.readHotRanges = VectorRef(_ranges.data(), _ranges.size());
req.reply.send(reply);
}
std::vector<KeyRef> getSplitPoints(KeyRangeRef range, int64_t chunkSize, Optional<Key> prefixToRemove) const;
std::vector<KeyRef> getSplitPoints(KeyRangeRef range, int64_t chunkSize, Optional<Key> prefixToRemove) const {
std::vector<KeyRef> toReturn;
KeyRef beginKey = range.begin;
IndexedSet<Key, int64_t>::const_iterator endKey =
byteSample.sample.index(byteSample.sample.sumTo(byteSample.sample.lower_bound(beginKey)) + chunkSize);
while (endKey != byteSample.sample.end()) {
if (*endKey > range.end) {
break;
}
if (*endKey == beginKey) {
++endKey;
continue;
}
KeyRef splitPoint = *endKey;
if (prefixToRemove.present()) {
splitPoint = splitPoint.removePrefix(prefixToRemove.get());
}
toReturn.push_back(splitPoint);
beginKey = *endKey;
endKey =
byteSample.sample.index(byteSample.sample.sumTo(byteSample.sample.lower_bound(beginKey)) + chunkSize);
}
return toReturn;
}
void getSplitPoints(SplitRangeRequest req, Optional<Key> prefix) const {
SplitRangeReply reply;
KeyRangeRef range = req.keys;
if (prefix.present()) {
range = range.withPrefix(prefix.get(), req.arena);
}
std::vector<KeyRef> points = getSplitPoints(range, req.chunkSize, prefix);
reply.splitPoints.append_deep(reply.splitPoints.arena(), points.data(), points.size());
req.reply.send(reply);
}
void getSplitPoints(SplitRangeRequest req, Optional<Key> prefix) const;
private:
static void collapse(KeyRangeMap<int>& map, KeyRef const& key) {
auto range = map.rangeContaining(key);
if (range == map.ranges().begin() || range == map.ranges().end())
return;
int value = range->value();
auto prev = range;
--prev;
if (prev->value() != value)
return;
KeyRange keys = KeyRangeRef(prev->begin(), range->end());
map.insert(keys, value);
}
static void add(KeyRangeMap<int>& map, KeyRangeRef const& keys, int delta) {
auto rs = map.modify(keys);
for (auto r = rs.begin(); r != rs.end(); ++r)
r->value() += delta;
collapse(map, keys.begin);
collapse(map, keys.end);
}
static void collapse(KeyRangeMap<int>& map, KeyRef const& key);
static void add(KeyRangeMap<int>& map, KeyRangeRef const& keys, int delta);
};
// Contains information about whether or not a key-value pair should be included in a byte sample

View File

@ -50,8 +50,6 @@ public:
virtual int64_t manualThrottleCount() const = 0;
virtual bool isAutoThrottlingEnabled() const = 0;
virtual GlobalTagThrottlerStatusReply getGlobalTagThrottlerStatusReply() const = 0;
// Based on the busiest read and write tags in the provided storage queue info, update
// tag throttling limits.
virtual Future<Void> tryUpdateAutoThrottling(StorageQueueInfo const&) = 0;
@ -75,7 +73,6 @@ public:
int64_t manualThrottleCount() const override;
bool isAutoThrottlingEnabled() const override;
Future<Void> tryUpdateAutoThrottling(StorageQueueInfo const&) override;
GlobalTagThrottlerStatusReply getGlobalTagThrottlerStatusReply() const override { return {}; }
};
class GlobalTagThrottler : public ITagThrottler {
@ -99,8 +96,6 @@ public:
PrioritizedTransactionTagMap<ClientTagThrottleLimits> getClientRates() override;
PrioritizedTransactionTagMap<double> getProxyRates(int numProxies) override;
GlobalTagThrottlerStatusReply getGlobalTagThrottlerStatusReply() const override;
// Testing only:
public:
void setQuota(TransactionTagRef, ThrottleApi::TagQuotaValue const&);

View File

@ -1092,6 +1092,7 @@ ACTOR Future<Void> encryptKeyProxyServer(EncryptKeyProxyInterface ei, Reference<
class IKeyValueStore;
class ServerCoordinators;
class IDiskQueue;
class IEncryptionKeyProvider;
ACTOR Future<Void> storageServer(IKeyValueStore* persistentData,
StorageServerInterface ssi,
Tag seedTag,
@ -1100,7 +1101,8 @@ ACTOR Future<Void> storageServer(IKeyValueStore* persistentData,
Version tssSeedVersion,
ReplyPromise<InitializeStorageReply> recruitReply,
Reference<AsyncVar<ServerDBInfo> const> db,
std::string folder);
std::string folder,
Reference<IEncryptionKeyProvider> encryptionKeyProvider);
ACTOR Future<Void> storageServer(
IKeyValueStore* persistentData,
StorageServerInterface ssi,
@ -1108,7 +1110,8 @@ ACTOR Future<Void> storageServer(
std::string folder,
Promise<Void> recovered,
Reference<IClusterConnectionRecord>
connRecord); // changes pssi->id() to be the recovered ID); // changes pssi->id() to be the recovered ID
connRecord, // changes pssi->id() to be the recovered ID); // changes pssi->id() to be the recovered ID
Reference<IEncryptionKeyProvider> encryptionKeyProvider);
ACTOR Future<Void> masterServer(MasterInterface mi,
Reference<AsyncVar<ServerDBInfo> const> db,
Reference<AsyncVar<Optional<ClusterControllerFullInterface>> const> ccInterface,

File diff suppressed because it is too large Load Diff

View File

@ -1253,7 +1253,8 @@ ACTOR Future<Void> storageServerRollbackRebooter(std::set<std::pair<UID, KeyValu
int64_t memoryLimit,
IKeyValueStore* store,
bool validateDataFiles,
Promise<Void>* rebootKVStore) {
Promise<Void>* rebootKVStore,
Reference<IEncryptionKeyProvider> encryptionKeyProvider) {
state TrackRunningStorage _(id, storeType, runningStorages);
loop {
ErrorOr<Void> e = wait(errorOr(prevStorageServer));
@ -1320,8 +1321,13 @@ ACTOR Future<Void> storageServerRollbackRebooter(std::set<std::pair<UID, KeyValu
DUMPTOKEN(recruited.changeFeedPop);
DUMPTOKEN(recruited.changeFeedVersionUpdate);
prevStorageServer =
storageServer(store, recruited, db, folder, Promise<Void>(), Reference<IClusterConnectionRecord>(nullptr));
prevStorageServer = storageServer(store,
recruited,
db,
folder,
Promise<Void>(),
Reference<IClusterConnectionRecord>(nullptr),
encryptionKeyProvider);
prevStorageServer = handleIOErrors(prevStorageServer, store, id, store->onClosed());
}
}
@ -1718,6 +1724,8 @@ ACTOR Future<Void> workerServer(Reference<IClusterConnectionRecord> connRecord,
if (s.storedComponent == DiskStore::Storage) {
LocalLineage _;
getCurrentLineage()->modify(&RoleLineage::role) = ProcessClass::ClusterRole::Storage;
Reference<IEncryptionKeyProvider> encryptionKeyProvider =
makeReference<TenantAwareEncryptionKeyProvider>(dbInfo);
IKeyValueStore* kv = openKVStore(
s.storeType,
s.filename,
@ -1730,7 +1738,8 @@ ACTOR Future<Void> workerServer(Reference<IClusterConnectionRecord> connRecord,
? (/* Disable for RocksDB */ s.storeType != KeyValueStoreType::SSD_ROCKSDB_V1 &&
s.storeType != KeyValueStoreType::SSD_SHARDED_ROCKSDB &&
deterministicRandom()->coinflip())
: true));
: true),
encryptionKeyProvider);
Future<Void> kvClosed =
kv->onClosed() ||
rebootKVSPromise.getFuture() /* clear the onClosed() Future in actorCollection when rebooting */;
@ -1778,7 +1787,8 @@ ACTOR Future<Void> workerServer(Reference<IClusterConnectionRecord> connRecord,
DUMPTOKEN(recruited.changeFeedVersionUpdate);
Promise<Void> recovery;
Future<Void> f = storageServer(kv, recruited, dbInfo, folder, recovery, connRecord);
Future<Void> f =
storageServer(kv, recruited, dbInfo, folder, recovery, connRecord, encryptionKeyProvider);
recoveries.push_back(recovery.getFuture());
f = handleIOErrors(f, kv, s.storeID, kvClosed);
f = storageServerRollbackRebooter(&runningStorages,
@ -1794,7 +1804,8 @@ ACTOR Future<Void> workerServer(Reference<IClusterConnectionRecord> connRecord,
memoryLimit,
kv,
validateDataFiles,
&rebootKVSPromise);
&rebootKVSPromise,
encryptionKeyProvider);
errorForwarders.add(forwardError(errors, ssRole, recruited.id(), f));
} else if (s.storedComponent == DiskStore::TLogData) {
LocalLineage _;
@ -2329,7 +2340,8 @@ ACTOR Future<Void> workerServer(Reference<IClusterConnectionRecord> connRecord,
folder,
isTss ? testingStoragePrefix.toString() : fileStoragePrefix.toString(),
recruited.id());
Reference<IEncryptionKeyProvider> encryptionKeyProvider =
makeReference<TenantAwareEncryptionKeyProvider>(dbInfo);
IKeyValueStore* data = openKVStore(
req.storeType,
filename,
@ -2342,7 +2354,8 @@ ACTOR Future<Void> workerServer(Reference<IClusterConnectionRecord> connRecord,
? (/* Disable for RocksDB */ req.storeType != KeyValueStoreType::SSD_ROCKSDB_V1 &&
req.storeType != KeyValueStoreType::SSD_SHARDED_ROCKSDB &&
deterministicRandom()->coinflip())
: true));
: true),
encryptionKeyProvider);
Future<Void> kvClosed =
data->onClosed() ||
@ -2359,7 +2372,8 @@ ACTOR Future<Void> workerServer(Reference<IClusterConnectionRecord> connRecord,
isTss ? req.tssPairIDAndVersion.get().second : 0,
storageReady,
dbInfo,
folder);
folder,
encryptionKeyProvider);
s = handleIOErrors(s, data, recruited.id(), kvClosed);
s = storageCache.removeOnReady(req.reqId, s);
s = storageServerRollbackRebooter(&runningStorages,
@ -2375,7 +2389,8 @@ ACTOR Future<Void> workerServer(Reference<IClusterConnectionRecord> connRecord,
memoryLimit,
data,
false,
&rebootKVSPromise2);
&rebootKVSPromise2,
encryptionKeyProvider);
errorForwarders.add(forwardError(errors, ssRole, recruited.id(), s));
} else if (storageCache.exists(req.reqId)) {
forwardPromise(req.reply, storageCache.get(req.reqId));

View File

@ -357,7 +357,7 @@ struct BlobGranuleRangesWorkload : TestWorkload {
bool fail7 = wait(self->isRangeActive(cx, KeyRangeRef(activeRange.begin, range.end)));
ASSERT(!fail7);
wait(self->tearDownRangeAfterUnit(cx, self, range));
wait(self->tearDownRangeAfterUnit(cx, self, activeRange));
return Void();
}

View File

@ -932,6 +932,7 @@ struct BlobGranuleVerifierWorkload : TestWorkload {
loop {
state RangeResult output;
state Version readVersion = invalidVersion;
state int64_t bufferedBytes = 0;
try {
Version ver = wait(tr.getReadVersion());
readVersion = ver;
@ -943,6 +944,11 @@ struct BlobGranuleVerifierWorkload : TestWorkload {
Standalone<RangeResultRef> res = waitNext(results.getFuture());
output.arena().dependsOn(res.arena());
output.append(output.arena(), res.begin(), res.size());
bufferedBytes += res.expectedSize();
// force checking if we have enough data
if (bufferedBytes >= 10 * SERVER_KNOBS->BG_SNAPSHOT_FILE_TARGET_BYTES) {
break;
}
}
} catch (Error& e) {
if (e.code() == error_code_operation_cancelled) {

View File

@ -284,6 +284,7 @@ struct ConfigureDatabaseWorkload : TestWorkload {
}
ACTOR Future<bool> _check(ConfigureDatabaseWorkload* self, Database cx) {
wait(delay(30.0));
// only storage_migration_type=gradual && perpetual_storage_wiggle=1 need this check because in QuietDatabase
// perpetual wiggle will be forced to close For other cases, later ConsistencyCheck will check KV store type
// there

View File

@ -70,6 +70,8 @@ struct SaveAndKillWorkload : TestWorkload {
ini.SetBoolValue("META", "enableEncryption", SERVER_KNOBS->ENABLE_ENCRYPTION);
ini.SetBoolValue("META", "enableTLogEncryption", SERVER_KNOBS->ENABLE_TLOG_ENCRYPTION);
ini.SetBoolValue("META", "enableStorageServerEncryption", SERVER_KNOBS->ENABLE_STORAGE_SERVER_ENCRYPTION);
ini.SetBoolValue("META", "enableBlobGranuleEncryption", SERVER_KNOBS->ENABLE_BLOB_GRANULE_ENCRYPTION);
std::vector<ISimulator::ProcessInfo*> processes = g_simulator.getAllProcesses();
std::map<NetworkAddress, ISimulator::ProcessInfo*> rebootingProcesses = g_simulator.currentlyRebootingProcesses;

View File

@ -88,8 +88,3 @@ endif()
add_executable(mkcert MkCertCli.cpp)
target_link_libraries(mkcert PUBLIC flow)
add_executable(mtls_unittest TLSTest.cpp)
target_link_libraries(mtls_unittest PUBLIC flow)
add_test(NAME mutual_tls_unittest
COMMAND $<TARGET_FILE:mtls_unittest>)

View File

@ -129,6 +129,10 @@ void FlowKnobs::initialize(Randomize randomize, IsSimulated isSimulated) {
init( NETWORK_TEST_REQUEST_COUNT, 0 ); // 0 -> run forever
init( NETWORK_TEST_REQUEST_SIZE, 1 );
init( NETWORK_TEST_SCRIPT_MODE, false );
//Authorization
init( PUBLIC_KEY_FILE_MAX_SIZE, 1024 * 1024 );
init( PUBLIC_KEY_FILE_REFRESH_INTERVAL_SECONDS, 30 );
init( MAX_CACHED_EXPIRED_TOKENS, 1024 );
//AsyncFileCached

View File

@ -166,13 +166,13 @@ PrivateKey makeEcP256() {
return PrivateKey(DerEncoded{}, StringRef(buf, len));
}
PrivateKey makeRsa2048Bit() {
PrivateKey makeRsa4096Bit() {
auto kctx = AutoCPointer(::EVP_PKEY_CTX_new_id(EVP_PKEY_RSA, nullptr), &::EVP_PKEY_CTX_free);
OSSL_ASSERT(kctx);
auto key = AutoCPointer(nullptr, &::EVP_PKEY_free);
auto keyRaw = std::add_pointer_t<EVP_PKEY>();
OSSL_ASSERT(0 < ::EVP_PKEY_keygen_init(kctx));
OSSL_ASSERT(0 < ::EVP_PKEY_CTX_set_rsa_keygen_bits(kctx, 2048));
OSSL_ASSERT(0 < ::EVP_PKEY_CTX_set_rsa_keygen_bits(kctx, 4096));
OSSL_ASSERT(0 < ::EVP_PKEY_keygen(kctx, &keyRaw));
OSSL_ASSERT(keyRaw);
key.reset(keyRaw);

View File

@ -50,6 +50,7 @@
#include "flow/ProtocolVersion.h"
#include "flow/SendBufferIterator.h"
#include "flow/TLSConfig.actor.h"
#include "flow/WatchFile.actor.h"
#include "flow/genericactors.actor.h"
#include "flow/Util.h"
#include "flow/UnitTest.h"
@ -238,6 +239,7 @@ public:
int sslHandshakerThreadsStarted;
int sslPoolHandshakesInProgress;
TLSConfig tlsConfig;
Reference<TLSPolicy> activeTlsPolicy;
Future<Void> backgroundCertRefresh;
ETLSInitState tlsInitializedState;
@ -507,6 +509,8 @@ public:
NetworkAddress getPeerAddress() const override { return peer_address; }
bool hasTrustedPeer() const override { return true; }
UID getDebugID() const override { return id; }
tcp::socket& getSocket() override { return socket; }
@ -839,7 +843,7 @@ public:
explicit SSLConnection(boost::asio::io_service& io_service,
Reference<ReferencedObject<boost::asio::ssl::context>> context)
: id(nondeterministicRandom()->randomUniqueID()), socket(io_service), ssl_sock(socket, context->mutate()),
sslContext(context) {}
sslContext(context), has_trusted_peer(false) {}
explicit SSLConnection(Reference<ReferencedObject<boost::asio::ssl::context>> context, tcp::socket* existingSocket)
: id(nondeterministicRandom()->randomUniqueID()), socket(std::move(*existingSocket)),
@ -900,6 +904,9 @@ public:
try {
Future<Void> onHandshook;
ConfigureSSLStream(N2::g_net2->activeTlsPolicy, self->ssl_sock, [this](bool verifyOk) {
self->has_trusted_peer = verifyOk;
});
// If the background handshakers are not all busy, use one
if (N2::g_net2->sslPoolHandshakesInProgress < N2::g_net2->sslHandshakerThreadsStarted) {
@ -975,6 +982,10 @@ public:
try {
Future<Void> onHandshook;
ConfigureSSLStream(N2::g_net2->activeTlsPolicy, self->ssl_sock, [this](bool verifyOk) {
self->has_trusted_peer = verifyOk;
});
// If the background handshakers are not all busy, use one
if (N2::g_net2->sslPoolHandshakesInProgress < N2::g_net2->sslHandshakerThreadsStarted) {
holder = Hold(&N2::g_net2->sslPoolHandshakesInProgress);
@ -1108,6 +1119,8 @@ public:
NetworkAddress getPeerAddress() const override { return peer_address; }
bool hasTrustedPeer() const override { return has_trusted_peer; }
UID getDebugID() const override { return id; }
tcp::socket& getSocket() override { return socket; }
@ -1120,6 +1133,7 @@ private:
ssl_socket ssl_sock;
NetworkAddress peer_address;
Reference<ReferencedObject<boost::asio::ssl::context>> sslContext;
bool has_trusted_peer;
void init() {
// Socket settings that have to be set after connect or accept succeeds
@ -1165,6 +1179,16 @@ public:
NetworkAddress listenAddress)
: io_service(io_service), listenAddress(listenAddress), acceptor(io_service, tcpEndpoint(listenAddress)),
contextVar(contextVar) {
// when port 0 is passed in, a random port will be opened
// set listenAddress as the address with the actual port opened instead of port 0
if (listenAddress.port == 0) {
this->listenAddress = NetworkAddress::parse(acceptor.local_endpoint()
.address()
.to_string()
.append(":")
.append(std::to_string(acceptor.local_endpoint().port()))
.append(listenAddress.isTLS() ? ":tls" : ""));
}
platform::setCloseOnExec(acceptor.native_handle());
}
@ -1240,45 +1264,11 @@ Net2::Net2(const TLSConfig& tlsConfig, bool useThreadPool, bool useMetrics)
updateNow();
}
ACTOR static Future<Void> watchFileForChanges(std::string filename, AsyncTrigger* fileChanged) {
if (filename == "") {
return Never();
}
state bool firstRun = true;
state bool statError = false;
state std::time_t lastModTime = 0;
loop {
try {
std::time_t modtime = wait(IAsyncFileSystem::filesystem()->lastWriteTime(filename));
if (firstRun) {
lastModTime = modtime;
firstRun = false;
}
if (lastModTime != modtime || statError) {
lastModTime = modtime;
statError = false;
fileChanged->trigger();
}
} catch (Error& e) {
if (e.code() == error_code_io_error) {
// EACCES, ELOOP, ENOENT all come out as io_error(), but are more of a system
// configuration issue than an FDB problem. If we managed to load valid
// certificates, then there's no point in crashing, but we should complain
// loudly. IAsyncFile will log the error, but not necessarily as a warning.
TraceEvent(SevWarnAlways, "TLSCertificateRefreshStatError").detail("File", filename);
statError = true;
} else {
throw;
}
}
wait(delay(FLOW_KNOBS->TLS_CERT_REFRESH_DELAY_SECONDS));
}
}
ACTOR static Future<Void> reloadCertificatesOnChange(
TLSConfig config,
std::function<void()> onPolicyFailure,
AsyncVar<Reference<ReferencedObject<boost::asio::ssl::context>>>* contextVar) {
AsyncVar<Reference<ReferencedObject<boost::asio::ssl::context>>>* contextVar,
Reference<TLSPolicy>* policy) {
if (FLOW_KNOBS->TLS_CERT_REFRESH_DELAY_SECONDS <= 0) {
return Void();
}
@ -1292,9 +1282,13 @@ ACTOR static Future<Void> reloadCertificatesOnChange(
state int mismatches = 0;
state AsyncTrigger fileChanged;
state std::vector<Future<Void>> lifetimes;
lifetimes.push_back(watchFileForChanges(config.getCertificatePathSync(), &fileChanged));
lifetimes.push_back(watchFileForChanges(config.getKeyPathSync(), &fileChanged));
lifetimes.push_back(watchFileForChanges(config.getCAPathSync(), &fileChanged));
const int& intervalSeconds = FLOW_KNOBS->TLS_CERT_REFRESH_DELAY_SECONDS;
lifetimes.push_back(watchFileForChanges(
config.getCertificatePathSync(), &fileChanged, &intervalSeconds, "TLSCertificateRefreshStatError"));
lifetimes.push_back(
watchFileForChanges(config.getKeyPathSync(), &fileChanged, &intervalSeconds, "TLSKeyRefreshStatError"));
lifetimes.push_back(
watchFileForChanges(config.getCAPathSync(), &fileChanged, &intervalSeconds, "TLSCARefreshStatError"));
loop {
wait(fileChanged.onTrigger());
TraceEvent("TLSCertificateRefreshBegin").log();
@ -1302,7 +1296,8 @@ ACTOR static Future<Void> reloadCertificatesOnChange(
try {
LoadedTLSConfig loaded = wait(config.loadAsync());
boost::asio::ssl::context context(boost::asio::ssl::context::tls);
ConfigureSSLContext(loaded, &context, onPolicyFailure);
ConfigureSSLContext(loaded, context);
*policy = makeReference<TLSPolicy>(loaded, onPolicyFailure);
TraceEvent(SevInfo, "TLSCertificateRefreshSucceeded").log();
mismatches = 0;
contextVar->set(ReferencedObject<boost::asio::ssl::context>::from(std::move(context)));
@ -1334,12 +1329,15 @@ void Net2::initTLS(ETLSInitState targetState) {
.detail("KeyPath", tlsConfig.getKeyPathSync())
.detail("HasPassword", !loaded.getPassword().empty())
.detail("VerifyPeers", boost::algorithm::join(loaded.getVerifyPeers(), "|"));
ConfigureSSLContext(tlsConfig.loadSync(), &newContext, onPolicyFailure);
auto loadedTlsConfig = tlsConfig.loadSync();
ConfigureSSLContext(loadedTlsConfig, newContext);
activeTlsPolicy = makeReference<TLSPolicy>(loadedTlsConfig, onPolicyFailure);
sslContextVar.set(ReferencedObject<boost::asio::ssl::context>::from(std::move(newContext)));
} catch (Error& e) {
TraceEvent("Net2TLSInitError").error(e);
}
backgroundCertRefresh = reloadCertificatesOnChange(tlsConfig, onPolicyFailure, &sslContextVar);
backgroundCertRefresh =
reloadCertificatesOnChange(tlsConfig, onPolicyFailure, &sslContextVar, &activeTlsPolicy);
}
// If a TLS connection is actually going to be used then start background threads if configured

View File

@ -81,7 +81,7 @@ void LoadedTLSConfig::print(FILE* fp) {
int num_certs = 0;
boost::asio::ssl::context context(boost::asio::ssl::context::tls);
try {
ConfigureSSLContext(*this, &context);
ConfigureSSLContext(*this, context);
} catch (Error& e) {
fprintf(fp, "There was an error in loading the certificate chain.\n");
throw;
@ -109,51 +109,58 @@ void LoadedTLSConfig::print(FILE* fp) {
X509_STORE_CTX_free(store_ctx);
}
void ConfigureSSLContext(const LoadedTLSConfig& loaded,
boost::asio::ssl::context* context,
std::function<void()> onPolicyFailure) {
void ConfigureSSLContext(const LoadedTLSConfig& loaded, boost::asio::ssl::context& context) {
try {
context->set_options(boost::asio::ssl::context::default_workarounds);
context->set_verify_mode(boost::asio::ssl::context::verify_peer |
boost::asio::ssl::verify_fail_if_no_peer_cert);
context.set_options(boost::asio::ssl::context::default_workarounds);
auto verifyFailIfNoPeerCert = boost::asio::ssl::verify_fail_if_no_peer_cert;
// Servers get to accept connections without peer certs as "untrusted" clients
if (loaded.getEndpointType() == TLSEndpointType::SERVER)
verifyFailIfNoPeerCert = 0;
context.set_verify_mode(boost::asio::ssl::context::verify_peer | verifyFailIfNoPeerCert);
if (loaded.isTLSEnabled()) {
auto tlsPolicy = makeReference<TLSPolicy>(loaded.getEndpointType());
tlsPolicy->set_verify_peers({ loaded.getVerifyPeers() });
context->set_verify_callback(
[policy = tlsPolicy, onPolicyFailure](bool preverified, boost::asio::ssl::verify_context& ctx) {
bool success = policy->verify_peer(preverified, ctx.native_handle());
if (!success) {
onPolicyFailure();
}
return success;
});
} else {
// Insecurely always except if TLS is not enabled.
context->set_verify_callback([](bool, boost::asio::ssl::verify_context&) { return true; });
}
context->set_password_callback([password = loaded.getPassword()](
size_t, boost::asio::ssl::context::password_purpose) { return password; });
context.set_password_callback([password = loaded.getPassword()](
size_t, boost::asio::ssl::context::password_purpose) { return password; });
const std::string& CABytes = loaded.getCABytes();
if (CABytes.size()) {
context->add_certificate_authority(boost::asio::buffer(CABytes.data(), CABytes.size()));
context.add_certificate_authority(boost::asio::buffer(CABytes.data(), CABytes.size()));
}
const std::string& keyBytes = loaded.getKeyBytes();
if (keyBytes.size()) {
context->use_private_key(boost::asio::buffer(keyBytes.data(), keyBytes.size()),
boost::asio::ssl::context::pem);
context.use_private_key(boost::asio::buffer(keyBytes.data(), keyBytes.size()),
boost::asio::ssl::context::pem);
}
const std::string& certBytes = loaded.getCertificateBytes();
if (certBytes.size()) {
context->use_certificate_chain(boost::asio::buffer(certBytes.data(), certBytes.size()));
context.use_certificate_chain(boost::asio::buffer(certBytes.data(), certBytes.size()));
}
} catch (boost::system::system_error& e) {
TraceEvent("TLSConfigureError")
TraceEvent("TLSContextConfigureError")
.detail("What", e.what())
.detail("Value", e.code().value())
.detail("WhichMeans", TLSPolicy::ErrorString(e.code()));
throw tls_error();
}
}
void ConfigureSSLStream(Reference<TLSPolicy> policy,
boost::asio::ssl::stream<boost::asio::ip::tcp::socket&>& stream,
std::function<void(bool)> callback) {
try {
stream.set_verify_callback([policy, callback](bool preverified, boost::asio::ssl::verify_context& ctx) {
bool success = policy->verify_peer(preverified, ctx.native_handle());
if (!success) {
if (policy->on_failure)
policy->on_failure();
}
if (callback)
callback(success);
return success;
});
} catch (boost::system::system_error& e) {
TraceEvent("TLSStreamConfigureError")
.detail("What", e.what())
.detail("Value", e.code().value())
.detail("WhichMeans", TLSPolicy::ErrorString(e.code()));
@ -261,6 +268,11 @@ LoadedTLSConfig TLSConfig::loadSync() const {
return loaded;
}
TLSPolicy::TLSPolicy(const LoadedTLSConfig& loaded, std::function<void()> on_failure)
: rules(), on_failure(std::move(on_failure)), is_client(loaded.getEndpointType() == TLSEndpointType::CLIENT) {
set_verify_peers(loaded.getVerifyPeers());
}
// And now do the same thing, but async...
ACTOR static Future<Void> readEntireFile(std::string filename, std::string* destination) {

View File

@ -195,6 +195,9 @@ public:
int NETWORK_TEST_REQUEST_SIZE;
bool NETWORK_TEST_SCRIPT_MODE;
// Authorization
int PUBLIC_KEY_FILE_MAX_SIZE;
int PUBLIC_KEY_FILE_REFRESH_INTERVAL_SECONDS;
int MAX_CACHED_EXPIRED_TOKENS;
// AsyncFileCached

View File

@ -39,7 +39,7 @@ void printPrivateKey(FILE* out, StringRef privateKeyPem);
PrivateKey makeEcP256();
PrivateKey makeRsa2048Bit();
PrivateKey makeRsa4096Bit();
struct Asn1EntryRef {
// field must match one of ASN.1 object short/long names: e.g. "C", "countryName", "CN", "commonName",

View File

@ -320,7 +320,7 @@ std::string readFileBytes(std::string const& filename, int maxSize);
// Read a file into memory supplied by the caller
// If 'len' is greater than file size, then read the filesize bytes.
void readFileBytes(std::string const& filename, uint8_t* buff, int64_t len);
size_t readFileBytes(std::string const& filename, uint8_t* buff, int64_t len);
// Write data buffer into file
void writeFileBytes(std::string const& filename, const char* data, size_t count);

View File

@ -33,6 +33,8 @@
#include <string>
#include <vector>
#include <boost/system/system_error.hpp>
#include <boost/asio/ip/tcp.hpp>
#include <boost/asio/ssl.hpp>
#include "flow/FastRef.h"
#include "flow/Knobs.h"
#include "flow/flow.h"
@ -201,21 +203,23 @@ private:
TLSEndpointType endpointType = TLSEndpointType::UNSET;
};
namespace boost {
namespace asio {
namespace ssl {
struct context;
}
} // namespace asio
} // namespace boost
void ConfigureSSLContext(
const LoadedTLSConfig& loaded,
boost::asio::ssl::context* context,
std::function<void()> onPolicyFailure = []() {});
class TLSPolicy;
void ConfigureSSLContext(const LoadedTLSConfig& loaded, boost::asio::ssl::context& context);
// Set up SSL for stream object based on policy.
// Optionally arm a callback that gets called with verify-outcome of each cert in peer certificate chain:
// e.g. for peer with a valid, trusted length-3 certificate chain (root CA, intermediate CA, and server certs),
// callback(true) will be called 3 times.
void ConfigureSSLStream(Reference<TLSPolicy> policy,
boost::asio::ssl::stream<boost::asio::ip::tcp::socket&>& stream,
std::function<void(bool)> callback);
class TLSPolicy : ReferenceCounted<TLSPolicy> {
void set_verify_peers(std::vector<std::string> verify_peers);
public:
TLSPolicy(TLSEndpointType client) : is_client(client == TLSEndpointType::CLIENT) {}
TLSPolicy(const LoadedTLSConfig& loaded, std::function<void()> on_failure);
virtual ~TLSPolicy();
virtual void addref() { ReferenceCounted<TLSPolicy>::addref(); }
@ -223,7 +227,6 @@ public:
static std::string ErrorString(boost::system::error_code e);
void set_verify_peers(std::vector<std::string> verify_peers);
bool verify_peer(bool preverified, X509_STORE_CTX* store_ctx);
std::string toString() const;
@ -242,6 +245,7 @@ public:
};
std::vector<Rule> rules;
std::function<void()> on_failure;
bool is_client;
};

View File

@ -0,0 +1,77 @@
/*
* WatchFile.actor.h
*
* This source file is part of the FoundationDB open source project
*
* Copyright 2013-2022 Apple Inc. and the FoundationDB project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
// When actually compiled (NO_INTELLISENSE), include the generated
// version of this file. In intellisense use the source version.
#if defined(NO_INTELLISENSE) && !defined(FLOW_WATCH_FILE_ACTOR_G_H)
#define FLOW_WATCH_FILE_ACTOR_G_H
#include "flow/WatchFile.actor.g.h"
#elif !defined(FLOW_WATCH_FILE_ACTOR_H)
#define FLOW_WATCH_FILE_ACTOR_H
#include <ctime>
#include <string>
#include "flow/IAsyncFile.h"
#include "flow/genericactors.actor.h"
#include "flow/actorcompiler.h"
ACTOR static Future<Void> watchFileForChanges(std::string filename,
AsyncTrigger* fileChanged,
const int* intervalSeconds,
const char* errorType) {
if (filename == "") {
return Never();
}
state bool firstRun = true;
state bool statError = false;
state std::time_t lastModTime = 0;
loop {
try {
std::time_t modtime = wait(IAsyncFileSystem::filesystem()->lastWriteTime(filename));
if (firstRun) {
lastModTime = modtime;
firstRun = false;
}
if (lastModTime != modtime || statError) {
lastModTime = modtime;
statError = false;
fileChanged->trigger();
}
} catch (Error& e) {
if (e.code() == error_code_io_error) {
// EACCES, ELOOP, ENOENT all come out as io_error(), but are more of a system
// configuration issue than an FDB problem. If we managed to load valid
// certificates, then there's no point in crashing, but we should complain
// loudly. IAsyncFile will log the error, but not necessarily as a warning.
TraceEvent(SevWarnAlways, errorType).detail("File", filename);
statError = true;
} else {
throw;
}
}
wait(delay(*intervalSeconds));
}
}
#include "flow/unactorcompiler.h"
#endif // FLOW_WATCH_FILE_ACTOR_H

View File

@ -99,6 +99,7 @@ ERROR( data_move_cancelled, 1075, "Data move was cancelled" )
ERROR( data_move_dest_team_not_found, 1076, "Dest team was not found for data move" )
ERROR( blob_worker_full, 1077, "Blob worker cannot take on more granule assignments" )
ERROR( grv_proxy_memory_limit_exceeded, 1078, "GetReadVersion proxy memory limit exceeded" )
ERROR( blob_granule_request_failed, 1079, "BlobGranule request failed" )
ERROR( broken_promise, 1100, "Broken promise" )
ERROR( operation_cancelled, 1101, "Asynchronous operation cancelled" )

View File

@ -467,6 +467,11 @@ public:
// this may not be an address we can connect to!
virtual NetworkAddress getPeerAddress() const = 0;
// Returns whether the peer is trusted.
// For TLS-enabled connections, this is true if the peer has presented a valid chain of certificates trusted by the
// local endpoint. For non-TLS connections this is always true for any valid open connection.
virtual bool hasTrustedPeer() const = 0;
virtual UID getDebugID() const = 0;
// At present, implemented by Sim2Conn where we want to disable bits flip for connections between parent process and

View File

@ -40,7 +40,7 @@ if(WITH_PYTHON)
configure_testing(TEST_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}"
ERROR_ON_ADDITIONAL_FILES
IGNORE_PATTERNS ".*/CMakeLists.txt")
IGNORE_PATTERNS ".*/CMakeLists.txt" ".*/requirements.txt")
add_fdb_test(TEST_FILES AsyncFileCorrectness.txt UNIT IGNORE)
add_fdb_test(TEST_FILES AsyncFileMix.txt UNIT IGNORE)
@ -396,6 +396,39 @@ if(WITH_PYTHON)
create_valgrind_correctness_package()
endif()
endif()
if (NOT WIN32)
# setup venv for testing token-based authorization
set(authz_venv_dir ${CMAKE_CURRENT_BINARY_DIR}/authorization_test_venv)
set(authz_venv_activate ". ${authz_venv_dir}/bin/activate")
set(authz_venv_stamp_file ${authz_venv_dir}/venv.ready)
set(authz_venv_cmd "")
string(APPEND authz_venv_cmd "[[ ! -f ${authz_venv_stamp_file} ]] && ")
string(APPEND authz_venv_cmd "${Python3_EXECUTABLE} -m venv ${authz_venv_dir} ")
string(APPEND authz_venv_cmd "&& ${authz_venv_activate} ")
string(APPEND authz_venv_cmd "&& pip install --upgrade pip ")
string(APPEND authz_venv_cmd "&& pip install --upgrade -r ${CMAKE_SOURCE_DIR}/tests/authorization/requirements.txt ")
string(APPEND authz_venv_cmd "&& (cd ${CMAKE_BINARY_DIR}/bindings/python && python3 setup.py install) ")
string(APPEND authz_venv_cmd "&& touch ${authz_venv_stamp_file} ")
string(APPEND authz_venv_cmd "|| echo 'venv already set up'")
add_test(
NAME authorization_venv_setup
COMMAND bash -c ${authz_venv_cmd}
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
set_tests_properties(authorization_venv_setup PROPERTIES FIXTURES_SETUP authz_virtual_env TIMEOUT 60)
set(authz_script_dir ${CMAKE_SOURCE_DIR}/tests/authorization)
set(authz_test_cmd "")
string(APPEND authz_test_cmd "${authz_venv_activate} && ")
string(APPEND authz_test_cmd "LD_LIBRARY_PATH=${CMAKE_BINARY_DIR}/lib pytest ${authz_script_dir}/authz_test.py -rA --build-dir ${CMAKE_BINARY_DIR} -vvv")
add_test(
NAME token_based_tenant_authorization
WORKING_DIRECTORY ${authz_script_dir}
COMMAND bash -c ${authz_test_cmd})
set_tests_properties(token_based_tenant_authorization PROPERTIES ENVIRONMENT PYTHONPATH=${CMAKE_SOURCE_DIR}/tests/TestRunner) # (local|tmp)_cluster.py
set_tests_properties(token_based_tenant_authorization PROPERTIES FIXTURES_REQUIRED authz_virtual_env)
set_tests_properties(token_based_tenant_authorization PROPERTIES TIMEOUT 120)
endif()
else()
message(WARNING "Python not found, won't configure ctest")
endif()

View File

@ -86,6 +86,8 @@ datadir = {datadir}/$ID
logdir = {logdir}
{bg_knob_line}
{tls_config}
{authz_public_key_config}
{custom_config}
{use_future_protocol_version}
# logsize = 10MiB
# maxlogssize = 100MiB
@ -117,6 +119,8 @@ logdir = {logdir}
redundancy: str = "single",
tls_config: TLSConfig = None,
mkcert_binary: str = "",
custom_config: dict = {},
public_key_json_str: str = "",
):
self.basedir = Path(basedir)
self.etc = self.basedir.joinpath("etc")
@ -137,6 +141,7 @@ logdir = {logdir}
self.redundancy = redundancy
self.ip_address = "127.0.0.1" if ip_address is None else ip_address
self.first_port = port
self.custom_config = custom_config
self.blob_granules_enabled = blob_granules_enabled
if blob_granules_enabled:
# add extra process for blob_worker
@ -158,6 +163,7 @@ logdir = {logdir}
self.coordinators = set()
self.active_servers = set(self.server_ports.keys())
self.tls_config = tls_config
self.public_key_json_file = None
self.mkcert_binary = Path(mkcert_binary)
self.server_cert_file = self.cert.joinpath("server_cert.pem")
self.client_cert_file = self.cert.joinpath("client_cert.pem")
@ -166,6 +172,11 @@ logdir = {logdir}
self.server_ca_file = self.cert.joinpath("server_ca.pem")
self.client_ca_file = self.cert.joinpath("client_ca.pem")
if public_key_json_str:
self.public_key_json_file = self.etc.joinpath("public_keys.json")
with open(self.public_key_json_file, "w") as pubkeyfile:
pubkeyfile.write(public_key_json_str)
if create_config:
self.create_cluster_file()
self.save_config()
@ -173,6 +184,8 @@ logdir = {logdir}
if self.tls_config is not None:
self.create_tls_cert()
self.cluster_file = self.etc.joinpath("fdb.cluster")
def __next_port(self):
if self.first_port is None:
return get_free_port()
@ -198,10 +211,10 @@ logdir = {logdir}
ip_address=self.ip_address,
bg_knob_line=bg_knob_line,
tls_config=self.tls_conf_string(),
authz_public_key_config=self.authz_public_key_conf_string(),
optional_tls=":tls" if self.tls_config is not None else "",
use_future_protocol_version="use-future-protocol-version = true"
if self.use_future_protocol_version
else "",
custom_config='\n'.join(["{} = {}".format(key, value) for key, value in self.custom_config.items()]),
use_future_protocol_version="use-future-protocol-version = true" if self.use_future_protocol_version else "",
)
)
# By default, the cluster only has one process
@ -369,6 +382,12 @@ logdir = {logdir}
}
return "\n".join("{} = {}".format(k, v) for k, v in conf_map.items())
def authz_public_key_conf_string(self):
if self.public_key_json_file is not None:
return "authorization-public-key-file = {}".format(self.public_key_json_file)
else:
return ""
# Get cluster status using fdbcli
def get_status(self):
status_output = self.fdbcli_exec_and_get("status json")

View File

@ -18,6 +18,9 @@ class TempCluster(LocalCluster):
port: str = None,
blob_granules_enabled: bool = False,
tls_config: TLSConfig = None,
public_key_json_str: str = None,
remove_at_exit: bool = True,
custom_config: dict = {},
enable_tenants: bool = True,
):
self.build_dir = Path(build_dir).resolve()
@ -26,6 +29,7 @@ class TempCluster(LocalCluster):
tmp_dir = self.build_dir.joinpath("tmp", random_secret_string(16))
tmp_dir.mkdir(parents=True)
self.tmp_dir = tmp_dir
self.remove_at_exit = remove_at_exit
self.enable_tenants = enable_tenants
super().__init__(
tmp_dir,
@ -37,6 +41,8 @@ class TempCluster(LocalCluster):
blob_granules_enabled=blob_granules_enabled,
tls_config=tls_config,
mkcert_binary=self.build_dir.joinpath("bin", "mkcert"),
public_key_json_str=public_key_json_str,
custom_config=custom_config,
)
def __enter__(self):
@ -49,11 +55,13 @@ class TempCluster(LocalCluster):
def __exit__(self, xc_type, exc_value, traceback):
super().__exit__(xc_type, exc_value, traceback)
shutil.rmtree(self.tmp_dir)
if self.remove_at_exit:
shutil.rmtree(self.tmp_dir)
def close(self):
super().__exit__(None, None, None)
shutil.rmtree(self.tmp_dir)
if self.remove_at_exit:
shutil.rmtree(self.tmp_dir)
if __name__ == "__main__":
@ -147,11 +155,11 @@ if __name__ == "__main__":
print("log-dir: {}".format(cluster.log))
print("etc-dir: {}".format(cluster.etc))
print("data-dir: {}".format(cluster.data))
print("cluster-file: {}".format(cluster.etc.joinpath("fdb.cluster")))
print("cluster-file: {}".format(cluster.cluster_file))
cmd_args = []
for cmd in args.cmd:
if cmd == "@CLUSTER_FILE@":
cmd_args.append(str(cluster.etc.joinpath("fdb.cluster")))
cmd_args.append(str(cluster.cluster_file))
elif cmd == "@DATA_DIR@":
cmd_args.append(str(cluster.data))
elif cmd == "@LOG_DIR@":
@ -178,7 +186,7 @@ if __name__ == "__main__":
cmd_args.append(cmd)
env = dict(**os.environ)
env["FDB_CLUSTER_FILE"] = env.get(
"FDB_CLUSTER_FILE", cluster.etc.joinpath("fdb.cluster")
"FDB_CLUSTER_FILE", cluster.cluster_file
)
errcode = subprocess.run(
cmd_args, stdout=sys.stdout, stderr=sys.stderr, env=env

View File

@ -0,0 +1,135 @@
#!/usr/bin/python
#
# admin_server.py
#
# This source file is part of the FoundationDB open source project
#
# Copyright 2013-2022 Apple Inc. and the FoundationDB project authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import fdb
from multiprocessing import Pipe, Process
from typing import Union, List
from util import to_str, to_bytes, cleanup_tenant
class _admin_request(object):
def __init__(self, op: str, args: List[Union[str, bytes]]=[]):
self.op = op
self.args = args
def __str__(self):
return f"admin_request({self.op}, {self.args})"
def __repr__(self):
return f"admin_request({self.op}, {self.args})"
def main_loop(main_pipe, pipe):
main_pipe.close()
db = None
while True:
try:
req = pipe.recv()
except EOFError:
return
if not isinstance(req, _admin_request):
pipe.send(TypeError("unexpected type {}".format(type(req))))
continue
op = req.op
args = req.args
resp = True
try:
if op == "connect":
db = fdb.open(req.args[0])
elif op == "configure_tls":
keyfile, certfile, cafile = req.args[:3]
fdb.options.set_tls_key_path(keyfile)
fdb.options.set_tls_cert_path(certfile)
fdb.options.set_tls_ca_path(cafile)
elif op == "create_tenant":
if db is None:
resp = Exception("db not open")
else:
for tenant in req.args:
tenant_str = to_str(tenant)
tenant_bytes = to_bytes(tenant)
fdb.tenant_management.create_tenant(db, tenant_bytes)
elif op == "delete_tenant":
if db is None:
resp = Exception("db not open")
else:
for tenant in req.args:
tenant_str = to_str(tenant)
tenant_bytes = to_bytes(tenant)
cleanup_tenant(db, tenant_bytes)
elif op == "cleanup_database":
if db is None:
resp = Exception("db not open")
else:
tr = db.create_transaction()
del tr[b'':b'\xff']
tr.commit().wait()
tenants = list(map(lambda x: x.key, list(fdb.tenant_management.list_tenants(db, b'', b'\xff', 0).to_list())))
for tenant in tenants:
fdb.tenant_management.delete_tenant(db, tenant)
elif op == "terminate":
pipe.send(True)
return
else:
resp = ValueError("unknown operation: {}".format(req))
except Exception as e:
resp = e
pipe.send(resp)
_admin_server = None
def get():
return _admin_server
# server needs to be a singleton running in subprocess, because FDB network layer (including active TLS config) is a global var
class Server(object):
def __init__(self):
global _admin_server
assert _admin_server is None, "admin server may be setup once per process"
_admin_server = self
self._main_pipe, self._admin_pipe = Pipe(duplex=True)
self._admin_proc = Process(target=main_loop, args=(self._main_pipe, self._admin_pipe))
def start(self):
self._admin_proc.start()
def join(self):
self._main_pipe.close()
self._admin_pipe.close()
self._admin_proc.join()
def __enter__(self):
self.start()
return self
def __exit__(self, exc_type, exc_value, exc_traceback):
self.join()
def request(self, op, args=[]):
req = _admin_request(op, args)
try:
self._main_pipe.send(req)
resp = self._main_pipe.recv()
if resp != True:
print("{} failed: {}".format(req, resp))
raise resp
else:
print("{} succeeded".format(req))
except Exception as e:
print("{} failed by exception: {}".format(req, e))
raise

View File

@ -0,0 +1,297 @@
#!/usr/bin/python
#
# authz_test.py
#
# This source file is part of the FoundationDB open source project
#
# Copyright 2013-2022 Apple Inc. and the FoundationDB project authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import admin_server
import argparse
import authlib
import fdb
import os
import pytest
import random
import sys
import time
from multiprocessing import Process, Pipe
from typing import Union
from util import alg_from_kty, public_keyset_from_keys, random_alphanum_str, random_alphanum_bytes, to_str, to_bytes, KeyFileReverter, token_claim_1h, wait_until_tenant_tr_succeeds, wait_until_tenant_tr_fails
special_key_ranges = [
("transaction description", b"/description", b"/description\x00"),
("global knobs", b"/globalKnobs", b"/globalKnobs\x00"),
("knobs", b"/knobs0", b"/knobs0\x00"),
("conflicting keys", b"/transaction/conflicting_keys/", b"/transaction/conflicting_keys/\xff\xff"),
("read conflict range", b"/transaction/read_conflict_range/", b"/transaction/read_conflict_range/\xff\xff"),
("conflicting keys", b"/transaction/write_conflict_range/", b"/transaction/write_conflict_range/\xff\xff"),
("data distribution stats", b"/metrics/data_distribution_stats/", b"/metrics/data_distribution_stats/\xff\xff"),
("kill storage", b"/globals/killStorage", b"/globals/killStorage\x00"),
]
def test_simple_tenant_access(private_key, token_gen, default_tenant, tenant_tr_gen):
token = token_gen(private_key, token_claim_1h(default_tenant))
tr = tenant_tr_gen(default_tenant)
tr.options.set_authorization_token(token)
tr[b"abc"] = b"def"
tr.commit().wait()
tr = tenant_tr_gen(default_tenant)
tr.options.set_authorization_token(token)
assert tr[b"abc"] == b"def", "tenant write transaction not visible"
def test_cross_tenant_access_disallowed(private_key, default_tenant, token_gen, tenant_gen, tenant_tr_gen):
# use default tenant token with second tenant transaction and see it fail
second_tenant = random_alphanum_bytes(12)
tenant_gen(second_tenant)
token_second = token_gen(private_key, token_claim_1h(second_tenant))
tr_second = tenant_tr_gen(second_tenant)
tr_second.options.set_authorization_token(token_second)
tr_second[b"abc"] = b"def"
tr_second.commit().wait()
token_default = token_gen(private_key, token_claim_1h(default_tenant))
tr_second = tenant_tr_gen(second_tenant)
tr_second.options.set_authorization_token(token_default)
# test that read transaction fails
try:
value = tr_second[b"abc"].value
assert False, f"expected permission denied, but read transaction went through, value: {value}"
except fdb.FDBError as e:
assert e.code == 6000, f"expected permission_denied, got {e} instead"
# test that write transaction fails
tr_second = tenant_tr_gen(second_tenant)
tr_second.options.set_authorization_token(token_default)
try:
tr_second[b"def"] = b"ghi"
tr_second.commit().wait()
assert False, "expected permission denied, but write transaction went through"
except fdb.FDBError as e:
assert e.code == 6000, f"expected permission_denied, got {e} instead"
def test_system_and_special_key_range_disallowed(db, tenant_tr_gen, token_gen):
second_tenant = random_alphanum_bytes(12)
try:
fdb.tenant_management.create_tenant(db, second_tenant)
assert False, "disallowed create_tenant has succeeded"
except fdb.FDBError as e:
assert e.code == 6000, f"expected permission_denied, got {e} instead"
try:
tr = db.create_transaction()
tr.options.set_access_system_keys()
kvs = tr.get_range(b"\xff", b"\xff\xff", limit=1).to_list()
assert False, f"disallowed system keyspace read has succeeded. found item: {kvs}"
except fdb.FDBError as e:
assert e.code == 6000, f"expected permission_denied, got {e} instead"
for range_name, special_range_begin, special_range_end in special_key_ranges:
tr = db.create_transaction()
tr.options.set_access_system_keys()
tr.options.set_special_key_space_relaxed()
try:
kvs = tr.get_range(special_range_begin, special_range_end, limit=1).to_list()
assert False, f"disallowed special keyspace read for range {range_name} has succeeded. found item {kvs}"
except fdb.FDBError as e:
assert e.code == 6000, f"expected permission_denied from attempted read to range {range_name}, got {e} instead"
try:
tr = db.create_transaction()
tr.options.set_access_system_keys()
del tr[b"\xff":b"\xff\xff"]
tr.commit().wait()
assert False, f"disallowed system keyspace write has succeeded"
except fdb.FDBError as e:
assert e.code == 6000, f"expected permission_denied, got {e} instead"
for range_name, special_range_begin, special_range_end in special_key_ranges:
tr = db.create_transaction()
tr.options.set_access_system_keys()
tr.options.set_special_key_space_relaxed()
try:
del tr[special_range_begin:special_range_end]
tr.commit().wait()
assert False, f"write to disallowed special keyspace range {range_name} has succeeded"
except fdb.FDBError as e:
assert e.code == 6000, f"expected permission_denied from attempted write to range {range_name}, got {e} instead"
try:
tr = db.create_transaction()
tr.options.set_access_system_keys()
kvs = tr.get_range(b"", b"\xff", limit=1).to_list()
assert False, f"disallowed normal keyspace read has succeeded. found item {kvs}"
except fdb.FDBError as e:
assert e.code == 6000, f"expected permission_denied, got {e} instead"
def test_public_key_set_rollover(
kty, private_key_gen, private_key, public_key_refresh_interval,
cluster, default_tenant, token_gen, tenant_gen, tenant_tr_gen):
new_kid = random_alphanum_str(12)
new_kty = "EC" if kty == "RSA" else "RSA"
new_key = private_key_gen(kty=new_kty, kid=new_kid)
token_default = token_gen(private_key, token_claim_1h(default_tenant))
second_tenant = random_alphanum_bytes(12)
tenant_gen(second_tenant)
token_second = token_gen(new_key, token_claim_1h(second_tenant))
interim_set = public_keyset_from_keys([new_key, private_key])
max_repeat = 10
print(f"interim keyset: {interim_set}")
old_key_json = None
with open(cluster.public_key_json_file, "r") as keyfile:
old_key_json = keyfile.read()
delay = public_key_refresh_interval
with KeyFileReverter(cluster.public_key_json_file, old_key_json, delay):
with open(cluster.public_key_json_file, "w") as keyfile:
keyfile.write(interim_set)
wait_until_tenant_tr_succeeds(second_tenant, new_key, tenant_tr_gen, token_gen, max_repeat, delay)
print("interim key set activated")
final_set = public_keyset_from_keys([new_key])
print(f"final keyset: {final_set}")
with open(cluster.public_key_json_file, "w") as keyfile:
keyfile.write(final_set)
wait_until_tenant_tr_fails(default_tenant, private_key, tenant_tr_gen, token_gen, max_repeat, delay)
def test_public_key_set_broken_file_tolerance(
private_key, public_key_refresh_interval,
cluster, public_key_jwks_str, default_tenant, token_gen, tenant_tr_gen):
delay = public_key_refresh_interval
# retry limit in waiting for keyset file update to propagate to FDB server's internal keyset
max_repeat = 10
with KeyFileReverter(cluster.public_key_json_file, public_key_jwks_str, delay):
# key file update should take effect even after witnessing broken key file
with open(cluster.public_key_json_file, "w") as keyfile:
keyfile.write(public_key_jwks_str.strip()[:10]) # make the file partial, injecting parse error
time.sleep(delay * 2)
# should still work; internal key set only clears with a valid, empty key set file
tr_default = tenant_tr_gen(default_tenant)
tr_default.options.set_authorization_token(token_gen(private_key, token_claim_1h(default_tenant)))
tr_default[b"abc"] = b"def"
tr_default.commit().wait()
with open(cluster.public_key_json_file, "w") as keyfile:
keyfile.write('{"keys":[]}')
# eventually internal key set will become empty and won't accept any new tokens
wait_until_tenant_tr_fails(default_tenant, private_key, tenant_tr_gen, token_gen, max_repeat, delay)
def test_public_key_set_deletion_tolerance(
private_key, public_key_refresh_interval,
cluster, public_key_jwks_str, default_tenant, token_gen, tenant_tr_gen):
delay = public_key_refresh_interval
# retry limit in waiting for keyset file update to propagate to FDB server's internal keyset
max_repeat = 10
with KeyFileReverter(cluster.public_key_json_file, public_key_jwks_str, delay):
# key file update should take effect even after witnessing deletion of key file
with open(cluster.public_key_json_file, "w") as keyfile:
keyfile.write('{"keys":[]}')
time.sleep(delay)
wait_until_tenant_tr_fails(default_tenant, private_key, tenant_tr_gen, token_gen, max_repeat, delay)
os.remove(cluster.public_key_json_file)
time.sleep(delay * 2)
with open(cluster.public_key_json_file, "w") as keyfile:
keyfile.write(public_key_jwks_str)
# eventually updated key set should take effect and transaction should be accepted
wait_until_tenant_tr_succeeds(default_tenant, private_key, tenant_tr_gen, token_gen, max_repeat, delay)
def test_public_key_set_empty_file_tolerance(
private_key, public_key_refresh_interval,
cluster, public_key_jwks_str, default_tenant, token_gen, tenant_tr_gen):
delay = public_key_refresh_interval
# retry limit in waiting for keyset file update to propagate to FDB server's internal keyset
max_repeat = 10
with KeyFileReverter(cluster.public_key_json_file, public_key_jwks_str, delay):
# key file update should take effect even after witnessing an empty file
with open(cluster.public_key_json_file, "w") as keyfile:
keyfile.write('{"keys":[]}')
# eventually internal key set will become empty and won't accept any new tokens
wait_until_tenant_tr_fails(default_tenant, private_key, tenant_tr_gen, token_gen, max_repeat, delay)
# empty the key file
with open(cluster.public_key_json_file, "w") as keyfile:
pass
time.sleep(delay * 2)
with open(cluster.public_key_json_file, "w") as keyfile:
keyfile.write(public_key_jwks_str)
# eventually key file should update and transactions should go through
wait_until_tenant_tr_succeeds(default_tenant, private_key, tenant_tr_gen, token_gen, max_repeat, delay)
def test_bad_token(private_key, token_gen, default_tenant, tenant_tr_gen):
def del_attr(d, attr):
del d[attr]
return d
def set_attr(d, attr, value):
d[attr] = value
return d
claim_mutations = [
("no nbf", lambda claim: del_attr(claim, "nbf")),
("no exp", lambda claim: del_attr(claim, "exp")),
("no iat", lambda claim: del_attr(claim, "iat")),
("too early", lambda claim: set_attr(claim, "nbf", time.time() + 30)),
("too late", lambda claim: set_attr(claim, "exp", time.time() - 10)),
("no tenants", lambda claim: del_attr(claim, "tenants")),
("empty tenants", lambda claim: set_attr(claim, "tenants", [])),
]
for case_name, mutation in claim_mutations:
tr = tenant_tr_gen(default_tenant)
tr.options.set_authorization_token(token_gen(private_key, mutation(token_claim_1h(default_tenant))))
try:
value = tr[b"abc"].value
assert False, f"expected permission_denied for case {case_name}, but read transaction went through"
except fdb.FDBError as e:
assert e.code == 6000, f"expected permission_denied for case {case_name}, got {e} instead"
tr = tenant_tr_gen(default_tenant)
tr.options.set_authorization_token(token_gen(private_key, mutation(token_claim_1h(default_tenant))))
tr[b"abc"] = b"def"
try:
tr.commit().wait()
assert False, f"expected permission_denied for case {case_name}, but write transaction went through"
except fdb.FDBError as e:
assert e.code == 6000, f"expected permission_denied for case {case_name}, got {e} instead"
# unknown key case: override "kid" field in header
# first, update only the kid field of key with export-update-import
key_dict = private_key.as_dict(is_private=True)
key_dict["kid"] = random_alphanum_str(10)
renamed_key = authlib.jose.JsonWebKey.import_key(key_dict)
unknown_key_token = token_gen(
renamed_key,
token_claim_1h(default_tenant),
headers={
"typ": "JWT",
"kty": renamed_key.kty,
"alg": alg_from_kty(renamed_key.kty),
"kid": renamed_key.kid,
})
tr = tenant_tr_gen(default_tenant)
tr.options.set_authorization_token(unknown_key_token)
try:
value = tr[b"abc"].value
assert False, f"expected permission_denied for 'unknown key' case, but read transaction went through"
except fdb.FDBError as e:
assert e.code == 6000, f"expected permission_denied for 'unknown key' case, got {e} instead"
tr = tenant_tr_gen(default_tenant)
tr.options.set_authorization_token(unknown_key_token)
tr[b"abc"] = b"def"
try:
tr.commit().wait()
assert False, f"expected permission_denied for 'unknown key' case, but write transaction went through"
except fdb.FDBError as e:
assert e.code == 6000, f"expected permission_denied for 'unknown key' case, got {e} instead"

View File

@ -0,0 +1,173 @@
#!/usr/bin/python
#
# conftest.py
#
# This source file is part of the FoundationDB open source project
#
# Copyright 2013-2022 Apple Inc. and the FoundationDB project authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import fdb
import pytest
import subprocess
import admin_server
from authlib.jose import JsonWebKey, KeySet, jwt
from local_cluster import TLSConfig
from tmp_cluster import TempCluster
from typing import Union
from util import alg_from_kty, public_keyset_from_keys, random_alphanum_str, random_alphanum_bytes, to_str, to_bytes
fdb.api_version(720)
cluster_scope = "module"
def pytest_addoption(parser):
parser.addoption(
"--build-dir", action="store", dest="build_dir", help="FDB build directory", required=True)
parser.addoption(
"--kty", action="store", choices=["EC", "RSA"], default="EC", dest="kty", help="Token signature algorithm")
parser.addoption(
"--trusted-client",
action="store_true",
default=False,
dest="trusted_client",
help="Whether client shall be configured trusted, i.e. mTLS-ready")
parser.addoption(
"--public-key-refresh-interval",
action="store",
default=1,
dest="public_key_refresh_interval",
help="How frequently server refreshes authorization public key file")
@pytest.fixture(scope="session")
def build_dir(request):
return request.config.option.build_dir
@pytest.fixture(scope="session")
def kty(request):
return request.config.option.kty
@pytest.fixture(scope="session")
def trusted_client(request):
return request.config.option.trusted_client
@pytest.fixture(scope="session")
def public_key_refresh_interval(request):
return request.config.option.public_key_refresh_interval
@pytest.fixture(scope="session")
def alg(kty):
if kty == "EC":
return "ES256"
else:
return "RS256"
@pytest.fixture(scope="session")
def kid():
return random_alphanum_str(12)
@pytest.fixture(scope="session")
def private_key_gen():
def fn(kty: str, kid: str):
if kty == "EC":
return JsonWebKey.generate_key(kty=kty, crv_or_size="P-256", is_private=True, options={"kid": kid})
else:
return JsonWebKey.generate_key(kty=kty, crv_or_size=4096, is_private=True, options={"kid": kid})
return fn
@pytest.fixture(scope="session")
def private_key(kty, kid, private_key_gen):
return private_key_gen(kty, kid)
@pytest.fixture(scope="session")
def public_key_jwks_str(private_key):
return public_keyset_from_keys([private_key])
@pytest.fixture(scope="session")
def token_gen():
def fn(private_key, claims, headers={}):
if not headers:
headers = {
"typ": "JWT",
"kty": private_key.kty,
"alg": alg_from_kty(private_key.kty),
"kid": private_key.kid,
}
return jwt.encode(headers, claims, private_key)
return fn
@pytest.fixture(scope=cluster_scope)
def admin_ipc():
server = admin_server.Server()
server.start()
yield server
server.join()
@pytest.fixture(autouse=True, scope=cluster_scope)
def cluster(admin_ipc, build_dir, public_key_jwks_str, public_key_refresh_interval, trusted_client):
with TempCluster(
build_dir=build_dir,
tls_config=TLSConfig(server_chain_len=3, client_chain_len=2),
public_key_json_str=public_key_jwks_str,
remove_at_exit=True,
custom_config={
"knob-public-key-file-refresh-interval-seconds": public_key_refresh_interval,
}) as cluster:
keyfile = str(cluster.client_key_file)
certfile = str(cluster.client_cert_file)
cafile = str(cluster.server_ca_file)
fdb.options.set_tls_key_path(keyfile if trusted_client else "")
fdb.options.set_tls_cert_path(certfile if trusted_client else "")
fdb.options.set_tls_ca_path(cafile)
fdb.options.set_trace_enable()
admin_ipc.request("configure_tls", [keyfile, certfile, cafile])
admin_ipc.request("connect", [str(cluster.cluster_file)])
yield cluster
@pytest.fixture
def db(cluster, admin_ipc):
db = fdb.open(str(cluster.cluster_file))
db.options.set_transaction_timeout(2000) # 2 seconds
db.options.set_transaction_retry_limit(3)
yield db
admin_ipc.request("cleanup_database")
db = None
@pytest.fixture
def tenant_gen(db, admin_ipc):
def fn(tenant):
tenant = to_bytes(tenant)
admin_ipc.request("create_tenant", [tenant])
return fn
@pytest.fixture
def tenant_del(db, admin_ipc):
def fn(tenant):
tenant = to_str(tenant)
admin_ipc.request("delete_tenant", [tenant])
return fn
@pytest.fixture
def default_tenant(tenant_gen, tenant_del):
tenant = random_alphanum_bytes(8)
tenant_gen(tenant)
yield tenant
tenant_del(tenant)
@pytest.fixture
def tenant_tr_gen(db):
def fn(tenant):
tenant = db.open_tenant(to_bytes(tenant))
return tenant.create_transaction()
return fn

View File

@ -0,0 +1,12 @@
attrs==22.1.0
Authlib==1.0.1
cffi==1.15.1
cryptography==37.0.4
iniconfig==1.1.1
packaging==21.3
pluggy==1.0.0
py==1.11.0
pycparser==2.21
pyparsing==3.0.9
pytest==7.1.2
tomli==2.0.1

124
tests/authorization/util.py Normal file
View File

@ -0,0 +1,124 @@
import fdb
import json
import random
import string
import time
from typing import Union, List
def to_str(s: Union[str, bytes]):
if isinstance(s, bytes):
s = s.decode("utf8")
return s
def to_bytes(s: Union[str, bytes]):
if isinstance(s, str):
s = s.encode("utf8")
return s
def random_alphanum_str(k: int):
return ''.join(random.choices(string.ascii_letters + string.digits, k=k))
def random_alphanum_bytes(k: int):
return random_alphanum_str(k).encode("ascii")
def cleanup_tenant(db, tenant_name):
try:
tenant = db.open_tenant(tenant_name)
del tenant[:]
fdb.tenant_management.delete_tenant(db, tenant_name)
except fdb.FDBError as e:
if e.code == 2131: # tenant not found
pass
else:
raise
def alg_from_kty(kty: str):
if kty == "EC":
return "ES256"
else:
return "RS256"
def public_keyset_from_keys(keys: List):
keys = list(map(lambda key: key.as_dict(is_private=False, alg=alg_from_kty(key.kty)), keys))
return json.dumps({ "keys": keys })
class KeyFileReverter(object):
def __init__(self, filename: str, content: str, refresh_delay: int):
self.filename = filename
self.content = content
self.refresh_delay = refresh_delay
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, exc_traceback):
with open(self.filename, "w") as keyfile:
keyfile.write(self.content)
print(f"key file reverted. waiting {self.refresh_delay * 2} seconds for the update to take effect...")
time.sleep(self.refresh_delay * 2)
# JWT claim that is valid for 1 hour since time of invocation
def token_claim_1h(tenant_name):
now = time.time()
return {
"iss": "fdb-authz-tester",
"sub": "authz-test",
"aud": ["tmp-cluster"],
"iat": now,
"nbf": now - 1,
"exp": now + 60 * 60,
"jti": random_alphanum_str(10),
"tenants": [to_str(tenant_name)],
}
# repeat try-wait loop up to max_repeat times until both read and write tr fails for tenant with permission_denied
# important: only use this function if you don't have any data dependencies to key "abc"
def wait_until_tenant_tr_fails(tenant, private_key, tenant_tr_gen, token_gen, max_repeat, delay):
repeat = 0
read_blocked = False
write_blocked = False
while (not read_blocked or not write_blocked) and repeat < max_repeat:
time.sleep(delay)
tr = tenant_tr_gen(tenant)
# a token needs to be generated at every iteration because once it is accepted/cached,
# it will pass verification by caching until it expires
tr.options.set_authorization_token(token_gen(private_key, token_claim_1h(tenant)))
try:
if not read_blocked:
value = tr[b"abc"].value
except fdb.FDBError as e:
assert e.code == 6000, f"expected permission_denied, got {e} instead"
read_blocked = True
if not read_blocked:
repeat += 1
continue
try:
if not write_blocked:
tr[b"abc"] = b"def"
tr.commit().wait()
except fdb.FDBError as e:
assert e.code == 6000, f"expected permission_denied, got {e} instead"
write_blocked = True
if not write_blocked:
repeat += 1
assert repeat < max_repeat, f"tenant transaction did not start to fail in {max_repeat * delay} seconds"
# repeat try-wait loop up to max_repeat times until both read and write tr succeeds for tenant
# important: only use this function if you don't have any data dependencies to key "abc"
def wait_until_tenant_tr_succeeds(tenant, private_key, tenant_tr_gen, token_gen, max_repeat, delay):
repeat = 0
token = token_gen(private_key, token_claim_1h(tenant))
while repeat < max_repeat:
try:
time.sleep(delay)
tr = tenant_tr_gen(tenant)
tr.options.set_authorization_token(token)
value = tr[b"abc"].value
tr[b"abc"] = b"qwe"
tr.commit().wait()
break
except fdb.FDBError as e:
assert e.code == 6000, f"expected permission_denied, got {e} instead"
repeat += 1
assert repeat < max_repeat, f"tenant transaction did not start to succeed in {max_repeat * delay} seconds"