Allow List and first test

This commit is contained in:
Markus Pilman 2022-02-18 16:48:44 +01:00
parent c7899b9d39
commit dc973fb67e
21 changed files with 494 additions and 170 deletions

View File

@ -167,6 +167,7 @@ endif()
include(CompileBoost)
include(GetMsgpack)
add_subdirectory(contrib)
add_subdirectory(flow)
add_subdirectory(fdbrpc)
add_subdirectory(fdbclient)
@ -178,7 +179,6 @@ else()
add_subdirectory(fdbservice)
endif()
add_subdirectory(fdbbackup)
add_subdirectory(contrib)
add_subdirectory(tests)
add_subdirectory(flowbench EXCLUDE_FROM_ALL)
if(WITH_PYTHON AND WITH_C_BINDING)

View File

@ -561,7 +561,7 @@ ACTOR Future<Void> applyMutations(Database cx,
Key removePrefix,
Version beginVersion,
Version* endVersion,
RequestStream<CommitTransactionRequest> commit,
RequestStream<CommitTransactionRequest, true> commit,
NotifiedVersion* committedVersion,
Reference<KeyRangeMap<Version>> keyVersion);
ACTOR Future<Void> cleanupBackup(Database cx, DeleteData deleteData);

View File

@ -598,7 +598,7 @@ ACTOR Future<int> dumpData(Database cx,
Key uid,
Key addPrefix,
Key removePrefix,
RequestStream<CommitTransactionRequest> commit,
RequestStream<CommitTransactionRequest, true> commit,
NotifiedVersion* committedVersion,
Optional<Version> endVersion,
Key rangeBegin,
@ -675,7 +675,7 @@ ACTOR Future<int> dumpData(Database cx,
ACTOR Future<Void> coalesceKeyVersionCache(Key uid,
Version endVersion,
Reference<KeyRangeMap<Version>> keyVersion,
RequestStream<CommitTransactionRequest> commit,
RequestStream<CommitTransactionRequest, true> commit,
NotifiedVersion* committedVersion,
PromiseStream<Future<Void>> addActor,
FlowLock* commitLock) {
@ -725,7 +725,7 @@ ACTOR Future<Void> applyMutations(Database cx,
Key removePrefix,
Version beginVersion,
Version* endVersion,
RequestStream<CommitTransactionRequest> commit,
RequestStream<CommitTransactionRequest, true> commit,
NotifiedVersion* committedVersion,
Reference<KeyRangeMap<Version>> keyVersion) {
state FlowLock commitLock(CLIENT_KNOBS->BACKUP_LOCK_BYTES);

View File

@ -71,9 +71,9 @@ struct CommitProxyInterface {
serializer(ar, processId, provisional, commit);
if (Archive::isDeserializing) {
getConsistentReadVersion =
RequestStream<struct GetReadVersionRequest>(commit.getEndpoint().getAdjustedEndpoint(1));
RequestStream<struct GetReadVersionRequest, true>(commit.getEndpoint().getAdjustedEndpoint(1));
getKeyServersLocations =
RequestStream<struct GetKeyServerLocationsRequest>(commit.getEndpoint().getAdjustedEndpoint(2));
RequestStream<struct GetKeyServerLocationsRequest, true>(commit.getEndpoint().getAdjustedEndpoint(2));
getStorageServerRejoinInfo =
RequestStream<struct GetStorageServerRejoinInfoRequest>(commit.getEndpoint().getAdjustedEndpoint(3));
waitFailure = RequestStream<ReplyPromise<Void>>(commit.getEndpoint().getAdjustedEndpoint(4));

View File

@ -100,11 +100,11 @@ namespace {
TransactionLineageCollector transactionLineageCollector;
NameLineageCollector nameLineageCollector;
template <class Interface, class Request>
template <class Interface, class Request, bool P>
Future<REPLY_TYPE(Request)> loadBalance(
DatabaseContext* ctx,
const Reference<LocationInfo> alternatives,
RequestStream<Request> Interface::*channel,
RequestStream<Request, P> Interface::*channel,
const Request& request = Request(),
TaskPriority taskID = TaskPriority::DefaultPromiseEndpoint,
AtMostOnce atMostOnce =
@ -2223,7 +2223,7 @@ void stopNetwork() {
if (!g_network)
throw network_not_setup();
TraceEvent("ClientStopNetwork");
TraceEvent("ClientStopNetwork").log();
g_network->stop();
closeTraceFile();
}
@ -3176,7 +3176,7 @@ void transformRangeLimits(GetRangeLimits limits, Reverse reverse, GetKeyValuesFa
}
template <class GetKeyValuesFamilyRequest>
RequestStream<GetKeyValuesFamilyRequest> StorageServerInterface::*getRangeRequestStream() {
RequestStream<GetKeyValuesFamilyRequest, true> StorageServerInterface::*getRangeRequestStream() {
if constexpr (std::is_same<GetKeyValuesFamilyRequest, GetKeyValuesRequest>::value) {
return &StorageServerInterface::getKeyValues;
} else if (std::is_same<GetKeyValuesFamilyRequest, GetKeyValuesAndFlatMapRequest>::value) {
@ -3908,9 +3908,9 @@ static Future<Void> tssStreamComparison(Request request,
// Currently only used for GetKeyValuesStream but could easily be plugged for other stream types
// User of the stream has to forward the SS's responses to the returned promise stream, if it is set
template <class Request>
template <class Request, bool P>
Optional<TSSDuplicateStreamData<REPLYSTREAM_TYPE(Request)>>
maybeDuplicateTSSStreamFragment(Request& req, QueueModel* model, RequestStream<Request> const* ssStream) {
maybeDuplicateTSSStreamFragment(Request& req, QueueModel* model, RequestStream<Request, P> const* ssStream) {
if (model) {
Optional<TSSEndpointData> tssData = model->getTssData(ssStream->getEndpoint().token.first());

View File

@ -68,11 +68,11 @@ struct StorageServerInterface {
RequestStream<struct GetKeyValuesRequest, true> getKeyValues;
RequestStream<struct GetKeyValuesAndFlatMapRequest, true> getKeyValuesAndFlatMap;
RequestStream<struct GetShardStateRequest, true> getShardState;
RequestStream<struct GetShardStateRequest> getShardState;
RequestStream<struct WaitMetricsRequest> waitMetrics;
RequestStream<struct SplitMetricsRequest> splitMetrics;
RequestStream<struct GetStorageMetricsRequest> getStorageMetrics;
RequestStream<ReplyPromise<Void>, true> waitFailure;
RequestStream<ReplyPromise<Void>> waitFailure;
RequestStream<struct StorageQueuingMetricsRequest> getQueuingMetrics;
RequestStream<ReplyPromise<KeyValueStoreType>> getKeyValueStoreType;
@ -106,8 +106,8 @@ struct StorageServerInterface {
serializer(ar, uniqueID, locality, getValue);
}
if (Ar::isDeserializing) {
getKey = RequestStream<struct GetKeyRequest>(getValue.getEndpoint().getAdjustedEndpoint(1));
getKeyValues = RequestStream<struct GetKeyValuesRequest>(getValue.getEndpoint().getAdjustedEndpoint(2));
getKey = RequestStream<struct GetKeyRequest, true>(getValue.getEndpoint().getAdjustedEndpoint(1));
getKeyValues = RequestStream<struct GetKeyValuesRequest, true>(getValue.getEndpoint().getAdjustedEndpoint(2));
getShardState =
RequestStream<struct GetShardStateRequest>(getValue.getEndpoint().getAdjustedEndpoint(3));
waitMetrics = RequestStream<struct WaitMetricsRequest>(getValue.getEndpoint().getAdjustedEndpoint(4));
@ -119,22 +119,22 @@ struct StorageServerInterface {
RequestStream<struct StorageQueuingMetricsRequest>(getValue.getEndpoint().getAdjustedEndpoint(8));
getKeyValueStoreType =
RequestStream<ReplyPromise<KeyValueStoreType>>(getValue.getEndpoint().getAdjustedEndpoint(9));
watchValue = RequestStream<struct WatchValueRequest>(getValue.getEndpoint().getAdjustedEndpoint(10));
watchValue = RequestStream<struct WatchValueRequest, true>(getValue.getEndpoint().getAdjustedEndpoint(10));
getReadHotRanges =
RequestStream<struct ReadHotSubRangeRequest>(getValue.getEndpoint().getAdjustedEndpoint(11));
getRangeSplitPoints =
RequestStream<struct SplitRangeRequest>(getValue.getEndpoint().getAdjustedEndpoint(12));
getKeyValuesStream =
RequestStream<struct GetKeyValuesStreamRequest>(getValue.getEndpoint().getAdjustedEndpoint(13));
RequestStream<struct GetKeyValuesStreamRequest, true>(getValue.getEndpoint().getAdjustedEndpoint(13));
getKeyValuesAndFlatMap =
RequestStream<struct GetKeyValuesAndFlatMapRequest>(getValue.getEndpoint().getAdjustedEndpoint(14));
RequestStream<struct GetKeyValuesAndFlatMapRequest, true>(getValue.getEndpoint().getAdjustedEndpoint(14));
changeFeedStream =
RequestStream<struct ChangeFeedStreamRequest>(getValue.getEndpoint().getAdjustedEndpoint(15));
RequestStream<struct ChangeFeedStreamRequest, true>(getValue.getEndpoint().getAdjustedEndpoint(15));
overlappingChangeFeeds =
RequestStream<struct OverlappingChangeFeedsRequest>(getValue.getEndpoint().getAdjustedEndpoint(16));
RequestStream<struct OverlappingChangeFeedsRequest, true>(getValue.getEndpoint().getAdjustedEndpoint(16));
changeFeedPop =
RequestStream<struct ChangeFeedPopRequest>(getValue.getEndpoint().getAdjustedEndpoint(17));
changeFeedVersionUpdate = RequestStream<struct ChangeFeedVersionUpdateRequest>(
RequestStream<struct ChangeFeedPopRequest, true>(getValue.getEndpoint().getAdjustedEndpoint(17));
changeFeedVersionUpdate = RequestStream<struct ChangeFeedVersionUpdateRequest, true>(
getValue.getEndpoint().getAdjustedEndpoint(18));
}
} else {

View File

@ -15,6 +15,7 @@ set(FDBRPC_SRCS
genericactors.actor.cpp
HealthMonitor.actor.cpp
IAsyncFile.actor.cpp
IPAllowList.cpp
LoadBalance.actor.cpp
LoadBalance.actor.h
Locality.cpp

View File

@ -32,6 +32,7 @@
#include "fdbrpc/FailureMonitor.h"
#include "fdbrpc/HealthMonitor.h"
#include "fdbrpc/genericactors.actor.h"
#include "fdbrpc/IPAllowList.h"
#include "fdbrpc/simulator.h"
#include "flow/ActorCollection.h"
#include "flow/Error.h"
@ -1807,7 +1808,7 @@ bool FlowTransport::incompatibleOutgoingConnectionsPresent() {
return self->numIncompatibleConnections > 0;
}
void FlowTransport::createInstance(bool isClient, uint64_t transportId, int maxWellKnownEndpoints) {
void FlowTransport::createInstance(bool isClient, uint64_t transportId, int maxWellKnownEndpoints, IPAllowList const* allowList) {
g_network->setGlobal(INetwork::enFlowTransport,
(flowGlobalType) new FlowTransport(transportId, maxWellKnownEndpoints));
g_network->setGlobal(INetwork::enNetworkAddressFunc, (flowGlobalType)&FlowTransport::getGlobalLocalAddress);

View File

@ -182,6 +182,8 @@ struct Peer : public ReferenceCounted<Peer> {
void onIncomingConnection(Reference<Peer> self, Reference<IConnection> conn, Future<Void> reader);
};
class IPAllowList;
class FlowTransport {
public:
FlowTransport(uint64_t transportId, int maxWellKnownEndpoints);
@ -189,7 +191,7 @@ public:
// Creates a new FlowTransport and makes FlowTransport::transport() return it. This uses g_network->global()
// variables, so it will be private to a simulation.
static void createInstance(bool isClient, uint64_t transportId, int maxWellKnownEndpoints);
static void createInstance(bool isClient, uint64_t transportId, int maxWellKnownEndpoints, IPAllowList const* allowList = nullptr);
static bool isClient() { return g_network->global(INetwork::enClientFailureMonitor) != nullptr; }

290
fdbrpc/IPAllowList.cpp Normal file
View File

@ -0,0 +1,290 @@
/*
* IPAllowList.h
*
* This source file is part of the FoundationDB open source project
*
* Copyright 2013-2022 Apple Inc. and the FoundationDB project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "fdbrpc/IPAllowList.h"
#include "flow/UnitTest.h"
#include <fmt/printf.h>
#include <fmt/format.h>
namespace {
template <std::size_t C>
std::string binRep(std::array<unsigned char, C> const& addr) {
return fmt::format("{:02x}", fmt::join(addr, ":"));
}
template <std::size_t C>
void printIP(std::array<unsigned char, C> const& addr) {
fmt::print(" {}", binRep(addr));
}
template <size_t Sz>
int hostCountImpl(std::array<unsigned char, Sz> const& addr) {
int count = 0;
for (int i = 0; i < addr.size() && addr[i] != 0xff; ++i) {
std::bitset<8> b(addr[i]);
count += 8 - b.count();
}
return count;
}
} // namespace
IPAddress AuthAllowedSubnet::netmask() const {
if (addressMask.isV4()) {
uint32_t res = 0xffffffff ^ addressMask.toV4();
return IPAddress(res);
} else {
std::array<unsigned char, 16> res;
res.fill(0xff);
auto mask = addressMask.toV6();
for (int i = 0; i < mask.size(); ++i) {
res[i] ^= mask[i];
}
return IPAddress(res);
}
}
int AuthAllowedSubnet::hostCount() const {
if (addressMask.isV4()) {
boost::asio::ip::address_v4 addr(addressMask.toV4());
return hostCountImpl(addr.to_bytes());
} else {
return hostCountImpl(addressMask.toV6());
}
}
AuthAllowedSubnet AuthAllowedSubnet::fromString(std::string_view addressString) {
auto pos = addressString.find('/');
if (pos == std::string_view::npos) {
fmt::print("ERROR: {} is not a valid (use Network-Prefix/hostcount syntax)\n");
throw invalid_option();
}
auto address = addressString.substr(0, pos);
auto hostCount = std::stoi(std::string(addressString.substr(pos + 1)));
auto addr = boost::asio::ip::make_address(address);
if (addr.is_v4()) {
auto bM = createBitMask(addr.to_v4().to_bytes(), hostCount);
// we typically would expect a base address has been passed, but to be safe we still
// will make the last bits 0
auto mask = boost::asio::ip::address_v4(bM).to_uint();
auto baseAddress = addr.to_v4().to_uint() & mask;
fmt::print("For address {}:", addressString);
printIP("Base Address", IPAddress(baseAddress));
printIP("Mask:", IPAddress(mask));
return AuthAllowedSubnet(IPAddress(baseAddress), IPAddress(mask));
} else {
auto mask = createBitMask(addr.to_v6().to_bytes(), hostCount);
auto baseAddress = addr.to_v6().to_bytes();
for (int i = 0; i < mask.size(); ++i) {
baseAddress[i] &= mask[i];
}
return AuthAllowedSubnet(IPAddress(baseAddress), IPAddress(mask));
}
}
void AuthAllowedSubnet::printIP(std::string_view txt, IPAddress const& address) {
fmt::print("{}:", txt);
if (address.isV4()) {
::printIP(boost::asio::ip::address_v4(address.toV4()).to_bytes());
} else {
::printIP(address.toV6());
}
fmt::print("\n");
}
// helpers for testing
namespace {
using boost::asio::ip::address_v4;
using boost::asio::ip::address_v6;
void traceAddress(TraceEvent& evt, const char* name, IPAddress address) {
evt.detail(name, address);
std::string bin;
if (address.isV4()) {
boost::asio::ip::address_v4 a(address.toV4());
bin = binRep(a.to_bytes());
} else {
bin = binRep(address.toV6());
}
evt.detail(fmt::format("{}Binary", name).c_str(), bin);
}
void subnetAssert(IPAllowList const& allowList, IPAddress addr, bool expectAllowed) {
if (allowList(addr) == expectAllowed) {
return;
}
TraceEvent evt(SevError, expectAllowed ? "ExpectedAddressToBeTrusted" : "ExpectedAddressToBeUntrusted");
traceAddress(evt, "Address", addr);
auto const& subnets = allowList.subnets();
for (int i = 0; i < subnets.size(); ++i) {
traceAddress(evt, fmt::format("SubnetBase{}", i).c_str(), subnets[i].baseAddress);
traceAddress(evt, fmt::format("SubnetMask{}", i).c_str(), subnets[i].addressMask);
}
}
IPAddress parseAddr(std::string const& str) {
auto res = IPAddress::parse(str);
ASSERT(res.present());
return res.get();
}
struct SubNetTest {
AuthAllowedSubnet subnet;
SubNetTest(AuthAllowedSubnet&& subnet)
: subnet(subnet)
{
}
template<bool V4>
static SubNetTest randomSubNetImpl() {
constexpr int width = V4 ? 4 : 16;
std::array<unsigned char, width> binAddr;
unsigned char rnd[4];
for (int i = 0; i < binAddr.size(); ++i) {
if (i % 4 == 0) {
auto tmp = deterministicRandom()->randomUInt32();
::memcpy(rnd, &tmp, 4);
}
binAddr[i] = rnd[i % 4];
}
auto hostCount = deterministicRandom()->randomInt(1, width);
std::string address;
if constexpr (V4) {
address_v4 a(binAddr);
address = a.to_string();
} else {
address_v6 a(binAddr);
address = a.to_string();
}
return SubNetTest(AuthAllowedSubnet::fromString(fmt::format("{}/{}", address, hostCount)));
}
static SubNetTest randomSubNet() {
if (deterministicRandom()->coinflip()) {
return randomSubNetImpl<true>();
} else {
return randomSubNetImpl<false>();
}
}
template<bool V4>
IPAddress intArrayToAddress(uint32_t* arr) {
if constexpr (V4) {
return IPAddress(arr[0]);
} else {
std::array<unsigned char, 16> res;
memcpy(res.data(), arr, 4);
return IPAddress(res);
}
}
template<class I>
I transformIntToSubnet(I val, I subnetMask, I baseAddress) {
return (val & subnetMask) ^ baseAddress;
}
template <bool V4>
IPAddress randomAddress(bool inSubnet) {
ASSERT(V4 == subnet.baseAddress.isV4() || !inSubnet);
constexpr int width = V4 ? 4 : 16;
for (;;) {
uint32_t rnd[width / 4];
for (int i = 0; i < width / 4; ++i) {
rnd[i] = deterministicRandom()->randomUInt32();
}
auto res = intArrayToAddress<V4>(rnd);
if (V4 != subnet.baseAddress.isV4()) {
return res;
}
if (!inSubnet) {
if (!subnet(res)) {
return res;
} else {
continue;
}
}
// first we make sure the address is in the subnet
if constexpr (V4) {
auto a = res.toV4();
auto base = subnet.baseAddress.toV4();
auto netmask = subnet.netmask().toV4();
auto validAddress = transformIntToSubnet(a, netmask, base);
res = IPAddress(validAddress);
} else {
auto a = res.toV6();
auto base = subnet.baseAddress.toV6();
auto netmask = subnet.netmask().toV6();
for (int i = 0; i < a.size(); ++i) {
a[i] = transformIntToSubnet(a[i], netmask[i], base[i]);
}
res = IPAddress(a);
}
return res;
}
}
IPAddress randomAddress(bool inSubnet) {
if (!inSubnet && deterministicRandom()->random01() < 0.1) {
// return an address of a different type
if (subnet.baseAddress.isV4()) {
return randomAddress<false>(false);
} else {
return randomAddress<true>(false);
}
}
if (subnet.addressMask.isV4()) {
return randomAddress<true>(inSubnet);
} else {
return randomAddress<true>(inSubnet);
}
}
};
} // namespace
TEST_CASE("/fdbrpc/allow_list") {
IPAllowList allowList;
allowList.addTrustedSubnet("1.0.0.0/8");
allowList.addTrustedSubnet("2.0.0.0/4");
::subnetAssert(allowList, parseAddr("1.0.1.1"), true);
::subnetAssert(allowList, parseAddr("1.1.2.2"), true);
::subnetAssert(allowList, parseAddr("2.2.1.1"), true);
::subnetAssert(allowList, parseAddr("128.0.1.1"), false);
allowList = IPAllowList();
allowList.addTrustedSubnet("0.0.0.0/2");
::subnetAssert(allowList, parseAddr("1.0.1.1"), true);
::subnetAssert(allowList, parseAddr("1.1.2.2"), true);
::subnetAssert(allowList, parseAddr("2.2.1.1"), true);
::subnetAssert(allowList, parseAddr("5.2.1.1"), true);
::subnetAssert(allowList, parseAddr("128.0.1.1"), false);
::subnetAssert(allowList, parseAddr("192.168.3.1"), false);
for (int i = 0; i < 100; ++i) {
SubNetTest subnetTest(SubNetTest::randomSubNet());
allowList = IPAllowList();
allowList.addTrustedSubnet(subnetTest.subnet);
for (int j = 0; j < 10; ++j) {
bool inSubnet = deterministicRandom()->random01() < 0.7;
auto addr = subnetTest.randomAddress(inSubnet);
::subnetAssert(allowList, addr, inSubnet);
}
}
return Void();
}

109
fdbrpc/IPAllowList.h Normal file
View File

@ -0,0 +1,109 @@
/*
* IPAllowList.h
*
* This source file is part of the FoundationDB open source project
*
* Copyright 2013-2022 Apple Inc. and the FoundationDB project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#ifndef FDBRPC_IP_ALLOW_LIST_H
#define FDBRPC_IP_ALLOW_LIST_H
#include "flow/network.h"
#include "flow/Arena.h"
struct AuthAllowedSubnet {
IPAddress baseAddress;
IPAddress addressMask;
AuthAllowedSubnet(IPAddress const& baseAddress, IPAddress const& addressMask)
: baseAddress(baseAddress), addressMask(addressMask) {
ASSERT(baseAddress.isV4() == addressMask.isV4());
}
static AuthAllowedSubnet fromString(std::string_view addressString);
template <std::size_t sz>
static auto createBitMask(std::array<unsigned char, sz> const& addr, int hostCount) -> std::array<unsigned char, sz> {
std::array<unsigned char, sz> res;
res.fill((unsigned char)0xff);
int idx = hostCount / 8;
if (hostCount % 8 > 0) {
// 2^(hostCount % 8) - 1 sets the last (hostCount % 8) number of bits to 1
// everything else will be zero. For example: 2^3 - 1 == 7 == 0b111
unsigned char bitmask = (1 << (hostCount % 8)) - ((unsigned char)1);
res[idx] ^= bitmask;
++idx;
}
for (; idx < res.size(); ++idx) {
res[idx] = (unsigned char)0;
}
return res;
}
bool operator() (IPAddress const& address) const {
if (addressMask.isV4() != address.isV4()) {
return false;
}
if (addressMask.isV4()) {
return (addressMask.toV4() & address.toV4()) == baseAddress.toV4();
} else {
auto res = address.toV6();
auto const& mask = addressMask.toV6();
for (int i = 0; i < res.size(); ++i) {
res[i] &= mask[i];
}
return res == baseAddress.toV6();
}
}
IPAddress netmask() const;
int hostCount() const;
// some useful helper functions if we need to debug ip masks etc
static void printIP(std::string_view txt, IPAddress const& address);
};
class IPAllowList {
std::vector<AuthAllowedSubnet> subnetList;
public:
void addTrustedSubnet(std::string_view str) {
subnetList.push_back(AuthAllowedSubnet::fromString(str));
}
void addTrustedSubnet(AuthAllowedSubnet const& subnet) {
subnetList.push_back(subnet);
}
std::vector<AuthAllowedSubnet> const& subnets() const {
return subnetList;
}
bool operator() (IPAddress address) const {
if (subnetList.empty()) {
return true;
}
for (auto const& subnet : subnetList) {
if (subnet(address)) {
return true;
}
}
return false;
}
};
#endif // FDBRPC_IP_ALLOW_LIST_H

View File

@ -78,14 +78,14 @@ struct LoadBalancedReply {
Optional<LoadBalancedReply> getLoadBalancedReply(const LoadBalancedReply* reply);
Optional<LoadBalancedReply> getLoadBalancedReply(const void*);
ACTOR template <class Req, class Resp, class Interface, class Multi>
ACTOR template <class Req, class Resp, class Interface, class Multi, bool P>
Future<Void> tssComparison(Req req,
Future<ErrorOr<Resp>> fSource,
Future<ErrorOr<Resp>> fTss,
TSSEndpointData tssData,
uint64_t srcEndpointId,
Reference<MultiInterface<Multi>> ssTeam,
RequestStream<Req> Interface::*channel) {
RequestStream<Req, P> Interface::*channel) {
state double startTime = now();
state Future<Optional<ErrorOr<Resp>>> fTssWithTimeout = timeout(fTss, FLOW_KNOBS->LOAD_BALANCE_TSS_TIMEOUT);
state int finished = 0;
@ -157,7 +157,7 @@ Future<Void> tssComparison(Req req,
state std::vector<Future<ErrorOr<Resp>>> restOfTeamFutures;
restOfTeamFutures.reserve(ssTeam->size() - 1);
for (int i = 0; i < ssTeam->size(); i++) {
RequestStream<Req> const* si = &ssTeam->get(i, channel);
RequestStream<Req, P> const* si = &ssTeam->get(i, channel);
if (si->getEndpoint().token.first() !=
srcEndpointId) { // don't re-request to SS we already have a response from
resetReply(req);
@ -242,7 +242,7 @@ FDB_DECLARE_BOOLEAN_PARAM(AtMostOnce);
FDB_DECLARE_BOOLEAN_PARAM(TriedAllOptions);
// Stores state for a request made by the load balancer
template <class Request, class Interface, class Multi>
template <class Request, class Interface, class Multi, bool P>
struct RequestData : NonCopyable {
typedef ErrorOr<REPLY_TYPE(Request)> Reply;
@ -257,12 +257,12 @@ struct RequestData : NonCopyable {
// This is true once setupRequest is called, even though at that point the response is Never().
bool isValid() { return response.isValid(); }
static void maybeDuplicateTSSRequest(RequestStream<Request> const* stream,
static void maybeDuplicateTSSRequest(RequestStream<Request, P> const* stream,
Request& request,
QueueModel* model,
Future<Reply> ssResponse,
Reference<MultiInterface<Multi>> alternatives,
RequestStream<Request> Interface::*channel) {
RequestStream<Request, P> Interface::*channel) {
if (model) {
// Send parallel request to TSS pair, if it exists
Optional<TSSEndpointData> tssData = model->getTssData(stream->getEndpoint().token.first());
@ -271,7 +271,7 @@ struct RequestData : NonCopyable {
TEST(true); // duplicating request to TSS
resetReply(request);
// FIXME: optimize to avoid creating new netNotifiedQueue for each message
RequestStream<Request> tssRequestStream(tssData.get().endpoint);
RequestStream<Request, P> tssRequestStream(tssData.get().endpoint);
Future<ErrorOr<REPLY_TYPE(Request)>> fTssResult = tssRequestStream.tryGetReply(request);
model->addActor.send(tssComparison(request,
ssResponse,
@ -288,11 +288,11 @@ struct RequestData : NonCopyable {
void startRequest(
double backoff,
TriedAllOptions triedAllOptions,
RequestStream<Request> const* stream,
RequestStream<Request, P> const* stream,
Request& request,
QueueModel* model,
Reference<MultiInterface<Multi>> alternatives, // alternatives and channel passed through for TSS check
RequestStream<Request> Interface::*channel) {
RequestStream<Request, P> Interface::*channel) {
modelHolder = Reference<ModelHolder>();
requestStarted = false;
@ -438,18 +438,18 @@ struct RequestData : NonCopyable {
// list of servers.
// When model is set, load balance among alternatives in the same DC aims to balance request queue length on these
// interfaces. If too many interfaces in the same DC are bad, try remote interfaces.
ACTOR template <class Interface, class Request, class Multi>
ACTOR template <class Interface, class Request, class Multi, bool P>
Future<REPLY_TYPE(Request)> loadBalance(
Reference<MultiInterface<Multi>> alternatives,
RequestStream<Request> Interface::*channel,
RequestStream<Request, P> Interface::*channel,
Request request = Request(),
TaskPriority taskID = TaskPriority::DefaultPromiseEndpoint,
AtMostOnce atMostOnce =
AtMostOnce::False, // if true, throws request_maybe_delivered() instead of retrying automatically
QueueModel* model = nullptr) {
state RequestData<Request, Interface, Multi> firstRequestData;
state RequestData<Request, Interface, Multi> secondRequestData;
state RequestData<Request, Interface, Multi, P> firstRequestData;
state RequestData<Request, Interface, Multi, P> secondRequestData;
state Optional<uint64_t> firstRequestEndpoint;
state Future<Void> secondDelay = Never();
@ -488,7 +488,7 @@ Future<REPLY_TYPE(Request)> loadBalance(
break;
}
RequestStream<Request> const* thisStream = &alternatives->get(i, channel);
RequestStream<Request, P> const* thisStream = &alternatives->get(i, channel);
if (!IFailureMonitor::failureMonitor().getState(thisStream->getEndpoint()).failed) {
auto& qd = model->getMeasurement(thisStream->getEndpoint().token.first());
if (now() > qd.failedUntil) {
@ -527,7 +527,7 @@ Future<REPLY_TYPE(Request)> loadBalance(
// go through all the remote servers again, since we may have
// skipped it.
for (int i = alternatives->countBest(); i < alternatives->size(); i++) {
RequestStream<Request> const* thisStream = &alternatives->get(i, channel);
RequestStream<Request, P> const* thisStream = &alternatives->get(i, channel);
if (!IFailureMonitor::failureMonitor().getState(thisStream->getEndpoint()).failed) {
auto& qd = model->getMeasurement(thisStream->getEndpoint().token.first());
if (now() > qd.failedUntil) {
@ -574,7 +574,7 @@ Future<REPLY_TYPE(Request)> loadBalance(
if (ev.isEnabled()) {
ev.log();
for (int alternativeNum = 0; alternativeNum < alternatives->size(); alternativeNum++) {
RequestStream<Request> const* thisStream = &alternatives->get(alternativeNum, channel);
RequestStream<Request, P> const* thisStream = &alternatives->get(alternativeNum, channel);
TraceEvent(SevWarn, "LoadBalanceTooLongEndpoint")
.detail("Addr", thisStream->getEndpoint().getPrimaryAddress())
.detail("Token", thisStream->getEndpoint().token)
@ -586,7 +586,7 @@ Future<REPLY_TYPE(Request)> loadBalance(
// Find an alternative, if any, that is not failed, starting with
// nextAlt. This logic matters only if model == nullptr. Otherwise, the
// bestAlt and nextAlt have been decided.
state RequestStream<Request> const* stream = nullptr;
state RequestStream<Request, P> const* stream = nullptr;
for (int alternativeNum = 0; alternativeNum < alternatives->size(); alternativeNum++) {
int useAlt = nextAlt;
if (nextAlt == startAlt)
@ -724,9 +724,9 @@ Optional<BasicLoadBalancedReply> getBasicLoadBalancedReply(const BasicLoadBalanc
Optional<BasicLoadBalancedReply> getBasicLoadBalancedReply(const void*);
// A simpler version of LoadBalance that does not send second requests where the list of servers are always fresh
ACTOR template <class Interface, class Request, class Multi>
ACTOR template <class Interface, class Request, class Multi, bool P>
Future<REPLY_TYPE(Request)> basicLoadBalance(Reference<ModelInterface<Multi>> alternatives,
RequestStream<Request> Interface::*channel,
RequestStream<Request, P> Interface::*channel,
Request request = Request(),
TaskPriority taskID = TaskPriority::DefaultPromiseEndpoint,
AtMostOnce atMostOnce = AtMostOnce::False) {
@ -749,7 +749,7 @@ Future<REPLY_TYPE(Request)> basicLoadBalance(Reference<ModelInterface<Multi>> al
state int useAlt;
loop {
// Find an alternative, if any, that is not failed, starting with nextAlt
state RequestStream<Request> const* stream = nullptr;
state RequestStream<Request, P> const* stream = nullptr;
for (int alternativeNum = 0; alternativeNum < alternatives->size(); alternativeNum++) {
useAlt = nextAlt;
if (nextAlt == startAlt)

View File

@ -825,8 +825,8 @@ public:
queue->makeWellKnownEndpoint(Endpoint::Token(-1, wlTokenID), taskID);
}
bool operator==(const RequestStream<T>& rhs) const { return queue == rhs.queue; }
bool operator!=(const RequestStream<T>& rhs) const { return !(*this == rhs); }
bool operator==(const RequestStream<T, IsPublic>& rhs) const { return queue == rhs.queue; }
bool operator!=(const RequestStream<T, IsPublic>& rhs) const { return !(*this == rhs); }
bool isEmpty() const { return !queue->isReady(); }
uint32_t size() const { return queue->size(); }
@ -838,29 +838,29 @@ private:
NetNotifiedQueue<T, IsPublic>* queue;
};
template <class Ar, class T>
void save(Ar& ar, const RequestStream<T>& value) {
template <class Ar, class T, bool P>
void save(Ar& ar, const RequestStream<T, P>& value) {
auto const& ep = value.getEndpoint();
ar << ep;
UNSTOPPABLE_ASSERT(
ep.getPrimaryAddress().isValid()); // No serializing PromiseStreams on a client with no public address
}
template <class Ar, class T>
void load(Ar& ar, RequestStream<T>& value) {
template <class Ar, class T, bool P>
void load(Ar& ar, RequestStream<T, P>& value) {
Endpoint endpoint;
ar >> endpoint;
value = RequestStream<T>(endpoint);
value = RequestStream<T, P>(endpoint);
}
template <class T>
struct serializable_traits<RequestStream<T>> : std::true_type {
template <class T, bool P>
struct serializable_traits<RequestStream<T, P>> : std::true_type {
template <class Archiver>
static void serialize(Archiver& ar, RequestStream<T>& stream) {
static void serialize(Archiver& ar, RequestStream<T, P>& stream) {
if constexpr (Archiver::isDeserializing) {
Endpoint endpoint;
serializer(ar, endpoint);
stream = RequestStream<T>(endpoint);
stream = RequestStream<T, P>(endpoint);
} else {
const auto& ep = stream.getEndpoint();
serializer(ar, ep);

View File

@ -30,8 +30,8 @@
#include "fdbrpc/fdbrpc.h"
#include "flow/actorcompiler.h" // This must be the last #include.
ACTOR template <class Req>
Future<REPLY_TYPE(Req)> retryBrokenPromise(RequestStream<Req> to, Req request) {
ACTOR template <class Req, bool P>
Future<REPLY_TYPE(Req)> retryBrokenPromise(RequestStream<Req, P> to, Req request) {
// Like to.getReply(request), except that a broken_promise exception results in retrying request immediately.
// Suitable for use with well known endpoints, which are likely to return to existence after the other process
// restarts. Not normally useful for ordinary endpoints, which conventionally are permanently destroyed after
@ -50,8 +50,8 @@ Future<REPLY_TYPE(Req)> retryBrokenPromise(RequestStream<Req> to, Req request) {
}
}
ACTOR template <class Req>
Future<REPLY_TYPE(Req)> retryBrokenPromise(RequestStream<Req> to, Req request, TaskPriority taskID) {
ACTOR template <class Req, bool P>
Future<REPLY_TYPE(Req)> retryBrokenPromise(RequestStream<Req, P> to, Req request, TaskPriority taskID) {
// Like to.getReply(request), except that a broken_promise exception results in retrying request immediately.
// Suitable for use with well known endpoints, which are likely to return to existence after the other process
// restarts. Not normally useful for ordinary endpoints, which conventionally are permanently destroyed after

View File

@ -107,7 +107,7 @@ private:
KeyRangeMap<ServerCacheInfo>* keyInfo = nullptr;
KeyRangeMap<bool>* cacheInfo = nullptr;
std::map<Key, ApplyMutationsData>* uid_applyMutationsData = nullptr;
RequestStream<CommitTransactionRequest> commit = RequestStream<CommitTransactionRequest>();
RequestStream<CommitTransactionRequest, true> commit = RequestStream<CommitTransactionRequest, true>();
Database cx = Database();
NotifiedVersion* commitVersion = nullptr;
std::map<UID, Reference<StorageInfo>>* storageCache = nullptr;

View File

@ -2919,7 +2919,7 @@ TEST_CASE("/fdbserver/clustercontroller/shouldTriggerRecoveryDueToDegradedServer
testDbInfo.logSystemConfig.tLogs.push_back(remoteTLogSet);
GrvProxyInterface proxyInterf;
proxyInterf.getConsistentReadVersion = RequestStream<struct GetReadVersionRequest>(Endpoint({ proxy }, testUID));
proxyInterf.getConsistentReadVersion = RequestStream<struct GetReadVersionRequest, true>(Endpoint({ proxy }, testUID));
testDbInfo.client.grvProxies.push_back(proxyInterf);
ResolverInterface resolverInterf;
@ -3028,11 +3028,11 @@ TEST_CASE("/fdbserver/clustercontroller/shouldTriggerFailoverDueToDegradedServer
testDbInfo.logSystemConfig.tLogs.push_back(remoteTLogSet);
GrvProxyInterface grvProxyInterf;
grvProxyInterf.getConsistentReadVersion = RequestStream<struct GetReadVersionRequest>(Endpoint({ proxy }, testUID));
grvProxyInterf.getConsistentReadVersion = RequestStream<struct GetReadVersionRequest, true>(Endpoint({ proxy }, testUID));
testDbInfo.client.grvProxies.push_back(grvProxyInterf);
CommitProxyInterface commitProxyInterf;
commitProxyInterf.commit = RequestStream<struct CommitTransactionRequest>(Endpoint({ proxy2 }, testUID));
commitProxyInterf.commit = RequestStream<struct CommitTransactionRequest, true>(Endpoint({ proxy2 }, testUID));
testDbInfo.client.commitProxies.push_back(commitProxyInterf);
ResolverInterface resolverInterf;

View File

@ -229,7 +229,7 @@ struct GrvProxyData {
GrvProxyStats stats;
MasterInterface master;
RequestStream<GetReadVersionRequest> getConsistentReadVersion;
RequestStream<GetReadVersionRequest, true> getConsistentReadVersion;
Reference<ILogSystem> logSystem;
Database cx;
@ -261,7 +261,7 @@ struct GrvProxyData {
GrvProxyData(UID dbgid,
MasterInterface master,
RequestStream<GetReadVersionRequest> getConsistentReadVersion,
RequestStream<GetReadVersionRequest, true> getConsistentReadVersion,
Reference<AsyncVar<ServerDBInfo> const> db)
: dbgid(dbgid), stats(dbgid), master(master), getConsistentReadVersion(getConsistentReadVersion),
cx(openDBOnServer(db, TaskPriority::DefaultEndpoint, LockAware::True)), db(db), lastStartCommit(0),

View File

@ -28,6 +28,9 @@
#include "fdbclient/FDBTypes.h"
#include "fdbrpc/Stats.h"
#include "fdbserver/Knobs.h"
#include "fdbserver/LogSystem.h"
#include "fdbserver/MasterInterface.h"
#include "fdbserver/ResolverInterface.h"
#include "fdbserver/LogSystemDiskQueueAdapter.h"
#include "flow/IRandom.h"
@ -189,8 +192,8 @@ struct ProxyCommitData {
NotifiedVersion latestLocalCommitBatchResolving;
NotifiedVersion latestLocalCommitBatchLogging;
RequestStream<GetReadVersionRequest> getConsistentReadVersion;
RequestStream<CommitTransactionRequest> commit;
RequestStream<GetReadVersionRequest, true> getConsistentReadVersion;
RequestStream<CommitTransactionRequest, true> commit;
Database cx;
Reference<AsyncVar<ServerDBInfo> const> db;
EventMetricHandle<SingleKeyMutation> singleKeyMutationEvent;
@ -267,9 +270,9 @@ struct ProxyCommitData {
ProxyCommitData(UID dbgid,
MasterInterface master,
RequestStream<GetReadVersionRequest> getConsistentReadVersion,
RequestStream<GetReadVersionRequest, true> getConsistentReadVersion,
Version recoveryTransactionVersion,
RequestStream<CommitTransactionRequest> commit,
RequestStream<CommitTransactionRequest, true> commit,
Reference<AsyncVar<ServerDBInfo> const> db,
bool firstProxy)
: dbgid(dbgid), commitBatchesMemBytesCount(0),

View File

@ -35,6 +35,8 @@
#include <boost/algorithm/string.hpp>
#include <boost/interprocess/managed_shared_memory.hpp>
#include <fmt/printf.h>
#include "fdbclient/ActorLineageProfiler.h"
#include "fdbclient/ClusterConnectionFile.h"
#include "fdbclient/IKnobCollection.h"
@ -45,6 +47,7 @@
#include "fdbclient/WellKnownEndpoints.h"
#include "fdbclient/SimpleIni.h"
#include "fdbrpc/AsyncFileCached.actor.h"
#include "fdbrpc/IPAllowList.h"
#include "fdbrpc/Net2FileSystem.h"
#include "fdbrpc/PerfMetric.h"
#include "fdbrpc/simulator.h"
@ -1017,6 +1020,7 @@ struct CLIOptions {
std::map<std::string, std::string> profilerConfig;
bool printSimTime = false;
IPAllowList allowList;
static CLIOptions parseArgs(int argc, char* argv[]) {
CLIOptions opts;
@ -1120,6 +1124,15 @@ private:
localities.set(key, Standalone<StringRef>(std::string(args.OptionArg())));
break;
}
case OPT_IP_TRUSTED_MASK: {
Optional<std::string> subnetKey = extractPrefixedArgument("--trusted-subnet", args.OptionSyntax());
if (!subnetKey.present()) {
fprintf(stderr, "ERROR: unable to parse locality key '%s'\n", args.OptionSyntax());
flushAndExit(FDB_EXIT_ERROR);
}
allowList.addTrustedSubnet(args.OptionArg());
break;
}
case OPT_VERSION:
printVersion();
flushAndExit(FDB_EXIT_SUCCESS);
@ -1668,103 +1681,7 @@ private:
};
} // namespace
#include <fmt/printf.h>
struct AuthAllowedSubnet {
IPAddress baseAddress;
IPAddress addressMask;
AuthAllowedSubnet(IPAddress const& baseAddress, IPAddress const& addressMask)
: baseAddress(baseAddress), addressMask(addressMask) {
ASSERT(baseAddress.isV4() == addressMask.isV4());
}
static AuthAllowedSubnet fromString(std::string_view addressString) {
auto pos = addressString.find('/');
if (pos == std::string_view::npos) {
fmt::print("ERROR: {} is not a valid (use Network-Prefix/hostcount syntax)\n");
throw invalid_option();
}
auto address = addressString.substr(0, pos);
auto hostCount = std::stoi(std::string(addressString.substr(pos + 1)));
auto addr = boost::asio::ip::make_address(address);
if (addr.is_v4()) {
auto bM = createBitMask(addr.to_v4().to_bytes(), hostCount);
// we typically would expect a base address has been passed, but to be safe we still
// will make the last bits 0
auto mask = boost::asio::ip::address_v4(bM).to_uint();
auto baseAddress = addr.to_v4().to_uint() & mask;
return AuthAllowedSubnet(IPAddress(baseAddress), IPAddress(mask));
} else {
auto mask = createBitMask(addr.to_v6().to_bytes(), hostCount);
auto baseAddress = addr.to_v6().to_bytes();
for (int i = 0; i < mask.size(); ++i) {
baseAddress[i] &= mask[i];
}
return AuthAllowedSubnet(IPAddress(baseAddress), IPAddress(mask));
}
}
template <std::size_t sz>
static auto createBitMask(std::array<unsigned char, sz> const& addr, int hostCount) -> std::array<unsigned char, sz> {
std::array<unsigned char, sz> res;
res.fill((unsigned char)0xff);
for (auto idx = (hostCount / 8) - 1; idx < res.size(); ++idx) {
if (hostCount > 0) {
// 2^(hostCount % 8) - 1 sets the last (hostCount % 8) number of bits to 1
// everything else will be zero. For example: 2^3 - 1 == 7 == 0b111
unsigned char bitmask = (1 << (hostCount % 8)) - ((unsigned char)1);
res[idx] ^= bitmask;
} else {
res[idx] = (unsigned char)0;
}
hostCount = 0;
}
return res;
}
bool operator() (IPAddress const& address) const {
if (addressMask.isV4() != address.isV4()) {
return false;
}
if (addressMask.isV4()) {
return (addressMask.toV4() & address.toV4()) == baseAddress.toV4();
} else {
auto res = address.toV6();
auto const& mask = addressMask.toV6();
for (int i = 0; i < res.size(); ++i) {
res[i] &= mask[i];
}
return res == baseAddress.toV6();
}
}
};
template<std::size_t C>
void printIP(std::array<unsigned char, C> const& addr) {
for (auto c : addr) {
fmt::print(" {:02x}", int(c));
}
}
void printIP(std::string_view txt, IPAddress const& address) {
fmt::print("{}:", txt);
if (address.isV4()) {
printIP(boost::asio::ip::address_v4(address.toV4()).to_bytes());
} else {
printIP(address.toV6());
}
fmt::print("\n");
}
int main(int argc, char* argv[]) {
//auto allowed = AuthAllowedSubnet::fromString(argv[1]);
//printIP("Base Address", allowed.baseAddress);
//printIP("Address Mask", allowed.addressMask);
//for (int idx = 1; idx < argc; ++idx) {
// auto addr = IPAddress::parse(argv[idx]);
//}
//return 0;
// TODO: Remove later, this is just to force the statics to be initialized
// otherwise the unit test won't run
#ifdef ENABLE_SAMPLING
@ -1892,7 +1809,7 @@ int main(int argc, char* argv[]) {
} else {
g_network = newNet2(opts.tlsConfig, opts.useThreadPool, true);
g_network->addStopCallback(Net2FileSystem::stop);
FlowTransport::createInstance(false, 1, WLTOKEN_RESERVED_COUNT);
FlowTransport::createInstance(false, 1, WLTOKEN_RESERVED_COUNT, &opts.allowList);
const bool expectsPublicAddress =
(role == ServerRole::FDBD || role == ServerRole::NetworkTestServer || role == ServerRole::Restore);

View File

@ -749,14 +749,14 @@ TEST_CASE("/fdbserver/worker/addressInDbAndPrimaryDc") {
NetworkAddress grvProxyAddress(IPAddress(0x26262626), 1);
GrvProxyInterface grvProxyInterf;
grvProxyInterf.getConsistentReadVersion =
RequestStream<struct GetReadVersionRequest>(Endpoint({ grvProxyAddress }, UID(1, 2)));
RequestStream<struct GetReadVersionRequest, true>(Endpoint({ grvProxyAddress }, UID(1, 2)));
testDbInfo.client.grvProxies.push_back(grvProxyInterf);
ASSERT(addressInDbAndPrimaryDc(grvProxyAddress, makeReference<AsyncVar<ServerDBInfo>>(testDbInfo)));
NetworkAddress commitProxyAddress(IPAddress(0x37373737), 1);
CommitProxyInterface commitProxyInterf;
commitProxyInterf.commit =
RequestStream<struct CommitTransactionRequest>(Endpoint({ commitProxyAddress }, UID(1, 2)));
RequestStream<struct CommitTransactionRequest, true>(Endpoint({ commitProxyAddress }, UID(1, 2)));
testDbInfo.client.commitProxies.push_back(commitProxyInterf);
ASSERT(addressInDbAndPrimaryDc(commitProxyAddress, makeReference<AsyncVar<ServerDBInfo>>(testDbInfo)));

View File

@ -134,6 +134,7 @@ target_link_libraries(flow PUBLIC fmt::fmt)
add_flow_target(STATIC_LIBRARY NAME flow_sampling SRCS ${FLOW_SRCS})
target_link_libraries(flow_sampling PRIVATE stacktrace)
target_link_libraries(flow_sampling PUBLIC fmt::fmt)
target_compile_definitions(flow_sampling PRIVATE -DENABLE_SAMPLING)
if(WIN32)
add_dependencies(flow_sampling_actors flow_actors)