diff --git a/CMakeLists.txt b/CMakeLists.txt index d5ce5de14a..2eb243c346 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -172,6 +172,7 @@ endif() include(CompileBoost) include(GetMsgpack) +add_subdirectory(contrib) add_subdirectory(flow) add_subdirectory(fdbrpc) add_subdirectory(fdbclient) @@ -183,7 +184,6 @@ else() add_subdirectory(fdbservice) endif() add_subdirectory(fdbbackup) -add_subdirectory(contrib) add_subdirectory(tests) add_subdirectory(flowbench EXCLUDE_FROM_ALL) if(WITH_PYTHON AND WITH_C_BINDING) diff --git a/fdbclient/BackupAgent.actor.h b/fdbclient/BackupAgent.actor.h index 44fa7921d3..0e72d1da5f 100644 --- a/fdbclient/BackupAgent.actor.h +++ b/fdbclient/BackupAgent.actor.h @@ -568,7 +568,7 @@ ACTOR Future applyMutations(Database cx, Key removePrefix, Version beginVersion, Version* endVersion, - RequestStream commit, + PublicRequestStream commit, NotifiedVersion* committedVersion, Reference> keyVersion); ACTOR Future cleanupBackup(Database cx, DeleteData deleteData); diff --git a/fdbclient/BackupAgentBase.actor.cpp b/fdbclient/BackupAgentBase.actor.cpp index 6033f80992..cb999a6e12 100644 --- a/fdbclient/BackupAgentBase.actor.cpp +++ b/fdbclient/BackupAgentBase.actor.cpp @@ -598,7 +598,7 @@ ACTOR Future dumpData(Database cx, Key uid, Key addPrefix, Key removePrefix, - RequestStream commit, + PublicRequestStream commit, NotifiedVersion* committedVersion, Optional endVersion, Key rangeBegin, @@ -675,7 +675,7 @@ ACTOR Future dumpData(Database cx, ACTOR Future coalesceKeyVersionCache(Key uid, Version endVersion, Reference> keyVersion, - RequestStream commit, + PublicRequestStream commit, NotifiedVersion* committedVersion, PromiseStream> addActor, FlowLock* commitLock) { @@ -725,7 +725,7 @@ ACTOR Future applyMutations(Database cx, Key removePrefix, Version beginVersion, Version* endVersion, - RequestStream commit, + PublicRequestStream commit, NotifiedVersion* committedVersion, Reference> keyVersion) { state FlowLock commitLock(CLIENT_KNOBS->BACKUP_LOCK_BYTES); diff --git a/fdbclient/CommitProxyInterface.h b/fdbclient/CommitProxyInterface.h index f5b96b342a..8d068926eb 100644 --- a/fdbclient/CommitProxyInterface.h +++ b/fdbclient/CommitProxyInterface.h @@ -43,13 +43,13 @@ struct CommitProxyInterface { Optional processId; bool provisional; - RequestStream commit; - RequestStream + PublicRequestStream commit; + PublicRequestStream getConsistentReadVersion; // Returns a version which (1) is committed, and (2) is >= the latest version reported // committed (by a commit response) when this request was sent // (at some point between when this request is sent and when its response is // received, the latest version reported committed) - RequestStream getKeyServersLocations; + PublicRequestStream getKeyServersLocations; RequestStream getStorageServerRejoinInfo; RequestStream> waitFailure; @@ -72,9 +72,9 @@ struct CommitProxyInterface { serializer(ar, processId, provisional, commit); if (Archive::isDeserializing) { getConsistentReadVersion = - RequestStream(commit.getEndpoint().getAdjustedEndpoint(1)); + PublicRequestStream(commit.getEndpoint().getAdjustedEndpoint(1)); getKeyServersLocations = - RequestStream(commit.getEndpoint().getAdjustedEndpoint(2)); + PublicRequestStream(commit.getEndpoint().getAdjustedEndpoint(2)); getStorageServerRejoinInfo = RequestStream(commit.getEndpoint().getAdjustedEndpoint(3)); waitFailure = RequestStream>(commit.getEndpoint().getAdjustedEndpoint(4)); diff --git a/fdbclient/CoordinationInterface.h b/fdbclient/CoordinationInterface.h index cc28dd5e25..c8f252088c 100644 --- a/fdbclient/CoordinationInterface.h +++ b/fdbclient/CoordinationInterface.h @@ -33,8 +33,8 @@ const int MAX_CLUSTER_FILE_BYTES = 60000; struct ClientLeaderRegInterface { - RequestStream getLeader; - RequestStream openDatabase; + PublicRequestStream getLeader; + PublicRequestStream openDatabase; RequestStream checkDescriptorMutable; Optional hostname; diff --git a/fdbclient/GrvProxyInterface.h b/fdbclient/GrvProxyInterface.h index 10dd81546d..4098a88d2c 100644 --- a/fdbclient/GrvProxyInterface.h +++ b/fdbclient/GrvProxyInterface.h @@ -36,7 +36,7 @@ struct GrvProxyInterface { Optional processId; bool provisional; - RequestStream + PublicRequestStream getConsistentReadVersion; // Returns a version which (1) is committed, and (2) is >= the latest version reported // committed (by a commit response) when this request was sent // (at some point between when this request is sent and when its response is received, the latest version reported diff --git a/fdbclient/NativeAPI.actor.cpp b/fdbclient/NativeAPI.actor.cpp index 96fd7bed78..ab8c957d53 100644 --- a/fdbclient/NativeAPI.actor.cpp +++ b/fdbclient/NativeAPI.actor.cpp @@ -105,11 +105,11 @@ namespace { TransactionLineageCollector transactionLineageCollector; NameLineageCollector nameLineageCollector; -template +template Future loadBalance( DatabaseContext* ctx, const Reference alternatives, - RequestStream Interface::*channel, + RequestStream Interface::*channel, const Request& request = Request(), TaskPriority taskID = TaskPriority::DefaultPromiseEndpoint, AtMostOnce atMostOnce = @@ -2541,7 +2541,7 @@ void stopNetwork() { if (!g_network) throw network_not_setup(); - TraceEvent("ClientStopNetwork"); + TraceEvent("ClientStopNetwork").log(); g_network->stop(); closeTraceFile(); } @@ -3771,7 +3771,7 @@ void transformRangeLimits(GetRangeLimits limits, Reverse reverse, GetKeyValuesFa } template -RequestStream StorageServerInterface::*getRangeRequestStream() { +PublicRequestStream StorageServerInterface::*getRangeRequestStream() { if constexpr (std::is_same::value) { return &StorageServerInterface::getKeyValues; } else if (std::is_same::value) { @@ -4597,9 +4597,9 @@ static Future tssStreamComparison(Request request, // Currently only used for GetKeyValuesStream but could easily be plugged for other stream types // User of the stream has to forward the SS's responses to the returned promise stream, if it is set -template +template Optional> -maybeDuplicateTSSStreamFragment(Request& req, QueueModel* model, RequestStream const* ssStream) { +maybeDuplicateTSSStreamFragment(Request& req, QueueModel* model, RequestStream const* ssStream) { if (model) { Optional tssData = model->getTssData(ssStream->getEndpoint().token.first()); diff --git a/fdbclient/StorageServerInterface.h b/fdbclient/StorageServerInterface.h index 7e0f92d3c9..13ba8f1e18 100644 --- a/fdbclient/StorageServerInterface.h +++ b/fdbclient/StorageServerInterface.h @@ -63,13 +63,13 @@ struct StorageServerInterface { UID uniqueID; Optional tssPairID; - RequestStream getValue; - RequestStream getKey; + PublicRequestStream getValue; + PublicRequestStream getKey; // Throws a wrong_shard_server if the keys in the request or result depend on data outside this server OR if a large // selector offset prevents all data from being read in one range read - RequestStream getKeyValues; - RequestStream getMappedKeyValues; + PublicRequestStream getKeyValues; + PublicRequestStream getMappedKeyValues; RequestStream getShardState; RequestStream waitMetrics; @@ -79,17 +79,17 @@ struct StorageServerInterface { RequestStream getQueuingMetrics; RequestStream> getKeyValueStoreType; - RequestStream watchValue; + PublicRequestStream watchValue; RequestStream getReadHotRanges; RequestStream getRangeSplitPoints; - RequestStream getKeyValuesStream; - RequestStream changeFeedStream; - RequestStream overlappingChangeFeeds; - RequestStream changeFeedPop; - RequestStream changeFeedVersionUpdate; - RequestStream checkpoint; - RequestStream fetchCheckpoint; - RequestStream fetchCheckpointKeyValues; + PublicRequestStream getKeyValuesStream; + PublicRequestStream changeFeedStream; + PublicRequestStream overlappingChangeFeeds; + PublicRequestStream changeFeedPop; + PublicRequestStream changeFeedVersionUpdate; + PublicRequestStream checkpoint; + PublicRequestStream fetchCheckpoint; + PublicRequestStream fetchCheckpointKeyValues; private: bool acceptingRequests; @@ -123,8 +123,9 @@ public: serializer(ar, uniqueID, locality, getValue); } if (Ar::isDeserializing) { - getKey = RequestStream(getValue.getEndpoint().getAdjustedEndpoint(1)); - getKeyValues = RequestStream(getValue.getEndpoint().getAdjustedEndpoint(2)); + getKey = PublicRequestStream(getValue.getEndpoint().getAdjustedEndpoint(1)); + getKeyValues = + PublicRequestStream(getValue.getEndpoint().getAdjustedEndpoint(2)); getShardState = RequestStream(getValue.getEndpoint().getAdjustedEndpoint(3)); waitMetrics = RequestStream(getValue.getEndpoint().getAdjustedEndpoint(4)); @@ -136,27 +137,29 @@ public: RequestStream(getValue.getEndpoint().getAdjustedEndpoint(8)); getKeyValueStoreType = RequestStream>(getValue.getEndpoint().getAdjustedEndpoint(9)); - watchValue = RequestStream(getValue.getEndpoint().getAdjustedEndpoint(10)); + watchValue = + PublicRequestStream(getValue.getEndpoint().getAdjustedEndpoint(10)); getReadHotRanges = RequestStream(getValue.getEndpoint().getAdjustedEndpoint(11)); getRangeSplitPoints = RequestStream(getValue.getEndpoint().getAdjustedEndpoint(12)); - getKeyValuesStream = - RequestStream(getValue.getEndpoint().getAdjustedEndpoint(13)); - getMappedKeyValues = - RequestStream(getValue.getEndpoint().getAdjustedEndpoint(14)); + getKeyValuesStream = PublicRequestStream( + getValue.getEndpoint().getAdjustedEndpoint(13)); + getMappedKeyValues = PublicRequestStream( + getValue.getEndpoint().getAdjustedEndpoint(14)); changeFeedStream = - RequestStream(getValue.getEndpoint().getAdjustedEndpoint(15)); - overlappingChangeFeeds = - RequestStream(getValue.getEndpoint().getAdjustedEndpoint(16)); + PublicRequestStream(getValue.getEndpoint().getAdjustedEndpoint(15)); + overlappingChangeFeeds = PublicRequestStream( + getValue.getEndpoint().getAdjustedEndpoint(16)); changeFeedPop = - RequestStream(getValue.getEndpoint().getAdjustedEndpoint(17)); - changeFeedVersionUpdate = RequestStream( + PublicRequestStream(getValue.getEndpoint().getAdjustedEndpoint(17)); + changeFeedVersionUpdate = PublicRequestStream( getValue.getEndpoint().getAdjustedEndpoint(18)); - checkpoint = RequestStream(getValue.getEndpoint().getAdjustedEndpoint(19)); + checkpoint = + PublicRequestStream(getValue.getEndpoint().getAdjustedEndpoint(19)); fetchCheckpoint = - RequestStream(getValue.getEndpoint().getAdjustedEndpoint(20)); - fetchCheckpointKeyValues = RequestStream( + PublicRequestStream(getValue.getEndpoint().getAdjustedEndpoint(20)); + fetchCheckpointKeyValues = PublicRequestStream( getValue.getEndpoint().getAdjustedEndpoint(21)); } } else { diff --git a/fdbclient/WellKnownEndpoints.h b/fdbclient/WellKnownEndpoints.h index bed5c34935..e42d13e443 100644 --- a/fdbclient/WellKnownEndpoints.h +++ b/fdbclient/WellKnownEndpoints.h @@ -28,28 +28,28 @@ * All well-known endpoints of FDB must be listed here to guarantee their uniqueness */ enum WellKnownEndpoints { - WLTOKEN_CLIENTLEADERREG_GETLEADER = WLTOKEN_FIRST_AVAILABLE, // 2 - WLTOKEN_CLIENTLEADERREG_OPENDATABASE, // 3 - WLTOKEN_LEADERELECTIONREG_CANDIDACY, // 4 - WLTOKEN_LEADERELECTIONREG_ELECTIONRESULT, // 5 - WLTOKEN_LEADERELECTIONREG_LEADERHEARTBEAT, // 6 - WLTOKEN_LEADERELECTIONREG_FORWARD, // 7 - WLTOKEN_GENERATIONREG_READ, // 8 - WLTOKEN_GENERATIONREG_WRITE, // 9 + WLTOKEN_CLIENTLEADERREG_GETLEADER = WLTOKEN_FIRST_AVAILABLE, // 4 + WLTOKEN_CLIENTLEADERREG_OPENDATABASE, // 5 + WLTOKEN_LEADERELECTIONREG_CANDIDACY, // 6 + WLTOKEN_LEADERELECTIONREG_ELECTIONRESULT, // 7 + WLTOKEN_LEADERELECTIONREG_LEADERHEARTBEAT, // 8 + WLTOKEN_LEADERELECTIONREG_FORWARD, // 9 WLTOKEN_PROTOCOL_INFO, // 10 : the value of this endpoint should be stable and not change. - WLTOKEN_CLIENTLEADERREG_DESCRIPTOR_MUTABLE, // 11 - WLTOKEN_CONFIGTXN_GETGENERATION, // 12 - WLTOKEN_CONFIGTXN_GET, // 13 - WLTOKEN_CONFIGTXN_GETCLASSES, // 14 - WLTOKEN_CONFIGTXN_GETKNOBS, // 15 - WLTOKEN_CONFIGTXN_COMMIT, // 16 - WLTOKEN_CONFIGFOLLOWER_GETSNAPSHOTANDCHANGES, // 17 - WLTOKEN_CONFIGFOLLOWER_GETCHANGES, // 18 - WLTOKEN_CONFIGFOLLOWER_COMPACT, // 19 - WLTOKEN_CONFIGFOLLOWER_ROLLFORWARD, // 20 - WLTOKEN_CONFIGFOLLOWER_GETCOMMITTEDVERSION, // 21 - WLTOKEN_PROCESS, // 22 - WLTOKEN_RESERVED_COUNT // 23 + WLTOKEN_GENERATIONREG_READ, // 11 + WLTOKEN_GENERATIONREG_WRITE, // 12 + WLTOKEN_CLIENTLEADERREG_DESCRIPTOR_MUTABLE, // 13 + WLTOKEN_CONFIGTXN_GETGENERATION, // 14 + WLTOKEN_CONFIGTXN_GET, // 15 + WLTOKEN_CONFIGTXN_GETCLASSES, // 16 + WLTOKEN_CONFIGTXN_GETKNOBS, // 17 + WLTOKEN_CONFIGTXN_COMMIT, // 18 + WLTOKEN_CONFIGFOLLOWER_GETSNAPSHOTANDCHANGES, // 19 + WLTOKEN_CONFIGFOLLOWER_GETCHANGES, // 20 + WLTOKEN_CONFIGFOLLOWER_COMPACT, // 21 + WLTOKEN_CONFIGFOLLOWER_ROLLFORWARD, // 22 + WLTOKEN_CONFIGFOLLOWER_GETCOMMITTEDVERSION, // 23 + WLTOKEN_PROCESS, // 24 + WLTOKEN_RESERVED_COUNT // 25 }; static_assert(WLTOKEN_PROTOCOL_INFO == diff --git a/fdbrpc/CMakeLists.txt b/fdbrpc/CMakeLists.txt index 59ae21bc9e..1b709e22f0 100644 --- a/fdbrpc/CMakeLists.txt +++ b/fdbrpc/CMakeLists.txt @@ -16,6 +16,7 @@ set(FDBRPC_SRCS genericactors.actor.cpp HealthMonitor.actor.cpp IAsyncFile.actor.cpp + IPAllowList.cpp LoadBalance.actor.cpp LoadBalance.actor.h Locality.cpp diff --git a/fdbrpc/FailureMonitor.actor.cpp b/fdbrpc/FailureMonitor.actor.cpp index dcb4052cf5..c7c5328a1c 100644 --- a/fdbrpc/FailureMonitor.actor.cpp +++ b/fdbrpc/FailureMonitor.actor.cpp @@ -131,11 +131,20 @@ void SimpleFailureMonitor::endpointNotFound(Endpoint const& endpoint) { TraceEvent(SevWarnAlways, "TooManyFailedEndpoints").suppressFor(1.0); failedEndpoints.clear(); } - failedEndpoints.insert(endpoint); + failedEndpoints.emplace(endpoint, FailedReason::NOT_FOUND); } endpointKnownFailed.trigger(endpoint); } +void SimpleFailureMonitor::unauthorizedEndpoint(Endpoint const& endpoint) { + TraceEvent(g_network->isSimulated() ? SevWarnAlways : SevError, "TriedAccessPrivateEndpoint") + .suppressFor(1.0) + .detail("Address", endpoint.getPrimaryAddress()) + .detail("Token", endpoint.token); + failedEndpoints.emplace(endpoint, FailedReason::UNAUTHORIZED); + endpointKnownFailed.trigger(endpoint); +} + void SimpleFailureMonitor::notifyDisconnect(NetworkAddress const& address) { //TraceEvent("NotifyDisconnect").detail("Address", address); endpointKnownFailed.triggerRange(Endpoint({ address }, UID()), Endpoint({ address }, UID(-1, -1))); @@ -208,8 +217,13 @@ bool SimpleFailureMonitor::permanentlyFailed(Endpoint const& endpoint) const { return failedEndpoints.count(endpoint); } +bool SimpleFailureMonitor::knownUnauthorized(Endpoint const& endpoint) const { + auto iter = failedEndpoints.find(endpoint); + return iter != failedEndpoints.end() && iter->second == FailedReason::UNAUTHORIZED; +} + void SimpleFailureMonitor::reset() { addressStatus = std::unordered_map(); - failedEndpoints = std::unordered_set(); + failedEndpoints = std::unordered_map(); endpointKnownFailed.resetNoWaiting(); } diff --git a/fdbrpc/FailureMonitor.h b/fdbrpc/FailureMonitor.h index 280f74d8cf..d32d141992 100644 --- a/fdbrpc/FailureMonitor.h +++ b/fdbrpc/FailureMonitor.h @@ -93,6 +93,9 @@ public: // Only use this function when the endpoint is known to be failed virtual void endpointNotFound(Endpoint const&) = 0; + // Inform client that it was trying to send a message to a private endpoint + virtual void unauthorizedEndpoint(Endpoint const&) = 0; + // The next time the known status for the endpoint changes, returns the new status. virtual Future onStateChanged(Endpoint const& endpoint) = 0; @@ -108,6 +111,9 @@ public: // Returns true if the endpoint will never become available. virtual bool permanentlyFailed(Endpoint const& endpoint) const = 0; + // Returns true if we known we're not allowed to send messages to the remote endpoint + virtual bool knownUnauthorized(Endpoint const&) const = 0; + // Called by FlowTransport when a connection closes and a prior request or reply might be lost virtual void notifyDisconnect(NetworkAddress const&) = 0; @@ -139,9 +145,11 @@ public: class SimpleFailureMonitor : public IFailureMonitor { public: + enum class FailedReason { NOT_FOUND, UNAUTHORIZED }; SimpleFailureMonitor(); void setStatus(NetworkAddress const& address, FailureStatus const& status) override; void endpointNotFound(Endpoint const&) override; + void unauthorizedEndpoint(Endpoint const&) override; void notifyDisconnect(NetworkAddress const&) override; Future onStateChanged(Endpoint const& endpoint) override; @@ -151,6 +159,7 @@ public: Future onDisconnect(NetworkAddress const& address) override; bool onlyEndpointFailed(Endpoint const& endpoint) const override; bool permanentlyFailed(Endpoint const& endpoint) const override; + bool knownUnauthorized(Endpoint const&) const override; void reset(); @@ -158,7 +167,7 @@ private: std::unordered_map addressStatus; YieldedAsyncMap endpointKnownFailed; AsyncMap disconnectTriggers; - std::unordered_set failedEndpoints; + std::unordered_map failedEndpoints; friend class OnStateChangedActorActor; }; diff --git a/fdbrpc/FlowTransport.actor.cpp b/fdbrpc/FlowTransport.actor.cpp index 9e5a224d4c..d5dac42abe 100644 --- a/fdbrpc/FlowTransport.actor.cpp +++ b/fdbrpc/FlowTransport.actor.cpp @@ -27,10 +27,12 @@ #include #endif +#include "fdbrpc/TenantInfo.h" #include "fdbrpc/fdbrpc.h" #include "fdbrpc/FailureMonitor.h" #include "fdbrpc/HealthMonitor.h" #include "fdbrpc/genericactors.actor.h" +#include "fdbrpc/IPAllowList.h" #include "fdbrpc/simulator.h" #include "flow/ActorCollection.h" #include "flow/Error.h" @@ -50,6 +52,9 @@ static Future g_currentDeliveryPeerDisconnect; constexpr int PACKET_LEN_WIDTH = sizeof(uint32_t); const uint64_t TOKEN_STREAM_FLAG = 1; +FDB_BOOLEAN_PARAM(InReadSocket); +FDB_BOOLEAN_PARAM(IsStableConnection); + class EndpointMap : NonCopyable { public: // Reserve space for this many wellKnownEndpoints @@ -205,6 +210,7 @@ struct EndpointNotFoundReceiver final : NetworkMessageReceiver { Endpoint e = FlowTransport::transport().loadedEndpoint(token); IFailureMonitor::failureMonitor().endpointNotFound(e); } + bool isPublic() const override { return true; } }; struct PingRequest { @@ -229,11 +235,52 @@ struct PingReceiver final : NetworkMessageReceiver { PeerCompatibilityPolicy peerCompatibilityPolicy() const override { return PeerCompatibilityPolicy{ RequirePeer::AtLeast, ProtocolVersion::withStableInterfaces() }; } + bool isPublic() const override { return true; } +}; + +struct TenantAuthorizer final : NetworkMessageReceiver { + TenantAuthorizer(EndpointMap& endpoints) { + endpoints.insertWellKnown(this, Endpoint::wellKnownToken(WLTOKEN_AUTH_TENANT), TaskPriority::ReadSocket); + } + void receive(ArenaObjectReader& reader) override { + AuthorizationRequest req; + try { + reader.deserialize(req); + // TODO: verify that token is valid + AuthorizedTenants& auth = reader.variable("AuthorizedTenants"); + for (auto const& t : req.tenants) { + auth.authorizedTenants.insert(TenantInfoRef(auth.arena, t)); + } + req.reply.send(Void()); + } catch (Error& e) { + if (e.code() == error_code_permission_denied) { + req.reply.sendError(e); + } else { + throw; + } + } + } + bool isPublic() const override { return true; } +}; + +struct UnauthorizedEndpointReceiver final : NetworkMessageReceiver { + UnauthorizedEndpointReceiver(EndpointMap& endpoints) { + endpoints.insertWellKnown( + this, Endpoint::wellKnownToken(WLTOKEN_UNAUTHORIZED_ENDPOINT), TaskPriority::ReadSocket); + } + + void receive(ArenaObjectReader& reader) override { + UID token; + reader.deserialize(token); + Endpoint e = FlowTransport::transport().loadedEndpoint(token); + IFailureMonitor::failureMonitor().unauthorizedEndpoint(e); + } + bool isPublic() const override { return true; } }; class TransportData { public: - TransportData(uint64_t transportId, int maxWellKnownEndpoints); + TransportData(uint64_t transportId, int maxWellKnownEndpoints, IPAllowList const* allowList); ~TransportData(); @@ -264,6 +311,8 @@ public: EndpointMap endpoints; EndpointNotFoundReceiver endpointNotFoundReceiver{ endpoints }; PingReceiver pingReceiver{ endpoints }; + TenantAuthorizer tenantReceiver{ endpoints }; + UnauthorizedEndpointReceiver unauthorizedEndpointReceiver{ endpoints }; Int64MetricHandle bytesSent; Int64MetricHandle countPacketsReceived; @@ -278,6 +327,8 @@ public: std::map multiVersionConnections; double lastIncompatibleMessage; uint64_t transportId; + IPAllowList allowList; + std::shared_ptr localCVM = std::make_shared(); // for local delivery Future multiVersionCleanup; Future pingLogger; @@ -340,9 +391,10 @@ ACTOR Future pingLatencyLogger(TransportData* self) { } } -TransportData::TransportData(uint64_t transportId, int maxWellKnownEndpoints) +TransportData::TransportData(uint64_t transportId, int maxWellKnownEndpoints, IPAllowList const* allowList) : warnAlwaysForLargePacket(true), endpoints(maxWellKnownEndpoints), endpointNotFoundReceiver(endpoints), - pingReceiver(endpoints), numIncompatibleConnections(0), lastIncompatibleMessage(0), transportId(transportId) { + pingReceiver(endpoints), numIncompatibleConnections(0), lastIncompatibleMessage(0), transportId(transportId), + allowList(allowList == nullptr ? IPAllowList() : *allowList) { degraded = makeReference>(false); pingLogger = pingLatencyLogger(this); } @@ -880,7 +932,8 @@ void Peer::onIncomingConnection(Reference self, Reference con .suppressFor(1.0) .detail("FromAddr", conn->getPeerAddress()) .detail("CanonicalAddr", destination) - .detail("IsPublic", destination.isPublic()); + .detail("IsPublic", destination.isPublic()) + .detail("Trusted", self->transport->allowList(conn->getPeerAddress().ip)); connect.cancel(); prependConnectPacket(); @@ -927,7 +980,10 @@ ACTOR static void deliver(TransportData* self, Endpoint destination, TaskPriority priority, ArenaReader reader, - bool inReadSocket, + NetworkAddress peerAddress, + Reference authorizedTenants, + std::shared_ptr cvm, + InReadSocket inReadSocket, Future disconnect) { // We want to run the task at the right priority. If the priority is higher than the current priority (which is // ReadSocket) we can just upgrade. Otherwise we'll context switch so that we don't block other tasks that might run @@ -941,7 +997,7 @@ ACTOR static void deliver(TransportData* self, } auto receiver = self->endpoints.get(destination.token); - if (receiver) { + if (receiver && (authorizedTenants->trusted || receiver->isPublic())) { if (!checkCompatible(receiver->peerCompatibilityPolicy(), reader.protocolVersion())) { return; } @@ -951,6 +1007,7 @@ ACTOR static void deliver(TransportData* self, StringRef data = reader.arenaReadAll(); ASSERT(data.size() > 8); ArenaObjectReader objReader(reader.arena(), reader.arenaReadAll(), AssumeVersion(reader.protocolVersion())); + objReader.setContextVariableMap(cvm); receiver->receive(objReader); g_currentDeliveryPeerAddress = { NetworkAddress() }; g_currentDeliveryPeerDisconnect = Future(); @@ -968,18 +1025,31 @@ ACTOR static void deliver(TransportData* self, } } else if (destination.token.first() & TOKEN_STREAM_FLAG) { // We don't have the (stream) endpoint 'token', notify the remote machine - if (destination.token.first() != -1) { - if (self->isLocalAddress(destination.getPrimaryAddress())) { - sendLocal(self, - SerializeSource(destination.token), - Endpoint::wellKnown(destination.addresses, WLTOKEN_ENDPOINT_NOT_FOUND)); - } else { - Reference peer = self->getOrOpenPeer(destination.getPrimaryAddress()); - sendPacket(self, - peer, - SerializeSource(destination.token), - Endpoint::wellKnown(destination.addresses, WLTOKEN_ENDPOINT_NOT_FOUND), - false); + if (receiver) { + TraceEvent(SevWarnAlways, "AttemptedRPCToPrivatePrevented") + .detail("From", peerAddress) + .detail("Token", destination.token); + ASSERT(!self->isLocalAddress(destination.getPrimaryAddress())); + Reference peer = self->getOrOpenPeer(destination.getPrimaryAddress()); + sendPacket(self, + peer, + SerializeSource(destination.token), + Endpoint::wellKnown(destination.addresses, WLTOKEN_UNAUTHORIZED_ENDPOINT), + false); + } else { + if (destination.token.first() != -1) { + if (self->isLocalAddress(destination.getPrimaryAddress())) { + sendLocal(self, + SerializeSource(destination.token), + Endpoint::wellKnown(destination.addresses, WLTOKEN_ENDPOINT_NOT_FOUND)); + } else { + Reference peer = self->getOrOpenPeer(destination.getPrimaryAddress()); + sendPacket(self, + peer, + SerializeSource(destination.token), + Endpoint::wellKnown(destination.addresses, WLTOKEN_ENDPOINT_NOT_FOUND), + false); + } } } } @@ -990,9 +1060,11 @@ static void scanPackets(TransportData* transport, const uint8_t* e, Arena& arena, NetworkAddress const& peerAddress, + Reference const& authorizedTenants, + std::shared_ptr cvm, ProtocolVersion peerProtocolVersion, Future disconnect, - bool isStableConnection) { + IsStableConnection isStableConnection) { // Find each complete packet in the given byte range and queue a ready task to deliver it. // Remove the complete packets from the range by increasing unprocessed_begin. // There won't be more than 64K of data plus one packet, so this shouldn't take a long time. @@ -1106,7 +1178,15 @@ static void scanPackets(TransportData* transport, // we have many messages to UnknownEndpoint we want to optimize earlier. As deliver is an actor it // will allocate some state on the heap and this prevents it from doing that. if (priority != TaskPriority::UnknownEndpoint || (token.first() & TOKEN_STREAM_FLAG) != 0) { - deliver(transport, Endpoint({ peerAddress }, token), priority, std::move(reader), true, disconnect); + deliver(transport, + Endpoint({ peerAddress }, token), + priority, + std::move(reader), + peerAddress, + authorizedTenants, + cvm, + InReadSocket::True, + disconnect); } unprocessed_begin = p = p + packetLen; @@ -1152,8 +1232,14 @@ ACTOR static Future connectionReader(TransportData* transport, state bool incompatibleProtocolVersionNewer = false; state NetworkAddress peerAddress; state ProtocolVersion peerProtocolVersion; - + state Reference authorizedTenants = makeReference(); + state std::shared_ptr cvm = std::make_shared(); peerAddress = conn->getPeerAddress(); + authorizedTenants->trusted = transport->allowList(conn->getPeerAddress().ip); + (*cvm)["AuthorizedTenants"] = &authorizedTenants; + (*cvm)["PeerAddress"] = &peerAddress; + + authorizedTenants->trusted = transport->allowList(peerAddress.ip); if (!peer) { ASSERT(!peerAddress.isPublic()); } @@ -1306,9 +1392,11 @@ ACTOR static Future connectionReader(TransportData* transport, unprocessed_end, arena, peerAddress, + authorizedTenants, + cvm, peerProtocolVersion, peer->disconnect.getFuture(), - g_network->isSimulated() && conn->isStableConnection()); + IsStableConnection(g_network->isSimulated() && conn->isStableConnection())); } else { unprocessed_begin = unprocessed_end; peer->resetPing.trigger(); @@ -1452,8 +1540,8 @@ ACTOR static Future multiVersionCleanupWorker(TransportData* self) { } } -FlowTransport::FlowTransport(uint64_t transportId, int maxWellKnownEndpoints) - : self(new TransportData(transportId, maxWellKnownEndpoints)) { +FlowTransport::FlowTransport(uint64_t transportId, int maxWellKnownEndpoints, IPAllowList const* allowList) + : self(new TransportData(transportId, maxWellKnownEndpoints, allowList)) { self->multiVersionCleanup = multiVersionCleanupWorker(self); } @@ -1473,6 +1561,10 @@ NetworkAddress FlowTransport::getLocalAddress() const { return self->localAddresses.address; } +void FlowTransport::setLocalAddress(NetworkAddress const& address) { + self->localAddresses.address = address; +} + const std::unordered_map>& FlowTransport::getAllPeers() const { return self->peers; } @@ -1587,11 +1679,16 @@ static void sendLocal(TransportData* self, ISerializeSource const& what, const E ASSERT(copy.size() > 0); TaskPriority priority = self->endpoints.getPriority(destination.token); if (priority != TaskPriority::UnknownEndpoint || (destination.token.first() & TOKEN_STREAM_FLAG) != 0) { + Reference authorizedTenants = makeReference(); + authorizedTenants->trusted = true; deliver(self, destination, priority, ArenaReader(copy.arena(), copy, AssumeVersion(currentProtocolVersion)), - false, + NetworkAddress(), + authorizedTenants, + self->localCVM, + InReadSocket::False, Never()); } } @@ -1792,9 +1889,12 @@ bool FlowTransport::incompatibleOutgoingConnectionsPresent() { return self->numIncompatibleConnections > 0; } -void FlowTransport::createInstance(bool isClient, uint64_t transportId, int maxWellKnownEndpoints) { +void FlowTransport::createInstance(bool isClient, + uint64_t transportId, + int maxWellKnownEndpoints, + IPAllowList const* allowList) { g_network->setGlobal(INetwork::enFlowTransport, - (flowGlobalType) new FlowTransport(transportId, maxWellKnownEndpoints)); + (flowGlobalType) new FlowTransport(transportId, maxWellKnownEndpoints, allowList)); g_network->setGlobal(INetwork::enNetworkAddressFunc, (flowGlobalType)&FlowTransport::getGlobalLocalAddress); g_network->setGlobal(INetwork::enNetworkAddressesFunc, (flowGlobalType)&FlowTransport::getGlobalLocalAddresses); g_network->setGlobal(INetwork::enFailureMonitor, (flowGlobalType) new SimpleFailureMonitor()); diff --git a/fdbrpc/FlowTransport.h b/fdbrpc/FlowTransport.h index d70b98c401..8f60a2fc9b 100644 --- a/fdbrpc/FlowTransport.h +++ b/fdbrpc/FlowTransport.h @@ -31,7 +31,13 @@ #include "flow/Net2Packet.h" #include "fdbrpc/ContinuousSample.h" -enum { WLTOKEN_ENDPOINT_NOT_FOUND = 0, WLTOKEN_PING_PACKET, WLTOKEN_FIRST_AVAILABLE }; +enum { + WLTOKEN_ENDPOINT_NOT_FOUND = 0, + WLTOKEN_PING_PACKET, + WLTOKEN_AUTH_TENANT, + WLTOKEN_UNAUTHORIZED_ENDPOINT, + WLTOKEN_FIRST_AVAILABLE +}; #pragma pack(push, 4) class Endpoint { @@ -129,6 +135,7 @@ class NetworkMessageReceiver { public: virtual void receive(ArenaObjectReader&) = 0; virtual bool isStream() const { return false; } + virtual bool isPublic() const = 0; virtual PeerCompatibilityPolicy peerCompatibilityPolicy() const { return { RequirePeer::Exactly, g_network->protocolVersion() }; } @@ -182,14 +189,19 @@ struct Peer : public ReferenceCounted { void onIncomingConnection(Reference self, Reference conn, Future reader); }; +class IPAllowList; + class FlowTransport { public: - FlowTransport(uint64_t transportId, int maxWellKnownEndpoints); + FlowTransport(uint64_t transportId, int maxWellKnownEndpoints, IPAllowList const* allowList); ~FlowTransport(); // Creates a new FlowTransport and makes FlowTransport::transport() return it. This uses g_network->global() // variables, so it will be private to a simulation. - static void createInstance(bool isClient, uint64_t transportId, int maxWellKnownEndpoints); + static void createInstance(bool isClient, + uint64_t transportId, + int maxWellKnownEndpoints, + IPAllowList const* allowList = nullptr); static bool isClient() { return g_network->global(INetwork::enClientFailureMonitor) != nullptr; } @@ -203,6 +215,9 @@ public: // Returns first local NetworkAddress. NetworkAddress getLocalAddress() const; + // Returns first local NetworkAddress. + void setLocalAddress(NetworkAddress const&); + // Returns all local NetworkAddress. NetworkAddressList getLocalAddresses() const; diff --git a/fdbrpc/IPAllowList.cpp b/fdbrpc/IPAllowList.cpp new file mode 100644 index 0000000000..be894b0e1a --- /dev/null +++ b/fdbrpc/IPAllowList.cpp @@ -0,0 +1,386 @@ +/* + * IPAllowList.cpp + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2013-2022 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "flow/UnitTest.h" +#include "flow/Error.h" +#include "fdbrpc/IPAllowList.h" + +#include +#include +#include + +namespace { + +template +std::string binRep(std::array const& addr) { + return fmt::format("{:02x}", fmt::join(addr, ":")); +} + +template +void printIP(std::array const& addr) { + fmt::print(" {}", binRep(addr)); +} + +template +int netmaskWeightImpl(std::array const& addr) { + int count = 0; + for (int i = 0; i < addr.size() && addr[i] != 0xff; ++i) { + std::bitset<8> b(addr[i]); + count += 8 - b.count(); + } + return count; +} + +} // namespace + +AuthAllowedSubnet::AuthAllowedSubnet(IPAddress const& baseAddress, IPAddress const& addressMask) + : baseAddress(baseAddress), addressMask(addressMask) { + ASSERT(baseAddress.isV4() == addressMask.isV4()); +} + +IPAddress AuthAllowedSubnet::netmask() const { + if (addressMask.isV4()) { + uint32_t res = 0xffffffff ^ addressMask.toV4(); + return IPAddress(res); + } else { + std::array res; + res.fill(0xff); + auto mask = addressMask.toV6(); + for (int i = 0; i < mask.size(); ++i) { + res[i] ^= mask[i]; + } + return IPAddress(res); + } +} + +int AuthAllowedSubnet::netmaskWeight() const { + if (addressMask.isV4()) { + boost::asio::ip::address_v4 addr(netmask().toV4()); + return netmaskWeightImpl(addr.to_bytes()); + } else { + return netmaskWeightImpl(netmask().toV6()); + } +} + +AuthAllowedSubnet AuthAllowedSubnet::fromString(std::string_view addressString) { + auto pos = addressString.find('/'); + if (pos == std::string_view::npos) { + fmt::print("ERROR: {} is not a valid (use Network-Prefix/netmaskWeight syntax)\n"); + throw invalid_option(); + } + auto address = addressString.substr(0, pos); + auto netmaskWeight = std::stoi(std::string(addressString.substr(pos + 1))); + auto addr = boost::asio::ip::make_address(address); + if (addr.is_v4()) { + auto bM = createBitMask(addr.to_v4().to_bytes(), netmaskWeight); + // we typically would expect a base address has been passed, but to be safe we still + // will make the last bits 0 + auto mask = boost::asio::ip::address_v4(bM).to_uint(); + auto baseAddress = addr.to_v4().to_uint() & mask; + return AuthAllowedSubnet(IPAddress(baseAddress), IPAddress(mask)); + } else { + auto mask = createBitMask(addr.to_v6().to_bytes(), netmaskWeight); + auto baseAddress = addr.to_v6().to_bytes(); + for (int i = 0; i < mask.size(); ++i) { + baseAddress[i] &= mask[i]; + } + return AuthAllowedSubnet(IPAddress(baseAddress), IPAddress(mask)); + } +} + +void AuthAllowedSubnet::printIP(std::string_view txt, IPAddress const& address) { + fmt::print("{}:", txt); + if (address.isV4()) { + ::printIP(boost::asio::ip::address_v4(address.toV4()).to_bytes()); + } else { + ::printIP(address.toV6()); + } + fmt::print("\n"); +} + +template +std::array AuthAllowedSubnet::createBitMask(std::array const& addr, + int netmaskWeight) { + std::array res; + res.fill((unsigned char)0xff); + int idx = netmaskWeight / 8; + if (netmaskWeight % 8 > 0) { + // 2^(netmaskWeight % 8) - 1 sets the last (netmaskWeight % 8) number of bits to 1 + // everything else will be zero. For example: 2^3 - 1 == 7 == 0b111 + unsigned char bitmask = (1 << (8 - (netmaskWeight % 8))) - ((unsigned char)1); + res[idx] ^= bitmask; + ++idx; + } + for (; idx < res.size(); ++idx) { + res[idx] = (unsigned char)0; + } + return res; +} + +template std::array AuthAllowedSubnet::createBitMask<4>(const std::array& addr, + int netmaskWeight); +template std::array AuthAllowedSubnet::createBitMask<16>(const std::array& addr, + int netmaskWeight); + +// helpers for testing +namespace { +using boost::asio::ip::address_v4; +using boost::asio::ip::address_v6; + +void traceAddress(TraceEvent& evt, const char* name, IPAddress address) { + evt.detail(name, address); + std::string bin; + if (address.isV4()) { + boost::asio::ip::address_v4 a(address.toV4()); + bin = binRep(a.to_bytes()); + } else { + bin = binRep(address.toV6()); + } + evt.detail(fmt::format("{}Binary", name).c_str(), bin); +} + +void subnetAssert(IPAllowList const& allowList, IPAddress addr, bool expectAllowed) { + if (allowList(addr) == expectAllowed) { + return; + } + TraceEvent evt(SevError, expectAllowed ? "ExpectedAddressToBeTrusted" : "ExpectedAddressToBeUntrusted"); + traceAddress(evt, "Address", addr); + auto const& subnets = allowList.subnets(); + for (int i = 0; i < subnets.size(); ++i) { + traceAddress(evt, fmt::format("SubnetBase{}", i).c_str(), subnets[i].baseAddress); + traceAddress(evt, fmt::format("SubnetMask{}", i).c_str(), subnets[i].addressMask); + } +} + +IPAddress parseAddr(std::string const& str) { + auto res = IPAddress::parse(str); + ASSERT(res.present()); + return res.get(); +} + +struct SubNetTest { + AuthAllowedSubnet subnet; + SubNetTest(AuthAllowedSubnet&& subnet) : subnet(std::move(subnet)) {} + SubNetTest(AuthAllowedSubnet const& subnet) : subnet(subnet) {} + template + static SubNetTest randomSubNetImpl() { + constexpr int width = V4 ? 4 : 16; + std::array binAddr; + unsigned char rnd[4]; + for (int i = 0; i < binAddr.size(); ++i) { + if (i % 4 == 0) { + auto tmp = deterministicRandom()->randomUInt32(); + ::memcpy(rnd, &tmp, 4); + } + binAddr[i] = rnd[i % 4]; + } + auto netmaskWeight = deterministicRandom()->randomInt(1, width); + std::string address; + if constexpr (V4) { + address_v4 a(binAddr); + address = a.to_string(); + } else { + address_v6 a(binAddr); + address = a.to_string(); + } + return SubNetTest(AuthAllowedSubnet::fromString(fmt::format("{}/{}", address, netmaskWeight))); + } + static SubNetTest randomSubNet() { + if (deterministicRandom()->coinflip()) { + return randomSubNetImpl(); + } else { + return randomSubNetImpl(); + } + } + + template + static IPAddress intArrayToAddress(uint32_t* arr) { + if constexpr (V4) { + return IPAddress(arr[0]); + } else { + std::array res; + memcpy(res.data(), arr, 4); + return IPAddress(res); + } + } + + template + I transformIntToSubnet(I val, I subnetMask, I baseAddress) { + return (val & subnetMask) ^ baseAddress; + } + + template + static IPAddress randomAddress() { + constexpr int width = V4 ? 4 : 16; + uint32_t rnd[width / 4]; + for (int i = 0; i < width / 4; ++i) { + rnd[i] = deterministicRandom()->randomUInt32(); + } + return intArrayToAddress(rnd); + } + + template + IPAddress randomAddress(bool inSubnet) { + ASSERT(V4 == subnet.baseAddress.isV4() || !inSubnet); + for (;;) { + auto res = randomAddress(); + if (V4 != subnet.baseAddress.isV4()) { + return res; + } + if (!inSubnet) { + if (!subnet(res)) { + return res; + } else { + continue; + } + } + // first we make sure the address is in the subnet + if constexpr (V4) { + auto a = res.toV4(); + auto base = subnet.baseAddress.toV4(); + auto netmask = subnet.netmask().toV4(); + auto validAddress = transformIntToSubnet(a, netmask, base); + res = IPAddress(validAddress); + } else { + auto a = res.toV6(); + auto base = subnet.baseAddress.toV6(); + auto netmask = subnet.netmask().toV6(); + for (int i = 0; i < a.size(); ++i) { + a[i] = transformIntToSubnet(a[i], netmask[i], base[i]); + } + res = IPAddress(a); + } + return res; + } + } + + IPAddress randomAddress(bool inSubnet) { + if (!inSubnet && deterministicRandom()->random01() < 0.1) { + // return an address of a different type + if (subnet.baseAddress.isV4()) { + return randomAddress(false); + } else { + return randomAddress(false); + } + } + if (subnet.addressMask.isV4()) { + return randomAddress(inSubnet); + } else { + return randomAddress(inSubnet); + } + } +}; + +} // namespace + +TEST_CASE("/fdbrpc/allow_list") { + // test correct weight calculation + // IPv4 + for (int i = 0; i < 33; ++i) { + auto str = fmt::format("0.0.0.0/{}", i); + auto subnet = AuthAllowedSubnet::fromString(str); + if (i != subnet.netmaskWeight()) { + fmt::print("Wrong calculated weight {} for {}\n", subnet.netmaskWeight(), str); + fmt::print("\tBase address: {}\n", subnet.baseAddress.toString()); + fmt::print("\tAddress Mask: {}\n", subnet.addressMask.toString()); + fmt::print("\tNetmask: {}\n", subnet.netmask().toString()); + ASSERT_EQ(i, subnet.netmaskWeight()); + } + } + // IPv6 + for (int i = 0; i < 129; ++i) { + auto subnet = AuthAllowedSubnet::fromString(fmt::format("0::/{}", i)); + ASSERT_EQ(i, subnet.netmaskWeight()); + } + IPAllowList allowList; + // Simulated v4 addresses + allowList.addTrustedSubnet("1.0.0.0/8"); + allowList.addTrustedSubnet("2.0.0.0/4"); + ::subnetAssert(allowList, parseAddr("1.0.1.1"), true); + ::subnetAssert(allowList, parseAddr("1.1.2.2"), true); + ::subnetAssert(allowList, parseAddr("2.2.1.1"), true); + ::subnetAssert(allowList, parseAddr("128.0.1.1"), false); + allowList = IPAllowList(); + allowList.addTrustedSubnet("0.0.0.0/2"); + allowList.addTrustedSubnet("abcd::/16"); + ::subnetAssert(allowList, parseAddr("1.0.1.1"), true); + ::subnetAssert(allowList, parseAddr("1.1.2.2"), true); + ::subnetAssert(allowList, parseAddr("2.2.1.1"), true); + ::subnetAssert(allowList, parseAddr("4.0.1.2"), true); + ::subnetAssert(allowList, parseAddr("5.2.1.1"), true); + ::subnetAssert(allowList, parseAddr("128.0.1.1"), false); + ::subnetAssert(allowList, parseAddr("192.168.3.1"), false); + // Simulated v6 addresses + ::subnetAssert(allowList, parseAddr("abcd::1:2:3:4"), true); + ::subnetAssert(allowList, parseAddr("abcd::2:3:3:4"), true); + ::subnetAssert(allowList, parseAddr("abcd:ab:ab:fdb:2:3:3:4"), true); + ::subnetAssert(allowList, parseAddr("2001:fdb1:fdb2:fdb3:fdb4:fdb5:fdb6:12"), false); + ::subnetAssert(allowList, parseAddr("2001:fdb1:fdb2:fdb3:fdb4:fdb5:fdb6:1"), false); + ::subnetAssert(allowList, parseAddr("2001:fdb1:fdb2:fdb3:fdb4:fdb5:fdb6:fdb"), false); + // Corner Cases + allowList = IPAllowList(); + allowList.addTrustedSubnet("0.0.0.0/0"); + // Random address tests + SubNetTest subnetTest(allowList.subnets()[0]); + for (int i = 0; i < 10; ++i) { + // All IPv4 addresses are in the allow list + ::subnetAssert(allowList, subnetTest.randomAddress(), true); + // No IPv6 addresses are in the allow list + ::subnetAssert(allowList, subnetTest.randomAddress(), false); + } + allowList = IPAllowList(); + allowList.addTrustedSubnet("::/0"); + subnetTest = SubNetTest(allowList.subnets()[0]); + for (int i = 0; i < 10; ++i) { + // All IPv6 addresses are in the allow list + ::subnetAssert(allowList, subnetTest.randomAddress(), true); + // No IPv4 addresses are ub the allow list + ::subnetAssert(allowList, subnetTest.randomAddress(), false); + } + allowList = IPAllowList(); + IPAddress baseAddress = SubNetTest::randomAddress(); + allowList.addTrustedSubnet(fmt::format("{}/32", baseAddress.toString())); + for (int i = 0; i < 10; ++i) { + auto rnd = SubNetTest::randomAddress(); + ::subnetAssert(allowList, rnd, rnd == baseAddress); + rnd = SubNetTest::randomAddress(); + ::subnetAssert(allowList, rnd, false); + } + allowList = IPAllowList(); + baseAddress = SubNetTest::randomAddress(); + allowList.addTrustedSubnet(fmt::format("{}/128", baseAddress.toString())); + for (int i = 0; i < 10; ++i) { + auto rnd = SubNetTest::randomAddress(); + ::subnetAssert(allowList, rnd, rnd == baseAddress); + rnd = SubNetTest::randomAddress(); + ::subnetAssert(allowList, rnd, false); + } + for (int i = 0; i < 100; ++i) { + SubNetTest subnetTest(SubNetTest::randomSubNet()); + allowList = IPAllowList(); + allowList.addTrustedSubnet(subnetTest.subnet); + for (int j = 0; j < 10; ++j) { + bool inSubnet = deterministicRandom()->random01() < 0.7; + auto addr = subnetTest.randomAddress(inSubnet); + ::subnetAssert(allowList, addr, inSubnet); + } + } + return Void(); +} diff --git a/fdbrpc/IPAllowList.h b/fdbrpc/IPAllowList.h new file mode 100644 index 0000000000..4be11a7f52 --- /dev/null +++ b/fdbrpc/IPAllowList.h @@ -0,0 +1,86 @@ +/* + * IPAllowList.h + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2013-2022 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#ifndef FDBRPC_IP_ALLOW_LIST_H +#define FDBRPC_IP_ALLOW_LIST_H + +#include "flow/network.h" +#include "flow/Arena.h" + +struct AuthAllowedSubnet { + IPAddress baseAddress; + IPAddress addressMask; + + AuthAllowedSubnet(IPAddress const& baseAddress, IPAddress const& addressMask); + + static AuthAllowedSubnet fromString(std::string_view addressString); + + template + static std::array createBitMask(std::array const& addr, int netmaskWeight); + + bool operator()(IPAddress const& address) const { + if (addressMask.isV4() != address.isV4()) { + return false; + } + if (addressMask.isV4()) { + return (addressMask.toV4() & address.toV4()) == baseAddress.toV4(); + } else { + auto res = address.toV6(); + auto const& mask = addressMask.toV6(); + for (int i = 0; i < res.size(); ++i) { + res[i] &= mask[i]; + } + return res == baseAddress.toV6(); + } + } + + IPAddress netmask() const; + + int netmaskWeight() const; + + // some useful helper functions if we need to debug ip masks etc + static void printIP(std::string_view txt, IPAddress const& address); +}; + +class IPAllowList { + std::vector subnetList; + +public: + void addTrustedSubnet(std::string_view str) { subnetList.push_back(AuthAllowedSubnet::fromString(str)); } + + void addTrustedSubnet(AuthAllowedSubnet const& subnet) { subnetList.push_back(subnet); } + + std::vector const& subnets() const { return subnetList; } + + bool operator()(IPAddress address) const { + if (subnetList.empty()) { + return true; + } + for (auto const& subnet : subnetList) { + if (subnet(address)) { + return true; + } + } + return false; + } +}; + +#endif // FDBRPC_IP_ALLOW_LIST_H diff --git a/fdbrpc/LoadBalance.actor.h b/fdbrpc/LoadBalance.actor.h index 76326c48a4..39170d75c2 100644 --- a/fdbrpc/LoadBalance.actor.h +++ b/fdbrpc/LoadBalance.actor.h @@ -78,14 +78,14 @@ struct LoadBalancedReply { Optional getLoadBalancedReply(const LoadBalancedReply* reply); Optional getLoadBalancedReply(const void*); -ACTOR template +ACTOR template Future tssComparison(Req req, Future> fSource, Future> fTss, TSSEndpointData tssData, uint64_t srcEndpointId, Reference> ssTeam, - RequestStream Interface::*channel) { + RequestStream Interface::*channel) { state double startTime = now(); state Future>> fTssWithTimeout = timeout(fTss, FLOW_KNOBS->LOAD_BALANCE_TSS_TIMEOUT); state int finished = 0; @@ -157,7 +157,7 @@ Future tssComparison(Req req, state std::vector>> restOfTeamFutures; restOfTeamFutures.reserve(ssTeam->size() - 1); for (int i = 0; i < ssTeam->size(); i++) { - RequestStream const* si = &ssTeam->get(i, channel); + RequestStream const* si = &ssTeam->get(i, channel); if (si->getEndpoint().token.first() != srcEndpointId) { // don't re-request to SS we already have a response from resetReply(req); @@ -242,7 +242,7 @@ FDB_DECLARE_BOOLEAN_PARAM(AtMostOnce); FDB_DECLARE_BOOLEAN_PARAM(TriedAllOptions); // Stores state for a request made by the load balancer -template +template struct RequestData : NonCopyable { typedef ErrorOr Reply; @@ -257,12 +257,12 @@ struct RequestData : NonCopyable { // This is true once setupRequest is called, even though at that point the response is Never(). bool isValid() { return response.isValid(); } - static void maybeDuplicateTSSRequest(RequestStream const* stream, + static void maybeDuplicateTSSRequest(RequestStream const* stream, Request& request, QueueModel* model, Future ssResponse, Reference> alternatives, - RequestStream Interface::*channel) { + RequestStream Interface::*channel) { if (model) { // Send parallel request to TSS pair, if it exists Optional tssData = model->getTssData(stream->getEndpoint().token.first()); @@ -271,7 +271,7 @@ struct RequestData : NonCopyable { TEST(true); // duplicating request to TSS resetReply(request); // FIXME: optimize to avoid creating new netNotifiedQueue for each message - RequestStream tssRequestStream(tssData.get().endpoint); + RequestStream tssRequestStream(tssData.get().endpoint); Future> fTssResult = tssRequestStream.tryGetReply(request); model->addActor.send(tssComparison(request, ssResponse, @@ -288,11 +288,11 @@ struct RequestData : NonCopyable { void startRequest( double backoff, TriedAllOptions triedAllOptions, - RequestStream const* stream, + RequestStream const* stream, Request& request, QueueModel* model, Reference> alternatives, // alternatives and channel passed through for TSS check - RequestStream Interface::*channel) { + RequestStream Interface::*channel) { modelHolder = Reference(); requestStarted = false; @@ -438,18 +438,18 @@ struct RequestData : NonCopyable { // list of servers. // When model is set, load balance among alternatives in the same DC aims to balance request queue length on these // interfaces. If too many interfaces in the same DC are bad, try remote interfaces. -ACTOR template +ACTOR template Future loadBalance( Reference> alternatives, - RequestStream Interface::*channel, + RequestStream Interface::*channel, Request request = Request(), TaskPriority taskID = TaskPriority::DefaultPromiseEndpoint, AtMostOnce atMostOnce = AtMostOnce::False, // if true, throws request_maybe_delivered() instead of retrying automatically QueueModel* model = nullptr) { - state RequestData firstRequestData; - state RequestData secondRequestData; + state RequestData firstRequestData; + state RequestData secondRequestData; state Optional firstRequestEndpoint; state Future secondDelay = Never(); @@ -488,7 +488,7 @@ Future loadBalance( break; } - RequestStream const* thisStream = &alternatives->get(i, channel); + RequestStream const* thisStream = &alternatives->get(i, channel); if (!IFailureMonitor::failureMonitor().getState(thisStream->getEndpoint()).failed) { auto const& qd = model->getMeasurement(thisStream->getEndpoint().token.first()); if (now() > qd.failedUntil) { @@ -527,7 +527,7 @@ Future loadBalance( // go through all the remote servers again, since we may have // skipped it. for (int i = alternatives->countBest(); i < alternatives->size(); i++) { - RequestStream const* thisStream = &alternatives->get(i, channel); + RequestStream const* thisStream = &alternatives->get(i, channel); if (!IFailureMonitor::failureMonitor().getState(thisStream->getEndpoint()).failed) { auto const& qd = model->getMeasurement(thisStream->getEndpoint().token.first()); if (now() > qd.failedUntil) { @@ -574,7 +574,7 @@ Future loadBalance( if (ev.isEnabled()) { ev.log(); for (int alternativeNum = 0; alternativeNum < alternatives->size(); alternativeNum++) { - RequestStream const* thisStream = &alternatives->get(alternativeNum, channel); + RequestStream const* thisStream = &alternatives->get(alternativeNum, channel); TraceEvent(SevWarn, "LoadBalanceTooLongEndpoint") .detail("Addr", thisStream->getEndpoint().getPrimaryAddress()) .detail("Token", thisStream->getEndpoint().token) @@ -586,7 +586,7 @@ Future loadBalance( // Find an alternative, if any, that is not failed, starting with // nextAlt. This logic matters only if model == nullptr. Otherwise, the // bestAlt and nextAlt have been decided. - state RequestStream const* stream = nullptr; + state RequestStream const* stream = nullptr; for (int alternativeNum = 0; alternativeNum < alternatives->size(); alternativeNum++) { int useAlt = nextAlt; if (nextAlt == startAlt) @@ -724,9 +724,9 @@ Optional getBasicLoadBalancedReply(const BasicLoadBalanc Optional getBasicLoadBalancedReply(const void*); // A simpler version of LoadBalance that does not send second requests where the list of servers are always fresh -ACTOR template +ACTOR template Future basicLoadBalance(Reference> alternatives, - RequestStream Interface::*channel, + RequestStream Interface::*channel, Request request = Request(), TaskPriority taskID = TaskPriority::DefaultPromiseEndpoint, AtMostOnce atMostOnce = AtMostOnce::False) { @@ -749,7 +749,7 @@ Future basicLoadBalance(Reference> al state int useAlt; loop { // Find an alternative, if any, that is not failed, starting with nextAlt - state RequestStream const* stream = nullptr; + state RequestStream const* stream = nullptr; for (int alternativeNum = 0; alternativeNum < alternatives->size(); alternativeNum++) { useAlt = nextAlt; if (nextAlt == startAlt) diff --git a/fdbrpc/PerfMetric.h b/fdbrpc/PerfMetric.h index 256dbf50f6..80289fbf03 100644 --- a/fdbrpc/PerfMetric.h +++ b/fdbrpc/PerfMetric.h @@ -43,7 +43,7 @@ struct PerfMetric { std::string format_code() const { return m_format_code; } bool averaged() const { return m_averaged; } - PerfMetric withPrefix(const std::string& pre) { + PerfMetric withPrefix(const std::string& pre) const { return PerfMetric(pre + name(), value(), Averaged{ averaged() }, format_code()); } diff --git a/fdbrpc/TenantInfo.h b/fdbrpc/TenantInfo.h new file mode 100644 index 0000000000..3efec71845 --- /dev/null +++ b/fdbrpc/TenantInfo.h @@ -0,0 +1,73 @@ +/* + * TenantInfo.h + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2013-2022 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#ifndef FDBRPC_TENANTINFO_H_ +#define FDBRPC_TENANTINFO_H_ +#include "flow/Arena.h" +#include "fdbrpc/fdbrpc.h" +#include + +struct TenantInfoRef { + TenantInfoRef() {} + TenantInfoRef(Arena& p, StringRef toCopy) : tenantName(StringRef(p, toCopy)) {} + TenantInfoRef(Arena& p, TenantInfoRef toCopy) + : tenantName(toCopy.tenantName.present() ? Optional(StringRef(p, toCopy.tenantName.get())) + : Optional()) {} + // Empty tenant name means that the peer is trusted + Optional tenantName; + + bool operator<(TenantInfoRef const& other) const { + if (!other.tenantName.present()) { + return false; + } + if (!tenantName.present()) { + return true; + } + return tenantName.get() < other.tenantName.get(); + } + + template + void serialize(Ar& ar) { + serializer(ar, tenantName); + } +}; + +struct AuthorizedTenants : ReferenceCounted { + Arena arena; + std::set authorizedTenants; + bool trusted = false; +}; + +// TODO: receive and validate token instead +struct AuthorizationRequest { + constexpr static FileIdentifier file_identifier = 11499331; + + Arena arena; + VectorRef tenants; + ReplyPromise reply; + + template + void serialize(Ar& ar) { + serializer(ar, tenants, reply, arena); + } +}; + +#endif // FDBRPC_TENANTINFO_H_ diff --git a/fdbrpc/fdbrpc.h b/fdbrpc/fdbrpc.h index b12c41b934..bec32b6696 100644 --- a/fdbrpc/fdbrpc.h +++ b/fdbrpc/fdbrpc.h @@ -110,6 +110,8 @@ struct NetSAV final : SAV, FlowReceiver, FastAllocated> { SAV::sendAndDelPromiseRef(message.get().asUnderlyingType()); } } + + bool isPublic() const override { return true; } }; template @@ -290,6 +292,8 @@ struct AcknowledgementReceiver final : FlowReceiver, FastAllocated message; reader.deserialize(message); @@ -337,6 +341,8 @@ struct NetNotifiedQueueWithAcknowledgements final : NotifiedQueue, acknowledgements.failures = tagError(FlowTransport::transport().loadedDisconnect(), operation_obsolete()); } + bool isPublic() const override { return true; } + void destroy() override { delete this; } void receive(ArenaObjectReader& reader) override { this->addPromiseRef(); @@ -642,10 +648,10 @@ struct serializable_traits> : std::true_type { } }; -template -struct NetNotifiedQueue final : NotifiedQueue, FlowReceiver, FastAllocated> { - using FastAllocated>::operator new; - using FastAllocated>::operator delete; +template +struct NetNotifiedQueue final : NotifiedQueue, FlowReceiver, FastAllocated> { + using FastAllocated>::operator new; + using FastAllocated>::operator delete; NetNotifiedQueue(int futures, int promises) : NotifiedQueue(futures, promises) {} NetNotifiedQueue(int futures, int promises, const Endpoint& remoteEndpoint) @@ -660,9 +666,10 @@ struct NetNotifiedQueue final : NotifiedQueue, FlowReceiver, FastAllocateddelPromiseRef(); } bool isStream() const override { return true; } + bool isPublic() const override { return IsPublic; } }; -template +template class RequestStream { public: // stream.send( request ) @@ -726,6 +733,9 @@ public: Future disc = makeDependent(IFailureMonitor::failureMonitor()).onDisconnectOrFailure(getEndpoint(taskID)); if (disc.isReady()) { + if (IFailureMonitor::failureMonitor().knownUnauthorized(getEndpoint(taskID))) { + return ErrorOr(unauthorized_attempt()); + } return ErrorOr(request_maybe_delivered()); } Reference peer = @@ -744,6 +754,9 @@ public: Future disc = makeDependent(IFailureMonitor::failureMonitor()).onDisconnectOrFailure(getEndpoint()); if (disc.isReady()) { + if (IFailureMonitor::failureMonitor().knownUnauthorized(getEndpoint())) { + return ErrorOr(unauthorized_attempt()); + } return ErrorOr(request_maybe_delivered()); } Reference peer = @@ -821,13 +834,13 @@ public: return getReplyUnlessFailedFor(ReplyPromise(), sustainedFailureDuration, sustainedFailureSlope); } - explicit RequestStream(const Endpoint& endpoint) : queue(new NetNotifiedQueue(0, 1, endpoint)) {} + explicit RequestStream(const Endpoint& endpoint) : queue(new NetNotifiedQueue(0, 1, endpoint)) {} FutureStream getFuture() const { queue->addFutureRef(); return FutureStream(queue); } - RequestStream() : queue(new NetNotifiedQueue(0, 1)) {} + RequestStream() : queue(new NetNotifiedQueue(0, 1)) {} explicit RequestStream(PeerCompatibilityPolicy policy) : RequestStream() { queue->setPeerCompatibilityPolicy(policy); } @@ -861,8 +874,8 @@ public: queue->makeWellKnownEndpoint(Endpoint::Token(-1, wlTokenID), taskID); } - bool operator==(const RequestStream& rhs) const { return queue == rhs.queue; } - bool operator!=(const RequestStream& rhs) const { return !(*this == rhs); } + bool operator==(const RequestStream& rhs) const { return queue == rhs.queue; } + bool operator!=(const RequestStream& rhs) const { return !(*this == rhs); } bool isEmpty() const { return !queue->isReady(); } uint32_t size() const { return queue->size(); } @@ -871,32 +884,37 @@ public: } private: - NetNotifiedQueue* queue; + NetNotifiedQueue* queue; }; -template -void save(Ar& ar, const RequestStream& value) { +template +using PrivateRequestStream = RequestStream; +template +using PublicRequestStream = RequestStream; + +template +void save(Ar& ar, const RequestStream& value) { auto const& ep = value.getEndpoint(); ar << ep; UNSTOPPABLE_ASSERT( ep.getPrimaryAddress().isValid()); // No serializing PromiseStreams on a client with no public address } -template -void load(Ar& ar, RequestStream& value) { +template +void load(Ar& ar, RequestStream& value) { Endpoint endpoint; ar >> endpoint; - value = RequestStream(endpoint); + value = RequestStream(endpoint); } -template -struct serializable_traits> : std::true_type { +template +struct serializable_traits> : std::true_type { template - static void serialize(Archiver& ar, RequestStream& stream) { + static void serialize(Archiver& ar, RequestStream& stream) { if constexpr (Archiver::isDeserializing) { Endpoint endpoint; serializer(ar, endpoint); - stream = RequestStream(endpoint); + stream = RequestStream(endpoint); } else { const auto& ep = stream.getEndpoint(); serializer(ar, ep); diff --git a/fdbrpc/genericactors.actor.h b/fdbrpc/genericactors.actor.h index 821c93d5fb..6226676ae1 100644 --- a/fdbrpc/genericactors.actor.h +++ b/fdbrpc/genericactors.actor.h @@ -32,8 +32,8 @@ #include "flow/Hostname.h" #include "flow/actorcompiler.h" // This must be the last #include. -ACTOR template -Future retryBrokenPromise(RequestStream to, Req request) { +ACTOR template +Future retryBrokenPromise(RequestStream to, Req request) { // Like to.getReply(request), except that a broken_promise exception results in retrying request immediately. // Suitable for use with well known endpoints, which are likely to return to existence after the other process // restarts. Not normally useful for ordinary endpoints, which conventionally are permanently destroyed after @@ -52,8 +52,8 @@ Future retryBrokenPromise(RequestStream to, Req request) { } } -ACTOR template -Future retryBrokenPromise(RequestStream to, Req request, TaskPriority taskID) { +ACTOR template +Future retryBrokenPromise(RequestStream to, Req request, TaskPriority taskID) { // Like to.getReply(request), except that a broken_promise exception results in retrying request immediately. // Suitable for use with well known endpoints, which are likely to return to existence after the other process // restarts. Not normally useful for ordinary endpoints, which conventionally are permanently destroyed after @@ -350,7 +350,11 @@ Future> waitValueOrSignal(Future value, try { choose { when(X x = wait(value)) { return x; } - when(wait(signal)) { return ErrorOr(request_maybe_delivered()); } + when(wait(signal)) { + return ErrorOr(IFailureMonitor::failureMonitor().knownUnauthorized(endpoint) + ? unauthorized_attempt() + : request_maybe_delivered()); + } } } catch (Error& e) { if (signal.isError()) { @@ -373,12 +377,31 @@ Future> waitValueOrSignal(Future value, ACTOR template Future sendCanceler(ReplyPromise reply, ReliablePacket* send, Endpoint endpoint) { + state bool didCancelReliable = false; try { - T t = wait(reply.getFuture()); - FlowTransport::transport().cancelReliable(send); - return t; + loop { + if (IFailureMonitor::failureMonitor().permanentlyFailed(endpoint)) { + FlowTransport::transport().cancelReliable(send); + didCancelReliable = true; + if (IFailureMonitor::failureMonitor().knownUnauthorized(endpoint)) { + throw unauthorized_attempt(); + } else { + wait(Never()); + } + } + choose { + when(T t = wait(reply.getFuture())) { + FlowTransport::transport().cancelReliable(send); + didCancelReliable = true; + return t; + } + when(wait(IFailureMonitor::failureMonitor().onStateChanged(endpoint))) {} + } + } } catch (Error& e) { - FlowTransport::transport().cancelReliable(send); + if (!didCancelReliable) { + FlowTransport::transport().cancelReliable(send); + } if (e.code() == error_code_broken_promise) { IFailureMonitor::failureMonitor().endpointNotFound(endpoint); } diff --git a/fdbrpc/sim2.actor.cpp b/fdbrpc/sim2.actor.cpp index 64817a5214..05e2f513ce 100644 --- a/fdbrpc/sim2.actor.cpp +++ b/fdbrpc/sim2.actor.cpp @@ -445,6 +445,7 @@ private: TraceEvent(SevError, "LeakedConnection", self->dbgid) .error(connection_leaked()) .detail("MyAddr", self->process->address) + .detail("IsPublic", self->process->address.isPublic()) .detail("PeerAddr", self->peerEndpoint) .detail("PeerId", self->peerId) .detail("Opened", self->opened); diff --git a/fdbrpc/simulator.h b/fdbrpc/simulator.h index 8f23db0400..72af7c372c 100644 --- a/fdbrpc/simulator.h +++ b/fdbrpc/simulator.h @@ -87,6 +87,7 @@ public: UID uid; ProtocolVersion protocolVersion; + bool excludeFromRestarts = false; std::vector childs; diff --git a/fdbserver/ApplyMetadataMutation.cpp b/fdbserver/ApplyMetadataMutation.cpp index cd9d15b958..90f987021f 100644 --- a/fdbserver/ApplyMetadataMutation.cpp +++ b/fdbserver/ApplyMetadataMutation.cpp @@ -118,7 +118,7 @@ private: KeyRangeMap* keyInfo = nullptr; KeyRangeMap* cacheInfo = nullptr; std::map* uid_applyMutationsData = nullptr; - RequestStream commit = RequestStream(); + PublicRequestStream commit = PublicRequestStream(); Database cx = Database(); NotifiedVersion* committedVersion = nullptr; std::map>* storageCache = nullptr; diff --git a/fdbserver/CMakeLists.txt b/fdbserver/CMakeLists.txt index 4e53302e46..df049163ac 100644 --- a/fdbserver/CMakeLists.txt +++ b/fdbserver/CMakeLists.txt @@ -198,6 +198,7 @@ set(FDBSERVER_SRCS workloads/ChangeFeeds.actor.cpp workloads/ClearSingleRange.actor.cpp workloads/ClientTransactionProfileCorrectness.actor.cpp + workloads/ClientWorkload.actor.cpp workloads/ClogSingleConnection.actor.cpp workloads/CommitBugCheck.actor.cpp workloads/ConfigIncrement.actor.cpp @@ -251,6 +252,7 @@ set(FDBSERVER_SRCS workloads/PhysicalShardMove.actor.cpp workloads/Ping.actor.cpp workloads/PopulateTPCC.actor.cpp + workloads/PrivateEndpoints.actor.cpp workloads/ProtocolVersion.actor.cpp workloads/PubSubMultiples.actor.cpp workloads/QueuePush.actor.cpp diff --git a/fdbserver/ClusterController.actor.cpp b/fdbserver/ClusterController.actor.cpp index a7b56ee3a2..21e42a5f5c 100644 --- a/fdbserver/ClusterController.actor.cpp +++ b/fdbserver/ClusterController.actor.cpp @@ -2981,7 +2981,8 @@ TEST_CASE("/fdbserver/clustercontroller/shouldTriggerRecoveryDueToDegradedServer testDbInfo.logSystemConfig.tLogs.push_back(remoteTLogSet); GrvProxyInterface proxyInterf; - proxyInterf.getConsistentReadVersion = RequestStream(Endpoint({ proxy }, testUID)); + proxyInterf.getConsistentReadVersion = + PublicRequestStream(Endpoint({ proxy }, testUID)); testDbInfo.client.grvProxies.push_back(proxyInterf); ResolverInterface resolverInterf; @@ -3090,11 +3091,12 @@ TEST_CASE("/fdbserver/clustercontroller/shouldTriggerFailoverDueToDegradedServer testDbInfo.logSystemConfig.tLogs.push_back(remoteTLogSet); GrvProxyInterface grvProxyInterf; - grvProxyInterf.getConsistentReadVersion = RequestStream(Endpoint({ proxy }, testUID)); + grvProxyInterf.getConsistentReadVersion = + PublicRequestStream(Endpoint({ proxy }, testUID)); testDbInfo.client.grvProxies.push_back(grvProxyInterf); CommitProxyInterface commitProxyInterf; - commitProxyInterf.commit = RequestStream(Endpoint({ proxy2 }, testUID)); + commitProxyInterf.commit = PublicRequestStream(Endpoint({ proxy2 }, testUID)); testDbInfo.client.commitProxies.push_back(commitProxyInterf); ResolverInterface resolverInterf; diff --git a/fdbserver/GrvProxyServer.actor.cpp b/fdbserver/GrvProxyServer.actor.cpp index d4b418d664..3925025ab7 100644 --- a/fdbserver/GrvProxyServer.actor.cpp +++ b/fdbserver/GrvProxyServer.actor.cpp @@ -239,7 +239,7 @@ struct GrvProxyData { GrvProxyStats stats; MasterInterface master; - RequestStream getConsistentReadVersion; + PublicRequestStream getConsistentReadVersion; Reference logSystem; Database cx; @@ -275,7 +275,7 @@ struct GrvProxyData { GrvProxyData(UID dbgid, MasterInterface master, - RequestStream getConsistentReadVersion, + PublicRequestStream getConsistentReadVersion, Reference const> db) : dbgid(dbgid), stats(dbgid), master(master), getConsistentReadVersion(getConsistentReadVersion), cx(openDBOnServer(db, TaskPriority::DefaultEndpoint, LockAware::True)), db(db), lastStartCommit(0), diff --git a/fdbserver/ProxyCommitData.actor.h b/fdbserver/ProxyCommitData.actor.h index 0bef13a363..620e58622e 100644 --- a/fdbserver/ProxyCommitData.actor.h +++ b/fdbserver/ProxyCommitData.actor.h @@ -29,6 +29,9 @@ #include "fdbclient/Tenant.h" #include "fdbrpc/Stats.h" #include "fdbserver/Knobs.h" +#include "fdbserver/LogSystem.h" +#include "fdbserver/MasterInterface.h" +#include "fdbserver/ResolverInterface.h" #include "fdbserver/LogSystemDiskQueueAdapter.h" #include "flow/IRandom.h" @@ -193,8 +196,8 @@ struct ProxyCommitData { NotifiedVersion latestLocalCommitBatchResolving; NotifiedVersion latestLocalCommitBatchLogging; - RequestStream getConsistentReadVersion; - RequestStream commit; + PublicRequestStream getConsistentReadVersion; + PublicRequestStream commit; Database cx; Reference const> db; EventMetricHandle singleKeyMutationEvent; @@ -273,9 +276,9 @@ struct ProxyCommitData { ProxyCommitData(UID dbgid, MasterInterface master, - RequestStream getConsistentReadVersion, + PublicRequestStream getConsistentReadVersion, Version recoveryTransactionVersion, - RequestStream commit, + PublicRequestStream commit, Reference const> db, bool firstProxy) : dbgid(dbgid), commitBatchesMemBytesCount(0), diff --git a/fdbserver/SimulatedCluster.actor.cpp b/fdbserver/SimulatedCluster.actor.cpp index 3883dd50d3..518e72061e 100644 --- a/fdbserver/SimulatedCluster.actor.cpp +++ b/fdbserver/SimulatedCluster.actor.cpp @@ -26,6 +26,7 @@ #include #include "fdbrpc/Locality.h" #include "fdbrpc/simulator.h" +#include "fdbrpc/IPAllowList.h" #include "fdbclient/ClusterConnectionFile.h" #include "fdbclient/ClusterConnectionMemoryRecord.h" #include "fdbclient/DatabaseContext.h" @@ -520,6 +521,10 @@ ACTOR Future simulatedFDBDRebooter(ReferencerandomUniqueID(); state int cycles = 0; + state IPAllowList allowList; + + allowList.addTrustedSubnet("0.0.0.0/2"sv); + allowList.addTrustedSubnet("abcd::/16"sv); loop { auto waitTime = @@ -579,7 +584,8 @@ ACTOR Future simulatedFDBDRebooter(Reference> futures; @@ -2334,10 +2340,14 @@ ACTOR void setupAndRun(std::string dataFolder, state Standalone startingConfiguration; state int testerCount = 1; state TestConfig testConfig; + state IPAllowList allowList; testConfig.readFromConfig(testFile); g_simulator.hasDiffProtocolProcess = testConfig.startIncompatibleProcess; g_simulator.setDiffProtocol = false; + // Build simulator allow list + allowList.addTrustedSubnet("0.0.0.0/2"sv); + allowList.addTrustedSubnet("abcd::/16"sv); state bool allowDefaultTenant = testConfig.allowDefaultTenant; state bool allowDisablingTenants = testConfig.allowDisablingTenants; @@ -2382,7 +2392,7 @@ ACTOR void setupAndRun(std::string dataFolder, } // TODO (IPv6) Use IPv6? - wait(g_simulator.onProcess( + auto testSystem = g_simulator.newProcess("TestSystem", IPAddress(0x01010101), 1, @@ -2395,10 +2405,11 @@ ACTOR void setupAndRun(std::string dataFolder, ProcessClass(ProcessClass::TesterClass, ProcessClass::CommandLineSource), "", "", - currentProtocolVersion), - TaskPriority::DefaultYield)); + currentProtocolVersion); + testSystem->excludeFromRestarts = true; + wait(g_simulator.onProcess(testSystem, TaskPriority::DefaultYield)); Sim2FileSystem::newFileSystem(); - FlowTransport::createInstance(true, 1, WLTOKEN_RESERVED_COUNT); + FlowTransport::createInstance(true, 1, WLTOKEN_RESERVED_COUNT, &allowList); TEST(true); // Simulation start state Optional defaultTenant; diff --git a/fdbserver/fdbserver.actor.cpp b/fdbserver/fdbserver.actor.cpp index aa40c6abbb..d27d65d97e 100644 --- a/fdbserver/fdbserver.actor.cpp +++ b/fdbserver/fdbserver.actor.cpp @@ -35,6 +35,8 @@ #include #include +#include + #include "fdbclient/ActorLineageProfiler.h" #include "fdbclient/ClusterConnectionFile.h" #include "fdbclient/IKnobCollection.h" @@ -45,6 +47,7 @@ #include "fdbclient/WellKnownEndpoints.h" #include "fdbclient/SimpleIni.h" #include "fdbrpc/AsyncFileCached.actor.h" +#include "fdbrpc/IPAllowList.h" #include "fdbrpc/FlowProcess.actor.h" #include "fdbrpc/Net2FileSystem.h" #include "fdbrpc/PerfMetric.h" @@ -107,7 +110,8 @@ enum { OPT_DCID, OPT_MACHINE_CLASS, OPT_BUGGIFY, OPT_VERSION, OPT_BUILD_FLAGS, OPT_CRASHONERROR, OPT_HELP, OPT_NETWORKIMPL, OPT_NOBUFSTDOUT, OPT_BUFSTDOUTERR, OPT_TRACECLOCK, OPT_NUMTESTERS, OPT_DEVHELP, OPT_ROLLSIZE, OPT_MAXLOGS, OPT_MAXLOGSSIZE, OPT_KNOB, OPT_UNITTESTPARAM, OPT_TESTSERVERS, OPT_TEST_ON_SERVERS, OPT_METRICSCONNFILE, OPT_METRICSPREFIX, OPT_LOGGROUP, OPT_LOCALITY, OPT_IO_TRUST_SECONDS, OPT_IO_TRUST_WARN_ONLY, OPT_FILESYSTEM, OPT_PROFILER_RSS_SIZE, OPT_KVFILE, - OPT_TRACE_FORMAT, OPT_WHITELIST_BINPATH, OPT_BLOB_CREDENTIAL_FILE, OPT_CONFIG_PATH, OPT_USE_TEST_CONFIG_DB, OPT_FAULT_INJECTION, OPT_PROFILER, OPT_PRINT_SIMTIME, OPT_FLOW_PROCESS_NAME, OPT_FLOW_PROCESS_ENDPOINT + OPT_TRACE_FORMAT, OPT_WHITELIST_BINPATH, OPT_BLOB_CREDENTIAL_FILE, OPT_CONFIG_PATH, OPT_USE_TEST_CONFIG_DB, OPT_FAULT_INJECTION, OPT_PROFILER, OPT_PRINT_SIMTIME, + OPT_FLOW_PROCESS_NAME, OPT_FLOW_PROCESS_ENDPOINT, OPT_IP_TRUSTED_MASK }; CSimpleOpt::SOption g_rgOptions[] = { @@ -199,6 +203,7 @@ CSimpleOpt::SOption g_rgOptions[] = { { OPT_PRINT_SIMTIME, "--print-sim-time", SO_NONE }, { OPT_FLOW_PROCESS_NAME, "--process-name", SO_REQ_SEP }, { OPT_FLOW_PROCESS_ENDPOINT, "--process-endpoint", SO_REQ_SEP }, + { OPT_IP_TRUSTED_MASK, "--trusted-subnet-", SO_REQ_SEP }, #ifndef TLS_DISABLED TLS_OPTION_FLAGS @@ -311,7 +316,7 @@ UID getSharedMemoryMachineId() { std::string sharedMemoryIdentifier = "fdbserver_shared_memory_id"; loop { try { - // "0" is the default parameter "addr" + // "0" is the default netPrefix "addr" boost::interprocess::managed_shared_memory segment( boost::interprocess::open_or_create, sharedMemoryIdentifier.c_str(), 1000, 0, p.permission); machineId = segment.find_or_construct("machineId")(newUID); @@ -1041,6 +1046,7 @@ struct CLIOptions { std::string flowProcessName; Endpoint flowProcessEndpoint; bool printSimTime = false; + IPAllowList allowList; static CLIOptions parseArgs(int argc, char* argv[]) { CLIOptions opts; @@ -1167,6 +1173,15 @@ private: localities.set(key, Standalone(std::string(args.OptionArg()))); break; } + case OPT_IP_TRUSTED_MASK: { + Optional subnetKey = extractPrefixedArgument("--trusted-subnet", args.OptionSyntax()); + if (!subnetKey.present()) { + fprintf(stderr, "ERROR: unable to parse locality key '%s'\n", args.OptionSyntax()); + flushAndExit(FDB_EXIT_ERROR); + } + allowList.addTrustedSubnet(args.OptionArg()); + break; + } case OPT_VERSION: printVersion(); flushAndExit(FDB_EXIT_SUCCESS); @@ -1853,7 +1868,7 @@ int main(int argc, char* argv[]) { } else { g_network = newNet2(opts.tlsConfig, opts.useThreadPool, true); g_network->addStopCallback(Net2FileSystem::stop); - FlowTransport::createInstance(false, 1, WLTOKEN_RESERVED_COUNT); + FlowTransport::createInstance(false, 1, WLTOKEN_RESERVED_COUNT, &opts.allowList); opts.buildNetwork(argv[0]); const bool expectsPublicAddress = (role == ServerRole::FDBD || role == ServerRole::NetworkTestServer || diff --git a/fdbserver/tester.actor.cpp b/fdbserver/tester.actor.cpp index 10aefd4e96..4abd71c494 100644 --- a/fdbserver/tester.actor.cpp +++ b/fdbserver/tester.actor.cpp @@ -48,7 +48,7 @@ WorkloadContext::WorkloadContext() {} WorkloadContext::WorkloadContext(const WorkloadContext& r) : options(r.options), clientId(r.clientId), clientCount(r.clientCount), sharedRandomNumber(r.sharedRandomNumber), - dbInfo(r.dbInfo) {} + dbInfo(r.dbInfo), ccr(r.ccr) {} WorkloadContext::~WorkloadContext() {} @@ -326,34 +326,51 @@ struct CompoundWorkload : TestWorkload { } return allTrue(all); } - void getMetrics(std::vector& m) override { - for (int w = 0; w < workloads.size(); w++) { + + ACTOR static Future> getMetrics(CompoundWorkload* self) { + state std::vector>> results; + for (int w = 0; w < self->workloads.size(); w++) { std::vector p; - workloads[w]->getMetrics(p); - for (int i = 0; i < p.size(); i++) - m.push_back(p[i].withPrefix(workloads[w]->description() + ".")); + results.push_back(self->workloads[w]->getMetrics()); } + wait(waitForAll(results)); + std::vector res; + for (int i = 0; i < results.size(); ++i) { + auto const& p = results[i].get(); + for (auto const& m : p) { + res.push_back(m.withPrefix(self->workloads[i]->description() + ".")); + } + } + return res; } + + Future> getMetrics() override { return getMetrics(this); } double getCheckTimeout() const override { double m = 0; for (int w = 0; w < workloads.size(); w++) m = std::max(workloads[w]->getCheckTimeout(), m); return m; } + + void getMetrics(std::vector&) override { ASSERT(false); } }; -Reference getWorkloadIface(WorkloadRequest work, - VectorRef options, - Reference const> dbInfo) { - Value testName = getOption(options, LiteralStringRef("testName"), LiteralStringRef("no-test-specified")); +ACTOR Future> getWorkloadIface(WorkloadRequest work, + Reference ccr, + VectorRef options, + Reference const> dbInfo) { + state Reference workload; + state Value testName = getOption(options, LiteralStringRef("testName"), LiteralStringRef("no-test-specified")); WorkloadContext wcx; wcx.clientId = work.clientId; wcx.clientCount = work.clientCount; + wcx.ccr = ccr; wcx.dbInfo = dbInfo; wcx.options = options; wcx.sharedRandomNumber = work.sharedRandomNumber; - auto workload = IWorkloadFactory::create(testName.toString(), wcx); + workload = IWorkloadFactory::create(testName.toString(), wcx); + wait(workload->initialized()); auto unconsumedOptions = checkAllOptionsConsumed(workload ? workload->options : VectorRef()); if (!workload || unconsumedOptions.size()) { @@ -378,24 +395,33 @@ Reference getWorkloadIface(WorkloadRequest work, return workload; } -Reference getWorkloadIface(WorkloadRequest work, Reference const> dbInfo) { +ACTOR Future> getWorkloadIface(WorkloadRequest work, + Reference ccr, + Reference const> dbInfo) { + state WorkloadContext wcx; + state std::vector>> ifaces; if (work.options.size() < 1) { TraceEvent(SevError, "TestCreationError").detail("Reason", "No options provided"); fprintf(stderr, "ERROR: No options were provided for workload.\n"); throw test_specification_invalid(); } - if (work.options.size() == 1) - return getWorkloadIface(work, work.options[0], dbInfo); + if (work.options.size() == 1) { + Reference res = wait(getWorkloadIface(work, ccr, work.options[0], dbInfo)); + return res; + } - WorkloadContext wcx; wcx.clientId = work.clientId; wcx.clientCount = work.clientCount; wcx.sharedRandomNumber = work.sharedRandomNumber; // FIXME: Other stuff not filled in; why isn't this constructed here and passed down to the other // getWorkloadIface()? + for (int i = 0; i < work.options.size(); i++) { + ifaces.push_back(getWorkloadIface(work, ccr, work.options[i], dbInfo)); + } + wait(waitForAll(ifaces)); auto compound = makeReference(wcx); for (int i = 0; i < work.options.size(); i++) { - compound->add(getWorkloadIface(work, work.options[i], dbInfo)); + compound->add(ifaces[i].getValue()); } return compound; } @@ -494,7 +520,7 @@ ACTOR Future testDatabaseLiveness(Database cx, try { state double start = now(); auto traceMsg = "PingingDatabaseLiveness_" + context; - TraceEvent(traceMsg.c_str()); + TraceEvent(traceMsg.c_str()).log(); wait(timeoutError(pingDatabase(cx), databasePingDelay)); double pingTime = now() - start; ASSERT(pingTime > 0); @@ -607,10 +633,9 @@ ACTOR Future runWorkloadAsync(Database cx, when(ReplyPromise> req = waitNext(workIface.metrics.getFuture())) { state ReplyPromise> s_req = req; try { - std::vector m; - workload->getMetrics(m); + std::vector m = wait(workload->getMetrics()); TraceEvent("WorkloadSendMetrics", workIface.id()).detail("Count", m.size()); - req.send(m); + s_req.send(m); } catch (Error& e) { if (e.code() == error_code_please_reboot || e.code() == error_code_please_reboot_delete) throw; @@ -649,7 +674,7 @@ ACTOR Future testerServerWorkload(WorkloadRequest work, // add test for "done" ? TraceEvent("WorkloadReceived", workIface.id()).detail("Title", work.title); - auto workload = getWorkloadIface(work, dbInfo); + Reference workload = wait(getWorkloadIface(work, ccr, dbInfo)); if (!workload) { TraceEvent("TestCreationError").detail("Reason", "Workload could not be created"); fprintf(stderr, "ERROR: The workload could not be created.\n"); @@ -704,6 +729,9 @@ ACTOR Future testerServerCore(TesterInterface interf, ACTOR Future clearData(Database cx) { state Transaction tr(cx); + state UID debugID = debugRandom()->randomUniqueID(); + TraceEvent("TesterClearingDatabaseStart", debugID).log(); + tr.debugTransaction(debugID); loop { try { // This transaction needs to be self-conflicting, but not conflict consistently with @@ -712,10 +740,10 @@ ACTOR Future clearData(Database cx) { tr.makeSelfConflicting(); wait(success(tr.getReadVersion())); // required since we use addReadConflictRange but not get wait(tr.commit()); - TraceEvent("TesterClearingDatabase").detail("AtVersion", tr.getCommittedVersion()); + TraceEvent("TesterClearingDatabase", debugID).detail("AtVersion", tr.getCommittedVersion()); break; } catch (Error& e) { - TraceEvent(SevWarn, "TesterClearingDatabaseError").error(e); + TraceEvent(SevWarn, "TesterClearingDatabaseError", debugID).error(e); wait(tr.onError(e)); } } diff --git a/fdbserver/worker.actor.cpp b/fdbserver/worker.actor.cpp index df59a5794b..eb90bbcb6d 100644 --- a/fdbserver/worker.actor.cpp +++ b/fdbserver/worker.actor.cpp @@ -762,14 +762,14 @@ TEST_CASE("/fdbserver/worker/addressInDbAndPrimaryDc") { NetworkAddress grvProxyAddress(IPAddress(0x26262626), 1); GrvProxyInterface grvProxyInterf; grvProxyInterf.getConsistentReadVersion = - RequestStream(Endpoint({ grvProxyAddress }, UID(1, 2))); + PublicRequestStream(Endpoint({ grvProxyAddress }, UID(1, 2))); testDbInfo.client.grvProxies.push_back(grvProxyInterf); ASSERT(addressInDbAndPrimaryDc(grvProxyAddress, makeReference>(testDbInfo))); NetworkAddress commitProxyAddress(IPAddress(0x37373737), 1); CommitProxyInterface commitProxyInterf; commitProxyInterf.commit = - RequestStream(Endpoint({ commitProxyAddress }, UID(1, 2))); + PublicRequestStream(Endpoint({ commitProxyAddress }, UID(1, 2))); testDbInfo.client.commitProxies.push_back(commitProxyInterf); ASSERT(addressInDbAndPrimaryDc(commitProxyAddress, makeReference>(testDbInfo))); diff --git a/fdbserver/workloads/ClientWorkload.actor.cpp b/fdbserver/workloads/ClientWorkload.actor.cpp new file mode 100644 index 0000000000..803bdf04b7 --- /dev/null +++ b/fdbserver/workloads/ClientWorkload.actor.cpp @@ -0,0 +1,241 @@ +/* + * ClientWorkload.actor.cpp + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2013-2022 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "fdbserver/ServerDBInfo.actor.h" +#include "fdbserver/workloads/workloads.actor.h" +#include "fdbrpc/simulator.h" + +#include + +#include "flow/actorcompiler.h" // has to be last include + +class WorkloadProcessState { + IPAddress childAddress; + std::string processName; + Future processActor; + Promise init; + + WorkloadProcessState(int clientId) : clientId(clientId) { processActor = processStart(this); } + + ~WorkloadProcessState() { + TraceEvent("ShutdownClientForWorkload", id).log(); + g_simulator.destroyProcess(childProcess); + } + + ACTOR static Future initializationDone(WorkloadProcessState* self, ISimulator::ProcessInfo* parent) { + wait(g_simulator.onProcess(parent, TaskPriority::DefaultYield)); + self->init.send(Void()); + wait(Never()); + ASSERT(false); // does not happen + return Void(); + } + + ACTOR static Future processStart(WorkloadProcessState* self) { + state ISimulator::ProcessInfo* parent = g_simulator.getCurrentProcess(); + state std::vector> futures; + if (parent->address.isV6()) { + self->childAddress = + IPAddress::parse(fmt::format("2001:fdb1:fdb2:fdb3:fdb4:fdb5:fdb6:{:04x}", self->clientId + 2)).get(); + } else { + self->childAddress = IPAddress::parse(fmt::format("192.168.0.{}", self->clientId + 2)).get(); + } + self->processName = fmt::format("TestClient{}", self->clientId); + Standalone newZoneId(deterministicRandom()->randomUniqueID().toString()); + auto locality = LocalityData(Optional>(), newZoneId, newZoneId, parent->locality.dcId()); + auto dataFolder = joinPath(popPath(parent->dataFolder), deterministicRandom()->randomUniqueID().toString()); + platform::createDirectory(dataFolder); + TraceEvent("StartingClientWorkloadProcess", self->id) + .detail("Name", self->processName) + .detail("Address", self->childAddress); + self->childProcess = g_simulator.newProcess(self->processName.c_str(), + self->childAddress, + 1, + parent->address.isTLS(), + 1, + locality, + ProcessClass(ProcessClass::TesterClass, ProcessClass::AutoSource), + dataFolder.c_str(), + parent->coordinationFolder, + parent->protocolVersion); + self->childProcess->excludeFromRestarts = true; + wait(g_simulator.onProcess(self->childProcess, TaskPriority::DefaultYield)); + try { + FlowTransport::createInstance(true, 1, WLTOKEN_RESERVED_COUNT); + Sim2FileSystem::newFileSystem(); + auto addr = g_simulator.getCurrentProcess()->address; + futures.push_back(FlowTransport::transport().bind(addr, addr)); + futures.push_back(success((self->childProcess->onShutdown()))); + TraceEvent("ClientWorkloadProcessInitialized", self->id).log(); + futures.push_back(initializationDone(self, parent)); + wait(waitForAny(futures)); + } catch (Error& e) { + if (e.code() == error_code_actor_cancelled) { + return Void(); + } + ASSERT(false); + } + ASSERT(false); + return Void(); + } + + static std::vector& states() { + static std::vector res; + return res; + } + +public: + static WorkloadProcessState* instance(int clientId) { + states().resize(std::max(states().size(), size_t(clientId + 1)), nullptr); + auto& res = states()[clientId]; + if (res == nullptr) { + res = new WorkloadProcessState(clientId); + } + return res; + } + + Future initialized() const { return init.getFuture(); } + + UID id = deterministicRandom()->randomUniqueID(); + int clientId; + ISimulator::ProcessInfo* childProcess; +}; + +struct WorkloadProcess { + WorkloadProcessState* processState; + UID id = deterministicRandom()->randomUniqueID(); + Database cx; + Future databaseOpened; + Reference child; + std::string desc; + + void createDatabase(ClientWorkload::CreateWorkload const& childCreator, WorkloadContext const& wcx) { + try { + child = childCreator(wcx); + TraceEvent("ClientWorkloadOpenDatabase", id).detail("ClusterFileLocation", child->ccr->getLocation()); + cx = Database::createDatabase(child->ccr, -1); + desc = child->description(); + } catch (Error&) { + throw; + } catch (...) { + ASSERT(false); + } + } + + ACTOR static Future openDatabase(WorkloadProcess* self, + ClientWorkload::CreateWorkload childCreator, + WorkloadContext wcx) { + state ISimulator::ProcessInfo* parent = g_simulator.getCurrentProcess(); + state Optional err; + wcx.dbInfo = Reference const>(); + wait(self->processState->initialized()); + wait(g_simulator.onProcess(self->childProcess(), TaskPriority::DefaultYield)); + try { + self->createDatabase(childCreator, wcx); + } catch (Error& e) { + ASSERT(e.code() != error_code_actor_cancelled); + err = e; + } + wait(g_simulator.onProcess(parent, TaskPriority::DefaultYield)); + if (err.present()) { + throw err.get(); + } + return Void(); + } + + ISimulator::ProcessInfo* childProcess() { return processState->childProcess; } + + int clientId() const { return processState->clientId; } + + WorkloadProcess(ClientWorkload::CreateWorkload const& childCreator, WorkloadContext const& wcx) + : processState(WorkloadProcessState::instance(wcx.clientId)) { + TraceEvent("StartingClinetWorkload", id).detail("OnClientProcess", processState->id); + databaseOpened = openDatabase(this, childCreator, wcx); + } + + ACTOR static void destroy(WorkloadProcess* self) { + state ISimulator::ProcessInfo* parent = g_simulator.getCurrentProcess(); + wait(g_simulator.onProcess(self->childProcess(), TaskPriority::DefaultYield)); + delete self; + wait(g_simulator.onProcess(parent, TaskPriority::DefaultYield)); + } + + std::string description() { return desc; } + + ACTOR template + Future runActor(WorkloadProcess* self, Optional defaultTenant, Fun f) { + state Optional err; + state Ret res; + state ISimulator::ProcessInfo* parent = g_simulator.getCurrentProcess(); + wait(self->databaseOpened); + wait(g_simulator.onProcess(self->childProcess(), TaskPriority::DefaultYield)); + self->cx->defaultTenant = defaultTenant; + try { + Ret r = wait(f(self->cx)); + res = r; + } catch (Error& e) { + if (e.code() == error_code_actor_cancelled) { + ASSERT(g_simulator.getCurrentProcess() == parent); + throw; + } + err = e; + } + wait(g_simulator.onProcess(parent, TaskPriority::DefaultYield)); + if (err.present()) { + throw err.get(); + } + return res; + } +}; + +ClientWorkload::ClientWorkload(CreateWorkload const& childCreator, WorkloadContext const& wcx) + : TestWorkload(wcx), impl(new WorkloadProcess(childCreator, wcx)) {} + +ClientWorkload::~ClientWorkload() { + WorkloadProcess::destroy(impl); +} + +std::string ClientWorkload::description() const { + return impl->description(); +} + +Future ClientWorkload::initialized() { + return impl->databaseOpened; +} + +Future ClientWorkload::setup(Database const& cx) { + return impl->runActor(impl, cx->defaultTenant, [this](Database const& db) { return impl->child->setup(db); }); +} +Future ClientWorkload::start(Database const& cx) { + return impl->runActor(impl, cx->defaultTenant, [this](Database const& db) { return impl->child->start(db); }); +} +Future ClientWorkload::check(Database const& cx) { + return impl->runActor(impl, cx->defaultTenant, [this](Database const& db) { return impl->child->check(db); }); +} +Future> ClientWorkload::getMetrics() { + return impl->runActor>( + impl, Optional(), [this](Database const& db) { return impl->child->getMetrics(); }); +} +void ClientWorkload::getMetrics(std::vector& m) { + ASSERT(false); +} + +double ClientWorkload::getCheckTimeout() const { + return impl->child->getCheckTimeout(); +} diff --git a/fdbserver/workloads/Cycle.actor.cpp b/fdbserver/workloads/Cycle.actor.cpp index 4e17e82a88..23ce15858c 100644 --- a/fdbserver/workloads/Cycle.actor.cpp +++ b/fdbserver/workloads/Cycle.actor.cpp @@ -268,4 +268,4 @@ struct CycleWorkload : TestWorkload { } }; -WorkloadFactory CycleWorkloadFactory("Cycle"); +WorkloadFactory CycleWorkloadFactory("Cycle", true); diff --git a/fdbserver/workloads/PrivateEndpoints.actor.cpp b/fdbserver/workloads/PrivateEndpoints.actor.cpp new file mode 100644 index 0000000000..5c27adf058 --- /dev/null +++ b/fdbserver/workloads/PrivateEndpoints.actor.cpp @@ -0,0 +1,142 @@ +/* + * PrivateEndpoints.actor.cpp + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2013-2022 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "fdbserver/workloads/workloads.actor.h" + +#include "flow/actorcompiler.h" // has to be last include + +namespace { + +struct PrivateEndpoints : TestWorkload { + static constexpr const char* WorkloadName = "PrivateEndpoints"; + bool success = true; + int numSuccesses = 0; + double startAfter; + double runFor; + + std::vector(Reference> const&)>> testFunctions; + + template + static Optional getRandom(std::vector const& v) { + if (v.empty()) { + return Optional(); + } else { + return deterministicRandom()->randomChoice(v); + } + } + + template + static Optional getInterface(Reference> const& clientDBInfo) { + if constexpr (std::is_same_v) { + return getRandom(clientDBInfo->get().grvProxies); + } else if constexpr (std::is_same_v) { + return getRandom(clientDBInfo->get().commitProxies); + } else { + ASSERT(false); // don't know how to handle this type + } + } + + ACTOR template + static Future assumeFailure(Future f) { + try { + T t = wait(f); + (void)t; + ASSERT(false); + } catch (Error& e) { + if (e.code() == error_code_actor_cancelled) { + throw; + } else if (e.code() == error_code_unauthorized_attempt) { + TraceEvent("SuccessPrivateEndpoint").log(); + } else if (e.code() == error_code_request_maybe_delivered) { + // this is also fine, because even when calling private endpoints + // we might see connection failures + TraceEvent("SuccessRequestMaybeDelivered").log(); + } else { + TraceEvent(SevError, "WrongErrorCode").error(e); + } + } + return Void(); + } + + template + void addTestFor(RequestStream I::*channel) { + testFunctions.push_back([channel](Reference> const& clientDBInfo) { + auto optintf = getInterface(clientDBInfo); + if (!optintf.present()) { + return clientDBInfo->onChange(); + } + RequestStream s = optintf.get().*channel; + RT req; + return assumeFailure(deterministicRandom()->coinflip() ? throwErrorOr(s.tryGetReply(req)) + : s.getReply(req)); + }); + } + + explicit PrivateEndpoints(WorkloadContext const& wcx) : TestWorkload(wcx) { + startAfter = getOption(options, "startAfter"_sr, 10.0); + runFor = getOption(options, "runFor"_sr, 10.0); + addTestFor(&GrvProxyInterface::waitFailure); + addTestFor(&GrvProxyInterface::getHealthMetrics); + addTestFor(&CommitProxyInterface::getStorageServerRejoinInfo); + addTestFor(&CommitProxyInterface::waitFailure); + addTestFor(&CommitProxyInterface::txnState); + addTestFor(&CommitProxyInterface::getHealthMetrics); + addTestFor(&CommitProxyInterface::proxySnapReq); + addTestFor(&CommitProxyInterface::exclusionSafetyCheckReq); + addTestFor(&CommitProxyInterface::getDDMetrics); + } + std::string description() const override { return WorkloadName; } + Future start(Database const& cx) override { return _start(this, cx); } + Future check(Database const& cx) override { return success; } + void getMetrics(std::vector& m) override { + m.emplace_back("Successes", double(numSuccesses), Averaged::True); + } + + ACTOR static Future _start(PrivateEndpoints* self, Database cx) { + state Reference> clientInfo = cx->clientInfo; + state Future end; + TraceEvent("PrivateEndpointTestStartWait").detail("WaitTime", self->startAfter).log(); + wait(delay(self->startAfter)); + TraceEvent("PrivateEndpointTestStart").detail("RunFor", self->runFor).log(); + end = delay(self->runFor); + try { + loop { + auto testFuture = deterministicRandom()->randomChoice(self->testFunctions)(cx->clientInfo); + choose { + when(wait(end)) { + TraceEvent("PrivateEndpointTestDone").log(); + return Void(); + } + when(wait(testFuture)) { ++self->numSuccesses; } + } + wait(delay(0.2)); + } + } catch (Error& e) { + TraceEvent(SevError, "PrivateEndpointTestError").errorUnsuppressed(e); + ASSERT(false); + } + UNREACHABLE(); + return Void(); + } +}; + +} // namespace + +WorkloadFactory PrivateEndpointsFactory(PrivateEndpoints::WorkloadName, true); diff --git a/fdbserver/workloads/SaveAndKill.actor.cpp b/fdbserver/workloads/SaveAndKill.actor.cpp index 3f3aa66f4c..7c61b19ebe 100644 --- a/fdbserver/workloads/SaveAndKill.actor.cpp +++ b/fdbserver/workloads/SaveAndKill.actor.cpp @@ -89,7 +89,7 @@ struct SaveAndKillWorkload : TestWorkload { for (const auto& [_, process] : allProcessesMap) { std::string machineId = printable(process->locality.machineId()); const char* machineIdString = machineId.c_str(); - if (strcmp(process->name, "TestSystem") != 0) { + if (!process->excludeFromRestarts) { if (machines.find(machineId) == machines.end()) { machines.insert(std::pair(machineId, 1)); ini.SetValue("META", format("%d", j).c_str(), machineIdString); diff --git a/fdbserver/workloads/workloads.actor.h b/fdbserver/workloads/workloads.actor.h index c7c0557b68..bdbbd5707c 100644 --- a/fdbserver/workloads/workloads.actor.h +++ b/fdbserver/workloads/workloads.actor.h @@ -30,7 +30,10 @@ #include "fdbserver/KnobProtectiveGroups.h" #include "fdbserver/TesterInterface.actor.h" #include "fdbrpc/simulator.h" -#include "flow/actorcompiler.h" + +#include + +#include "flow/actorcompiler.h" // has to be last import /* * Gets an Value from a list of key/value pairs, using a default value if the key is not present. @@ -51,6 +54,7 @@ struct WorkloadContext { int clientId, clientCount; int64_t sharedRandomNumber; Reference const> dbInfo; + Reference ccr; WorkloadContext(); WorkloadContext(const WorkloadContext&); @@ -69,15 +73,40 @@ struct TestWorkload : NonCopyable, WorkloadContext, ReferenceCounted initialized() { return Void(); } virtual std::string description() const = 0; virtual Future setup(Database const& cx) { return Void(); } virtual Future start(Database const& cx) = 0; virtual Future check(Database const& cx) = 0; - virtual void getMetrics(std::vector& m) = 0; + virtual Future> getMetrics() { + std::vector result; + getMetrics(result); + return result; + } virtual double getCheckTimeout() const { return 3000; } enum WorkloadPhase { SETUP = 1, EXECUTION = 2, CHECK = 4, METRICS = 8 }; + +private: + virtual void getMetrics(std::vector& m) = 0; +}; + +struct WorkloadProcess; +struct ClientWorkload : TestWorkload { + WorkloadProcess* impl; + using CreateWorkload = std::function(WorkloadContext const&)>; + ClientWorkload(CreateWorkload const& childCreator, WorkloadContext const& wcx); + ~ClientWorkload(); + Future initialized() override; + std::string description() const override; + Future setup(Database const& cx) override; + Future start(Database const& cx) override; + Future check(Database const& cx) override; + void getMetrics(std::vector& m) override; + Future> getMetrics() override; + + double getCheckTimeout() const override; }; struct KVWorkload : TestWorkload { @@ -122,8 +151,17 @@ struct IWorkloadFactory : ReferenceCounted { template struct WorkloadFactory : IWorkloadFactory { - WorkloadFactory(const char* name) { factories()[name] = Reference::addRef(this); } - Reference create(WorkloadContext const& wcx) override { return makeReference(wcx); } + bool asClient; + WorkloadFactory(const char* name, bool asClient = false) : asClient(asClient) { + factories()[name] = Reference::addRef(this); + } + Reference create(WorkloadContext const& wcx) override { + if (g_network->isSimulated() && asClient) { + return makeReference( + [](WorkloadContext const& wcx) { return makeReference(wcx); }, wcx); + } + return makeReference(wcx); + } }; #define REGISTER_WORKLOAD(classname) WorkloadFactory classname##WorkloadFactory(#classname) diff --git a/flow/CMakeLists.txt b/flow/CMakeLists.txt index 5cd37810b5..a2246df1a1 100644 --- a/flow/CMakeLists.txt +++ b/flow/CMakeLists.txt @@ -138,6 +138,7 @@ target_link_libraries(flow PUBLIC fmt::fmt) add_flow_target(STATIC_LIBRARY NAME flow_sampling SRCS ${FLOW_SRCS}) target_link_libraries(flow_sampling PRIVATE stacktrace) +target_link_libraries(flow_sampling PUBLIC fmt::fmt) target_compile_definitions(flow_sampling PRIVATE -DENABLE_SAMPLING) if(WIN32) add_dependencies(flow_sampling_actors flow_actors) diff --git a/flow/ObjectSerializer.h b/flow/ObjectSerializer.h index 84937e31de..bd4cfe33e0 100644 --- a/flow/ObjectSerializer.h +++ b/flow/ObjectSerializer.h @@ -24,10 +24,16 @@ #include "flow/flat_buffers.h" #include "flow/ProtocolVersion.h" +#include + +using ContextVariableMap = std::unordered_map; + template struct LoadContext { Ar* ar; + LoadContext(Ar* ar) : ar(ar) {} + Arena& arena() { return ar->arena(); } ProtocolVersion protocolVersion() const { return ar->protocolVersion(); } @@ -68,20 +74,23 @@ struct SaveContext { template class _ObjectReader { protected: - ProtocolVersion mProtocolVersion; + Optional mProtocolVersion; + std::shared_ptr variables; public: - ProtocolVersion protocolVersion() const { return mProtocolVersion; } + ProtocolVersion protocolVersion() const { return mProtocolVersion.get(); } void setProtocolVersion(ProtocolVersion v) { mProtocolVersion = v; } + void setContextVariableMap(std::shared_ptr const& cvm) { variables = cvm; } template void deserialize(FileIdentifier file_identifier, Items&... items) { - const uint8_t* data = static_cast(this)->data(); LoadContext context(static_cast(this)); + const uint8_t* data = static_cast(this)->data(); if (read_file_identifier(data) != file_identifier) { // Some file identifiers are changed in 7.0, so file identifier mismatches // are expected during a downgrade from 7.0 to 6.3 - bool expectMismatch = mProtocolVersion >= ProtocolVersion(0x0FDB00B070000000LL); + bool expectMismatch = mProtocolVersion.get() >= ProtocolVersion(0x0FDB00B070000000LL) && + currentProtocolVersion < ProtocolVersion(0x0FDB00B070000000LL); { TraceEvent te(expectMismatch ? SevInfo : SevError, "MismatchedFileIdentifier"); if (expectMismatch) { @@ -100,6 +109,24 @@ public: void deserialize(Item& item) { deserialize(FileIdentifierFor::value, item); } + + template + bool variable(std::string_view name, T* val) { + auto p = variables->insert(std::make_pair(name, val)); + return p.second; + } + + template + T& variable(std::string_view name) { + auto res = variables->at(name); + return *reinterpret_cast(res); + } + + template + T const& variable(std::string_view name) const { + auto res = variables->at(name); + return *reinterpret_cast(res); + } }; class ObjectReader : public _ObjectReader { diff --git a/flow/error_definitions.h b/flow/error_definitions.h index 1a054ce43e..a46e8dc796 100755 --- a/flow/error_definitions.h +++ b/flow/error_definitions.h @@ -311,6 +311,10 @@ ERROR( encrypt_invalid_id, 2706, "Invalid encryption domainId or encryption ciph ERROR( unknown_error, 4000, "An unknown error occurred" ) // C++ exception not of type Error ERROR( internal_error, 4100, "An internal error occurred" ) ERROR( not_implemented, 4200, "Not implemented yet" ) + +// 6xxx Authorization and authentication error codes +ERROR( permission_denied, 6000, "Client tried to access unauthorized data" ) +ERROR( unauthorized_attempt, 6001, "A untrusted client tried to send a message to a private endpoint" ) // clang-format on #undef ERROR diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index d7d03858b1..e8e16777f0 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -161,6 +161,7 @@ if(WITH_PYTHON) add_fdb_test(TEST_FILES fast/MoveKeysCycle.toml) add_fdb_test(TEST_FILES fast/MutationLogReaderCorrectness.toml) add_fdb_test(TEST_FILES fast/GetMappedRange.toml) + add_fdb_test(TEST_FILES fast/PrivateEndpoints.toml) add_fdb_test(TEST_FILES fast/ProtocolVersion.toml) add_fdb_test(TEST_FILES fast/RandomSelector.toml) add_fdb_test(TEST_FILES fast/RandomUnitTests.toml) diff --git a/tests/fast/PrivateEndpoints.toml b/tests/fast/PrivateEndpoints.toml new file mode 100644 index 0000000000..67be2c6b01 --- /dev/null +++ b/tests/fast/PrivateEndpoints.toml @@ -0,0 +1,5 @@ +[[test]] +testTitle = 'PrivateEndpoints' + +[[test.workload]] +testName = 'PrivateEndpoints'