Merge branch 'main' of github.com:apple/foundationdb into debug2
This commit is contained in:
commit
0789ab35e9
|
@ -31,19 +31,24 @@ import random
|
|||
import string
|
||||
import toml
|
||||
|
||||
sys.path[:0] = [os.path.join(os.path.dirname(__file__), "..", "..", "..", "..", "tests", "TestRunner")]
|
||||
|
||||
# fmt: off
|
||||
from tmp_cluster import TempCluster
|
||||
from local_cluster import TLSConfig
|
||||
# fmt: on
|
||||
|
||||
sys.path[:0] = [
|
||||
os.path.join(
|
||||
os.path.dirname(__file__), "..", "..", "..", "..", "tests", "TestRunner"
|
||||
)
|
||||
]
|
||||
|
||||
TESTER_STATS_INTERVAL_SEC = 5
|
||||
|
||||
|
||||
def random_string(len):
|
||||
return "".join(random.choice(string.ascii_letters + string.digits) for i in range(len))
|
||||
return "".join(
|
||||
random.choice(string.ascii_letters + string.digits) for i in range(len)
|
||||
)
|
||||
|
||||
|
||||
def get_logger():
|
||||
|
@ -77,7 +82,9 @@ def dump_client_logs(log_dir):
|
|||
def run_tester(args, cluster, test_file):
|
||||
build_dir = Path(args.build_dir).resolve()
|
||||
tester_binary = Path(args.api_tester_bin).resolve()
|
||||
external_client_library = build_dir.joinpath("bindings", "c", "libfdb_c_external.so")
|
||||
external_client_library = build_dir.joinpath(
|
||||
"bindings", "c", "libfdb_c_external.so"
|
||||
)
|
||||
log_dir = Path(cluster.log).joinpath("client")
|
||||
log_dir.mkdir(exist_ok=True)
|
||||
cmd = [
|
||||
|
@ -141,7 +148,9 @@ def run_tester(args, cluster, test_file):
|
|||
reason = signal.Signals(-ret_code).name
|
||||
else:
|
||||
reason = "exit code: %d" % ret_code
|
||||
get_logger().error("\n'%s' did not complete succesfully (%s)" % (cmd[0], reason))
|
||||
get_logger().error(
|
||||
"\n'%s' did not complete succesfully (%s)" % (cmd[0], reason)
|
||||
)
|
||||
if log_dir is not None and not args.disable_log_dump:
|
||||
dump_client_logs(log_dir)
|
||||
|
||||
|
@ -160,7 +169,9 @@ class TestConfig:
|
|||
self.server_chain_len = server_config.get("tls_server_chain_len", 3)
|
||||
self.min_num_processes = server_config.get("min_num_processes", 1)
|
||||
self.max_num_processes = server_config.get("max_num_processes", 3)
|
||||
self.num_processes = random.randint(self.min_num_processes, self.max_num_processes)
|
||||
self.num_processes = random.randint(
|
||||
self.min_num_processes, self.max_num_processes
|
||||
)
|
||||
|
||||
|
||||
def run_test(args, test_file):
|
||||
|
@ -210,9 +221,20 @@ def run_tests(args):
|
|||
|
||||
def parse_args(argv):
|
||||
parser = argparse.ArgumentParser(description="FoundationDB C API Tester")
|
||||
parser.add_argument("--build-dir", "-b", type=str, required=True, help="FDB build directory")
|
||||
parser.add_argument("--api-tester-bin", type=str, help="Path to the fdb_c_api_tester executable.", required=True)
|
||||
parser.add_argument("--external-client-library", type=str, help="Path to the external client library.")
|
||||
parser.add_argument(
|
||||
"--build-dir", "-b", type=str, required=True, help="FDB build directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--api-tester-bin",
|
||||
type=str,
|
||||
help="Path to the fdb_c_api_tester executable.",
|
||||
required=True,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--external-client-library",
|
||||
type=str,
|
||||
help="Path to the external client library.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--retain-client-lib-copies",
|
||||
action="store_true",
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include "future.hpp"
|
||||
#include "logger.hpp"
|
||||
#include "tenant.hpp"
|
||||
#include "time.hpp"
|
||||
#include "utils.hpp"
|
||||
#include <map>
|
||||
#include <cerrno>
|
||||
|
@ -30,39 +31,19 @@
|
|||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
#include <thread>
|
||||
#include <tuple>
|
||||
|
||||
#include <unistd.h>
|
||||
#include <sys/wait.h>
|
||||
|
||||
#include <boost/archive/binary_iarchive.hpp>
|
||||
#include <boost/archive/binary_oarchive.hpp>
|
||||
#include <boost/serialization/optional.hpp>
|
||||
#include <boost/serialization/string.hpp>
|
||||
#include <boost/serialization/variant.hpp>
|
||||
#include <boost/variant/apply_visitor.hpp>
|
||||
|
||||
#include "rapidjson/document.h"
|
||||
|
||||
extern thread_local mako::Logger logr;
|
||||
|
||||
using oarchive = boost::archive::binary_oarchive;
|
||||
using iarchive = boost::archive::binary_iarchive;
|
||||
|
||||
namespace {
|
||||
|
||||
template <class T>
|
||||
void sendObject(boost::process::pstream& pipe, T obj) {
|
||||
oarchive oa(pipe);
|
||||
oa << obj;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
T receiveObject(boost::process::pstream& pipe) {
|
||||
iarchive ia(pipe);
|
||||
T obj;
|
||||
ia >> obj;
|
||||
return obj;
|
||||
}
|
||||
|
||||
fdb::Database getOrCreateDatabase(std::map<std::string, fdb::Database>& db_map, const std::string& cluster_file) {
|
||||
auto iter = db_map.find(cluster_file);
|
||||
if (iter == db_map.end()) {
|
||||
|
@ -122,35 +103,54 @@ void AdminServer::start() {
|
|||
}
|
||||
});
|
||||
|
||||
while (true) {
|
||||
bool stop = false;
|
||||
while (!stop) {
|
||||
try {
|
||||
auto req = receiveObject<Request>(pipe_to_server);
|
||||
if (setup_error) {
|
||||
sendObject(pipe_to_client, Response{ setup_error });
|
||||
} else if (boost::get<PingRequest>(&req)) {
|
||||
sendObject(pipe_to_client, Response{});
|
||||
} else if (boost::get<StopRequest>(&req)) {
|
||||
logr.info("server was requested to stop");
|
||||
sendObject(pipe_to_client, Response{});
|
||||
return;
|
||||
} else if (auto p = boost::get<BatchCreateTenantRequest>(&req)) {
|
||||
logr.info("received request to batch-create tenants [{}:{}) in database '{}'",
|
||||
p->id_begin,
|
||||
p->id_end,
|
||||
p->cluster_file);
|
||||
auto err_msg = createTenant(getOrCreateDatabase(databases, p->cluster_file), p->id_begin, p->id_end);
|
||||
sendObject(pipe_to_client, Response{ std::move(err_msg) });
|
||||
} else if (auto p = boost::get<BatchDeleteTenantRequest>(&req)) {
|
||||
logr.info("received request to batch-delete tenants [{}:{}) in database '{}'",
|
||||
p->id_begin,
|
||||
p->id_end,
|
||||
p->cluster_file);
|
||||
auto err_msg = deleteTenant(getOrCreateDatabase(databases, p->cluster_file), p->id_begin, p->id_end);
|
||||
sendObject(pipe_to_client, Response{ std::move(err_msg) });
|
||||
} else {
|
||||
logr.error("unknown request received");
|
||||
sendObject(pipe_to_client, Response{ std::string("unknown request type") });
|
||||
}
|
||||
boost::apply_visitor(
|
||||
[this, &databases, &setup_error, &stop](auto&& request) -> void {
|
||||
using ReqType = std::decay_t<decltype(request)>;
|
||||
if (setup_error) {
|
||||
sendResponse<ReqType>(pipe_to_client, ReqType::ResponseType::makeError(*setup_error));
|
||||
return;
|
||||
}
|
||||
if constexpr (std::is_same_v<ReqType, PingRequest>) {
|
||||
sendResponse<ReqType>(pipe_to_client, DefaultResponse{});
|
||||
} else if constexpr (std::is_same_v<ReqType, StopRequest>) {
|
||||
logr.info("server was requested to stop");
|
||||
sendResponse<ReqType>(pipe_to_client, DefaultResponse{});
|
||||
stop = true;
|
||||
} else if constexpr (std::is_same_v<ReqType, BatchCreateTenantRequest>) {
|
||||
logr.info("received request to batch-create tenants [{}:{}) in database '{}'",
|
||||
request.id_begin,
|
||||
request.id_end,
|
||||
request.cluster_file);
|
||||
auto err_msg = createTenant(
|
||||
getOrCreateDatabase(databases, request.cluster_file), request.id_begin, request.id_end);
|
||||
sendResponse<ReqType>(pipe_to_client, DefaultResponse{ std::move(err_msg) });
|
||||
} else if constexpr (std::is_same_v<ReqType, BatchDeleteTenantRequest>) {
|
||||
logr.info("received request to batch-delete tenants [{}:{}) in database '{}'",
|
||||
request.id_begin,
|
||||
request.id_end,
|
||||
request.cluster_file);
|
||||
auto err_msg = deleteTenant(
|
||||
getOrCreateDatabase(databases, request.cluster_file), request.id_begin, request.id_end);
|
||||
sendResponse<ReqType>(pipe_to_client, DefaultResponse{ std::move(err_msg) });
|
||||
} else if constexpr (std::is_same_v<ReqType, FetchTenantIdsRequest>) {
|
||||
logr.info("received request to fetch tenant IDs [{}:{}) in database '{}'",
|
||||
request.id_begin,
|
||||
request.id_end,
|
||||
request.cluster_file);
|
||||
sendResponse<ReqType>(pipe_to_client,
|
||||
fetchTenantIds(getOrCreateDatabase(databases, request.cluster_file),
|
||||
request.id_begin,
|
||||
request.id_end));
|
||||
} else {
|
||||
logr.error("unknown request received, typename '{}'", typeid(ReqType).name());
|
||||
sendResponse<ReqType>(pipe_to_client, ReqType::ResponseType::makeError("unknown request type"));
|
||||
}
|
||||
},
|
||||
req);
|
||||
} catch (const std::exception& e) {
|
||||
logr.error("fatal exception: {}", e.what());
|
||||
return;
|
||||
|
@ -161,6 +161,7 @@ void AdminServer::start() {
|
|||
boost::optional<std::string> AdminServer::createTenant(fdb::Database db, int id_begin, int id_end) {
|
||||
try {
|
||||
auto tx = db.createTransaction();
|
||||
auto stopwatch = Stopwatch(StartAtCtor{});
|
||||
logr.info("create_tenants [{}-{})", id_begin, id_end);
|
||||
while (true) {
|
||||
for (auto id = id_begin; id < id_end; id++) {
|
||||
|
@ -180,7 +181,8 @@ boost::optional<std::string> AdminServer::createTenant(fdb::Database db, int id_
|
|||
return fmt::format("create_tenants [{}:{}) failed with '{}'", id_begin, id_end, f.error().what());
|
||||
}
|
||||
}
|
||||
logr.info("create_tenants [{}-{}) OK", id_begin, id_end);
|
||||
logr.info("create_tenants [{}-{}) OK ({:.3f}s)", id_begin, id_end, toDoubleSeconds(stopwatch.stop().diff()));
|
||||
stopwatch.start();
|
||||
logr.info("blobbify_tenants [{}-{})", id_begin, id_end);
|
||||
for (auto id = id_begin; id < id_end; id++) {
|
||||
while (true) {
|
||||
|
@ -204,7 +206,7 @@ boost::optional<std::string> AdminServer::createTenant(fdb::Database db, int id_
|
|||
}
|
||||
}
|
||||
}
|
||||
logr.info("blobbify_tenants [{}-{}) OK", id_begin, id_end);
|
||||
logr.info("blobbify_tenants [{}-{}) OK ({:.3f}s)", id_begin, id_end, toDoubleSeconds(stopwatch.stop().diff()));
|
||||
return {};
|
||||
} catch (const std::exception& e) {
|
||||
return std::string(e.what());
|
||||
|
@ -307,12 +309,52 @@ boost::optional<std::string> AdminServer::deleteTenant(fdb::Database db, int id_
|
|||
}
|
||||
}
|
||||
|
||||
Response AdminServer::request(Request req) {
|
||||
// should always be invoked from client side (currently just the main process)
|
||||
assert(server_pid > 0);
|
||||
assert(logr.isFor(ProcKind::MAIN));
|
||||
sendObject(pipe_to_server, std::move(req));
|
||||
return receiveObject<Response>(pipe_to_client);
|
||||
TenantIdsResponse AdminServer::fetchTenantIds(fdb::Database db, int id_begin, int id_end) {
|
||||
try {
|
||||
logr.info("fetch_tenant_ids [{}:{})", id_begin, id_end);
|
||||
auto stopwatch = Stopwatch(StartAtCtor{});
|
||||
size_t const count = id_end - id_begin;
|
||||
std::vector<int64_t> ids(count);
|
||||
std::vector<std::tuple<fdb::TypedFuture<fdb::future_var::Int64>, fdb::Tenant, bool>> state(count);
|
||||
boost::optional<std::string> err_msg;
|
||||
for (auto idx = id_begin; idx < id_end; idx++) {
|
||||
auto& [future, tenant, done] = state[idx - id_begin];
|
||||
tenant = db.openTenant(fdb::toBytesRef(getTenantNameByIndex(idx)));
|
||||
future = tenant.getId();
|
||||
done = false;
|
||||
}
|
||||
while (true) {
|
||||
bool has_retries = false;
|
||||
for (auto idx = id_begin; idx < id_end; idx++) {
|
||||
auto& [future, tenant, done] = state[idx - id_begin];
|
||||
if (!done) {
|
||||
if (auto err = future.blockUntilReady()) {
|
||||
return TenantIdsResponse::makeError(
|
||||
fmt::format("error while waiting for tenant ID of tenant {}: {}", idx, err.what()));
|
||||
}
|
||||
if (auto err = future.error()) {
|
||||
if (err.retryable()) {
|
||||
logr.debug("retryable error while getting tenant ID of tenant {}: {}", idx, err.what());
|
||||
future = tenant.getId();
|
||||
has_retries = true;
|
||||
} else {
|
||||
return TenantIdsResponse::makeError(fmt::format(
|
||||
"unretryable error while getting tenant ID of tenant {}: {}", idx, err.what()));
|
||||
}
|
||||
} else {
|
||||
ids[idx - id_begin] = future.get();
|
||||
done = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!has_retries)
|
||||
break;
|
||||
}
|
||||
logr.info("fetch_tenant_ids [{}:{}) OK ({:.3f}s)", id_begin, id_end, toDoubleSeconds(stopwatch.stop().diff()));
|
||||
return TenantIdsResponse{ {}, std::move(ids) };
|
||||
} catch (const std::exception& e) {
|
||||
return TenantIdsResponse::makeError(fmt::format("unexpected exception: {}", e.what()));
|
||||
}
|
||||
}
|
||||
|
||||
AdminServer::~AdminServer() {
|
||||
|
|
|
@ -23,6 +23,13 @@
|
|||
#include <boost/process/pipe.hpp>
|
||||
#include <boost/optional.hpp>
|
||||
#include <boost/variant.hpp>
|
||||
#include <boost/archive/binary_iarchive.hpp>
|
||||
#include <boost/archive/binary_oarchive.hpp>
|
||||
#include <boost/serialization/optional.hpp>
|
||||
#include <boost/serialization/string.hpp>
|
||||
#include <boost/serialization/variant.hpp>
|
||||
#include <boost/serialization/vector.hpp>
|
||||
|
||||
#include <unistd.h>
|
||||
#include "fdb_api.hpp"
|
||||
#include "logger.hpp"
|
||||
|
@ -35,16 +42,32 @@ extern thread_local mako::Logger logr;
|
|||
// Therefore, order to benchmark for authorization
|
||||
namespace mako::ipc {
|
||||
|
||||
struct Response {
|
||||
struct DefaultResponse {
|
||||
boost::optional<std::string> error_message;
|
||||
|
||||
static DefaultResponse makeError(std::string msg) { return DefaultResponse{ msg }; }
|
||||
|
||||
template <class Ar>
|
||||
void serialize(Ar& ar, unsigned int) {
|
||||
ar& error_message;
|
||||
}
|
||||
};
|
||||
|
||||
struct TenantIdsResponse {
|
||||
boost::optional<std::string> error_message;
|
||||
std::vector<int64_t> ids;
|
||||
|
||||
static TenantIdsResponse makeError(std::string msg) { return TenantIdsResponse{ msg, {} }; }
|
||||
|
||||
template <class Ar>
|
||||
void serialize(Ar& ar, unsigned int) {
|
||||
ar& error_message;
|
||||
ar& ids;
|
||||
}
|
||||
};
|
||||
|
||||
struct BatchCreateTenantRequest {
|
||||
using ResponseType = DefaultResponse;
|
||||
std::string cluster_file;
|
||||
int id_begin = 0;
|
||||
int id_end = 0;
|
||||
|
@ -58,6 +81,7 @@ struct BatchCreateTenantRequest {
|
|||
};
|
||||
|
||||
struct BatchDeleteTenantRequest {
|
||||
using ResponseType = DefaultResponse;
|
||||
std::string cluster_file;
|
||||
int id_begin = 0;
|
||||
int id_end = 0;
|
||||
|
@ -70,17 +94,34 @@ struct BatchDeleteTenantRequest {
|
|||
}
|
||||
};
|
||||
|
||||
struct FetchTenantIdsRequest {
|
||||
using ResponseType = TenantIdsResponse;
|
||||
std::string cluster_file;
|
||||
int id_begin;
|
||||
int id_end;
|
||||
|
||||
template <class Ar>
|
||||
void serialize(Ar& ar, unsigned int) {
|
||||
ar& cluster_file;
|
||||
ar& id_begin;
|
||||
ar& id_end;
|
||||
}
|
||||
};
|
||||
|
||||
struct PingRequest {
|
||||
using ResponseType = DefaultResponse;
|
||||
template <class Ar>
|
||||
void serialize(Ar&, unsigned int) {}
|
||||
};
|
||||
|
||||
struct StopRequest {
|
||||
using ResponseType = DefaultResponse;
|
||||
template <class Ar>
|
||||
void serialize(Ar&, unsigned int) {}
|
||||
};
|
||||
|
||||
using Request = boost::variant<PingRequest, StopRequest, BatchCreateTenantRequest, BatchDeleteTenantRequest>;
|
||||
using Request =
|
||||
boost::variant<PingRequest, StopRequest, BatchCreateTenantRequest, BatchDeleteTenantRequest, FetchTenantIdsRequest>;
|
||||
|
||||
class AdminServer {
|
||||
const Arguments& args;
|
||||
|
@ -89,13 +130,32 @@ class AdminServer {
|
|||
boost::process::pstream pipe_to_client;
|
||||
void start();
|
||||
void configure();
|
||||
Response request(Request req);
|
||||
boost::optional<std::string> getTenantPrefixes(fdb::Transaction tx,
|
||||
int id_begin,
|
||||
int id_end,
|
||||
std::vector<fdb::ByteString>& out_prefixes);
|
||||
boost::optional<std::string> createTenant(fdb::Database db, int id_begin, int id_end);
|
||||
boost::optional<std::string> deleteTenant(fdb::Database db, int id_begin, int id_end);
|
||||
TenantIdsResponse fetchTenantIds(fdb::Database db, int id_begin, int id_end);
|
||||
|
||||
template <class T>
|
||||
static void sendObject(boost::process::pstream& pipe, T obj) {
|
||||
boost::archive::binary_oarchive oa(pipe);
|
||||
oa << obj;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
static T receiveObject(boost::process::pstream& pipe) {
|
||||
boost::archive::binary_iarchive ia(pipe);
|
||||
T obj;
|
||||
ia >> obj;
|
||||
return obj;
|
||||
}
|
||||
|
||||
template <class RequestType>
|
||||
static void sendResponse(boost::process::pstream& pipe, typename RequestType::ResponseType obj) {
|
||||
sendObject(pipe, std::move(obj));
|
||||
}
|
||||
|
||||
public:
|
||||
AdminServer(const Arguments& args)
|
||||
|
@ -107,9 +167,14 @@ public:
|
|||
// forks a server subprocess internally
|
||||
|
||||
bool isClient() const noexcept { return server_pid > 0; }
|
||||
|
||||
template <class T>
|
||||
Response send(T req) {
|
||||
return request(Request(std::forward<T>(req)));
|
||||
typename T::ResponseType send(T req) {
|
||||
// should always be invoked from client side (currently just the main process)
|
||||
assert(server_pid > 0);
|
||||
assert(logr.isFor(ProcKind::MAIN));
|
||||
sendObject(pipe_to_server, Request(std::move(req)));
|
||||
return receiveObject<typename T::ResponseType>(pipe_to_client);
|
||||
}
|
||||
|
||||
AdminServer(const AdminServer&) = delete;
|
||||
|
|
|
@ -90,10 +90,11 @@ using namespace mako;
|
|||
|
||||
thread_local Logger logr = Logger(MainProcess{}, VERBOSE_DEFAULT);
|
||||
|
||||
Transaction createNewTransaction(Database db, Arguments const& args, int id = -1, Tenant* tenants = nullptr) {
|
||||
std::pair<Transaction, std::optional<std::string> /*token*/>
|
||||
createNewTransaction(Database db, Arguments const& args, int id, std::optional<std::vector<Tenant>>& tenants) {
|
||||
// No tenants specified
|
||||
if (args.active_tenants <= 0) {
|
||||
return db.createTransaction();
|
||||
return { db.createTransaction(), {} };
|
||||
}
|
||||
// Create Tenant Transaction
|
||||
int tenant_id = (id == -1) ? urand(0, args.active_tenants - 1) : id;
|
||||
|
@ -101,7 +102,7 @@ Transaction createNewTransaction(Database db, Arguments const& args, int id = -1
|
|||
std::string tenant_name;
|
||||
// If provided tenants array, use it
|
||||
if (tenants) {
|
||||
tr = tenants[tenant_id].createTransaction();
|
||||
tr = (*tenants)[tenant_id].createTransaction();
|
||||
} else {
|
||||
tenant_name = getTenantNameByIndex(tenant_id);
|
||||
Tenant t = db.openTenant(toBytesRef(tenant_name));
|
||||
|
@ -120,7 +121,7 @@ Transaction createNewTransaction(Database db, Arguments const& args, int id = -1
|
|||
_exit(1);
|
||||
}
|
||||
}
|
||||
return tr;
|
||||
return { tr, { tenant_name } };
|
||||
}
|
||||
|
||||
int cleanupTenants(ipc::AdminServer& server, Arguments const& args, int db_id) {
|
||||
|
@ -197,14 +198,11 @@ int populate(Database db, const ThreadArgs& thread_args, int thread_tps, Workflo
|
|||
auto watch_trace = Stopwatch(watch_total.getStart());
|
||||
|
||||
// tenants are assumed to have been generated by populateTenants() at main process, pre-fork
|
||||
Tenant tenants[args.active_tenants];
|
||||
for (int i = 0; i < args.active_tenants; ++i) {
|
||||
tenants[i] = db.openTenant(toBytesRef(getTenantNameByIndex(i)));
|
||||
}
|
||||
std::optional<std::vector<Tenant>> tenants = args.prepareTenants(db);
|
||||
int populate_iters = args.active_tenants > 0 ? args.active_tenants : 1;
|
||||
// Each tenant should have the same range populated
|
||||
for (auto t_id = 0; t_id < populate_iters; ++t_id) {
|
||||
Transaction tx = createNewTransaction(db, args, t_id, args.active_tenants > 0 ? tenants : nullptr);
|
||||
auto [tx, token] = createNewTransaction(db, args, t_id, tenants);
|
||||
const auto key_begin = insertBegin(args.rows, process_idx, thread_idx, args.num_processes, args.num_threads);
|
||||
const auto key_end = insertEnd(args.rows, process_idx, thread_idx, args.num_processes, args.num_threads);
|
||||
auto key_checkpoint = key_begin; // in case of commit failure, restart from this key
|
||||
|
@ -262,7 +260,7 @@ int populate(Database db, const ThreadArgs& thread_args, int thread_tps, Workflo
|
|||
auto tx_restarter = ExitGuard([&watch_tx]() { watch_tx.startFromStop(); });
|
||||
if (rc == FutureRC::OK) {
|
||||
key_checkpoint = i + 1; // restart on failures from next key
|
||||
tx = createNewTransaction(db, args, t_id, args.active_tenants > 0 ? tenants : nullptr);
|
||||
std::tie(tx, token) = createNewTransaction(db, args, t_id, tenants);
|
||||
} else if (rc == FutureRC::ABORT) {
|
||||
return -1;
|
||||
} else {
|
||||
|
@ -306,6 +304,7 @@ void updateErrorStatsRunMode(WorkflowStatistics& stats, fdb::Error err, int op)
|
|||
|
||||
/* run one iteration of configured transaction */
|
||||
int runOneTransaction(Transaction& tx,
|
||||
std::optional<std::string> const& token,
|
||||
Arguments const& args,
|
||||
WorkflowStatistics& stats,
|
||||
ByteString& key1,
|
||||
|
@ -356,6 +355,8 @@ transaction_begin:
|
|||
stats.addLatency(OP_COMMIT, step_latency);
|
||||
}
|
||||
tx.reset();
|
||||
if (token)
|
||||
tx.setOption(FDB_TR_OPTION_AUTHORIZATION_TOKEN, *token);
|
||||
stats.incrOpCount(OP_COMMIT);
|
||||
needs_commit = false;
|
||||
}
|
||||
|
@ -444,12 +445,7 @@ int runWorkload(Database db,
|
|||
auto val = ByteString{};
|
||||
val.resize(args.value_length);
|
||||
|
||||
// mimic typical tenant usage: keep tenants in memory
|
||||
// and create transactions as needed
|
||||
Tenant tenants[args.active_tenants];
|
||||
for (int i = 0; i < args.active_tenants; ++i) {
|
||||
tenants[i] = db.openTenant(toBytesRef(getTenantNameByIndex(i)));
|
||||
}
|
||||
std::optional<std::vector<fdb::Tenant>> tenants = args.prepareTenants(db);
|
||||
|
||||
/* main transaction loop */
|
||||
while (1) {
|
||||
|
@ -470,7 +466,7 @@ int runWorkload(Database db,
|
|||
}
|
||||
|
||||
if (current_tps > 0 || thread_tps == 0 /* throttling off */) {
|
||||
Transaction tx = createNewTransaction(db, args, -1, args.active_tenants > 0 ? tenants : nullptr);
|
||||
auto [tx, token] = createNewTransaction(db, args, -1, tenants);
|
||||
setTransactionTimeoutIfEnabled(args, tx);
|
||||
|
||||
/* enable transaction trace */
|
||||
|
@ -507,7 +503,7 @@ int runWorkload(Database db,
|
|||
}
|
||||
}
|
||||
|
||||
rc = runOneTransaction(tx, args, workflow_stats, key1, key2, val);
|
||||
rc = runOneTransaction(tx, token, args, workflow_stats, key1, key2, val);
|
||||
if (rc) {
|
||||
logr.warn("runOneTransaction failed ({})", rc);
|
||||
}
|
||||
|
@ -580,7 +576,7 @@ void runAsyncWorkload(Arguments const& args,
|
|||
auto state =
|
||||
std::make_shared<ResumableStateForPopulate>(Logger(WorkerProcess{}, args.verbose, process_idx, i),
|
||||
db,
|
||||
createNewTransaction(db, args),
|
||||
db.createTransaction(),
|
||||
io_context,
|
||||
args,
|
||||
shm.workerStatsSlot(process_idx, i),
|
||||
|
@ -613,7 +609,7 @@ void runAsyncWorkload(Arguments const& args,
|
|||
auto state =
|
||||
std::make_shared<ResumableStateForRunWorkload>(Logger(WorkerProcess{}, args.verbose, process_idx, i),
|
||||
db,
|
||||
createNewTransaction(db, args),
|
||||
db.createTransaction(),
|
||||
io_context,
|
||||
args,
|
||||
shm.workerStatsSlot(process_idx, i),
|
||||
|
@ -1685,6 +1681,10 @@ int Arguments::validate() {
|
|||
}
|
||||
|
||||
if (enable_token_based_authorization) {
|
||||
if (num_fdb_clusters > 1) {
|
||||
logr.error("for simplicity, --enable_token_based_authorization must be used with exactly one fdb cluster");
|
||||
return -1;
|
||||
}
|
||||
if (active_tenants <= 0 || total_tenants <= 0) {
|
||||
logr.error("--enable_token_based_authorization must be used with at least one tenant");
|
||||
return -1;
|
||||
|
@ -1708,18 +1708,40 @@ bool Arguments::isAuthorizationEnabled() const noexcept {
|
|||
private_key_pem.has_value();
|
||||
}
|
||||
|
||||
void Arguments::collectTenantIds() {
|
||||
auto db = Database(cluster_files[0]);
|
||||
tenant_ids.clear();
|
||||
tenant_ids.reserve(active_tenants);
|
||||
}
|
||||
|
||||
void Arguments::generateAuthorizationTokens() {
|
||||
assert(active_tenants > 0);
|
||||
assert(keypair_id.has_value());
|
||||
assert(private_key_pem.has_value());
|
||||
authorization_tokens.clear();
|
||||
assert(num_fdb_clusters == 1);
|
||||
assert(!tenant_ids.empty());
|
||||
// assumes tenants have already been populated
|
||||
logr.info("generating authorization tokens to be used by worker threads");
|
||||
auto stopwatch = Stopwatch(StartAtCtor{});
|
||||
authorization_tokens = generateAuthorizationTokenMap(active_tenants, keypair_id.value(), private_key_pem.value());
|
||||
authorization_tokens =
|
||||
generateAuthorizationTokenMap(active_tenants, keypair_id.value(), private_key_pem.value(), tenant_ids);
|
||||
assert(authorization_tokens.size() == active_tenants);
|
||||
logr.info("generated {} tokens in {:6.3f} seconds", active_tenants, toDoubleSeconds(stopwatch.stop().diff()));
|
||||
}
|
||||
|
||||
std::optional<std::vector<fdb::Tenant>> Arguments::prepareTenants(fdb::Database db) const {
|
||||
if (active_tenants > 0) {
|
||||
std::vector<fdb::Tenant> tenants(active_tenants);
|
||||
for (auto i = 0; i < active_tenants; i++) {
|
||||
tenants[i] = db.openTenant(toBytesRef(getTenantNameByIndex(i)));
|
||||
}
|
||||
return tenants;
|
||||
} else {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
void printStats(Arguments const& args, WorkflowStatistics const* stats, double const duration_sec, FILE* fp) {
|
||||
static WorkflowStatistics prev;
|
||||
|
||||
|
@ -2503,10 +2525,6 @@ int main(int argc, char* argv[]) {
|
|||
|
||||
logr.setVerbosity(args.verbose);
|
||||
|
||||
if (args.isAuthorizationEnabled()) {
|
||||
args.generateAuthorizationTokens();
|
||||
}
|
||||
|
||||
if (args.mode == MODE_CLEAN) {
|
||||
/* cleanup will be done from a single thread */
|
||||
args.num_processes = 1;
|
||||
|
@ -2525,7 +2543,8 @@ int main(int argc, char* argv[]) {
|
|||
return 0;
|
||||
}
|
||||
|
||||
if (args.total_tenants > 0 && (args.mode == MODE_BUILD || args.mode == MODE_CLEAN)) {
|
||||
if (args.total_tenants > 0 &&
|
||||
(args.isAuthorizationEnabled() || args.mode == MODE_BUILD || args.mode == MODE_CLEAN)) {
|
||||
|
||||
// below construction fork()s internally
|
||||
auto server = ipc::AdminServer(args);
|
||||
|
@ -2542,12 +2561,11 @@ int main(int argc, char* argv[]) {
|
|||
logr.info("admin server ready");
|
||||
}
|
||||
}
|
||||
// Use admin server to request tenant creation or deletion.
|
||||
// This is necessary when tenant authorization is enabled,
|
||||
// in which case the worker threads connect to database as untrusted clients,
|
||||
// as which they wouldn't be allowed to create/delete tenants on their own.
|
||||
// Although it is possible to allow worker threads to create/delete
|
||||
// tenants in a authorization-disabled mode, use the admin server anyway for simplicity.
|
||||
// Use admin server as proxy to creating/deleting tenants or pre-fetching tenant IDs for token signing when
|
||||
// authorization is enabled This is necessary when tenant authorization is enabled, in which case the worker
|
||||
// threads connect to database as untrusted clients, as which they wouldn't be allowed to create/delete tenants
|
||||
// on their own. Although it is possible to allow worker threads to create/delete tenants in a
|
||||
// authorization-disabled mode, use the admin server anyway for simplicity.
|
||||
if (args.mode == MODE_CLEAN) {
|
||||
// short-circuit tenant cleanup
|
||||
const auto num_dbs = std::min(args.num_fdb_clusters, args.num_databases);
|
||||
|
@ -2563,6 +2581,21 @@ int main(int argc, char* argv[]) {
|
|||
return -1;
|
||||
}
|
||||
}
|
||||
if ((args.mode == MODE_BUILD || args.mode == MODE_RUN) && args.isAuthorizationEnabled()) {
|
||||
assert(args.num_fdb_clusters == 1);
|
||||
// need to fetch tenant IDs to pre-generate tokens
|
||||
// fetch all IDs in one go
|
||||
auto res = server.send(ipc::FetchTenantIdsRequest{ args.cluster_files[0], 0, args.active_tenants });
|
||||
if (res.error_message) {
|
||||
logr.error("tenant ID fetch failed: {}", *res.error_message);
|
||||
return -1;
|
||||
} else {
|
||||
logr.info("Successfully prefetched {} tenant IDs", res.ids.size());
|
||||
assert(res.ids.size() == args.active_tenants);
|
||||
args.tenant_ids = std::move(res.ids);
|
||||
}
|
||||
args.generateAuthorizationTokens();
|
||||
}
|
||||
}
|
||||
|
||||
const auto pid_main = getpid();
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
#include <chrono>
|
||||
#include <list>
|
||||
#include <map>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <vector>
|
||||
|
@ -143,10 +144,13 @@ constexpr const int MAX_REPORT_FILES = 200;
|
|||
struct Arguments {
|
||||
Arguments();
|
||||
int validate();
|
||||
void collectTenantIds();
|
||||
bool isAuthorizationEnabled() const noexcept;
|
||||
std::optional<std::vector<fdb::Tenant>> prepareTenants(fdb::Database db) const;
|
||||
void generateAuthorizationTokens();
|
||||
|
||||
// Needs to be called once per fdb-accessing process
|
||||
// Needs to be called once per fdb client process from a clean state:
|
||||
// i.e. no FDB API called
|
||||
int setGlobalOptions() const;
|
||||
bool isAnyTimeoutEnabled() const;
|
||||
|
||||
|
@ -206,6 +210,7 @@ struct Arguments {
|
|||
std::optional<std::string> keypair_id;
|
||||
std::optional<std::string> private_key_pem;
|
||||
std::map<std::string, std::string> authorization_tokens; // maps tenant name to token string
|
||||
std::vector<int64_t> tenant_ids; // maps tenant index to tenant id for signing tokens
|
||||
int transaction_timeout_db;
|
||||
int transaction_timeout_tx;
|
||||
};
|
||||
|
|
|
@ -28,7 +28,8 @@ namespace mako {
|
|||
|
||||
std::map<std::string, std::string> generateAuthorizationTokenMap(int num_tenants,
|
||||
std::string public_key_id,
|
||||
std::string private_key_pem) {
|
||||
std::string private_key_pem,
|
||||
const std::vector<int64_t>& tenant_ids) {
|
||||
std::map<std::string, std::string> m;
|
||||
auto t = authz::jwt::stdtypes::TokenSpec{};
|
||||
auto const now = toIntegerSeconds(std::chrono::system_clock::now().time_since_epoch());
|
||||
|
@ -40,14 +41,14 @@ std::map<std::string, std::string> generateAuthorizationTokenMap(int num_tenants
|
|||
t.issuedAtUnixTime = now;
|
||||
t.expiresAtUnixTime = now + 60 * 60 * 12; // Good for 12 hours
|
||||
t.notBeforeUnixTime = now - 60 * 5; // activated 5 mins ago
|
||||
const int tokenIdLen = 36; // UUID length
|
||||
auto tokenId = std::string(tokenIdLen, '\0');
|
||||
const int tokenid_len = 36; // UUID length
|
||||
auto tokenid = std::string(tokenid_len, '\0');
|
||||
for (auto i = 0; i < num_tenants; i++) {
|
||||
std::string tenant_name = getTenantNameByIndex(i);
|
||||
// swap out only the token ids and tenant names
|
||||
randomAlphanumString(tokenId.data(), tokenIdLen);
|
||||
t.tokenId = tokenId;
|
||||
t.tenants = std::vector<std::string>{ tenant_name };
|
||||
randomAlphanumString(tokenid.data(), tokenid_len);
|
||||
t.tokenId = tokenid;
|
||||
t.tenants = std::vector<int64_t>{ tenant_ids[i] };
|
||||
m[tenant_name] = authz::jwt::stdtypes::signToken(t, private_key_pem);
|
||||
}
|
||||
return m;
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include <cassert>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "fdb_api.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
|
@ -28,7 +29,8 @@ namespace mako {
|
|||
|
||||
std::map<std::string, std::string> generateAuthorizationTokenMap(int tenants,
|
||||
std::string public_key_id,
|
||||
std::string private_key_pem);
|
||||
std::string private_key_pem,
|
||||
const std::vector<int64_t>& tenant_ids);
|
||||
|
||||
inline std::string getTenantNameByIndex(int index) {
|
||||
assert(index >= 0);
|
||||
|
|
|
@ -3312,7 +3312,6 @@ Reference<TransactionState> TransactionState::cloneAndReset(Reference<Transactio
|
|||
newState->startTime = startTime;
|
||||
newState->committedVersion = committedVersion;
|
||||
newState->conflictingKeys = conflictingKeys;
|
||||
newState->authToken = authToken;
|
||||
newState->tenantSet = tenantSet;
|
||||
|
||||
return newState;
|
||||
|
|
|
@ -684,6 +684,8 @@ struct GlobalConfigRefreshRequest {
|
|||
GlobalConfigRefreshRequest() {}
|
||||
explicit GlobalConfigRefreshRequest(Version lastKnown) : lastKnown(lastKnown) {}
|
||||
|
||||
bool verify() const noexcept { return true; }
|
||||
|
||||
template <class Ar>
|
||||
void serialize(Ar& ar) {
|
||||
serializer(ar, lastKnown, reply);
|
||||
|
|
|
@ -317,6 +317,9 @@ struct ProtocolInfoRequest {
|
|||
constexpr static FileIdentifier file_identifier = 13261233;
|
||||
ReplyPromise<ProtocolInfoReply> reply{ PeerCompatibilityPolicy{ RequirePeer::AtLeast,
|
||||
ProtocolVersion::withStableInterfaces() } };
|
||||
|
||||
bool verify() const noexcept { return true; }
|
||||
|
||||
template <class Ar>
|
||||
void serialize(Ar& ar) {
|
||||
serializer(ar, reply);
|
||||
|
|
|
@ -43,7 +43,7 @@ struct GrvProxyInterface {
|
|||
// committed)
|
||||
RequestStream<ReplyPromise<Void>> waitFailure; // reports heartbeat to master.
|
||||
RequestStream<struct GetHealthMetricsRequest> getHealthMetrics;
|
||||
RequestStream<struct GlobalConfigRefreshRequest> refreshGlobalConfig;
|
||||
PublicRequestStream<struct GlobalConfigRefreshRequest> refreshGlobalConfig;
|
||||
|
||||
UID id() const { return getConsistentReadVersion.getEndpoint().token; }
|
||||
std::string toString() const { return id().shortString(); }
|
||||
|
@ -60,7 +60,7 @@ struct GrvProxyInterface {
|
|||
RequestStream<ReplyPromise<Void>>(getConsistentReadVersion.getEndpoint().getAdjustedEndpoint(1));
|
||||
getHealthMetrics = RequestStream<struct GetHealthMetricsRequest>(
|
||||
getConsistentReadVersion.getEndpoint().getAdjustedEndpoint(2));
|
||||
refreshGlobalConfig = RequestStream<struct GlobalConfigRefreshRequest>(
|
||||
refreshGlobalConfig = PublicRequestStream<struct GlobalConfigRefreshRequest>(
|
||||
getConsistentReadVersion.getEndpoint().getAdjustedEndpoint(3));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -331,7 +331,8 @@ description is not currently required but encouraged.
|
|||
hidden="true"/>
|
||||
<Option name="authorization_token" code="2000"
|
||||
description="Attach given authorization token to the transaction such that subsequent tenant-aware requests are authorized"
|
||||
paramType="String" paramDescription="A JSON Web Token authorized to access data belonging to one or more tenants, indicated by 'tenants' claim of the token's payload."/>
|
||||
paramType="String" paramDescription="A JSON Web Token authorized to access data belonging to one or more tenants, indicated by 'tenants' claim of the token's payload."
|
||||
persistent="true"/>
|
||||
</Scope>
|
||||
|
||||
<!-- The enumeration values matter - do not change them without
|
||||
|
|
|
@ -1074,7 +1074,8 @@ ACTOR static void deliver(TransportData* self,
|
|||
if (receiver) {
|
||||
TraceEvent(SevWarnAlways, "AttemptedRPCToPrivatePrevented")
|
||||
.detail("From", peerAddress)
|
||||
.detail("Token", destination.token);
|
||||
.detail("Token", destination.token)
|
||||
.detail("Receiver", typeid(*receiver).name());
|
||||
ASSERT(!self->isLocalAddress(destination.getPrimaryAddress()));
|
||||
Reference<Peer> peer = self->getOrOpenPeer(destination.getPrimaryAddress());
|
||||
sendPacket(self,
|
||||
|
|
|
@ -22,6 +22,8 @@
|
|||
|
||||
#include "flow/actorcompiler.h" // has to be last include
|
||||
|
||||
using authz::TenantId;
|
||||
|
||||
template <class Key, class Value>
|
||||
class LRUCache {
|
||||
public:
|
||||
|
@ -132,28 +134,26 @@ TEST_CASE("/fdbrpc/authz/LRUCache") {
|
|||
|
||||
struct CacheEntry {
|
||||
Arena arena;
|
||||
VectorRef<TenantNameRef> tenants;
|
||||
VectorRef<TenantId> tenants;
|
||||
Optional<StringRef> tokenId;
|
||||
double expirationTime = 0.0;
|
||||
};
|
||||
|
||||
struct AuditEntry {
|
||||
NetworkAddress address;
|
||||
TenantId tenantId;
|
||||
Optional<Standalone<StringRef>> tokenId;
|
||||
explicit AuditEntry(NetworkAddress const& address, CacheEntry const& cacheEntry)
|
||||
: address(address),
|
||||
bool operator==(const AuditEntry& other) const noexcept = default;
|
||||
explicit AuditEntry(NetworkAddress const& address, TenantId tenantId, CacheEntry const& cacheEntry)
|
||||
: address(address), tenantId(tenantId),
|
||||
tokenId(cacheEntry.tokenId.present() ? Standalone<StringRef>(cacheEntry.tokenId.get(), cacheEntry.arena)
|
||||
: Optional<Standalone<StringRef>>()) {}
|
||||
};
|
||||
|
||||
bool operator==(AuditEntry const& lhs, AuditEntry const& rhs) {
|
||||
return (lhs.address == rhs.address) && (lhs.tokenId.present() == rhs.tokenId.present()) &&
|
||||
(!lhs.tokenId.present() || lhs.tokenId.get() == rhs.tokenId.get());
|
||||
}
|
||||
|
||||
std::size_t hash_value(AuditEntry const& value) {
|
||||
std::size_t seed = 0;
|
||||
boost::hash_combine(seed, value.address);
|
||||
boost::hash_combine(seed, value.tenantId);
|
||||
if (value.tokenId.present()) {
|
||||
boost::hash_combine(seed, value.tokenId.get());
|
||||
}
|
||||
|
@ -161,38 +161,17 @@ std::size_t hash_value(AuditEntry const& value) {
|
|||
}
|
||||
|
||||
struct TokenCacheImpl {
|
||||
TokenCacheImpl();
|
||||
LRUCache<StringRef, CacheEntry> cache;
|
||||
boost::unordered_set<AuditEntry> usedTokens;
|
||||
Future<Void> auditor;
|
||||
TokenCacheImpl();
|
||||
double lastResetTime;
|
||||
|
||||
bool validate(TenantNameRef tenant, StringRef token);
|
||||
bool validate(TenantId tenantId, StringRef token);
|
||||
bool validateAndAdd(double currentTime, StringRef token, NetworkAddress const& peer);
|
||||
void logTokenUsage(double currentTime, AuditEntry&& entry);
|
||||
};
|
||||
|
||||
ACTOR Future<Void> tokenCacheAudit(TokenCacheImpl* self) {
|
||||
state boost::unordered_set<AuditEntry> audits;
|
||||
state boost::unordered_set<AuditEntry>::iterator iter;
|
||||
state double lastLoggedTime = 0;
|
||||
loop {
|
||||
auto const timeSinceLog = g_network->timer() - lastLoggedTime;
|
||||
if (timeSinceLog < FLOW_KNOBS->AUDIT_TIME_WINDOW) {
|
||||
wait(delay(FLOW_KNOBS->AUDIT_TIME_WINDOW - timeSinceLog));
|
||||
}
|
||||
lastLoggedTime = g_network->timer();
|
||||
audits.swap(self->usedTokens);
|
||||
for (iter = audits.begin(); iter != audits.end(); ++iter) {
|
||||
CODE_PROBE(true, "Audit Logging Running");
|
||||
TraceEvent("AuditTokenUsed").detail("Client", iter->address).detail("TokenId", iter->tokenId).log();
|
||||
wait(yield());
|
||||
}
|
||||
audits.clear();
|
||||
}
|
||||
}
|
||||
|
||||
TokenCacheImpl::TokenCacheImpl() : cache(FLOW_KNOBS->TOKEN_CACHE_SIZE) {
|
||||
auditor = tokenCacheAudit(this);
|
||||
}
|
||||
TokenCacheImpl::TokenCacheImpl() : cache(FLOW_KNOBS->TOKEN_CACHE_SIZE), usedTokens(), lastResetTime(0) {}
|
||||
|
||||
TokenCache::TokenCache() : impl(new TokenCacheImpl()) {}
|
||||
TokenCache::~TokenCache() {
|
||||
|
@ -207,8 +186,8 @@ TokenCache& TokenCache::instance() {
|
|||
return *reinterpret_cast<TokenCache*>(g_network->global(INetwork::enTokenCache));
|
||||
}
|
||||
|
||||
bool TokenCache::validate(TenantNameRef name, StringRef token) {
|
||||
return impl->validate(name, token);
|
||||
bool TokenCache::validate(TenantId tenantId, StringRef token) {
|
||||
return impl->validate(tenantId, token);
|
||||
}
|
||||
|
||||
#define TRACE_INVALID_PARSED_TOKEN(reason, token) \
|
||||
|
@ -280,8 +259,8 @@ bool TokenCacheImpl::validateAndAdd(double currentTime, StringRef token, Network
|
|||
CacheEntry c;
|
||||
c.expirationTime = t.expiresAtUnixTime.get();
|
||||
c.tenants.reserve(c.arena, t.tenants.get().size());
|
||||
for (auto tenant : t.tenants.get()) {
|
||||
c.tenants.push_back_deep(c.arena, tenant);
|
||||
for (auto tenantId : t.tenants.get()) {
|
||||
c.tenants.push_back(c.arena, tenantId);
|
||||
}
|
||||
if (t.tokenId.present()) {
|
||||
c.tokenId = StringRef(c.arena, t.tokenId.get());
|
||||
|
@ -291,7 +270,7 @@ bool TokenCacheImpl::validateAndAdd(double currentTime, StringRef token, Network
|
|||
}
|
||||
}
|
||||
|
||||
bool TokenCacheImpl::validate(TenantNameRef name, StringRef token) {
|
||||
bool TokenCacheImpl::validate(TenantId tenantId, StringRef token) {
|
||||
NetworkAddress peer = FlowTransport::transport().currentDeliveryPeerAddress();
|
||||
auto cachedEntry = cache.get(token);
|
||||
double currentTime = g_network->timer();
|
||||
|
@ -314,21 +293,44 @@ bool TokenCacheImpl::validate(TenantNameRef name, StringRef token) {
|
|||
}
|
||||
bool tenantFound = false;
|
||||
for (auto const& t : entry->tenants) {
|
||||
if (t == name) {
|
||||
if (t == tenantId) {
|
||||
tenantFound = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!tenantFound) {
|
||||
CODE_PROBE(true, "Valid token doesn't reference tenant");
|
||||
TraceEvent(SevWarn, "TenantTokenMismatch").detail("From", peer).detail("Tenant", name.toString());
|
||||
TraceEvent(SevWarn, "TenantTokenMismatch")
|
||||
.detail("From", peer)
|
||||
.detail("RequestedTenant", fmt::format("{:#x}", tenantId))
|
||||
.detail("TenantsInToken", fmt::format("{:#x}", fmt::join(entry->tenants, " ")));
|
||||
return false;
|
||||
}
|
||||
// audit logging
|
||||
usedTokens.insert(AuditEntry(peer, *cachedEntry.get()));
|
||||
if (FLOW_KNOBS->AUDIT_LOGGING_ENABLED)
|
||||
logTokenUsage(currentTime, AuditEntry(peer, tenantId, *cachedEntry.get()));
|
||||
return true;
|
||||
}
|
||||
|
||||
void TokenCacheImpl::logTokenUsage(double currentTime, AuditEntry&& entry) {
|
||||
if (currentTime > lastResetTime + FLOW_KNOBS->AUDIT_TIME_WINDOW) {
|
||||
// clear usage cache every AUDIT_TIME_WINDOW seconds
|
||||
usedTokens.clear();
|
||||
lastResetTime = currentTime;
|
||||
}
|
||||
auto [iter, inserted] = usedTokens.insert(std::move(entry));
|
||||
if (inserted) {
|
||||
// access in the context of this (client_ip, tenant, token_id) tuple hasn't been logged in current window. log
|
||||
// usage.
|
||||
CODE_PROBE(true, "Audit Logging Running");
|
||||
TraceEvent("AuditTokenUsed")
|
||||
.detail("Client", iter->address)
|
||||
.detail("TenantId", fmt::format("{:#x}", iter->tenantId))
|
||||
.detail("TokenId", iter->tokenId)
|
||||
.log();
|
||||
}
|
||||
}
|
||||
|
||||
namespace authz::jwt {
|
||||
extern TokenRef makeRandomTokenSpec(Arena&, IRandom&, authz::Algorithm);
|
||||
}
|
||||
|
@ -375,9 +377,9 @@ TEST_CASE("/fdbrpc/authz/TokenCache/BadTokens") {
|
|||
},
|
||||
{
|
||||
[](Arena& arena, IRandom&, authz::jwt::TokenRef& token) {
|
||||
StringRef* newTenants = new (arena) StringRef[1];
|
||||
*newTenants = token.tenants.get()[0].substr(1);
|
||||
token.tenants = VectorRef<StringRef>(newTenants, 1);
|
||||
TenantId* newTenants = new (arena) TenantId[1];
|
||||
*newTenants = token.tenants.get()[0] + 1;
|
||||
token.tenants = VectorRef<TenantId>(newTenants, 1);
|
||||
},
|
||||
"UnmatchedTenant",
|
||||
},
|
||||
|
@ -443,15 +445,15 @@ TEST_CASE("/fdbrpc/authz/TokenCache/BadTokens") {
|
|||
}
|
||||
}
|
||||
}
|
||||
if (TokenCache::instance().validate("TenantNameDontMatterHere"_sr, StringRef())) {
|
||||
if (TokenCache::instance().validate(TenantInfo::INVALID_TENANT, StringRef())) {
|
||||
fmt::print("Unexpected successful validation of ill-formed token (no signature part)\n");
|
||||
ASSERT(false);
|
||||
}
|
||||
if (TokenCache::instance().validate("TenantNameDontMatterHere"_sr, "1111.22"_sr)) {
|
||||
if (TokenCache::instance().validate(TenantInfo::INVALID_TENANT, "1111.22"_sr)) {
|
||||
fmt::print("Unexpected successful validation of ill-formed token (no signature part)\n");
|
||||
ASSERT(false);
|
||||
}
|
||||
if (TokenCache::instance().validate("TenantNameDontMatterHere2"_sr, "////.////.////"_sr)) {
|
||||
if (TokenCache::instance().validate(TenantInfo::INVALID_TENANT, "////.////.////"_sr)) {
|
||||
fmt::print("Unexpected successful validation of unparseable token\n");
|
||||
ASSERT(false);
|
||||
}
|
||||
|
|
|
@ -52,11 +52,11 @@
|
|||
|
||||
namespace {
|
||||
|
||||
// test-only constants for generating random tenant/key names
|
||||
// test-only constants for generating random tenant ID and key names
|
||||
constexpr int MinIssuerNameLen = 16;
|
||||
constexpr int MaxIssuerNameLenPlus1 = 25;
|
||||
constexpr int MinTenantNameLen = 8;
|
||||
constexpr int MaxTenantNameLenPlus1 = 17;
|
||||
constexpr authz::TenantId MinTenantId = 1;
|
||||
constexpr authz::TenantId MaxTenantIdPlus1 = 0xffffffffll;
|
||||
constexpr int MinKeyNameLen = 10;
|
||||
constexpr int MaxKeyNameLenPlus1 = 21;
|
||||
|
||||
|
@ -176,6 +176,14 @@ void appendField(fmt::memory_buffer& b, char const (&name)[NameLen], Optional<Fi
|
|||
fmt::format_to(bi, fmt::runtime(f[i].toStringView()));
|
||||
}
|
||||
fmt::format_to(bi, "]");
|
||||
} else if constexpr (std::is_same_v<FieldType, VectorRef<TenantId>>) {
|
||||
fmt::format_to(bi, " {}=[", name);
|
||||
for (auto i = 0; i < f.size(); i++) {
|
||||
if (i)
|
||||
fmt::format_to(bi, ",");
|
||||
fmt::format_to(bi, "{:#x}", f[i]);
|
||||
}
|
||||
fmt::format_to(bi, "]");
|
||||
} else if constexpr (std::is_same_v<FieldType, StringRef>) {
|
||||
fmt::format_to(bi, " {}={}", name, f.toStringView());
|
||||
} else {
|
||||
|
@ -202,33 +210,34 @@ StringRef toStringRef(Arena& arena, const TokenRef& tokenSpec) {
|
|||
return StringRef(str, buf.size());
|
||||
}
|
||||
|
||||
template <class FieldType, class Writer, bool MakeStringArrayBase64 = false>
|
||||
void putField(Optional<FieldType> const& field,
|
||||
Writer& wr,
|
||||
const char* fieldName,
|
||||
std::bool_constant<MakeStringArrayBase64> _ = std::bool_constant<false>{}) {
|
||||
template <class FieldType, class Writer>
|
||||
void putField(Optional<FieldType> const& field, Writer& wr, const char* fieldName) {
|
||||
if (!field.present())
|
||||
return;
|
||||
wr.Key(fieldName);
|
||||
auto const& value = field.get();
|
||||
static_assert(std::is_same_v<StringRef, FieldType> || std::is_same_v<FieldType, uint64_t> ||
|
||||
std::is_same_v<FieldType, VectorRef<StringRef>>);
|
||||
std::is_same_v<FieldType, VectorRef<StringRef>> || std::is_same_v<FieldType, VectorRef<TenantId>>);
|
||||
if constexpr (std::is_same_v<StringRef, FieldType>) {
|
||||
wr.String(reinterpret_cast<const char*>(value.begin()), value.size());
|
||||
} else if constexpr (std::is_same_v<FieldType, uint64_t>) {
|
||||
wr.Uint64(value);
|
||||
} else if constexpr (std::is_same_v<FieldType, VectorRef<TenantId>>) {
|
||||
// "tenants" array = array of base64-encoded tenant key prefix
|
||||
// key prefix = bytestring representation of big-endian tenant ID (int64_t)
|
||||
Arena arena;
|
||||
wr.StartArray();
|
||||
for (auto elem : value) {
|
||||
auto const bigEndianId = bigEndian64(elem);
|
||||
auto encodedElem =
|
||||
base64::encode(arena, StringRef(reinterpret_cast<const uint8_t*>(&bigEndianId), sizeof(bigEndianId)));
|
||||
wr.String(reinterpret_cast<const char*>(encodedElem.begin()), encodedElem.size());
|
||||
}
|
||||
wr.EndArray();
|
||||
} else {
|
||||
wr.StartArray();
|
||||
if constexpr (MakeStringArrayBase64) {
|
||||
Arena arena;
|
||||
for (auto elem : value) {
|
||||
auto encodedElem = base64::encode(arena, elem);
|
||||
wr.String(reinterpret_cast<const char*>(encodedElem.begin()), encodedElem.size());
|
||||
}
|
||||
} else {
|
||||
for (auto elem : value) {
|
||||
wr.String(reinterpret_cast<const char*>(elem.begin()), elem.size());
|
||||
}
|
||||
for (auto elem : value) {
|
||||
wr.String(reinterpret_cast<const char*>(elem.begin()), elem.size());
|
||||
}
|
||||
wr.EndArray();
|
||||
}
|
||||
|
@ -259,7 +268,7 @@ StringRef makeSignInput(Arena& arena, const TokenRef& tokenSpec) {
|
|||
putField(tokenSpec.expiresAtUnixTime, payload, "exp");
|
||||
putField(tokenSpec.notBeforeUnixTime, payload, "nbf");
|
||||
putField(tokenSpec.tokenId, payload, "jti");
|
||||
putField(tokenSpec.tenants, payload, "tenants", std::bool_constant<true>{} /* encode tenants in base64 */);
|
||||
putField(tokenSpec.tenants, payload, "tenants");
|
||||
payload.EndObject();
|
||||
auto const headerPartLen = base64::url::encodedLength(headerBuffer.GetSize());
|
||||
auto const payloadPartLen = base64::url::encodedLength(payloadBuffer.GetSize());
|
||||
|
@ -347,18 +356,17 @@ Optional<StringRef> parseHeaderPart(Arena& arena, TokenRef& token, StringRef b64
|
|||
return {};
|
||||
}
|
||||
|
||||
template <class FieldType, bool ExpectBase64StringArray = false>
|
||||
template <class FieldType>
|
||||
Optional<StringRef> parseField(Arena& arena,
|
||||
Optional<FieldType>& out,
|
||||
const rapidjson::Document& d,
|
||||
const char* fieldName,
|
||||
std::bool_constant<ExpectBase64StringArray> _ = std::bool_constant<false>{}) {
|
||||
const char* fieldName) {
|
||||
auto fieldItr = d.FindMember(fieldName);
|
||||
if (fieldItr == d.MemberEnd())
|
||||
return {};
|
||||
auto const& field = fieldItr->value;
|
||||
static_assert(std::is_same_v<StringRef, FieldType> || std::is_same_v<FieldType, uint64_t> ||
|
||||
std::is_same_v<FieldType, VectorRef<StringRef>>);
|
||||
std::is_same_v<FieldType, VectorRef<StringRef>> || std::is_same_v<FieldType, VectorRef<TenantId>>);
|
||||
if constexpr (std::is_same_v<FieldType, StringRef>) {
|
||||
if (!field.IsString()) {
|
||||
return StringRef(arena, fmt::format("'{}' is not a string", fieldName));
|
||||
|
@ -369,7 +377,7 @@ Optional<StringRef> parseField(Arena& arena,
|
|||
return StringRef(arena, fmt::format("'{}' is not a number", fieldName));
|
||||
}
|
||||
out = static_cast<uint64_t>(field.GetDouble());
|
||||
} else {
|
||||
} else if constexpr (std::is_same_v<FieldType, VectorRef<StringRef>>) {
|
||||
if (!field.IsArray()) {
|
||||
return StringRef(arena, fmt::format("'{}' is not an array", fieldName));
|
||||
}
|
||||
|
@ -379,26 +387,50 @@ Optional<StringRef> parseField(Arena& arena,
|
|||
if (!field[i].IsString()) {
|
||||
return StringRef(arena, fmt::format("{}th element of '{}' is not a string", i + 1, fieldName));
|
||||
}
|
||||
if constexpr (ExpectBase64StringArray) {
|
||||
Optional<StringRef> decodedString = base64::decode(
|
||||
arena,
|
||||
StringRef(reinterpret_cast<const uint8_t*>(field[i].GetString()), field[i].GetStringLength()));
|
||||
if (decodedString.present()) {
|
||||
vector[i] = decodedString.get();
|
||||
} else {
|
||||
CODE_PROBE(true, "Base64 token field has failed to be parsed");
|
||||
return StringRef(arena,
|
||||
fmt::format("Failed to base64-decode {}th element of '{}'", i + 1, fieldName));
|
||||
}
|
||||
} else {
|
||||
vector[i] = StringRef(
|
||||
arena, reinterpret_cast<const uint8_t*>(field[i].GetString()), field[i].GetStringLength());
|
||||
}
|
||||
vector[i] = StringRef(
|
||||
arena, reinterpret_cast<const uint8_t*>(field[i].GetString()), field[i].GetStringLength());
|
||||
}
|
||||
out = VectorRef<StringRef>(vector, field.Size());
|
||||
} else {
|
||||
out = VectorRef<StringRef>();
|
||||
}
|
||||
} else {
|
||||
// tenant ids case: convert array of base64-encoded length-8 bytestring containing big-endian int64_t to
|
||||
// local-endian int64_t
|
||||
if (!field.IsArray()) {
|
||||
return StringRef(arena, fmt::format("'{}' is not an array", fieldName));
|
||||
}
|
||||
if (field.Size() > 0) {
|
||||
auto vector = new (arena) TenantId[field.Size()];
|
||||
for (auto i = 0; i < field.Size(); i++) {
|
||||
if (!field[i].IsString()) {
|
||||
return StringRef(arena, fmt::format("{}th element of '{}' is not a string", i + 1, fieldName));
|
||||
}
|
||||
Optional<StringRef> decodedString = base64::decode(
|
||||
arena,
|
||||
StringRef(reinterpret_cast<const uint8_t*>(field[i].GetString()), field[i].GetStringLength()));
|
||||
if (decodedString.present()) {
|
||||
auto const tenantPrefix = decodedString.get();
|
||||
if (tenantPrefix.size() != sizeof(TenantId)) {
|
||||
CODE_PROBE(true, "Tenant prefix has an invalid length");
|
||||
return StringRef(arena,
|
||||
fmt::format("{}th element of '{}' has an invalid bytewise length of {}",
|
||||
i + 1,
|
||||
fieldName,
|
||||
tenantPrefix.size()));
|
||||
}
|
||||
TenantId tenantId = *reinterpret_cast<const TenantId*>(tenantPrefix.begin());
|
||||
vector[i] = fromBigEndian64(tenantId);
|
||||
} else {
|
||||
CODE_PROBE(true, "Tenant field has failed to be parsed");
|
||||
return StringRef(arena,
|
||||
fmt::format("Failed to base64-decode {}th element of '{}'", i + 1, fieldName));
|
||||
}
|
||||
}
|
||||
out = VectorRef<TenantId>(vector, field.Size());
|
||||
} else {
|
||||
out = VectorRef<TenantId>();
|
||||
}
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
@ -431,12 +463,7 @@ Optional<StringRef> parsePayloadPart(Arena& arena, TokenRef& token, StringRef b6
|
|||
return err;
|
||||
if ((err = parseField(arena, token.notBeforeUnixTime, d, "nbf")).present())
|
||||
return err;
|
||||
if ((err = parseField(arena,
|
||||
token.tenants,
|
||||
d,
|
||||
"tenants",
|
||||
std::bool_constant<true>{} /* expect field elements encoded in base64 */))
|
||||
.present())
|
||||
if ((err = parseField(arena, token.tenants, d, "tenants")).present())
|
||||
return err;
|
||||
return {};
|
||||
}
|
||||
|
@ -526,16 +553,16 @@ TokenRef makeRandomTokenSpec(Arena& arena, IRandom& rng, Algorithm alg) {
|
|||
auto numAudience = rng.randomInt(1, 5);
|
||||
auto aud = new (arena) StringRef[numAudience];
|
||||
for (auto i = 0; i < numAudience; i++)
|
||||
aud[i] = genRandomAlphanumStringRef(arena, rng, MinTenantNameLen, MaxTenantNameLenPlus1);
|
||||
aud[i] = genRandomAlphanumStringRef(arena, rng, MinIssuerNameLen, MaxIssuerNameLenPlus1);
|
||||
ret.audience = VectorRef<StringRef>(aud, numAudience);
|
||||
ret.issuedAtUnixTime = g_network->timer();
|
||||
ret.notBeforeUnixTime = ret.issuedAtUnixTime.get();
|
||||
ret.expiresAtUnixTime = ret.issuedAtUnixTime.get() + rng.randomInt(360, 1080 + 1);
|
||||
auto numTenants = rng.randomInt(1, 3);
|
||||
auto tenants = new (arena) StringRef[numTenants];
|
||||
auto tenants = new (arena) TenantId[numTenants];
|
||||
for (auto i = 0; i < numTenants; i++)
|
||||
tenants[i] = genRandomAlphanumStringRef(arena, rng, MinTenantNameLen, MaxTenantNameLenPlus1);
|
||||
ret.tenants = VectorRef<StringRef>(tenants, numTenants);
|
||||
tenants[i] = rng.randomInt64(MinTenantId, MaxTenantIdPlus1);
|
||||
ret.tenants = VectorRef<TenantId>(tenants, numTenants);
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
@ -584,8 +611,7 @@ TEST_CASE("/fdbrpc/TokenSign/JWT") {
|
|||
ASSERT(verifyOk);
|
||||
}
|
||||
// try tampering with signed token by adding one more tenant
|
||||
tokenSpec.tenants.get().push_back(
|
||||
arena, genRandomAlphanumStringRef(arena, rng, MinTenantNameLen, MaxTenantNameLenPlus1));
|
||||
tokenSpec.tenants.get().push_back(arena, rng.randomInt64(MinTenantId, MaxTenantIdPlus1));
|
||||
auto tamperedTokenPart = makeSignInput(arena, tokenSpec);
|
||||
auto tamperedTokenString = fmt::format("{}.{}", tamperedTokenPart.toString(), signaturePart.toString());
|
||||
std::tie(verifyOk, verifyErr) = authz::jwt::verifyToken(StringRef(tamperedTokenString), publicKey);
|
||||
|
@ -608,12 +634,12 @@ TEST_CASE("/fdbrpc/TokenSign/JWT/ToStringRef") {
|
|||
t.notBeforeUnixTime = 789ul;
|
||||
t.keyId = "keyId"_sr;
|
||||
t.tokenId = "tokenId"_sr;
|
||||
StringRef tenants[2]{ "tenant1"_sr, "tenant2"_sr };
|
||||
t.tenants = VectorRef<StringRef>(tenants, 2);
|
||||
authz::TenantId tenants[2]{ 0x1ll, 0xabcdefabcdefll };
|
||||
t.tenants = VectorRef<authz::TenantId>(tenants, 2);
|
||||
auto arena = Arena();
|
||||
auto tokenStr = toStringRef(arena, t);
|
||||
auto tokenStrExpected =
|
||||
"alg=ES256 kid=keyId iss=issuer sub=subject aud=[aud1,aud2,aud3] iat=123 exp=456 nbf=789 jti=tokenId tenants=[tenant1,tenant2]"_sr;
|
||||
"alg=ES256 kid=keyId iss=issuer sub=subject aud=[aud1,aud2,aud3] iat=123 exp=456 nbf=789 jti=tokenId tenants=[0x1,0xabcdefabcdef]"_sr;
|
||||
if (tokenStr != tokenStrExpected) {
|
||||
fmt::print("Expected: {}\nGot : {}\n", tokenStrExpected.toStringView(), tokenStr.toStringView());
|
||||
ASSERT(false);
|
||||
|
|
|
@ -29,38 +29,48 @@ namespace {
|
|||
// converts std::optional<STANDARD_TYPE(S)> to Optional<FLOW_TYPE(T)>
|
||||
template <class T, class S>
|
||||
void convertAndAssign(Arena& arena, Optional<T>& to, const std::optional<S>& from) {
|
||||
if (!from.has_value()) {
|
||||
to.reset();
|
||||
return;
|
||||
}
|
||||
if constexpr (std::is_same_v<S, std::vector<std::string>>) {
|
||||
static_assert(std::is_same_v<T, VectorRef<StringRef>>,
|
||||
"Source type std::vector<std::string> must convert to VectorRef<StringRef>");
|
||||
if (from.has_value()) {
|
||||
const auto& value = from.value();
|
||||
if (value.empty()) {
|
||||
to = VectorRef<StringRef>();
|
||||
} else {
|
||||
// no need to deep copy string because we have the underlying memory for the duration of token signing.
|
||||
auto buf = new (arena) StringRef[value.size()];
|
||||
for (auto i = 0u; i < value.size(); i++) {
|
||||
buf[i] = StringRef(value[i]);
|
||||
}
|
||||
to = VectorRef<StringRef>(buf, value.size());
|
||||
const auto& value = from.value();
|
||||
if (value.empty()) {
|
||||
to = VectorRef<StringRef>();
|
||||
} else {
|
||||
// no need to deep copy string because we have the underlying memory for the duration of token signing.
|
||||
auto buf = new (arena) StringRef[value.size()];
|
||||
for (auto i = 0u; i < value.size(); i++) {
|
||||
buf[i] = StringRef(value[i]);
|
||||
}
|
||||
to = VectorRef<StringRef>(buf, value.size());
|
||||
}
|
||||
} else if constexpr (std::is_same_v<S, std::vector<int64_t>>) {
|
||||
static_assert(std::is_same_v<T, VectorRef<int64_t>>,
|
||||
"Source type std::vector<int64_t> must convert to VectorRef<int64_t>");
|
||||
const auto& value = from.value();
|
||||
if (value.empty()) {
|
||||
to = VectorRef<int64_t>();
|
||||
} else {
|
||||
auto buf = new (arena) int64_t[value.size()];
|
||||
for (auto i = 0; i < value.size(); i++)
|
||||
buf[i] = value[i];
|
||||
to = VectorRef<int64_t>(buf, value.size());
|
||||
}
|
||||
} else if constexpr (std::is_same_v<S, std::string>) {
|
||||
static_assert(std::is_same_v<T, StringRef>, "Source type std::string must convert to StringRef");
|
||||
if (from.has_value()) {
|
||||
const auto& value = from.value();
|
||||
// no need to deep copy string because we have the underlying memory for the duration of token signing.
|
||||
to = StringRef(value);
|
||||
}
|
||||
const auto& value = from.value();
|
||||
// no need to deep copy string because we have the underlying memory for the duration of token signing.
|
||||
to = StringRef(value);
|
||||
} else {
|
||||
static_assert(
|
||||
std::is_same_v<S, T>,
|
||||
"Source types that aren't std::vector<std::string> or std::string must have the same destination type");
|
||||
static_assert(std::is_same_v<S, T>,
|
||||
"Source types that aren't std::vector<std::string>, std::vector<int64_t>, or std::string must "
|
||||
"have the same destination type");
|
||||
static_assert(std::is_trivially_copy_assignable_v<S>,
|
||||
"Source types that aren't std::vector<std::string> or std::string must not use heap memory");
|
||||
if (from.has_value()) {
|
||||
to = from.value();
|
||||
}
|
||||
to = from.value();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -76,9 +76,7 @@ struct serializable_traits<TenantInfo> : std::true_type {
|
|||
if constexpr (Archiver::isDeserializing) {
|
||||
bool tenantAuthorized = FLOW_KNOBS->ALLOW_TOKENLESS_TENANT_ACCESS;
|
||||
if (!tenantAuthorized && v.tenantId != TenantInfo::INVALID_TENANT && v.token.present()) {
|
||||
// TODO: update tokens to be ID based
|
||||
// tenantAuthorized = TokenCache::instance().validate(v.tenantId, v.token.get());
|
||||
tenantAuthorized = true;
|
||||
tenantAuthorized = TokenCache::instance().validate(v.tenantId, v.token.get());
|
||||
}
|
||||
v.trusted = FlowTransport::transport().currentDeliveryPeerIsTrusted();
|
||||
v.tenantAuthorized = tenantAuthorized;
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#ifndef TOKENCACHE_H_
|
||||
#define TOKENCACHE_H_
|
||||
#include "fdbrpc/TenantName.h"
|
||||
#include "fdbrpc/TokenSpec.h"
|
||||
#include "flow/Arena.h"
|
||||
|
||||
class TokenCache : NonCopyable {
|
||||
|
@ -31,7 +32,7 @@ public:
|
|||
~TokenCache();
|
||||
static void createInstance();
|
||||
static TokenCache& instance();
|
||||
bool validate(TenantNameRef tenant, StringRef token);
|
||||
bool validate(authz::TenantId tenant, StringRef token);
|
||||
};
|
||||
|
||||
#endif // TOKENCACHE_H_
|
||||
|
|
|
@ -28,6 +28,8 @@
|
|||
|
||||
namespace authz {
|
||||
|
||||
using TenantId = int64_t;
|
||||
|
||||
enum class Algorithm : int {
|
||||
RS256,
|
||||
ES256,
|
||||
|
@ -67,7 +69,7 @@ struct BasicTokenSpec {
|
|||
OptionalType<uint64_t> expiresAtUnixTime; // exp
|
||||
OptionalType<uint64_t> notBeforeUnixTime; // nbf
|
||||
OptionalType<StringType> tokenId; // jti
|
||||
OptionalType<VectorType<StringType>> tenants; // tenants
|
||||
OptionalType<VectorType<TenantId>> tenants; // tenants
|
||||
// signature part
|
||||
StringType signature;
|
||||
};
|
||||
|
|
|
@ -695,6 +695,10 @@ struct NetNotifiedQueue final : NotifiedQueue<T>, FlowReceiver, FastAllocated<Ne
|
|||
if (!message.verify()) {
|
||||
if constexpr (HasReply<T>) {
|
||||
message.reply.sendError(permission_denied());
|
||||
TraceEvent(SevWarnAlways, "UnauthorizedAccessPrevented")
|
||||
.detail("RequestType", typeid(T).name())
|
||||
.detail("ClientIP", FlowTransport::transport().currentDeliveryPeerAddress())
|
||||
.log();
|
||||
}
|
||||
} else {
|
||||
this->send(std::move(message));
|
||||
|
|
|
@ -346,7 +346,7 @@ public:
|
|||
double checkDisabled(const std::string& desc) const;
|
||||
|
||||
// generate authz token for use in simulation environment
|
||||
Standalone<StringRef> makeToken(StringRef tenantName, uint64_t ttlSecondsFromNow);
|
||||
Standalone<StringRef> makeToken(int64_t tenantId, uint64_t ttlSecondsFromNow);
|
||||
|
||||
static thread_local ProcessInfo* currentProcess;
|
||||
|
||||
|
|
|
@ -174,7 +174,7 @@ void ISimulator::displayWorkers() const {
|
|||
return;
|
||||
}
|
||||
|
||||
Standalone<StringRef> ISimulator::makeToken(StringRef tenantName, uint64_t ttlSecondsFromNow) {
|
||||
Standalone<StringRef> ISimulator::makeToken(int64_t tenantId, uint64_t ttlSecondsFromNow) {
|
||||
ASSERT_GT(authKeys.size(), 0);
|
||||
auto tokenSpec = authz::jwt::TokenRef{};
|
||||
auto [keyName, key] = *authKeys.begin();
|
||||
|
@ -188,7 +188,7 @@ Standalone<StringRef> ISimulator::makeToken(StringRef tenantName, uint64_t ttlSe
|
|||
tokenSpec.expiresAtUnixTime = now + ttlSecondsFromNow;
|
||||
auto const tokenId = deterministicRandom()->randomAlphaNumeric(10);
|
||||
tokenSpec.tokenId = StringRef(tokenId);
|
||||
tokenSpec.tenants = VectorRef<StringRef>(&tenantName, 1);
|
||||
tokenSpec.tenants = VectorRef<int64_t>(&tenantId, 1);
|
||||
auto ret = Standalone<StringRef>();
|
||||
ret.contents() = authz::jwt::signToken(ret.arena(), tokenSpec, key);
|
||||
return ret;
|
||||
|
|
|
@ -3511,7 +3511,7 @@ ACTOR Future<Void> monitorLeaderWithDelayedCandidacy(
|
|||
extern void setupStackSignal();
|
||||
|
||||
ACTOR Future<Void> serveProtocolInfo() {
|
||||
state RequestStream<ProtocolInfoRequest> protocolInfo(
|
||||
state PublicRequestStream<ProtocolInfoRequest> protocolInfo(
|
||||
PeerCompatibilityPolicy{ RequirePeer::AtLeast, ProtocolVersion::withStableInterfaces() });
|
||||
protocolInfo.makeWellKnownEndpoint(WLTOKEN_PROTOCOL_INFO, TaskPriority::DefaultEndpoint);
|
||||
loop {
|
||||
|
|
|
@ -46,6 +46,7 @@ struct AuthzSecurityWorkload : TestWorkload {
|
|||
std::vector<Future<Void>> clients;
|
||||
Arena arena;
|
||||
Reference<Tenant> tenant;
|
||||
Reference<Tenant> anotherTenant;
|
||||
TenantName tenantName;
|
||||
TenantName anotherTenantName;
|
||||
Standalone<StringRef> signedToken;
|
||||
|
@ -68,10 +69,6 @@ struct AuthzSecurityWorkload : TestWorkload {
|
|||
tLogConfigKey = getOption(options, "tLogConfigKey"_sr, "TLogInterface"_sr);
|
||||
ASSERT(g_network->isSimulated());
|
||||
// make it comfortably longer than the timeout of the workload
|
||||
signedToken = g_simulator->makeToken(
|
||||
tenantName, uint64_t(std::lround(getCheckTimeout())) + uint64_t(std::lround(testDuration)) + 100);
|
||||
signedTokenAnotherTenant = g_simulator->makeToken(
|
||||
anotherTenantName, uint64_t(std::lround(getCheckTimeout())) + uint64_t(std::lround(testDuration)) + 100);
|
||||
testFunctions.push_back(
|
||||
[this](Database cx) { return testCrossTenantGetDisallowed(this, cx, PositiveTestcase::True); });
|
||||
testFunctions.push_back(
|
||||
|
@ -87,10 +84,15 @@ struct AuthzSecurityWorkload : TestWorkload {
|
|||
|
||||
Future<Void> setup(Database const& cx) override {
|
||||
tenant = makeReference<Tenant>(cx, tenantName);
|
||||
return tenant->ready();
|
||||
anotherTenant = makeReference<Tenant>(cx, anotherTenantName);
|
||||
return tenant->ready() && anotherTenant->ready();
|
||||
}
|
||||
|
||||
Future<Void> start(Database const& cx) override {
|
||||
signedToken = g_simulator->makeToken(
|
||||
tenant->id(), uint64_t(std::lround(getCheckTimeout())) + uint64_t(std::lround(testDuration)) + 100);
|
||||
signedTokenAnotherTenant = g_simulator->makeToken(
|
||||
anotherTenant->id(), uint64_t(std::lround(getCheckTimeout())) + uint64_t(std::lround(testDuration)) + 100);
|
||||
for (int c = 0; c < actorCount; c++)
|
||||
clients.push_back(timeout(runTestClient(this, cx->clone()), testDuration, Void()));
|
||||
return waitForAll(clients);
|
||||
|
@ -128,9 +130,9 @@ struct AuthzSecurityWorkload : TestWorkload {
|
|||
StringRef key,
|
||||
StringRef value) {
|
||||
state Transaction tr(cx, tenant);
|
||||
self->setAuthToken(tr, token);
|
||||
loop {
|
||||
try {
|
||||
self->setAuthToken(tr, token);
|
||||
tr.set(key, value);
|
||||
wait(tr.commit());
|
||||
return tr.getCommittedVersion();
|
||||
|
@ -146,10 +148,10 @@ struct AuthzSecurityWorkload : TestWorkload {
|
|||
Standalone<StringRef> token,
|
||||
StringRef key) {
|
||||
state Transaction tr(cx, tenant);
|
||||
self->setAuthToken(tr, token);
|
||||
loop {
|
||||
try {
|
||||
// trigger GetKeyServerLocationsRequest and subsequent cache update
|
||||
self->setAuthToken(tr, token);
|
||||
wait(success(tr.get(key)));
|
||||
auto loc = cx->getCachedLocation(tr.trState->getTenantInfo(), key);
|
||||
if (loc.present()) {
|
||||
|
@ -342,10 +344,10 @@ struct AuthzSecurityWorkload : TestWorkload {
|
|||
state Version committedVersion =
|
||||
wait(setAndCommitKeyValueAndGetVersion(self, cx, self->tenant, self->signedToken, key, value));
|
||||
state Transaction tr(cx, self->tenant);
|
||||
self->setAuthToken(tr, self->signedToken);
|
||||
state Optional<Value> tLogConfigString;
|
||||
loop {
|
||||
try {
|
||||
self->setAuthToken(tr, self->signedToken);
|
||||
Optional<Value> value = wait(tr.get(self->tLogConfigKey));
|
||||
ASSERT(value.present());
|
||||
tLogConfigString = value;
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include "flow/serialize.h"
|
||||
#include "fdbrpc/simulator.h"
|
||||
#include "fdbrpc/TokenSign.h"
|
||||
#include "fdbrpc/TenantInfo.h"
|
||||
#include "fdbclient/FDBOptions.g.h"
|
||||
#include "fdbclient/NativeAPI.actor.h"
|
||||
#include "fdbserver/TesterInterface.actor.h"
|
||||
|
@ -39,13 +40,20 @@ template <>
|
|||
struct CycleMembers<true> {
|
||||
Arena arena;
|
||||
TenantName tenant;
|
||||
int64_t tenantId;
|
||||
Standalone<StringRef> signedToken;
|
||||
bool useToken;
|
||||
};
|
||||
|
||||
template <bool>
|
||||
struct CycleWorkload;
|
||||
|
||||
ACTOR Future<Void> prepareToken(Database cx, CycleWorkload<true>* self);
|
||||
|
||||
template <bool MultiTenancy>
|
||||
struct CycleWorkload : TestWorkload, CycleMembers<MultiTenancy> {
|
||||
static constexpr auto NAME = MultiTenancy ? "TenantCycle" : "Cycle";
|
||||
static constexpr auto TenantEnabled = MultiTenancy;
|
||||
int actorCount, nodeCount;
|
||||
double testDuration, transactionsPerSecond, minExpectedTransactionsPerSecond, traceParentProbability;
|
||||
Key keyPrefix;
|
||||
|
@ -68,17 +76,18 @@ struct CycleWorkload : TestWorkload, CycleMembers<MultiTenancy> {
|
|||
ASSERT(g_network->isSimulated());
|
||||
this->useToken = getOption(options, "useToken"_sr, true);
|
||||
this->tenant = getOption(options, "tenant"_sr, "CycleTenant"_sr);
|
||||
// make it comfortably longer than the timeout of the workload
|
||||
this->signedToken = g_simulator->makeToken(
|
||||
this->tenant, uint64_t(std::lround(getCheckTimeout())) + uint64_t(std::lround(testDuration)) + 100);
|
||||
this->tenantId = TenantInfo::INVALID_TENANT;
|
||||
}
|
||||
}
|
||||
|
||||
Future<Void> setup(Database const& cx) override {
|
||||
Future<Void> prepare;
|
||||
if constexpr (MultiTenancy) {
|
||||
cx->defaultTenant = this->tenant;
|
||||
prepare = prepareToken(cx, this);
|
||||
} else {
|
||||
prepare = Void();
|
||||
}
|
||||
return bulkSetup(cx, this, nodeCount, Promise<double>());
|
||||
return runAfter(prepare, [this, cx](Void) { return bulkSetup(cx, this, nodeCount, Promise<double>()); });
|
||||
}
|
||||
Future<Void> start(Database const& cx) override {
|
||||
if constexpr (MultiTenancy) {
|
||||
|
@ -144,7 +153,6 @@ struct CycleWorkload : TestWorkload, CycleMembers<MultiTenancy> {
|
|||
state double tstart = now();
|
||||
state int r = deterministicRandom()->randomInt(0, self->nodeCount);
|
||||
state Transaction tr(cx);
|
||||
self->setAuthToken(tr);
|
||||
if (deterministicRandom()->random01() <= self->traceParentProbability) {
|
||||
state Span span("CycleClient"_loc);
|
||||
TraceEvent("CycleTracingTransaction", span.context.traceID).log();
|
||||
|
@ -153,6 +161,7 @@ struct CycleWorkload : TestWorkload, CycleMembers<MultiTenancy> {
|
|||
}
|
||||
while (true) {
|
||||
try {
|
||||
self->setAuthToken(tr);
|
||||
// Reverse next and next^2 node
|
||||
Optional<Value> v = wait(tr.get(self->key(r)));
|
||||
if (!v.present())
|
||||
|
@ -291,9 +300,9 @@ struct CycleWorkload : TestWorkload, CycleMembers<MultiTenancy> {
|
|||
// One client checks the validity of the cycle
|
||||
state Transaction tr(cx);
|
||||
state int retryCount = 0;
|
||||
self->setAuthToken(tr);
|
||||
loop {
|
||||
try {
|
||||
self->setAuthToken(tr);
|
||||
state Version v = wait(tr.getReadVersion());
|
||||
RangeResult data = wait(tr.getRange(firstGreaterOrEqual(doubleToTestKey(0.0, self->keyPrefix)),
|
||||
firstGreaterOrEqual(doubleToTestKey(1.0, self->keyPrefix)),
|
||||
|
@ -316,5 +325,17 @@ struct CycleWorkload : TestWorkload, CycleMembers<MultiTenancy> {
|
|||
}
|
||||
};
|
||||
|
||||
ACTOR Future<Void> prepareToken(Database cx, CycleWorkload<true>* self) {
|
||||
cx->defaultTenant = self->tenant;
|
||||
int64_t tenantId = wait(cx->lookupTenant(self->tenant));
|
||||
self->tenantId = tenantId;
|
||||
ASSERT_NE(self->tenantId, TenantInfo::INVALID_TENANT);
|
||||
// make the lifetime comfortably longer than the timeout of the workload
|
||||
self->signedToken = g_simulator->makeToken(self->tenantId,
|
||||
uint64_t(std::lround(self->getCheckTimeout())) +
|
||||
uint64_t(std::lround(self->testDuration)) + 100);
|
||||
return Void();
|
||||
}
|
||||
|
||||
WorkloadFactory<CycleWorkload<false>> CycleWorkloadFactory(UntrustedMode::False);
|
||||
WorkloadFactory<CycleWorkload<true>> TenantCycleWorkloadFactory(UntrustedMode::True);
|
||||
|
|
|
@ -140,10 +140,11 @@ void FlowKnobs::initialize(Randomize randomize, IsSimulated isSimulated) {
|
|||
|
||||
//Authorization
|
||||
init( ALLOW_TOKENLESS_TENANT_ACCESS, false );
|
||||
init( AUDIT_LOGGING_ENABLED, true );
|
||||
init( PUBLIC_KEY_FILE_MAX_SIZE, 1024 * 1024 );
|
||||
init( PUBLIC_KEY_FILE_REFRESH_INTERVAL_SECONDS, 30 );
|
||||
init( MAX_CACHED_EXPIRED_TOKENS, 1024 );
|
||||
init( AUDIT_TIME_WINDOW, 5.0 );
|
||||
init( TOKEN_CACHE_SIZE, 2000 );
|
||||
|
||||
//AsyncFileCached
|
||||
init( PAGE_CACHE_4K, 2LL<<30 );
|
||||
|
@ -307,14 +308,12 @@ void FlowKnobs::initialize(Randomize randomize, IsSimulated isSimulated) {
|
|||
if ( randomize && BUGGIFY) { ENCRYPT_CIPHER_KEY_CACHE_TTL = deterministicRandom()->randomInt(2, 10) * 60; }
|
||||
init( ENCRYPT_KEY_REFRESH_INTERVAL, isSimulated ? 60 : 8 * 60 );
|
||||
if ( randomize && BUGGIFY) { ENCRYPT_KEY_REFRESH_INTERVAL = deterministicRandom()->randomInt(2, 10); }
|
||||
init( TOKEN_CACHE_SIZE, 100 );
|
||||
init( ENCRYPT_KEY_CACHE_LOGGING_INTERVAL, 5.0 );
|
||||
init( ENCRYPT_KEY_CACHE_LOGGING_SKETCH_ACCURACY, 0.01 );
|
||||
// Refer to EncryptUtil::EncryptAuthTokenAlgo for more details
|
||||
init( ENCRYPT_HEADER_AUTH_TOKEN_ENABLED, false ); if ( randomize && BUGGIFY ) { ENCRYPT_HEADER_AUTH_TOKEN_ENABLED = !ENCRYPT_HEADER_AUTH_TOKEN_ENABLED; }
|
||||
init( ENCRYPT_HEADER_AUTH_TOKEN_ALGO, 0 ); if ( randomize && ENCRYPT_HEADER_AUTH_TOKEN_ENABLED ) { ENCRYPT_HEADER_AUTH_TOKEN_ALGO = getRandomAuthTokenAlgo(); }
|
||||
|
||||
|
||||
// REST Client
|
||||
init( RESTCLIENT_MAX_CONNECTIONPOOL_SIZE, 10 );
|
||||
init( RESTCLIENT_CONNECT_TRIES, 10 );
|
||||
|
|
|
@ -204,10 +204,11 @@ public:
|
|||
|
||||
// Authorization
|
||||
bool ALLOW_TOKENLESS_TENANT_ACCESS;
|
||||
bool AUDIT_LOGGING_ENABLED;
|
||||
int PUBLIC_KEY_FILE_MAX_SIZE;
|
||||
int PUBLIC_KEY_FILE_REFRESH_INTERVAL_SECONDS;
|
||||
int MAX_CACHED_EXPIRED_TOKENS;
|
||||
double AUDIT_TIME_WINDOW;
|
||||
int TOKEN_CACHE_SIZE;
|
||||
|
||||
// AsyncFileCached
|
||||
int64_t PAGE_CACHE_4K;
|
||||
|
@ -375,9 +376,6 @@ public:
|
|||
bool ENCRYPT_HEADER_AUTH_TOKEN_ENABLED;
|
||||
int ENCRYPT_HEADER_AUTH_TOKEN_ALGO;
|
||||
|
||||
// Authorization
|
||||
int TOKEN_CACHE_SIZE;
|
||||
|
||||
// RESTClient
|
||||
int RESTCLIENT_MAX_CONNECTIONPOOL_SIZE;
|
||||
int RESTCLIENT_CONNECT_TRIES;
|
||||
|
|
|
@ -126,7 +126,7 @@ if(WITH_PYTHON)
|
|||
add_fdb_test(TEST_FILES fast/AtomicBackupToDBCorrectness.toml)
|
||||
add_fdb_test(TEST_FILES fast/AtomicOps.toml)
|
||||
add_fdb_test(TEST_FILES fast/AtomicOpsApiCorrectness.toml)
|
||||
add_fdb_test(TEST_FILES fast/AuthzSecurity.toml IGNORE) # TODO re-enable once authz uses ID tokens
|
||||
add_fdb_test(TEST_FILES fast/AuthzSecurity.toml)
|
||||
add_fdb_test(TEST_FILES fast/AutomaticIdempotency.toml)
|
||||
add_fdb_test(TEST_FILES fast/BackupAzureBlobCorrectness.toml IGNORE)
|
||||
add_fdb_test(TEST_FILES fast/BackupS3BlobCorrectness.toml IGNORE)
|
||||
|
@ -504,15 +504,37 @@ if(WITH_PYTHON)
|
|||
set_tests_properties(authorization_venv_setup PROPERTIES FIXTURES_SETUP authz_virtual_env TIMEOUT 60)
|
||||
|
||||
set(authz_script_dir ${CMAKE_SOURCE_DIR}/tests/authorization)
|
||||
set(authz_test_cmd "${authz_venv_activate} && pytest ${authz_script_dir}/authz_test.py -rA --build-dir ${CMAKE_BINARY_DIR} -vvv")
|
||||
# TODO: reenable when authz is updated to validate based on tenant IDs
|
||||
#add_test(
|
||||
#NAME token_based_tenant_authorization
|
||||
#WORKING_DIRECTORY ${authz_venv_dir}
|
||||
#COMMAND bash -c ${authz_test_cmd})
|
||||
#set_tests_properties(token_based_tenant_authorization PROPERTIES ENVIRONMENT "PYTHONPATH=${CMAKE_SOURCE_DIR}/tests/TestRunner;${ld_env_name}=${CMAKE_BINARY_DIR}/lib")
|
||||
#set_tests_properties(token_based_tenant_authorization PROPERTIES FIXTURES_REQUIRED authz_virtual_env)
|
||||
#set_tests_properties(token_based_tenant_authorization PROPERTIES TIMEOUT 120)
|
||||
set(enable_grv_cache 0 1)
|
||||
set(force_mvc 0 1)
|
||||
foreach(is_grv_cache_enabled IN LISTS enable_grv_cache)
|
||||
foreach(is_mvc_forced IN LISTS force_mvc)
|
||||
if(NOT is_mvc_forced AND is_grv_cache_enabled)
|
||||
continue() # grv cache requires setting up of shared database state which is only available in MVC mode
|
||||
endif()
|
||||
set(authz_test_name "authz")
|
||||
set(test_opt "")
|
||||
if(is_grv_cache_enabled)
|
||||
string(APPEND test_opt " --use-grv-cache")
|
||||
string(APPEND authz_test_name "_with_grv_cache")
|
||||
else()
|
||||
string(APPEND authz_test_name "_no_grv_cache")
|
||||
endif()
|
||||
if(is_mvc_forced)
|
||||
string(APPEND test_opt " --force-multi-version-client")
|
||||
string(APPEND authz_test_name "_with_forced_mvc")
|
||||
else()
|
||||
string(APPEND authz_test_name "_no_forced_mvc")
|
||||
endif()
|
||||
set(authz_test_cmd "${authz_venv_activate} && pytest ${authz_script_dir}/authz_test.py -rA --build-dir ${CMAKE_BINARY_DIR} -vvv${test_opt}")
|
||||
add_test(
|
||||
NAME ${authz_test_name}
|
||||
WORKING_DIRECTORY ${authz_venv_dir}
|
||||
COMMAND bash -c ${authz_test_cmd})
|
||||
set_tests_properties(${authz_test_name} PROPERTIES ENVIRONMENT "PYTHONPATH=${CMAKE_SOURCE_DIR}/tests/TestRunner;${ld_env_name}=${CMAKE_BINARY_DIR}/lib")
|
||||
set_tests_properties(${authz_test_name} PROPERTIES FIXTURES_REQUIRED authz_virtual_env)
|
||||
set_tests_properties(${authz_test_name} PROPERTIES TIMEOUT 120)
|
||||
endforeach()
|
||||
endforeach()
|
||||
endif()
|
||||
else()
|
||||
message(WARNING "Python not found, won't configure ctest")
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from authlib.jose import JsonWebKey, KeySet, jwt
|
||||
from typing import List
|
||||
import json
|
||||
import time
|
||||
|
||||
def private_key_gen(kty: str, kid: str):
|
||||
assert kty == "EC" or kty == "RSA"
|
||||
|
@ -29,3 +30,17 @@ def token_gen(private_key, claims, headers={}):
|
|||
"kid": private_key.kid,
|
||||
}
|
||||
return jwt.encode(headers, claims, private_key)
|
||||
|
||||
def token_claim_1h(tenant_id: int):
|
||||
# JWT claim that is valid for 1 hour since time of invocation
|
||||
now = time.time()
|
||||
return {
|
||||
"iss": "fdb-authz-tester",
|
||||
"sub": "authz-test",
|
||||
"aud": ["tmp-cluster"],
|
||||
"iat": now,
|
||||
"nbf": now - 1,
|
||||
"exp": now + 60 * 60,
|
||||
"jti": random_alphanum_str(10),
|
||||
"tenants": [to_str(base64.b64encode(tenant_id.to_bytes(8, "big")))],
|
||||
}
|
||||
|
|
|
@ -30,7 +30,9 @@ class PortProvider:
|
|||
while True:
|
||||
counter += 1
|
||||
if counter > MAX_PORT_ACQUIRE_ATTEMPTS:
|
||||
assert False, "Failed to acquire a free port after {} attempts".format(MAX_PORT_ACQUIRE_ATTEMPTS)
|
||||
assert False, "Failed to acquire a free port after {} attempts".format(
|
||||
MAX_PORT_ACQUIRE_ATTEMPTS
|
||||
)
|
||||
port = PortProvider._get_free_port_internal()
|
||||
if port in self._used_ports:
|
||||
continue
|
||||
|
@ -42,7 +44,12 @@ class PortProvider:
|
|||
self._used_ports.add(port)
|
||||
return port
|
||||
except OSError:
|
||||
print("Failed to lock file {}. Trying to aquire another port".format(lock_path), file=sys.stderr)
|
||||
print(
|
||||
"Failed to lock file {}. Trying to aquire another port".format(
|
||||
lock_path
|
||||
),
|
||||
file=sys.stderr,
|
||||
)
|
||||
pass
|
||||
|
||||
def is_port_in_use(port):
|
||||
|
@ -59,10 +66,11 @@ class PortProvider:
|
|||
fd.close()
|
||||
try:
|
||||
os.remove(fd.name)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
self._lock_files.clear()
|
||||
|
||||
|
||||
class TLSConfig:
|
||||
# Passing a negative chain length generates expired leaf certificate
|
||||
def __init__(
|
||||
|
@ -75,6 +83,7 @@ class TLSConfig:
|
|||
self.client_chain_len = client_chain_len
|
||||
self.verify_peers = verify_peers
|
||||
|
||||
|
||||
class LocalCluster:
|
||||
configuration_template = """
|
||||
## foundationdb.conf
|
||||
|
@ -170,8 +179,12 @@ logdir = {logdir}
|
|||
|
||||
if self.first_port is not None:
|
||||
self.last_used_port = int(self.first_port) - 1
|
||||
self.server_ports = {server_id: self.__next_port() for server_id in range(self.process_number)}
|
||||
self.server_by_port = {port: server_id for server_id, port in self.server_ports.items()}
|
||||
self.server_ports = {
|
||||
server_id: self.__next_port() for server_id in range(self.process_number)
|
||||
}
|
||||
self.server_by_port = {
|
||||
port: server_id for server_id, port in self.server_ports.items()
|
||||
}
|
||||
self.next_server_id = self.process_number
|
||||
self.cluster_desc = random_alphanum_string(8)
|
||||
self.cluster_secret = random_alphanum_string(8)
|
||||
|
@ -198,16 +211,23 @@ logdir = {logdir}
|
|||
self.client_ca_file = self.cert.joinpath("client_ca.pem")
|
||||
|
||||
if self.authorization_kty:
|
||||
assert self.authorization_keypair_id, "keypair ID must be set to enable authorization"
|
||||
assert (
|
||||
self.authorization_keypair_id
|
||||
), "keypair ID must be set to enable authorization"
|
||||
self.public_key_json_file = self.etc.joinpath("public_keys.json")
|
||||
self.private_key = private_key_gen(
|
||||
kty=self.authorization_kty, kid=self.authorization_keypair_id)
|
||||
kty=self.authorization_kty, kid=self.authorization_keypair_id
|
||||
)
|
||||
self.public_key_jwks_str = public_keyset_from_keys([self.private_key])
|
||||
with open(self.public_key_json_file, "w") as pubkeyfile:
|
||||
pubkeyfile.write(self.public_key_jwks_str)
|
||||
self.authorization_private_key_pem_file = self.etc.joinpath("authorization_private_key.pem")
|
||||
self.authorization_private_key_pem_file = self.etc.joinpath(
|
||||
"authorization_private_key.pem"
|
||||
)
|
||||
with open(self.authorization_private_key_pem_file, "w") as privkeyfile:
|
||||
privkeyfile.write(self.private_key.as_pem(is_private=True).decode("utf8"))
|
||||
privkeyfile.write(
|
||||
self.private_key.as_pem(is_private=True).decode("utf8")
|
||||
)
|
||||
|
||||
if create_config:
|
||||
self.create_cluster_file()
|
||||
|
@ -246,7 +266,10 @@ logdir = {logdir}
|
|||
authz_public_key_config=self.authz_public_key_conf_string(),
|
||||
optional_tls=":tls" if self.tls_config is not None else "",
|
||||
custom_config="\n".join(
|
||||
["{} = {}".format(key, value) for key, value in self.custom_config.items()]
|
||||
[
|
||||
"{} = {}".format(key, value)
|
||||
for key, value in self.custom_config.items()
|
||||
]
|
||||
),
|
||||
use_future_protocol_version="use-future-protocol-version = true"
|
||||
if self.use_future_protocol_version
|
||||
|
@ -259,7 +282,11 @@ logdir = {logdir}
|
|||
# Then 4000,4001,4002,4003,4004 will be used as ports
|
||||
# If port number is not given, we will randomly pick free ports
|
||||
for server_id in self.active_servers:
|
||||
f.write("[fdbserver.{server_port}]\n".format(server_port=self.server_ports[server_id]))
|
||||
f.write(
|
||||
"[fdbserver.{server_port}]\n".format(
|
||||
server_port=self.server_ports[server_id]
|
||||
)
|
||||
)
|
||||
if self.use_legacy_conf_syntax:
|
||||
f.write("machine_id = {}\n".format(server_id))
|
||||
else:
|
||||
|
@ -353,8 +380,12 @@ logdir = {logdir}
|
|||
]
|
||||
if self.use_future_protocol_version:
|
||||
args += ["--use-future-protocol-version"]
|
||||
res = subprocess.run(args, env=self.process_env(), stderr=stderr, stdout=stdout, timeout=timeout)
|
||||
assert res.returncode == 0, "fdbcli command {} failed with {}".format(cmd, res.returncode)
|
||||
res = subprocess.run(
|
||||
args, env=self.process_env(), stderr=stderr, stdout=stdout, timeout=timeout
|
||||
)
|
||||
assert res.returncode == 0, "fdbcli command {} failed with {}".format(
|
||||
cmd, res.returncode
|
||||
)
|
||||
return res.stdout
|
||||
|
||||
# Execute a fdbcli command
|
||||
|
@ -376,9 +407,9 @@ logdir = {logdir}
|
|||
# Generate and install test certificate chains and keys
|
||||
def create_tls_cert(self):
|
||||
assert self.tls_config is not None, "TLS not enabled"
|
||||
assert self.mkcert_binary.exists() and self.mkcert_binary.is_file(), "{} does not exist".format(
|
||||
self.mkcert_binary
|
||||
)
|
||||
assert (
|
||||
self.mkcert_binary.exists() and self.mkcert_binary.is_file()
|
||||
), "{} does not exist".format(self.mkcert_binary)
|
||||
self.cert.mkdir(exist_ok=True)
|
||||
server_chain_len = abs(self.tls_config.server_chain_len)
|
||||
client_chain_len = abs(self.tls_config.client_chain_len)
|
||||
|
@ -425,7 +456,9 @@ logdir = {logdir}
|
|||
|
||||
def authz_public_key_conf_string(self):
|
||||
if self.public_key_json_file is not None:
|
||||
return "authorization-public-key-file = {}".format(self.public_key_json_file)
|
||||
return "authorization-public-key-file = {}".format(
|
||||
self.public_key_json_file
|
||||
)
|
||||
else:
|
||||
return ""
|
||||
|
||||
|
@ -441,7 +474,11 @@ logdir = {logdir}
|
|||
return {}
|
||||
|
||||
servers_found = set()
|
||||
addresses = [proc_info["address"] for proc_info in status["cluster"]["processes"].values() if filter(proc_info)]
|
||||
addresses = [
|
||||
proc_info["address"]
|
||||
for proc_info in status["cluster"]["processes"].values()
|
||||
if filter(proc_info)
|
||||
]
|
||||
for addr in addresses:
|
||||
port = int(addr.split(":", 1)[1])
|
||||
assert port in self.server_by_port, "Unknown server port {}".format(port)
|
||||
|
@ -472,7 +509,9 @@ logdir = {logdir}
|
|||
# Need to call save_config to apply the changes
|
||||
def add_server(self):
|
||||
server_id = self.next_server_id
|
||||
assert server_id not in self.server_ports, "Server ID {} is already in use".format(server_id)
|
||||
assert (
|
||||
server_id not in self.server_ports
|
||||
), "Server ID {} is already in use".format(server_id)
|
||||
self.next_server_id += 1
|
||||
port = self.__next_port()
|
||||
self.server_ports[server_id] = port
|
||||
|
@ -483,7 +522,9 @@ logdir = {logdir}
|
|||
# Remove the server with the given ID from the cluster
|
||||
# Need to call save_config to apply the changes
|
||||
def remove_server(self, server_id):
|
||||
assert server_id in self.active_servers, "Server {} does not exist".format(server_id)
|
||||
assert server_id in self.active_servers, "Server {} does not exist".format(
|
||||
server_id
|
||||
)
|
||||
self.active_servers.remove(server_id)
|
||||
|
||||
# Wait until changes to the set of servers (additions & removals) are applied
|
||||
|
@ -501,7 +542,10 @@ logdir = {logdir}
|
|||
|
||||
# Apply changes to the set of the coordinators, based on the current value of self.coordinators
|
||||
def update_coordinators(self):
|
||||
urls = ["{}:{}".format(self.ip_address, self.server_ports[id]) for id in self.coordinators]
|
||||
urls = [
|
||||
"{}:{}".format(self.ip_address, self.server_ports[id])
|
||||
for id in self.coordinators
|
||||
]
|
||||
self.fdbcli_exec("coordinators {}".format(" ".join(urls)))
|
||||
|
||||
# Wait until the changes to the set of the coordinators are applied
|
||||
|
@ -521,13 +565,20 @@ logdir = {logdir}
|
|||
for server_id in self.coordinators:
|
||||
assert (
|
||||
connection_string.find(str(self.server_ports[server_id])) != -1
|
||||
), "Missing coordinator {} port {} in the cluster file".format(server_id, self.server_ports[server_id])
|
||||
), "Missing coordinator {} port {} in the cluster file".format(
|
||||
server_id, self.server_ports[server_id]
|
||||
)
|
||||
|
||||
# Exclude the servers with the given ID from the cluster, i.e. move out their data
|
||||
# The method waits until the changes are applied
|
||||
def exclude_servers(self, server_ids):
|
||||
urls = ["{}:{}".format(self.ip_address, self.server_ports[id]) for id in server_ids]
|
||||
self.fdbcli_exec("exclude FORCE {}".format(" ".join(urls)), timeout=EXCLUDE_SERVERS_TIMEOUT_SEC)
|
||||
urls = [
|
||||
"{}:{}".format(self.ip_address, self.server_ports[id]) for id in server_ids
|
||||
]
|
||||
self.fdbcli_exec(
|
||||
"exclude FORCE {}".format(" ".join(urls)),
|
||||
timeout=EXCLUDE_SERVERS_TIMEOUT_SEC,
|
||||
)
|
||||
|
||||
# Perform a cluster wiggle: replace all servers with new ones
|
||||
def cluster_wiggle(self):
|
||||
|
@ -552,7 +603,11 @@ logdir = {logdir}
|
|||
)
|
||||
self.save_config()
|
||||
self.wait_for_server_update()
|
||||
print("New servers successfully added to the cluster. Time: {}s".format(time.time() - start_time))
|
||||
print(
|
||||
"New servers successfully added to the cluster. Time: {}s".format(
|
||||
time.time() - start_time
|
||||
)
|
||||
)
|
||||
|
||||
# Step 2: change coordinators
|
||||
start_time = time.time()
|
||||
|
@ -561,12 +616,20 @@ logdir = {logdir}
|
|||
self.coordinators = new_coordinators.copy()
|
||||
self.update_coordinators()
|
||||
self.wait_for_coordinator_update()
|
||||
print("Coordinators successfully changed. Time: {}s".format(time.time() - start_time))
|
||||
print(
|
||||
"Coordinators successfully changed. Time: {}s".format(
|
||||
time.time() - start_time
|
||||
)
|
||||
)
|
||||
|
||||
# Step 3: exclude old servers from the cluster, i.e. move out their data
|
||||
start_time = time.time()
|
||||
self.exclude_servers(old_servers)
|
||||
print("Old servers successfully excluded from the cluster. Time: {}s".format(time.time() - start_time))
|
||||
print(
|
||||
"Old servers successfully excluded from the cluster. Time: {}s".format(
|
||||
time.time() - start_time
|
||||
)
|
||||
)
|
||||
|
||||
# Step 4: remove the old servers
|
||||
start_time = time.time()
|
||||
|
@ -574,11 +637,21 @@ logdir = {logdir}
|
|||
self.remove_server(server_id)
|
||||
self.save_config()
|
||||
self.wait_for_server_update()
|
||||
print("Old servers successfully removed from the cluster. Time: {}s".format(time.time() - start_time))
|
||||
print(
|
||||
"Old servers successfully removed from the cluster. Time: {}s".format(
|
||||
time.time() - start_time
|
||||
)
|
||||
)
|
||||
|
||||
# Check the cluster log for errors
|
||||
def check_cluster_logs(self, error_limit=100):
|
||||
sev40s = subprocess.getoutput("grep -r 'Severity=\"40\"' {}".format(self.log.as_posix())).rstrip().splitlines()
|
||||
sev40s = (
|
||||
subprocess.getoutput(
|
||||
"grep -r 'Severity=\"40\"' {}".format(self.log.as_posix())
|
||||
)
|
||||
.rstrip()
|
||||
.splitlines()
|
||||
)
|
||||
|
||||
err_cnt = 0
|
||||
for line in sev40s:
|
||||
|
|
|
@ -104,8 +104,10 @@ if __name__ == "__main__":
|
|||
|
||||
tls_config = None
|
||||
if args.tls_enabled:
|
||||
tls_config = TLSConfig(server_chain_len=args.server_cert_chain_len,
|
||||
client_chain_len=args.client_cert_chain_len)
|
||||
tls_config = TLSConfig(
|
||||
server_chain_len=args.server_cert_chain_len,
|
||||
client_chain_len=args.client_cert_chain_len,
|
||||
)
|
||||
errcode = 1
|
||||
with TempCluster(
|
||||
args.build_dir,
|
||||
|
@ -133,16 +135,15 @@ if __name__ == "__main__":
|
|||
("@SERVER_CA_FILE@", str(cluster.server_ca_file)),
|
||||
("@CLIENT_CERT_FILE@", str(cluster.client_cert_file)),
|
||||
("@CLIENT_KEY_FILE@", str(cluster.client_key_file)),
|
||||
("@CLIENT_CA_FILE@", str(cluster.client_ca_file))]
|
||||
("@CLIENT_CA_FILE@", str(cluster.client_ca_file)),
|
||||
]
|
||||
|
||||
for cmd in args.cmd:
|
||||
for (placeholder, value) in substitution_table:
|
||||
cmd = cmd.replace(placeholder, value)
|
||||
cmd_args.append(cmd)
|
||||
env = dict(**os.environ)
|
||||
env["FDB_CLUSTER_FILE"] = env.get(
|
||||
"FDB_CLUSTER_FILE", cluster.cluster_file
|
||||
)
|
||||
env["FDB_CLUSTER_FILE"] = env.get("FDB_CLUSTER_FILE", cluster.cluster_file)
|
||||
print("command: {}".format(cmd_args))
|
||||
errcode = subprocess.run(
|
||||
cmd_args, stdout=sys.stdout, stderr=sys.stderr, env=env
|
||||
|
|
|
@ -36,6 +36,7 @@ class _admin_request(object):
|
|||
|
||||
def main_loop(main_pipe, pipe):
|
||||
main_pipe.close()
|
||||
use_grv_cache = False
|
||||
db = None
|
||||
while True:
|
||||
try:
|
||||
|
@ -49,7 +50,14 @@ def main_loop(main_pipe, pipe):
|
|||
args = req.args
|
||||
resp = True
|
||||
try:
|
||||
if op == "connect":
|
||||
if op == "configure_client":
|
||||
force_multi_version_client, use_grv_cache, logdir = req.args[:3]
|
||||
if force_multi_version_client:
|
||||
fdb.options.set_disable_client_bypass()
|
||||
if len(logdir) > 0:
|
||||
fdb.options.set_trace_enable(logdir)
|
||||
fdb.options.set_trace_file_identifier("adminserver")
|
||||
elif op == "connect":
|
||||
db = fdb.open(req.args[0])
|
||||
elif op == "configure_tls":
|
||||
keyfile, certfile, cafile = req.args[:3]
|
||||
|
|
|
@ -21,16 +21,18 @@
|
|||
import admin_server
|
||||
import argparse
|
||||
import authlib
|
||||
import base64
|
||||
import fdb
|
||||
import os
|
||||
import pytest
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from multiprocessing import Process, Pipe
|
||||
from typing import Union
|
||||
from authz_util import token_gen, private_key_gen, public_keyset_from_keys, alg_from_kty
|
||||
from util import random_alphanum_str, random_alphanum_bytes, to_str, to_bytes, KeyFileReverter, token_claim_1h, wait_until_tenant_tr_succeeds, wait_until_tenant_tr_fails
|
||||
from util import random_alphanum_str, random_alphanum_bytes, to_str, to_bytes, KeyFileReverter, wait_until_tenant_tr_succeeds, wait_until_tenant_tr_fails
|
||||
|
||||
special_key_ranges = [
|
||||
("transaction description", b"/description", b"/description\x00"),
|
||||
|
@ -43,25 +45,79 @@ special_key_ranges = [
|
|||
("kill storage", b"/globals/killStorage", b"/globals/killStorage\x00"),
|
||||
]
|
||||
|
||||
def test_simple_tenant_access(cluster, default_tenant, tenant_tr_gen):
|
||||
# handler for when looping is assumed with usage
|
||||
# e.g. GRV cache enablement removes the guarantee that transaction always gets the latest read version before it starts,
|
||||
# which could introduce arbitrary conflicts even on idle test clusters, and those need to be resolved via retrying.
|
||||
def loop_until_success(tr: fdb.Transaction, func):
|
||||
while True:
|
||||
try:
|
||||
return func(tr)
|
||||
except fdb.FDBError as e:
|
||||
tr.on_error(e).wait()
|
||||
|
||||
# test that token option on a transaction should survive soft transaction resets,
|
||||
# be cleared by hard transaction resets, and also clearable by setting empty value
|
||||
def test_token_option(cluster, default_tenant, tenant_tr_gen, token_claim_1h):
|
||||
token = token_gen(cluster.private_key, token_claim_1h(default_tenant))
|
||||
tr = tenant_tr_gen(default_tenant)
|
||||
tr.options.set_authorization_token(token)
|
||||
tr[b"abc"] = b"def"
|
||||
tr.commit().wait()
|
||||
def commit_some_value(tr):
|
||||
tr[b"abc"] = b"def"
|
||||
return tr.commit().wait()
|
||||
|
||||
loop_until_success(tr, commit_some_value)
|
||||
|
||||
# token option should survive a soft reset by a retryable error
|
||||
tr.on_error(fdb.FDBError(1020)).wait() # not_committed (conflict)
|
||||
def read_back_value(tr):
|
||||
return tr[b"abc"].value
|
||||
value = loop_until_success(tr, read_back_value)
|
||||
assert value == b"def", f"unexpected value found: {value}"
|
||||
|
||||
tr.reset() # token shouldn't survive a hard reset
|
||||
try:
|
||||
value = read_back_value(tr)
|
||||
assert False, "expected permission_denied, but succeeded"
|
||||
except fdb.FDBError as e:
|
||||
assert e.code == 6000, f"expected permission_denied, got {e} instead"
|
||||
|
||||
tr.reset()
|
||||
tr.options.set_authorization_token(token)
|
||||
tr.options.set_authorization_token() # option set with no arg should clear the token
|
||||
|
||||
try:
|
||||
value = read_back_value(tr)
|
||||
assert False, "expected permission_denied, but succeeded"
|
||||
except fdb.FDBError as e:
|
||||
assert e.code == 6000, f"expected permission_denied, got {e} instead"
|
||||
|
||||
def test_simple_tenant_access(cluster, default_tenant, tenant_tr_gen, token_claim_1h):
|
||||
token = token_gen(cluster.private_key, token_claim_1h(default_tenant))
|
||||
tr = tenant_tr_gen(default_tenant)
|
||||
tr.options.set_authorization_token(token)
|
||||
assert tr[b"abc"] == b"def", "tenant write transaction not visible"
|
||||
def commit_some_value(tr):
|
||||
tr[b"abc"] = b"def"
|
||||
tr.commit().wait()
|
||||
|
||||
def test_cross_tenant_access_disallowed(cluster, default_tenant, tenant_gen, tenant_tr_gen):
|
||||
loop_until_success(tr, commit_some_value)
|
||||
tr = tenant_tr_gen(default_tenant)
|
||||
tr.options.set_authorization_token(token)
|
||||
def read_back_value(tr):
|
||||
return tr[b"abc"].value
|
||||
value = loop_until_success(tr, read_back_value)
|
||||
assert value == b"def", "tenant write transaction not visible"
|
||||
|
||||
def test_cross_tenant_access_disallowed(cluster, default_tenant, tenant_gen, tenant_tr_gen, token_claim_1h):
|
||||
# use default tenant token with second tenant transaction and see it fail
|
||||
second_tenant = random_alphanum_bytes(12)
|
||||
tenant_gen(second_tenant)
|
||||
token_second = token_gen(cluster.private_key, token_claim_1h(second_tenant))
|
||||
tr_second = tenant_tr_gen(second_tenant)
|
||||
tr_second.options.set_authorization_token(token_second)
|
||||
tr_second[b"abc"] = b"def"
|
||||
tr_second.commit().wait()
|
||||
def commit_some_value(tr):
|
||||
tr[b"abc"] = b"def"
|
||||
return tr.commit().wait()
|
||||
loop_until_success(tr_second, commit_some_value)
|
||||
token_default = token_gen(cluster.private_key, token_claim_1h(default_tenant))
|
||||
tr_second = tenant_tr_gen(second_tenant)
|
||||
tr_second.options.set_authorization_token(token_default)
|
||||
|
@ -81,6 +137,44 @@ def test_cross_tenant_access_disallowed(cluster, default_tenant, tenant_gen, ten
|
|||
except fdb.FDBError as e:
|
||||
assert e.code == 6000, f"expected permission_denied, got {e} instead"
|
||||
|
||||
def test_cross_tenant_raw_access_disallowed_with_token(cluster, db, default_tenant, tenant_gen, tenant_tr_gen, token_claim_1h):
|
||||
def commit_some_value(tr):
|
||||
tr[b"abc"] = b"def"
|
||||
return tr.commit().wait()
|
||||
|
||||
second_tenant = random_alphanum_bytes(12)
|
||||
tenant_gen(second_tenant)
|
||||
|
||||
first_tenant_token_claim = token_claim_1h(default_tenant)
|
||||
second_tenant_token_claim = token_claim_1h(second_tenant)
|
||||
# create a token that's good for both tenants
|
||||
first_tenant_token_claim["tenants"] += second_tenant_token_claim["tenants"]
|
||||
token = token_gen(cluster.private_key, first_tenant_token_claim)
|
||||
tr_first = tenant_tr_gen(default_tenant)
|
||||
tr_first.options.set_authorization_token(token)
|
||||
loop_until_success(tr_first, commit_some_value)
|
||||
tr_second = tenant_tr_gen(second_tenant)
|
||||
tr_second.options.set_authorization_token(token)
|
||||
loop_until_success(tr_second, commit_some_value)
|
||||
|
||||
# now try a normal keyspace transaction to raw-access both tenants' keyspace at once, with token
|
||||
tr = db.create_transaction()
|
||||
tr.options.set_authorization_token(token)
|
||||
tr.options.set_raw_access()
|
||||
prefix_first = base64.b64decode(first_tenant_token_claim["tenants"][0])
|
||||
assert len(prefix_first) == 8
|
||||
prefix_second = base64.b64decode(first_tenant_token_claim["tenants"][1])
|
||||
assert len(prefix_second) == 8
|
||||
lhs = min(prefix_first, prefix_second)
|
||||
rhs = max(prefix_first, prefix_second)
|
||||
rhs = bytearray(rhs)
|
||||
rhs[-1] += 1 # exclusive end
|
||||
try:
|
||||
value = tr[lhs:bytes(rhs)].to_list()
|
||||
assert False, f"expected permission_denied, but succeeded, value: {value}"
|
||||
except fdb.FDBError as e:
|
||||
assert e.code == 6000, f"expected permission_denied, got {e} instead"
|
||||
|
||||
def test_system_and_special_key_range_disallowed(db, tenant_tr_gen):
|
||||
second_tenant = random_alphanum_bytes(12)
|
||||
try:
|
||||
|
@ -137,7 +231,7 @@ def test_system_and_special_key_range_disallowed(db, tenant_tr_gen):
|
|||
|
||||
def test_public_key_set_rollover(
|
||||
kty, public_key_refresh_interval,
|
||||
cluster, default_tenant, tenant_gen, tenant_tr_gen):
|
||||
cluster, default_tenant, tenant_gen, tenant_tr_gen, token_claim_1h):
|
||||
new_kid = random_alphanum_str(12)
|
||||
new_kty = "EC" if kty == "RSA" else "RSA"
|
||||
new_key = private_key_gen(kty=new_kty, kid=new_kid)
|
||||
|
@ -160,16 +254,16 @@ def test_public_key_set_rollover(
|
|||
with KeyFileReverter(cluster.public_key_json_file, old_key_json, delay):
|
||||
with open(cluster.public_key_json_file, "w") as keyfile:
|
||||
keyfile.write(interim_set)
|
||||
wait_until_tenant_tr_succeeds(second_tenant, new_key, tenant_tr_gen, max_repeat, delay)
|
||||
wait_until_tenant_tr_succeeds(second_tenant, new_key, tenant_tr_gen, max_repeat, delay, token_claim_1h)
|
||||
print("interim key set activated")
|
||||
final_set = public_keyset_from_keys([new_key])
|
||||
print(f"final keyset: {final_set}")
|
||||
with open(cluster.public_key_json_file, "w") as keyfile:
|
||||
keyfile.write(final_set)
|
||||
wait_until_tenant_tr_fails(default_tenant, cluster.private_key, tenant_tr_gen, max_repeat, delay)
|
||||
wait_until_tenant_tr_fails(default_tenant, cluster.private_key, tenant_tr_gen, max_repeat, delay, token_claim_1h)
|
||||
|
||||
def test_public_key_set_broken_file_tolerance(
|
||||
cluster, public_key_refresh_interval, default_tenant, tenant_tr_gen):
|
||||
cluster, public_key_refresh_interval, default_tenant, tenant_tr_gen, token_claim_1h):
|
||||
delay = public_key_refresh_interval
|
||||
# retry limit in waiting for keyset file update to propagate to FDB server's internal keyset
|
||||
max_repeat = 10
|
||||
|
@ -187,10 +281,10 @@ def test_public_key_set_broken_file_tolerance(
|
|||
with open(cluster.public_key_json_file, "w") as keyfile:
|
||||
keyfile.write('{"keys":[]}')
|
||||
# eventually internal key set will become empty and won't accept any new tokens
|
||||
wait_until_tenant_tr_fails(default_tenant, cluster.private_key, tenant_tr_gen, max_repeat, delay)
|
||||
wait_until_tenant_tr_fails(default_tenant, cluster.private_key, tenant_tr_gen, max_repeat, delay, token_claim_1h)
|
||||
|
||||
def test_public_key_set_deletion_tolerance(
|
||||
cluster, public_key_refresh_interval, default_tenant, tenant_tr_gen):
|
||||
cluster, public_key_refresh_interval, default_tenant, tenant_tr_gen, token_claim_1h):
|
||||
delay = public_key_refresh_interval
|
||||
# retry limit in waiting for keyset file update to propagate to FDB server's internal keyset
|
||||
max_repeat = 10
|
||||
|
@ -200,16 +294,16 @@ def test_public_key_set_deletion_tolerance(
|
|||
with open(cluster.public_key_json_file, "w") as keyfile:
|
||||
keyfile.write('{"keys":[]}')
|
||||
time.sleep(delay)
|
||||
wait_until_tenant_tr_fails(default_tenant, cluster.private_key, tenant_tr_gen, max_repeat, delay)
|
||||
wait_until_tenant_tr_fails(default_tenant, cluster.private_key, tenant_tr_gen, max_repeat, delay, token_claim_1h)
|
||||
os.remove(cluster.public_key_json_file)
|
||||
time.sleep(delay * 2)
|
||||
with open(cluster.public_key_json_file, "w") as keyfile:
|
||||
keyfile.write(cluster.public_key_jwks_str)
|
||||
# eventually updated key set should take effect and transaction should be accepted
|
||||
wait_until_tenant_tr_succeeds(default_tenant, cluster.private_key, tenant_tr_gen, max_repeat, delay)
|
||||
wait_until_tenant_tr_succeeds(default_tenant, cluster.private_key, tenant_tr_gen, max_repeat, delay, token_claim_1h)
|
||||
|
||||
def test_public_key_set_empty_file_tolerance(
|
||||
cluster, public_key_refresh_interval, default_tenant, tenant_tr_gen):
|
||||
cluster, public_key_refresh_interval, default_tenant, tenant_tr_gen, token_claim_1h):
|
||||
delay = public_key_refresh_interval
|
||||
# retry limit in waiting for keyset file update to propagate to FDB server's internal keyset
|
||||
max_repeat = 10
|
||||
|
@ -219,7 +313,7 @@ def test_public_key_set_empty_file_tolerance(
|
|||
with open(cluster.public_key_json_file, "w") as keyfile:
|
||||
keyfile.write('{"keys":[]}')
|
||||
# eventually internal key set will become empty and won't accept any new tokens
|
||||
wait_until_tenant_tr_fails(default_tenant, cluster.private_key, tenant_tr_gen, max_repeat, delay)
|
||||
wait_until_tenant_tr_fails(default_tenant, cluster.private_key, tenant_tr_gen, max_repeat, delay, token_claim_1h)
|
||||
# empty the key file
|
||||
with open(cluster.public_key_json_file, "w") as keyfile:
|
||||
pass
|
||||
|
@ -227,9 +321,9 @@ def test_public_key_set_empty_file_tolerance(
|
|||
with open(cluster.public_key_json_file, "w") as keyfile:
|
||||
keyfile.write(cluster.public_key_jwks_str)
|
||||
# eventually key file should update and transactions should go through
|
||||
wait_until_tenant_tr_succeeds(default_tenant, cluster.private_key, tenant_tr_gen, max_repeat, delay)
|
||||
wait_until_tenant_tr_succeeds(default_tenant, cluster.private_key, tenant_tr_gen, max_repeat, delay, token_claim_1h)
|
||||
|
||||
def test_bad_token(cluster, default_tenant, tenant_tr_gen):
|
||||
def test_bad_token(cluster, default_tenant, tenant_tr_gen, token_claim_1h):
|
||||
def del_attr(d, attr):
|
||||
del d[attr]
|
||||
return d
|
||||
|
|
|
@ -22,10 +22,14 @@ import fdb
|
|||
import pytest
|
||||
import subprocess
|
||||
import admin_server
|
||||
import base64
|
||||
import glob
|
||||
import time
|
||||
from local_cluster import TLSConfig
|
||||
from tmp_cluster import TempCluster
|
||||
from typing import Union
|
||||
from util import random_alphanum_str, random_alphanum_bytes, to_str, to_bytes
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
fdb.api_version(720)
|
||||
|
||||
|
@ -48,6 +52,18 @@ def pytest_addoption(parser):
|
|||
default=1,
|
||||
dest="public_key_refresh_interval",
|
||||
help="How frequently server refreshes authorization public key file")
|
||||
parser.addoption(
|
||||
"--force-multi-version-client",
|
||||
action="store_true",
|
||||
default=False,
|
||||
dest="force_multi_version_client",
|
||||
help="Whether to force multi-version client mode")
|
||||
parser.addoption(
|
||||
"--use-grv-cache",
|
||||
action="store_true",
|
||||
default=False,
|
||||
dest="use_grv_cache",
|
||||
help="Whether to make client use cached GRV from database context")
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def build_dir(request):
|
||||
|
@ -65,6 +81,14 @@ def trusted_client(request):
|
|||
def public_key_refresh_interval(request):
|
||||
return request.config.option.public_key_refresh_interval
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def force_multi_version_client(request):
|
||||
return request.config.option.force_multi_version_client
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def use_grv_cache(request):
|
||||
return request.config.option.use_grv_cache
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def kid():
|
||||
return random_alphanum_str(12)
|
||||
|
@ -77,7 +101,7 @@ def admin_ipc():
|
|||
server.join()
|
||||
|
||||
@pytest.fixture(autouse=True, scope=cluster_scope)
|
||||
def cluster(admin_ipc, build_dir, public_key_refresh_interval, trusted_client):
|
||||
def cluster(admin_ipc, build_dir, public_key_refresh_interval, trusted_client, force_multi_version_client, use_grv_cache):
|
||||
with TempCluster(
|
||||
build_dir=build_dir,
|
||||
tls_config=TLSConfig(server_chain_len=3, client_chain_len=2),
|
||||
|
@ -90,19 +114,44 @@ def cluster(admin_ipc, build_dir, public_key_refresh_interval, trusted_client):
|
|||
keyfile = str(cluster.client_key_file)
|
||||
certfile = str(cluster.client_cert_file)
|
||||
cafile = str(cluster.server_ca_file)
|
||||
logdir = str(cluster.log)
|
||||
fdb.options.set_tls_key_path(keyfile if trusted_client else "")
|
||||
fdb.options.set_tls_cert_path(certfile if trusted_client else "")
|
||||
fdb.options.set_tls_ca_path(cafile)
|
||||
fdb.options.set_trace_enable()
|
||||
fdb.options.set_trace_enable(logdir)
|
||||
fdb.options.set_trace_file_identifier("testclient")
|
||||
if force_multi_version_client:
|
||||
fdb.options.set_disable_client_bypass()
|
||||
admin_ipc.request("configure_client", [force_multi_version_client, use_grv_cache, logdir])
|
||||
admin_ipc.request("configure_tls", [keyfile, certfile, cafile])
|
||||
admin_ipc.request("connect", [str(cluster.cluster_file)])
|
||||
yield cluster
|
||||
err_count = {}
|
||||
for file in glob.glob(str(cluster.log.joinpath("*.xml"))):
|
||||
lineno = 1
|
||||
for line in open(file):
|
||||
try:
|
||||
doc = ET.fromstring(line)
|
||||
except:
|
||||
continue
|
||||
if doc.attrib.get("Severity", "") == "40":
|
||||
ev_type = doc.attrib.get("Type", "[unset]")
|
||||
err = doc.attrib.get("Error", "[unset]")
|
||||
tup = (file, ev_type, err)
|
||||
err_count[tup] = err_count.get(tup, 0) + 1
|
||||
lineno += 1
|
||||
print("Sev40 Summary:")
|
||||
if len(err_count) == 0:
|
||||
print(" No errors")
|
||||
else:
|
||||
for tup, count in err_count.items():
|
||||
print(" {}: {}".format(tup, count))
|
||||
|
||||
@pytest.fixture
|
||||
def db(cluster, admin_ipc):
|
||||
db = fdb.open(str(cluster.cluster_file))
|
||||
db.options.set_transaction_timeout(2000) # 2 seconds
|
||||
db.options.set_transaction_retry_limit(3)
|
||||
#db.options.set_transaction_timeout(2000) # 2 seconds
|
||||
db.options.set_transaction_retry_limit(10)
|
||||
yield db
|
||||
admin_ipc.request("cleanup_database")
|
||||
db = None
|
||||
|
@ -129,8 +178,30 @@ def default_tenant(tenant_gen, tenant_del):
|
|||
tenant_del(tenant)
|
||||
|
||||
@pytest.fixture
|
||||
def tenant_tr_gen(db):
|
||||
def tenant_tr_gen(db, use_grv_cache):
|
||||
def fn(tenant):
|
||||
tenant = db.open_tenant(to_bytes(tenant))
|
||||
return tenant.create_transaction()
|
||||
tr = tenant.create_transaction()
|
||||
if use_grv_cache:
|
||||
tr.options.set_use_grv_cache()
|
||||
return tr
|
||||
return fn
|
||||
|
||||
@pytest.fixture
|
||||
def token_claim_1h(db):
|
||||
# JWT claim that is valid for 1 hour since time of invocation
|
||||
def fn(tenant_name: Union[bytes, str]):
|
||||
tenant = db.open_tenant(to_bytes(tenant_name))
|
||||
tenant_id = tenant.get_id().wait()
|
||||
now = time.time()
|
||||
return {
|
||||
"iss": "fdb-authz-tester",
|
||||
"sub": "authz-test",
|
||||
"aud": ["tmp-cluster"],
|
||||
"iat": now,
|
||||
"nbf": now - 1,
|
||||
"exp": now + 60 * 60,
|
||||
"jti": random_alphanum_str(10),
|
||||
"tenants": [to_str(base64.b64encode(tenant_id.to_bytes(8, "big")))],
|
||||
}
|
||||
return fn
|
||||
|
|
|
@ -48,23 +48,9 @@ class KeyFileReverter(object):
|
|||
print(f"key file reverted. waiting {self.refresh_delay * 2} seconds for the update to take effect...")
|
||||
time.sleep(self.refresh_delay * 2)
|
||||
|
||||
# JWT claim that is valid for 1 hour since time of invocation
|
||||
def token_claim_1h(tenant_name):
|
||||
now = time.time()
|
||||
return {
|
||||
"iss": "fdb-authz-tester",
|
||||
"sub": "authz-test",
|
||||
"aud": ["tmp-cluster"],
|
||||
"iat": now,
|
||||
"nbf": now - 1,
|
||||
"exp": now + 60 * 60,
|
||||
"jti": random_alphanum_str(10),
|
||||
"tenants": [to_str(base64.b64encode(tenant_name))],
|
||||
}
|
||||
|
||||
# repeat try-wait loop up to max_repeat times until both read and write tr fails for tenant with permission_denied
|
||||
# important: only use this function if you don't have any data dependencies to key "abc"
|
||||
def wait_until_tenant_tr_fails(tenant, private_key, tenant_tr_gen, max_repeat, delay):
|
||||
def wait_until_tenant_tr_fails(tenant, private_key, tenant_tr_gen, max_repeat, delay, token_claim_1h):
|
||||
repeat = 0
|
||||
read_blocked = False
|
||||
write_blocked = False
|
||||
|
@ -97,7 +83,7 @@ def wait_until_tenant_tr_fails(tenant, private_key, tenant_tr_gen, max_repeat, d
|
|||
|
||||
# repeat try-wait loop up to max_repeat times until both read and write tr succeeds for tenant
|
||||
# important: only use this function if you don't have any data dependencies to key "abc"
|
||||
def wait_until_tenant_tr_succeeds(tenant, private_key, tenant_tr_gen, max_repeat, delay):
|
||||
def wait_until_tenant_tr_succeeds(tenant, private_key, tenant_tr_gen, max_repeat, delay, token_claim_1h):
|
||||
repeat = 0
|
||||
token = token_gen(private_key, token_claim_1h(tenant))
|
||||
while repeat < max_repeat:
|
||||
|
|
|
@ -2,6 +2,10 @@
|
|||
allowDefaultTenant = false
|
||||
tenantModes = ['optional', 'required']
|
||||
|
||||
[[knobs]]
|
||||
audit_logging_enabled = false
|
||||
max_trace_lines = 2000000
|
||||
|
||||
[[test]]
|
||||
testTitle = 'TenantCreation'
|
||||
|
||||
|
|
|
@ -2,6 +2,10 @@
|
|||
allowDefaultTenant = false
|
||||
tenantModes = ['optional', 'required']
|
||||
|
||||
[[knobs]]
|
||||
audit_logging_enabled = false
|
||||
max_trace_lines = 2000000
|
||||
|
||||
[[test]]
|
||||
testTitle = 'TenantCreation'
|
||||
|
||||
|
|
|
@ -4,6 +4,8 @@ tenantModes = ['optional', 'required']
|
|||
|
||||
[[knobs]]
|
||||
allow_tokenless_tenant_access = true
|
||||
audit_logging_enabled = false
|
||||
max_trace_lines = 2000000
|
||||
|
||||
[[test]]
|
||||
testTitle = 'TenantCreation'
|
||||
|
|
Loading…
Reference in New Issue