Merge pull request #7731 from sfc-gh-jshim/authz-general-tls-and-integration-test
Authz general tls and integration test
This commit is contained in:
commit
ac6889286c
|
@ -622,6 +622,13 @@ func (o TransactionOptions) SetUseGrvCache() error {
|
|||
return o.setOpt(1101, nil)
|
||||
}
|
||||
|
||||
// Attach given authorization token to the transaction such that subsequent tenant-aware requests are authorized
|
||||
//
|
||||
// Parameter: A JSON Web Token authorized to access data belonging to one or more tenants, indicated by 'tenants' claim of the token's payload.
|
||||
func (o TransactionOptions) SetAuthorizationToken(param string) error {
|
||||
return o.setOpt(2000, []byte(param))
|
||||
}
|
||||
|
||||
type StreamingMode int
|
||||
|
||||
const (
|
||||
|
|
|
@ -306,9 +306,8 @@ description is not currently required but encouraged.
|
|||
description="Specifically instruct this transaction to NOT use cached GRV. Primarily used for the read version cache's background updater to avoid attempting to read a cached entry in specific situations."
|
||||
hidden="true"/>
|
||||
<Option name="authorization_token" code="2000"
|
||||
description="Add a given authorization token to the network thread so that future requests are authorized"
|
||||
paramType="String" paramDescription="A signed token serialized using flatbuffers"
|
||||
hidden="true" />
|
||||
description="Attach given authorization token to the transaction such that subsequent tenant-aware requests are authorized"
|
||||
paramType="String" paramDescription="A JSON Web Token authorized to access data belonging to one or more tenants, indicated by 'tenants' claim of the token's payload."/>
|
||||
</Scope>
|
||||
|
||||
<!-- The enumeration values matter - do not change them without
|
||||
|
|
|
@ -80,3 +80,5 @@ target_compile_definitions(fdbrpc_sampling PRIVATE -DENABLE_SAMPLING)
|
|||
if(WIN32)
|
||||
add_dependencies(fdbrpc_sampling_actors fdbrpc_actors)
|
||||
endif()
|
||||
|
||||
add_subdirectory(tests)
|
||||
|
|
|
@ -34,6 +34,7 @@
|
|||
#include "fdbrpc/fdbrpc.h"
|
||||
#include "fdbrpc/FailureMonitor.h"
|
||||
#include "fdbrpc/HealthMonitor.h"
|
||||
#include "fdbrpc/JsonWebKeySet.h"
|
||||
#include "fdbrpc/genericactors.actor.h"
|
||||
#include "fdbrpc/IPAllowList.h"
|
||||
#include "fdbrpc/TokenCache.h"
|
||||
|
@ -44,8 +45,10 @@
|
|||
#include "flow/Net2Packet.h"
|
||||
#include "flow/TDMetric.actor.h"
|
||||
#include "flow/ObjectSerializer.h"
|
||||
#include "flow/Platform.h"
|
||||
#include "flow/ProtocolVersion.h"
|
||||
#include "flow/UnitTest.h"
|
||||
#include "flow/WatchFile.actor.h"
|
||||
#define XXH_INLINE_ALL
|
||||
#include "flow/xxhash.h"
|
||||
#include "flow/actorcompiler.h" // This must be the last #include.
|
||||
|
@ -309,6 +312,7 @@ public:
|
|||
|
||||
// Returns true if given network address 'address' is one of the address we are listening on.
|
||||
bool isLocalAddress(const NetworkAddress& address) const;
|
||||
void applyPublicKeySet(StringRef jwkSetString);
|
||||
|
||||
NetworkAddressCachedString localAddresses;
|
||||
std::vector<Future<Void>> listeners;
|
||||
|
@ -341,6 +345,7 @@ public:
|
|||
|
||||
Future<Void> multiVersionCleanup;
|
||||
Future<Void> pingLogger;
|
||||
Future<Void> publicKeyFileWatch;
|
||||
|
||||
std::unordered_map<Standalone<StringRef>, PublicKey> publicKeys;
|
||||
};
|
||||
|
@ -958,7 +963,7 @@ void Peer::onIncomingConnection(Reference<Peer> self, Reference<IConnection> con
|
|||
.detail("FromAddr", conn->getPeerAddress())
|
||||
.detail("CanonicalAddr", destination)
|
||||
.detail("IsPublic", destination.isPublic())
|
||||
.detail("Trusted", self->transport->allowList(conn->getPeerAddress().ip));
|
||||
.detail("Trusted", self->transport->allowList(conn->getPeerAddress().ip) && conn->hasTrustedPeer());
|
||||
|
||||
connect.cancel();
|
||||
prependConnectPacket();
|
||||
|
@ -1257,7 +1262,7 @@ ACTOR static Future<Void> connectionReader(TransportData* transport,
|
|||
state bool incompatiblePeerCounted = false;
|
||||
state NetworkAddress peerAddress;
|
||||
state ProtocolVersion peerProtocolVersion;
|
||||
state bool trusted = transport->allowList(conn->getPeerAddress().ip);
|
||||
state bool trusted = transport->allowList(conn->getPeerAddress().ip) && conn->hasTrustedPeer();
|
||||
peerAddress = conn->getPeerAddress();
|
||||
|
||||
if (!peer) {
|
||||
|
@ -1529,6 +1534,27 @@ bool TransportData::isLocalAddress(const NetworkAddress& address) const {
|
|||
address == localAddresses.getAddressList().secondaryAddress.get());
|
||||
}
|
||||
|
||||
void TransportData::applyPublicKeySet(StringRef jwkSetString) {
|
||||
auto jwks = JsonWebKeySet::parse(jwkSetString, {});
|
||||
if (!jwks.present())
|
||||
throw pkey_decode_error();
|
||||
const auto& keySet = jwks.get().keys;
|
||||
publicKeys.clear();
|
||||
int numPrivateKeys = 0;
|
||||
for (auto [keyName, key] : keySet) {
|
||||
// ignore private keys
|
||||
if (key.isPublic()) {
|
||||
publicKeys[keyName] = key.getPublic();
|
||||
} else {
|
||||
numPrivateKeys++;
|
||||
}
|
||||
}
|
||||
TraceEvent(SevInfo, "AuthzPublicKeySetApply").detail("NumPublicKeys", publicKeys.size());
|
||||
if (numPrivateKeys > 0) {
|
||||
TraceEvent(SevWarnAlways, "AuthzPublicKeySetContainsPrivateKeys").detail("NumPrivateKeys", numPrivateKeys);
|
||||
}
|
||||
}
|
||||
|
||||
ACTOR static Future<Void> multiVersionCleanupWorker(TransportData* self) {
|
||||
loop {
|
||||
wait(delay(FLOW_KNOBS->CONNECTION_CLEANUP_DELAY));
|
||||
|
@ -1967,3 +1993,62 @@ void FlowTransport::removePublicKey(StringRef name) {
|
|||
void FlowTransport::removeAllPublicKeys() {
|
||||
self->publicKeys.clear();
|
||||
}
|
||||
|
||||
void FlowTransport::loadPublicKeyFile(const std::string& filePath) {
|
||||
if (!fileExists(filePath)) {
|
||||
throw file_not_found();
|
||||
}
|
||||
int64_t const len = fileSize(filePath);
|
||||
if (len <= 0) {
|
||||
TraceEvent(SevWarn, "AuthzPublicKeySetEmpty").detail("Path", filePath);
|
||||
} else if (len > FLOW_KNOBS->PUBLIC_KEY_FILE_MAX_SIZE) {
|
||||
throw file_too_large();
|
||||
} else {
|
||||
auto json = readFileBytes(filePath, len);
|
||||
self->applyPublicKeySet(StringRef(json));
|
||||
}
|
||||
}
|
||||
|
||||
ACTOR static Future<Void> watchPublicKeyJwksFile(std::string filePath, TransportData* self) {
|
||||
state AsyncTrigger fileChanged;
|
||||
state Future<Void> fileWatch;
|
||||
state unsigned errorCount = 0; // error since watch start or last successful refresh
|
||||
|
||||
// Make sure this watch setup does not break due to async file system initialization not having been called
|
||||
loop {
|
||||
if (IAsyncFileSystem::filesystem())
|
||||
break;
|
||||
wait(delay(1.0));
|
||||
}
|
||||
const int& intervalSeconds = FLOW_KNOBS->PUBLIC_KEY_FILE_REFRESH_INTERVAL_SECONDS;
|
||||
fileWatch = watchFileForChanges(filePath, &fileChanged, &intervalSeconds, "AuthzPublicKeySetRefreshStatError");
|
||||
loop {
|
||||
try {
|
||||
wait(fileChanged.onTrigger());
|
||||
state Reference<IAsyncFile> file = wait(IAsyncFileSystem::filesystem()->open(
|
||||
filePath, IAsyncFile::OPEN_READONLY | IAsyncFile::OPEN_UNCACHED, 0));
|
||||
state int64_t filesize = wait(file->size());
|
||||
state std::string json(filesize, '\0');
|
||||
if (filesize > FLOW_KNOBS->PUBLIC_KEY_FILE_MAX_SIZE)
|
||||
throw file_too_large();
|
||||
if (filesize <= 0) {
|
||||
TraceEvent(SevWarn, "AuthzPublicKeySetEmpty").suppressFor(60);
|
||||
continue;
|
||||
}
|
||||
wait(success(file->read(&json[0], filesize, 0)));
|
||||
self->applyPublicKeySet(StringRef(json));
|
||||
errorCount = 0;
|
||||
} catch (Error& e) {
|
||||
if (e.code() == error_code_actor_cancelled) {
|
||||
throw;
|
||||
}
|
||||
// parse/read error
|
||||
errorCount++;
|
||||
TraceEvent(SevWarn, "AuthzPublicKeySetRefreshError").error(e).detail("ErrorCount", errorCount);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void FlowTransport::watchPublicKeyFile(const std::string& publicKeyFilePath) {
|
||||
self->publicKeyFileWatch = watchPublicKeyJwksFile(publicKeyFilePath, self);
|
||||
}
|
||||
|
|
|
@ -830,11 +830,18 @@ TEST_CASE("/fdbrpc/JsonWebKeySet/EC/PrivateKey") {
|
|||
}
|
||||
|
||||
TEST_CASE("/fdbrpc/JsonWebKeySet/RSA/PublicKey") {
|
||||
testPublicKey(&mkcert::makeRsa2048Bit);
|
||||
testPublicKey(&mkcert::makeRsa4096Bit);
|
||||
return Void();
|
||||
}
|
||||
|
||||
TEST_CASE("/fdbrpc/JsonWebKeySet/RSA/PrivateKey") {
|
||||
testPrivateKey(&mkcert::makeRsa2048Bit);
|
||||
testPrivateKey(&mkcert::makeRsa4096Bit);
|
||||
return Void();
|
||||
}
|
||||
|
||||
TEST_CASE("/fdbrpc/JsonWebKeySet/Empty") {
|
||||
auto keyset = JsonWebKeySet::parse("{\"keys\":[]}"_sr, {});
|
||||
ASSERT(keyset.present());
|
||||
ASSERT(keyset.get().keys.empty());
|
||||
return Void();
|
||||
}
|
||||
|
|
|
@ -125,6 +125,10 @@ NetworkAddress SimExternalConnection::getPeerAddress() const {
|
|||
}
|
||||
}
|
||||
|
||||
bool SimExternalConnection::hasTrustedPeer() const {
|
||||
return true;
|
||||
}
|
||||
|
||||
UID SimExternalConnection::getDebugID() const {
|
||||
return dbgid;
|
||||
}
|
||||
|
|
|
@ -177,6 +177,10 @@ bool TokenCacheImpl::validateAndAdd(double currentTime, StringRef token, Network
|
|||
CODE_PROBE(true, "Token referencing non-existing key");
|
||||
TRACE_INVALID_PARSED_TOKEN("UnknownKey", t);
|
||||
return false;
|
||||
} else if (!t.issuedAtUnixTime.present()) {
|
||||
CODE_PROBE(true, "Token has no issued-at field");
|
||||
TRACE_INVALID_PARSED_TOKEN("NoIssuedAt", t);
|
||||
return false;
|
||||
} else if (!t.expiresAtUnixTime.present()) {
|
||||
CODE_PROBE(true, "Token has no expiration time");
|
||||
TRACE_INVALID_PARSED_TOKEN("NoExpirationTime", t);
|
||||
|
@ -203,7 +207,7 @@ bool TokenCacheImpl::validateAndAdd(double currentTime, StringRef token, Network
|
|||
return false;
|
||||
} else {
|
||||
CacheEntry c;
|
||||
c.expirationTime = double(t.expiresAtUnixTime.get());
|
||||
c.expirationTime = t.expiresAtUnixTime.get();
|
||||
c.tenants.reserve(c.arena, t.tenants.get().size());
|
||||
for (auto tenant : t.tenants.get()) {
|
||||
c.tenants.push_back_deep(c.arena, tenant);
|
||||
|
@ -265,7 +269,7 @@ TEST_CASE("/fdbrpc/authz/TokenCache/BadTokens") {
|
|||
},
|
||||
{
|
||||
[](Arena&, IRandom& rng, authz::jwt::TokenRef& token) {
|
||||
token.expiresAtUnixTime = uint64_t(std::max<double>(g_network->timer() - 10 - rng.random01() * 50, 0));
|
||||
token.expiresAtUnixTime = std::max<double>(g_network->timer() - 10 - rng.random01() * 50, 0);
|
||||
},
|
||||
"ExpiredToken",
|
||||
},
|
||||
|
@ -275,10 +279,15 @@ TEST_CASE("/fdbrpc/authz/TokenCache/BadTokens") {
|
|||
},
|
||||
{
|
||||
[](Arena&, IRandom& rng, authz::jwt::TokenRef& token) {
|
||||
token.notBeforeUnixTime = uint64_t(g_network->timer() + 10 + rng.random01() * 50);
|
||||
token.notBeforeUnixTime = g_network->timer() + 10 + rng.random01() * 50;
|
||||
},
|
||||
"TokenNotYetValid",
|
||||
},
|
||||
{
|
||||
[](Arena&, IRandom&, authz::jwt::TokenRef& token) { token.issuedAtUnixTime.reset(); },
|
||||
"NoIssuedAt",
|
||||
},
|
||||
|
||||
{
|
||||
[](Arena& arena, IRandom&, authz::jwt::TokenRef& token) { token.tenants.reset(); },
|
||||
"NoTenants",
|
||||
|
@ -336,7 +345,7 @@ TEST_CASE("/fdbrpc/authz/TokenCache/GoodTokens") {
|
|||
authz::jwt::makeRandomTokenSpec(arena, *deterministicRandom(), authz::Algorithm::ES256);
|
||||
state StringRef signedToken;
|
||||
FlowTransport::transport().addPublicKey(pubKeyName, privateKey.toPublic());
|
||||
tokenSpec.expiresAtUnixTime = static_cast<uint64_t>(g_network->timer() + 2.0);
|
||||
tokenSpec.expiresAtUnixTime = g_network->timer() + 2.0;
|
||||
tokenSpec.keyId = pubKeyName;
|
||||
signedToken = authz::jwt::signToken(arena, tokenSpec, privateKey);
|
||||
if (!TokenCache::instance().validate(tokenSpec.tenants.get()[0], signedToken)) {
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include "flow/network.h"
|
||||
#include "flow/serialize.h"
|
||||
#include "flow/Arena.h"
|
||||
#include "flow/AutoCPointer.h"
|
||||
#include "flow/Error.h"
|
||||
#include "flow/IRandom.h"
|
||||
#include "flow/MkCert.h"
|
||||
|
@ -30,6 +31,7 @@
|
|||
#include "flow/Trace.h"
|
||||
#include "flow/UnitTest.h"
|
||||
#include <fmt/format.h>
|
||||
#include <cmath>
|
||||
#include <iterator>
|
||||
#include <string_view>
|
||||
#include <type_traits>
|
||||
|
@ -87,6 +89,51 @@ bool checkSignAlgorithm(PKeyAlgorithm algo, PrivateKey key) {
|
|||
}
|
||||
}
|
||||
|
||||
Optional<StringRef> convertEs256P1363ToDer(Arena& arena, StringRef p1363) {
|
||||
const int SIGLEN = p1363.size();
|
||||
const int HALF_SIGLEN = SIGLEN / 2;
|
||||
auto r = AutoCPointer(BN_bin2bn(p1363.begin(), HALF_SIGLEN, nullptr), &::BN_free);
|
||||
auto s = AutoCPointer(BN_bin2bn(p1363.begin() + HALF_SIGLEN, HALF_SIGLEN, nullptr), &::BN_free);
|
||||
if (!r || !s)
|
||||
return {};
|
||||
auto sig = AutoCPointer(::ECDSA_SIG_new(), &ECDSA_SIG_free);
|
||||
if (!sig)
|
||||
return {};
|
||||
::ECDSA_SIG_set0(sig, r.release(), s.release());
|
||||
auto const derLen = ::i2d_ECDSA_SIG(sig, nullptr);
|
||||
if (derLen < 0)
|
||||
return {};
|
||||
auto buf = new (arena) uint8_t[derLen];
|
||||
auto bufPtr = buf;
|
||||
::i2d_ECDSA_SIG(sig, &bufPtr);
|
||||
return StringRef(buf, derLen);
|
||||
}
|
||||
|
||||
Optional<StringRef> convertEs256DerToP1363(Arena& arena, StringRef der) {
|
||||
uint8_t const* derPtr = der.begin();
|
||||
auto sig = AutoCPointer(::d2i_ECDSA_SIG(nullptr, &derPtr, der.size()), &::ECDSA_SIG_free);
|
||||
if (!sig) {
|
||||
return {};
|
||||
}
|
||||
// ES256-specific constant. Adapt as needed
|
||||
constexpr const int SIGLEN = 64;
|
||||
constexpr const int HALF_SIGLEN = SIGLEN / 2;
|
||||
auto buf = new (arena) uint8_t[SIGLEN];
|
||||
::memset(buf, 0, SIGLEN);
|
||||
auto bufr = buf;
|
||||
auto bufs = bufr + HALF_SIGLEN;
|
||||
auto r = std::add_pointer_t<BIGNUM const>();
|
||||
auto s = std::add_pointer_t<BIGNUM const>();
|
||||
ECDSA_SIG_get0(sig, &r, &s);
|
||||
auto const lenr = BN_num_bytes(r);
|
||||
auto const lens = BN_num_bytes(s);
|
||||
if (lenr > HALF_SIGLEN || lens > HALF_SIGLEN)
|
||||
return {};
|
||||
BN_bn2bin(r, bufr + (HALF_SIGLEN - lenr));
|
||||
BN_bn2bin(s, bufs + (HALF_SIGLEN - lens));
|
||||
return StringRef(buf, SIGLEN);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace authz {
|
||||
|
@ -130,11 +177,7 @@ SignedTokenRef signToken(Arena& arena, TokenRef token, StringRef keyName, Privat
|
|||
auto writer = ObjectWriter([&arena](size_t len) { return new (arena) uint8_t[len]; }, IncludeVersion());
|
||||
writer.serialize(token);
|
||||
auto tokenStr = writer.toStringRef();
|
||||
auto [signAlgo, digest] = getMethod(Algorithm::ES256);
|
||||
if (!checkSignAlgorithm(signAlgo, privateKey)) {
|
||||
throw digital_signature_ops_error();
|
||||
}
|
||||
auto sig = privateKey.sign(arena, tokenStr, *digest);
|
||||
auto sig = privateKey.sign(arena, tokenStr, *::EVP_sha256());
|
||||
ret.token = tokenStr;
|
||||
ret.signature = sig;
|
||||
ret.keyName = StringRef(arena, keyName);
|
||||
|
@ -142,10 +185,7 @@ SignedTokenRef signToken(Arena& arena, TokenRef token, StringRef keyName, Privat
|
|||
}
|
||||
|
||||
bool verifyToken(SignedTokenRef signedToken, PublicKey publicKey) {
|
||||
auto [keyAlg, digest] = getMethod(Algorithm::ES256);
|
||||
if (!checkVerifyAlgorithm(keyAlg, publicKey))
|
||||
return false;
|
||||
return publicKey.verify(signedToken.token, signedToken.signature, *digest);
|
||||
return publicKey.verify(signedToken.token, signedToken.signature, *::EVP_sha256());
|
||||
}
|
||||
|
||||
TokenRef makeRandomTokenSpec(Arena& arena, IRandom& rng) {
|
||||
|
@ -268,6 +308,17 @@ StringRef signToken(Arena& arena, TokenRef tokenSpec, PrivateKey privateKey) {
|
|||
throw digital_signature_ops_error();
|
||||
}
|
||||
auto plainSig = privateKey.sign(tmpArena, tokenPart, *digest);
|
||||
if (tokenSpec.algorithm == Algorithm::ES256) {
|
||||
// Need to convert ASN.1/DER signature to IEEE-P1363
|
||||
auto convertedSig = convertEs256DerToP1363(tmpArena, plainSig);
|
||||
if (!convertedSig.present()) {
|
||||
auto tmpArena = Arena();
|
||||
TraceEvent(SevWarn, "TokenSigConversionFailure")
|
||||
.detail("TokenSpec", tokenSpec.toStringRef(tmpArena).toString());
|
||||
throw digital_signature_ops_error();
|
||||
}
|
||||
plainSig = convertedSig.get();
|
||||
}
|
||||
auto const sigPartLen = base64url::encodedLength(plainSig.size());
|
||||
auto const totalLen = tokenPart.size() + 1 + sigPartLen;
|
||||
auto out = new (arena) uint8_t[totalLen];
|
||||
|
@ -335,9 +386,9 @@ bool parseField(Arena& arena, Optional<FieldType>& out, const rapidjson::Documen
|
|||
return false;
|
||||
out = StringRef(arena, reinterpret_cast<const uint8_t*>(field.GetString()), field.GetStringLength());
|
||||
} else if constexpr (std::is_same_v<FieldType, uint64_t>) {
|
||||
if (!field.IsUint64())
|
||||
if (!field.IsNumber())
|
||||
return false;
|
||||
out = field.GetUint64();
|
||||
out = static_cast<uint64_t>(field.GetDouble());
|
||||
} else {
|
||||
if (!field.IsArray())
|
||||
return false;
|
||||
|
@ -442,13 +493,17 @@ bool verifyToken(StringRef signedToken, PublicKey publicKey) {
|
|||
auto [verifyAlgo, digest] = getMethod(parsedToken.algorithm);
|
||||
if (!checkVerifyAlgorithm(verifyAlgo, publicKey))
|
||||
return false;
|
||||
if (parsedToken.algorithm == Algorithm::ES256) {
|
||||
// Need to convert IEEE-P1363 signature to ASN.1/DER
|
||||
auto convertedSig = convertEs256P1363ToDer(arena, sig);
|
||||
if (!convertedSig.present())
|
||||
return false;
|
||||
sig = convertedSig.get();
|
||||
}
|
||||
return publicKey.verify(b64urlTokenPart, sig, *digest);
|
||||
}
|
||||
|
||||
TokenRef makeRandomTokenSpec(Arena& arena, IRandom& rng, Algorithm alg) {
|
||||
if (alg != Algorithm::ES256) {
|
||||
throw unsupported_operation();
|
||||
}
|
||||
auto ret = TokenRef{};
|
||||
ret.algorithm = alg;
|
||||
ret.keyId = genRandomAlphanumStringRef(arena, rng, MaxKeyNameLenPlus1);
|
||||
|
@ -460,7 +515,7 @@ TokenRef makeRandomTokenSpec(Arena& arena, IRandom& rng, Algorithm alg) {
|
|||
for (auto i = 0; i < numAudience; i++)
|
||||
aud[i] = genRandomAlphanumStringRef(arena, rng, MaxTenantNameLenPlus1);
|
||||
ret.audience = VectorRef<StringRef>(aud, numAudience);
|
||||
ret.issuedAtUnixTime = uint64_t(std::floor(g_network->timer()));
|
||||
ret.issuedAtUnixTime = g_network->timer();
|
||||
ret.notBeforeUnixTime = ret.issuedAtUnixTime.get();
|
||||
ret.expiresAtUnixTime = ret.issuedAtUnixTime.get() + rng.randomInt(360, 1080 + 1);
|
||||
auto numTenants = rng.randomInt(1, 3);
|
||||
|
@ -569,51 +624,68 @@ TEST_CASE("/fdbrpc/TokenSign/JWT/ToStringRef") {
|
|||
}
|
||||
|
||||
TEST_CASE("/fdbrpc/TokenSign/bench") {
|
||||
constexpr auto repeat = 5;
|
||||
constexpr auto numSamples = 10000;
|
||||
auto keys = std::vector<PrivateKey>(numSamples);
|
||||
auto pubKeys = std::vector<PublicKey>(numSamples);
|
||||
for (auto i = 0; i < numSamples; i++) {
|
||||
keys[i] = mkcert::makeEcP256();
|
||||
pubKeys[i] = keys[i].toPublic();
|
||||
}
|
||||
fmt::print("{} keys generated\n", numSamples);
|
||||
auto& rng = *deterministicRandom();
|
||||
auto arena = Arena();
|
||||
auto jwts = new (arena) StringRef[numSamples];
|
||||
auto fbs = new (arena) StringRef[numSamples];
|
||||
{
|
||||
auto tmpArena = Arena();
|
||||
auto keyTypes = std::array<StringRef, 2>{ "EC"_sr, "RSA"_sr };
|
||||
for (auto kty : keyTypes) {
|
||||
constexpr auto repeat = 5;
|
||||
constexpr auto numSamples = 10000;
|
||||
fmt::print("=== {} keys case\n", kty.toString());
|
||||
auto key = kty == "EC"_sr ? mkcert::makeEcP256() : mkcert::makeRsa4096Bit();
|
||||
auto pubKey = key.toPublic();
|
||||
auto& rng = *deterministicRandom();
|
||||
auto arena = Arena();
|
||||
auto jwtSpecs = new (arena) authz::jwt::TokenRef[numSamples];
|
||||
auto fbSpecs = new (arena) authz::flatbuffers::TokenRef[numSamples];
|
||||
auto jwts = new (arena) StringRef[numSamples];
|
||||
auto fbs = new (arena) StringRef[numSamples];
|
||||
for (auto i = 0; i < numSamples; i++) {
|
||||
auto jwtSpec = authz::jwt::makeRandomTokenSpec(tmpArena, rng, authz::Algorithm::ES256);
|
||||
jwts[i] = authz::jwt::signToken(arena, jwtSpec, keys[i]);
|
||||
auto fbSpec = authz::flatbuffers::makeRandomTokenSpec(tmpArena, rng);
|
||||
auto fbToken = authz::flatbuffers::signToken(tmpArena, fbSpec, "defaultKey"_sr, keys[i]);
|
||||
auto wr = ObjectWriter([&arena](size_t len) { return new (arena) uint8_t[len]; }, Unversioned());
|
||||
wr.serialize(fbToken);
|
||||
fbs[i] = wr.toStringRef();
|
||||
jwtSpecs[i] = authz::jwt::makeRandomTokenSpec(
|
||||
arena, rng, kty == "EC"_sr ? authz::Algorithm::ES256 : authz::Algorithm::RS256);
|
||||
fbSpecs[i] = authz::flatbuffers::makeRandomTokenSpec(arena, rng);
|
||||
}
|
||||
{
|
||||
auto const jwtSignBegin = timer_monotonic();
|
||||
for (auto i = 0; i < numSamples; i++) {
|
||||
jwts[i] = authz::jwt::signToken(arena, jwtSpecs[i], key);
|
||||
}
|
||||
auto const jwtSignEnd = timer_monotonic();
|
||||
fmt::print("JWT Sign : {:.2f} OPS\n", numSamples / (jwtSignEnd - jwtSignBegin));
|
||||
}
|
||||
{
|
||||
auto const jwtVerifyBegin = timer_monotonic();
|
||||
for (auto rep = 0; rep < repeat; rep++) {
|
||||
for (auto i = 0; i < numSamples; i++) {
|
||||
auto verifyOk = authz::jwt::verifyToken(jwts[i], pubKey);
|
||||
ASSERT(verifyOk);
|
||||
}
|
||||
}
|
||||
auto const jwtVerifyEnd = timer_monotonic();
|
||||
fmt::print("JWT Verify : {:.2f} OPS\n", repeat * numSamples / (jwtVerifyEnd - jwtVerifyBegin));
|
||||
}
|
||||
{
|
||||
auto tmpArena = Arena();
|
||||
auto const fbSignBegin = timer_monotonic();
|
||||
for (auto i = 0; i < numSamples; i++) {
|
||||
auto fbToken = authz::flatbuffers::signToken(tmpArena, fbSpecs[i], "defaultKey"_sr, key);
|
||||
auto wr = ObjectWriter([&arena](size_t len) { return new (arena) uint8_t[len]; }, Unversioned());
|
||||
wr.serialize(fbToken);
|
||||
fbs[i] = wr.toStringRef();
|
||||
}
|
||||
auto const fbSignEnd = timer_monotonic();
|
||||
fmt::print("FlatBuffers Sign : {:.2f} OPS\n", numSamples / (fbSignEnd - fbSignBegin));
|
||||
}
|
||||
{
|
||||
auto const fbVerifyBegin = timer_monotonic();
|
||||
for (auto rep = 0; rep < repeat; rep++) {
|
||||
for (auto i = 0; i < numSamples; i++) {
|
||||
auto signedToken = ObjectReader::fromStringRef<Standalone<authz::flatbuffers::SignedTokenRef>>(
|
||||
fbs[i], Unversioned());
|
||||
auto verifyOk = authz::flatbuffers::verifyToken(signedToken, pubKey);
|
||||
ASSERT(verifyOk);
|
||||
}
|
||||
}
|
||||
auto const fbVerifyEnd = timer_monotonic();
|
||||
fmt::print("FlatBuffers Verify : {:.2f} OPS\n", repeat * numSamples / (fbVerifyEnd - fbVerifyBegin));
|
||||
}
|
||||
}
|
||||
fmt::print("{} FB/JWT tokens generated\n", numSamples);
|
||||
auto jwtBegin = timer_monotonic();
|
||||
for (auto rep = 0; rep < repeat; rep++) {
|
||||
for (auto i = 0; i < numSamples; i++) {
|
||||
auto verifyOk = authz::jwt::verifyToken(jwts[i], pubKeys[i]);
|
||||
ASSERT(verifyOk);
|
||||
}
|
||||
}
|
||||
auto jwtEnd = timer_monotonic();
|
||||
fmt::print("JWT: {:.2f} OPS\n", repeat * numSamples / (jwtEnd - jwtBegin));
|
||||
auto fbBegin = timer_monotonic();
|
||||
for (auto rep = 0; rep < repeat; rep++) {
|
||||
for (auto i = 0; i < numSamples; i++) {
|
||||
auto signedToken =
|
||||
ObjectReader::fromStringRef<Standalone<authz::flatbuffers::SignedTokenRef>>(fbs[i], Unversioned());
|
||||
auto verifyOk = authz::flatbuffers::verifyToken(signedToken, pubKeys[i]);
|
||||
ASSERT(verifyOk);
|
||||
}
|
||||
}
|
||||
auto fbEnd = timer_monotonic();
|
||||
fmt::print("FlatBuffers: {:.2f} OPS\n", repeat * numSamples / (fbEnd - fbBegin));
|
||||
return Void();
|
||||
}
|
||||
|
|
|
@ -298,6 +298,12 @@ public:
|
|||
void removePublicKey(StringRef name);
|
||||
void removeAllPublicKeys();
|
||||
|
||||
// Synchronously load and apply JWKS (RFC 7517) public key file with which to verify authorization tokens.
|
||||
void loadPublicKeyFile(const std::string& publicKeyFilePath);
|
||||
|
||||
// Periodically read JWKS (RFC 7517) public key file to refresh public key set.
|
||||
void watchPublicKeyFile(const std::string& publicKeyFilePath);
|
||||
|
||||
private:
|
||||
class TransportData* self;
|
||||
};
|
||||
|
|
|
@ -47,6 +47,7 @@ public:
|
|||
int read(uint8_t* begin, uint8_t* end) override;
|
||||
int write(SendBuffer const* buffer, int limit) override;
|
||||
NetworkAddress getPeerAddress() const override;
|
||||
bool hasTrustedPeer() const override;
|
||||
UID getDebugID() const override;
|
||||
boost::asio::ip::tcp::socket& getSocket() override { return socket; }
|
||||
static Future<std::vector<NetworkAddress>> resolveTCPEndpoint(const std::string& host,
|
||||
|
|
|
@ -208,7 +208,7 @@ SimClogging g_clogging;
|
|||
|
||||
struct Sim2Conn final : IConnection, ReferenceCounted<Sim2Conn> {
|
||||
Sim2Conn(ISimulator::ProcessInfo* process)
|
||||
: opened(false), closedByCaller(false), stableConnection(false), process(process),
|
||||
: opened(false), closedByCaller(false), stableConnection(false), trustedPeer(true), process(process),
|
||||
dbgid(deterministicRandom()->randomUniqueID()), stopReceive(Never()) {
|
||||
pipes = sender(this) && receiver(this);
|
||||
}
|
||||
|
@ -259,6 +259,8 @@ struct Sim2Conn final : IConnection, ReferenceCounted<Sim2Conn> {
|
|||
|
||||
bool isPeerGone() const { return !peer || peerProcess->failed; }
|
||||
|
||||
bool hasTrustedPeer() const override { return trustedPeer; }
|
||||
|
||||
bool isStableConnection() const override { return stableConnection; }
|
||||
|
||||
void peerClosed() {
|
||||
|
@ -327,7 +329,7 @@ struct Sim2Conn final : IConnection, ReferenceCounted<Sim2Conn> {
|
|||
|
||||
boost::asio::ip::tcp::socket& getSocket() override { throw operation_failed(); }
|
||||
|
||||
bool opened, closedByCaller, stableConnection;
|
||||
bool opened, closedByCaller, stableConnection, trustedPeer;
|
||||
|
||||
private:
|
||||
ISimulator::ProcessInfo *process, *peerProcess;
|
||||
|
|
|
@ -0,0 +1,357 @@
|
|||
/*
|
||||
* AuthzTlsTest.cpp
|
||||
*
|
||||
* This source file is part of the FoundationDB open source project
|
||||
*
|
||||
* Copyright 2013-2022 Apple Inc. and the FoundationDB project authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef _WIN32
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
#include <fmt/format.h>
|
||||
#include <unistd.h>
|
||||
#include <string_view>
|
||||
#include <signal.h>
|
||||
#include <sys/wait.h>
|
||||
#include "flow/Arena.h"
|
||||
#include "flow/MkCert.h"
|
||||
#include "flow/ScopeExit.h"
|
||||
#include "flow/TLSConfig.actor.h"
|
||||
#include "fdbrpc/fdbrpc.h"
|
||||
#include "fdbrpc/FlowTransport.h"
|
||||
#include "flow/actorcompiler.h" // This must be the last #include.
|
||||
|
||||
std::FILE* outp = stdout;
|
||||
|
||||
template <class... Args>
|
||||
void log(Args&&... args) {
|
||||
auto buf = fmt::memory_buffer{};
|
||||
fmt::format_to(std::back_inserter(buf), std::forward<Args>(args)...);
|
||||
fmt::print(outp, "{}\n", std::string_view(buf.data(), buf.size()));
|
||||
}
|
||||
|
||||
template <class... Args>
|
||||
void logc(Args&&... args) {
|
||||
auto buf = fmt::memory_buffer{};
|
||||
fmt::format_to(std::back_inserter(buf), "[CLIENT] ");
|
||||
fmt::format_to(std::back_inserter(buf), std::forward<Args>(args)...);
|
||||
fmt::print(outp, "{}\n", std::string_view(buf.data(), buf.size()));
|
||||
}
|
||||
|
||||
template <class... Args>
|
||||
void logs(Args&&... args) {
|
||||
auto buf = fmt::memory_buffer{};
|
||||
fmt::format_to(std::back_inserter(buf), "[SERVER] ");
|
||||
fmt::format_to(std::back_inserter(buf), std::forward<Args>(args)...);
|
||||
fmt::print(outp, "{}\n", std::string_view(buf.data(), buf.size()));
|
||||
}
|
||||
|
||||
template <class... Args>
|
||||
void logm(Args&&... args) {
|
||||
auto buf = fmt::memory_buffer{};
|
||||
fmt::format_to(std::back_inserter(buf), "[ MAIN ] ");
|
||||
fmt::format_to(std::back_inserter(buf), std::forward<Args>(args)...);
|
||||
fmt::print(outp, "{}\n", std::string_view(buf.data(), buf.size()));
|
||||
}
|
||||
|
||||
struct TLSCreds {
|
||||
std::string certBytes;
|
||||
std::string keyBytes;
|
||||
std::string caBytes;
|
||||
};
|
||||
|
||||
TLSCreds makeCreds(int chainLen, mkcert::ESide side) {
|
||||
if (chainLen == 0)
|
||||
return {};
|
||||
auto arena = Arena();
|
||||
auto ret = TLSCreds{};
|
||||
auto specs = mkcert::makeCertChainSpec(arena, std::labs(chainLen), side);
|
||||
if (chainLen < 0) {
|
||||
specs[0].offsetNotBefore = -60l * 60 * 24 * 365;
|
||||
specs[0].offsetNotAfter = -10l; // cert that expired 10 seconds ago
|
||||
}
|
||||
auto chain = mkcert::makeCertChain(arena, specs, {} /* create root CA cert from spec*/);
|
||||
if (chain.size() == 1) {
|
||||
ret.certBytes = concatCertChain(arena, chain).toString();
|
||||
} else {
|
||||
auto nonRootChain = chain;
|
||||
nonRootChain.pop_back();
|
||||
ret.certBytes = concatCertChain(arena, nonRootChain).toString();
|
||||
}
|
||||
ret.caBytes = chain.back().certPem.toString();
|
||||
ret.keyBytes = chain.front().privateKeyPem.toString();
|
||||
return ret;
|
||||
}
|
||||
|
||||
enum class Result : int {
|
||||
TRUSTED = 0,
|
||||
UNTRUSTED,
|
||||
ERROR,
|
||||
};
|
||||
|
||||
template <>
|
||||
struct fmt::formatter<Result> {
|
||||
constexpr auto parse(format_parse_context& ctx) -> decltype(ctx.begin()) { return ctx.begin(); }
|
||||
|
||||
template <class FormatContext>
|
||||
auto format(const Result& r, FormatContext& ctx) -> decltype(ctx.out()) {
|
||||
if (r == Result::TRUSTED)
|
||||
return fmt::format_to(ctx.out(), "TRUSTED");
|
||||
else if (r == Result::UNTRUSTED)
|
||||
return fmt::format_to(ctx.out(), "UNTRUSTED");
|
||||
else
|
||||
return fmt::format_to(ctx.out(), "ERROR");
|
||||
}
|
||||
};
|
||||
|
||||
ACTOR template <class T>
|
||||
Future<T> stopNetworkAfter(Future<T> what) {
|
||||
T t = wait(what);
|
||||
g_network->stop();
|
||||
return t;
|
||||
}
|
||||
|
||||
// Reflective struct containing information about the requester from a server PoV
|
||||
struct SessionInfo {
|
||||
constexpr static FileIdentifier file_identifier = 1578312;
|
||||
bool isPeerTrusted = false;
|
||||
NetworkAddress peerAddress;
|
||||
|
||||
template <class Ar>
|
||||
void serialize(Ar& ar) {
|
||||
serializer(ar, isPeerTrusted, peerAddress);
|
||||
}
|
||||
};
|
||||
|
||||
struct SessionProbeRequest {
|
||||
constexpr static FileIdentifier file_identifier = 1559713;
|
||||
ReplyPromise<SessionInfo> reply{ PeerCompatibilityPolicy{ RequirePeer::AtLeast,
|
||||
ProtocolVersion::withStableInterfaces() } };
|
||||
|
||||
bool verify() const { return true; }
|
||||
|
||||
template <class Ar>
|
||||
void serialize(Ar& ar) {
|
||||
serializer(ar, reply);
|
||||
}
|
||||
};
|
||||
|
||||
struct SessionProbeReceiver final : NetworkMessageReceiver {
|
||||
SessionProbeReceiver() {}
|
||||
void receive(ArenaObjectReader& reader) override {
|
||||
SessionProbeRequest req;
|
||||
reader.deserialize(req);
|
||||
SessionInfo res;
|
||||
res.isPeerTrusted = FlowTransport::transport().currentDeliveryPeerIsTrusted();
|
||||
res.peerAddress = FlowTransport::transport().currentDeliveryPeerAddress();
|
||||
req.reply.send(res);
|
||||
}
|
||||
PeerCompatibilityPolicy peerCompatibilityPolicy() const override {
|
||||
return PeerCompatibilityPolicy{ RequirePeer::AtLeast, ProtocolVersion::withStableInterfaces() };
|
||||
}
|
||||
bool isPublic() const override { return true; }
|
||||
};
|
||||
|
||||
Future<Void> runServer(Future<Void> listenFuture, const Endpoint& endpoint, int addrPipe, int completionPipe) {
|
||||
auto realAddr = FlowTransport::transport().getLocalAddresses().address;
|
||||
logs("Listening at {}", realAddr.toString());
|
||||
logs("Endpoint token is {}", endpoint.token.toString());
|
||||
// below writes/reads would block, but this is good enough for a test.
|
||||
if (sizeof(realAddr) != ::write(addrPipe, &realAddr, sizeof(realAddr))) {
|
||||
logs("Failed to write server addr to pipe: {}", strerror(errno));
|
||||
return Void();
|
||||
}
|
||||
if (sizeof(endpoint.token) != ::write(addrPipe, &endpoint.token, sizeof(endpoint.token))) {
|
||||
logs("Failed to write server endpoint to pipe: {}", strerror(errno));
|
||||
return Void();
|
||||
}
|
||||
auto done = false;
|
||||
if (sizeof(done) != ::read(completionPipe, &done, sizeof(done))) {
|
||||
logs("Failed to read completion flag from pipe: {}", strerror(errno));
|
||||
return Void();
|
||||
}
|
||||
return Void();
|
||||
}
|
||||
|
||||
ACTOR Future<Void> waitAndPrintResponse(Future<SessionInfo> response, Result* rc) {
|
||||
try {
|
||||
SessionInfo info = wait(response);
|
||||
logc("Probe response: trusted={} peerAddress={}", info.isPeerTrusted, info.peerAddress.toString());
|
||||
*rc = info.isPeerTrusted ? Result::TRUSTED : Result::UNTRUSTED;
|
||||
} catch (Error& err) {
|
||||
logc("Error: {}", err.what());
|
||||
*rc = Result::ERROR;
|
||||
}
|
||||
return Void();
|
||||
}
|
||||
|
||||
template <bool IsServer>
|
||||
int runHost(TLSCreds creds, int addrPipe, int completionPipe, Result expect) {
|
||||
auto tlsConfig = TLSConfig(IsServer ? TLSEndpointType::SERVER : TLSEndpointType::CLIENT);
|
||||
tlsConfig.setCertificateBytes(creds.certBytes);
|
||||
tlsConfig.setCABytes(creds.caBytes);
|
||||
tlsConfig.setKeyBytes(creds.keyBytes);
|
||||
g_network = newNet2(tlsConfig);
|
||||
openTraceFile(NetworkAddress(),
|
||||
10 << 20,
|
||||
10 << 20,
|
||||
".",
|
||||
IsServer ? "authz_tls_unittest_server" : "authz_tls_unittest_client");
|
||||
FlowTransport::createInstance(!IsServer, 1, WLTOKEN_RESERVED_COUNT);
|
||||
auto& transport = FlowTransport::transport();
|
||||
if constexpr (IsServer) {
|
||||
auto addr = NetworkAddress::parse("127.0.0.1:0:tls");
|
||||
auto thread = std::thread([]() {
|
||||
g_network->run();
|
||||
flushTraceFileVoid();
|
||||
});
|
||||
auto endpoint = Endpoint();
|
||||
auto receiver = SessionProbeReceiver();
|
||||
transport.addEndpoint(endpoint, &receiver, TaskPriority::ReadSocket);
|
||||
runServer(transport.bind(addr, addr), endpoint, addrPipe, completionPipe);
|
||||
auto cleanupGuard = ScopeExit([&thread]() {
|
||||
g_network->stop();
|
||||
thread.join();
|
||||
});
|
||||
} else {
|
||||
auto dest = Endpoint();
|
||||
auto& serverAddr = dest.addresses.address;
|
||||
if (sizeof(serverAddr) != ::read(addrPipe, &serverAddr, sizeof(serverAddr))) {
|
||||
logc("Failed to read server addr from pipe: {}", strerror(errno));
|
||||
return 1;
|
||||
}
|
||||
auto& token = dest.token;
|
||||
if (sizeof(token) != ::read(addrPipe, &token, sizeof(token))) {
|
||||
logc("Failed to read server endpoint token from pipe: {}", strerror(errno));
|
||||
return 2;
|
||||
}
|
||||
logc("Server address is {}", serverAddr.toString());
|
||||
logc("Server endpoint token is {}", token.toString());
|
||||
auto sessionProbeReq = SessionProbeRequest{};
|
||||
transport.sendUnreliable(SerializeSource(sessionProbeReq), dest, true /*openConnection*/);
|
||||
logc("Request is sent");
|
||||
auto probeResponse = sessionProbeReq.reply.getFuture();
|
||||
auto result = Result::TRUSTED;
|
||||
auto timeout = delay(5);
|
||||
auto complete = waitAndPrintResponse(probeResponse, &result);
|
||||
auto f = stopNetworkAfter(complete || timeout);
|
||||
auto rc = 0;
|
||||
g_network->run();
|
||||
if (!complete.isReady()) {
|
||||
logc("Error: Probe request timed out");
|
||||
rc = 3;
|
||||
}
|
||||
auto done = true;
|
||||
if (sizeof(done) != ::write(completionPipe, &done, sizeof(done))) {
|
||||
logc("Failed to signal server to terminate: {}", strerror(errno));
|
||||
rc = 4;
|
||||
}
|
||||
if (rc == 0) {
|
||||
if (expect != result) {
|
||||
logc("Test failed: expected {}, got {}", expect, result);
|
||||
rc = 5;
|
||||
} else {
|
||||
logc("Response OK: got {} as expected", result);
|
||||
}
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
int runTlsTest(int serverChainLen, int clientChainLen) {
|
||||
log("==== BEGIN TESTCASE ====");
|
||||
auto expect = Result::ERROR;
|
||||
if (serverChainLen > 0) {
|
||||
if (clientChainLen > 0)
|
||||
expect = Result::TRUSTED;
|
||||
else if (clientChainLen == 0)
|
||||
expect = Result::UNTRUSTED;
|
||||
}
|
||||
log("Cert chain length: server={} client={}", serverChainLen, clientChainLen);
|
||||
auto arena = Arena();
|
||||
auto serverCreds = makeCreds(serverChainLen, mkcert::ESide::Server);
|
||||
auto clientCreds = makeCreds(clientChainLen, mkcert::ESide::Client);
|
||||
// make server and client trust each other
|
||||
std::swap(serverCreds.caBytes, clientCreds.caBytes);
|
||||
auto clientPid = pid_t{};
|
||||
auto serverPid = pid_t{};
|
||||
int addrPipe[2];
|
||||
int completionPipe[2];
|
||||
if (::pipe(addrPipe) || ::pipe(completionPipe)) {
|
||||
logm("Pipe open failed: {}", strerror(errno));
|
||||
return 1;
|
||||
}
|
||||
auto pipeCleanup = ScopeExit([&addrPipe, &completionPipe]() {
|
||||
::close(addrPipe[0]);
|
||||
::close(addrPipe[1]);
|
||||
::close(completionPipe[0]);
|
||||
::close(completionPipe[1]);
|
||||
});
|
||||
serverPid = fork();
|
||||
if (serverPid == 0) {
|
||||
_exit(runHost<true>(std::move(serverCreds), addrPipe[1], completionPipe[0], expect));
|
||||
}
|
||||
clientPid = fork();
|
||||
if (clientPid == 0) {
|
||||
_exit(runHost<false>(std::move(clientCreds), addrPipe[0], completionPipe[1], expect));
|
||||
}
|
||||
auto pid = pid_t{};
|
||||
auto status = int{};
|
||||
pid = waitpid(clientPid, &status, 0);
|
||||
auto ok = true;
|
||||
if (pid < 0) {
|
||||
logm("waitpid() for client failed with {}", strerror(errno));
|
||||
ok = false;
|
||||
} else {
|
||||
if (status != 0) {
|
||||
logm("Client error: rc={}", status);
|
||||
ok = false;
|
||||
} else {
|
||||
logm("Client OK");
|
||||
}
|
||||
}
|
||||
pid = waitpid(serverPid, &status, 0);
|
||||
if (pid < 0) {
|
||||
logm("waitpid() for server failed with {}", strerror(errno));
|
||||
ok = false;
|
||||
} else {
|
||||
if (status != 0) {
|
||||
logm("Server error: rc={}", status);
|
||||
ok = false;
|
||||
} else {
|
||||
logm("Server OK");
|
||||
}
|
||||
}
|
||||
log(ok ? "OK" : "FAILED");
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main() {
|
||||
std::pair<int, int> inputs[] = { { 3, 2 }, { 4, 0 }, { 1, 3 }, { 1, 0 }, { 2, 0 }, { 3, 3 }, { 3, 0 } };
|
||||
for (auto input : inputs) {
|
||||
auto [serverChainLen, clientChainLen] = input;
|
||||
if (auto rc = runTlsTest(serverChainLen, clientChainLen))
|
||||
return rc;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
#else // _WIN32
|
||||
|
||||
int main() {
|
||||
return 0;
|
||||
}
|
||||
#endif // _WIN32
|
|
@ -0,0 +1,6 @@
|
|||
if(NOT WIN32)
|
||||
add_flow_target(EXECUTABLE NAME authz_tls_unittest SRCS AuthzTlsTest.actor.cpp)
|
||||
target_link_libraries(authz_tls_unittest PRIVATE flow fdbrpc fmt::fmt)
|
||||
add_test(NAME authorization_tls_unittest
|
||||
COMMAND $<TARGET_FILE:authz_tls_unittest>)
|
||||
endif()
|
|
@ -113,7 +113,7 @@ enum {
|
|||
OPT_METRICSPREFIX, OPT_LOGGROUP, OPT_LOCALITY, OPT_IO_TRUST_SECONDS, OPT_IO_TRUST_WARN_ONLY, OPT_FILESYSTEM, OPT_PROFILER_RSS_SIZE, OPT_KVFILE,
|
||||
OPT_TRACE_FORMAT, OPT_WHITELIST_BINPATH, OPT_BLOB_CREDENTIAL_FILE, OPT_CONFIG_PATH, OPT_USE_TEST_CONFIG_DB, OPT_FAULT_INJECTION, OPT_PROFILER, OPT_PRINT_SIMTIME,
|
||||
OPT_FLOW_PROCESS_NAME, OPT_FLOW_PROCESS_ENDPOINT, OPT_IP_TRUSTED_MASK, OPT_KMS_CONN_DISCOVERY_URL_FILE, OPT_KMS_CONNECTOR_TYPE, OPT_KMS_CONN_VALIDATION_TOKEN_DETAILS,
|
||||
OPT_KMS_CONN_GET_ENCRYPTION_KEYS_ENDPOINT, OPT_NEW_CLUSTER_KEY, OPT_USE_FUTURE_PROTOCOL_VERSION
|
||||
OPT_KMS_CONN_GET_ENCRYPTION_KEYS_ENDPOINT, OPT_NEW_CLUSTER_KEY, OPT_AUTHZ_PUBLIC_KEY_FILE, OPT_USE_FUTURE_PROTOCOL_VERSION
|
||||
};
|
||||
|
||||
CSimpleOpt::SOption g_rgOptions[] = {
|
||||
|
@ -128,8 +128,8 @@ CSimpleOpt::SOption g_rgOptions[] = {
|
|||
{ OPT_LISTEN, "-l", SO_REQ_SEP },
|
||||
{ OPT_LISTEN, "--listen-address", SO_REQ_SEP },
|
||||
#ifdef __linux__
|
||||
{ OPT_FILESYSTEM, "--data-filesystem", SO_REQ_SEP },
|
||||
{ OPT_PROFILER_RSS_SIZE, "--rsssize", SO_REQ_SEP },
|
||||
{ OPT_FILESYSTEM, "--data-filesystem", SO_REQ_SEP },
|
||||
{ OPT_PROFILER_RSS_SIZE, "--rsssize", SO_REQ_SEP },
|
||||
#endif
|
||||
{ OPT_DATAFOLDER, "-d", SO_REQ_SEP },
|
||||
{ OPT_DATAFOLDER, "--datadir", SO_REQ_SEP },
|
||||
|
@ -208,6 +208,7 @@ CSimpleOpt::SOption g_rgOptions[] = {
|
|||
{ OPT_FLOW_PROCESS_ENDPOINT, "--process-endpoint", SO_REQ_SEP },
|
||||
{ OPT_IP_TRUSTED_MASK, "--trusted-subnet-", SO_REQ_SEP },
|
||||
{ OPT_NEW_CLUSTER_KEY, "--new-cluster-key", SO_REQ_SEP },
|
||||
{ OPT_AUTHZ_PUBLIC_KEY_FILE, "--authorization-public-key-file", SO_REQ_SEP },
|
||||
{ OPT_KMS_CONN_DISCOVERY_URL_FILE, "--discover-kms-conn-url-file", SO_REQ_SEP },
|
||||
{ OPT_KMS_CONNECTOR_TYPE, "--kms-connector-type", SO_REQ_SEP },
|
||||
{ OPT_KMS_CONN_VALIDATION_TOKEN_DETAILS, "--kms-conn-validation-token-details", SO_REQ_SEP },
|
||||
|
@ -1022,8 +1023,8 @@ enum class ServerRole {
|
|||
};
|
||||
struct CLIOptions {
|
||||
std::string commandLine;
|
||||
std::string fileSystemPath, dataFolder, connFile, seedConnFile, seedConnString, logFolder = ".", metricsConnFile,
|
||||
metricsPrefix, newClusterKey;
|
||||
std::string fileSystemPath, dataFolder, connFile, seedConnFile, seedConnString,
|
||||
logFolder = ".", metricsConnFile, metricsPrefix, newClusterKey, authzPublicKeyFile;
|
||||
std::string logGroup = "default";
|
||||
uint64_t rollsize = TRACE_DEFAULT_ROLL_SIZE;
|
||||
uint64_t maxLogsSize = TRACE_DEFAULT_MAX_LOGS_SIZE;
|
||||
|
@ -1713,6 +1714,10 @@ private:
|
|||
}
|
||||
break;
|
||||
}
|
||||
case OPT_AUTHZ_PUBLIC_KEY_FILE: {
|
||||
authzPublicKeyFile = args.OptionArg();
|
||||
break;
|
||||
}
|
||||
case OPT_USE_FUTURE_PROTOCOL_VERSION: {
|
||||
if (!strcmp(args.OptionArg(), "true")) {
|
||||
::useFutureProtocolVersion();
|
||||
|
@ -2029,6 +2034,16 @@ int main(int argc, char* argv[]) {
|
|||
openTraceFile(
|
||||
opts.publicAddresses.address, opts.rollsize, opts.maxLogsSize, opts.logFolder, "trace", opts.logGroup);
|
||||
g_network->initTLS();
|
||||
if (!opts.authzPublicKeyFile.empty()) {
|
||||
try {
|
||||
FlowTransport::transport().loadPublicKeyFile(opts.authzPublicKeyFile);
|
||||
} catch (Error& e) {
|
||||
TraceEvent("AuthzPublicKeySetLoadError").error(e);
|
||||
}
|
||||
FlowTransport::transport().watchPublicKeyFile(opts.authzPublicKeyFile);
|
||||
} else {
|
||||
TraceEvent(SevInfo, "AuthzPublicKeyFileNotSet");
|
||||
}
|
||||
|
||||
if (expectsPublicAddress) {
|
||||
for (int ii = 0; ii < (opts.publicAddresses.secondaryAddress.present() ? 2 : 1); ++ii) {
|
||||
|
|
|
@ -88,8 +88,3 @@ endif()
|
|||
|
||||
add_executable(mkcert MkCertCli.cpp)
|
||||
target_link_libraries(mkcert PUBLIC flow)
|
||||
|
||||
add_executable(mtls_unittest TLSTest.cpp)
|
||||
target_link_libraries(mtls_unittest PUBLIC flow)
|
||||
add_test(NAME mutual_tls_unittest
|
||||
COMMAND $<TARGET_FILE:mtls_unittest>)
|
||||
|
|
|
@ -129,6 +129,10 @@ void FlowKnobs::initialize(Randomize randomize, IsSimulated isSimulated) {
|
|||
init( NETWORK_TEST_REQUEST_COUNT, 0 ); // 0 -> run forever
|
||||
init( NETWORK_TEST_REQUEST_SIZE, 1 );
|
||||
init( NETWORK_TEST_SCRIPT_MODE, false );
|
||||
|
||||
//Authorization
|
||||
init( PUBLIC_KEY_FILE_MAX_SIZE, 1024 * 1024 );
|
||||
init( PUBLIC_KEY_FILE_REFRESH_INTERVAL_SECONDS, 30 );
|
||||
init( MAX_CACHED_EXPIRED_TOKENS, 1024 );
|
||||
|
||||
//AsyncFileCached
|
||||
|
|
|
@ -166,13 +166,13 @@ PrivateKey makeEcP256() {
|
|||
return PrivateKey(DerEncoded{}, StringRef(buf, len));
|
||||
}
|
||||
|
||||
PrivateKey makeRsa2048Bit() {
|
||||
PrivateKey makeRsa4096Bit() {
|
||||
auto kctx = AutoCPointer(::EVP_PKEY_CTX_new_id(EVP_PKEY_RSA, nullptr), &::EVP_PKEY_CTX_free);
|
||||
OSSL_ASSERT(kctx);
|
||||
auto key = AutoCPointer(nullptr, &::EVP_PKEY_free);
|
||||
auto keyRaw = std::add_pointer_t<EVP_PKEY>();
|
||||
OSSL_ASSERT(0 < ::EVP_PKEY_keygen_init(kctx));
|
||||
OSSL_ASSERT(0 < ::EVP_PKEY_CTX_set_rsa_keygen_bits(kctx, 2048));
|
||||
OSSL_ASSERT(0 < ::EVP_PKEY_CTX_set_rsa_keygen_bits(kctx, 4096));
|
||||
OSSL_ASSERT(0 < ::EVP_PKEY_keygen(kctx, &keyRaw));
|
||||
OSSL_ASSERT(keyRaw);
|
||||
key.reset(keyRaw);
|
||||
|
|
|
@ -50,6 +50,7 @@
|
|||
#include "flow/ProtocolVersion.h"
|
||||
#include "flow/SendBufferIterator.h"
|
||||
#include "flow/TLSConfig.actor.h"
|
||||
#include "flow/WatchFile.actor.h"
|
||||
#include "flow/genericactors.actor.h"
|
||||
#include "flow/Util.h"
|
||||
#include "flow/UnitTest.h"
|
||||
|
@ -238,6 +239,7 @@ public:
|
|||
int sslHandshakerThreadsStarted;
|
||||
int sslPoolHandshakesInProgress;
|
||||
TLSConfig tlsConfig;
|
||||
Reference<TLSPolicy> activeTlsPolicy;
|
||||
Future<Void> backgroundCertRefresh;
|
||||
ETLSInitState tlsInitializedState;
|
||||
|
||||
|
@ -507,6 +509,8 @@ public:
|
|||
|
||||
NetworkAddress getPeerAddress() const override { return peer_address; }
|
||||
|
||||
bool hasTrustedPeer() const override { return true; }
|
||||
|
||||
UID getDebugID() const override { return id; }
|
||||
|
||||
tcp::socket& getSocket() override { return socket; }
|
||||
|
@ -839,7 +843,7 @@ public:
|
|||
explicit SSLConnection(boost::asio::io_service& io_service,
|
||||
Reference<ReferencedObject<boost::asio::ssl::context>> context)
|
||||
: id(nondeterministicRandom()->randomUniqueID()), socket(io_service), ssl_sock(socket, context->mutate()),
|
||||
sslContext(context) {}
|
||||
sslContext(context), has_trusted_peer(false) {}
|
||||
|
||||
explicit SSLConnection(Reference<ReferencedObject<boost::asio::ssl::context>> context, tcp::socket* existingSocket)
|
||||
: id(nondeterministicRandom()->randomUniqueID()), socket(std::move(*existingSocket)),
|
||||
|
@ -900,6 +904,9 @@ public:
|
|||
|
||||
try {
|
||||
Future<Void> onHandshook;
|
||||
ConfigureSSLStream(N2::g_net2->activeTlsPolicy, self->ssl_sock, [this](bool verifyOk) {
|
||||
self->has_trusted_peer = verifyOk;
|
||||
});
|
||||
|
||||
// If the background handshakers are not all busy, use one
|
||||
if (N2::g_net2->sslPoolHandshakesInProgress < N2::g_net2->sslHandshakerThreadsStarted) {
|
||||
|
@ -975,6 +982,10 @@ public:
|
|||
|
||||
try {
|
||||
Future<Void> onHandshook;
|
||||
ConfigureSSLStream(N2::g_net2->activeTlsPolicy, self->ssl_sock, [this](bool verifyOk) {
|
||||
self->has_trusted_peer = verifyOk;
|
||||
});
|
||||
|
||||
// If the background handshakers are not all busy, use one
|
||||
if (N2::g_net2->sslPoolHandshakesInProgress < N2::g_net2->sslHandshakerThreadsStarted) {
|
||||
holder = Hold(&N2::g_net2->sslPoolHandshakesInProgress);
|
||||
|
@ -1108,6 +1119,8 @@ public:
|
|||
|
||||
NetworkAddress getPeerAddress() const override { return peer_address; }
|
||||
|
||||
bool hasTrustedPeer() const override { return has_trusted_peer; }
|
||||
|
||||
UID getDebugID() const override { return id; }
|
||||
|
||||
tcp::socket& getSocket() override { return socket; }
|
||||
|
@ -1120,6 +1133,7 @@ private:
|
|||
ssl_socket ssl_sock;
|
||||
NetworkAddress peer_address;
|
||||
Reference<ReferencedObject<boost::asio::ssl::context>> sslContext;
|
||||
bool has_trusted_peer;
|
||||
|
||||
void init() {
|
||||
// Socket settings that have to be set after connect or accept succeeds
|
||||
|
@ -1165,6 +1179,16 @@ public:
|
|||
NetworkAddress listenAddress)
|
||||
: io_service(io_service), listenAddress(listenAddress), acceptor(io_service, tcpEndpoint(listenAddress)),
|
||||
contextVar(contextVar) {
|
||||
// when port 0 is passed in, a random port will be opened
|
||||
// set listenAddress as the address with the actual port opened instead of port 0
|
||||
if (listenAddress.port == 0) {
|
||||
this->listenAddress = NetworkAddress::parse(acceptor.local_endpoint()
|
||||
.address()
|
||||
.to_string()
|
||||
.append(":")
|
||||
.append(std::to_string(acceptor.local_endpoint().port()))
|
||||
.append(listenAddress.isTLS() ? ":tls" : ""));
|
||||
}
|
||||
platform::setCloseOnExec(acceptor.native_handle());
|
||||
}
|
||||
|
||||
|
@ -1240,45 +1264,11 @@ Net2::Net2(const TLSConfig& tlsConfig, bool useThreadPool, bool useMetrics)
|
|||
updateNow();
|
||||
}
|
||||
|
||||
ACTOR static Future<Void> watchFileForChanges(std::string filename, AsyncTrigger* fileChanged) {
|
||||
if (filename == "") {
|
||||
return Never();
|
||||
}
|
||||
state bool firstRun = true;
|
||||
state bool statError = false;
|
||||
state std::time_t lastModTime = 0;
|
||||
loop {
|
||||
try {
|
||||
std::time_t modtime = wait(IAsyncFileSystem::filesystem()->lastWriteTime(filename));
|
||||
if (firstRun) {
|
||||
lastModTime = modtime;
|
||||
firstRun = false;
|
||||
}
|
||||
if (lastModTime != modtime || statError) {
|
||||
lastModTime = modtime;
|
||||
statError = false;
|
||||
fileChanged->trigger();
|
||||
}
|
||||
} catch (Error& e) {
|
||||
if (e.code() == error_code_io_error) {
|
||||
// EACCES, ELOOP, ENOENT all come out as io_error(), but are more of a system
|
||||
// configuration issue than an FDB problem. If we managed to load valid
|
||||
// certificates, then there's no point in crashing, but we should complain
|
||||
// loudly. IAsyncFile will log the error, but not necessarily as a warning.
|
||||
TraceEvent(SevWarnAlways, "TLSCertificateRefreshStatError").detail("File", filename);
|
||||
statError = true;
|
||||
} else {
|
||||
throw;
|
||||
}
|
||||
}
|
||||
wait(delay(FLOW_KNOBS->TLS_CERT_REFRESH_DELAY_SECONDS));
|
||||
}
|
||||
}
|
||||
|
||||
ACTOR static Future<Void> reloadCertificatesOnChange(
|
||||
TLSConfig config,
|
||||
std::function<void()> onPolicyFailure,
|
||||
AsyncVar<Reference<ReferencedObject<boost::asio::ssl::context>>>* contextVar) {
|
||||
AsyncVar<Reference<ReferencedObject<boost::asio::ssl::context>>>* contextVar,
|
||||
Reference<TLSPolicy>* policy) {
|
||||
if (FLOW_KNOBS->TLS_CERT_REFRESH_DELAY_SECONDS <= 0) {
|
||||
return Void();
|
||||
}
|
||||
|
@ -1292,9 +1282,13 @@ ACTOR static Future<Void> reloadCertificatesOnChange(
|
|||
state int mismatches = 0;
|
||||
state AsyncTrigger fileChanged;
|
||||
state std::vector<Future<Void>> lifetimes;
|
||||
lifetimes.push_back(watchFileForChanges(config.getCertificatePathSync(), &fileChanged));
|
||||
lifetimes.push_back(watchFileForChanges(config.getKeyPathSync(), &fileChanged));
|
||||
lifetimes.push_back(watchFileForChanges(config.getCAPathSync(), &fileChanged));
|
||||
const int& intervalSeconds = FLOW_KNOBS->TLS_CERT_REFRESH_DELAY_SECONDS;
|
||||
lifetimes.push_back(watchFileForChanges(
|
||||
config.getCertificatePathSync(), &fileChanged, &intervalSeconds, "TLSCertificateRefreshStatError"));
|
||||
lifetimes.push_back(
|
||||
watchFileForChanges(config.getKeyPathSync(), &fileChanged, &intervalSeconds, "TLSKeyRefreshStatError"));
|
||||
lifetimes.push_back(
|
||||
watchFileForChanges(config.getCAPathSync(), &fileChanged, &intervalSeconds, "TLSCARefreshStatError"));
|
||||
loop {
|
||||
wait(fileChanged.onTrigger());
|
||||
TraceEvent("TLSCertificateRefreshBegin").log();
|
||||
|
@ -1302,7 +1296,8 @@ ACTOR static Future<Void> reloadCertificatesOnChange(
|
|||
try {
|
||||
LoadedTLSConfig loaded = wait(config.loadAsync());
|
||||
boost::asio::ssl::context context(boost::asio::ssl::context::tls);
|
||||
ConfigureSSLContext(loaded, &context, onPolicyFailure);
|
||||
ConfigureSSLContext(loaded, context);
|
||||
*policy = makeReference<TLSPolicy>(loaded, onPolicyFailure);
|
||||
TraceEvent(SevInfo, "TLSCertificateRefreshSucceeded").log();
|
||||
mismatches = 0;
|
||||
contextVar->set(ReferencedObject<boost::asio::ssl::context>::from(std::move(context)));
|
||||
|
@ -1334,12 +1329,15 @@ void Net2::initTLS(ETLSInitState targetState) {
|
|||
.detail("KeyPath", tlsConfig.getKeyPathSync())
|
||||
.detail("HasPassword", !loaded.getPassword().empty())
|
||||
.detail("VerifyPeers", boost::algorithm::join(loaded.getVerifyPeers(), "|"));
|
||||
ConfigureSSLContext(tlsConfig.loadSync(), &newContext, onPolicyFailure);
|
||||
auto loadedTlsConfig = tlsConfig.loadSync();
|
||||
ConfigureSSLContext(loadedTlsConfig, newContext);
|
||||
activeTlsPolicy = makeReference<TLSPolicy>(loadedTlsConfig, onPolicyFailure);
|
||||
sslContextVar.set(ReferencedObject<boost::asio::ssl::context>::from(std::move(newContext)));
|
||||
} catch (Error& e) {
|
||||
TraceEvent("Net2TLSInitError").error(e);
|
||||
}
|
||||
backgroundCertRefresh = reloadCertificatesOnChange(tlsConfig, onPolicyFailure, &sslContextVar);
|
||||
backgroundCertRefresh =
|
||||
reloadCertificatesOnChange(tlsConfig, onPolicyFailure, &sslContextVar, &activeTlsPolicy);
|
||||
}
|
||||
|
||||
// If a TLS connection is actually going to be used then start background threads if configured
|
||||
|
|
|
@ -81,7 +81,7 @@ void LoadedTLSConfig::print(FILE* fp) {
|
|||
int num_certs = 0;
|
||||
boost::asio::ssl::context context(boost::asio::ssl::context::tls);
|
||||
try {
|
||||
ConfigureSSLContext(*this, &context);
|
||||
ConfigureSSLContext(*this, context);
|
||||
} catch (Error& e) {
|
||||
fprintf(fp, "There was an error in loading the certificate chain.\n");
|
||||
throw;
|
||||
|
@ -109,51 +109,58 @@ void LoadedTLSConfig::print(FILE* fp) {
|
|||
X509_STORE_CTX_free(store_ctx);
|
||||
}
|
||||
|
||||
void ConfigureSSLContext(const LoadedTLSConfig& loaded,
|
||||
boost::asio::ssl::context* context,
|
||||
std::function<void()> onPolicyFailure) {
|
||||
void ConfigureSSLContext(const LoadedTLSConfig& loaded, boost::asio::ssl::context& context) {
|
||||
try {
|
||||
context->set_options(boost::asio::ssl::context::default_workarounds);
|
||||
context->set_verify_mode(boost::asio::ssl::context::verify_peer |
|
||||
boost::asio::ssl::verify_fail_if_no_peer_cert);
|
||||
context.set_options(boost::asio::ssl::context::default_workarounds);
|
||||
auto verifyFailIfNoPeerCert = boost::asio::ssl::verify_fail_if_no_peer_cert;
|
||||
// Servers get to accept connections without peer certs as "untrusted" clients
|
||||
if (loaded.getEndpointType() == TLSEndpointType::SERVER)
|
||||
verifyFailIfNoPeerCert = 0;
|
||||
context.set_verify_mode(boost::asio::ssl::context::verify_peer | verifyFailIfNoPeerCert);
|
||||
|
||||
if (loaded.isTLSEnabled()) {
|
||||
auto tlsPolicy = makeReference<TLSPolicy>(loaded.getEndpointType());
|
||||
tlsPolicy->set_verify_peers({ loaded.getVerifyPeers() });
|
||||
|
||||
context->set_verify_callback(
|
||||
[policy = tlsPolicy, onPolicyFailure](bool preverified, boost::asio::ssl::verify_context& ctx) {
|
||||
bool success = policy->verify_peer(preverified, ctx.native_handle());
|
||||
if (!success) {
|
||||
onPolicyFailure();
|
||||
}
|
||||
return success;
|
||||
});
|
||||
} else {
|
||||
// Insecurely always except if TLS is not enabled.
|
||||
context->set_verify_callback([](bool, boost::asio::ssl::verify_context&) { return true; });
|
||||
}
|
||||
|
||||
context->set_password_callback([password = loaded.getPassword()](
|
||||
size_t, boost::asio::ssl::context::password_purpose) { return password; });
|
||||
context.set_password_callback([password = loaded.getPassword()](
|
||||
size_t, boost::asio::ssl::context::password_purpose) { return password; });
|
||||
|
||||
const std::string& CABytes = loaded.getCABytes();
|
||||
if (CABytes.size()) {
|
||||
context->add_certificate_authority(boost::asio::buffer(CABytes.data(), CABytes.size()));
|
||||
context.add_certificate_authority(boost::asio::buffer(CABytes.data(), CABytes.size()));
|
||||
}
|
||||
|
||||
const std::string& keyBytes = loaded.getKeyBytes();
|
||||
if (keyBytes.size()) {
|
||||
context->use_private_key(boost::asio::buffer(keyBytes.data(), keyBytes.size()),
|
||||
boost::asio::ssl::context::pem);
|
||||
context.use_private_key(boost::asio::buffer(keyBytes.data(), keyBytes.size()),
|
||||
boost::asio::ssl::context::pem);
|
||||
}
|
||||
|
||||
const std::string& certBytes = loaded.getCertificateBytes();
|
||||
if (certBytes.size()) {
|
||||
context->use_certificate_chain(boost::asio::buffer(certBytes.data(), certBytes.size()));
|
||||
context.use_certificate_chain(boost::asio::buffer(certBytes.data(), certBytes.size()));
|
||||
}
|
||||
} catch (boost::system::system_error& e) {
|
||||
TraceEvent("TLSConfigureError")
|
||||
TraceEvent("TLSContextConfigureError")
|
||||
.detail("What", e.what())
|
||||
.detail("Value", e.code().value())
|
||||
.detail("WhichMeans", TLSPolicy::ErrorString(e.code()));
|
||||
throw tls_error();
|
||||
}
|
||||
}
|
||||
|
||||
void ConfigureSSLStream(Reference<TLSPolicy> policy,
|
||||
boost::asio::ssl::stream<boost::asio::ip::tcp::socket&>& stream,
|
||||
std::function<void(bool)> callback) {
|
||||
try {
|
||||
stream.set_verify_callback([policy, callback](bool preverified, boost::asio::ssl::verify_context& ctx) {
|
||||
bool success = policy->verify_peer(preverified, ctx.native_handle());
|
||||
if (!success) {
|
||||
if (policy->on_failure)
|
||||
policy->on_failure();
|
||||
}
|
||||
if (callback)
|
||||
callback(success);
|
||||
return success;
|
||||
});
|
||||
} catch (boost::system::system_error& e) {
|
||||
TraceEvent("TLSStreamConfigureError")
|
||||
.detail("What", e.what())
|
||||
.detail("Value", e.code().value())
|
||||
.detail("WhichMeans", TLSPolicy::ErrorString(e.code()));
|
||||
|
@ -261,6 +268,11 @@ LoadedTLSConfig TLSConfig::loadSync() const {
|
|||
return loaded;
|
||||
}
|
||||
|
||||
TLSPolicy::TLSPolicy(const LoadedTLSConfig& loaded, std::function<void()> on_failure)
|
||||
: rules(), on_failure(std::move(on_failure)), is_client(loaded.getEndpointType() == TLSEndpointType::CLIENT) {
|
||||
set_verify_peers(loaded.getVerifyPeers());
|
||||
}
|
||||
|
||||
// And now do the same thing, but async...
|
||||
|
||||
ACTOR static Future<Void> readEntireFile(std::string filename, std::string* destination) {
|
||||
|
|
|
@ -195,6 +195,9 @@ public:
|
|||
int NETWORK_TEST_REQUEST_SIZE;
|
||||
bool NETWORK_TEST_SCRIPT_MODE;
|
||||
|
||||
// Authorization
|
||||
int PUBLIC_KEY_FILE_MAX_SIZE;
|
||||
int PUBLIC_KEY_FILE_REFRESH_INTERVAL_SECONDS;
|
||||
int MAX_CACHED_EXPIRED_TOKENS;
|
||||
|
||||
// AsyncFileCached
|
||||
|
|
|
@ -39,7 +39,7 @@ void printPrivateKey(FILE* out, StringRef privateKeyPem);
|
|||
|
||||
PrivateKey makeEcP256();
|
||||
|
||||
PrivateKey makeRsa2048Bit();
|
||||
PrivateKey makeRsa4096Bit();
|
||||
|
||||
struct Asn1EntryRef {
|
||||
// field must match one of ASN.1 object short/long names: e.g. "C", "countryName", "CN", "commonName",
|
||||
|
|
|
@ -320,7 +320,7 @@ std::string readFileBytes(std::string const& filename, int maxSize);
|
|||
|
||||
// Read a file into memory supplied by the caller
|
||||
// If 'len' is greater than file size, then read the filesize bytes.
|
||||
void readFileBytes(std::string const& filename, uint8_t* buff, int64_t len);
|
||||
size_t readFileBytes(std::string const& filename, uint8_t* buff, int64_t len);
|
||||
|
||||
// Write data buffer into file
|
||||
void writeFileBytes(std::string const& filename, const char* data, size_t count);
|
||||
|
|
|
@ -33,6 +33,8 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
#include <boost/system/system_error.hpp>
|
||||
#include <boost/asio/ip/tcp.hpp>
|
||||
#include <boost/asio/ssl.hpp>
|
||||
#include "flow/FastRef.h"
|
||||
#include "flow/Knobs.h"
|
||||
#include "flow/flow.h"
|
||||
|
@ -201,21 +203,23 @@ private:
|
|||
TLSEndpointType endpointType = TLSEndpointType::UNSET;
|
||||
};
|
||||
|
||||
namespace boost {
|
||||
namespace asio {
|
||||
namespace ssl {
|
||||
struct context;
|
||||
}
|
||||
} // namespace asio
|
||||
} // namespace boost
|
||||
void ConfigureSSLContext(
|
||||
const LoadedTLSConfig& loaded,
|
||||
boost::asio::ssl::context* context,
|
||||
std::function<void()> onPolicyFailure = []() {});
|
||||
class TLSPolicy;
|
||||
|
||||
void ConfigureSSLContext(const LoadedTLSConfig& loaded, boost::asio::ssl::context& context);
|
||||
|
||||
// Set up SSL for stream object based on policy.
|
||||
// Optionally arm a callback that gets called with verify-outcome of each cert in peer certificate chain:
|
||||
// e.g. for peer with a valid, trusted length-3 certificate chain (root CA, intermediate CA, and server certs),
|
||||
// callback(true) will be called 3 times.
|
||||
void ConfigureSSLStream(Reference<TLSPolicy> policy,
|
||||
boost::asio::ssl::stream<boost::asio::ip::tcp::socket&>& stream,
|
||||
std::function<void(bool)> callback);
|
||||
|
||||
class TLSPolicy : ReferenceCounted<TLSPolicy> {
|
||||
void set_verify_peers(std::vector<std::string> verify_peers);
|
||||
|
||||
public:
|
||||
TLSPolicy(TLSEndpointType client) : is_client(client == TLSEndpointType::CLIENT) {}
|
||||
TLSPolicy(const LoadedTLSConfig& loaded, std::function<void()> on_failure);
|
||||
virtual ~TLSPolicy();
|
||||
|
||||
virtual void addref() { ReferenceCounted<TLSPolicy>::addref(); }
|
||||
|
@ -223,7 +227,6 @@ public:
|
|||
|
||||
static std::string ErrorString(boost::system::error_code e);
|
||||
|
||||
void set_verify_peers(std::vector<std::string> verify_peers);
|
||||
bool verify_peer(bool preverified, X509_STORE_CTX* store_ctx);
|
||||
|
||||
std::string toString() const;
|
||||
|
@ -242,6 +245,7 @@ public:
|
|||
};
|
||||
|
||||
std::vector<Rule> rules;
|
||||
std::function<void()> on_failure;
|
||||
bool is_client;
|
||||
};
|
||||
|
||||
|
|
|
@ -0,0 +1,77 @@
|
|||
/*
|
||||
* WatchFile.actor.h
|
||||
*
|
||||
* This source file is part of the FoundationDB open source project
|
||||
*
|
||||
* Copyright 2013-2022 Apple Inc. and the FoundationDB project authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
// When actually compiled (NO_INTELLISENSE), include the generated
|
||||
// version of this file. In intellisense use the source version.
|
||||
#if defined(NO_INTELLISENSE) && !defined(FLOW_WATCH_FILE_ACTOR_G_H)
|
||||
#define FLOW_WATCH_FILE_ACTOR_G_H
|
||||
#include "flow/WatchFile.actor.g.h"
|
||||
#elif !defined(FLOW_WATCH_FILE_ACTOR_H)
|
||||
#define FLOW_WATCH_FILE_ACTOR_H
|
||||
|
||||
#include <ctime>
|
||||
#include <string>
|
||||
#include "flow/IAsyncFile.h"
|
||||
#include "flow/genericactors.actor.h"
|
||||
#include "flow/actorcompiler.h"
|
||||
|
||||
ACTOR static Future<Void> watchFileForChanges(std::string filename,
|
||||
AsyncTrigger* fileChanged,
|
||||
const int* intervalSeconds,
|
||||
const char* errorType) {
|
||||
if (filename == "") {
|
||||
return Never();
|
||||
}
|
||||
state bool firstRun = true;
|
||||
state bool statError = false;
|
||||
state std::time_t lastModTime = 0;
|
||||
loop {
|
||||
try {
|
||||
std::time_t modtime = wait(IAsyncFileSystem::filesystem()->lastWriteTime(filename));
|
||||
if (firstRun) {
|
||||
lastModTime = modtime;
|
||||
firstRun = false;
|
||||
}
|
||||
if (lastModTime != modtime || statError) {
|
||||
lastModTime = modtime;
|
||||
statError = false;
|
||||
fileChanged->trigger();
|
||||
}
|
||||
} catch (Error& e) {
|
||||
if (e.code() == error_code_io_error) {
|
||||
// EACCES, ELOOP, ENOENT all come out as io_error(), but are more of a system
|
||||
// configuration issue than an FDB problem. If we managed to load valid
|
||||
// certificates, then there's no point in crashing, but we should complain
|
||||
// loudly. IAsyncFile will log the error, but not necessarily as a warning.
|
||||
TraceEvent(SevWarnAlways, errorType).detail("File", filename);
|
||||
statError = true;
|
||||
} else {
|
||||
throw;
|
||||
}
|
||||
}
|
||||
wait(delay(*intervalSeconds));
|
||||
}
|
||||
}
|
||||
|
||||
#include "flow/unactorcompiler.h"
|
||||
|
||||
#endif // FLOW_WATCH_FILE_ACTOR_H
|
|
@ -467,6 +467,11 @@ public:
|
|||
// this may not be an address we can connect to!
|
||||
virtual NetworkAddress getPeerAddress() const = 0;
|
||||
|
||||
// Returns whether the peer is trusted.
|
||||
// For TLS-enabled connections, this is true if the peer has presented a valid chain of certificates trusted by the
|
||||
// local endpoint. For non-TLS connections this is always true for any valid open connection.
|
||||
virtual bool hasTrustedPeer() const = 0;
|
||||
|
||||
virtual UID getDebugID() const = 0;
|
||||
|
||||
// At present, implemented by Sim2Conn where we want to disable bits flip for connections between parent process and
|
||||
|
|
|
@ -40,7 +40,7 @@ if(WITH_PYTHON)
|
|||
|
||||
configure_testing(TEST_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}"
|
||||
ERROR_ON_ADDITIONAL_FILES
|
||||
IGNORE_PATTERNS ".*/CMakeLists.txt")
|
||||
IGNORE_PATTERNS ".*/CMakeLists.txt" ".*/requirements.txt")
|
||||
|
||||
add_fdb_test(TEST_FILES AsyncFileCorrectness.txt UNIT IGNORE)
|
||||
add_fdb_test(TEST_FILES AsyncFileMix.txt UNIT IGNORE)
|
||||
|
@ -396,6 +396,39 @@ if(WITH_PYTHON)
|
|||
create_valgrind_correctness_package()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (NOT WIN32)
|
||||
# setup venv for testing token-based authorization
|
||||
set(authz_venv_dir ${CMAKE_CURRENT_BINARY_DIR}/authorization_test_venv)
|
||||
set(authz_venv_activate ". ${authz_venv_dir}/bin/activate")
|
||||
set(authz_venv_stamp_file ${authz_venv_dir}/venv.ready)
|
||||
set(authz_venv_cmd "")
|
||||
string(APPEND authz_venv_cmd "[[ ! -f ${authz_venv_stamp_file} ]] && ")
|
||||
string(APPEND authz_venv_cmd "${Python3_EXECUTABLE} -m venv ${authz_venv_dir} ")
|
||||
string(APPEND authz_venv_cmd "&& ${authz_venv_activate} ")
|
||||
string(APPEND authz_venv_cmd "&& pip install --upgrade pip ")
|
||||
string(APPEND authz_venv_cmd "&& pip install --upgrade -r ${CMAKE_SOURCE_DIR}/tests/authorization/requirements.txt ")
|
||||
string(APPEND authz_venv_cmd "&& (cd ${CMAKE_BINARY_DIR}/bindings/python && python3 setup.py install) ")
|
||||
string(APPEND authz_venv_cmd "&& touch ${authz_venv_stamp_file} ")
|
||||
string(APPEND authz_venv_cmd "|| echo 'venv already set up'")
|
||||
add_test(
|
||||
NAME authorization_venv_setup
|
||||
COMMAND bash -c ${authz_venv_cmd}
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
|
||||
set_tests_properties(authorization_venv_setup PROPERTIES FIXTURES_SETUP authz_virtual_env TIMEOUT 60)
|
||||
|
||||
set(authz_script_dir ${CMAKE_SOURCE_DIR}/tests/authorization)
|
||||
set(authz_test_cmd "")
|
||||
string(APPEND authz_test_cmd "${authz_venv_activate} && ")
|
||||
string(APPEND authz_test_cmd "LD_LIBRARY_PATH=${CMAKE_BINARY_DIR}/lib pytest ${authz_script_dir}/authz_test.py -rA --build-dir ${CMAKE_BINARY_DIR} -vvv")
|
||||
add_test(
|
||||
NAME token_based_tenant_authorization
|
||||
WORKING_DIRECTORY ${authz_script_dir}
|
||||
COMMAND bash -c ${authz_test_cmd})
|
||||
set_tests_properties(token_based_tenant_authorization PROPERTIES ENVIRONMENT PYTHONPATH=${CMAKE_SOURCE_DIR}/tests/TestRunner) # (local|tmp)_cluster.py
|
||||
set_tests_properties(token_based_tenant_authorization PROPERTIES FIXTURES_REQUIRED authz_virtual_env)
|
||||
set_tests_properties(token_based_tenant_authorization PROPERTIES TIMEOUT 120)
|
||||
endif()
|
||||
else()
|
||||
message(WARNING "Python not found, won't configure ctest")
|
||||
endif()
|
||||
|
|
|
@ -86,6 +86,8 @@ datadir = {datadir}/$ID
|
|||
logdir = {logdir}
|
||||
{bg_knob_line}
|
||||
{tls_config}
|
||||
{authz_public_key_config}
|
||||
{custom_config}
|
||||
{use_future_protocol_version}
|
||||
# logsize = 10MiB
|
||||
# maxlogssize = 100MiB
|
||||
|
@ -117,6 +119,8 @@ logdir = {logdir}
|
|||
redundancy: str = "single",
|
||||
tls_config: TLSConfig = None,
|
||||
mkcert_binary: str = "",
|
||||
custom_config: dict = {},
|
||||
public_key_json_str: str = "",
|
||||
):
|
||||
self.basedir = Path(basedir)
|
||||
self.etc = self.basedir.joinpath("etc")
|
||||
|
@ -137,6 +141,7 @@ logdir = {logdir}
|
|||
self.redundancy = redundancy
|
||||
self.ip_address = "127.0.0.1" if ip_address is None else ip_address
|
||||
self.first_port = port
|
||||
self.custom_config = custom_config
|
||||
self.blob_granules_enabled = blob_granules_enabled
|
||||
if blob_granules_enabled:
|
||||
# add extra process for blob_worker
|
||||
|
@ -158,6 +163,7 @@ logdir = {logdir}
|
|||
self.coordinators = set()
|
||||
self.active_servers = set(self.server_ports.keys())
|
||||
self.tls_config = tls_config
|
||||
self.public_key_json_file = None
|
||||
self.mkcert_binary = Path(mkcert_binary)
|
||||
self.server_cert_file = self.cert.joinpath("server_cert.pem")
|
||||
self.client_cert_file = self.cert.joinpath("client_cert.pem")
|
||||
|
@ -166,6 +172,11 @@ logdir = {logdir}
|
|||
self.server_ca_file = self.cert.joinpath("server_ca.pem")
|
||||
self.client_ca_file = self.cert.joinpath("client_ca.pem")
|
||||
|
||||
if public_key_json_str:
|
||||
self.public_key_json_file = self.etc.joinpath("public_keys.json")
|
||||
with open(self.public_key_json_file, "w") as pubkeyfile:
|
||||
pubkeyfile.write(public_key_json_str)
|
||||
|
||||
if create_config:
|
||||
self.create_cluster_file()
|
||||
self.save_config()
|
||||
|
@ -173,6 +184,8 @@ logdir = {logdir}
|
|||
if self.tls_config is not None:
|
||||
self.create_tls_cert()
|
||||
|
||||
self.cluster_file = self.etc.joinpath("fdb.cluster")
|
||||
|
||||
def __next_port(self):
|
||||
if self.first_port is None:
|
||||
return get_free_port()
|
||||
|
@ -198,10 +211,10 @@ logdir = {logdir}
|
|||
ip_address=self.ip_address,
|
||||
bg_knob_line=bg_knob_line,
|
||||
tls_config=self.tls_conf_string(),
|
||||
authz_public_key_config=self.authz_public_key_conf_string(),
|
||||
optional_tls=":tls" if self.tls_config is not None else "",
|
||||
use_future_protocol_version="use-future-protocol-version = true"
|
||||
if self.use_future_protocol_version
|
||||
else "",
|
||||
custom_config='\n'.join(["{} = {}".format(key, value) for key, value in self.custom_config.items()]),
|
||||
use_future_protocol_version="use-future-protocol-version = true" if self.use_future_protocol_version else "",
|
||||
)
|
||||
)
|
||||
# By default, the cluster only has one process
|
||||
|
@ -369,6 +382,12 @@ logdir = {logdir}
|
|||
}
|
||||
return "\n".join("{} = {}".format(k, v) for k, v in conf_map.items())
|
||||
|
||||
def authz_public_key_conf_string(self):
|
||||
if self.public_key_json_file is not None:
|
||||
return "authorization-public-key-file = {}".format(self.public_key_json_file)
|
||||
else:
|
||||
return ""
|
||||
|
||||
# Get cluster status using fdbcli
|
||||
def get_status(self):
|
||||
status_output = self.fdbcli_exec_and_get("status json")
|
||||
|
|
|
@ -18,6 +18,9 @@ class TempCluster(LocalCluster):
|
|||
port: str = None,
|
||||
blob_granules_enabled: bool = False,
|
||||
tls_config: TLSConfig = None,
|
||||
public_key_json_str: str = None,
|
||||
remove_at_exit: bool = True,
|
||||
custom_config: dict = {},
|
||||
enable_tenants: bool = True,
|
||||
):
|
||||
self.build_dir = Path(build_dir).resolve()
|
||||
|
@ -26,6 +29,7 @@ class TempCluster(LocalCluster):
|
|||
tmp_dir = self.build_dir.joinpath("tmp", random_secret_string(16))
|
||||
tmp_dir.mkdir(parents=True)
|
||||
self.tmp_dir = tmp_dir
|
||||
self.remove_at_exit = remove_at_exit
|
||||
self.enable_tenants = enable_tenants
|
||||
super().__init__(
|
||||
tmp_dir,
|
||||
|
@ -37,6 +41,8 @@ class TempCluster(LocalCluster):
|
|||
blob_granules_enabled=blob_granules_enabled,
|
||||
tls_config=tls_config,
|
||||
mkcert_binary=self.build_dir.joinpath("bin", "mkcert"),
|
||||
public_key_json_str=public_key_json_str,
|
||||
custom_config=custom_config,
|
||||
)
|
||||
|
||||
def __enter__(self):
|
||||
|
@ -49,11 +55,13 @@ class TempCluster(LocalCluster):
|
|||
|
||||
def __exit__(self, xc_type, exc_value, traceback):
|
||||
super().__exit__(xc_type, exc_value, traceback)
|
||||
shutil.rmtree(self.tmp_dir)
|
||||
if self.remove_at_exit:
|
||||
shutil.rmtree(self.tmp_dir)
|
||||
|
||||
def close(self):
|
||||
super().__exit__(None, None, None)
|
||||
shutil.rmtree(self.tmp_dir)
|
||||
if self.remove_at_exit:
|
||||
shutil.rmtree(self.tmp_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -147,11 +155,11 @@ if __name__ == "__main__":
|
|||
print("log-dir: {}".format(cluster.log))
|
||||
print("etc-dir: {}".format(cluster.etc))
|
||||
print("data-dir: {}".format(cluster.data))
|
||||
print("cluster-file: {}".format(cluster.etc.joinpath("fdb.cluster")))
|
||||
print("cluster-file: {}".format(cluster.cluster_file))
|
||||
cmd_args = []
|
||||
for cmd in args.cmd:
|
||||
if cmd == "@CLUSTER_FILE@":
|
||||
cmd_args.append(str(cluster.etc.joinpath("fdb.cluster")))
|
||||
cmd_args.append(str(cluster.cluster_file))
|
||||
elif cmd == "@DATA_DIR@":
|
||||
cmd_args.append(str(cluster.data))
|
||||
elif cmd == "@LOG_DIR@":
|
||||
|
@ -178,7 +186,7 @@ if __name__ == "__main__":
|
|||
cmd_args.append(cmd)
|
||||
env = dict(**os.environ)
|
||||
env["FDB_CLUSTER_FILE"] = env.get(
|
||||
"FDB_CLUSTER_FILE", cluster.etc.joinpath("fdb.cluster")
|
||||
"FDB_CLUSTER_FILE", cluster.cluster_file
|
||||
)
|
||||
errcode = subprocess.run(
|
||||
cmd_args, stdout=sys.stdout, stderr=sys.stderr, env=env
|
||||
|
|
|
@ -0,0 +1,135 @@
|
|||
#!/usr/bin/python
|
||||
#
|
||||
# admin_server.py
|
||||
#
|
||||
# This source file is part of the FoundationDB open source project
|
||||
#
|
||||
# Copyright 2013-2022 Apple Inc. and the FoundationDB project authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import fdb
|
||||
from multiprocessing import Pipe, Process
|
||||
from typing import Union, List
|
||||
from util import to_str, to_bytes, cleanup_tenant
|
||||
|
||||
class _admin_request(object):
|
||||
def __init__(self, op: str, args: List[Union[str, bytes]]=[]):
|
||||
self.op = op
|
||||
self.args = args
|
||||
|
||||
def __str__(self):
|
||||
return f"admin_request({self.op}, {self.args})"
|
||||
|
||||
def __repr__(self):
|
||||
return f"admin_request({self.op}, {self.args})"
|
||||
|
||||
def main_loop(main_pipe, pipe):
|
||||
main_pipe.close()
|
||||
db = None
|
||||
while True:
|
||||
try:
|
||||
req = pipe.recv()
|
||||
except EOFError:
|
||||
return
|
||||
if not isinstance(req, _admin_request):
|
||||
pipe.send(TypeError("unexpected type {}".format(type(req))))
|
||||
continue
|
||||
op = req.op
|
||||
args = req.args
|
||||
resp = True
|
||||
try:
|
||||
if op == "connect":
|
||||
db = fdb.open(req.args[0])
|
||||
elif op == "configure_tls":
|
||||
keyfile, certfile, cafile = req.args[:3]
|
||||
fdb.options.set_tls_key_path(keyfile)
|
||||
fdb.options.set_tls_cert_path(certfile)
|
||||
fdb.options.set_tls_ca_path(cafile)
|
||||
elif op == "create_tenant":
|
||||
if db is None:
|
||||
resp = Exception("db not open")
|
||||
else:
|
||||
for tenant in req.args:
|
||||
tenant_str = to_str(tenant)
|
||||
tenant_bytes = to_bytes(tenant)
|
||||
fdb.tenant_management.create_tenant(db, tenant_bytes)
|
||||
elif op == "delete_tenant":
|
||||
if db is None:
|
||||
resp = Exception("db not open")
|
||||
else:
|
||||
for tenant in req.args:
|
||||
tenant_str = to_str(tenant)
|
||||
tenant_bytes = to_bytes(tenant)
|
||||
cleanup_tenant(db, tenant_bytes)
|
||||
elif op == "cleanup_database":
|
||||
if db is None:
|
||||
resp = Exception("db not open")
|
||||
else:
|
||||
tr = db.create_transaction()
|
||||
del tr[b'':b'\xff']
|
||||
tr.commit().wait()
|
||||
tenants = list(map(lambda x: x.key, list(fdb.tenant_management.list_tenants(db, b'', b'\xff', 0).to_list())))
|
||||
for tenant in tenants:
|
||||
fdb.tenant_management.delete_tenant(db, tenant)
|
||||
elif op == "terminate":
|
||||
pipe.send(True)
|
||||
return
|
||||
else:
|
||||
resp = ValueError("unknown operation: {}".format(req))
|
||||
except Exception as e:
|
||||
resp = e
|
||||
pipe.send(resp)
|
||||
|
||||
_admin_server = None
|
||||
|
||||
def get():
|
||||
return _admin_server
|
||||
|
||||
# server needs to be a singleton running in subprocess, because FDB network layer (including active TLS config) is a global var
|
||||
class Server(object):
|
||||
def __init__(self):
|
||||
global _admin_server
|
||||
assert _admin_server is None, "admin server may be setup once per process"
|
||||
_admin_server = self
|
||||
self._main_pipe, self._admin_pipe = Pipe(duplex=True)
|
||||
self._admin_proc = Process(target=main_loop, args=(self._main_pipe, self._admin_pipe))
|
||||
|
||||
def start(self):
|
||||
self._admin_proc.start()
|
||||
|
||||
def join(self):
|
||||
self._main_pipe.close()
|
||||
self._admin_pipe.close()
|
||||
self._admin_proc.join()
|
||||
|
||||
def __enter__(self):
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||||
self.join()
|
||||
|
||||
def request(self, op, args=[]):
|
||||
req = _admin_request(op, args)
|
||||
try:
|
||||
self._main_pipe.send(req)
|
||||
resp = self._main_pipe.recv()
|
||||
if resp != True:
|
||||
print("{} failed: {}".format(req, resp))
|
||||
raise resp
|
||||
else:
|
||||
print("{} succeeded".format(req))
|
||||
except Exception as e:
|
||||
print("{} failed by exception: {}".format(req, e))
|
||||
raise
|
|
@ -0,0 +1,297 @@
|
|||
#!/usr/bin/python
|
||||
#
|
||||
# authz_test.py
|
||||
#
|
||||
# This source file is part of the FoundationDB open source project
|
||||
#
|
||||
# Copyright 2013-2022 Apple Inc. and the FoundationDB project authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import admin_server
|
||||
import argparse
|
||||
import authlib
|
||||
import fdb
|
||||
import os
|
||||
import pytest
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
from multiprocessing import Process, Pipe
|
||||
from typing import Union
|
||||
from util import alg_from_kty, public_keyset_from_keys, random_alphanum_str, random_alphanum_bytes, to_str, to_bytes, KeyFileReverter, token_claim_1h, wait_until_tenant_tr_succeeds, wait_until_tenant_tr_fails
|
||||
|
||||
special_key_ranges = [
|
||||
("transaction description", b"/description", b"/description\x00"),
|
||||
("global knobs", b"/globalKnobs", b"/globalKnobs\x00"),
|
||||
("knobs", b"/knobs0", b"/knobs0\x00"),
|
||||
("conflicting keys", b"/transaction/conflicting_keys/", b"/transaction/conflicting_keys/\xff\xff"),
|
||||
("read conflict range", b"/transaction/read_conflict_range/", b"/transaction/read_conflict_range/\xff\xff"),
|
||||
("conflicting keys", b"/transaction/write_conflict_range/", b"/transaction/write_conflict_range/\xff\xff"),
|
||||
("data distribution stats", b"/metrics/data_distribution_stats/", b"/metrics/data_distribution_stats/\xff\xff"),
|
||||
("kill storage", b"/globals/killStorage", b"/globals/killStorage\x00"),
|
||||
]
|
||||
|
||||
def test_simple_tenant_access(private_key, token_gen, default_tenant, tenant_tr_gen):
|
||||
token = token_gen(private_key, token_claim_1h(default_tenant))
|
||||
tr = tenant_tr_gen(default_tenant)
|
||||
tr.options.set_authorization_token(token)
|
||||
tr[b"abc"] = b"def"
|
||||
tr.commit().wait()
|
||||
tr = tenant_tr_gen(default_tenant)
|
||||
tr.options.set_authorization_token(token)
|
||||
assert tr[b"abc"] == b"def", "tenant write transaction not visible"
|
||||
|
||||
def test_cross_tenant_access_disallowed(private_key, default_tenant, token_gen, tenant_gen, tenant_tr_gen):
|
||||
# use default tenant token with second tenant transaction and see it fail
|
||||
second_tenant = random_alphanum_bytes(12)
|
||||
tenant_gen(second_tenant)
|
||||
token_second = token_gen(private_key, token_claim_1h(second_tenant))
|
||||
tr_second = tenant_tr_gen(second_tenant)
|
||||
tr_second.options.set_authorization_token(token_second)
|
||||
tr_second[b"abc"] = b"def"
|
||||
tr_second.commit().wait()
|
||||
token_default = token_gen(private_key, token_claim_1h(default_tenant))
|
||||
tr_second = tenant_tr_gen(second_tenant)
|
||||
tr_second.options.set_authorization_token(token_default)
|
||||
# test that read transaction fails
|
||||
try:
|
||||
value = tr_second[b"abc"].value
|
||||
assert False, f"expected permission denied, but read transaction went through, value: {value}"
|
||||
except fdb.FDBError as e:
|
||||
assert e.code == 6000, f"expected permission_denied, got {e} instead"
|
||||
# test that write transaction fails
|
||||
tr_second = tenant_tr_gen(second_tenant)
|
||||
tr_second.options.set_authorization_token(token_default)
|
||||
try:
|
||||
tr_second[b"def"] = b"ghi"
|
||||
tr_second.commit().wait()
|
||||
assert False, "expected permission denied, but write transaction went through"
|
||||
except fdb.FDBError as e:
|
||||
assert e.code == 6000, f"expected permission_denied, got {e} instead"
|
||||
|
||||
def test_system_and_special_key_range_disallowed(db, tenant_tr_gen, token_gen):
|
||||
second_tenant = random_alphanum_bytes(12)
|
||||
try:
|
||||
fdb.tenant_management.create_tenant(db, second_tenant)
|
||||
assert False, "disallowed create_tenant has succeeded"
|
||||
except fdb.FDBError as e:
|
||||
assert e.code == 6000, f"expected permission_denied, got {e} instead"
|
||||
|
||||
try:
|
||||
tr = db.create_transaction()
|
||||
tr.options.set_access_system_keys()
|
||||
kvs = tr.get_range(b"\xff", b"\xff\xff", limit=1).to_list()
|
||||
assert False, f"disallowed system keyspace read has succeeded. found item: {kvs}"
|
||||
except fdb.FDBError as e:
|
||||
assert e.code == 6000, f"expected permission_denied, got {e} instead"
|
||||
|
||||
for range_name, special_range_begin, special_range_end in special_key_ranges:
|
||||
tr = db.create_transaction()
|
||||
tr.options.set_access_system_keys()
|
||||
tr.options.set_special_key_space_relaxed()
|
||||
try:
|
||||
kvs = tr.get_range(special_range_begin, special_range_end, limit=1).to_list()
|
||||
assert False, f"disallowed special keyspace read for range {range_name} has succeeded. found item {kvs}"
|
||||
except fdb.FDBError as e:
|
||||
assert e.code == 6000, f"expected permission_denied from attempted read to range {range_name}, got {e} instead"
|
||||
|
||||
try:
|
||||
tr = db.create_transaction()
|
||||
tr.options.set_access_system_keys()
|
||||
del tr[b"\xff":b"\xff\xff"]
|
||||
tr.commit().wait()
|
||||
assert False, f"disallowed system keyspace write has succeeded"
|
||||
except fdb.FDBError as e:
|
||||
assert e.code == 6000, f"expected permission_denied, got {e} instead"
|
||||
|
||||
for range_name, special_range_begin, special_range_end in special_key_ranges:
|
||||
tr = db.create_transaction()
|
||||
tr.options.set_access_system_keys()
|
||||
tr.options.set_special_key_space_relaxed()
|
||||
try:
|
||||
del tr[special_range_begin:special_range_end]
|
||||
tr.commit().wait()
|
||||
assert False, f"write to disallowed special keyspace range {range_name} has succeeded"
|
||||
except fdb.FDBError as e:
|
||||
assert e.code == 6000, f"expected permission_denied from attempted write to range {range_name}, got {e} instead"
|
||||
|
||||
try:
|
||||
tr = db.create_transaction()
|
||||
tr.options.set_access_system_keys()
|
||||
kvs = tr.get_range(b"", b"\xff", limit=1).to_list()
|
||||
assert False, f"disallowed normal keyspace read has succeeded. found item {kvs}"
|
||||
except fdb.FDBError as e:
|
||||
assert e.code == 6000, f"expected permission_denied, got {e} instead"
|
||||
|
||||
def test_public_key_set_rollover(
|
||||
kty, private_key_gen, private_key, public_key_refresh_interval,
|
||||
cluster, default_tenant, token_gen, tenant_gen, tenant_tr_gen):
|
||||
new_kid = random_alphanum_str(12)
|
||||
new_kty = "EC" if kty == "RSA" else "RSA"
|
||||
new_key = private_key_gen(kty=new_kty, kid=new_kid)
|
||||
token_default = token_gen(private_key, token_claim_1h(default_tenant))
|
||||
|
||||
second_tenant = random_alphanum_bytes(12)
|
||||
tenant_gen(second_tenant)
|
||||
token_second = token_gen(new_key, token_claim_1h(second_tenant))
|
||||
|
||||
interim_set = public_keyset_from_keys([new_key, private_key])
|
||||
max_repeat = 10
|
||||
|
||||
print(f"interim keyset: {interim_set}")
|
||||
old_key_json = None
|
||||
with open(cluster.public_key_json_file, "r") as keyfile:
|
||||
old_key_json = keyfile.read()
|
||||
|
||||
delay = public_key_refresh_interval
|
||||
|
||||
with KeyFileReverter(cluster.public_key_json_file, old_key_json, delay):
|
||||
with open(cluster.public_key_json_file, "w") as keyfile:
|
||||
keyfile.write(interim_set)
|
||||
wait_until_tenant_tr_succeeds(second_tenant, new_key, tenant_tr_gen, token_gen, max_repeat, delay)
|
||||
print("interim key set activated")
|
||||
final_set = public_keyset_from_keys([new_key])
|
||||
print(f"final keyset: {final_set}")
|
||||
with open(cluster.public_key_json_file, "w") as keyfile:
|
||||
keyfile.write(final_set)
|
||||
wait_until_tenant_tr_fails(default_tenant, private_key, tenant_tr_gen, token_gen, max_repeat, delay)
|
||||
|
||||
def test_public_key_set_broken_file_tolerance(
|
||||
private_key, public_key_refresh_interval,
|
||||
cluster, public_key_jwks_str, default_tenant, token_gen, tenant_tr_gen):
|
||||
delay = public_key_refresh_interval
|
||||
# retry limit in waiting for keyset file update to propagate to FDB server's internal keyset
|
||||
max_repeat = 10
|
||||
|
||||
with KeyFileReverter(cluster.public_key_json_file, public_key_jwks_str, delay):
|
||||
# key file update should take effect even after witnessing broken key file
|
||||
with open(cluster.public_key_json_file, "w") as keyfile:
|
||||
keyfile.write(public_key_jwks_str.strip()[:10]) # make the file partial, injecting parse error
|
||||
time.sleep(delay * 2)
|
||||
# should still work; internal key set only clears with a valid, empty key set file
|
||||
tr_default = tenant_tr_gen(default_tenant)
|
||||
tr_default.options.set_authorization_token(token_gen(private_key, token_claim_1h(default_tenant)))
|
||||
tr_default[b"abc"] = b"def"
|
||||
tr_default.commit().wait()
|
||||
with open(cluster.public_key_json_file, "w") as keyfile:
|
||||
keyfile.write('{"keys":[]}')
|
||||
# eventually internal key set will become empty and won't accept any new tokens
|
||||
wait_until_tenant_tr_fails(default_tenant, private_key, tenant_tr_gen, token_gen, max_repeat, delay)
|
||||
|
||||
def test_public_key_set_deletion_tolerance(
|
||||
private_key, public_key_refresh_interval,
|
||||
cluster, public_key_jwks_str, default_tenant, token_gen, tenant_tr_gen):
|
||||
delay = public_key_refresh_interval
|
||||
# retry limit in waiting for keyset file update to propagate to FDB server's internal keyset
|
||||
max_repeat = 10
|
||||
|
||||
with KeyFileReverter(cluster.public_key_json_file, public_key_jwks_str, delay):
|
||||
# key file update should take effect even after witnessing deletion of key file
|
||||
with open(cluster.public_key_json_file, "w") as keyfile:
|
||||
keyfile.write('{"keys":[]}')
|
||||
time.sleep(delay)
|
||||
wait_until_tenant_tr_fails(default_tenant, private_key, tenant_tr_gen, token_gen, max_repeat, delay)
|
||||
os.remove(cluster.public_key_json_file)
|
||||
time.sleep(delay * 2)
|
||||
with open(cluster.public_key_json_file, "w") as keyfile:
|
||||
keyfile.write(public_key_jwks_str)
|
||||
# eventually updated key set should take effect and transaction should be accepted
|
||||
wait_until_tenant_tr_succeeds(default_tenant, private_key, tenant_tr_gen, token_gen, max_repeat, delay)
|
||||
|
||||
def test_public_key_set_empty_file_tolerance(
|
||||
private_key, public_key_refresh_interval,
|
||||
cluster, public_key_jwks_str, default_tenant, token_gen, tenant_tr_gen):
|
||||
delay = public_key_refresh_interval
|
||||
# retry limit in waiting for keyset file update to propagate to FDB server's internal keyset
|
||||
max_repeat = 10
|
||||
|
||||
with KeyFileReverter(cluster.public_key_json_file, public_key_jwks_str, delay):
|
||||
# key file update should take effect even after witnessing an empty file
|
||||
with open(cluster.public_key_json_file, "w") as keyfile:
|
||||
keyfile.write('{"keys":[]}')
|
||||
# eventually internal key set will become empty and won't accept any new tokens
|
||||
wait_until_tenant_tr_fails(default_tenant, private_key, tenant_tr_gen, token_gen, max_repeat, delay)
|
||||
# empty the key file
|
||||
with open(cluster.public_key_json_file, "w") as keyfile:
|
||||
pass
|
||||
time.sleep(delay * 2)
|
||||
with open(cluster.public_key_json_file, "w") as keyfile:
|
||||
keyfile.write(public_key_jwks_str)
|
||||
# eventually key file should update and transactions should go through
|
||||
wait_until_tenant_tr_succeeds(default_tenant, private_key, tenant_tr_gen, token_gen, max_repeat, delay)
|
||||
|
||||
def test_bad_token(private_key, token_gen, default_tenant, tenant_tr_gen):
|
||||
def del_attr(d, attr):
|
||||
del d[attr]
|
||||
return d
|
||||
|
||||
def set_attr(d, attr, value):
|
||||
d[attr] = value
|
||||
return d
|
||||
|
||||
claim_mutations = [
|
||||
("no nbf", lambda claim: del_attr(claim, "nbf")),
|
||||
("no exp", lambda claim: del_attr(claim, "exp")),
|
||||
("no iat", lambda claim: del_attr(claim, "iat")),
|
||||
("too early", lambda claim: set_attr(claim, "nbf", time.time() + 30)),
|
||||
("too late", lambda claim: set_attr(claim, "exp", time.time() - 10)),
|
||||
("no tenants", lambda claim: del_attr(claim, "tenants")),
|
||||
("empty tenants", lambda claim: set_attr(claim, "tenants", [])),
|
||||
]
|
||||
for case_name, mutation in claim_mutations:
|
||||
tr = tenant_tr_gen(default_tenant)
|
||||
tr.options.set_authorization_token(token_gen(private_key, mutation(token_claim_1h(default_tenant))))
|
||||
try:
|
||||
value = tr[b"abc"].value
|
||||
assert False, f"expected permission_denied for case {case_name}, but read transaction went through"
|
||||
except fdb.FDBError as e:
|
||||
assert e.code == 6000, f"expected permission_denied for case {case_name}, got {e} instead"
|
||||
tr = tenant_tr_gen(default_tenant)
|
||||
tr.options.set_authorization_token(token_gen(private_key, mutation(token_claim_1h(default_tenant))))
|
||||
tr[b"abc"] = b"def"
|
||||
try:
|
||||
tr.commit().wait()
|
||||
assert False, f"expected permission_denied for case {case_name}, but write transaction went through"
|
||||
except fdb.FDBError as e:
|
||||
assert e.code == 6000, f"expected permission_denied for case {case_name}, got {e} instead"
|
||||
|
||||
# unknown key case: override "kid" field in header
|
||||
# first, update only the kid field of key with export-update-import
|
||||
key_dict = private_key.as_dict(is_private=True)
|
||||
key_dict["kid"] = random_alphanum_str(10)
|
||||
renamed_key = authlib.jose.JsonWebKey.import_key(key_dict)
|
||||
unknown_key_token = token_gen(
|
||||
renamed_key,
|
||||
token_claim_1h(default_tenant),
|
||||
headers={
|
||||
"typ": "JWT",
|
||||
"kty": renamed_key.kty,
|
||||
"alg": alg_from_kty(renamed_key.kty),
|
||||
"kid": renamed_key.kid,
|
||||
})
|
||||
tr = tenant_tr_gen(default_tenant)
|
||||
tr.options.set_authorization_token(unknown_key_token)
|
||||
try:
|
||||
value = tr[b"abc"].value
|
||||
assert False, f"expected permission_denied for 'unknown key' case, but read transaction went through"
|
||||
except fdb.FDBError as e:
|
||||
assert e.code == 6000, f"expected permission_denied for 'unknown key' case, got {e} instead"
|
||||
tr = tenant_tr_gen(default_tenant)
|
||||
tr.options.set_authorization_token(unknown_key_token)
|
||||
tr[b"abc"] = b"def"
|
||||
try:
|
||||
tr.commit().wait()
|
||||
assert False, f"expected permission_denied for 'unknown key' case, but write transaction went through"
|
||||
except fdb.FDBError as e:
|
||||
assert e.code == 6000, f"expected permission_denied for 'unknown key' case, got {e} instead"
|
|
@ -0,0 +1,173 @@
|
|||
#!/usr/bin/python
|
||||
#
|
||||
# conftest.py
|
||||
#
|
||||
# This source file is part of the FoundationDB open source project
|
||||
#
|
||||
# Copyright 2013-2022 Apple Inc. and the FoundationDB project authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import fdb
|
||||
import pytest
|
||||
import subprocess
|
||||
import admin_server
|
||||
from authlib.jose import JsonWebKey, KeySet, jwt
|
||||
from local_cluster import TLSConfig
|
||||
from tmp_cluster import TempCluster
|
||||
from typing import Union
|
||||
from util import alg_from_kty, public_keyset_from_keys, random_alphanum_str, random_alphanum_bytes, to_str, to_bytes
|
||||
|
||||
fdb.api_version(720)
|
||||
|
||||
cluster_scope = "module"
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption(
|
||||
"--build-dir", action="store", dest="build_dir", help="FDB build directory", required=True)
|
||||
parser.addoption(
|
||||
"--kty", action="store", choices=["EC", "RSA"], default="EC", dest="kty", help="Token signature algorithm")
|
||||
parser.addoption(
|
||||
"--trusted-client",
|
||||
action="store_true",
|
||||
default=False,
|
||||
dest="trusted_client",
|
||||
help="Whether client shall be configured trusted, i.e. mTLS-ready")
|
||||
parser.addoption(
|
||||
"--public-key-refresh-interval",
|
||||
action="store",
|
||||
default=1,
|
||||
dest="public_key_refresh_interval",
|
||||
help="How frequently server refreshes authorization public key file")
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def build_dir(request):
|
||||
return request.config.option.build_dir
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def kty(request):
|
||||
return request.config.option.kty
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def trusted_client(request):
|
||||
return request.config.option.trusted_client
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def public_key_refresh_interval(request):
|
||||
return request.config.option.public_key_refresh_interval
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def alg(kty):
|
||||
if kty == "EC":
|
||||
return "ES256"
|
||||
else:
|
||||
return "RS256"
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def kid():
|
||||
return random_alphanum_str(12)
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def private_key_gen():
|
||||
def fn(kty: str, kid: str):
|
||||
if kty == "EC":
|
||||
return JsonWebKey.generate_key(kty=kty, crv_or_size="P-256", is_private=True, options={"kid": kid})
|
||||
else:
|
||||
return JsonWebKey.generate_key(kty=kty, crv_or_size=4096, is_private=True, options={"kid": kid})
|
||||
return fn
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def private_key(kty, kid, private_key_gen):
|
||||
return private_key_gen(kty, kid)
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def public_key_jwks_str(private_key):
|
||||
return public_keyset_from_keys([private_key])
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def token_gen():
|
||||
def fn(private_key, claims, headers={}):
|
||||
if not headers:
|
||||
headers = {
|
||||
"typ": "JWT",
|
||||
"kty": private_key.kty,
|
||||
"alg": alg_from_kty(private_key.kty),
|
||||
"kid": private_key.kid,
|
||||
}
|
||||
return jwt.encode(headers, claims, private_key)
|
||||
return fn
|
||||
|
||||
@pytest.fixture(scope=cluster_scope)
|
||||
def admin_ipc():
|
||||
server = admin_server.Server()
|
||||
server.start()
|
||||
yield server
|
||||
server.join()
|
||||
|
||||
@pytest.fixture(autouse=True, scope=cluster_scope)
|
||||
def cluster(admin_ipc, build_dir, public_key_jwks_str, public_key_refresh_interval, trusted_client):
|
||||
with TempCluster(
|
||||
build_dir=build_dir,
|
||||
tls_config=TLSConfig(server_chain_len=3, client_chain_len=2),
|
||||
public_key_json_str=public_key_jwks_str,
|
||||
remove_at_exit=True,
|
||||
custom_config={
|
||||
"knob-public-key-file-refresh-interval-seconds": public_key_refresh_interval,
|
||||
}) as cluster:
|
||||
keyfile = str(cluster.client_key_file)
|
||||
certfile = str(cluster.client_cert_file)
|
||||
cafile = str(cluster.server_ca_file)
|
||||
fdb.options.set_tls_key_path(keyfile if trusted_client else "")
|
||||
fdb.options.set_tls_cert_path(certfile if trusted_client else "")
|
||||
fdb.options.set_tls_ca_path(cafile)
|
||||
fdb.options.set_trace_enable()
|
||||
admin_ipc.request("configure_tls", [keyfile, certfile, cafile])
|
||||
admin_ipc.request("connect", [str(cluster.cluster_file)])
|
||||
yield cluster
|
||||
|
||||
@pytest.fixture
|
||||
def db(cluster, admin_ipc):
|
||||
db = fdb.open(str(cluster.cluster_file))
|
||||
db.options.set_transaction_timeout(2000) # 2 seconds
|
||||
db.options.set_transaction_retry_limit(3)
|
||||
yield db
|
||||
admin_ipc.request("cleanup_database")
|
||||
db = None
|
||||
|
||||
@pytest.fixture
|
||||
def tenant_gen(db, admin_ipc):
|
||||
def fn(tenant):
|
||||
tenant = to_bytes(tenant)
|
||||
admin_ipc.request("create_tenant", [tenant])
|
||||
return fn
|
||||
|
||||
@pytest.fixture
|
||||
def tenant_del(db, admin_ipc):
|
||||
def fn(tenant):
|
||||
tenant = to_str(tenant)
|
||||
admin_ipc.request("delete_tenant", [tenant])
|
||||
return fn
|
||||
|
||||
@pytest.fixture
|
||||
def default_tenant(tenant_gen, tenant_del):
|
||||
tenant = random_alphanum_bytes(8)
|
||||
tenant_gen(tenant)
|
||||
yield tenant
|
||||
tenant_del(tenant)
|
||||
|
||||
@pytest.fixture
|
||||
def tenant_tr_gen(db):
|
||||
def fn(tenant):
|
||||
tenant = db.open_tenant(to_bytes(tenant))
|
||||
return tenant.create_transaction()
|
||||
return fn
|
|
@ -0,0 +1,12 @@
|
|||
attrs==22.1.0
|
||||
Authlib==1.0.1
|
||||
cffi==1.15.1
|
||||
cryptography==37.0.4
|
||||
iniconfig==1.1.1
|
||||
packaging==21.3
|
||||
pluggy==1.0.0
|
||||
py==1.11.0
|
||||
pycparser==2.21
|
||||
pyparsing==3.0.9
|
||||
pytest==7.1.2
|
||||
tomli==2.0.1
|
|
@ -0,0 +1,124 @@
|
|||
import fdb
|
||||
import json
|
||||
import random
|
||||
import string
|
||||
import time
|
||||
from typing import Union, List
|
||||
|
||||
def to_str(s: Union[str, bytes]):
|
||||
if isinstance(s, bytes):
|
||||
s = s.decode("utf8")
|
||||
return s
|
||||
|
||||
def to_bytes(s: Union[str, bytes]):
|
||||
if isinstance(s, str):
|
||||
s = s.encode("utf8")
|
||||
return s
|
||||
|
||||
def random_alphanum_str(k: int):
|
||||
return ''.join(random.choices(string.ascii_letters + string.digits, k=k))
|
||||
|
||||
def random_alphanum_bytes(k: int):
|
||||
return random_alphanum_str(k).encode("ascii")
|
||||
|
||||
def cleanup_tenant(db, tenant_name):
|
||||
try:
|
||||
tenant = db.open_tenant(tenant_name)
|
||||
del tenant[:]
|
||||
fdb.tenant_management.delete_tenant(db, tenant_name)
|
||||
except fdb.FDBError as e:
|
||||
if e.code == 2131: # tenant not found
|
||||
pass
|
||||
else:
|
||||
raise
|
||||
|
||||
def alg_from_kty(kty: str):
|
||||
if kty == "EC":
|
||||
return "ES256"
|
||||
else:
|
||||
return "RS256"
|
||||
|
||||
def public_keyset_from_keys(keys: List):
|
||||
keys = list(map(lambda key: key.as_dict(is_private=False, alg=alg_from_kty(key.kty)), keys))
|
||||
return json.dumps({ "keys": keys })
|
||||
|
||||
class KeyFileReverter(object):
|
||||
def __init__(self, filename: str, content: str, refresh_delay: int):
|
||||
self.filename = filename
|
||||
self.content = content
|
||||
self.refresh_delay = refresh_delay
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||||
with open(self.filename, "w") as keyfile:
|
||||
keyfile.write(self.content)
|
||||
print(f"key file reverted. waiting {self.refresh_delay * 2} seconds for the update to take effect...")
|
||||
time.sleep(self.refresh_delay * 2)
|
||||
|
||||
# JWT claim that is valid for 1 hour since time of invocation
|
||||
def token_claim_1h(tenant_name):
|
||||
now = time.time()
|
||||
return {
|
||||
"iss": "fdb-authz-tester",
|
||||
"sub": "authz-test",
|
||||
"aud": ["tmp-cluster"],
|
||||
"iat": now,
|
||||
"nbf": now - 1,
|
||||
"exp": now + 60 * 60,
|
||||
"jti": random_alphanum_str(10),
|
||||
"tenants": [to_str(tenant_name)],
|
||||
}
|
||||
|
||||
# repeat try-wait loop up to max_repeat times until both read and write tr fails for tenant with permission_denied
|
||||
# important: only use this function if you don't have any data dependencies to key "abc"
|
||||
def wait_until_tenant_tr_fails(tenant, private_key, tenant_tr_gen, token_gen, max_repeat, delay):
|
||||
repeat = 0
|
||||
read_blocked = False
|
||||
write_blocked = False
|
||||
while (not read_blocked or not write_blocked) and repeat < max_repeat:
|
||||
time.sleep(delay)
|
||||
tr = tenant_tr_gen(tenant)
|
||||
# a token needs to be generated at every iteration because once it is accepted/cached,
|
||||
# it will pass verification by caching until it expires
|
||||
tr.options.set_authorization_token(token_gen(private_key, token_claim_1h(tenant)))
|
||||
try:
|
||||
if not read_blocked:
|
||||
value = tr[b"abc"].value
|
||||
except fdb.FDBError as e:
|
||||
assert e.code == 6000, f"expected permission_denied, got {e} instead"
|
||||
read_blocked = True
|
||||
if not read_blocked:
|
||||
repeat += 1
|
||||
continue
|
||||
|
||||
try:
|
||||
if not write_blocked:
|
||||
tr[b"abc"] = b"def"
|
||||
tr.commit().wait()
|
||||
except fdb.FDBError as e:
|
||||
assert e.code == 6000, f"expected permission_denied, got {e} instead"
|
||||
write_blocked = True
|
||||
if not write_blocked:
|
||||
repeat += 1
|
||||
assert repeat < max_repeat, f"tenant transaction did not start to fail in {max_repeat * delay} seconds"
|
||||
|
||||
# repeat try-wait loop up to max_repeat times until both read and write tr succeeds for tenant
|
||||
# important: only use this function if you don't have any data dependencies to key "abc"
|
||||
def wait_until_tenant_tr_succeeds(tenant, private_key, tenant_tr_gen, token_gen, max_repeat, delay):
|
||||
repeat = 0
|
||||
token = token_gen(private_key, token_claim_1h(tenant))
|
||||
while repeat < max_repeat:
|
||||
try:
|
||||
time.sleep(delay)
|
||||
tr = tenant_tr_gen(tenant)
|
||||
tr.options.set_authorization_token(token)
|
||||
value = tr[b"abc"].value
|
||||
tr[b"abc"] = b"qwe"
|
||||
tr.commit().wait()
|
||||
break
|
||||
except fdb.FDBError as e:
|
||||
assert e.code == 6000, f"expected permission_denied, got {e} instead"
|
||||
repeat += 1
|
||||
assert repeat < max_repeat, f"tenant transaction did not start to succeed in {max_repeat * delay} seconds"
|
Loading…
Reference in New Issue