diff --git a/CMakeLists.txt b/CMakeLists.txt index e07a9c9ac3..5d43cdc201 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -167,6 +167,7 @@ endif() include(CompileBoost) include(GetMsgpack) +add_subdirectory(contrib) add_subdirectory(flow) add_subdirectory(fdbrpc) add_subdirectory(fdbclient) @@ -178,7 +179,6 @@ else() add_subdirectory(fdbservice) endif() add_subdirectory(fdbbackup) -add_subdirectory(contrib) add_subdirectory(tests) add_subdirectory(flowbench EXCLUDE_FROM_ALL) if(WITH_PYTHON AND WITH_C_BINDING) diff --git a/fdbclient/BackupAgent.actor.h b/fdbclient/BackupAgent.actor.h index 5669f5d9d7..1f18b10a4c 100644 --- a/fdbclient/BackupAgent.actor.h +++ b/fdbclient/BackupAgent.actor.h @@ -561,7 +561,7 @@ ACTOR Future applyMutations(Database cx, Key removePrefix, Version beginVersion, Version* endVersion, - RequestStream commit, + RequestStream commit, NotifiedVersion* committedVersion, Reference> keyVersion); ACTOR Future cleanupBackup(Database cx, DeleteData deleteData); diff --git a/fdbclient/BackupAgentBase.actor.cpp b/fdbclient/BackupAgentBase.actor.cpp index 74f40743d2..75445d52f3 100644 --- a/fdbclient/BackupAgentBase.actor.cpp +++ b/fdbclient/BackupAgentBase.actor.cpp @@ -598,7 +598,7 @@ ACTOR Future dumpData(Database cx, Key uid, Key addPrefix, Key removePrefix, - RequestStream commit, + RequestStream commit, NotifiedVersion* committedVersion, Optional endVersion, Key rangeBegin, @@ -675,7 +675,7 @@ ACTOR Future dumpData(Database cx, ACTOR Future coalesceKeyVersionCache(Key uid, Version endVersion, Reference> keyVersion, - RequestStream commit, + RequestStream commit, NotifiedVersion* committedVersion, PromiseStream> addActor, FlowLock* commitLock) { @@ -725,7 +725,7 @@ ACTOR Future applyMutations(Database cx, Key removePrefix, Version beginVersion, Version* endVersion, - RequestStream commit, + RequestStream commit, NotifiedVersion* committedVersion, Reference> keyVersion) { state FlowLock commitLock(CLIENT_KNOBS->BACKUP_LOCK_BYTES); diff --git a/fdbclient/CommitProxyInterface.h b/fdbclient/CommitProxyInterface.h index 102b5a2088..af8f733ec3 100644 --- a/fdbclient/CommitProxyInterface.h +++ b/fdbclient/CommitProxyInterface.h @@ -71,9 +71,9 @@ struct CommitProxyInterface { serializer(ar, processId, provisional, commit); if (Archive::isDeserializing) { getConsistentReadVersion = - RequestStream(commit.getEndpoint().getAdjustedEndpoint(1)); + RequestStream(commit.getEndpoint().getAdjustedEndpoint(1)); getKeyServersLocations = - RequestStream(commit.getEndpoint().getAdjustedEndpoint(2)); + RequestStream(commit.getEndpoint().getAdjustedEndpoint(2)); getStorageServerRejoinInfo = RequestStream(commit.getEndpoint().getAdjustedEndpoint(3)); waitFailure = RequestStream>(commit.getEndpoint().getAdjustedEndpoint(4)); diff --git a/fdbclient/NativeAPI.actor.cpp b/fdbclient/NativeAPI.actor.cpp index 79165dbf4c..ca5436e72c 100644 --- a/fdbclient/NativeAPI.actor.cpp +++ b/fdbclient/NativeAPI.actor.cpp @@ -100,11 +100,11 @@ namespace { TransactionLineageCollector transactionLineageCollector; NameLineageCollector nameLineageCollector; -template +template Future loadBalance( DatabaseContext* ctx, const Reference alternatives, - RequestStream Interface::*channel, + RequestStream Interface::*channel, const Request& request = Request(), TaskPriority taskID = TaskPriority::DefaultPromiseEndpoint, AtMostOnce atMostOnce = @@ -2223,7 +2223,7 @@ void stopNetwork() { if (!g_network) throw network_not_setup(); - TraceEvent("ClientStopNetwork"); + TraceEvent("ClientStopNetwork").log(); g_network->stop(); closeTraceFile(); } @@ -3176,7 +3176,7 @@ void transformRangeLimits(GetRangeLimits limits, Reverse reverse, GetKeyValuesFa } template -RequestStream StorageServerInterface::*getRangeRequestStream() { +RequestStream StorageServerInterface::*getRangeRequestStream() { if constexpr (std::is_same::value) { return &StorageServerInterface::getKeyValues; } else if (std::is_same::value) { @@ -3908,9 +3908,9 @@ static Future tssStreamComparison(Request request, // Currently only used for GetKeyValuesStream but could easily be plugged for other stream types // User of the stream has to forward the SS's responses to the returned promise stream, if it is set -template +template Optional> -maybeDuplicateTSSStreamFragment(Request& req, QueueModel* model, RequestStream const* ssStream) { +maybeDuplicateTSSStreamFragment(Request& req, QueueModel* model, RequestStream const* ssStream) { if (model) { Optional tssData = model->getTssData(ssStream->getEndpoint().token.first()); diff --git a/fdbclient/StorageServerInterface.h b/fdbclient/StorageServerInterface.h index 3da2ad402b..9636367ec9 100644 --- a/fdbclient/StorageServerInterface.h +++ b/fdbclient/StorageServerInterface.h @@ -68,11 +68,11 @@ struct StorageServerInterface { RequestStream getKeyValues; RequestStream getKeyValuesAndFlatMap; - RequestStream getShardState; + RequestStream getShardState; RequestStream waitMetrics; RequestStream splitMetrics; RequestStream getStorageMetrics; - RequestStream, true> waitFailure; + RequestStream> waitFailure; RequestStream getQueuingMetrics; RequestStream> getKeyValueStoreType; @@ -106,8 +106,8 @@ struct StorageServerInterface { serializer(ar, uniqueID, locality, getValue); } if (Ar::isDeserializing) { - getKey = RequestStream(getValue.getEndpoint().getAdjustedEndpoint(1)); - getKeyValues = RequestStream(getValue.getEndpoint().getAdjustedEndpoint(2)); + getKey = RequestStream(getValue.getEndpoint().getAdjustedEndpoint(1)); + getKeyValues = RequestStream(getValue.getEndpoint().getAdjustedEndpoint(2)); getShardState = RequestStream(getValue.getEndpoint().getAdjustedEndpoint(3)); waitMetrics = RequestStream(getValue.getEndpoint().getAdjustedEndpoint(4)); @@ -119,22 +119,22 @@ struct StorageServerInterface { RequestStream(getValue.getEndpoint().getAdjustedEndpoint(8)); getKeyValueStoreType = RequestStream>(getValue.getEndpoint().getAdjustedEndpoint(9)); - watchValue = RequestStream(getValue.getEndpoint().getAdjustedEndpoint(10)); + watchValue = RequestStream(getValue.getEndpoint().getAdjustedEndpoint(10)); getReadHotRanges = RequestStream(getValue.getEndpoint().getAdjustedEndpoint(11)); getRangeSplitPoints = RequestStream(getValue.getEndpoint().getAdjustedEndpoint(12)); getKeyValuesStream = - RequestStream(getValue.getEndpoint().getAdjustedEndpoint(13)); + RequestStream(getValue.getEndpoint().getAdjustedEndpoint(13)); getKeyValuesAndFlatMap = - RequestStream(getValue.getEndpoint().getAdjustedEndpoint(14)); + RequestStream(getValue.getEndpoint().getAdjustedEndpoint(14)); changeFeedStream = - RequestStream(getValue.getEndpoint().getAdjustedEndpoint(15)); + RequestStream(getValue.getEndpoint().getAdjustedEndpoint(15)); overlappingChangeFeeds = - RequestStream(getValue.getEndpoint().getAdjustedEndpoint(16)); + RequestStream(getValue.getEndpoint().getAdjustedEndpoint(16)); changeFeedPop = - RequestStream(getValue.getEndpoint().getAdjustedEndpoint(17)); - changeFeedVersionUpdate = RequestStream( + RequestStream(getValue.getEndpoint().getAdjustedEndpoint(17)); + changeFeedVersionUpdate = RequestStream( getValue.getEndpoint().getAdjustedEndpoint(18)); } } else { diff --git a/fdbrpc/CMakeLists.txt b/fdbrpc/CMakeLists.txt index 046ba4ff46..00c149aec5 100644 --- a/fdbrpc/CMakeLists.txt +++ b/fdbrpc/CMakeLists.txt @@ -15,6 +15,7 @@ set(FDBRPC_SRCS genericactors.actor.cpp HealthMonitor.actor.cpp IAsyncFile.actor.cpp + IPAllowList.cpp LoadBalance.actor.cpp LoadBalance.actor.h Locality.cpp diff --git a/fdbrpc/FlowTransport.actor.cpp b/fdbrpc/FlowTransport.actor.cpp index 851e3c084f..a9dc7d8c1f 100644 --- a/fdbrpc/FlowTransport.actor.cpp +++ b/fdbrpc/FlowTransport.actor.cpp @@ -32,6 +32,7 @@ #include "fdbrpc/FailureMonitor.h" #include "fdbrpc/HealthMonitor.h" #include "fdbrpc/genericactors.actor.h" +#include "fdbrpc/IPAllowList.h" #include "fdbrpc/simulator.h" #include "flow/ActorCollection.h" #include "flow/Error.h" @@ -1807,7 +1808,7 @@ bool FlowTransport::incompatibleOutgoingConnectionsPresent() { return self->numIncompatibleConnections > 0; } -void FlowTransport::createInstance(bool isClient, uint64_t transportId, int maxWellKnownEndpoints) { +void FlowTransport::createInstance(bool isClient, uint64_t transportId, int maxWellKnownEndpoints, IPAllowList const* allowList) { g_network->setGlobal(INetwork::enFlowTransport, (flowGlobalType) new FlowTransport(transportId, maxWellKnownEndpoints)); g_network->setGlobal(INetwork::enNetworkAddressFunc, (flowGlobalType)&FlowTransport::getGlobalLocalAddress); diff --git a/fdbrpc/FlowTransport.h b/fdbrpc/FlowTransport.h index ab132d78f4..9ff39a5596 100644 --- a/fdbrpc/FlowTransport.h +++ b/fdbrpc/FlowTransport.h @@ -182,6 +182,8 @@ struct Peer : public ReferenceCounted { void onIncomingConnection(Reference self, Reference conn, Future reader); }; +class IPAllowList; + class FlowTransport { public: FlowTransport(uint64_t transportId, int maxWellKnownEndpoints); @@ -189,7 +191,7 @@ public: // Creates a new FlowTransport and makes FlowTransport::transport() return it. This uses g_network->global() // variables, so it will be private to a simulation. - static void createInstance(bool isClient, uint64_t transportId, int maxWellKnownEndpoints); + static void createInstance(bool isClient, uint64_t transportId, int maxWellKnownEndpoints, IPAllowList const* allowList = nullptr); static bool isClient() { return g_network->global(INetwork::enClientFailureMonitor) != nullptr; } diff --git a/fdbrpc/IPAllowList.cpp b/fdbrpc/IPAllowList.cpp new file mode 100644 index 0000000000..dbe7cfa012 --- /dev/null +++ b/fdbrpc/IPAllowList.cpp @@ -0,0 +1,290 @@ +/* + * IPAllowList.h + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2013-2022 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "fdbrpc/IPAllowList.h" +#include "flow/UnitTest.h" + +#include +#include + +namespace { + +template +std::string binRep(std::array const& addr) { + return fmt::format("{:02x}", fmt::join(addr, ":")); +} + +template +void printIP(std::array const& addr) { + fmt::print(" {}", binRep(addr)); +} + +template +int hostCountImpl(std::array const& addr) { + int count = 0; + for (int i = 0; i < addr.size() && addr[i] != 0xff; ++i) { + std::bitset<8> b(addr[i]); + count += 8 - b.count(); + } + return count; +} + +} // namespace + +IPAddress AuthAllowedSubnet::netmask() const { + if (addressMask.isV4()) { + uint32_t res = 0xffffffff ^ addressMask.toV4(); + return IPAddress(res); + } else { + std::array res; + res.fill(0xff); + auto mask = addressMask.toV6(); + for (int i = 0; i < mask.size(); ++i) { + res[i] ^= mask[i]; + } + return IPAddress(res); + } +} + + +int AuthAllowedSubnet::hostCount() const { + if (addressMask.isV4()) { + boost::asio::ip::address_v4 addr(addressMask.toV4()); + return hostCountImpl(addr.to_bytes()); + } else { + return hostCountImpl(addressMask.toV6()); + } +} + +AuthAllowedSubnet AuthAllowedSubnet::fromString(std::string_view addressString) { + auto pos = addressString.find('/'); + if (pos == std::string_view::npos) { + fmt::print("ERROR: {} is not a valid (use Network-Prefix/hostcount syntax)\n"); + throw invalid_option(); + } + auto address = addressString.substr(0, pos); + auto hostCount = std::stoi(std::string(addressString.substr(pos + 1))); + auto addr = boost::asio::ip::make_address(address); + if (addr.is_v4()) { + auto bM = createBitMask(addr.to_v4().to_bytes(), hostCount); + // we typically would expect a base address has been passed, but to be safe we still + // will make the last bits 0 + auto mask = boost::asio::ip::address_v4(bM).to_uint(); + auto baseAddress = addr.to_v4().to_uint() & mask; + fmt::print("For address {}:", addressString); + printIP("Base Address", IPAddress(baseAddress)); + printIP("Mask:", IPAddress(mask)); + return AuthAllowedSubnet(IPAddress(baseAddress), IPAddress(mask)); + } else { + auto mask = createBitMask(addr.to_v6().to_bytes(), hostCount); + auto baseAddress = addr.to_v6().to_bytes(); + for (int i = 0; i < mask.size(); ++i) { + baseAddress[i] &= mask[i]; + } + return AuthAllowedSubnet(IPAddress(baseAddress), IPAddress(mask)); + } +} + +void AuthAllowedSubnet::printIP(std::string_view txt, IPAddress const& address) { + fmt::print("{}:", txt); + if (address.isV4()) { + ::printIP(boost::asio::ip::address_v4(address.toV4()).to_bytes()); + } else { + ::printIP(address.toV6()); + } + fmt::print("\n"); +} + +// helpers for testing +namespace { +using boost::asio::ip::address_v4; +using boost::asio::ip::address_v6; + +void traceAddress(TraceEvent& evt, const char* name, IPAddress address) { + evt.detail(name, address); + std::string bin; + if (address.isV4()) { + boost::asio::ip::address_v4 a(address.toV4()); + bin = binRep(a.to_bytes()); + } else { + bin = binRep(address.toV6()); + } + evt.detail(fmt::format("{}Binary", name).c_str(), bin); +} + +void subnetAssert(IPAllowList const& allowList, IPAddress addr, bool expectAllowed) { + if (allowList(addr) == expectAllowed) { + return; + } + TraceEvent evt(SevError, expectAllowed ? "ExpectedAddressToBeTrusted" : "ExpectedAddressToBeUntrusted"); + traceAddress(evt, "Address", addr); + auto const& subnets = allowList.subnets(); + for (int i = 0; i < subnets.size(); ++i) { + traceAddress(evt, fmt::format("SubnetBase{}", i).c_str(), subnets[i].baseAddress); + traceAddress(evt, fmt::format("SubnetMask{}", i).c_str(), subnets[i].addressMask); + } +} + +IPAddress parseAddr(std::string const& str) { + auto res = IPAddress::parse(str); + ASSERT(res.present()); + return res.get(); +} + +struct SubNetTest { + AuthAllowedSubnet subnet; + SubNetTest(AuthAllowedSubnet&& subnet) + : subnet(subnet) + { + } + template + static SubNetTest randomSubNetImpl() { + constexpr int width = V4 ? 4 : 16; + std::array binAddr; + unsigned char rnd[4]; + for (int i = 0; i < binAddr.size(); ++i) { + if (i % 4 == 0) { + auto tmp = deterministicRandom()->randomUInt32(); + ::memcpy(rnd, &tmp, 4); + } + binAddr[i] = rnd[i % 4]; + } + auto hostCount = deterministicRandom()->randomInt(1, width); + std::string address; + if constexpr (V4) { + address_v4 a(binAddr); + address = a.to_string(); + } else { + address_v6 a(binAddr); + address = a.to_string(); + } + return SubNetTest(AuthAllowedSubnet::fromString(fmt::format("{}/{}", address, hostCount))); + } + static SubNetTest randomSubNet() { + if (deterministicRandom()->coinflip()) { + return randomSubNetImpl(); + } else { + return randomSubNetImpl(); + } + } + + template + IPAddress intArrayToAddress(uint32_t* arr) { + if constexpr (V4) { + return IPAddress(arr[0]); + } else { + std::array res; + memcpy(res.data(), arr, 4); + return IPAddress(res); + } + } + + template + I transformIntToSubnet(I val, I subnetMask, I baseAddress) { + return (val & subnetMask) ^ baseAddress; + } + + template + IPAddress randomAddress(bool inSubnet) { + ASSERT(V4 == subnet.baseAddress.isV4() || !inSubnet); + constexpr int width = V4 ? 4 : 16; + for (;;) { + uint32_t rnd[width / 4]; + for (int i = 0; i < width / 4; ++i) { + rnd[i] = deterministicRandom()->randomUInt32(); + } + auto res = intArrayToAddress(rnd); + if (V4 != subnet.baseAddress.isV4()) { + return res; + } + if (!inSubnet) { + if (!subnet(res)) { + return res; + } else { + continue; + } + } + // first we make sure the address is in the subnet + if constexpr (V4) { + auto a = res.toV4(); + auto base = subnet.baseAddress.toV4(); + auto netmask = subnet.netmask().toV4(); + auto validAddress = transformIntToSubnet(a, netmask, base); + res = IPAddress(validAddress); + } else { + auto a = res.toV6(); + auto base = subnet.baseAddress.toV6(); + auto netmask = subnet.netmask().toV6(); + for (int i = 0; i < a.size(); ++i) { + a[i] = transformIntToSubnet(a[i], netmask[i], base[i]); + } + res = IPAddress(a); + } + return res; + } + } + + IPAddress randomAddress(bool inSubnet) { + if (!inSubnet && deterministicRandom()->random01() < 0.1) { + // return an address of a different type + if (subnet.baseAddress.isV4()) { + return randomAddress(false); + } else { + return randomAddress(false); + } + } + if (subnet.addressMask.isV4()) { + return randomAddress(inSubnet); + } else { + return randomAddress(inSubnet); + } + } +}; + +} // namespace + +TEST_CASE("/fdbrpc/allow_list") { + IPAllowList allowList; + allowList.addTrustedSubnet("1.0.0.0/8"); + allowList.addTrustedSubnet("2.0.0.0/4"); + ::subnetAssert(allowList, parseAddr("1.0.1.1"), true); + ::subnetAssert(allowList, parseAddr("1.1.2.2"), true); + ::subnetAssert(allowList, parseAddr("2.2.1.1"), true); + ::subnetAssert(allowList, parseAddr("128.0.1.1"), false); + allowList = IPAllowList(); + allowList.addTrustedSubnet("0.0.0.0/2"); + ::subnetAssert(allowList, parseAddr("1.0.1.1"), true); + ::subnetAssert(allowList, parseAddr("1.1.2.2"), true); + ::subnetAssert(allowList, parseAddr("2.2.1.1"), true); + ::subnetAssert(allowList, parseAddr("5.2.1.1"), true); + ::subnetAssert(allowList, parseAddr("128.0.1.1"), false); + ::subnetAssert(allowList, parseAddr("192.168.3.1"), false); + for (int i = 0; i < 100; ++i) { + SubNetTest subnetTest(SubNetTest::randomSubNet()); + allowList = IPAllowList(); + allowList.addTrustedSubnet(subnetTest.subnet); + for (int j = 0; j < 10; ++j) { + bool inSubnet = deterministicRandom()->random01() < 0.7; + auto addr = subnetTest.randomAddress(inSubnet); + ::subnetAssert(allowList, addr, inSubnet); + } + } + return Void(); +} diff --git a/fdbrpc/IPAllowList.h b/fdbrpc/IPAllowList.h new file mode 100644 index 0000000000..9860029e4a --- /dev/null +++ b/fdbrpc/IPAllowList.h @@ -0,0 +1,109 @@ +/* + * IPAllowList.h + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2013-2022 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#ifndef FDBRPC_IP_ALLOW_LIST_H +#define FDBRPC_IP_ALLOW_LIST_H + +#include "flow/network.h" +#include "flow/Arena.h" + +struct AuthAllowedSubnet { + IPAddress baseAddress; + IPAddress addressMask; + + AuthAllowedSubnet(IPAddress const& baseAddress, IPAddress const& addressMask) + : baseAddress(baseAddress), addressMask(addressMask) { + ASSERT(baseAddress.isV4() == addressMask.isV4()); + } + + static AuthAllowedSubnet fromString(std::string_view addressString); + + template + static auto createBitMask(std::array const& addr, int hostCount) -> std::array { + std::array res; + res.fill((unsigned char)0xff); + int idx = hostCount / 8; + if (hostCount % 8 > 0) { + // 2^(hostCount % 8) - 1 sets the last (hostCount % 8) number of bits to 1 + // everything else will be zero. For example: 2^3 - 1 == 7 == 0b111 + unsigned char bitmask = (1 << (hostCount % 8)) - ((unsigned char)1); + res[idx] ^= bitmask; + ++idx; + } + for (; idx < res.size(); ++idx) { + res[idx] = (unsigned char)0; + } + return res; + } + + bool operator() (IPAddress const& address) const { + if (addressMask.isV4() != address.isV4()) { + return false; + } + if (addressMask.isV4()) { + return (addressMask.toV4() & address.toV4()) == baseAddress.toV4(); + } else { + auto res = address.toV6(); + auto const& mask = addressMask.toV6(); + for (int i = 0; i < res.size(); ++i) { + res[i] &= mask[i]; + } + return res == baseAddress.toV6(); + } + } + + IPAddress netmask() const; + + int hostCount() const; + + // some useful helper functions if we need to debug ip masks etc + static void printIP(std::string_view txt, IPAddress const& address); +}; + +class IPAllowList { + std::vector subnetList; +public: + void addTrustedSubnet(std::string_view str) { + subnetList.push_back(AuthAllowedSubnet::fromString(str)); + } + + void addTrustedSubnet(AuthAllowedSubnet const& subnet) { + subnetList.push_back(subnet); + } + + std::vector const& subnets() const { + return subnetList; + } + + bool operator() (IPAddress address) const { + if (subnetList.empty()) { + return true; + } + for (auto const& subnet : subnetList) { + if (subnet(address)) { + return true; + } + } + return false; + } +}; + +#endif // FDBRPC_IP_ALLOW_LIST_H diff --git a/fdbrpc/LoadBalance.actor.h b/fdbrpc/LoadBalance.actor.h index 838a1762b8..d150fded11 100644 --- a/fdbrpc/LoadBalance.actor.h +++ b/fdbrpc/LoadBalance.actor.h @@ -78,14 +78,14 @@ struct LoadBalancedReply { Optional getLoadBalancedReply(const LoadBalancedReply* reply); Optional getLoadBalancedReply(const void*); -ACTOR template +ACTOR template Future tssComparison(Req req, Future> fSource, Future> fTss, TSSEndpointData tssData, uint64_t srcEndpointId, Reference> ssTeam, - RequestStream Interface::*channel) { + RequestStream Interface::*channel) { state double startTime = now(); state Future>> fTssWithTimeout = timeout(fTss, FLOW_KNOBS->LOAD_BALANCE_TSS_TIMEOUT); state int finished = 0; @@ -157,7 +157,7 @@ Future tssComparison(Req req, state std::vector>> restOfTeamFutures; restOfTeamFutures.reserve(ssTeam->size() - 1); for (int i = 0; i < ssTeam->size(); i++) { - RequestStream const* si = &ssTeam->get(i, channel); + RequestStream const* si = &ssTeam->get(i, channel); if (si->getEndpoint().token.first() != srcEndpointId) { // don't re-request to SS we already have a response from resetReply(req); @@ -242,7 +242,7 @@ FDB_DECLARE_BOOLEAN_PARAM(AtMostOnce); FDB_DECLARE_BOOLEAN_PARAM(TriedAllOptions); // Stores state for a request made by the load balancer -template +template struct RequestData : NonCopyable { typedef ErrorOr Reply; @@ -257,12 +257,12 @@ struct RequestData : NonCopyable { // This is true once setupRequest is called, even though at that point the response is Never(). bool isValid() { return response.isValid(); } - static void maybeDuplicateTSSRequest(RequestStream const* stream, + static void maybeDuplicateTSSRequest(RequestStream const* stream, Request& request, QueueModel* model, Future ssResponse, Reference> alternatives, - RequestStream Interface::*channel) { + RequestStream Interface::*channel) { if (model) { // Send parallel request to TSS pair, if it exists Optional tssData = model->getTssData(stream->getEndpoint().token.first()); @@ -271,7 +271,7 @@ struct RequestData : NonCopyable { TEST(true); // duplicating request to TSS resetReply(request); // FIXME: optimize to avoid creating new netNotifiedQueue for each message - RequestStream tssRequestStream(tssData.get().endpoint); + RequestStream tssRequestStream(tssData.get().endpoint); Future> fTssResult = tssRequestStream.tryGetReply(request); model->addActor.send(tssComparison(request, ssResponse, @@ -288,11 +288,11 @@ struct RequestData : NonCopyable { void startRequest( double backoff, TriedAllOptions triedAllOptions, - RequestStream const* stream, + RequestStream const* stream, Request& request, QueueModel* model, Reference> alternatives, // alternatives and channel passed through for TSS check - RequestStream Interface::*channel) { + RequestStream Interface::*channel) { modelHolder = Reference(); requestStarted = false; @@ -438,18 +438,18 @@ struct RequestData : NonCopyable { // list of servers. // When model is set, load balance among alternatives in the same DC aims to balance request queue length on these // interfaces. If too many interfaces in the same DC are bad, try remote interfaces. -ACTOR template +ACTOR template Future loadBalance( Reference> alternatives, - RequestStream Interface::*channel, + RequestStream Interface::*channel, Request request = Request(), TaskPriority taskID = TaskPriority::DefaultPromiseEndpoint, AtMostOnce atMostOnce = AtMostOnce::False, // if true, throws request_maybe_delivered() instead of retrying automatically QueueModel* model = nullptr) { - state RequestData firstRequestData; - state RequestData secondRequestData; + state RequestData firstRequestData; + state RequestData secondRequestData; state Optional firstRequestEndpoint; state Future secondDelay = Never(); @@ -488,7 +488,7 @@ Future loadBalance( break; } - RequestStream const* thisStream = &alternatives->get(i, channel); + RequestStream const* thisStream = &alternatives->get(i, channel); if (!IFailureMonitor::failureMonitor().getState(thisStream->getEndpoint()).failed) { auto& qd = model->getMeasurement(thisStream->getEndpoint().token.first()); if (now() > qd.failedUntil) { @@ -527,7 +527,7 @@ Future loadBalance( // go through all the remote servers again, since we may have // skipped it. for (int i = alternatives->countBest(); i < alternatives->size(); i++) { - RequestStream const* thisStream = &alternatives->get(i, channel); + RequestStream const* thisStream = &alternatives->get(i, channel); if (!IFailureMonitor::failureMonitor().getState(thisStream->getEndpoint()).failed) { auto& qd = model->getMeasurement(thisStream->getEndpoint().token.first()); if (now() > qd.failedUntil) { @@ -574,7 +574,7 @@ Future loadBalance( if (ev.isEnabled()) { ev.log(); for (int alternativeNum = 0; alternativeNum < alternatives->size(); alternativeNum++) { - RequestStream const* thisStream = &alternatives->get(alternativeNum, channel); + RequestStream const* thisStream = &alternatives->get(alternativeNum, channel); TraceEvent(SevWarn, "LoadBalanceTooLongEndpoint") .detail("Addr", thisStream->getEndpoint().getPrimaryAddress()) .detail("Token", thisStream->getEndpoint().token) @@ -586,7 +586,7 @@ Future loadBalance( // Find an alternative, if any, that is not failed, starting with // nextAlt. This logic matters only if model == nullptr. Otherwise, the // bestAlt and nextAlt have been decided. - state RequestStream const* stream = nullptr; + state RequestStream const* stream = nullptr; for (int alternativeNum = 0; alternativeNum < alternatives->size(); alternativeNum++) { int useAlt = nextAlt; if (nextAlt == startAlt) @@ -724,9 +724,9 @@ Optional getBasicLoadBalancedReply(const BasicLoadBalanc Optional getBasicLoadBalancedReply(const void*); // A simpler version of LoadBalance that does not send second requests where the list of servers are always fresh -ACTOR template +ACTOR template Future basicLoadBalance(Reference> alternatives, - RequestStream Interface::*channel, + RequestStream Interface::*channel, Request request = Request(), TaskPriority taskID = TaskPriority::DefaultPromiseEndpoint, AtMostOnce atMostOnce = AtMostOnce::False) { @@ -749,7 +749,7 @@ Future basicLoadBalance(Reference> al state int useAlt; loop { // Find an alternative, if any, that is not failed, starting with nextAlt - state RequestStream const* stream = nullptr; + state RequestStream const* stream = nullptr; for (int alternativeNum = 0; alternativeNum < alternatives->size(); alternativeNum++) { useAlt = nextAlt; if (nextAlt == startAlt) diff --git a/fdbrpc/fdbrpc.h b/fdbrpc/fdbrpc.h index c7c5b21480..d78816910a 100644 --- a/fdbrpc/fdbrpc.h +++ b/fdbrpc/fdbrpc.h @@ -825,8 +825,8 @@ public: queue->makeWellKnownEndpoint(Endpoint::Token(-1, wlTokenID), taskID); } - bool operator==(const RequestStream& rhs) const { return queue == rhs.queue; } - bool operator!=(const RequestStream& rhs) const { return !(*this == rhs); } + bool operator==(const RequestStream& rhs) const { return queue == rhs.queue; } + bool operator!=(const RequestStream& rhs) const { return !(*this == rhs); } bool isEmpty() const { return !queue->isReady(); } uint32_t size() const { return queue->size(); } @@ -838,29 +838,29 @@ private: NetNotifiedQueue* queue; }; -template -void save(Ar& ar, const RequestStream& value) { +template +void save(Ar& ar, const RequestStream& value) { auto const& ep = value.getEndpoint(); ar << ep; UNSTOPPABLE_ASSERT( ep.getPrimaryAddress().isValid()); // No serializing PromiseStreams on a client with no public address } -template -void load(Ar& ar, RequestStream& value) { +template +void load(Ar& ar, RequestStream& value) { Endpoint endpoint; ar >> endpoint; - value = RequestStream(endpoint); + value = RequestStream(endpoint); } -template -struct serializable_traits> : std::true_type { +template +struct serializable_traits> : std::true_type { template - static void serialize(Archiver& ar, RequestStream& stream) { + static void serialize(Archiver& ar, RequestStream& stream) { if constexpr (Archiver::isDeserializing) { Endpoint endpoint; serializer(ar, endpoint); - stream = RequestStream(endpoint); + stream = RequestStream(endpoint); } else { const auto& ep = stream.getEndpoint(); serializer(ar, ep); diff --git a/fdbrpc/genericactors.actor.h b/fdbrpc/genericactors.actor.h index 46a79d29cf..da476bf339 100644 --- a/fdbrpc/genericactors.actor.h +++ b/fdbrpc/genericactors.actor.h @@ -30,8 +30,8 @@ #include "fdbrpc/fdbrpc.h" #include "flow/actorcompiler.h" // This must be the last #include. -ACTOR template -Future retryBrokenPromise(RequestStream to, Req request) { +ACTOR template +Future retryBrokenPromise(RequestStream to, Req request) { // Like to.getReply(request), except that a broken_promise exception results in retrying request immediately. // Suitable for use with well known endpoints, which are likely to return to existence after the other process // restarts. Not normally useful for ordinary endpoints, which conventionally are permanently destroyed after @@ -50,8 +50,8 @@ Future retryBrokenPromise(RequestStream to, Req request) { } } -ACTOR template -Future retryBrokenPromise(RequestStream to, Req request, TaskPriority taskID) { +ACTOR template +Future retryBrokenPromise(RequestStream to, Req request, TaskPriority taskID) { // Like to.getReply(request), except that a broken_promise exception results in retrying request immediately. // Suitable for use with well known endpoints, which are likely to return to existence after the other process // restarts. Not normally useful for ordinary endpoints, which conventionally are permanently destroyed after diff --git a/fdbserver/ApplyMetadataMutation.cpp b/fdbserver/ApplyMetadataMutation.cpp index 47871915f9..96ca476648 100644 --- a/fdbserver/ApplyMetadataMutation.cpp +++ b/fdbserver/ApplyMetadataMutation.cpp @@ -107,7 +107,7 @@ private: KeyRangeMap* keyInfo = nullptr; KeyRangeMap* cacheInfo = nullptr; std::map* uid_applyMutationsData = nullptr; - RequestStream commit = RequestStream(); + RequestStream commit = RequestStream(); Database cx = Database(); NotifiedVersion* commitVersion = nullptr; std::map>* storageCache = nullptr; diff --git a/fdbserver/ClusterController.actor.cpp b/fdbserver/ClusterController.actor.cpp index 40456dba38..eacef60897 100644 --- a/fdbserver/ClusterController.actor.cpp +++ b/fdbserver/ClusterController.actor.cpp @@ -2919,7 +2919,7 @@ TEST_CASE("/fdbserver/clustercontroller/shouldTriggerRecoveryDueToDegradedServer testDbInfo.logSystemConfig.tLogs.push_back(remoteTLogSet); GrvProxyInterface proxyInterf; - proxyInterf.getConsistentReadVersion = RequestStream(Endpoint({ proxy }, testUID)); + proxyInterf.getConsistentReadVersion = RequestStream(Endpoint({ proxy }, testUID)); testDbInfo.client.grvProxies.push_back(proxyInterf); ResolverInterface resolverInterf; @@ -3028,11 +3028,11 @@ TEST_CASE("/fdbserver/clustercontroller/shouldTriggerFailoverDueToDegradedServer testDbInfo.logSystemConfig.tLogs.push_back(remoteTLogSet); GrvProxyInterface grvProxyInterf; - grvProxyInterf.getConsistentReadVersion = RequestStream(Endpoint({ proxy }, testUID)); + grvProxyInterf.getConsistentReadVersion = RequestStream(Endpoint({ proxy }, testUID)); testDbInfo.client.grvProxies.push_back(grvProxyInterf); CommitProxyInterface commitProxyInterf; - commitProxyInterf.commit = RequestStream(Endpoint({ proxy2 }, testUID)); + commitProxyInterf.commit = RequestStream(Endpoint({ proxy2 }, testUID)); testDbInfo.client.commitProxies.push_back(commitProxyInterf); ResolverInterface resolverInterf; diff --git a/fdbserver/GrvProxyServer.actor.cpp b/fdbserver/GrvProxyServer.actor.cpp index 5f427dd0b0..3ae922a980 100644 --- a/fdbserver/GrvProxyServer.actor.cpp +++ b/fdbserver/GrvProxyServer.actor.cpp @@ -229,7 +229,7 @@ struct GrvProxyData { GrvProxyStats stats; MasterInterface master; - RequestStream getConsistentReadVersion; + RequestStream getConsistentReadVersion; Reference logSystem; Database cx; @@ -261,7 +261,7 @@ struct GrvProxyData { GrvProxyData(UID dbgid, MasterInterface master, - RequestStream getConsistentReadVersion, + RequestStream getConsistentReadVersion, Reference const> db) : dbgid(dbgid), stats(dbgid), master(master), getConsistentReadVersion(getConsistentReadVersion), cx(openDBOnServer(db, TaskPriority::DefaultEndpoint, LockAware::True)), db(db), lastStartCommit(0), diff --git a/fdbserver/ProxyCommitData.actor.h b/fdbserver/ProxyCommitData.actor.h index d1f9ebccf4..715aa6ddcc 100644 --- a/fdbserver/ProxyCommitData.actor.h +++ b/fdbserver/ProxyCommitData.actor.h @@ -28,6 +28,9 @@ #include "fdbclient/FDBTypes.h" #include "fdbrpc/Stats.h" #include "fdbserver/Knobs.h" +#include "fdbserver/LogSystem.h" +#include "fdbserver/MasterInterface.h" +#include "fdbserver/ResolverInterface.h" #include "fdbserver/LogSystemDiskQueueAdapter.h" #include "flow/IRandom.h" @@ -189,8 +192,8 @@ struct ProxyCommitData { NotifiedVersion latestLocalCommitBatchResolving; NotifiedVersion latestLocalCommitBatchLogging; - RequestStream getConsistentReadVersion; - RequestStream commit; + RequestStream getConsistentReadVersion; + RequestStream commit; Database cx; Reference const> db; EventMetricHandle singleKeyMutationEvent; @@ -267,9 +270,9 @@ struct ProxyCommitData { ProxyCommitData(UID dbgid, MasterInterface master, - RequestStream getConsistentReadVersion, + RequestStream getConsistentReadVersion, Version recoveryTransactionVersion, - RequestStream commit, + RequestStream commit, Reference const> db, bool firstProxy) : dbgid(dbgid), commitBatchesMemBytesCount(0), diff --git a/fdbserver/fdbserver.actor.cpp b/fdbserver/fdbserver.actor.cpp index 4f5d7a26a9..2c2915f3f4 100644 --- a/fdbserver/fdbserver.actor.cpp +++ b/fdbserver/fdbserver.actor.cpp @@ -35,6 +35,8 @@ #include #include +#include + #include "fdbclient/ActorLineageProfiler.h" #include "fdbclient/ClusterConnectionFile.h" #include "fdbclient/IKnobCollection.h" @@ -45,6 +47,7 @@ #include "fdbclient/WellKnownEndpoints.h" #include "fdbclient/SimpleIni.h" #include "fdbrpc/AsyncFileCached.actor.h" +#include "fdbrpc/IPAllowList.h" #include "fdbrpc/Net2FileSystem.h" #include "fdbrpc/PerfMetric.h" #include "fdbrpc/simulator.h" @@ -1017,6 +1020,7 @@ struct CLIOptions { std::map profilerConfig; bool printSimTime = false; + IPAllowList allowList; static CLIOptions parseArgs(int argc, char* argv[]) { CLIOptions opts; @@ -1120,6 +1124,15 @@ private: localities.set(key, Standalone(std::string(args.OptionArg()))); break; } + case OPT_IP_TRUSTED_MASK: { + Optional subnetKey = extractPrefixedArgument("--trusted-subnet", args.OptionSyntax()); + if (!subnetKey.present()) { + fprintf(stderr, "ERROR: unable to parse locality key '%s'\n", args.OptionSyntax()); + flushAndExit(FDB_EXIT_ERROR); + } + allowList.addTrustedSubnet(args.OptionArg()); + break; + } case OPT_VERSION: printVersion(); flushAndExit(FDB_EXIT_SUCCESS); @@ -1668,103 +1681,7 @@ private: }; } // namespace -#include - -struct AuthAllowedSubnet { - IPAddress baseAddress; - IPAddress addressMask; - - AuthAllowedSubnet(IPAddress const& baseAddress, IPAddress const& addressMask) - : baseAddress(baseAddress), addressMask(addressMask) { - ASSERT(baseAddress.isV4() == addressMask.isV4()); - } - - static AuthAllowedSubnet fromString(std::string_view addressString) { - auto pos = addressString.find('/'); - if (pos == std::string_view::npos) { - fmt::print("ERROR: {} is not a valid (use Network-Prefix/hostcount syntax)\n"); - throw invalid_option(); - } - auto address = addressString.substr(0, pos); - auto hostCount = std::stoi(std::string(addressString.substr(pos + 1))); - auto addr = boost::asio::ip::make_address(address); - if (addr.is_v4()) { - auto bM = createBitMask(addr.to_v4().to_bytes(), hostCount); - // we typically would expect a base address has been passed, but to be safe we still - // will make the last bits 0 - auto mask = boost::asio::ip::address_v4(bM).to_uint(); - auto baseAddress = addr.to_v4().to_uint() & mask; - return AuthAllowedSubnet(IPAddress(baseAddress), IPAddress(mask)); - } else { - auto mask = createBitMask(addr.to_v6().to_bytes(), hostCount); - auto baseAddress = addr.to_v6().to_bytes(); - for (int i = 0; i < mask.size(); ++i) { - baseAddress[i] &= mask[i]; - } - return AuthAllowedSubnet(IPAddress(baseAddress), IPAddress(mask)); - } - } - - template - static auto createBitMask(std::array const& addr, int hostCount) -> std::array { - std::array res; - res.fill((unsigned char)0xff); - for (auto idx = (hostCount / 8) - 1; idx < res.size(); ++idx) { - if (hostCount > 0) { - // 2^(hostCount % 8) - 1 sets the last (hostCount % 8) number of bits to 1 - // everything else will be zero. For example: 2^3 - 1 == 7 == 0b111 - unsigned char bitmask = (1 << (hostCount % 8)) - ((unsigned char)1); - res[idx] ^= bitmask; - } else { - res[idx] = (unsigned char)0; - } - hostCount = 0; - } - return res; - } - - bool operator() (IPAddress const& address) const { - if (addressMask.isV4() != address.isV4()) { - return false; - } - if (addressMask.isV4()) { - return (addressMask.toV4() & address.toV4()) == baseAddress.toV4(); - } else { - auto res = address.toV6(); - auto const& mask = addressMask.toV6(); - for (int i = 0; i < res.size(); ++i) { - res[i] &= mask[i]; - } - return res == baseAddress.toV6(); - } - } -}; - -template -void printIP(std::array const& addr) { - for (auto c : addr) { - fmt::print(" {:02x}", int(c)); - } -} - -void printIP(std::string_view txt, IPAddress const& address) { - fmt::print("{}:", txt); - if (address.isV4()) { - printIP(boost::asio::ip::address_v4(address.toV4()).to_bytes()); - } else { - printIP(address.toV6()); - } - fmt::print("\n"); -} - int main(int argc, char* argv[]) { - //auto allowed = AuthAllowedSubnet::fromString(argv[1]); - //printIP("Base Address", allowed.baseAddress); - //printIP("Address Mask", allowed.addressMask); - //for (int idx = 1; idx < argc; ++idx) { - // auto addr = IPAddress::parse(argv[idx]); - //} - //return 0; // TODO: Remove later, this is just to force the statics to be initialized // otherwise the unit test won't run #ifdef ENABLE_SAMPLING @@ -1892,7 +1809,7 @@ int main(int argc, char* argv[]) { } else { g_network = newNet2(opts.tlsConfig, opts.useThreadPool, true); g_network->addStopCallback(Net2FileSystem::stop); - FlowTransport::createInstance(false, 1, WLTOKEN_RESERVED_COUNT); + FlowTransport::createInstance(false, 1, WLTOKEN_RESERVED_COUNT, &opts.allowList); const bool expectsPublicAddress = (role == ServerRole::FDBD || role == ServerRole::NetworkTestServer || role == ServerRole::Restore); diff --git a/fdbserver/worker.actor.cpp b/fdbserver/worker.actor.cpp index 9c19e13512..e2715f15b1 100644 --- a/fdbserver/worker.actor.cpp +++ b/fdbserver/worker.actor.cpp @@ -749,14 +749,14 @@ TEST_CASE("/fdbserver/worker/addressInDbAndPrimaryDc") { NetworkAddress grvProxyAddress(IPAddress(0x26262626), 1); GrvProxyInterface grvProxyInterf; grvProxyInterf.getConsistentReadVersion = - RequestStream(Endpoint({ grvProxyAddress }, UID(1, 2))); + RequestStream(Endpoint({ grvProxyAddress }, UID(1, 2))); testDbInfo.client.grvProxies.push_back(grvProxyInterf); ASSERT(addressInDbAndPrimaryDc(grvProxyAddress, makeReference>(testDbInfo))); NetworkAddress commitProxyAddress(IPAddress(0x37373737), 1); CommitProxyInterface commitProxyInterf; commitProxyInterf.commit = - RequestStream(Endpoint({ commitProxyAddress }, UID(1, 2))); + RequestStream(Endpoint({ commitProxyAddress }, UID(1, 2))); testDbInfo.client.commitProxies.push_back(commitProxyInterf); ASSERT(addressInDbAndPrimaryDc(commitProxyAddress, makeReference>(testDbInfo))); diff --git a/flow/CMakeLists.txt b/flow/CMakeLists.txt index a64b028974..990632f290 100644 --- a/flow/CMakeLists.txt +++ b/flow/CMakeLists.txt @@ -134,6 +134,7 @@ target_link_libraries(flow PUBLIC fmt::fmt) add_flow_target(STATIC_LIBRARY NAME flow_sampling SRCS ${FLOW_SRCS}) target_link_libraries(flow_sampling PRIVATE stacktrace) +target_link_libraries(flow_sampling PUBLIC fmt::fmt) target_compile_definitions(flow_sampling PRIVATE -DENABLE_SAMPLING) if(WIN32) add_dependencies(flow_sampling_actors flow_actors)