Merge branch 'main' of github.com:apple/foundationdb into debug2

This commit is contained in:
Ankita Kejriwal 2023-02-08 13:26:39 -08:00
commit 0789ab35e9
38 changed files with 879 additions and 360 deletions

View File

@ -31,19 +31,24 @@ import random
import string
import toml
sys.path[:0] = [os.path.join(os.path.dirname(__file__), "..", "..", "..", "..", "tests", "TestRunner")]
# fmt: off
from tmp_cluster import TempCluster
from local_cluster import TLSConfig
# fmt: on
sys.path[:0] = [
os.path.join(
os.path.dirname(__file__), "..", "..", "..", "..", "tests", "TestRunner"
)
]
TESTER_STATS_INTERVAL_SEC = 5
def random_string(len):
return "".join(random.choice(string.ascii_letters + string.digits) for i in range(len))
return "".join(
random.choice(string.ascii_letters + string.digits) for i in range(len)
)
def get_logger():
@ -77,7 +82,9 @@ def dump_client_logs(log_dir):
def run_tester(args, cluster, test_file):
build_dir = Path(args.build_dir).resolve()
tester_binary = Path(args.api_tester_bin).resolve()
external_client_library = build_dir.joinpath("bindings", "c", "libfdb_c_external.so")
external_client_library = build_dir.joinpath(
"bindings", "c", "libfdb_c_external.so"
)
log_dir = Path(cluster.log).joinpath("client")
log_dir.mkdir(exist_ok=True)
cmd = [
@ -141,7 +148,9 @@ def run_tester(args, cluster, test_file):
reason = signal.Signals(-ret_code).name
else:
reason = "exit code: %d" % ret_code
get_logger().error("\n'%s' did not complete succesfully (%s)" % (cmd[0], reason))
get_logger().error(
"\n'%s' did not complete succesfully (%s)" % (cmd[0], reason)
)
if log_dir is not None and not args.disable_log_dump:
dump_client_logs(log_dir)
@ -160,7 +169,9 @@ class TestConfig:
self.server_chain_len = server_config.get("tls_server_chain_len", 3)
self.min_num_processes = server_config.get("min_num_processes", 1)
self.max_num_processes = server_config.get("max_num_processes", 3)
self.num_processes = random.randint(self.min_num_processes, self.max_num_processes)
self.num_processes = random.randint(
self.min_num_processes, self.max_num_processes
)
def run_test(args, test_file):
@ -210,9 +221,20 @@ def run_tests(args):
def parse_args(argv):
parser = argparse.ArgumentParser(description="FoundationDB C API Tester")
parser.add_argument("--build-dir", "-b", type=str, required=True, help="FDB build directory")
parser.add_argument("--api-tester-bin", type=str, help="Path to the fdb_c_api_tester executable.", required=True)
parser.add_argument("--external-client-library", type=str, help="Path to the external client library.")
parser.add_argument(
"--build-dir", "-b", type=str, required=True, help="FDB build directory"
)
parser.add_argument(
"--api-tester-bin",
type=str,
help="Path to the fdb_c_api_tester executable.",
required=True,
)
parser.add_argument(
"--external-client-library",
type=str,
help="Path to the external client library.",
)
parser.add_argument(
"--retain-client-lib-copies",
action="store_true",

View File

@ -22,6 +22,7 @@
#include "future.hpp"
#include "logger.hpp"
#include "tenant.hpp"
#include "time.hpp"
#include "utils.hpp"
#include <map>
#include <cerrno>
@ -30,39 +31,19 @@
#include <sstream>
#include <stdexcept>
#include <thread>
#include <tuple>
#include <unistd.h>
#include <sys/wait.h>
#include <boost/archive/binary_iarchive.hpp>
#include <boost/archive/binary_oarchive.hpp>
#include <boost/serialization/optional.hpp>
#include <boost/serialization/string.hpp>
#include <boost/serialization/variant.hpp>
#include <boost/variant/apply_visitor.hpp>
#include "rapidjson/document.h"
extern thread_local mako::Logger logr;
using oarchive = boost::archive::binary_oarchive;
using iarchive = boost::archive::binary_iarchive;
namespace {
template <class T>
void sendObject(boost::process::pstream& pipe, T obj) {
oarchive oa(pipe);
oa << obj;
}
template <class T>
T receiveObject(boost::process::pstream& pipe) {
iarchive ia(pipe);
T obj;
ia >> obj;
return obj;
}
fdb::Database getOrCreateDatabase(std::map<std::string, fdb::Database>& db_map, const std::string& cluster_file) {
auto iter = db_map.find(cluster_file);
if (iter == db_map.end()) {
@ -122,35 +103,54 @@ void AdminServer::start() {
}
});
while (true) {
bool stop = false;
while (!stop) {
try {
auto req = receiveObject<Request>(pipe_to_server);
if (setup_error) {
sendObject(pipe_to_client, Response{ setup_error });
} else if (boost::get<PingRequest>(&req)) {
sendObject(pipe_to_client, Response{});
} else if (boost::get<StopRequest>(&req)) {
logr.info("server was requested to stop");
sendObject(pipe_to_client, Response{});
return;
} else if (auto p = boost::get<BatchCreateTenantRequest>(&req)) {
logr.info("received request to batch-create tenants [{}:{}) in database '{}'",
p->id_begin,
p->id_end,
p->cluster_file);
auto err_msg = createTenant(getOrCreateDatabase(databases, p->cluster_file), p->id_begin, p->id_end);
sendObject(pipe_to_client, Response{ std::move(err_msg) });
} else if (auto p = boost::get<BatchDeleteTenantRequest>(&req)) {
logr.info("received request to batch-delete tenants [{}:{}) in database '{}'",
p->id_begin,
p->id_end,
p->cluster_file);
auto err_msg = deleteTenant(getOrCreateDatabase(databases, p->cluster_file), p->id_begin, p->id_end);
sendObject(pipe_to_client, Response{ std::move(err_msg) });
} else {
logr.error("unknown request received");
sendObject(pipe_to_client, Response{ std::string("unknown request type") });
}
boost::apply_visitor(
[this, &databases, &setup_error, &stop](auto&& request) -> void {
using ReqType = std::decay_t<decltype(request)>;
if (setup_error) {
sendResponse<ReqType>(pipe_to_client, ReqType::ResponseType::makeError(*setup_error));
return;
}
if constexpr (std::is_same_v<ReqType, PingRequest>) {
sendResponse<ReqType>(pipe_to_client, DefaultResponse{});
} else if constexpr (std::is_same_v<ReqType, StopRequest>) {
logr.info("server was requested to stop");
sendResponse<ReqType>(pipe_to_client, DefaultResponse{});
stop = true;
} else if constexpr (std::is_same_v<ReqType, BatchCreateTenantRequest>) {
logr.info("received request to batch-create tenants [{}:{}) in database '{}'",
request.id_begin,
request.id_end,
request.cluster_file);
auto err_msg = createTenant(
getOrCreateDatabase(databases, request.cluster_file), request.id_begin, request.id_end);
sendResponse<ReqType>(pipe_to_client, DefaultResponse{ std::move(err_msg) });
} else if constexpr (std::is_same_v<ReqType, BatchDeleteTenantRequest>) {
logr.info("received request to batch-delete tenants [{}:{}) in database '{}'",
request.id_begin,
request.id_end,
request.cluster_file);
auto err_msg = deleteTenant(
getOrCreateDatabase(databases, request.cluster_file), request.id_begin, request.id_end);
sendResponse<ReqType>(pipe_to_client, DefaultResponse{ std::move(err_msg) });
} else if constexpr (std::is_same_v<ReqType, FetchTenantIdsRequest>) {
logr.info("received request to fetch tenant IDs [{}:{}) in database '{}'",
request.id_begin,
request.id_end,
request.cluster_file);
sendResponse<ReqType>(pipe_to_client,
fetchTenantIds(getOrCreateDatabase(databases, request.cluster_file),
request.id_begin,
request.id_end));
} else {
logr.error("unknown request received, typename '{}'", typeid(ReqType).name());
sendResponse<ReqType>(pipe_to_client, ReqType::ResponseType::makeError("unknown request type"));
}
},
req);
} catch (const std::exception& e) {
logr.error("fatal exception: {}", e.what());
return;
@ -161,6 +161,7 @@ void AdminServer::start() {
boost::optional<std::string> AdminServer::createTenant(fdb::Database db, int id_begin, int id_end) {
try {
auto tx = db.createTransaction();
auto stopwatch = Stopwatch(StartAtCtor{});
logr.info("create_tenants [{}-{})", id_begin, id_end);
while (true) {
for (auto id = id_begin; id < id_end; id++) {
@ -180,7 +181,8 @@ boost::optional<std::string> AdminServer::createTenant(fdb::Database db, int id_
return fmt::format("create_tenants [{}:{}) failed with '{}'", id_begin, id_end, f.error().what());
}
}
logr.info("create_tenants [{}-{}) OK", id_begin, id_end);
logr.info("create_tenants [{}-{}) OK ({:.3f}s)", id_begin, id_end, toDoubleSeconds(stopwatch.stop().diff()));
stopwatch.start();
logr.info("blobbify_tenants [{}-{})", id_begin, id_end);
for (auto id = id_begin; id < id_end; id++) {
while (true) {
@ -204,7 +206,7 @@ boost::optional<std::string> AdminServer::createTenant(fdb::Database db, int id_
}
}
}
logr.info("blobbify_tenants [{}-{}) OK", id_begin, id_end);
logr.info("blobbify_tenants [{}-{}) OK ({:.3f}s)", id_begin, id_end, toDoubleSeconds(stopwatch.stop().diff()));
return {};
} catch (const std::exception& e) {
return std::string(e.what());
@ -307,12 +309,52 @@ boost::optional<std::string> AdminServer::deleteTenant(fdb::Database db, int id_
}
}
Response AdminServer::request(Request req) {
// should always be invoked from client side (currently just the main process)
assert(server_pid > 0);
assert(logr.isFor(ProcKind::MAIN));
sendObject(pipe_to_server, std::move(req));
return receiveObject<Response>(pipe_to_client);
TenantIdsResponse AdminServer::fetchTenantIds(fdb::Database db, int id_begin, int id_end) {
try {
logr.info("fetch_tenant_ids [{}:{})", id_begin, id_end);
auto stopwatch = Stopwatch(StartAtCtor{});
size_t const count = id_end - id_begin;
std::vector<int64_t> ids(count);
std::vector<std::tuple<fdb::TypedFuture<fdb::future_var::Int64>, fdb::Tenant, bool>> state(count);
boost::optional<std::string> err_msg;
for (auto idx = id_begin; idx < id_end; idx++) {
auto& [future, tenant, done] = state[idx - id_begin];
tenant = db.openTenant(fdb::toBytesRef(getTenantNameByIndex(idx)));
future = tenant.getId();
done = false;
}
while (true) {
bool has_retries = false;
for (auto idx = id_begin; idx < id_end; idx++) {
auto& [future, tenant, done] = state[idx - id_begin];
if (!done) {
if (auto err = future.blockUntilReady()) {
return TenantIdsResponse::makeError(
fmt::format("error while waiting for tenant ID of tenant {}: {}", idx, err.what()));
}
if (auto err = future.error()) {
if (err.retryable()) {
logr.debug("retryable error while getting tenant ID of tenant {}: {}", idx, err.what());
future = tenant.getId();
has_retries = true;
} else {
return TenantIdsResponse::makeError(fmt::format(
"unretryable error while getting tenant ID of tenant {}: {}", idx, err.what()));
}
} else {
ids[idx - id_begin] = future.get();
done = true;
}
}
}
if (!has_retries)
break;
}
logr.info("fetch_tenant_ids [{}:{}) OK ({:.3f}s)", id_begin, id_end, toDoubleSeconds(stopwatch.stop().diff()));
return TenantIdsResponse{ {}, std::move(ids) };
} catch (const std::exception& e) {
return TenantIdsResponse::makeError(fmt::format("unexpected exception: {}", e.what()));
}
}
AdminServer::~AdminServer() {

View File

@ -23,6 +23,13 @@
#include <boost/process/pipe.hpp>
#include <boost/optional.hpp>
#include <boost/variant.hpp>
#include <boost/archive/binary_iarchive.hpp>
#include <boost/archive/binary_oarchive.hpp>
#include <boost/serialization/optional.hpp>
#include <boost/serialization/string.hpp>
#include <boost/serialization/variant.hpp>
#include <boost/serialization/vector.hpp>
#include <unistd.h>
#include "fdb_api.hpp"
#include "logger.hpp"
@ -35,16 +42,32 @@ extern thread_local mako::Logger logr;
// Therefore, order to benchmark for authorization
namespace mako::ipc {
struct Response {
struct DefaultResponse {
boost::optional<std::string> error_message;
static DefaultResponse makeError(std::string msg) { return DefaultResponse{ msg }; }
template <class Ar>
void serialize(Ar& ar, unsigned int) {
ar& error_message;
}
};
struct TenantIdsResponse {
boost::optional<std::string> error_message;
std::vector<int64_t> ids;
static TenantIdsResponse makeError(std::string msg) { return TenantIdsResponse{ msg, {} }; }
template <class Ar>
void serialize(Ar& ar, unsigned int) {
ar& error_message;
ar& ids;
}
};
struct BatchCreateTenantRequest {
using ResponseType = DefaultResponse;
std::string cluster_file;
int id_begin = 0;
int id_end = 0;
@ -58,6 +81,7 @@ struct BatchCreateTenantRequest {
};
struct BatchDeleteTenantRequest {
using ResponseType = DefaultResponse;
std::string cluster_file;
int id_begin = 0;
int id_end = 0;
@ -70,17 +94,34 @@ struct BatchDeleteTenantRequest {
}
};
struct FetchTenantIdsRequest {
using ResponseType = TenantIdsResponse;
std::string cluster_file;
int id_begin;
int id_end;
template <class Ar>
void serialize(Ar& ar, unsigned int) {
ar& cluster_file;
ar& id_begin;
ar& id_end;
}
};
struct PingRequest {
using ResponseType = DefaultResponse;
template <class Ar>
void serialize(Ar&, unsigned int) {}
};
struct StopRequest {
using ResponseType = DefaultResponse;
template <class Ar>
void serialize(Ar&, unsigned int) {}
};
using Request = boost::variant<PingRequest, StopRequest, BatchCreateTenantRequest, BatchDeleteTenantRequest>;
using Request =
boost::variant<PingRequest, StopRequest, BatchCreateTenantRequest, BatchDeleteTenantRequest, FetchTenantIdsRequest>;
class AdminServer {
const Arguments& args;
@ -89,13 +130,32 @@ class AdminServer {
boost::process::pstream pipe_to_client;
void start();
void configure();
Response request(Request req);
boost::optional<std::string> getTenantPrefixes(fdb::Transaction tx,
int id_begin,
int id_end,
std::vector<fdb::ByteString>& out_prefixes);
boost::optional<std::string> createTenant(fdb::Database db, int id_begin, int id_end);
boost::optional<std::string> deleteTenant(fdb::Database db, int id_begin, int id_end);
TenantIdsResponse fetchTenantIds(fdb::Database db, int id_begin, int id_end);
template <class T>
static void sendObject(boost::process::pstream& pipe, T obj) {
boost::archive::binary_oarchive oa(pipe);
oa << obj;
}
template <class T>
static T receiveObject(boost::process::pstream& pipe) {
boost::archive::binary_iarchive ia(pipe);
T obj;
ia >> obj;
return obj;
}
template <class RequestType>
static void sendResponse(boost::process::pstream& pipe, typename RequestType::ResponseType obj) {
sendObject(pipe, std::move(obj));
}
public:
AdminServer(const Arguments& args)
@ -107,9 +167,14 @@ public:
// forks a server subprocess internally
bool isClient() const noexcept { return server_pid > 0; }
template <class T>
Response send(T req) {
return request(Request(std::forward<T>(req)));
typename T::ResponseType send(T req) {
// should always be invoked from client side (currently just the main process)
assert(server_pid > 0);
assert(logr.isFor(ProcKind::MAIN));
sendObject(pipe_to_server, Request(std::move(req)));
return receiveObject<typename T::ResponseType>(pipe_to_client);
}
AdminServer(const AdminServer&) = delete;

View File

@ -90,10 +90,11 @@ using namespace mako;
thread_local Logger logr = Logger(MainProcess{}, VERBOSE_DEFAULT);
Transaction createNewTransaction(Database db, Arguments const& args, int id = -1, Tenant* tenants = nullptr) {
std::pair<Transaction, std::optional<std::string> /*token*/>
createNewTransaction(Database db, Arguments const& args, int id, std::optional<std::vector<Tenant>>& tenants) {
// No tenants specified
if (args.active_tenants <= 0) {
return db.createTransaction();
return { db.createTransaction(), {} };
}
// Create Tenant Transaction
int tenant_id = (id == -1) ? urand(0, args.active_tenants - 1) : id;
@ -101,7 +102,7 @@ Transaction createNewTransaction(Database db, Arguments const& args, int id = -1
std::string tenant_name;
// If provided tenants array, use it
if (tenants) {
tr = tenants[tenant_id].createTransaction();
tr = (*tenants)[tenant_id].createTransaction();
} else {
tenant_name = getTenantNameByIndex(tenant_id);
Tenant t = db.openTenant(toBytesRef(tenant_name));
@ -120,7 +121,7 @@ Transaction createNewTransaction(Database db, Arguments const& args, int id = -1
_exit(1);
}
}
return tr;
return { tr, { tenant_name } };
}
int cleanupTenants(ipc::AdminServer& server, Arguments const& args, int db_id) {
@ -197,14 +198,11 @@ int populate(Database db, const ThreadArgs& thread_args, int thread_tps, Workflo
auto watch_trace = Stopwatch(watch_total.getStart());
// tenants are assumed to have been generated by populateTenants() at main process, pre-fork
Tenant tenants[args.active_tenants];
for (int i = 0; i < args.active_tenants; ++i) {
tenants[i] = db.openTenant(toBytesRef(getTenantNameByIndex(i)));
}
std::optional<std::vector<Tenant>> tenants = args.prepareTenants(db);
int populate_iters = args.active_tenants > 0 ? args.active_tenants : 1;
// Each tenant should have the same range populated
for (auto t_id = 0; t_id < populate_iters; ++t_id) {
Transaction tx = createNewTransaction(db, args, t_id, args.active_tenants > 0 ? tenants : nullptr);
auto [tx, token] = createNewTransaction(db, args, t_id, tenants);
const auto key_begin = insertBegin(args.rows, process_idx, thread_idx, args.num_processes, args.num_threads);
const auto key_end = insertEnd(args.rows, process_idx, thread_idx, args.num_processes, args.num_threads);
auto key_checkpoint = key_begin; // in case of commit failure, restart from this key
@ -262,7 +260,7 @@ int populate(Database db, const ThreadArgs& thread_args, int thread_tps, Workflo
auto tx_restarter = ExitGuard([&watch_tx]() { watch_tx.startFromStop(); });
if (rc == FutureRC::OK) {
key_checkpoint = i + 1; // restart on failures from next key
tx = createNewTransaction(db, args, t_id, args.active_tenants > 0 ? tenants : nullptr);
std::tie(tx, token) = createNewTransaction(db, args, t_id, tenants);
} else if (rc == FutureRC::ABORT) {
return -1;
} else {
@ -306,6 +304,7 @@ void updateErrorStatsRunMode(WorkflowStatistics& stats, fdb::Error err, int op)
/* run one iteration of configured transaction */
int runOneTransaction(Transaction& tx,
std::optional<std::string> const& token,
Arguments const& args,
WorkflowStatistics& stats,
ByteString& key1,
@ -356,6 +355,8 @@ transaction_begin:
stats.addLatency(OP_COMMIT, step_latency);
}
tx.reset();
if (token)
tx.setOption(FDB_TR_OPTION_AUTHORIZATION_TOKEN, *token);
stats.incrOpCount(OP_COMMIT);
needs_commit = false;
}
@ -444,12 +445,7 @@ int runWorkload(Database db,
auto val = ByteString{};
val.resize(args.value_length);
// mimic typical tenant usage: keep tenants in memory
// and create transactions as needed
Tenant tenants[args.active_tenants];
for (int i = 0; i < args.active_tenants; ++i) {
tenants[i] = db.openTenant(toBytesRef(getTenantNameByIndex(i)));
}
std::optional<std::vector<fdb::Tenant>> tenants = args.prepareTenants(db);
/* main transaction loop */
while (1) {
@ -470,7 +466,7 @@ int runWorkload(Database db,
}
if (current_tps > 0 || thread_tps == 0 /* throttling off */) {
Transaction tx = createNewTransaction(db, args, -1, args.active_tenants > 0 ? tenants : nullptr);
auto [tx, token] = createNewTransaction(db, args, -1, tenants);
setTransactionTimeoutIfEnabled(args, tx);
/* enable transaction trace */
@ -507,7 +503,7 @@ int runWorkload(Database db,
}
}
rc = runOneTransaction(tx, args, workflow_stats, key1, key2, val);
rc = runOneTransaction(tx, token, args, workflow_stats, key1, key2, val);
if (rc) {
logr.warn("runOneTransaction failed ({})", rc);
}
@ -580,7 +576,7 @@ void runAsyncWorkload(Arguments const& args,
auto state =
std::make_shared<ResumableStateForPopulate>(Logger(WorkerProcess{}, args.verbose, process_idx, i),
db,
createNewTransaction(db, args),
db.createTransaction(),
io_context,
args,
shm.workerStatsSlot(process_idx, i),
@ -613,7 +609,7 @@ void runAsyncWorkload(Arguments const& args,
auto state =
std::make_shared<ResumableStateForRunWorkload>(Logger(WorkerProcess{}, args.verbose, process_idx, i),
db,
createNewTransaction(db, args),
db.createTransaction(),
io_context,
args,
shm.workerStatsSlot(process_idx, i),
@ -1685,6 +1681,10 @@ int Arguments::validate() {
}
if (enable_token_based_authorization) {
if (num_fdb_clusters > 1) {
logr.error("for simplicity, --enable_token_based_authorization must be used with exactly one fdb cluster");
return -1;
}
if (active_tenants <= 0 || total_tenants <= 0) {
logr.error("--enable_token_based_authorization must be used with at least one tenant");
return -1;
@ -1708,18 +1708,40 @@ bool Arguments::isAuthorizationEnabled() const noexcept {
private_key_pem.has_value();
}
void Arguments::collectTenantIds() {
auto db = Database(cluster_files[0]);
tenant_ids.clear();
tenant_ids.reserve(active_tenants);
}
void Arguments::generateAuthorizationTokens() {
assert(active_tenants > 0);
assert(keypair_id.has_value());
assert(private_key_pem.has_value());
authorization_tokens.clear();
assert(num_fdb_clusters == 1);
assert(!tenant_ids.empty());
// assumes tenants have already been populated
logr.info("generating authorization tokens to be used by worker threads");
auto stopwatch = Stopwatch(StartAtCtor{});
authorization_tokens = generateAuthorizationTokenMap(active_tenants, keypair_id.value(), private_key_pem.value());
authorization_tokens =
generateAuthorizationTokenMap(active_tenants, keypair_id.value(), private_key_pem.value(), tenant_ids);
assert(authorization_tokens.size() == active_tenants);
logr.info("generated {} tokens in {:6.3f} seconds", active_tenants, toDoubleSeconds(stopwatch.stop().diff()));
}
std::optional<std::vector<fdb::Tenant>> Arguments::prepareTenants(fdb::Database db) const {
if (active_tenants > 0) {
std::vector<fdb::Tenant> tenants(active_tenants);
for (auto i = 0; i < active_tenants; i++) {
tenants[i] = db.openTenant(toBytesRef(getTenantNameByIndex(i)));
}
return tenants;
} else {
return {};
}
}
void printStats(Arguments const& args, WorkflowStatistics const* stats, double const duration_sec, FILE* fp) {
static WorkflowStatistics prev;
@ -2503,10 +2525,6 @@ int main(int argc, char* argv[]) {
logr.setVerbosity(args.verbose);
if (args.isAuthorizationEnabled()) {
args.generateAuthorizationTokens();
}
if (args.mode == MODE_CLEAN) {
/* cleanup will be done from a single thread */
args.num_processes = 1;
@ -2525,7 +2543,8 @@ int main(int argc, char* argv[]) {
return 0;
}
if (args.total_tenants > 0 && (args.mode == MODE_BUILD || args.mode == MODE_CLEAN)) {
if (args.total_tenants > 0 &&
(args.isAuthorizationEnabled() || args.mode == MODE_BUILD || args.mode == MODE_CLEAN)) {
// below construction fork()s internally
auto server = ipc::AdminServer(args);
@ -2542,12 +2561,11 @@ int main(int argc, char* argv[]) {
logr.info("admin server ready");
}
}
// Use admin server to request tenant creation or deletion.
// This is necessary when tenant authorization is enabled,
// in which case the worker threads connect to database as untrusted clients,
// as which they wouldn't be allowed to create/delete tenants on their own.
// Although it is possible to allow worker threads to create/delete
// tenants in a authorization-disabled mode, use the admin server anyway for simplicity.
// Use admin server as proxy to creating/deleting tenants or pre-fetching tenant IDs for token signing when
// authorization is enabled This is necessary when tenant authorization is enabled, in which case the worker
// threads connect to database as untrusted clients, as which they wouldn't be allowed to create/delete tenants
// on their own. Although it is possible to allow worker threads to create/delete tenants in a
// authorization-disabled mode, use the admin server anyway for simplicity.
if (args.mode == MODE_CLEAN) {
// short-circuit tenant cleanup
const auto num_dbs = std::min(args.num_fdb_clusters, args.num_databases);
@ -2563,6 +2581,21 @@ int main(int argc, char* argv[]) {
return -1;
}
}
if ((args.mode == MODE_BUILD || args.mode == MODE_RUN) && args.isAuthorizationEnabled()) {
assert(args.num_fdb_clusters == 1);
// need to fetch tenant IDs to pre-generate tokens
// fetch all IDs in one go
auto res = server.send(ipc::FetchTenantIdsRequest{ args.cluster_files[0], 0, args.active_tenants });
if (res.error_message) {
logr.error("tenant ID fetch failed: {}", *res.error_message);
return -1;
} else {
logr.info("Successfully prefetched {} tenant IDs", res.ids.size());
assert(res.ids.size() == args.active_tenants);
args.tenant_ids = std::move(res.ids);
}
args.generateAuthorizationTokens();
}
}
const auto pid_main = getpid();

View File

@ -31,6 +31,7 @@
#include <chrono>
#include <list>
#include <map>
#include <optional>
#include <string>
#include <string_view>
#include <vector>
@ -143,10 +144,13 @@ constexpr const int MAX_REPORT_FILES = 200;
struct Arguments {
Arguments();
int validate();
void collectTenantIds();
bool isAuthorizationEnabled() const noexcept;
std::optional<std::vector<fdb::Tenant>> prepareTenants(fdb::Database db) const;
void generateAuthorizationTokens();
// Needs to be called once per fdb-accessing process
// Needs to be called once per fdb client process from a clean state:
// i.e. no FDB API called
int setGlobalOptions() const;
bool isAnyTimeoutEnabled() const;
@ -206,6 +210,7 @@ struct Arguments {
std::optional<std::string> keypair_id;
std::optional<std::string> private_key_pem;
std::map<std::string, std::string> authorization_tokens; // maps tenant name to token string
std::vector<int64_t> tenant_ids; // maps tenant index to tenant id for signing tokens
int transaction_timeout_db;
int transaction_timeout_tx;
};

View File

@ -28,7 +28,8 @@ namespace mako {
std::map<std::string, std::string> generateAuthorizationTokenMap(int num_tenants,
std::string public_key_id,
std::string private_key_pem) {
std::string private_key_pem,
const std::vector<int64_t>& tenant_ids) {
std::map<std::string, std::string> m;
auto t = authz::jwt::stdtypes::TokenSpec{};
auto const now = toIntegerSeconds(std::chrono::system_clock::now().time_since_epoch());
@ -40,14 +41,14 @@ std::map<std::string, std::string> generateAuthorizationTokenMap(int num_tenants
t.issuedAtUnixTime = now;
t.expiresAtUnixTime = now + 60 * 60 * 12; // Good for 12 hours
t.notBeforeUnixTime = now - 60 * 5; // activated 5 mins ago
const int tokenIdLen = 36; // UUID length
auto tokenId = std::string(tokenIdLen, '\0');
const int tokenid_len = 36; // UUID length
auto tokenid = std::string(tokenid_len, '\0');
for (auto i = 0; i < num_tenants; i++) {
std::string tenant_name = getTenantNameByIndex(i);
// swap out only the token ids and tenant names
randomAlphanumString(tokenId.data(), tokenIdLen);
t.tokenId = tokenId;
t.tenants = std::vector<std::string>{ tenant_name };
randomAlphanumString(tokenid.data(), tokenid_len);
t.tokenId = tokenid;
t.tenants = std::vector<int64_t>{ tenant_ids[i] };
m[tenant_name] = authz::jwt::stdtypes::signToken(t, private_key_pem);
}
return m;

View File

@ -21,6 +21,7 @@
#include <cassert>
#include <map>
#include <string>
#include <vector>
#include "fdb_api.hpp"
#include "utils.hpp"
@ -28,7 +29,8 @@ namespace mako {
std::map<std::string, std::string> generateAuthorizationTokenMap(int tenants,
std::string public_key_id,
std::string private_key_pem);
std::string private_key_pem,
const std::vector<int64_t>& tenant_ids);
inline std::string getTenantNameByIndex(int index) {
assert(index >= 0);

View File

@ -3312,7 +3312,6 @@ Reference<TransactionState> TransactionState::cloneAndReset(Reference<Transactio
newState->startTime = startTime;
newState->committedVersion = committedVersion;
newState->conflictingKeys = conflictingKeys;
newState->authToken = authToken;
newState->tenantSet = tenantSet;
return newState;

View File

@ -684,6 +684,8 @@ struct GlobalConfigRefreshRequest {
GlobalConfigRefreshRequest() {}
explicit GlobalConfigRefreshRequest(Version lastKnown) : lastKnown(lastKnown) {}
bool verify() const noexcept { return true; }
template <class Ar>
void serialize(Ar& ar) {
serializer(ar, lastKnown, reply);

View File

@ -317,6 +317,9 @@ struct ProtocolInfoRequest {
constexpr static FileIdentifier file_identifier = 13261233;
ReplyPromise<ProtocolInfoReply> reply{ PeerCompatibilityPolicy{ RequirePeer::AtLeast,
ProtocolVersion::withStableInterfaces() } };
bool verify() const noexcept { return true; }
template <class Ar>
void serialize(Ar& ar) {
serializer(ar, reply);

View File

@ -43,7 +43,7 @@ struct GrvProxyInterface {
// committed)
RequestStream<ReplyPromise<Void>> waitFailure; // reports heartbeat to master.
RequestStream<struct GetHealthMetricsRequest> getHealthMetrics;
RequestStream<struct GlobalConfigRefreshRequest> refreshGlobalConfig;
PublicRequestStream<struct GlobalConfigRefreshRequest> refreshGlobalConfig;
UID id() const { return getConsistentReadVersion.getEndpoint().token; }
std::string toString() const { return id().shortString(); }
@ -60,7 +60,7 @@ struct GrvProxyInterface {
RequestStream<ReplyPromise<Void>>(getConsistentReadVersion.getEndpoint().getAdjustedEndpoint(1));
getHealthMetrics = RequestStream<struct GetHealthMetricsRequest>(
getConsistentReadVersion.getEndpoint().getAdjustedEndpoint(2));
refreshGlobalConfig = RequestStream<struct GlobalConfigRefreshRequest>(
refreshGlobalConfig = PublicRequestStream<struct GlobalConfigRefreshRequest>(
getConsistentReadVersion.getEndpoint().getAdjustedEndpoint(3));
}
}

View File

@ -331,7 +331,8 @@ description is not currently required but encouraged.
hidden="true"/>
<Option name="authorization_token" code="2000"
description="Attach given authorization token to the transaction such that subsequent tenant-aware requests are authorized"
paramType="String" paramDescription="A JSON Web Token authorized to access data belonging to one or more tenants, indicated by 'tenants' claim of the token's payload."/>
paramType="String" paramDescription="A JSON Web Token authorized to access data belonging to one or more tenants, indicated by 'tenants' claim of the token's payload."
persistent="true"/>
</Scope>
<!-- The enumeration values matter - do not change them without

View File

@ -1074,7 +1074,8 @@ ACTOR static void deliver(TransportData* self,
if (receiver) {
TraceEvent(SevWarnAlways, "AttemptedRPCToPrivatePrevented")
.detail("From", peerAddress)
.detail("Token", destination.token);
.detail("Token", destination.token)
.detail("Receiver", typeid(*receiver).name());
ASSERT(!self->isLocalAddress(destination.getPrimaryAddress()));
Reference<Peer> peer = self->getOrOpenPeer(destination.getPrimaryAddress());
sendPacket(self,

View File

@ -22,6 +22,8 @@
#include "flow/actorcompiler.h" // has to be last include
using authz::TenantId;
template <class Key, class Value>
class LRUCache {
public:
@ -132,28 +134,26 @@ TEST_CASE("/fdbrpc/authz/LRUCache") {
struct CacheEntry {
Arena arena;
VectorRef<TenantNameRef> tenants;
VectorRef<TenantId> tenants;
Optional<StringRef> tokenId;
double expirationTime = 0.0;
};
struct AuditEntry {
NetworkAddress address;
TenantId tenantId;
Optional<Standalone<StringRef>> tokenId;
explicit AuditEntry(NetworkAddress const& address, CacheEntry const& cacheEntry)
: address(address),
bool operator==(const AuditEntry& other) const noexcept = default;
explicit AuditEntry(NetworkAddress const& address, TenantId tenantId, CacheEntry const& cacheEntry)
: address(address), tenantId(tenantId),
tokenId(cacheEntry.tokenId.present() ? Standalone<StringRef>(cacheEntry.tokenId.get(), cacheEntry.arena)
: Optional<Standalone<StringRef>>()) {}
};
bool operator==(AuditEntry const& lhs, AuditEntry const& rhs) {
return (lhs.address == rhs.address) && (lhs.tokenId.present() == rhs.tokenId.present()) &&
(!lhs.tokenId.present() || lhs.tokenId.get() == rhs.tokenId.get());
}
std::size_t hash_value(AuditEntry const& value) {
std::size_t seed = 0;
boost::hash_combine(seed, value.address);
boost::hash_combine(seed, value.tenantId);
if (value.tokenId.present()) {
boost::hash_combine(seed, value.tokenId.get());
}
@ -161,38 +161,17 @@ std::size_t hash_value(AuditEntry const& value) {
}
struct TokenCacheImpl {
TokenCacheImpl();
LRUCache<StringRef, CacheEntry> cache;
boost::unordered_set<AuditEntry> usedTokens;
Future<Void> auditor;
TokenCacheImpl();
double lastResetTime;
bool validate(TenantNameRef tenant, StringRef token);
bool validate(TenantId tenantId, StringRef token);
bool validateAndAdd(double currentTime, StringRef token, NetworkAddress const& peer);
void logTokenUsage(double currentTime, AuditEntry&& entry);
};
ACTOR Future<Void> tokenCacheAudit(TokenCacheImpl* self) {
state boost::unordered_set<AuditEntry> audits;
state boost::unordered_set<AuditEntry>::iterator iter;
state double lastLoggedTime = 0;
loop {
auto const timeSinceLog = g_network->timer() - lastLoggedTime;
if (timeSinceLog < FLOW_KNOBS->AUDIT_TIME_WINDOW) {
wait(delay(FLOW_KNOBS->AUDIT_TIME_WINDOW - timeSinceLog));
}
lastLoggedTime = g_network->timer();
audits.swap(self->usedTokens);
for (iter = audits.begin(); iter != audits.end(); ++iter) {
CODE_PROBE(true, "Audit Logging Running");
TraceEvent("AuditTokenUsed").detail("Client", iter->address).detail("TokenId", iter->tokenId).log();
wait(yield());
}
audits.clear();
}
}
TokenCacheImpl::TokenCacheImpl() : cache(FLOW_KNOBS->TOKEN_CACHE_SIZE) {
auditor = tokenCacheAudit(this);
}
TokenCacheImpl::TokenCacheImpl() : cache(FLOW_KNOBS->TOKEN_CACHE_SIZE), usedTokens(), lastResetTime(0) {}
TokenCache::TokenCache() : impl(new TokenCacheImpl()) {}
TokenCache::~TokenCache() {
@ -207,8 +186,8 @@ TokenCache& TokenCache::instance() {
return *reinterpret_cast<TokenCache*>(g_network->global(INetwork::enTokenCache));
}
bool TokenCache::validate(TenantNameRef name, StringRef token) {
return impl->validate(name, token);
bool TokenCache::validate(TenantId tenantId, StringRef token) {
return impl->validate(tenantId, token);
}
#define TRACE_INVALID_PARSED_TOKEN(reason, token) \
@ -280,8 +259,8 @@ bool TokenCacheImpl::validateAndAdd(double currentTime, StringRef token, Network
CacheEntry c;
c.expirationTime = t.expiresAtUnixTime.get();
c.tenants.reserve(c.arena, t.tenants.get().size());
for (auto tenant : t.tenants.get()) {
c.tenants.push_back_deep(c.arena, tenant);
for (auto tenantId : t.tenants.get()) {
c.tenants.push_back(c.arena, tenantId);
}
if (t.tokenId.present()) {
c.tokenId = StringRef(c.arena, t.tokenId.get());
@ -291,7 +270,7 @@ bool TokenCacheImpl::validateAndAdd(double currentTime, StringRef token, Network
}
}
bool TokenCacheImpl::validate(TenantNameRef name, StringRef token) {
bool TokenCacheImpl::validate(TenantId tenantId, StringRef token) {
NetworkAddress peer = FlowTransport::transport().currentDeliveryPeerAddress();
auto cachedEntry = cache.get(token);
double currentTime = g_network->timer();
@ -314,21 +293,44 @@ bool TokenCacheImpl::validate(TenantNameRef name, StringRef token) {
}
bool tenantFound = false;
for (auto const& t : entry->tenants) {
if (t == name) {
if (t == tenantId) {
tenantFound = true;
break;
}
}
if (!tenantFound) {
CODE_PROBE(true, "Valid token doesn't reference tenant");
TraceEvent(SevWarn, "TenantTokenMismatch").detail("From", peer).detail("Tenant", name.toString());
TraceEvent(SevWarn, "TenantTokenMismatch")
.detail("From", peer)
.detail("RequestedTenant", fmt::format("{:#x}", tenantId))
.detail("TenantsInToken", fmt::format("{:#x}", fmt::join(entry->tenants, " ")));
return false;
}
// audit logging
usedTokens.insert(AuditEntry(peer, *cachedEntry.get()));
if (FLOW_KNOBS->AUDIT_LOGGING_ENABLED)
logTokenUsage(currentTime, AuditEntry(peer, tenantId, *cachedEntry.get()));
return true;
}
void TokenCacheImpl::logTokenUsage(double currentTime, AuditEntry&& entry) {
if (currentTime > lastResetTime + FLOW_KNOBS->AUDIT_TIME_WINDOW) {
// clear usage cache every AUDIT_TIME_WINDOW seconds
usedTokens.clear();
lastResetTime = currentTime;
}
auto [iter, inserted] = usedTokens.insert(std::move(entry));
if (inserted) {
// access in the context of this (client_ip, tenant, token_id) tuple hasn't been logged in current window. log
// usage.
CODE_PROBE(true, "Audit Logging Running");
TraceEvent("AuditTokenUsed")
.detail("Client", iter->address)
.detail("TenantId", fmt::format("{:#x}", iter->tenantId))
.detail("TokenId", iter->tokenId)
.log();
}
}
namespace authz::jwt {
extern TokenRef makeRandomTokenSpec(Arena&, IRandom&, authz::Algorithm);
}
@ -375,9 +377,9 @@ TEST_CASE("/fdbrpc/authz/TokenCache/BadTokens") {
},
{
[](Arena& arena, IRandom&, authz::jwt::TokenRef& token) {
StringRef* newTenants = new (arena) StringRef[1];
*newTenants = token.tenants.get()[0].substr(1);
token.tenants = VectorRef<StringRef>(newTenants, 1);
TenantId* newTenants = new (arena) TenantId[1];
*newTenants = token.tenants.get()[0] + 1;
token.tenants = VectorRef<TenantId>(newTenants, 1);
},
"UnmatchedTenant",
},
@ -443,15 +445,15 @@ TEST_CASE("/fdbrpc/authz/TokenCache/BadTokens") {
}
}
}
if (TokenCache::instance().validate("TenantNameDontMatterHere"_sr, StringRef())) {
if (TokenCache::instance().validate(TenantInfo::INVALID_TENANT, StringRef())) {
fmt::print("Unexpected successful validation of ill-formed token (no signature part)\n");
ASSERT(false);
}
if (TokenCache::instance().validate("TenantNameDontMatterHere"_sr, "1111.22"_sr)) {
if (TokenCache::instance().validate(TenantInfo::INVALID_TENANT, "1111.22"_sr)) {
fmt::print("Unexpected successful validation of ill-formed token (no signature part)\n");
ASSERT(false);
}
if (TokenCache::instance().validate("TenantNameDontMatterHere2"_sr, "////.////.////"_sr)) {
if (TokenCache::instance().validate(TenantInfo::INVALID_TENANT, "////.////.////"_sr)) {
fmt::print("Unexpected successful validation of unparseable token\n");
ASSERT(false);
}

View File

@ -52,11 +52,11 @@
namespace {
// test-only constants for generating random tenant/key names
// test-only constants for generating random tenant ID and key names
constexpr int MinIssuerNameLen = 16;
constexpr int MaxIssuerNameLenPlus1 = 25;
constexpr int MinTenantNameLen = 8;
constexpr int MaxTenantNameLenPlus1 = 17;
constexpr authz::TenantId MinTenantId = 1;
constexpr authz::TenantId MaxTenantIdPlus1 = 0xffffffffll;
constexpr int MinKeyNameLen = 10;
constexpr int MaxKeyNameLenPlus1 = 21;
@ -176,6 +176,14 @@ void appendField(fmt::memory_buffer& b, char const (&name)[NameLen], Optional<Fi
fmt::format_to(bi, fmt::runtime(f[i].toStringView()));
}
fmt::format_to(bi, "]");
} else if constexpr (std::is_same_v<FieldType, VectorRef<TenantId>>) {
fmt::format_to(bi, " {}=[", name);
for (auto i = 0; i < f.size(); i++) {
if (i)
fmt::format_to(bi, ",");
fmt::format_to(bi, "{:#x}", f[i]);
}
fmt::format_to(bi, "]");
} else if constexpr (std::is_same_v<FieldType, StringRef>) {
fmt::format_to(bi, " {}={}", name, f.toStringView());
} else {
@ -202,33 +210,34 @@ StringRef toStringRef(Arena& arena, const TokenRef& tokenSpec) {
return StringRef(str, buf.size());
}
template <class FieldType, class Writer, bool MakeStringArrayBase64 = false>
void putField(Optional<FieldType> const& field,
Writer& wr,
const char* fieldName,
std::bool_constant<MakeStringArrayBase64> _ = std::bool_constant<false>{}) {
template <class FieldType, class Writer>
void putField(Optional<FieldType> const& field, Writer& wr, const char* fieldName) {
if (!field.present())
return;
wr.Key(fieldName);
auto const& value = field.get();
static_assert(std::is_same_v<StringRef, FieldType> || std::is_same_v<FieldType, uint64_t> ||
std::is_same_v<FieldType, VectorRef<StringRef>>);
std::is_same_v<FieldType, VectorRef<StringRef>> || std::is_same_v<FieldType, VectorRef<TenantId>>);
if constexpr (std::is_same_v<StringRef, FieldType>) {
wr.String(reinterpret_cast<const char*>(value.begin()), value.size());
} else if constexpr (std::is_same_v<FieldType, uint64_t>) {
wr.Uint64(value);
} else if constexpr (std::is_same_v<FieldType, VectorRef<TenantId>>) {
// "tenants" array = array of base64-encoded tenant key prefix
// key prefix = bytestring representation of big-endian tenant ID (int64_t)
Arena arena;
wr.StartArray();
for (auto elem : value) {
auto const bigEndianId = bigEndian64(elem);
auto encodedElem =
base64::encode(arena, StringRef(reinterpret_cast<const uint8_t*>(&bigEndianId), sizeof(bigEndianId)));
wr.String(reinterpret_cast<const char*>(encodedElem.begin()), encodedElem.size());
}
wr.EndArray();
} else {
wr.StartArray();
if constexpr (MakeStringArrayBase64) {
Arena arena;
for (auto elem : value) {
auto encodedElem = base64::encode(arena, elem);
wr.String(reinterpret_cast<const char*>(encodedElem.begin()), encodedElem.size());
}
} else {
for (auto elem : value) {
wr.String(reinterpret_cast<const char*>(elem.begin()), elem.size());
}
for (auto elem : value) {
wr.String(reinterpret_cast<const char*>(elem.begin()), elem.size());
}
wr.EndArray();
}
@ -259,7 +268,7 @@ StringRef makeSignInput(Arena& arena, const TokenRef& tokenSpec) {
putField(tokenSpec.expiresAtUnixTime, payload, "exp");
putField(tokenSpec.notBeforeUnixTime, payload, "nbf");
putField(tokenSpec.tokenId, payload, "jti");
putField(tokenSpec.tenants, payload, "tenants", std::bool_constant<true>{} /* encode tenants in base64 */);
putField(tokenSpec.tenants, payload, "tenants");
payload.EndObject();
auto const headerPartLen = base64::url::encodedLength(headerBuffer.GetSize());
auto const payloadPartLen = base64::url::encodedLength(payloadBuffer.GetSize());
@ -347,18 +356,17 @@ Optional<StringRef> parseHeaderPart(Arena& arena, TokenRef& token, StringRef b64
return {};
}
template <class FieldType, bool ExpectBase64StringArray = false>
template <class FieldType>
Optional<StringRef> parseField(Arena& arena,
Optional<FieldType>& out,
const rapidjson::Document& d,
const char* fieldName,
std::bool_constant<ExpectBase64StringArray> _ = std::bool_constant<false>{}) {
const char* fieldName) {
auto fieldItr = d.FindMember(fieldName);
if (fieldItr == d.MemberEnd())
return {};
auto const& field = fieldItr->value;
static_assert(std::is_same_v<StringRef, FieldType> || std::is_same_v<FieldType, uint64_t> ||
std::is_same_v<FieldType, VectorRef<StringRef>>);
std::is_same_v<FieldType, VectorRef<StringRef>> || std::is_same_v<FieldType, VectorRef<TenantId>>);
if constexpr (std::is_same_v<FieldType, StringRef>) {
if (!field.IsString()) {
return StringRef(arena, fmt::format("'{}' is not a string", fieldName));
@ -369,7 +377,7 @@ Optional<StringRef> parseField(Arena& arena,
return StringRef(arena, fmt::format("'{}' is not a number", fieldName));
}
out = static_cast<uint64_t>(field.GetDouble());
} else {
} else if constexpr (std::is_same_v<FieldType, VectorRef<StringRef>>) {
if (!field.IsArray()) {
return StringRef(arena, fmt::format("'{}' is not an array", fieldName));
}
@ -379,26 +387,50 @@ Optional<StringRef> parseField(Arena& arena,
if (!field[i].IsString()) {
return StringRef(arena, fmt::format("{}th element of '{}' is not a string", i + 1, fieldName));
}
if constexpr (ExpectBase64StringArray) {
Optional<StringRef> decodedString = base64::decode(
arena,
StringRef(reinterpret_cast<const uint8_t*>(field[i].GetString()), field[i].GetStringLength()));
if (decodedString.present()) {
vector[i] = decodedString.get();
} else {
CODE_PROBE(true, "Base64 token field has failed to be parsed");
return StringRef(arena,
fmt::format("Failed to base64-decode {}th element of '{}'", i + 1, fieldName));
}
} else {
vector[i] = StringRef(
arena, reinterpret_cast<const uint8_t*>(field[i].GetString()), field[i].GetStringLength());
}
vector[i] = StringRef(
arena, reinterpret_cast<const uint8_t*>(field[i].GetString()), field[i].GetStringLength());
}
out = VectorRef<StringRef>(vector, field.Size());
} else {
out = VectorRef<StringRef>();
}
} else {
// tenant ids case: convert array of base64-encoded length-8 bytestring containing big-endian int64_t to
// local-endian int64_t
if (!field.IsArray()) {
return StringRef(arena, fmt::format("'{}' is not an array", fieldName));
}
if (field.Size() > 0) {
auto vector = new (arena) TenantId[field.Size()];
for (auto i = 0; i < field.Size(); i++) {
if (!field[i].IsString()) {
return StringRef(arena, fmt::format("{}th element of '{}' is not a string", i + 1, fieldName));
}
Optional<StringRef> decodedString = base64::decode(
arena,
StringRef(reinterpret_cast<const uint8_t*>(field[i].GetString()), field[i].GetStringLength()));
if (decodedString.present()) {
auto const tenantPrefix = decodedString.get();
if (tenantPrefix.size() != sizeof(TenantId)) {
CODE_PROBE(true, "Tenant prefix has an invalid length");
return StringRef(arena,
fmt::format("{}th element of '{}' has an invalid bytewise length of {}",
i + 1,
fieldName,
tenantPrefix.size()));
}
TenantId tenantId = *reinterpret_cast<const TenantId*>(tenantPrefix.begin());
vector[i] = fromBigEndian64(tenantId);
} else {
CODE_PROBE(true, "Tenant field has failed to be parsed");
return StringRef(arena,
fmt::format("Failed to base64-decode {}th element of '{}'", i + 1, fieldName));
}
}
out = VectorRef<TenantId>(vector, field.Size());
} else {
out = VectorRef<TenantId>();
}
}
return {};
}
@ -431,12 +463,7 @@ Optional<StringRef> parsePayloadPart(Arena& arena, TokenRef& token, StringRef b6
return err;
if ((err = parseField(arena, token.notBeforeUnixTime, d, "nbf")).present())
return err;
if ((err = parseField(arena,
token.tenants,
d,
"tenants",
std::bool_constant<true>{} /* expect field elements encoded in base64 */))
.present())
if ((err = parseField(arena, token.tenants, d, "tenants")).present())
return err;
return {};
}
@ -526,16 +553,16 @@ TokenRef makeRandomTokenSpec(Arena& arena, IRandom& rng, Algorithm alg) {
auto numAudience = rng.randomInt(1, 5);
auto aud = new (arena) StringRef[numAudience];
for (auto i = 0; i < numAudience; i++)
aud[i] = genRandomAlphanumStringRef(arena, rng, MinTenantNameLen, MaxTenantNameLenPlus1);
aud[i] = genRandomAlphanumStringRef(arena, rng, MinIssuerNameLen, MaxIssuerNameLenPlus1);
ret.audience = VectorRef<StringRef>(aud, numAudience);
ret.issuedAtUnixTime = g_network->timer();
ret.notBeforeUnixTime = ret.issuedAtUnixTime.get();
ret.expiresAtUnixTime = ret.issuedAtUnixTime.get() + rng.randomInt(360, 1080 + 1);
auto numTenants = rng.randomInt(1, 3);
auto tenants = new (arena) StringRef[numTenants];
auto tenants = new (arena) TenantId[numTenants];
for (auto i = 0; i < numTenants; i++)
tenants[i] = genRandomAlphanumStringRef(arena, rng, MinTenantNameLen, MaxTenantNameLenPlus1);
ret.tenants = VectorRef<StringRef>(tenants, numTenants);
tenants[i] = rng.randomInt64(MinTenantId, MaxTenantIdPlus1);
ret.tenants = VectorRef<TenantId>(tenants, numTenants);
return ret;
}
@ -584,8 +611,7 @@ TEST_CASE("/fdbrpc/TokenSign/JWT") {
ASSERT(verifyOk);
}
// try tampering with signed token by adding one more tenant
tokenSpec.tenants.get().push_back(
arena, genRandomAlphanumStringRef(arena, rng, MinTenantNameLen, MaxTenantNameLenPlus1));
tokenSpec.tenants.get().push_back(arena, rng.randomInt64(MinTenantId, MaxTenantIdPlus1));
auto tamperedTokenPart = makeSignInput(arena, tokenSpec);
auto tamperedTokenString = fmt::format("{}.{}", tamperedTokenPart.toString(), signaturePart.toString());
std::tie(verifyOk, verifyErr) = authz::jwt::verifyToken(StringRef(tamperedTokenString), publicKey);
@ -608,12 +634,12 @@ TEST_CASE("/fdbrpc/TokenSign/JWT/ToStringRef") {
t.notBeforeUnixTime = 789ul;
t.keyId = "keyId"_sr;
t.tokenId = "tokenId"_sr;
StringRef tenants[2]{ "tenant1"_sr, "tenant2"_sr };
t.tenants = VectorRef<StringRef>(tenants, 2);
authz::TenantId tenants[2]{ 0x1ll, 0xabcdefabcdefll };
t.tenants = VectorRef<authz::TenantId>(tenants, 2);
auto arena = Arena();
auto tokenStr = toStringRef(arena, t);
auto tokenStrExpected =
"alg=ES256 kid=keyId iss=issuer sub=subject aud=[aud1,aud2,aud3] iat=123 exp=456 nbf=789 jti=tokenId tenants=[tenant1,tenant2]"_sr;
"alg=ES256 kid=keyId iss=issuer sub=subject aud=[aud1,aud2,aud3] iat=123 exp=456 nbf=789 jti=tokenId tenants=[0x1,0xabcdefabcdef]"_sr;
if (tokenStr != tokenStrExpected) {
fmt::print("Expected: {}\nGot : {}\n", tokenStrExpected.toStringView(), tokenStr.toStringView());
ASSERT(false);

View File

@ -29,38 +29,48 @@ namespace {
// converts std::optional<STANDARD_TYPE(S)> to Optional<FLOW_TYPE(T)>
template <class T, class S>
void convertAndAssign(Arena& arena, Optional<T>& to, const std::optional<S>& from) {
if (!from.has_value()) {
to.reset();
return;
}
if constexpr (std::is_same_v<S, std::vector<std::string>>) {
static_assert(std::is_same_v<T, VectorRef<StringRef>>,
"Source type std::vector<std::string> must convert to VectorRef<StringRef>");
if (from.has_value()) {
const auto& value = from.value();
if (value.empty()) {
to = VectorRef<StringRef>();
} else {
// no need to deep copy string because we have the underlying memory for the duration of token signing.
auto buf = new (arena) StringRef[value.size()];
for (auto i = 0u; i < value.size(); i++) {
buf[i] = StringRef(value[i]);
}
to = VectorRef<StringRef>(buf, value.size());
const auto& value = from.value();
if (value.empty()) {
to = VectorRef<StringRef>();
} else {
// no need to deep copy string because we have the underlying memory for the duration of token signing.
auto buf = new (arena) StringRef[value.size()];
for (auto i = 0u; i < value.size(); i++) {
buf[i] = StringRef(value[i]);
}
to = VectorRef<StringRef>(buf, value.size());
}
} else if constexpr (std::is_same_v<S, std::vector<int64_t>>) {
static_assert(std::is_same_v<T, VectorRef<int64_t>>,
"Source type std::vector<int64_t> must convert to VectorRef<int64_t>");
const auto& value = from.value();
if (value.empty()) {
to = VectorRef<int64_t>();
} else {
auto buf = new (arena) int64_t[value.size()];
for (auto i = 0; i < value.size(); i++)
buf[i] = value[i];
to = VectorRef<int64_t>(buf, value.size());
}
} else if constexpr (std::is_same_v<S, std::string>) {
static_assert(std::is_same_v<T, StringRef>, "Source type std::string must convert to StringRef");
if (from.has_value()) {
const auto& value = from.value();
// no need to deep copy string because we have the underlying memory for the duration of token signing.
to = StringRef(value);
}
const auto& value = from.value();
// no need to deep copy string because we have the underlying memory for the duration of token signing.
to = StringRef(value);
} else {
static_assert(
std::is_same_v<S, T>,
"Source types that aren't std::vector<std::string> or std::string must have the same destination type");
static_assert(std::is_same_v<S, T>,
"Source types that aren't std::vector<std::string>, std::vector<int64_t>, or std::string must "
"have the same destination type");
static_assert(std::is_trivially_copy_assignable_v<S>,
"Source types that aren't std::vector<std::string> or std::string must not use heap memory");
if (from.has_value()) {
to = from.value();
}
to = from.value();
}
}

View File

@ -76,9 +76,7 @@ struct serializable_traits<TenantInfo> : std::true_type {
if constexpr (Archiver::isDeserializing) {
bool tenantAuthorized = FLOW_KNOBS->ALLOW_TOKENLESS_TENANT_ACCESS;
if (!tenantAuthorized && v.tenantId != TenantInfo::INVALID_TENANT && v.token.present()) {
// TODO: update tokens to be ID based
// tenantAuthorized = TokenCache::instance().validate(v.tenantId, v.token.get());
tenantAuthorized = true;
tenantAuthorized = TokenCache::instance().validate(v.tenantId, v.token.get());
}
v.trusted = FlowTransport::transport().currentDeliveryPeerIsTrusted();
v.tenantAuthorized = tenantAuthorized;

View File

@ -21,6 +21,7 @@
#ifndef TOKENCACHE_H_
#define TOKENCACHE_H_
#include "fdbrpc/TenantName.h"
#include "fdbrpc/TokenSpec.h"
#include "flow/Arena.h"
class TokenCache : NonCopyable {
@ -31,7 +32,7 @@ public:
~TokenCache();
static void createInstance();
static TokenCache& instance();
bool validate(TenantNameRef tenant, StringRef token);
bool validate(authz::TenantId tenant, StringRef token);
};
#endif // TOKENCACHE_H_

View File

@ -28,6 +28,8 @@
namespace authz {
using TenantId = int64_t;
enum class Algorithm : int {
RS256,
ES256,
@ -67,7 +69,7 @@ struct BasicTokenSpec {
OptionalType<uint64_t> expiresAtUnixTime; // exp
OptionalType<uint64_t> notBeforeUnixTime; // nbf
OptionalType<StringType> tokenId; // jti
OptionalType<VectorType<StringType>> tenants; // tenants
OptionalType<VectorType<TenantId>> tenants; // tenants
// signature part
StringType signature;
};

View File

@ -695,6 +695,10 @@ struct NetNotifiedQueue final : NotifiedQueue<T>, FlowReceiver, FastAllocated<Ne
if (!message.verify()) {
if constexpr (HasReply<T>) {
message.reply.sendError(permission_denied());
TraceEvent(SevWarnAlways, "UnauthorizedAccessPrevented")
.detail("RequestType", typeid(T).name())
.detail("ClientIP", FlowTransport::transport().currentDeliveryPeerAddress())
.log();
}
} else {
this->send(std::move(message));

View File

@ -346,7 +346,7 @@ public:
double checkDisabled(const std::string& desc) const;
// generate authz token for use in simulation environment
Standalone<StringRef> makeToken(StringRef tenantName, uint64_t ttlSecondsFromNow);
Standalone<StringRef> makeToken(int64_t tenantId, uint64_t ttlSecondsFromNow);
static thread_local ProcessInfo* currentProcess;

View File

@ -174,7 +174,7 @@ void ISimulator::displayWorkers() const {
return;
}
Standalone<StringRef> ISimulator::makeToken(StringRef tenantName, uint64_t ttlSecondsFromNow) {
Standalone<StringRef> ISimulator::makeToken(int64_t tenantId, uint64_t ttlSecondsFromNow) {
ASSERT_GT(authKeys.size(), 0);
auto tokenSpec = authz::jwt::TokenRef{};
auto [keyName, key] = *authKeys.begin();
@ -188,7 +188,7 @@ Standalone<StringRef> ISimulator::makeToken(StringRef tenantName, uint64_t ttlSe
tokenSpec.expiresAtUnixTime = now + ttlSecondsFromNow;
auto const tokenId = deterministicRandom()->randomAlphaNumeric(10);
tokenSpec.tokenId = StringRef(tokenId);
tokenSpec.tenants = VectorRef<StringRef>(&tenantName, 1);
tokenSpec.tenants = VectorRef<int64_t>(&tenantId, 1);
auto ret = Standalone<StringRef>();
ret.contents() = authz::jwt::signToken(ret.arena(), tokenSpec, key);
return ret;

View File

@ -3511,7 +3511,7 @@ ACTOR Future<Void> monitorLeaderWithDelayedCandidacy(
extern void setupStackSignal();
ACTOR Future<Void> serveProtocolInfo() {
state RequestStream<ProtocolInfoRequest> protocolInfo(
state PublicRequestStream<ProtocolInfoRequest> protocolInfo(
PeerCompatibilityPolicy{ RequirePeer::AtLeast, ProtocolVersion::withStableInterfaces() });
protocolInfo.makeWellKnownEndpoint(WLTOKEN_PROTOCOL_INFO, TaskPriority::DefaultEndpoint);
loop {

View File

@ -46,6 +46,7 @@ struct AuthzSecurityWorkload : TestWorkload {
std::vector<Future<Void>> clients;
Arena arena;
Reference<Tenant> tenant;
Reference<Tenant> anotherTenant;
TenantName tenantName;
TenantName anotherTenantName;
Standalone<StringRef> signedToken;
@ -68,10 +69,6 @@ struct AuthzSecurityWorkload : TestWorkload {
tLogConfigKey = getOption(options, "tLogConfigKey"_sr, "TLogInterface"_sr);
ASSERT(g_network->isSimulated());
// make it comfortably longer than the timeout of the workload
signedToken = g_simulator->makeToken(
tenantName, uint64_t(std::lround(getCheckTimeout())) + uint64_t(std::lround(testDuration)) + 100);
signedTokenAnotherTenant = g_simulator->makeToken(
anotherTenantName, uint64_t(std::lround(getCheckTimeout())) + uint64_t(std::lround(testDuration)) + 100);
testFunctions.push_back(
[this](Database cx) { return testCrossTenantGetDisallowed(this, cx, PositiveTestcase::True); });
testFunctions.push_back(
@ -87,10 +84,15 @@ struct AuthzSecurityWorkload : TestWorkload {
Future<Void> setup(Database const& cx) override {
tenant = makeReference<Tenant>(cx, tenantName);
return tenant->ready();
anotherTenant = makeReference<Tenant>(cx, anotherTenantName);
return tenant->ready() && anotherTenant->ready();
}
Future<Void> start(Database const& cx) override {
signedToken = g_simulator->makeToken(
tenant->id(), uint64_t(std::lround(getCheckTimeout())) + uint64_t(std::lround(testDuration)) + 100);
signedTokenAnotherTenant = g_simulator->makeToken(
anotherTenant->id(), uint64_t(std::lround(getCheckTimeout())) + uint64_t(std::lround(testDuration)) + 100);
for (int c = 0; c < actorCount; c++)
clients.push_back(timeout(runTestClient(this, cx->clone()), testDuration, Void()));
return waitForAll(clients);
@ -128,9 +130,9 @@ struct AuthzSecurityWorkload : TestWorkload {
StringRef key,
StringRef value) {
state Transaction tr(cx, tenant);
self->setAuthToken(tr, token);
loop {
try {
self->setAuthToken(tr, token);
tr.set(key, value);
wait(tr.commit());
return tr.getCommittedVersion();
@ -146,10 +148,10 @@ struct AuthzSecurityWorkload : TestWorkload {
Standalone<StringRef> token,
StringRef key) {
state Transaction tr(cx, tenant);
self->setAuthToken(tr, token);
loop {
try {
// trigger GetKeyServerLocationsRequest and subsequent cache update
self->setAuthToken(tr, token);
wait(success(tr.get(key)));
auto loc = cx->getCachedLocation(tr.trState->getTenantInfo(), key);
if (loc.present()) {
@ -342,10 +344,10 @@ struct AuthzSecurityWorkload : TestWorkload {
state Version committedVersion =
wait(setAndCommitKeyValueAndGetVersion(self, cx, self->tenant, self->signedToken, key, value));
state Transaction tr(cx, self->tenant);
self->setAuthToken(tr, self->signedToken);
state Optional<Value> tLogConfigString;
loop {
try {
self->setAuthToken(tr, self->signedToken);
Optional<Value> value = wait(tr.get(self->tLogConfigKey));
ASSERT(value.present());
tLogConfigString = value;

View File

@ -24,6 +24,7 @@
#include "flow/serialize.h"
#include "fdbrpc/simulator.h"
#include "fdbrpc/TokenSign.h"
#include "fdbrpc/TenantInfo.h"
#include "fdbclient/FDBOptions.g.h"
#include "fdbclient/NativeAPI.actor.h"
#include "fdbserver/TesterInterface.actor.h"
@ -39,13 +40,20 @@ template <>
struct CycleMembers<true> {
Arena arena;
TenantName tenant;
int64_t tenantId;
Standalone<StringRef> signedToken;
bool useToken;
};
template <bool>
struct CycleWorkload;
ACTOR Future<Void> prepareToken(Database cx, CycleWorkload<true>* self);
template <bool MultiTenancy>
struct CycleWorkload : TestWorkload, CycleMembers<MultiTenancy> {
static constexpr auto NAME = MultiTenancy ? "TenantCycle" : "Cycle";
static constexpr auto TenantEnabled = MultiTenancy;
int actorCount, nodeCount;
double testDuration, transactionsPerSecond, minExpectedTransactionsPerSecond, traceParentProbability;
Key keyPrefix;
@ -68,17 +76,18 @@ struct CycleWorkload : TestWorkload, CycleMembers<MultiTenancy> {
ASSERT(g_network->isSimulated());
this->useToken = getOption(options, "useToken"_sr, true);
this->tenant = getOption(options, "tenant"_sr, "CycleTenant"_sr);
// make it comfortably longer than the timeout of the workload
this->signedToken = g_simulator->makeToken(
this->tenant, uint64_t(std::lround(getCheckTimeout())) + uint64_t(std::lround(testDuration)) + 100);
this->tenantId = TenantInfo::INVALID_TENANT;
}
}
Future<Void> setup(Database const& cx) override {
Future<Void> prepare;
if constexpr (MultiTenancy) {
cx->defaultTenant = this->tenant;
prepare = prepareToken(cx, this);
} else {
prepare = Void();
}
return bulkSetup(cx, this, nodeCount, Promise<double>());
return runAfter(prepare, [this, cx](Void) { return bulkSetup(cx, this, nodeCount, Promise<double>()); });
}
Future<Void> start(Database const& cx) override {
if constexpr (MultiTenancy) {
@ -144,7 +153,6 @@ struct CycleWorkload : TestWorkload, CycleMembers<MultiTenancy> {
state double tstart = now();
state int r = deterministicRandom()->randomInt(0, self->nodeCount);
state Transaction tr(cx);
self->setAuthToken(tr);
if (deterministicRandom()->random01() <= self->traceParentProbability) {
state Span span("CycleClient"_loc);
TraceEvent("CycleTracingTransaction", span.context.traceID).log();
@ -153,6 +161,7 @@ struct CycleWorkload : TestWorkload, CycleMembers<MultiTenancy> {
}
while (true) {
try {
self->setAuthToken(tr);
// Reverse next and next^2 node
Optional<Value> v = wait(tr.get(self->key(r)));
if (!v.present())
@ -291,9 +300,9 @@ struct CycleWorkload : TestWorkload, CycleMembers<MultiTenancy> {
// One client checks the validity of the cycle
state Transaction tr(cx);
state int retryCount = 0;
self->setAuthToken(tr);
loop {
try {
self->setAuthToken(tr);
state Version v = wait(tr.getReadVersion());
RangeResult data = wait(tr.getRange(firstGreaterOrEqual(doubleToTestKey(0.0, self->keyPrefix)),
firstGreaterOrEqual(doubleToTestKey(1.0, self->keyPrefix)),
@ -316,5 +325,17 @@ struct CycleWorkload : TestWorkload, CycleMembers<MultiTenancy> {
}
};
ACTOR Future<Void> prepareToken(Database cx, CycleWorkload<true>* self) {
cx->defaultTenant = self->tenant;
int64_t tenantId = wait(cx->lookupTenant(self->tenant));
self->tenantId = tenantId;
ASSERT_NE(self->tenantId, TenantInfo::INVALID_TENANT);
// make the lifetime comfortably longer than the timeout of the workload
self->signedToken = g_simulator->makeToken(self->tenantId,
uint64_t(std::lround(self->getCheckTimeout())) +
uint64_t(std::lround(self->testDuration)) + 100);
return Void();
}
WorkloadFactory<CycleWorkload<false>> CycleWorkloadFactory(UntrustedMode::False);
WorkloadFactory<CycleWorkload<true>> TenantCycleWorkloadFactory(UntrustedMode::True);

View File

@ -140,10 +140,11 @@ void FlowKnobs::initialize(Randomize randomize, IsSimulated isSimulated) {
//Authorization
init( ALLOW_TOKENLESS_TENANT_ACCESS, false );
init( AUDIT_LOGGING_ENABLED, true );
init( PUBLIC_KEY_FILE_MAX_SIZE, 1024 * 1024 );
init( PUBLIC_KEY_FILE_REFRESH_INTERVAL_SECONDS, 30 );
init( MAX_CACHED_EXPIRED_TOKENS, 1024 );
init( AUDIT_TIME_WINDOW, 5.0 );
init( TOKEN_CACHE_SIZE, 2000 );
//AsyncFileCached
init( PAGE_CACHE_4K, 2LL<<30 );
@ -307,14 +308,12 @@ void FlowKnobs::initialize(Randomize randomize, IsSimulated isSimulated) {
if ( randomize && BUGGIFY) { ENCRYPT_CIPHER_KEY_CACHE_TTL = deterministicRandom()->randomInt(2, 10) * 60; }
init( ENCRYPT_KEY_REFRESH_INTERVAL, isSimulated ? 60 : 8 * 60 );
if ( randomize && BUGGIFY) { ENCRYPT_KEY_REFRESH_INTERVAL = deterministicRandom()->randomInt(2, 10); }
init( TOKEN_CACHE_SIZE, 100 );
init( ENCRYPT_KEY_CACHE_LOGGING_INTERVAL, 5.0 );
init( ENCRYPT_KEY_CACHE_LOGGING_SKETCH_ACCURACY, 0.01 );
// Refer to EncryptUtil::EncryptAuthTokenAlgo for more details
init( ENCRYPT_HEADER_AUTH_TOKEN_ENABLED, false ); if ( randomize && BUGGIFY ) { ENCRYPT_HEADER_AUTH_TOKEN_ENABLED = !ENCRYPT_HEADER_AUTH_TOKEN_ENABLED; }
init( ENCRYPT_HEADER_AUTH_TOKEN_ALGO, 0 ); if ( randomize && ENCRYPT_HEADER_AUTH_TOKEN_ENABLED ) { ENCRYPT_HEADER_AUTH_TOKEN_ALGO = getRandomAuthTokenAlgo(); }
// REST Client
init( RESTCLIENT_MAX_CONNECTIONPOOL_SIZE, 10 );
init( RESTCLIENT_CONNECT_TRIES, 10 );

View File

@ -204,10 +204,11 @@ public:
// Authorization
bool ALLOW_TOKENLESS_TENANT_ACCESS;
bool AUDIT_LOGGING_ENABLED;
int PUBLIC_KEY_FILE_MAX_SIZE;
int PUBLIC_KEY_FILE_REFRESH_INTERVAL_SECONDS;
int MAX_CACHED_EXPIRED_TOKENS;
double AUDIT_TIME_WINDOW;
int TOKEN_CACHE_SIZE;
// AsyncFileCached
int64_t PAGE_CACHE_4K;
@ -375,9 +376,6 @@ public:
bool ENCRYPT_HEADER_AUTH_TOKEN_ENABLED;
int ENCRYPT_HEADER_AUTH_TOKEN_ALGO;
// Authorization
int TOKEN_CACHE_SIZE;
// RESTClient
int RESTCLIENT_MAX_CONNECTIONPOOL_SIZE;
int RESTCLIENT_CONNECT_TRIES;

View File

@ -126,7 +126,7 @@ if(WITH_PYTHON)
add_fdb_test(TEST_FILES fast/AtomicBackupToDBCorrectness.toml)
add_fdb_test(TEST_FILES fast/AtomicOps.toml)
add_fdb_test(TEST_FILES fast/AtomicOpsApiCorrectness.toml)
add_fdb_test(TEST_FILES fast/AuthzSecurity.toml IGNORE) # TODO re-enable once authz uses ID tokens
add_fdb_test(TEST_FILES fast/AuthzSecurity.toml)
add_fdb_test(TEST_FILES fast/AutomaticIdempotency.toml)
add_fdb_test(TEST_FILES fast/BackupAzureBlobCorrectness.toml IGNORE)
add_fdb_test(TEST_FILES fast/BackupS3BlobCorrectness.toml IGNORE)
@ -504,15 +504,37 @@ if(WITH_PYTHON)
set_tests_properties(authorization_venv_setup PROPERTIES FIXTURES_SETUP authz_virtual_env TIMEOUT 60)
set(authz_script_dir ${CMAKE_SOURCE_DIR}/tests/authorization)
set(authz_test_cmd "${authz_venv_activate} && pytest ${authz_script_dir}/authz_test.py -rA --build-dir ${CMAKE_BINARY_DIR} -vvv")
# TODO: reenable when authz is updated to validate based on tenant IDs
#add_test(
#NAME token_based_tenant_authorization
#WORKING_DIRECTORY ${authz_venv_dir}
#COMMAND bash -c ${authz_test_cmd})
#set_tests_properties(token_based_tenant_authorization PROPERTIES ENVIRONMENT "PYTHONPATH=${CMAKE_SOURCE_DIR}/tests/TestRunner;${ld_env_name}=${CMAKE_BINARY_DIR}/lib")
#set_tests_properties(token_based_tenant_authorization PROPERTIES FIXTURES_REQUIRED authz_virtual_env)
#set_tests_properties(token_based_tenant_authorization PROPERTIES TIMEOUT 120)
set(enable_grv_cache 0 1)
set(force_mvc 0 1)
foreach(is_grv_cache_enabled IN LISTS enable_grv_cache)
foreach(is_mvc_forced IN LISTS force_mvc)
if(NOT is_mvc_forced AND is_grv_cache_enabled)
continue() # grv cache requires setting up of shared database state which is only available in MVC mode
endif()
set(authz_test_name "authz")
set(test_opt "")
if(is_grv_cache_enabled)
string(APPEND test_opt " --use-grv-cache")
string(APPEND authz_test_name "_with_grv_cache")
else()
string(APPEND authz_test_name "_no_grv_cache")
endif()
if(is_mvc_forced)
string(APPEND test_opt " --force-multi-version-client")
string(APPEND authz_test_name "_with_forced_mvc")
else()
string(APPEND authz_test_name "_no_forced_mvc")
endif()
set(authz_test_cmd "${authz_venv_activate} && pytest ${authz_script_dir}/authz_test.py -rA --build-dir ${CMAKE_BINARY_DIR} -vvv${test_opt}")
add_test(
NAME ${authz_test_name}
WORKING_DIRECTORY ${authz_venv_dir}
COMMAND bash -c ${authz_test_cmd})
set_tests_properties(${authz_test_name} PROPERTIES ENVIRONMENT "PYTHONPATH=${CMAKE_SOURCE_DIR}/tests/TestRunner;${ld_env_name}=${CMAKE_BINARY_DIR}/lib")
set_tests_properties(${authz_test_name} PROPERTIES FIXTURES_REQUIRED authz_virtual_env)
set_tests_properties(${authz_test_name} PROPERTIES TIMEOUT 120)
endforeach()
endforeach()
endif()
else()
message(WARNING "Python not found, won't configure ctest")

View File

@ -1,6 +1,7 @@
from authlib.jose import JsonWebKey, KeySet, jwt
from typing import List
import json
import time
def private_key_gen(kty: str, kid: str):
assert kty == "EC" or kty == "RSA"
@ -29,3 +30,17 @@ def token_gen(private_key, claims, headers={}):
"kid": private_key.kid,
}
return jwt.encode(headers, claims, private_key)
def token_claim_1h(tenant_id: int):
# JWT claim that is valid for 1 hour since time of invocation
now = time.time()
return {
"iss": "fdb-authz-tester",
"sub": "authz-test",
"aud": ["tmp-cluster"],
"iat": now,
"nbf": now - 1,
"exp": now + 60 * 60,
"jti": random_alphanum_str(10),
"tenants": [to_str(base64.b64encode(tenant_id.to_bytes(8, "big")))],
}

View File

@ -30,7 +30,9 @@ class PortProvider:
while True:
counter += 1
if counter > MAX_PORT_ACQUIRE_ATTEMPTS:
assert False, "Failed to acquire a free port after {} attempts".format(MAX_PORT_ACQUIRE_ATTEMPTS)
assert False, "Failed to acquire a free port after {} attempts".format(
MAX_PORT_ACQUIRE_ATTEMPTS
)
port = PortProvider._get_free_port_internal()
if port in self._used_ports:
continue
@ -42,7 +44,12 @@ class PortProvider:
self._used_ports.add(port)
return port
except OSError:
print("Failed to lock file {}. Trying to aquire another port".format(lock_path), file=sys.stderr)
print(
"Failed to lock file {}. Trying to aquire another port".format(
lock_path
),
file=sys.stderr,
)
pass
def is_port_in_use(port):
@ -59,10 +66,11 @@ class PortProvider:
fd.close()
try:
os.remove(fd.name)
except:
except Exception:
pass
self._lock_files.clear()
class TLSConfig:
# Passing a negative chain length generates expired leaf certificate
def __init__(
@ -75,6 +83,7 @@ class TLSConfig:
self.client_chain_len = client_chain_len
self.verify_peers = verify_peers
class LocalCluster:
configuration_template = """
## foundationdb.conf
@ -170,8 +179,12 @@ logdir = {logdir}
if self.first_port is not None:
self.last_used_port = int(self.first_port) - 1
self.server_ports = {server_id: self.__next_port() for server_id in range(self.process_number)}
self.server_by_port = {port: server_id for server_id, port in self.server_ports.items()}
self.server_ports = {
server_id: self.__next_port() for server_id in range(self.process_number)
}
self.server_by_port = {
port: server_id for server_id, port in self.server_ports.items()
}
self.next_server_id = self.process_number
self.cluster_desc = random_alphanum_string(8)
self.cluster_secret = random_alphanum_string(8)
@ -198,16 +211,23 @@ logdir = {logdir}
self.client_ca_file = self.cert.joinpath("client_ca.pem")
if self.authorization_kty:
assert self.authorization_keypair_id, "keypair ID must be set to enable authorization"
assert (
self.authorization_keypair_id
), "keypair ID must be set to enable authorization"
self.public_key_json_file = self.etc.joinpath("public_keys.json")
self.private_key = private_key_gen(
kty=self.authorization_kty, kid=self.authorization_keypair_id)
kty=self.authorization_kty, kid=self.authorization_keypair_id
)
self.public_key_jwks_str = public_keyset_from_keys([self.private_key])
with open(self.public_key_json_file, "w") as pubkeyfile:
pubkeyfile.write(self.public_key_jwks_str)
self.authorization_private_key_pem_file = self.etc.joinpath("authorization_private_key.pem")
self.authorization_private_key_pem_file = self.etc.joinpath(
"authorization_private_key.pem"
)
with open(self.authorization_private_key_pem_file, "w") as privkeyfile:
privkeyfile.write(self.private_key.as_pem(is_private=True).decode("utf8"))
privkeyfile.write(
self.private_key.as_pem(is_private=True).decode("utf8")
)
if create_config:
self.create_cluster_file()
@ -246,7 +266,10 @@ logdir = {logdir}
authz_public_key_config=self.authz_public_key_conf_string(),
optional_tls=":tls" if self.tls_config is not None else "",
custom_config="\n".join(
["{} = {}".format(key, value) for key, value in self.custom_config.items()]
[
"{} = {}".format(key, value)
for key, value in self.custom_config.items()
]
),
use_future_protocol_version="use-future-protocol-version = true"
if self.use_future_protocol_version
@ -259,7 +282,11 @@ logdir = {logdir}
# Then 4000,4001,4002,4003,4004 will be used as ports
# If port number is not given, we will randomly pick free ports
for server_id in self.active_servers:
f.write("[fdbserver.{server_port}]\n".format(server_port=self.server_ports[server_id]))
f.write(
"[fdbserver.{server_port}]\n".format(
server_port=self.server_ports[server_id]
)
)
if self.use_legacy_conf_syntax:
f.write("machine_id = {}\n".format(server_id))
else:
@ -353,8 +380,12 @@ logdir = {logdir}
]
if self.use_future_protocol_version:
args += ["--use-future-protocol-version"]
res = subprocess.run(args, env=self.process_env(), stderr=stderr, stdout=stdout, timeout=timeout)
assert res.returncode == 0, "fdbcli command {} failed with {}".format(cmd, res.returncode)
res = subprocess.run(
args, env=self.process_env(), stderr=stderr, stdout=stdout, timeout=timeout
)
assert res.returncode == 0, "fdbcli command {} failed with {}".format(
cmd, res.returncode
)
return res.stdout
# Execute a fdbcli command
@ -376,9 +407,9 @@ logdir = {logdir}
# Generate and install test certificate chains and keys
def create_tls_cert(self):
assert self.tls_config is not None, "TLS not enabled"
assert self.mkcert_binary.exists() and self.mkcert_binary.is_file(), "{} does not exist".format(
self.mkcert_binary
)
assert (
self.mkcert_binary.exists() and self.mkcert_binary.is_file()
), "{} does not exist".format(self.mkcert_binary)
self.cert.mkdir(exist_ok=True)
server_chain_len = abs(self.tls_config.server_chain_len)
client_chain_len = abs(self.tls_config.client_chain_len)
@ -425,7 +456,9 @@ logdir = {logdir}
def authz_public_key_conf_string(self):
if self.public_key_json_file is not None:
return "authorization-public-key-file = {}".format(self.public_key_json_file)
return "authorization-public-key-file = {}".format(
self.public_key_json_file
)
else:
return ""
@ -441,7 +474,11 @@ logdir = {logdir}
return {}
servers_found = set()
addresses = [proc_info["address"] for proc_info in status["cluster"]["processes"].values() if filter(proc_info)]
addresses = [
proc_info["address"]
for proc_info in status["cluster"]["processes"].values()
if filter(proc_info)
]
for addr in addresses:
port = int(addr.split(":", 1)[1])
assert port in self.server_by_port, "Unknown server port {}".format(port)
@ -472,7 +509,9 @@ logdir = {logdir}
# Need to call save_config to apply the changes
def add_server(self):
server_id = self.next_server_id
assert server_id not in self.server_ports, "Server ID {} is already in use".format(server_id)
assert (
server_id not in self.server_ports
), "Server ID {} is already in use".format(server_id)
self.next_server_id += 1
port = self.__next_port()
self.server_ports[server_id] = port
@ -483,7 +522,9 @@ logdir = {logdir}
# Remove the server with the given ID from the cluster
# Need to call save_config to apply the changes
def remove_server(self, server_id):
assert server_id in self.active_servers, "Server {} does not exist".format(server_id)
assert server_id in self.active_servers, "Server {} does not exist".format(
server_id
)
self.active_servers.remove(server_id)
# Wait until changes to the set of servers (additions & removals) are applied
@ -501,7 +542,10 @@ logdir = {logdir}
# Apply changes to the set of the coordinators, based on the current value of self.coordinators
def update_coordinators(self):
urls = ["{}:{}".format(self.ip_address, self.server_ports[id]) for id in self.coordinators]
urls = [
"{}:{}".format(self.ip_address, self.server_ports[id])
for id in self.coordinators
]
self.fdbcli_exec("coordinators {}".format(" ".join(urls)))
# Wait until the changes to the set of the coordinators are applied
@ -521,13 +565,20 @@ logdir = {logdir}
for server_id in self.coordinators:
assert (
connection_string.find(str(self.server_ports[server_id])) != -1
), "Missing coordinator {} port {} in the cluster file".format(server_id, self.server_ports[server_id])
), "Missing coordinator {} port {} in the cluster file".format(
server_id, self.server_ports[server_id]
)
# Exclude the servers with the given ID from the cluster, i.e. move out their data
# The method waits until the changes are applied
def exclude_servers(self, server_ids):
urls = ["{}:{}".format(self.ip_address, self.server_ports[id]) for id in server_ids]
self.fdbcli_exec("exclude FORCE {}".format(" ".join(urls)), timeout=EXCLUDE_SERVERS_TIMEOUT_SEC)
urls = [
"{}:{}".format(self.ip_address, self.server_ports[id]) for id in server_ids
]
self.fdbcli_exec(
"exclude FORCE {}".format(" ".join(urls)),
timeout=EXCLUDE_SERVERS_TIMEOUT_SEC,
)
# Perform a cluster wiggle: replace all servers with new ones
def cluster_wiggle(self):
@ -552,7 +603,11 @@ logdir = {logdir}
)
self.save_config()
self.wait_for_server_update()
print("New servers successfully added to the cluster. Time: {}s".format(time.time() - start_time))
print(
"New servers successfully added to the cluster. Time: {}s".format(
time.time() - start_time
)
)
# Step 2: change coordinators
start_time = time.time()
@ -561,12 +616,20 @@ logdir = {logdir}
self.coordinators = new_coordinators.copy()
self.update_coordinators()
self.wait_for_coordinator_update()
print("Coordinators successfully changed. Time: {}s".format(time.time() - start_time))
print(
"Coordinators successfully changed. Time: {}s".format(
time.time() - start_time
)
)
# Step 3: exclude old servers from the cluster, i.e. move out their data
start_time = time.time()
self.exclude_servers(old_servers)
print("Old servers successfully excluded from the cluster. Time: {}s".format(time.time() - start_time))
print(
"Old servers successfully excluded from the cluster. Time: {}s".format(
time.time() - start_time
)
)
# Step 4: remove the old servers
start_time = time.time()
@ -574,11 +637,21 @@ logdir = {logdir}
self.remove_server(server_id)
self.save_config()
self.wait_for_server_update()
print("Old servers successfully removed from the cluster. Time: {}s".format(time.time() - start_time))
print(
"Old servers successfully removed from the cluster. Time: {}s".format(
time.time() - start_time
)
)
# Check the cluster log for errors
def check_cluster_logs(self, error_limit=100):
sev40s = subprocess.getoutput("grep -r 'Severity=\"40\"' {}".format(self.log.as_posix())).rstrip().splitlines()
sev40s = (
subprocess.getoutput(
"grep -r 'Severity=\"40\"' {}".format(self.log.as_posix())
)
.rstrip()
.splitlines()
)
err_cnt = 0
for line in sev40s:

View File

@ -104,8 +104,10 @@ if __name__ == "__main__":
tls_config = None
if args.tls_enabled:
tls_config = TLSConfig(server_chain_len=args.server_cert_chain_len,
client_chain_len=args.client_cert_chain_len)
tls_config = TLSConfig(
server_chain_len=args.server_cert_chain_len,
client_chain_len=args.client_cert_chain_len,
)
errcode = 1
with TempCluster(
args.build_dir,
@ -133,16 +135,15 @@ if __name__ == "__main__":
("@SERVER_CA_FILE@", str(cluster.server_ca_file)),
("@CLIENT_CERT_FILE@", str(cluster.client_cert_file)),
("@CLIENT_KEY_FILE@", str(cluster.client_key_file)),
("@CLIENT_CA_FILE@", str(cluster.client_ca_file))]
("@CLIENT_CA_FILE@", str(cluster.client_ca_file)),
]
for cmd in args.cmd:
for (placeholder, value) in substitution_table:
cmd = cmd.replace(placeholder, value)
cmd_args.append(cmd)
env = dict(**os.environ)
env["FDB_CLUSTER_FILE"] = env.get(
"FDB_CLUSTER_FILE", cluster.cluster_file
)
env["FDB_CLUSTER_FILE"] = env.get("FDB_CLUSTER_FILE", cluster.cluster_file)
print("command: {}".format(cmd_args))
errcode = subprocess.run(
cmd_args, stdout=sys.stdout, stderr=sys.stderr, env=env

View File

@ -36,6 +36,7 @@ class _admin_request(object):
def main_loop(main_pipe, pipe):
main_pipe.close()
use_grv_cache = False
db = None
while True:
try:
@ -49,7 +50,14 @@ def main_loop(main_pipe, pipe):
args = req.args
resp = True
try:
if op == "connect":
if op == "configure_client":
force_multi_version_client, use_grv_cache, logdir = req.args[:3]
if force_multi_version_client:
fdb.options.set_disable_client_bypass()
if len(logdir) > 0:
fdb.options.set_trace_enable(logdir)
fdb.options.set_trace_file_identifier("adminserver")
elif op == "connect":
db = fdb.open(req.args[0])
elif op == "configure_tls":
keyfile, certfile, cafile = req.args[:3]

View File

@ -21,16 +21,18 @@
import admin_server
import argparse
import authlib
import base64
import fdb
import os
import pytest
import random
import sys
import time
from collections.abc import Callable
from multiprocessing import Process, Pipe
from typing import Union
from authz_util import token_gen, private_key_gen, public_keyset_from_keys, alg_from_kty
from util import random_alphanum_str, random_alphanum_bytes, to_str, to_bytes, KeyFileReverter, token_claim_1h, wait_until_tenant_tr_succeeds, wait_until_tenant_tr_fails
from util import random_alphanum_str, random_alphanum_bytes, to_str, to_bytes, KeyFileReverter, wait_until_tenant_tr_succeeds, wait_until_tenant_tr_fails
special_key_ranges = [
("transaction description", b"/description", b"/description\x00"),
@ -43,25 +45,79 @@ special_key_ranges = [
("kill storage", b"/globals/killStorage", b"/globals/killStorage\x00"),
]
def test_simple_tenant_access(cluster, default_tenant, tenant_tr_gen):
# handler for when looping is assumed with usage
# e.g. GRV cache enablement removes the guarantee that transaction always gets the latest read version before it starts,
# which could introduce arbitrary conflicts even on idle test clusters, and those need to be resolved via retrying.
def loop_until_success(tr: fdb.Transaction, func):
while True:
try:
return func(tr)
except fdb.FDBError as e:
tr.on_error(e).wait()
# test that token option on a transaction should survive soft transaction resets,
# be cleared by hard transaction resets, and also clearable by setting empty value
def test_token_option(cluster, default_tenant, tenant_tr_gen, token_claim_1h):
token = token_gen(cluster.private_key, token_claim_1h(default_tenant))
tr = tenant_tr_gen(default_tenant)
tr.options.set_authorization_token(token)
tr[b"abc"] = b"def"
tr.commit().wait()
def commit_some_value(tr):
tr[b"abc"] = b"def"
return tr.commit().wait()
loop_until_success(tr, commit_some_value)
# token option should survive a soft reset by a retryable error
tr.on_error(fdb.FDBError(1020)).wait() # not_committed (conflict)
def read_back_value(tr):
return tr[b"abc"].value
value = loop_until_success(tr, read_back_value)
assert value == b"def", f"unexpected value found: {value}"
tr.reset() # token shouldn't survive a hard reset
try:
value = read_back_value(tr)
assert False, "expected permission_denied, but succeeded"
except fdb.FDBError as e:
assert e.code == 6000, f"expected permission_denied, got {e} instead"
tr.reset()
tr.options.set_authorization_token(token)
tr.options.set_authorization_token() # option set with no arg should clear the token
try:
value = read_back_value(tr)
assert False, "expected permission_denied, but succeeded"
except fdb.FDBError as e:
assert e.code == 6000, f"expected permission_denied, got {e} instead"
def test_simple_tenant_access(cluster, default_tenant, tenant_tr_gen, token_claim_1h):
token = token_gen(cluster.private_key, token_claim_1h(default_tenant))
tr = tenant_tr_gen(default_tenant)
tr.options.set_authorization_token(token)
assert tr[b"abc"] == b"def", "tenant write transaction not visible"
def commit_some_value(tr):
tr[b"abc"] = b"def"
tr.commit().wait()
def test_cross_tenant_access_disallowed(cluster, default_tenant, tenant_gen, tenant_tr_gen):
loop_until_success(tr, commit_some_value)
tr = tenant_tr_gen(default_tenant)
tr.options.set_authorization_token(token)
def read_back_value(tr):
return tr[b"abc"].value
value = loop_until_success(tr, read_back_value)
assert value == b"def", "tenant write transaction not visible"
def test_cross_tenant_access_disallowed(cluster, default_tenant, tenant_gen, tenant_tr_gen, token_claim_1h):
# use default tenant token with second tenant transaction and see it fail
second_tenant = random_alphanum_bytes(12)
tenant_gen(second_tenant)
token_second = token_gen(cluster.private_key, token_claim_1h(second_tenant))
tr_second = tenant_tr_gen(second_tenant)
tr_second.options.set_authorization_token(token_second)
tr_second[b"abc"] = b"def"
tr_second.commit().wait()
def commit_some_value(tr):
tr[b"abc"] = b"def"
return tr.commit().wait()
loop_until_success(tr_second, commit_some_value)
token_default = token_gen(cluster.private_key, token_claim_1h(default_tenant))
tr_second = tenant_tr_gen(second_tenant)
tr_second.options.set_authorization_token(token_default)
@ -81,6 +137,44 @@ def test_cross_tenant_access_disallowed(cluster, default_tenant, tenant_gen, ten
except fdb.FDBError as e:
assert e.code == 6000, f"expected permission_denied, got {e} instead"
def test_cross_tenant_raw_access_disallowed_with_token(cluster, db, default_tenant, tenant_gen, tenant_tr_gen, token_claim_1h):
def commit_some_value(tr):
tr[b"abc"] = b"def"
return tr.commit().wait()
second_tenant = random_alphanum_bytes(12)
tenant_gen(second_tenant)
first_tenant_token_claim = token_claim_1h(default_tenant)
second_tenant_token_claim = token_claim_1h(second_tenant)
# create a token that's good for both tenants
first_tenant_token_claim["tenants"] += second_tenant_token_claim["tenants"]
token = token_gen(cluster.private_key, first_tenant_token_claim)
tr_first = tenant_tr_gen(default_tenant)
tr_first.options.set_authorization_token(token)
loop_until_success(tr_first, commit_some_value)
tr_second = tenant_tr_gen(second_tenant)
tr_second.options.set_authorization_token(token)
loop_until_success(tr_second, commit_some_value)
# now try a normal keyspace transaction to raw-access both tenants' keyspace at once, with token
tr = db.create_transaction()
tr.options.set_authorization_token(token)
tr.options.set_raw_access()
prefix_first = base64.b64decode(first_tenant_token_claim["tenants"][0])
assert len(prefix_first) == 8
prefix_second = base64.b64decode(first_tenant_token_claim["tenants"][1])
assert len(prefix_second) == 8
lhs = min(prefix_first, prefix_second)
rhs = max(prefix_first, prefix_second)
rhs = bytearray(rhs)
rhs[-1] += 1 # exclusive end
try:
value = tr[lhs:bytes(rhs)].to_list()
assert False, f"expected permission_denied, but succeeded, value: {value}"
except fdb.FDBError as e:
assert e.code == 6000, f"expected permission_denied, got {e} instead"
def test_system_and_special_key_range_disallowed(db, tenant_tr_gen):
second_tenant = random_alphanum_bytes(12)
try:
@ -137,7 +231,7 @@ def test_system_and_special_key_range_disallowed(db, tenant_tr_gen):
def test_public_key_set_rollover(
kty, public_key_refresh_interval,
cluster, default_tenant, tenant_gen, tenant_tr_gen):
cluster, default_tenant, tenant_gen, tenant_tr_gen, token_claim_1h):
new_kid = random_alphanum_str(12)
new_kty = "EC" if kty == "RSA" else "RSA"
new_key = private_key_gen(kty=new_kty, kid=new_kid)
@ -160,16 +254,16 @@ def test_public_key_set_rollover(
with KeyFileReverter(cluster.public_key_json_file, old_key_json, delay):
with open(cluster.public_key_json_file, "w") as keyfile:
keyfile.write(interim_set)
wait_until_tenant_tr_succeeds(second_tenant, new_key, tenant_tr_gen, max_repeat, delay)
wait_until_tenant_tr_succeeds(second_tenant, new_key, tenant_tr_gen, max_repeat, delay, token_claim_1h)
print("interim key set activated")
final_set = public_keyset_from_keys([new_key])
print(f"final keyset: {final_set}")
with open(cluster.public_key_json_file, "w") as keyfile:
keyfile.write(final_set)
wait_until_tenant_tr_fails(default_tenant, cluster.private_key, tenant_tr_gen, max_repeat, delay)
wait_until_tenant_tr_fails(default_tenant, cluster.private_key, tenant_tr_gen, max_repeat, delay, token_claim_1h)
def test_public_key_set_broken_file_tolerance(
cluster, public_key_refresh_interval, default_tenant, tenant_tr_gen):
cluster, public_key_refresh_interval, default_tenant, tenant_tr_gen, token_claim_1h):
delay = public_key_refresh_interval
# retry limit in waiting for keyset file update to propagate to FDB server's internal keyset
max_repeat = 10
@ -187,10 +281,10 @@ def test_public_key_set_broken_file_tolerance(
with open(cluster.public_key_json_file, "w") as keyfile:
keyfile.write('{"keys":[]}')
# eventually internal key set will become empty and won't accept any new tokens
wait_until_tenant_tr_fails(default_tenant, cluster.private_key, tenant_tr_gen, max_repeat, delay)
wait_until_tenant_tr_fails(default_tenant, cluster.private_key, tenant_tr_gen, max_repeat, delay, token_claim_1h)
def test_public_key_set_deletion_tolerance(
cluster, public_key_refresh_interval, default_tenant, tenant_tr_gen):
cluster, public_key_refresh_interval, default_tenant, tenant_tr_gen, token_claim_1h):
delay = public_key_refresh_interval
# retry limit in waiting for keyset file update to propagate to FDB server's internal keyset
max_repeat = 10
@ -200,16 +294,16 @@ def test_public_key_set_deletion_tolerance(
with open(cluster.public_key_json_file, "w") as keyfile:
keyfile.write('{"keys":[]}')
time.sleep(delay)
wait_until_tenant_tr_fails(default_tenant, cluster.private_key, tenant_tr_gen, max_repeat, delay)
wait_until_tenant_tr_fails(default_tenant, cluster.private_key, tenant_tr_gen, max_repeat, delay, token_claim_1h)
os.remove(cluster.public_key_json_file)
time.sleep(delay * 2)
with open(cluster.public_key_json_file, "w") as keyfile:
keyfile.write(cluster.public_key_jwks_str)
# eventually updated key set should take effect and transaction should be accepted
wait_until_tenant_tr_succeeds(default_tenant, cluster.private_key, tenant_tr_gen, max_repeat, delay)
wait_until_tenant_tr_succeeds(default_tenant, cluster.private_key, tenant_tr_gen, max_repeat, delay, token_claim_1h)
def test_public_key_set_empty_file_tolerance(
cluster, public_key_refresh_interval, default_tenant, tenant_tr_gen):
cluster, public_key_refresh_interval, default_tenant, tenant_tr_gen, token_claim_1h):
delay = public_key_refresh_interval
# retry limit in waiting for keyset file update to propagate to FDB server's internal keyset
max_repeat = 10
@ -219,7 +313,7 @@ def test_public_key_set_empty_file_tolerance(
with open(cluster.public_key_json_file, "w") as keyfile:
keyfile.write('{"keys":[]}')
# eventually internal key set will become empty and won't accept any new tokens
wait_until_tenant_tr_fails(default_tenant, cluster.private_key, tenant_tr_gen, max_repeat, delay)
wait_until_tenant_tr_fails(default_tenant, cluster.private_key, tenant_tr_gen, max_repeat, delay, token_claim_1h)
# empty the key file
with open(cluster.public_key_json_file, "w") as keyfile:
pass
@ -227,9 +321,9 @@ def test_public_key_set_empty_file_tolerance(
with open(cluster.public_key_json_file, "w") as keyfile:
keyfile.write(cluster.public_key_jwks_str)
# eventually key file should update and transactions should go through
wait_until_tenant_tr_succeeds(default_tenant, cluster.private_key, tenant_tr_gen, max_repeat, delay)
wait_until_tenant_tr_succeeds(default_tenant, cluster.private_key, tenant_tr_gen, max_repeat, delay, token_claim_1h)
def test_bad_token(cluster, default_tenant, tenant_tr_gen):
def test_bad_token(cluster, default_tenant, tenant_tr_gen, token_claim_1h):
def del_attr(d, attr):
del d[attr]
return d

View File

@ -22,10 +22,14 @@ import fdb
import pytest
import subprocess
import admin_server
import base64
import glob
import time
from local_cluster import TLSConfig
from tmp_cluster import TempCluster
from typing import Union
from util import random_alphanum_str, random_alphanum_bytes, to_str, to_bytes
import xml.etree.ElementTree as ET
fdb.api_version(720)
@ -48,6 +52,18 @@ def pytest_addoption(parser):
default=1,
dest="public_key_refresh_interval",
help="How frequently server refreshes authorization public key file")
parser.addoption(
"--force-multi-version-client",
action="store_true",
default=False,
dest="force_multi_version_client",
help="Whether to force multi-version client mode")
parser.addoption(
"--use-grv-cache",
action="store_true",
default=False,
dest="use_grv_cache",
help="Whether to make client use cached GRV from database context")
@pytest.fixture(scope="session")
def build_dir(request):
@ -65,6 +81,14 @@ def trusted_client(request):
def public_key_refresh_interval(request):
return request.config.option.public_key_refresh_interval
@pytest.fixture(scope="session")
def force_multi_version_client(request):
return request.config.option.force_multi_version_client
@pytest.fixture(scope="session")
def use_grv_cache(request):
return request.config.option.use_grv_cache
@pytest.fixture(scope="session")
def kid():
return random_alphanum_str(12)
@ -77,7 +101,7 @@ def admin_ipc():
server.join()
@pytest.fixture(autouse=True, scope=cluster_scope)
def cluster(admin_ipc, build_dir, public_key_refresh_interval, trusted_client):
def cluster(admin_ipc, build_dir, public_key_refresh_interval, trusted_client, force_multi_version_client, use_grv_cache):
with TempCluster(
build_dir=build_dir,
tls_config=TLSConfig(server_chain_len=3, client_chain_len=2),
@ -90,19 +114,44 @@ def cluster(admin_ipc, build_dir, public_key_refresh_interval, trusted_client):
keyfile = str(cluster.client_key_file)
certfile = str(cluster.client_cert_file)
cafile = str(cluster.server_ca_file)
logdir = str(cluster.log)
fdb.options.set_tls_key_path(keyfile if trusted_client else "")
fdb.options.set_tls_cert_path(certfile if trusted_client else "")
fdb.options.set_tls_ca_path(cafile)
fdb.options.set_trace_enable()
fdb.options.set_trace_enable(logdir)
fdb.options.set_trace_file_identifier("testclient")
if force_multi_version_client:
fdb.options.set_disable_client_bypass()
admin_ipc.request("configure_client", [force_multi_version_client, use_grv_cache, logdir])
admin_ipc.request("configure_tls", [keyfile, certfile, cafile])
admin_ipc.request("connect", [str(cluster.cluster_file)])
yield cluster
err_count = {}
for file in glob.glob(str(cluster.log.joinpath("*.xml"))):
lineno = 1
for line in open(file):
try:
doc = ET.fromstring(line)
except:
continue
if doc.attrib.get("Severity", "") == "40":
ev_type = doc.attrib.get("Type", "[unset]")
err = doc.attrib.get("Error", "[unset]")
tup = (file, ev_type, err)
err_count[tup] = err_count.get(tup, 0) + 1
lineno += 1
print("Sev40 Summary:")
if len(err_count) == 0:
print(" No errors")
else:
for tup, count in err_count.items():
print(" {}: {}".format(tup, count))
@pytest.fixture
def db(cluster, admin_ipc):
db = fdb.open(str(cluster.cluster_file))
db.options.set_transaction_timeout(2000) # 2 seconds
db.options.set_transaction_retry_limit(3)
#db.options.set_transaction_timeout(2000) # 2 seconds
db.options.set_transaction_retry_limit(10)
yield db
admin_ipc.request("cleanup_database")
db = None
@ -129,8 +178,30 @@ def default_tenant(tenant_gen, tenant_del):
tenant_del(tenant)
@pytest.fixture
def tenant_tr_gen(db):
def tenant_tr_gen(db, use_grv_cache):
def fn(tenant):
tenant = db.open_tenant(to_bytes(tenant))
return tenant.create_transaction()
tr = tenant.create_transaction()
if use_grv_cache:
tr.options.set_use_grv_cache()
return tr
return fn
@pytest.fixture
def token_claim_1h(db):
# JWT claim that is valid for 1 hour since time of invocation
def fn(tenant_name: Union[bytes, str]):
tenant = db.open_tenant(to_bytes(tenant_name))
tenant_id = tenant.get_id().wait()
now = time.time()
return {
"iss": "fdb-authz-tester",
"sub": "authz-test",
"aud": ["tmp-cluster"],
"iat": now,
"nbf": now - 1,
"exp": now + 60 * 60,
"jti": random_alphanum_str(10),
"tenants": [to_str(base64.b64encode(tenant_id.to_bytes(8, "big")))],
}
return fn

View File

@ -48,23 +48,9 @@ class KeyFileReverter(object):
print(f"key file reverted. waiting {self.refresh_delay * 2} seconds for the update to take effect...")
time.sleep(self.refresh_delay * 2)
# JWT claim that is valid for 1 hour since time of invocation
def token_claim_1h(tenant_name):
now = time.time()
return {
"iss": "fdb-authz-tester",
"sub": "authz-test",
"aud": ["tmp-cluster"],
"iat": now,
"nbf": now - 1,
"exp": now + 60 * 60,
"jti": random_alphanum_str(10),
"tenants": [to_str(base64.b64encode(tenant_name))],
}
# repeat try-wait loop up to max_repeat times until both read and write tr fails for tenant with permission_denied
# important: only use this function if you don't have any data dependencies to key "abc"
def wait_until_tenant_tr_fails(tenant, private_key, tenant_tr_gen, max_repeat, delay):
def wait_until_tenant_tr_fails(tenant, private_key, tenant_tr_gen, max_repeat, delay, token_claim_1h):
repeat = 0
read_blocked = False
write_blocked = False
@ -97,7 +83,7 @@ def wait_until_tenant_tr_fails(tenant, private_key, tenant_tr_gen, max_repeat, d
# repeat try-wait loop up to max_repeat times until both read and write tr succeeds for tenant
# important: only use this function if you don't have any data dependencies to key "abc"
def wait_until_tenant_tr_succeeds(tenant, private_key, tenant_tr_gen, max_repeat, delay):
def wait_until_tenant_tr_succeeds(tenant, private_key, tenant_tr_gen, max_repeat, delay, token_claim_1h):
repeat = 0
token = token_gen(private_key, token_claim_1h(tenant))
while repeat < max_repeat:

View File

@ -2,6 +2,10 @@
allowDefaultTenant = false
tenantModes = ['optional', 'required']
[[knobs]]
audit_logging_enabled = false
max_trace_lines = 2000000
[[test]]
testTitle = 'TenantCreation'

View File

@ -2,6 +2,10 @@
allowDefaultTenant = false
tenantModes = ['optional', 'required']
[[knobs]]
audit_logging_enabled = false
max_trace_lines = 2000000
[[test]]
testTitle = 'TenantCreation'

View File

@ -4,6 +4,8 @@ tenantModes = ['optional', 'required']
[[knobs]]
allow_tokenless_tenant_access = true
audit_logging_enabled = false
max_trace_lines = 2000000
[[test]]
testTitle = 'TenantCreation'