496 lines
18 KiB
C++
496 lines
18 KiB
C++
#include "fdbrpc/Base64Encode.h"
|
|
#include "fdbrpc/Base64Decode.h"
|
|
#include "fdbrpc/FlowTransport.h"
|
|
#include "fdbrpc/TokenCache.h"
|
|
#include "fdbrpc/TokenSign.h"
|
|
#include "fdbrpc/TenantInfo.h"
|
|
#include "flow/MkCert.h"
|
|
#include "flow/ScopeExit.h"
|
|
#include "flow/UnitTest.h"
|
|
#include "flow/network.h"
|
|
|
|
#include <rapidjson/document.h>
|
|
#include <rapidjson/writer.h>
|
|
#include <rapidjson/stringbuffer.h>
|
|
|
|
#include <boost/unordered_map.hpp>
|
|
#include <boost/unordered_set.hpp>
|
|
|
|
#include <fmt/format.h>
|
|
#include <list>
|
|
#include <deque>
|
|
|
|
#include "flow/actorcompiler.h" // has to be last include
|
|
|
|
using authz::TenantId;
|
|
|
|
template <class Key, class Value>
|
|
class LRUCache {
|
|
public:
|
|
using key_type = Key;
|
|
using list_type = std::list<key_type>;
|
|
using mapped_type = Value;
|
|
using map_type = boost::unordered_map<key_type, std::pair<mapped_type, typename list_type::iterator>>;
|
|
using size_type = unsigned;
|
|
|
|
explicit LRUCache(size_type capacity) : _capacity(capacity) { _map.reserve(capacity); }
|
|
|
|
size_type size() const { return _map.size(); }
|
|
size_type capacity() const { return _capacity; }
|
|
bool empty() const { return _map.empty(); }
|
|
|
|
Optional<mapped_type*> get(key_type const& key) {
|
|
auto i = _map.find(key);
|
|
if (i == _map.end()) {
|
|
return Optional<mapped_type*>();
|
|
}
|
|
auto j = i->second.second;
|
|
if (j != _list.begin()) {
|
|
_list.erase(j);
|
|
_list.push_front(i->first);
|
|
i->second.second = _list.begin();
|
|
}
|
|
return &i->second.first;
|
|
}
|
|
|
|
template <class K, class V>
|
|
mapped_type* insert(K&& key, V&& value) {
|
|
auto iter = _map.find(key);
|
|
if (iter != _map.end()) {
|
|
return &iter->second.first;
|
|
}
|
|
if (size() == capacity()) {
|
|
auto i = --_list.end();
|
|
_map.erase(*i);
|
|
_list.erase(i);
|
|
}
|
|
_list.push_front(std::forward<K>(key));
|
|
std::tie(iter, std::ignore) =
|
|
_map.insert(std::make_pair(*_list.begin(), std::make_pair(std::forward<V>(value), _list.begin())));
|
|
return &iter->second.first;
|
|
}
|
|
|
|
private:
|
|
const size_type _capacity;
|
|
map_type _map;
|
|
list_type _list;
|
|
};
|
|
|
|
TEST_CASE("/fdbrpc/authz/LRUCache") {
|
|
auto& rng = *deterministicRandom();
|
|
{
|
|
// test very small LRU cache
|
|
LRUCache<int, StringRef> cache(rng.randomInt(2, 10));
|
|
for (int i = 0; i < 200; ++i) {
|
|
cache.insert(i, "val"_sr);
|
|
if (i >= cache.capacity()) {
|
|
for (auto j = 0; j <= i - cache.capacity(); j++)
|
|
ASSERT(!cache.get(j).present());
|
|
// ordering is important so as not to disrupt the LRU order
|
|
for (auto j = i - cache.capacity() + 1; j <= i; j++)
|
|
ASSERT(cache.get(j).present());
|
|
}
|
|
}
|
|
}
|
|
{
|
|
// Test larger cache
|
|
LRUCache<int, StringRef> cache(1000);
|
|
for (auto i = 0; i < 1000; ++i) {
|
|
cache.insert(i, "value"_sr);
|
|
}
|
|
cache.insert(1000, "value"_sr); // should evict 0
|
|
ASSERT(!cache.get(0).present());
|
|
}
|
|
{
|
|
// memory test -- this is what the boost implementation didn't do correctly
|
|
LRUCache<StringRef, Standalone<StringRef>> cache(10);
|
|
std::deque<std::string> cachedStrings;
|
|
std::deque<std::string> evictedStrings;
|
|
for (int i = 0; i < 10; ++i) {
|
|
auto str = rng.randomAlphaNumeric(rng.randomInt(100, 1024));
|
|
Standalone<StringRef> sref(str);
|
|
cache.insert(sref, sref);
|
|
cachedStrings.push_back(str);
|
|
}
|
|
for (int i = 0; i < 10; ++i) {
|
|
Standalone<StringRef> existingStr(cachedStrings.back());
|
|
auto cachedStr = cache.get(existingStr);
|
|
ASSERT(cachedStr.present());
|
|
ASSERT(*cachedStr.get() == existingStr);
|
|
if (!evictedStrings.empty()) {
|
|
Standalone<StringRef> nonexisting(evictedStrings.at(rng.randomInt(0, evictedStrings.size())));
|
|
ASSERT(!cache.get(nonexisting).present());
|
|
}
|
|
auto str = rng.randomAlphaNumeric(rng.randomInt(100, 1024));
|
|
Standalone<StringRef> sref(str);
|
|
evictedStrings.push_back(cachedStrings.front());
|
|
cachedStrings.pop_front();
|
|
cachedStrings.push_back(str);
|
|
cache.insert(sref, sref);
|
|
}
|
|
}
|
|
return Void();
|
|
}
|
|
|
|
struct CacheEntry {
|
|
Arena arena;
|
|
VectorRef<TenantId> tenants;
|
|
Optional<StringRef> tokenId;
|
|
double expirationTime = 0.0;
|
|
};
|
|
|
|
struct AuditEntry {
|
|
NetworkAddress address;
|
|
TenantId tenantId;
|
|
Optional<Standalone<StringRef>> tokenId;
|
|
bool operator==(const AuditEntry& other) const noexcept = default;
|
|
explicit AuditEntry(NetworkAddress const& address, TenantId tenantId, CacheEntry const& cacheEntry)
|
|
: address(address), tenantId(tenantId),
|
|
tokenId(cacheEntry.tokenId.present() ? Standalone<StringRef>(cacheEntry.tokenId.get(), cacheEntry.arena)
|
|
: Optional<Standalone<StringRef>>()) {}
|
|
};
|
|
|
|
std::size_t hash_value(AuditEntry const& value) {
|
|
std::size_t seed = 0;
|
|
boost::hash_combine(seed, value.address);
|
|
boost::hash_combine(seed, value.tenantId);
|
|
if (value.tokenId.present()) {
|
|
boost::hash_combine(seed, value.tokenId.get());
|
|
}
|
|
return seed;
|
|
}
|
|
|
|
struct TokenCacheImpl {
|
|
TokenCacheImpl();
|
|
LRUCache<StringRef, CacheEntry> cache;
|
|
boost::unordered_set<AuditEntry> usedTokens;
|
|
double lastResetTime;
|
|
|
|
bool validate(TenantId tenantId, StringRef token);
|
|
bool validateAndAdd(double currentTime, StringRef token, NetworkAddress const& peer);
|
|
void logTokenUsage(double currentTime, AuditEntry&& entry);
|
|
};
|
|
|
|
TokenCacheImpl::TokenCacheImpl() : cache(FLOW_KNOBS->TOKEN_CACHE_SIZE), usedTokens(), lastResetTime(0) {}
|
|
|
|
TokenCache::TokenCache() : impl(new TokenCacheImpl()) {}
|
|
TokenCache::~TokenCache() {
|
|
delete impl;
|
|
}
|
|
|
|
void TokenCache::createInstance() {
|
|
g_network->setGlobal(INetwork::enTokenCache, new TokenCache());
|
|
}
|
|
|
|
TokenCache& TokenCache::instance() {
|
|
return *reinterpret_cast<TokenCache*>(g_network->global(INetwork::enTokenCache));
|
|
}
|
|
|
|
bool TokenCache::validate(TenantId tenantId, StringRef token) {
|
|
return impl->validate(tenantId, token);
|
|
}
|
|
|
|
#define TRACE_INVALID_PARSED_TOKEN(reason, token) \
|
|
TraceEvent(SevWarn, "InvalidToken"_audit) \
|
|
.detail("From", peer) \
|
|
.detail("Reason", reason) \
|
|
.detail("CurrentTime", currentTime) \
|
|
.detail("Token", toStringRef(arena, token).toStringView())
|
|
|
|
bool TokenCacheImpl::validateAndAdd(double currentTime, StringRef token, NetworkAddress const& peer) {
|
|
Arena arena;
|
|
authz::jwt::TokenRef t;
|
|
StringRef signInput;
|
|
Optional<StringRef> err;
|
|
bool verifyOutcome;
|
|
if ((err = authz::jwt::parseToken(arena, token, t, signInput)).present()) {
|
|
CODE_PROBE(true, "Token can't be parsed");
|
|
TraceEvent te(SevWarn, "InvalidToken");
|
|
te.detail("From", peer);
|
|
te.detail("Reason", "ParseError");
|
|
te.detail("ErrorDetail", err.get());
|
|
if (signInput.empty()) { // unrecognizable token structure
|
|
te.detail("Token", token.toString());
|
|
} else { // trace with signature part taken out
|
|
te.detail("SignInput", signInput.toString());
|
|
}
|
|
return false;
|
|
}
|
|
auto key = FlowTransport::transport().getPublicKeyByName(t.keyId);
|
|
if (!key.present()) {
|
|
CODE_PROBE(true, "Token referencing non-existing key");
|
|
TRACE_INVALID_PARSED_TOKEN("UnknownKey", t);
|
|
return false;
|
|
} else if (!t.issuedAtUnixTime.present()) {
|
|
CODE_PROBE(true, "Token has no issued-at field");
|
|
TRACE_INVALID_PARSED_TOKEN("NoIssuedAt", t);
|
|
return false;
|
|
} else if (!t.expiresAtUnixTime.present()) {
|
|
CODE_PROBE(true, "Token has no expiration time");
|
|
TRACE_INVALID_PARSED_TOKEN("NoExpirationTime", t);
|
|
return false;
|
|
} else if (double(t.expiresAtUnixTime.get()) <= currentTime) {
|
|
CODE_PROBE(true, "Expired token");
|
|
TRACE_INVALID_PARSED_TOKEN("Expired", t);
|
|
return false;
|
|
} else if (!t.notBeforeUnixTime.present()) {
|
|
CODE_PROBE(true, "Token has no not-before field");
|
|
TRACE_INVALID_PARSED_TOKEN("NoNotBefore", t);
|
|
return false;
|
|
} else if (double(t.notBeforeUnixTime.get()) > currentTime) {
|
|
CODE_PROBE(true, "Token's not-before is in the future");
|
|
TRACE_INVALID_PARSED_TOKEN("TokenNotYetValid", t);
|
|
return false;
|
|
} else if (!t.tenants.present()) {
|
|
CODE_PROBE(true, "Token with no tenants");
|
|
TRACE_INVALID_PARSED_TOKEN("NoTenants", t);
|
|
return false;
|
|
}
|
|
std::tie(verifyOutcome, err) = authz::jwt::verifyToken(signInput, t, key.get());
|
|
if (err.present()) {
|
|
CODE_PROBE(true, "Error while verifying token");
|
|
TRACE_INVALID_PARSED_TOKEN("ErrorWhileVerifyingToken", t).detail("ErrorDetail", err.get());
|
|
return false;
|
|
} else if (!verifyOutcome) {
|
|
CODE_PROBE(true, "Token with invalid signature");
|
|
TRACE_INVALID_PARSED_TOKEN("InvalidSignature", t);
|
|
return false;
|
|
} else {
|
|
CacheEntry c;
|
|
c.expirationTime = t.expiresAtUnixTime.get();
|
|
c.tenants.reserve(c.arena, t.tenants.get().size());
|
|
for (auto tenantId : t.tenants.get()) {
|
|
c.tenants.push_back(c.arena, tenantId);
|
|
}
|
|
if (t.tokenId.present()) {
|
|
c.tokenId = StringRef(c.arena, t.tokenId.get());
|
|
}
|
|
cache.insert(StringRef(c.arena, token), c);
|
|
return true;
|
|
}
|
|
}
|
|
|
|
bool TokenCacheImpl::validate(TenantId tenantId, StringRef token) {
|
|
NetworkAddress peer = FlowTransport::transport().currentDeliveryPeerAddress();
|
|
auto cachedEntry = cache.get(token);
|
|
double currentTime = g_network->timer();
|
|
|
|
if (!cachedEntry.present()) {
|
|
if (validateAndAdd(currentTime, token, peer)) {
|
|
cachedEntry = cache.get(token);
|
|
} else {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
ASSERT(cachedEntry.present());
|
|
|
|
auto& entry = cachedEntry.get();
|
|
if (entry->expirationTime < currentTime) {
|
|
CODE_PROBE(true, "Found expired token in cache");
|
|
TraceEvent(SevWarn, "InvalidToken"_audit).detail("From", peer).detail("Reason", "ExpiredInCache");
|
|
return false;
|
|
}
|
|
bool tenantFound = false;
|
|
for (auto const& t : entry->tenants) {
|
|
if (t == tenantId) {
|
|
tenantFound = true;
|
|
break;
|
|
}
|
|
}
|
|
if (!tenantFound) {
|
|
CODE_PROBE(true, "Valid token doesn't reference tenant");
|
|
TraceEvent(SevWarn, "InvalidToken"_audit)
|
|
.detail("From", peer)
|
|
.detail("Reason", "TenantTokenMismatch")
|
|
.detail("RequestedTenant", fmt::format("{:#x}", tenantId))
|
|
.detail("TenantsInToken", fmt::format("{:#x}", fmt::join(entry->tenants, " ")));
|
|
return false;
|
|
}
|
|
// audit logging
|
|
if (FLOW_KNOBS->AUDIT_LOGGING_ENABLED)
|
|
logTokenUsage(currentTime, AuditEntry(peer, tenantId, *cachedEntry.get()));
|
|
return true;
|
|
}
|
|
|
|
void TokenCacheImpl::logTokenUsage(double currentTime, AuditEntry&& entry) {
|
|
if (currentTime > lastResetTime + FLOW_KNOBS->AUDIT_TIME_WINDOW) {
|
|
// clear usage cache every AUDIT_TIME_WINDOW seconds
|
|
usedTokens.clear();
|
|
lastResetTime = currentTime;
|
|
}
|
|
auto [iter, inserted] = usedTokens.insert(std::move(entry));
|
|
if (inserted) {
|
|
// access in the context of this (client_ip, tenant, token_id) tuple hasn't been logged in current window. log
|
|
// usage.
|
|
CODE_PROBE(true, "Audit Logging Running");
|
|
TraceEvent("AuditTokenUsed"_audit)
|
|
.detail("Client", iter->address)
|
|
.detail("TenantId", fmt::format("{:#x}", iter->tenantId))
|
|
.detail("TokenId", iter->tokenId)
|
|
.log();
|
|
}
|
|
}
|
|
|
|
namespace authz::jwt {
|
|
extern TokenRef makeRandomTokenSpec(Arena&, IRandom&, authz::Algorithm);
|
|
}
|
|
|
|
TEST_CASE("/fdbrpc/authz/TokenCache/BadTokens") {
|
|
auto const pubKeyName = "someEcPublicKey"_sr;
|
|
auto const rsaPubKeyName = "someRsaPublicKey"_sr;
|
|
auto privateKey = mkcert::makeEcP256();
|
|
auto publicKey = privateKey.toPublic();
|
|
auto rsaPrivateKey = mkcert::makeRsa4096Bit(); // to trigger unmatched sign algorithm
|
|
auto rsaPublicKey = rsaPrivateKey.toPublic();
|
|
std::pair<std::function<void(Arena&, IRandom&, authz::jwt::TokenRef&)>, char const*> badMutations[]{
|
|
{
|
|
[](Arena&, IRandom&, authz::jwt::TokenRef&) { FlowTransport::transport().removeAllPublicKeys(); },
|
|
"NoKeyWithSuchName",
|
|
},
|
|
{
|
|
[](Arena&, IRandom&, authz::jwt::TokenRef& token) { token.expiresAtUnixTime.reset(); },
|
|
"NoExpirationTime",
|
|
},
|
|
{
|
|
[](Arena&, IRandom& rng, authz::jwt::TokenRef& token) {
|
|
token.expiresAtUnixTime = std::max<double>(g_network->timer() - 10 - rng.random01() * 50, 0);
|
|
},
|
|
"ExpiredToken",
|
|
},
|
|
{
|
|
[](Arena&, IRandom&, authz::jwt::TokenRef& token) { token.notBeforeUnixTime.reset(); },
|
|
"NoNotBefore",
|
|
},
|
|
{
|
|
[](Arena&, IRandom& rng, authz::jwt::TokenRef& token) {
|
|
token.notBeforeUnixTime = g_network->timer() + 10 + rng.random01() * 50;
|
|
},
|
|
"TokenNotYetValid",
|
|
},
|
|
{
|
|
[](Arena&, IRandom&, authz::jwt::TokenRef& token) { token.issuedAtUnixTime.reset(); },
|
|
"NoIssuedAt",
|
|
},
|
|
{
|
|
[](Arena& arena, IRandom&, authz::jwt::TokenRef& token) { token.tenants.reset(); },
|
|
"NoTenants",
|
|
},
|
|
{
|
|
[](Arena& arena, IRandom&, authz::jwt::TokenRef& token) {
|
|
TenantId* newTenants = new (arena) TenantId[1];
|
|
*newTenants = token.tenants.get()[0] + 1;
|
|
token.tenants = VectorRef<TenantId>(newTenants, 1);
|
|
},
|
|
"UnmatchedTenant",
|
|
},
|
|
{
|
|
[rsaPubKeyName](Arena& arena, IRandom&, authz::jwt::TokenRef& token) { token.keyId = rsaPubKeyName; },
|
|
"UnmatchedSignAlgorithm",
|
|
},
|
|
};
|
|
auto const numBadMutations = sizeof(badMutations) / sizeof(badMutations[0]);
|
|
for (auto repeat = 0; repeat < 50; repeat++) {
|
|
auto arena = Arena();
|
|
auto& rng = *deterministicRandom();
|
|
auto validTokenSpec = authz::jwt::makeRandomTokenSpec(arena, rng, authz::Algorithm::ES256);
|
|
validTokenSpec.keyId = pubKeyName;
|
|
for (auto i = 0; i <= numBadMutations + 1; i++) {
|
|
FlowTransport::transport().addPublicKey(pubKeyName, publicKey);
|
|
FlowTransport::transport().addPublicKey(rsaPubKeyName, rsaPublicKey);
|
|
auto publicKeyClearGuard = ScopeExit([]() { FlowTransport::transport().removeAllPublicKeys(); });
|
|
auto signedToken = StringRef();
|
|
auto tmpArena = Arena();
|
|
if (i < numBadMutations) {
|
|
auto [mutationFn, mutationDesc] = badMutations[i];
|
|
auto mutatedTokenSpec = validTokenSpec;
|
|
mutationFn(tmpArena, rng, mutatedTokenSpec);
|
|
signedToken = authz::jwt::signToken(tmpArena, mutatedTokenSpec, privateKey);
|
|
if (TokenCache::instance().validate(validTokenSpec.tenants.get()[0], signedToken)) {
|
|
fmt::print("Unexpected successful validation at mutation {}, token spec: {}\n",
|
|
mutationDesc,
|
|
toStringRef(tmpArena, mutatedTokenSpec).toStringView());
|
|
ASSERT(false);
|
|
}
|
|
} else if (i == numBadMutations) {
|
|
// squeeze in a bad signature case that does not fit into mutation interface
|
|
signedToken = authz::jwt::signToken(tmpArena, validTokenSpec, privateKey);
|
|
signedToken.popBack();
|
|
if (TokenCache::instance().validate(validTokenSpec.tenants.get()[0], signedToken)) {
|
|
fmt::print("Unexpected successful validation with a token with truncated signature part\n");
|
|
ASSERT(false);
|
|
}
|
|
} else {
|
|
// test if badly base64-encoded tenant name causes validation to fail as expected
|
|
auto signInput = authz::jwt::makeSignInput(tmpArena, validTokenSpec);
|
|
auto b64Header = signInput.eat("."_sr);
|
|
auto payload = base64::url::decode(tmpArena, signInput).get();
|
|
rapidjson::Document d;
|
|
d.Parse(reinterpret_cast<const char*>(payload.begin()), payload.size());
|
|
ASSERT(!d.HasParseError());
|
|
rapidjson::StringBuffer wrBuf;
|
|
rapidjson::Writer<rapidjson::StringBuffer> wr(wrBuf);
|
|
auto tenantsField = d.FindMember("tenants");
|
|
ASSERT(tenantsField != d.MemberEnd());
|
|
tenantsField->value.PushBack("ABC#", d.GetAllocator()); // inject base64-illegal character
|
|
d.Accept(wr);
|
|
auto b64ModifiedPayload = base64::url::encode(
|
|
tmpArena, StringRef(reinterpret_cast<const uint8_t*>(wrBuf.GetString()), wrBuf.GetSize()));
|
|
signInput = b64Header.withSuffix("."_sr, tmpArena).withSuffix(b64ModifiedPayload, tmpArena);
|
|
signedToken = authz::jwt::signToken(tmpArena, signInput, validTokenSpec.algorithm, privateKey);
|
|
if (TokenCache::instance().validate(validTokenSpec.tenants.get()[0], signedToken)) {
|
|
fmt::print(
|
|
"Unexpected successful validation of a token with tenant name containing non-base64 chars)\n");
|
|
ASSERT(false);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
if (TokenCache::instance().validate(TenantInfo::INVALID_TENANT, StringRef())) {
|
|
fmt::print("Unexpected successful validation of ill-formed token (no signature part)\n");
|
|
ASSERT(false);
|
|
}
|
|
if (TokenCache::instance().validate(TenantInfo::INVALID_TENANT, "1111.22"_sr)) {
|
|
fmt::print("Unexpected successful validation of ill-formed token (no signature part)\n");
|
|
ASSERT(false);
|
|
}
|
|
if (TokenCache::instance().validate(TenantInfo::INVALID_TENANT, "////.////.////"_sr)) {
|
|
fmt::print("Unexpected successful validation of unparseable token\n");
|
|
ASSERT(false);
|
|
}
|
|
fmt::print("TEST OK\n");
|
|
return Void();
|
|
}
|
|
|
|
TEST_CASE("/fdbrpc/authz/TokenCache/GoodTokens") {
|
|
// Don't repeat because token expiry is at seconds granularity and sleeps are costly in unit tests
|
|
state Arena arena;
|
|
state PrivateKey privateKey = mkcert::makeEcP256();
|
|
state StringRef pubKeyName = "somePublicKey"_sr;
|
|
state ScopeExit<std::function<void()>> publicKeyClearGuard(
|
|
[pubKeyName = pubKeyName]() { FlowTransport::transport().removePublicKey(pubKeyName); });
|
|
state authz::jwt::TokenRef tokenSpec =
|
|
authz::jwt::makeRandomTokenSpec(arena, *deterministicRandom(), authz::Algorithm::ES256);
|
|
state StringRef signedToken;
|
|
FlowTransport::transport().addPublicKey(pubKeyName, privateKey.toPublic());
|
|
tokenSpec.expiresAtUnixTime = g_network->timer() + 2.0;
|
|
tokenSpec.keyId = pubKeyName;
|
|
signedToken = authz::jwt::signToken(arena, tokenSpec, privateKey);
|
|
if (!TokenCache::instance().validate(tokenSpec.tenants.get()[0], signedToken)) {
|
|
fmt::print("Unexpected failed token validation, token spec: {}, now: {}\n",
|
|
toStringRef(arena, tokenSpec).toStringView(),
|
|
g_network->timer());
|
|
ASSERT(false);
|
|
}
|
|
wait(delay(3.5));
|
|
if (TokenCache::instance().validate(tokenSpec.tenants.get()[0], signedToken)) {
|
|
fmt::print(
|
|
"Unexpected successful token validation after supposedly expiring in cache, token spec: {}, now: {}\n",
|
|
toStringRef(arena, tokenSpec).toStringView(),
|
|
g_network->timer());
|
|
ASSERT(false);
|
|
}
|
|
fmt::print("TEST OK\n");
|
|
return Void();
|
|
}
|