diff --git a/fdbrpc/TokenCache.cpp b/fdbrpc/TokenCache.cpp index dc38a23702..19b89d7234 100644 --- a/fdbrpc/TokenCache.cpp +++ b/fdbrpc/TokenCache.cpp @@ -1,9 +1,120 @@ #include "fdbrpc/FlowTransport.h" #include "fdbrpc/TokenCache.h" #include "fdbrpc/TokenSign.h" +#include "flow/UnitTest.h" #include "flow/network.h" -#include +#include + +#include +#include + +template +class LRUCache { +public: + using key_type = Key; + using list_type = std::list; + using mapped_type = Value; + using map_type = boost::unordered_map>; + using size_type = unsigned; + + explicit LRUCache(size_type capacity) : _capacity(capacity) { _map.reserve(capacity); } + + size_type size() const { return _map.size(); } + size_type capacity() const { return _capacity; } + bool empty() const { return _map.empty(); } + + Optional get(key_type const& key) { + auto i = _map.find(key); + if (i == _map.end()) { + return Optional(); + } + auto j = i->second.second; + if (j != _list.begin()) { + _list.erase(j); + _list.push_front(i->first); + i->second.second = _list.begin(); + } + return &i->second.first; + } + + template + mapped_type* insert(K&& key, V&& value) { + auto iter = _map.find(key); + if (iter != _map.end()) { + return &iter->second.first; + } + if (size() == capacity()) { + auto i = --_list.end(); + _map.erase(*i); + _list.erase(i); + } + std::tie(iter, std::ignore) = + _map.insert(std::make_pair(std::forward(key), std::make_pair(std::forward(value), _list.begin()))); + _list.push_back(iter->first); + iter->second.second = _list.begin(); + return &iter->second.first; + } + +private: + const size_type _capacity; + map_type _map; + list_type _list; +}; + +TEST_CASE("/fdbrpc/authz/LRUCache") { + { + // test very small LRU cache + LRUCache cache(2); + for (int i = 0; i < 200; ++i) { + cache.insert(i, "val"_sr); + if (i > cache.capacity()) { + ASSERT(cache.get(i - cache.capacity() + 1).present()); + ASSERT(!cache.get(i - cache.capacity()).present()); + } + } + } + { + // Test larger cache + LRUCache cache(1000); + int last = 0; + for (; last < 1000; ++last) { + cache.insert(last, "value"_sr); + } + cache.insert(0, "value"); // should evict 1 + ASSERT(!cache.get(1).present()); + } + { + // memory test -- this is what the boost implementation didn't do correctly + LRUCache> cache(10); + std::deque cachedStrings; + std::deque evictedStrings; + for (int i = 0; i < 10; ++i) { + auto str = deterministicRandom()->randomAlphaNumeric(deterministicRandom()->randomInt(100, 1024)); + Standalone sref(str); + cache.insert(sref, sref); + cachedStrings.push_back(str); + } + for (int i = 0; i < 10; ++i) { + Standalone existingStr(cachedStrings.back()); + auto cachedStr = cache.get(existingStr); + ASSERT(cachedStr.present()); + ASSERT(*cachedStr.get() == existingStr); + if (!evictedStrings.empty()) { + Standalone nonexisting( + evictedStrings.at(deterministicRandom()->randomInt(0, evictedStrings.size()))); + ASSERT(!cache.get(nonexisting).present()); + } + auto str = deterministicRandom()->randomAlphaNumeric(deterministicRandom()->randomInt(100, 1024)); + Standalone sref(str); + evictedStrings.push_back(cachedStrings.front()); + cachedStrings.pop_front(); + cachedStrings.push_back(str); + cache.insert(sref, sref); + } + } + return Void(); +} struct TokenCacheImpl { struct CacheEntry { @@ -12,7 +123,7 @@ struct TokenCacheImpl { double expirationTime = 0.0; }; - boost::compute::detail::lru_cache cache; + LRUCache cache; TokenCacheImpl() : cache(FLOW_KNOBS->TOKEN_CACHE_SIZE) {} bool validate(TenantNameRef tenant, StringRef token); @@ -69,7 +180,7 @@ bool TokenCacheImpl::validateAndAdd(double currentTime, TraceEvent(SevWarn, "InvalidToken").detail("From", peer).detail("Reason", "NoNotBefore"); return false; } else if (double(t.notBeforeUnixTime.get()) > currentTime) { - TEST(true); // Token has no not-before field + TEST(true); // Tokens not-before is in the future TraceEvent(SevWarn, "InvalidToken").detail("From", peer).detail("Reason", "TokenNotYetValid"); return false; } else if (!t.tenants.present()) { @@ -98,7 +209,7 @@ bool TokenCacheImpl::validate(TenantNameRef name, StringRef token) { double currentTime = g_network->timer(); NetworkAddress peer = FlowTransport::transport().currentDeliveryPeerAddress(); - if (!cachedEntry.has_value()) { + if (!cachedEntry.present()) { if (validateAndAdd(currentTime, sig, token, peer)) { cachedEntry = cache.get(sig); } else { @@ -106,16 +217,16 @@ bool TokenCacheImpl::validate(TenantNameRef name, StringRef token) { } } - ASSERT(cachedEntry.has_value()); + ASSERT(cachedEntry.present()); auto& entry = cachedEntry.get(); - if (entry.expirationTime < currentTime) { + if (entry->expirationTime < currentTime) { TEST(true); // Read expired token from cache TraceEvent(SevWarn, "InvalidToken").detail("From", peer).detail("Reason", "Expired"); return false; } bool tenantFound = false; - for (auto const& t : entry.tenants) { + for (auto const& t : entry->tenants) { if (t == name) { tenantFound = true; break;