Allow unthrottled, unsuppressed traces for security-related events (#9459)
* Define API for unsuppressable TraceEvent types Add trace checking tests for authz trace events * Revert temporary configurations used for debugging * Simplify/Modernize flow audit logging API - Do event type whitelist checks at compile time - Use ""_audit literal API instead of a tag struct - Replace int with a lightweight struct for tracking/modifying TraceEvent enablement * Revert installing signal handler for SIGTERM and refactor test script Move trace checker to local_cluster.py * Lengthen public key refresh interval and add more audited events * Try and make MSVC and Mac build happy * consteval > constexpr 'inline consteval' still causes link errors in Mac builds
This commit is contained in:
parent
80eb84de3c
commit
b811881f41
|
@ -971,7 +971,7 @@ void Peer::onIncomingConnection(Reference<Peer> self, Reference<IConnection> con
|
|||
if (!destination.isPublic() || outgoingConnectionIdle || destination > compatibleAddr ||
|
||||
(lastConnectTime > 1.0 && now() - lastConnectTime > FLOW_KNOBS->ALWAYS_ACCEPT_DELAY)) {
|
||||
// Keep the new connection
|
||||
TraceEvent("IncomingConnection", conn->getDebugID())
|
||||
TraceEvent("IncomingConnection"_audit, conn->getDebugID())
|
||||
.suppressFor(1.0)
|
||||
.detail("FromAddr", conn->getPeerAddress())
|
||||
.detail("CanonicalAddr", destination)
|
||||
|
@ -1072,7 +1072,7 @@ ACTOR static void deliver(TransportData* self,
|
|||
} else if (destination.token.first() & TOKEN_STREAM_FLAG) {
|
||||
// We don't have the (stream) endpoint 'token', notify the remote machine
|
||||
if (receiver) {
|
||||
TraceEvent(SevWarnAlways, "AttemptedRPCToPrivatePrevented")
|
||||
TraceEvent(SevWarnAlways, "AttemptedRPCToPrivatePrevented"_audit)
|
||||
.detail("From", peerAddress)
|
||||
.detail("Token", destination.token)
|
||||
.detail("Receiver", typeid(*receiver).name());
|
||||
|
@ -1574,7 +1574,7 @@ void TransportData::applyPublicKeySet(StringRef jwkSetString) {
|
|||
numPrivateKeys++;
|
||||
}
|
||||
}
|
||||
TraceEvent(SevInfo, "AuthzPublicKeySetApply").detail("NumPublicKeys", publicKeys.size());
|
||||
TraceEvent(SevInfo, "AuthzPublicKeySetApply"_audit).detail("NumPublicKeys", publicKeys.size());
|
||||
if (numPrivateKeys > 0) {
|
||||
TraceEvent(SevWarnAlways, "AuthzPublicKeySetContainsPrivateKeys").detail("NumPrivateKeys", numPrivateKeys);
|
||||
}
|
||||
|
@ -2066,7 +2066,7 @@ ACTOR static Future<Void> watchPublicKeyJwksFile(std::string filePath, Transport
|
|||
}
|
||||
// parse/read error
|
||||
errorCount++;
|
||||
TraceEvent(SevWarn, "AuthzPublicKeySetRefreshError").error(e).detail("ErrorCount", errorCount);
|
||||
TraceEvent(SevWarn, "AuthzPublicKeySetRefreshError"_audit).error(e).detail("ErrorCount", errorCount);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -191,7 +191,7 @@ bool TokenCache::validate(TenantId tenantId, StringRef token) {
|
|||
}
|
||||
|
||||
#define TRACE_INVALID_PARSED_TOKEN(reason, token) \
|
||||
TraceEvent(SevWarn, "InvalidToken") \
|
||||
TraceEvent(SevWarn, "InvalidToken"_audit) \
|
||||
.detail("From", peer) \
|
||||
.detail("Reason", reason) \
|
||||
.detail("CurrentTime", currentTime) \
|
||||
|
@ -288,7 +288,7 @@ bool TokenCacheImpl::validate(TenantId tenantId, StringRef token) {
|
|||
auto& entry = cachedEntry.get();
|
||||
if (entry->expirationTime < currentTime) {
|
||||
CODE_PROBE(true, "Found expired token in cache");
|
||||
TraceEvent(SevWarn, "InvalidToken").detail("From", peer).detail("Reason", "ExpiredInCache");
|
||||
TraceEvent(SevWarn, "InvalidToken"_audit).detail("From", peer).detail("Reason", "ExpiredInCache");
|
||||
return false;
|
||||
}
|
||||
bool tenantFound = false;
|
||||
|
@ -300,8 +300,9 @@ bool TokenCacheImpl::validate(TenantId tenantId, StringRef token) {
|
|||
}
|
||||
if (!tenantFound) {
|
||||
CODE_PROBE(true, "Valid token doesn't reference tenant");
|
||||
TraceEvent(SevWarn, "TenantTokenMismatch")
|
||||
TraceEvent(SevWarn, "InvalidToken"_audit)
|
||||
.detail("From", peer)
|
||||
.detail("Reason", "TenantTokenMismatch")
|
||||
.detail("RequestedTenant", fmt::format("{:#x}", tenantId))
|
||||
.detail("TenantsInToken", fmt::format("{:#x}", fmt::join(entry->tenants, " ")));
|
||||
return false;
|
||||
|
@ -323,7 +324,7 @@ void TokenCacheImpl::logTokenUsage(double currentTime, AuditEntry&& entry) {
|
|||
// access in the context of this (client_ip, tenant, token_id) tuple hasn't been logged in current window. log
|
||||
// usage.
|
||||
CODE_PROBE(true, "Audit Logging Running");
|
||||
TraceEvent("AuditTokenUsed")
|
||||
TraceEvent("AuditTokenUsed"_audit)
|
||||
.detail("Client", iter->address)
|
||||
.detail("TenantId", fmt::format("{:#x}", iter->tenantId))
|
||||
.detail("TokenId", iter->tokenId)
|
||||
|
|
|
@ -695,11 +695,11 @@ struct NetNotifiedQueue final : NotifiedQueue<T>, FlowReceiver, FastAllocated<Ne
|
|||
if constexpr (IsPublic) {
|
||||
if (!message.verify()) {
|
||||
if constexpr (HasReply<T>) {
|
||||
message.reply.sendError(permission_denied());
|
||||
TraceEvent(SevWarnAlways, "UnauthorizedAccessPrevented")
|
||||
TraceEvent(SevWarnAlways, "UnauthorizedAccessPrevented"_audit)
|
||||
.detail("RequestType", typeid(T).name())
|
||||
.detail("ClientIP", FlowTransport::transport().currentDeliveryPeerAddress())
|
||||
.log();
|
||||
message.reply.sendError(permission_denied());
|
||||
}
|
||||
} else {
|
||||
this->send(std::move(message));
|
||||
|
|
|
@ -2050,6 +2050,8 @@ int main(int argc, char* argv[]) {
|
|||
} else {
|
||||
TraceEvent(SevInfo, "AuthzPublicKeyFileNotSet");
|
||||
}
|
||||
if (FLOW_KNOBS->ALLOW_TOKENLESS_TENANT_ACCESS)
|
||||
TraceEvent(SevWarnAlways, "AuthzTokenlessAccessEnabled");
|
||||
|
||||
if (expectsPublicAddress) {
|
||||
for (int ii = 0; ii < (opts.publicAddresses.secondaryAddress.present() ? 2 : 1); ++ii) {
|
||||
|
|
|
@ -142,7 +142,7 @@ void FlowKnobs::initialize(Randomize randomize, IsSimulated isSimulated) {
|
|||
init( ALLOW_TOKENLESS_TENANT_ACCESS, false );
|
||||
init( AUDIT_LOGGING_ENABLED, true );
|
||||
init( PUBLIC_KEY_FILE_MAX_SIZE, 1024 * 1024 );
|
||||
init( PUBLIC_KEY_FILE_REFRESH_INTERVAL_SECONDS, 30 );
|
||||
init( PUBLIC_KEY_FILE_REFRESH_INTERVAL_SECONDS, 300 );
|
||||
init( AUDIT_TIME_WINDOW, 5.0 );
|
||||
init( TOKEN_CACHE_SIZE, 2000 );
|
||||
|
||||
|
|
|
@ -339,11 +339,12 @@ static udp::endpoint udpEndpoint(NetworkAddress const& n) {
|
|||
|
||||
class BindPromise {
|
||||
Promise<Void> p;
|
||||
const char* errContext;
|
||||
std::variant<const char*, AuditedEvent> errContext;
|
||||
UID errID;
|
||||
|
||||
public:
|
||||
BindPromise(const char* errContext, UID errID) : errContext(errContext), errID(errID) {}
|
||||
BindPromise(AuditedEvent auditedEvent, UID errID) : errContext(auditedEvent), errID(errID) {}
|
||||
BindPromise(BindPromise const& r) : p(r.p), errContext(r.errContext), errID(r.errID) {}
|
||||
BindPromise(BindPromise&& r) noexcept : p(std::move(r.p)), errContext(r.errContext), errID(r.errID) {}
|
||||
|
||||
|
@ -354,7 +355,12 @@ public:
|
|||
if (error) {
|
||||
// Log the error...
|
||||
{
|
||||
TraceEvent evt(SevWarn, errContext, errID);
|
||||
std::optional<TraceEvent> traceEvent;
|
||||
if (std::holds_alternative<AuditedEvent>(errContext))
|
||||
traceEvent.emplace(SevWarn, std::get<AuditedEvent>(errContext), errID);
|
||||
else
|
||||
traceEvent.emplace(SevWarn, std::get<const char*>(errContext), errID);
|
||||
TraceEvent& evt = *traceEvent;
|
||||
evt.suppressFor(1.0).detail("ErrorCode", error.value()).detail("Message", error.message());
|
||||
// There is no function in OpenSSL to use to check if an error code is from OpenSSL,
|
||||
// but all OpenSSL errors have a non-zero "library" code set in bits 24-32, and linux
|
||||
|
@ -800,8 +806,8 @@ struct SSLHandshakerThread final : IThreadPoolReceiver {
|
|||
}
|
||||
if (h.err.failed()) {
|
||||
TraceEvent(SevWarn,
|
||||
h.type == ssl_socket::handshake_type::client ? "N2_ConnectHandshakeError"
|
||||
: "N2_AcceptHandshakeError")
|
||||
h.type == ssl_socket::handshake_type::client ? "N2_ConnectHandshakeError"_audit
|
||||
: "N2_AcceptHandshakeError"_audit)
|
||||
.detail("ErrorCode", h.err.value())
|
||||
.detail("ErrorMsg", h.err.message().c_str())
|
||||
.detail("BackgroundThread", true);
|
||||
|
@ -811,8 +817,8 @@ struct SSLHandshakerThread final : IThreadPoolReceiver {
|
|||
}
|
||||
} catch (...) {
|
||||
TraceEvent(SevWarn,
|
||||
h.type == ssl_socket::handshake_type::client ? "N2_ConnectHandshakeUnknownError"
|
||||
: "N2_AcceptHandshakeUnknownError")
|
||||
h.type == ssl_socket::handshake_type::client ? "N2_ConnectHandshakeUnknownError"_audit
|
||||
: "N2_AcceptHandshakeUnknownError"_audit)
|
||||
.detail("BackgroundThread", true);
|
||||
h.done.sendError(connection_failed());
|
||||
}
|
||||
|
@ -903,7 +909,7 @@ public:
|
|||
N2::g_net2->sslHandshakerPool->post(handshake);
|
||||
} else {
|
||||
// Otherwise use flow network thread
|
||||
BindPromise p("N2_AcceptHandshakeError", UID());
|
||||
BindPromise p("N2_AcceptHandshakeError"_audit, UID());
|
||||
onHandshook = p.getFuture();
|
||||
self->ssl_sock.async_handshake(boost::asio::ssl::stream_base::server, std::move(p));
|
||||
}
|
||||
|
@ -985,7 +991,7 @@ public:
|
|||
N2::g_net2->sslHandshakerPool->post(handshake);
|
||||
} else {
|
||||
// Otherwise use flow network thread
|
||||
BindPromise p("N2_ConnectHandshakeError", self->id);
|
||||
BindPromise p("N2_ConnectHandshakeError"_audit, self->id);
|
||||
onHandshook = p.getFuture();
|
||||
self->ssl_sock.async_handshake(boost::asio::ssl::stream_base::client, std::move(p));
|
||||
}
|
||||
|
|
|
@ -25,12 +25,15 @@
|
|||
#include "flow/JsonTraceLogFormatter.h"
|
||||
#include "flow/flow.h"
|
||||
#include "flow/DeterministicRandom.h"
|
||||
#include "flow/UnitTest.h"
|
||||
#include <exception>
|
||||
#include <stdlib.h>
|
||||
#include <stdarg.h>
|
||||
#include <cctype>
|
||||
#include <time.h>
|
||||
#include <set>
|
||||
#include <unordered_set>
|
||||
#include <string_view>
|
||||
#include <iomanip>
|
||||
#include "flow/IThreadPool.h"
|
||||
#include "flow/ThreadHelper.actor.h"
|
||||
|
@ -101,6 +104,7 @@ SuppressionMap suppressedEvents;
|
|||
static TransientThresholdMetricSample<Standalone<StringRef>>* traceEventThrottlerCache;
|
||||
static const char* TRACE_EVENT_THROTTLE_STARTING_TYPE = "TraceEventThrottle_";
|
||||
static const char* TRACE_EVENT_INVALID_SUPPRESSION = "InvalidSuppression_";
|
||||
static const char* TRACE_EVENT_INVALID_AUDIT_LOG_TYPE = "InvalidAuditLogType_";
|
||||
static int TRACE_LOG_MAX_PREOPEN_BUFFER = 1000000;
|
||||
|
||||
struct TraceLog {
|
||||
|
@ -859,13 +863,15 @@ std::string getTraceFormatExtension() {
|
|||
return std::string(g_traceLog.formatter->getExtension());
|
||||
}
|
||||
|
||||
BaseTraceEvent::BaseTraceEvent() : initialized(true), enabled(false), logged(true) {}
|
||||
BaseTraceEvent::State::State(Severity severity) noexcept
|
||||
: value((g_network == nullptr || FLOW_KNOBS->MIN_TRACE_SEVERITY <= severity) ? Type::ENABLED : Type::DISABLED) {}
|
||||
|
||||
BaseTraceEvent::BaseTraceEvent() : enabled(), initialized(true), logged(true) {}
|
||||
BaseTraceEvent::BaseTraceEvent(Severity severity, const char* type, UID id)
|
||||
: initialized(false), enabled(g_network == nullptr || FLOW_KNOBS->MIN_TRACE_SEVERITY <= severity), logged(false),
|
||||
severity(severity), type(type), id(id) {}
|
||||
: enabled(severity), initialized(false), logged(false), severity(severity), type(type), id(id) {}
|
||||
|
||||
BaseTraceEvent::BaseTraceEvent(BaseTraceEvent&& ev) {
|
||||
enabled = ev.enabled;
|
||||
enabled = std::move(ev.enabled);
|
||||
err = ev.err;
|
||||
fields = std::move(ev.fields);
|
||||
id = ev.id;
|
||||
|
@ -886,13 +892,12 @@ BaseTraceEvent::BaseTraceEvent(BaseTraceEvent&& ev) {
|
|||
networkThread = ev.networkThread;
|
||||
|
||||
ev.initialized = true;
|
||||
ev.enabled = false;
|
||||
ev.logged = true;
|
||||
}
|
||||
|
||||
BaseTraceEvent& BaseTraceEvent::operator=(BaseTraceEvent&& ev) {
|
||||
// Note: still broken if ev and this are the same memory address.
|
||||
enabled = ev.enabled;
|
||||
enabled = std::move(ev.enabled);
|
||||
err = ev.err;
|
||||
fields = std::move(ev.fields);
|
||||
id = ev.id;
|
||||
|
@ -913,7 +918,6 @@ BaseTraceEvent& BaseTraceEvent::operator=(BaseTraceEvent&& ev) {
|
|||
networkThread = ev.networkThread;
|
||||
|
||||
ev.initialized = true;
|
||||
ev.enabled = false;
|
||||
ev.logged = true;
|
||||
|
||||
return *this;
|
||||
|
@ -949,8 +953,26 @@ TraceEvent::TraceEvent(Severity severity, TraceInterval& interval, UID id)
|
|||
init(interval);
|
||||
}
|
||||
|
||||
bool BaseTraceEvent::init(TraceInterval& interval) {
|
||||
bool result = init();
|
||||
TraceEvent::TraceEvent(Severity severity, AuditedEvent auditedEvent, UID id)
|
||||
: BaseTraceEvent(severity, auditedEvent.type(), id) {
|
||||
setMaxFieldLength(0);
|
||||
setMaxEventLength(0);
|
||||
if (FLOW_KNOBS->AUDIT_LOGGING_ENABLED) {
|
||||
if (!auditedEvent) {
|
||||
// Event is not whitelisted. Trace error in simulation and warning in real deployment
|
||||
TraceEvent(g_network && g_network->isSimulated() ? SevError : SevWarnAlways,
|
||||
std::string(TRACE_EVENT_INVALID_AUDIT_LOG_TYPE).append(auditedEvent.typeSv()).c_str())
|
||||
.suppressFor(5);
|
||||
} else {
|
||||
enabled.promoteToForcedIfEnabled();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TraceEvent::TraceEvent(AuditedEvent auditedEvent, UID id) : TraceEvent(SevInfo, auditedEvent, id) {}
|
||||
|
||||
void BaseTraceEvent::init(TraceInterval& interval) {
|
||||
init();
|
||||
switch (interval.count++) {
|
||||
case 0: {
|
||||
detail("BeginPair", interval.pairID);
|
||||
|
@ -963,10 +985,9 @@ bool BaseTraceEvent::init(TraceInterval& interval) {
|
|||
default:
|
||||
ASSERT(false);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
bool BaseTraceEvent::init() {
|
||||
BaseTraceEvent::State BaseTraceEvent::init() {
|
||||
ASSERT(!logged);
|
||||
if (initialized) {
|
||||
return enabled;
|
||||
|
@ -977,12 +998,16 @@ bool BaseTraceEvent::init() {
|
|||
|
||||
++g_allocation_tracing_disabled;
|
||||
|
||||
enabled = enabled && (!g_network || severity >= FLOW_KNOBS->MIN_TRACE_SEVERITY);
|
||||
if (g_network && severity < FLOW_KNOBS->MIN_TRACE_SEVERITY)
|
||||
enabled = BaseTraceEvent::State::disabled();
|
||||
|
||||
std::string_view typeSv(type);
|
||||
|
||||
// Backstop to throttle very spammy trace events
|
||||
if (enabled && g_network && !g_network->isSimulated() && severity > SevDebug && isNetworkThread()) {
|
||||
if (enabled.isSuppressible() && g_network && !g_network->isSimulated() && severity > SevDebug &&
|
||||
isNetworkThread()) {
|
||||
if (traceEventThrottlerCache->isAboveThreshold(StringRef((uint8_t*)type, strlen(type)))) {
|
||||
enabled = false;
|
||||
enabled.suppress();
|
||||
TraceEvent(SevWarnAlways, std::string(TRACE_EVENT_THROTTLE_STARTING_TYPE).append(type).c_str())
|
||||
.suppressFor(5);
|
||||
} else {
|
||||
|
@ -1054,7 +1079,8 @@ TraceEvent& TraceEvent::errorImpl(class Error const& error, bool includeCancelle
|
|||
std::string(TRACE_EVENT_INVALID_SUPPRESSION).append(type).c_str())
|
||||
.suppressFor(5);
|
||||
} else {
|
||||
enabled = false;
|
||||
// even force-enabled events should respect suppression by error type
|
||||
enabled = BaseTraceEvent::State::disabled();
|
||||
}
|
||||
}
|
||||
return *this;
|
||||
|
@ -1078,7 +1104,7 @@ BaseTraceEvent& BaseTraceEvent::detailImpl(std::string&& key, std::string&& valu
|
|||
TraceEvent(g_network && g_network->isSimulated() ? SevError : SevWarnAlways, "TraceEventOverflow")
|
||||
.setMaxEventLength(1000)
|
||||
.detail("TraceFirstBytes", fields.toString().substr(0, 300));
|
||||
enabled = false;
|
||||
enabled = BaseTraceEvent::State::disabled();
|
||||
}
|
||||
--g_allocation_tracing_disabled;
|
||||
}
|
||||
|
@ -1149,7 +1175,8 @@ BaseTraceEvent& TraceEvent::sample(double sampleRate, bool logSampleRate) {
|
|||
return *this;
|
||||
}
|
||||
|
||||
enabled = enabled && deterministicRandom()->random01() < sampleRate;
|
||||
if (deterministicRandom()->random01() >= sampleRate)
|
||||
enabled.suppress();
|
||||
|
||||
if (enabled && logSampleRate) {
|
||||
detail("SampleRate", sampleRate);
|
||||
|
@ -1161,7 +1188,7 @@ BaseTraceEvent& TraceEvent::sample(double sampleRate, bool logSampleRate) {
|
|||
|
||||
BaseTraceEvent& TraceEvent::suppressFor(double duration, bool logSuppressedEventCount) {
|
||||
ASSERT(!logged);
|
||||
if (enabled) {
|
||||
if (enabled.isSuppressible()) {
|
||||
if (initialized) {
|
||||
TraceEvent(g_network && g_network->isSimulated() ? SevError : SevWarnAlways,
|
||||
std::string(TRACE_EVENT_INVALID_SUPPRESSION).append(type).c_str())
|
||||
|
@ -1172,7 +1199,8 @@ BaseTraceEvent& TraceEvent::suppressFor(double duration, bool logSuppressedEvent
|
|||
if (g_network) {
|
||||
if (isNetworkThread()) {
|
||||
int64_t suppressedEventCount = suppressedEvents.checkAndInsertSuppression(type, duration);
|
||||
enabled = enabled && suppressedEventCount >= 0;
|
||||
if (suppressedEventCount < 0)
|
||||
enabled.suppress();
|
||||
if (enabled && logSuppressedEventCount) {
|
||||
detail("SuppressedEventCount", suppressedEventCount);
|
||||
}
|
||||
|
@ -1290,7 +1318,6 @@ void BaseTraceEvent::log() {
|
|||
if (isNetworkThread()) {
|
||||
TraceEvent::eventCounts[severity / 10]++;
|
||||
}
|
||||
|
||||
g_traceLog.writeEvent(fields, trackingKey, severity > SevWarnAlways);
|
||||
|
||||
if (g_traceLog.isOpen()) {
|
||||
|
@ -1764,3 +1791,8 @@ std::string traceableStringToString(const char* value, size_t S) {
|
|||
|
||||
return std::string(value, S - 1); // Exclude trailing \0 byte
|
||||
}
|
||||
|
||||
// AuditedEvent unit test: make sure that whitelist-checks for AuditedEvent gets evaluated at compile time, and has a
|
||||
// correct outcome
|
||||
static_assert("InvalidToken"_audit, "Either AuditedEvent has a bug or whitelisting for this event type has changed");
|
||||
static_assert(!"nvalidToken"_audit, "AuditedEvent has a bug");
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#define FLOW_TRACE_H
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <atomic>
|
||||
#include <stdarg.h>
|
||||
#include <stdint.h>
|
||||
|
@ -210,6 +211,50 @@ struct SpecialTraceMetricType
|
|||
|
||||
TRACE_METRIC_TYPE(double, double);
|
||||
|
||||
class AuditedEvent;
|
||||
|
||||
inline constexpr AuditedEvent operator""_audit(const char*, size_t) noexcept;
|
||||
|
||||
class AuditedEvent {
|
||||
// special TraceEvents that may bypass throttling or suppression
|
||||
static constexpr std::string_view auditTopics[]{
|
||||
"AttemptedRPCToPrivatePrevented",
|
||||
"AuditTokenUsed",
|
||||
"AuthzPublicKeySetApply",
|
||||
"AuthzPublicKeySetRefreshError",
|
||||
"IncomingConnection",
|
||||
"InvalidToken",
|
||||
"N2_ConnectHandshakeError",
|
||||
"N2_ConnectHandshakeUnknownError",
|
||||
"N2_AcceptHandshakeError",
|
||||
"N2_AcceptHandshakeUnknownError",
|
||||
"UnauthorizedAccessPrevented",
|
||||
};
|
||||
const char* eventType;
|
||||
int len;
|
||||
bool valid;
|
||||
explicit constexpr AuditedEvent(const char* type, int len) noexcept
|
||||
: eventType(type), len(len),
|
||||
valid(std::find(std::begin(auditTopics), std::end(auditTopics), std::string_view(type, len)) !=
|
||||
std::end(auditTopics)) // whitelist looked up during compile time
|
||||
{}
|
||||
|
||||
friend constexpr AuditedEvent operator""_audit(const char*, size_t) noexcept;
|
||||
|
||||
public:
|
||||
constexpr const char* type() const noexcept { return eventType; }
|
||||
|
||||
constexpr std::string_view typeSv() const noexcept { return std::string_view(eventType, len); }
|
||||
|
||||
explicit constexpr operator bool() const noexcept { return valid; }
|
||||
};
|
||||
|
||||
// This, along with private AuditedEvent constructor, guarantees that AuditedEvent is always created with a string
|
||||
// literal
|
||||
inline constexpr AuditedEvent operator""_audit(const char* eventType, size_t len) noexcept {
|
||||
return AuditedEvent(eventType, len);
|
||||
}
|
||||
|
||||
// The BaseTraceEvent class is the parent class of TraceEvent and provides all functionality on the TraceEvent except
|
||||
// for the functionality that can be used to suppress the trace event.
|
||||
//
|
||||
|
@ -258,6 +303,52 @@ struct BaseTraceEvent {
|
|||
BaseTraceEvent& detailf(std::string key, const char* valueFormat, ...);
|
||||
|
||||
protected:
|
||||
class State {
|
||||
enum class Type {
|
||||
DISABLED = 0,
|
||||
ENABLED,
|
||||
FORCED,
|
||||
};
|
||||
Type value;
|
||||
|
||||
public:
|
||||
constexpr State() noexcept : value(Type::DISABLED) {}
|
||||
State(Severity severity) noexcept;
|
||||
State(Severity severity, AuditedEvent) noexcept : State(severity) {
|
||||
if (*this)
|
||||
value = Type::FORCED;
|
||||
}
|
||||
|
||||
State(const State& other) noexcept = default;
|
||||
State(State&& other) noexcept : value(other.value) { other.value = Type::DISABLED; }
|
||||
State& operator=(const State& other) noexcept = default;
|
||||
State& operator=(State&& other) noexcept {
|
||||
if (this != &other) {
|
||||
value = other.value;
|
||||
other.value = Type::DISABLED;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
bool operator==(const State& other) const noexcept = default;
|
||||
bool operator!=(const State& other) const noexcept = default;
|
||||
|
||||
explicit operator bool() const noexcept { return value == Type::ENABLED || value == Type::FORCED; }
|
||||
|
||||
void suppress() noexcept {
|
||||
if (value == Type::ENABLED)
|
||||
value = Type::DISABLED;
|
||||
}
|
||||
|
||||
bool isSuppressible() const noexcept { return value == Type::ENABLED; }
|
||||
|
||||
void promoteToForcedIfEnabled() noexcept {
|
||||
if (value == Type::ENABLED)
|
||||
value = Type::FORCED;
|
||||
}
|
||||
|
||||
static constexpr State disabled() noexcept { return State(); }
|
||||
};
|
||||
|
||||
BaseTraceEvent();
|
||||
BaseTraceEvent(Severity, const char* type, UID id = UID());
|
||||
|
||||
|
@ -303,15 +394,15 @@ public:
|
|||
|
||||
BaseTraceEvent& GetLastError();
|
||||
|
||||
bool isEnabled() const { return enabled; }
|
||||
bool isEnabled() const { return static_cast<bool>(enabled); }
|
||||
|
||||
BaseTraceEvent& setErrorKind(ErrorKind errorKind);
|
||||
|
||||
explicit operator bool() const { return enabled; }
|
||||
explicit operator bool() const { return static_cast<bool>(enabled); }
|
||||
|
||||
void log();
|
||||
|
||||
void disable() { enabled = false; } // Disables the trace event so it doesn't get logged
|
||||
void disable() { enabled.suppress(); } // Disables the trace event so it doesn't get logged
|
||||
|
||||
virtual ~BaseTraceEvent(); // Actually logs the event
|
||||
|
||||
|
@ -323,8 +414,8 @@ public:
|
|||
const TraceEventFields& getFields() const { return fields; }
|
||||
|
||||
protected:
|
||||
State enabled;
|
||||
bool initialized;
|
||||
bool enabled;
|
||||
bool logged;
|
||||
std::string trackingKey;
|
||||
TraceEventFields fields;
|
||||
|
@ -344,8 +435,8 @@ protected:
|
|||
static unsigned long eventCounts[NUM_MAJOR_LEVELS_OF_EVENTS];
|
||||
static thread_local bool networkThread;
|
||||
|
||||
bool init();
|
||||
bool init(struct TraceInterval&);
|
||||
State init();
|
||||
void init(struct TraceInterval&);
|
||||
};
|
||||
|
||||
// The TraceEvent class provides the implementation for BaseTraceEvent. The only functions that should be implemented
|
||||
|
@ -356,6 +447,8 @@ struct TraceEvent : public BaseTraceEvent {
|
|||
TraceEvent(Severity, const char* type, UID id = UID());
|
||||
TraceEvent(struct TraceInterval&, UID id = UID());
|
||||
TraceEvent(Severity severity, struct TraceInterval& interval, UID id = UID());
|
||||
TraceEvent(AuditedEvent, UID id = UID());
|
||||
TraceEvent(Severity, AuditedEvent, UID id = UID());
|
||||
|
||||
BaseTraceEvent& error(const class Error& e) {
|
||||
if (enabled) {
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import glob
|
||||
import json
|
||||
from pathlib import Path
|
||||
import random
|
||||
|
@ -8,6 +9,7 @@ import socket
|
|||
import time
|
||||
import fcntl
|
||||
import sys
|
||||
import xml.etree.ElementTree as ET
|
||||
import tempfile
|
||||
from authz_util import private_key_gen, public_keyset_from_keys
|
||||
from test_util import random_alphanum_string
|
||||
|
@ -177,6 +179,7 @@ logdir = {logdir}
|
|||
self.custom_config = custom_config
|
||||
self.blob_granules_enabled = blob_granules_enabled
|
||||
self.enable_encryption_at_rest = enable_encryption_at_rest
|
||||
self.trace_check_entries = []
|
||||
if blob_granules_enabled:
|
||||
# add extra process for blob_worker
|
||||
self.process_number += 1
|
||||
|
@ -375,8 +378,12 @@ logdir = {logdir}
|
|||
return self
|
||||
|
||||
def __exit__(self, xc_type, exc_value, traceback):
|
||||
if self.trace_check_entries:
|
||||
# sleep a while before checking trace to make sure everything has flushed out
|
||||
time.sleep(3)
|
||||
self.stop_cluster()
|
||||
self.release_ports()
|
||||
self.check_trace()
|
||||
|
||||
def release_ports(self):
|
||||
self.port_provider.release_locks()
|
||||
|
@ -694,3 +701,45 @@ logdir = {logdir}
|
|||
else:
|
||||
print("No errors found in logs")
|
||||
return err_cnt == 0
|
||||
|
||||
# Add trace check callback function to be called once the cluster terminates.
|
||||
# _from() and _from_to() variants offer pre-filtering by time window, using epoch seconds as timestamps
|
||||
# Consider using ScopedTraceChecker to simplify timestamp management
|
||||
# Caveat: the checker assumes the traces to be in XML and to have .xml file extensions,
|
||||
# which prevents fdbmonitor.log from being considered and parsed.
|
||||
def add_trace_check(self, check_func, filename_substr: str = ""):
|
||||
self.trace_check_entries.append((check_func, None, None, filename_substr))
|
||||
|
||||
def add_trace_check_from(self, check_func, time_begin, filename_substr: str = ""):
|
||||
self.trace_check_entries.append((check_func, time_begin, None, filename_substr))
|
||||
|
||||
def add_trace_check_from_to(self, check_func, time_begin, time_end, filename_substr: str = ""):
|
||||
self.trace_check_entries.append((check_func, time_begin, time_end, filename_substr))
|
||||
|
||||
# generator function that yields (filename, event_type, XML_trace_entry) that matches the parameter
|
||||
def __loop_through_trace(self, time_begin, time_end, filename_substr: str):
|
||||
glob_pattern = str(self.log.joinpath("*.xml"))
|
||||
for file in glob.glob(glob_pattern):
|
||||
if filename_substr and file.find(filename_substr) == -1:
|
||||
continue
|
||||
print(f"### considering file {file}")
|
||||
for line in open(file):
|
||||
try:
|
||||
entry = ET.fromstring(line)
|
||||
# Below fields always exist. If not, their access throws to be skipped over
|
||||
ev_type = entry.attrib["Type"]
|
||||
ts = float(entry.attrib["Time"])
|
||||
if time_begin != None and ts < time_begin:
|
||||
continue
|
||||
if time_end != None and time_end < ts:
|
||||
break # no need to look further in this file
|
||||
yield (file, ev_type, entry)
|
||||
except ET.ParseError as e:
|
||||
pass # ignore header, footer, or broken line
|
||||
|
||||
# applies user-provided check_func that takes a trace entry generator as the parameter
|
||||
def check_trace(self):
|
||||
for check_func, time_begin, time_end, filename_substr in self.trace_check_entries:
|
||||
check_func(self.__loop_through_trace(time_begin, time_end, filename_substr))
|
||||
|
||||
|
||||
|
|
|
@ -1,7 +1,23 @@
|
|||
import random
|
||||
import string
|
||||
import time
|
||||
|
||||
alphanum_letters = string.ascii_letters + string.digits
|
||||
|
||||
def random_alphanum_string(length):
|
||||
return "".join(random.choice(alphanum_letters) for _ in range(length))
|
||||
|
||||
# attach a post-run trace checker to cluster that runs for events between the time of scope entry and exit
|
||||
class ScopedTraceChecker:
|
||||
def __init__(self, cluster, checker_func, filename_substr: str = ""):
|
||||
self.cluster = cluster
|
||||
self.checker_func = checker_func
|
||||
self.filename_substr = filename_substr
|
||||
self.begin = None
|
||||
|
||||
def __enter__(self):
|
||||
self.begin = time.time()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.cluster.add_trace_check_from_to(self.checker_func, self.begin, time.time(), self.filename_substr)
|
||||
|
|
|
@ -70,7 +70,6 @@ class TempCluster(LocalCluster):
|
|||
if self.remove_at_exit:
|
||||
shutil.rmtree(self.tmp_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
script_desc = """
|
||||
This script automatically configures a temporary local cluster on the machine
|
||||
|
|
|
@ -23,6 +23,7 @@ import argparse
|
|||
import authlib
|
||||
import base64
|
||||
import fdb
|
||||
import functools
|
||||
import os
|
||||
import pytest
|
||||
import random
|
||||
|
@ -33,16 +34,20 @@ from multiprocessing import Process, Pipe
|
|||
from typing import Union
|
||||
from authz_util import token_gen, private_key_gen, public_keyset_from_keys, alg_from_kty
|
||||
from util import random_alphanum_str, random_alphanum_bytes, to_str, to_bytes, KeyFileReverter, wait_until_tenant_tr_succeeds, wait_until_tenant_tr_fails
|
||||
from test_util import ScopedTraceChecker
|
||||
from local_cluster import TLSConfig
|
||||
from tmp_cluster import TempCluster
|
||||
|
||||
special_key_ranges = [
|
||||
("transaction description", b"/description", b"/description\x00"),
|
||||
("global knobs", b"/globalKnobs", b"/globalKnobs\x00"),
|
||||
("knobs", b"/knobs0", b"/knobs0\x00"),
|
||||
("conflicting keys", b"/transaction/conflicting_keys/", b"/transaction/conflicting_keys/\xff\xff"),
|
||||
("read conflict range", b"/transaction/read_conflict_range/", b"/transaction/read_conflict_range/\xff\xff"),
|
||||
("conflicting keys", b"/transaction/write_conflict_range/", b"/transaction/write_conflict_range/\xff\xff"),
|
||||
("data distribution stats", b"/metrics/data_distribution_stats/", b"/metrics/data_distribution_stats/\xff\xff"),
|
||||
("kill storage", b"/globals/killStorage", b"/globals/killStorage\x00"),
|
||||
# (description, range_begin, range_end, readable, writable)
|
||||
("transaction description", b"\xff\xff/description", b"\xff\xff/description\x00", True, False),
|
||||
("global knobs", b"\xff\xff/globalKnobs", b"\xff\xff/globalKnobs\x00", True, False),
|
||||
("knobs", b"\xff\xff/knobs/", b"\xff\xff/knobs0\x00", True, False),
|
||||
("conflicting keys", b"\xff\xff/transaction/conflicting_keys/", b"\xff\xff/transaction/conflicting_keys/\xff\xff", True, False),
|
||||
("read conflict range", b"\xff\xff/transaction/read_conflict_range/", b"\xff\xff/transaction/read_conflict_range/\xff\xff", True, False),
|
||||
("conflicting keys", b"\xff\xff/transaction/write_conflict_range/", b"\xff\xff/transaction/write_conflict_range/\xff\xff", True, False),
|
||||
("data distribution stats", b"\xff\xff/metrics/data_distribution_stats/", b"\xff\xff/metrics/data_distribution_stats/\xff\xff", False, False),
|
||||
("kill storage", b"\xff\xff/globals/killStorage", b"\xff\xff/globals/killStorage\x00", True, False),
|
||||
]
|
||||
|
||||
# handler for when looping is assumed with usage
|
||||
|
@ -91,21 +96,43 @@ def test_token_option(cluster, default_tenant, tenant_tr_gen, token_claim_1h):
|
|||
except fdb.FDBError as e:
|
||||
assert e.code == 6000, f"expected permission_denied, got {e} instead"
|
||||
|
||||
def test_simple_tenant_access(cluster, default_tenant, tenant_tr_gen, token_claim_1h):
|
||||
token = token_gen(cluster.private_key, token_claim_1h(default_tenant))
|
||||
tr = tenant_tr_gen(default_tenant)
|
||||
tr.options.set_authorization_token(token)
|
||||
def commit_some_value(tr):
|
||||
tr[b"abc"] = b"def"
|
||||
tr.commit().wait()
|
||||
def test_simple_tenant_access(cluster, default_tenant, tenant_tr_gen, token_claim_1h, tenant_id_from_name):
|
||||
def check_token_usage_trace(trace_entries, token_claim, token_signature_part):
|
||||
found = False
|
||||
for filename, ev_type, entry in trace_entries:
|
||||
if ev_type == "AuditTokenUsed":
|
||||
jti_actual = entry.attrib["TokenId"]
|
||||
jti_expect = token_claim["jti"]
|
||||
tenantid_actual = entry.attrib["TenantId"]
|
||||
tenantid_expect_bytes = base64.b64decode(token_claim["tenants"][0])
|
||||
tenantid_expect = hex(int.from_bytes(tenantid_expect_bytes, "big"))
|
||||
if jti_actual == jti_expect and tenantid_actual == tenantid_expect:
|
||||
found = True
|
||||
else:
|
||||
print(f"found unknown tenant in token usage audit log; tokenid={jti_actual} vs. {jti_expect}, tenantid={tenantid_actual} vs. {tenantid_expect}")
|
||||
for k, v in entry.items():
|
||||
if k.find(token_signature_part) != -1 or v.find(token_signature_part) != -1:
|
||||
pytest.fail(f"token usage trace includes sensitive token signature: key={k} value={v}")
|
||||
if not found:
|
||||
pytest.fail("failed to find any AuditTokenUsed entry matching token from the testcase")
|
||||
|
||||
loop_until_success(tr, commit_some_value)
|
||||
tr = tenant_tr_gen(default_tenant)
|
||||
tr.options.set_authorization_token(token)
|
||||
def read_back_value(tr):
|
||||
return tr[b"abc"].value
|
||||
value = loop_until_success(tr, read_back_value)
|
||||
assert value == b"def", "tenant write transaction not visible"
|
||||
token_claim = token_claim_1h(default_tenant)
|
||||
token = token_gen(cluster.private_key, token_claim)
|
||||
token_sig_part = to_str(token[token.rfind(b".") + 1:])
|
||||
with ScopedTraceChecker(cluster, functools.partial(check_token_usage_trace, token_claim=token_claim, token_signature_part=token_sig_part)):
|
||||
tr = tenant_tr_gen(default_tenant)
|
||||
tr.options.set_authorization_token(token)
|
||||
def commit_some_value(tr):
|
||||
tr[b"abc"] = b"def"
|
||||
tr.commit().wait()
|
||||
|
||||
loop_until_success(tr, commit_some_value)
|
||||
tr = tenant_tr_gen(default_tenant)
|
||||
tr.options.set_authorization_token(token)
|
||||
def read_back_value(tr):
|
||||
return tr[b"abc"].value
|
||||
value = loop_until_success(tr, read_back_value)
|
||||
assert value == b"def", "tenant write transaction not visible"
|
||||
|
||||
def test_cross_tenant_access_disallowed(cluster, default_tenant, tenant_gen, tenant_tr_gen, token_claim_1h):
|
||||
# use default tenant token with second tenant transaction and see it fail
|
||||
|
@ -191,15 +218,18 @@ def test_system_and_special_key_range_disallowed(db, tenant_tr_gen):
|
|||
except fdb.FDBError as e:
|
||||
assert e.code == 6000, f"expected permission_denied, got {e} instead"
|
||||
|
||||
for range_name, special_range_begin, special_range_end in special_key_ranges:
|
||||
for range_name, special_range_begin, special_range_end, readable, _ in special_key_ranges:
|
||||
tr = db.create_transaction()
|
||||
tr.options.set_access_system_keys()
|
||||
tr.options.set_special_key_space_relaxed()
|
||||
try:
|
||||
kvs = tr.get_range(special_range_begin, special_range_end, limit=1).to_list()
|
||||
assert False, f"disallowed special keyspace read for range {range_name} has succeeded. found item {kvs}"
|
||||
if readable:
|
||||
pass
|
||||
else:
|
||||
pytest.fail(f"disallowed special keyspace read for range '{range_name}' has succeeded. found item {kvs}")
|
||||
except fdb.FDBError as e:
|
||||
assert e.code == 6000, f"expected permission_denied from attempted read to range {range_name}, got {e} instead"
|
||||
assert e.code in [6000, 6001], f"expected authz error from attempted read to range '{range_name}', got {e} instead"
|
||||
|
||||
try:
|
||||
tr = db.create_transaction()
|
||||
|
@ -210,16 +240,21 @@ def test_system_and_special_key_range_disallowed(db, tenant_tr_gen):
|
|||
except fdb.FDBError as e:
|
||||
assert e.code == 6000, f"expected permission_denied, got {e} instead"
|
||||
|
||||
for range_name, special_range_begin, special_range_end in special_key_ranges:
|
||||
for range_name, special_range_begin, special_range_end, _, writable in special_key_ranges:
|
||||
tr = db.create_transaction()
|
||||
tr.options.set_access_system_keys()
|
||||
tr.options.set_special_key_space_relaxed()
|
||||
tr.options.set_special_key_space_enable_writes()
|
||||
try:
|
||||
del tr[special_range_begin:special_range_end]
|
||||
tr.commit().wait()
|
||||
assert False, f"write to disallowed special keyspace range {range_name} has succeeded"
|
||||
if writable:
|
||||
pass
|
||||
else:
|
||||
pytest.fail(f"write to disallowed special keyspace range '{range_name}' has succeeded")
|
||||
except fdb.FDBError as e:
|
||||
assert e.code == 6000, f"expected permission_denied from attempted write to range {range_name}, got {e} instead"
|
||||
error_range = [6000, 6001, 2115] if not writable else []
|
||||
assert e.code in error_range, f"expected errors {error_range} from attempted write to range '{range_name}', got {e} instead"
|
||||
|
||||
try:
|
||||
tr = db.create_transaction()
|
||||
|
@ -333,57 +368,108 @@ def test_bad_token(cluster, default_tenant, tenant_tr_gen, token_claim_1h):
|
|||
return d
|
||||
|
||||
claim_mutations = [
|
||||
("no nbf", lambda claim: del_attr(claim, "nbf")),
|
||||
("no exp", lambda claim: del_attr(claim, "exp")),
|
||||
("no iat", lambda claim: del_attr(claim, "iat")),
|
||||
("too early", lambda claim: set_attr(claim, "nbf", time.time() + 30)),
|
||||
("too late", lambda claim: set_attr(claim, "exp", time.time() - 10)),
|
||||
("no tenants", lambda claim: del_attr(claim, "tenants")),
|
||||
("empty tenants", lambda claim: set_attr(claim, "tenants", [])),
|
||||
("no nbf", lambda claim: del_attr(claim, "nbf"), "NoNotBefore"),
|
||||
("no exp", lambda claim: del_attr(claim, "exp"), "NoExpirationTime"),
|
||||
("no iat", lambda claim: del_attr(claim, "iat"), "NoIssuedAt"),
|
||||
("too early", lambda claim: set_attr(claim, "nbf", time.time() + 30), "TokenNotYetValid"),
|
||||
("too late", lambda claim: set_attr(claim, "exp", time.time() - 10), "Expired"),
|
||||
("no tenants", lambda claim: del_attr(claim, "tenants"), "NoTenants"),
|
||||
("empty tenants", lambda claim: set_attr(claim, "tenants", []), "TenantTokenMismatch"),
|
||||
]
|
||||
for case_name, mutation in claim_mutations:
|
||||
|
||||
def check_invalid_token_trace(trace_entries, expected_reason, case_name):
|
||||
invalid_token_found = False
|
||||
unauthorized_access_found = False
|
||||
for filename, ev_type, entry in trace_entries:
|
||||
if ev_type == "InvalidToken":
|
||||
actual_reason = entry.attrib["Reason"]
|
||||
if actual_reason == expected_reason:
|
||||
invalid_token_found = True
|
||||
else:
|
||||
print("InvalidToken reason mismatch: expected '{}' got '{}'".format(expected_reason, actual_reason))
|
||||
print("trace entry: {}".format(entry.items()))
|
||||
elif ev_type == "UnauthorizedAccessPrevented":
|
||||
unauthorized_access_found = True
|
||||
if not invalid_token_found:
|
||||
pytest.fail("Failed to find invalid token reason '{}' in trace for case '{}'".format(expected_reason, case_name))
|
||||
if not unauthorized_access_found:
|
||||
pytest.fail("Failed to find 'UnauthorizedAccessPrevented' event in trace for case '{}'".format(case_name))
|
||||
|
||||
for case_name, mutation, expected_failure_reason in claim_mutations:
|
||||
with ScopedTraceChecker(cluster, functools.partial(check_invalid_token_trace, expected_reason=expected_failure_reason, case_name=case_name)) as checker:
|
||||
tr = tenant_tr_gen(default_tenant)
|
||||
tr.options.set_authorization_token(token_gen(cluster.private_key, mutation(token_claim_1h(default_tenant))))
|
||||
print(f"Trace check begin for '{case_name}': {checker.begin}")
|
||||
try:
|
||||
value = tr[b"abc"].value
|
||||
assert False, f"expected permission_denied for case '{case_name}', but read transaction went through"
|
||||
except fdb.FDBError as e:
|
||||
assert e.code == 6000, f"expected permission_denied for case '{case_name}', got {e} instead"
|
||||
tr = tenant_tr_gen(default_tenant)
|
||||
tr.options.set_authorization_token(token_gen(cluster.private_key, mutation(token_claim_1h(default_tenant))))
|
||||
tr[b"abc"] = b"def"
|
||||
try:
|
||||
tr.commit().wait()
|
||||
assert False, f"expected permission_denied for case '{case_name}', but write transaction went through"
|
||||
except fdb.FDBError as e:
|
||||
assert e.code == 6000, f"expected permission_denied for case '{case_name}', got {e} instead"
|
||||
print(f"Trace check end for '{case_name}': {time.time()}")
|
||||
|
||||
with ScopedTraceChecker(cluster, functools.partial(check_invalid_token_trace, expected_reason="UnknownKey", case_name="untrusted key")):
|
||||
# unknown key case: override "kid" field in header
|
||||
# first, update only the kid field of key with export-update-import
|
||||
key_dict = cluster.private_key.as_dict(is_private=True)
|
||||
key_dict["kid"] = random_alphanum_str(10)
|
||||
renamed_key = authlib.jose.JsonWebKey.import_key(key_dict)
|
||||
unknown_key_token = token_gen(
|
||||
renamed_key,
|
||||
token_claim_1h(default_tenant),
|
||||
headers={
|
||||
"typ": "JWT",
|
||||
"kty": renamed_key.kty,
|
||||
"alg": alg_from_kty(renamed_key.kty),
|
||||
"kid": renamed_key.kid,
|
||||
})
|
||||
tr = tenant_tr_gen(default_tenant)
|
||||
tr.options.set_authorization_token(token_gen(cluster.private_key, mutation(token_claim_1h(default_tenant))))
|
||||
tr.options.set_authorization_token(unknown_key_token)
|
||||
try:
|
||||
value = tr[b"abc"].value
|
||||
assert False, f"expected permission_denied for case {case_name}, but read transaction went through"
|
||||
assert False, f"expected permission_denied for 'unknown key' case, but read transaction went through"
|
||||
except fdb.FDBError as e:
|
||||
assert e.code == 6000, f"expected permission_denied for case {case_name}, got {e} instead"
|
||||
assert e.code == 6000, f"expected permission_denied for 'unknown key' case, got {e} instead"
|
||||
tr = tenant_tr_gen(default_tenant)
|
||||
tr.options.set_authorization_token(token_gen(cluster.private_key, mutation(token_claim_1h(default_tenant))))
|
||||
tr.options.set_authorization_token(unknown_key_token)
|
||||
tr[b"abc"] = b"def"
|
||||
try:
|
||||
tr.commit().wait()
|
||||
assert False, f"expected permission_denied for case {case_name}, but write transaction went through"
|
||||
assert False, f"expected permission_denied for 'unknown key' case, but write transaction went through"
|
||||
except fdb.FDBError as e:
|
||||
assert e.code == 6000, f"expected permission_denied for case {case_name}, got {e} instead"
|
||||
assert e.code == 6000, f"expected permission_denied for 'unknown key' case, got {e} instead"
|
||||
|
||||
# unknown key case: override "kid" field in header
|
||||
# first, update only the kid field of key with export-update-import
|
||||
key_dict = cluster.private_key.as_dict(is_private=True)
|
||||
key_dict["kid"] = random_alphanum_str(10)
|
||||
renamed_key = authlib.jose.JsonWebKey.import_key(key_dict)
|
||||
unknown_key_token = token_gen(
|
||||
renamed_key,
|
||||
token_claim_1h(default_tenant),
|
||||
headers={
|
||||
"typ": "JWT",
|
||||
"kty": renamed_key.kty,
|
||||
"alg": alg_from_kty(renamed_key.kty),
|
||||
"kid": renamed_key.kid,
|
||||
})
|
||||
tr = tenant_tr_gen(default_tenant)
|
||||
tr.options.set_authorization_token(unknown_key_token)
|
||||
try:
|
||||
value = tr[b"abc"].value
|
||||
assert False, f"expected permission_denied for 'unknown key' case, but read transaction went through"
|
||||
except fdb.FDBError as e:
|
||||
assert e.code == 6000, f"expected permission_denied for 'unknown key' case, got {e} instead"
|
||||
tr = tenant_tr_gen(default_tenant)
|
||||
tr.options.set_authorization_token(unknown_key_token)
|
||||
tr[b"abc"] = b"def"
|
||||
try:
|
||||
tr.commit().wait()
|
||||
assert False, f"expected permission_denied for 'unknown key' case, but write transaction went through"
|
||||
except fdb.FDBError as e:
|
||||
assert e.code == 6000, f"expected permission_denied for 'unknown key' case, got {e} instead"
|
||||
def test_authz_not_enabled_trace(build_dir):
|
||||
# spin up a cluster without authz and see it logs as expected
|
||||
def check_authz_disablement_traces(trace_entries):
|
||||
keyfile_unset_ev = "AuthzPublicKeyFileNotSet"
|
||||
tokenless_mode_ev = "AuthzTokenlessAccessEnabled"
|
||||
keyfile_unset_found = False
|
||||
tokenless_mode_found = False
|
||||
for _, ev_type, _ in trace_entries:
|
||||
if ev_type == keyfile_unset_ev:
|
||||
keyfile_unset_found = True
|
||||
elif ev_type == tokenless_mode_ev:
|
||||
tokenless_mode_found = True
|
||||
if not keyfile_unset_found:
|
||||
pytest.fail(f"failed to locate keyfile unset trace '{keyfile_unset_ev}'")
|
||||
if not tokenless_mode_found:
|
||||
pytest.fail(f"failed to locate tokenless mode trace '{keyfile_unset_ev}'")
|
||||
|
||||
with TempCluster(
|
||||
build_dir=build_dir,
|
||||
tls_config = TLSConfig(server_chain_len=3, client_chain_len=2),
|
||||
authorization_kty = "", # this ensures that no public key files are generated and produces AuthzPublicKeyFileNotSet
|
||||
remove_at_exit=True,
|
||||
custom_config={
|
||||
"knob-allow-tokenless-tenant-access": "true",
|
||||
}) as cluster:
|
||||
cluster.add_trace_check(check_authz_disablement_traces)
|
||||
# safe to drop cluster immediately. TempCluster.__enter__ returns only after fdbcli "create database" succeeds.
|
||||
|
|
|
@ -19,12 +19,14 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
import fdb
|
||||
import functools
|
||||
import pytest
|
||||
import subprocess
|
||||
import admin_server
|
||||
import base64
|
||||
import glob
|
||||
import time
|
||||
import ipaddress
|
||||
from local_cluster import TLSConfig
|
||||
from tmp_cluster import TempCluster
|
||||
from typing import Union
|
||||
|
@ -102,6 +104,7 @@ def admin_ipc():
|
|||
|
||||
@pytest.fixture(autouse=True, scope=cluster_scope)
|
||||
def cluster(admin_ipc, build_dir, public_key_refresh_interval, trusted_client, force_multi_version_client, use_grv_cache):
|
||||
cluster_creation_time = time.time()
|
||||
with TempCluster(
|
||||
build_dir=build_dir,
|
||||
tls_config=TLSConfig(server_chain_len=3, client_chain_len=2),
|
||||
|
@ -125,32 +128,67 @@ def cluster(admin_ipc, build_dir, public_key_refresh_interval, trusted_client, f
|
|||
admin_ipc.request("configure_client", [force_multi_version_client, use_grv_cache, logdir])
|
||||
admin_ipc.request("configure_tls", [keyfile, certfile, cafile])
|
||||
admin_ipc.request("connect", [str(cluster.cluster_file)])
|
||||
|
||||
def check_no_invalid_traces(entries):
|
||||
for filename, ev_type, entry in entries:
|
||||
if ev_type.startswith("InvalidAuditLogType_"):
|
||||
pytest.fail("Invalid audit log detected in file {}: {}".format(filename, entry.items()))
|
||||
|
||||
cluster.add_trace_check(check_no_invalid_traces)
|
||||
|
||||
def check_public_keyset_apply(entries, cluster_creation_time):
|
||||
keyset_apply_ev_type = "AuthzPublicKeySetApply"
|
||||
bad_ev_type = "AuthzPublicKeyFileNotSet"
|
||||
apply_trace_time = None
|
||||
bad_trace_time = None
|
||||
for filename, ev_type, entry in entries:
|
||||
if apply_trace_time is None and ev_type == keyset_apply_ev_type and int(entry.attrib["NumPublicKeys"]) > 0:
|
||||
apply_trace_time = float(entry.attrib["Time"])
|
||||
if bad_trace_time is None and ev_type == bad_ev_type:
|
||||
bad_trace_found = float(entry.attrib["Time"])
|
||||
if apply_trace_time is None:
|
||||
pytest.fail(f"failed to find '{keyset_apply_ev_type}' event with >0 public keys")
|
||||
else:
|
||||
print(f"'{keyset_apply_ev_type}' found at {apply_trace_time - cluster_creation_time}s since cluster creation")
|
||||
if bad_trace_time is not None:
|
||||
pytest.fail(f"unexpected '{bad_ev_type}' trace found at {bad_trace_time}")
|
||||
|
||||
cluster.add_trace_check(functools.partial(check_public_keyset_apply, cluster_creation_time=cluster_creation_time))
|
||||
|
||||
def check_connection_traces(entries, look_for_untrusted):
|
||||
trusted_conns_traced = False # admin connections
|
||||
untrusted_conns_traced = False
|
||||
ev_target = "IncomingConnection"
|
||||
for _, ev_type, entry in entries:
|
||||
if ev_type == ev_target:
|
||||
trusted = entry.attrib["Trusted"]
|
||||
from_addr = entry.attrib["FromAddr"]
|
||||
client_ip, port, tls_suffix = from_addr.split(":")
|
||||
if tls_suffix != "tls":
|
||||
pytest.fail(f"{ev_target} trace entry's FromAddr does not have a valid ':tls' suffix: found '{tls_suffix}'")
|
||||
try:
|
||||
ip = ipaddress.ip_address(client_ip)
|
||||
except ValueError as e:
|
||||
pytest.fail(f"{ev_target} trace entry's FromAddr '{client_ip}' has an invalid IP format: {e}")
|
||||
|
||||
if trusted == "1":
|
||||
trusted_conns_traced = True
|
||||
elif trusted == "0":
|
||||
untrusted_conns_traced = True
|
||||
else:
|
||||
pytest.fail(f"{ev_target} trace entry's Trusted field has an unexpected value: {trusted}")
|
||||
if look_for_untrusted and not untrusted_conns_traced:
|
||||
pytest.fail("failed to find any 'IncomingConnection' traces for untrusted clients")
|
||||
if not trusted_conns_traced:
|
||||
pytest.fail("failed to find any 'IncomingConnection' traces for trusted clients")
|
||||
|
||||
cluster.add_trace_check(functools.partial(check_connection_traces, look_for_untrusted=not trusted_client))
|
||||
|
||||
yield cluster
|
||||
err_count = {}
|
||||
for file in glob.glob(str(cluster.log.joinpath("*.xml"))):
|
||||
lineno = 1
|
||||
for line in open(file):
|
||||
try:
|
||||
doc = ET.fromstring(line)
|
||||
except:
|
||||
continue
|
||||
if doc.attrib.get("Severity", "") == "40":
|
||||
ev_type = doc.attrib.get("Type", "[unset]")
|
||||
err = doc.attrib.get("Error", "[unset]")
|
||||
tup = (file, ev_type, err)
|
||||
err_count[tup] = err_count.get(tup, 0) + 1
|
||||
lineno += 1
|
||||
print("Sev40 Summary:")
|
||||
if len(err_count) == 0:
|
||||
print(" No errors")
|
||||
else:
|
||||
for tup, count in err_count.items():
|
||||
print(" {}: {}".format(tup, count))
|
||||
|
||||
@pytest.fixture
|
||||
def db(cluster, admin_ipc):
|
||||
db = fdb.open(str(cluster.cluster_file))
|
||||
#db.options.set_transaction_timeout(2000) # 2 seconds
|
||||
db.options.set_transaction_retry_limit(10)
|
||||
yield db
|
||||
admin_ipc.request("cleanup_database")
|
||||
|
@ -188,11 +226,17 @@ def tenant_tr_gen(db, use_grv_cache):
|
|||
return fn
|
||||
|
||||
@pytest.fixture
|
||||
def token_claim_1h(db):
|
||||
def tenant_id_from_name(db):
|
||||
def fn(tenant_name):
|
||||
tenant = db.open_tenant(to_bytes(tenant_name))
|
||||
return tenant.get_id().wait() # returns int
|
||||
return fn
|
||||
|
||||
@pytest.fixture
|
||||
def token_claim_1h(tenant_id_from_name):
|
||||
# JWT claim that is valid for 1 hour since time of invocation
|
||||
def fn(tenant_name: Union[bytes, str]):
|
||||
tenant = db.open_tenant(to_bytes(tenant_name))
|
||||
tenant_id = tenant.get_id().wait()
|
||||
tenant_id = tenant_id_from_name(tenant_name)
|
||||
now = time.time()
|
||||
return {
|
||||
"iss": "fdb-authz-tester",
|
||||
|
|
|
@ -3,7 +3,7 @@ allowDefaultTenant = false
|
|||
tenantModes = ['optional', 'required']
|
||||
|
||||
[[knobs]]
|
||||
audit_logging_enabled = false
|
||||
audit_logging_enabled = true
|
||||
max_trace_lines = 2000000
|
||||
|
||||
[[test]]
|
||||
|
|
Loading…
Reference in New Issue