Address review coomments

This commit is contained in:
Markus Pilman 2022-02-23 19:02:29 +01:00
parent cf31e14904
commit 20bf3e1599
6 changed files with 101 additions and 42 deletions

View File

@ -258,7 +258,7 @@ struct TenantAuthorizer final : NetworkMessageReceiver {
class TransportData {
public:
TransportData(uint64_t transportId, int maxWellKnownEndpoints);
TransportData(uint64_t transportId, int maxWellKnownEndpoints, IPAllowList const* allowList);
~TransportData();
@ -303,6 +303,7 @@ public:
std::map<uint64_t, double> multiVersionConnections;
double lastIncompatibleMessage;
uint64_t transportId;
IPAllowList allowList;
Future<Void> multiVersionCleanup;
Future<Void> pingLogger;
@ -365,9 +366,10 @@ ACTOR Future<Void> pingLatencyLogger(TransportData* self) {
}
}
TransportData::TransportData(uint64_t transportId, int maxWellKnownEndpoints)
TransportData::TransportData(uint64_t transportId, int maxWellKnownEndpoints, IPAllowList const* allowList)
: warnAlwaysForLargePacket(true), endpoints(maxWellKnownEndpoints), endpointNotFoundReceiver(endpoints),
pingReceiver(endpoints), numIncompatibleConnections(0), lastIncompatibleMessage(0), transportId(transportId) {
pingReceiver(endpoints), numIncompatibleConnections(0), lastIncompatibleMessage(0), transportId(transportId),
allowList(allowList == nullptr ? IPAllowList() : *allowList) {
degraded = makeReference<AsyncVar<bool>>(false);
pingLogger = pingLatencyLogger(this);
}
@ -946,6 +948,7 @@ ACTOR static void deliver(TransportData* self,
ArenaReader reader,
NetworkAddress peerAddress,
AuthorizedTenants authorizedTenants,
ContextVariableMap* cvm,
bool inReadSocket) {
// We want to run the task at the right priority. If the priority is higher than the current priority (which is
// ReadSocket) we can just upgrade. Otherwise we'll context switch so that we don't block other tasks that might run
@ -972,9 +975,7 @@ ACTOR static void deliver(TransportData* self,
StringRef data = reader.arenaReadAll();
ASSERT(data.size() > 8);
ArenaObjectReader objReader(reader.arena(), reader.arenaReadAll(), AssumeVersion(reader.protocolVersion()));
bool didInsert = objReader.variable<AuthorizedTenants>("AuthorizedTenants", &authorizedTenants);
didInsert = didInsert && objReader.variable<NetworkAddress>("PeerAddress", &peerAddress);
ASSERT(didInsert); // check that we could set both context variables
objReader.setContextVariableMap(cvm);
receiver->receive(objReader);
g_currentDeliveryPeerAddress = { NetworkAddress() };
} catch (Error& e) {
@ -1012,7 +1013,8 @@ static void scanPackets(TransportData* transport,
const uint8_t* e,
Arena& arena,
NetworkAddress const& peerAddress,
AuthorizedTenants authorizedTenants,
AuthorizedTenants const& authorizedTenants,
ContextVariableMap* cvm,
ProtocolVersion peerProtocolVersion) {
// Find each complete packet in the given byte range and queue a ready task to deliver it.
// Remove the complete packets from the range by increasing unprocessed_begin.
@ -1132,6 +1134,7 @@ static void scanPackets(TransportData* transport,
std::move(reader),
peerAddress,
authorizedTenants,
cvm,
true);
}
@ -1179,6 +1182,10 @@ ACTOR static Future<Void> connectionReader(TransportData* transport,
state NetworkAddress peerAddress;
state ProtocolVersion peerProtocolVersion;
state AuthorizedTenants authorizedTenants;
authorizedTenants.trusted = transport->allowList(conn->getPeerAddress().ip);
ContextVariableMap cvm;
cvm["AuthorizedTenants"] = &authorizedTenants;
cvm["PeerAddress"] = &peerAddress;
peerAddress = conn->getPeerAddress();
// TODO: check whether peers ip is in trusted range
@ -1336,6 +1343,7 @@ ACTOR static Future<Void> connectionReader(TransportData* transport,
arena,
peerAddress,
authorizedTenants,
&cvm,
peerProtocolVersion);
} else {
unprocessed_begin = unprocessed_end;
@ -1473,8 +1481,8 @@ ACTOR static Future<Void> multiVersionCleanupWorker(TransportData* self) {
}
}
FlowTransport::FlowTransport(uint64_t transportId, int maxWellKnownEndpoints)
: self(new TransportData(transportId, maxWellKnownEndpoints)) {
FlowTransport::FlowTransport(uint64_t transportId, int maxWellKnownEndpoints, IPAllowList const* allowList)
: self(new TransportData(transportId, maxWellKnownEndpoints, allowList)) {
self->multiVersionCleanup = multiVersionCleanupWorker(self);
}
@ -1594,6 +1602,7 @@ static void sendLocal(TransportData* self, ISerializeSource const& what, const E
// SOMEDAY: Would it be better to avoid (de)serialization by doing this check in flow?
Standalone<StringRef> copy;
ContextVariableMap cvm;
ObjectWriter wr(AssumeVersion(g_network->protocolVersion()));
what.serializeObjectWriter(wr);
copy = wr.toStringRef();
@ -1612,6 +1621,7 @@ static void sendLocal(TransportData* self, ISerializeSource const& what, const E
ArenaReader(copy.arena(), copy, AssumeVersion(currentProtocolVersion)),
NetworkAddress(),
authorizedTenants,
&cvm,
false);
}
}
@ -1817,7 +1827,7 @@ void FlowTransport::createInstance(bool isClient,
int maxWellKnownEndpoints,
IPAllowList const* allowList) {
g_network->setGlobal(INetwork::enFlowTransport,
(flowGlobalType) new FlowTransport(transportId, maxWellKnownEndpoints));
(flowGlobalType) new FlowTransport(transportId, maxWellKnownEndpoints, allowList));
g_network->setGlobal(INetwork::enNetworkAddressFunc, (flowGlobalType)&FlowTransport::getGlobalLocalAddress);
g_network->setGlobal(INetwork::enNetworkAddressesFunc, (flowGlobalType)&FlowTransport::getGlobalLocalAddresses);
g_network->setGlobal(INetwork::enFailureMonitor, (flowGlobalType) new SimpleFailureMonitor());

View File

@ -186,7 +186,7 @@ class IPAllowList;
class FlowTransport {
public:
FlowTransport(uint64_t transportId, int maxWellKnownEndpoints);
FlowTransport(uint64_t transportId, int maxWellKnownEndpoints, IPAllowList const* allowList);
~FlowTransport();
// Creates a new FlowTransport and makes FlowTransport::transport() return it. This uses g_network->global()

View File

@ -1,5 +1,5 @@
/*
* IPAllowList.h
* IPAllowList.cpp
*
* This source file is part of the FoundationDB open source project
*
@ -37,7 +37,7 @@ void printIP(std::array<unsigned char, C> const& addr) {
}
template <size_t Sz>
int hostCountImpl(std::array<unsigned char, Sz> const& addr) {
int netmaskWeightImpl(std::array<unsigned char, Sz> const& addr) {
int count = 0;
for (int i = 0; i < addr.size() && addr[i] != 0xff; ++i) {
std::bitset<8> b(addr[i]);
@ -63,26 +63,26 @@ IPAddress AuthAllowedSubnet::netmask() const {
}
}
int AuthAllowedSubnet::hostCount() const {
int AuthAllowedSubnet::netmaskWeight() const {
if (addressMask.isV4()) {
boost::asio::ip::address_v4 addr(addressMask.toV4());
return hostCountImpl(addr.to_bytes());
return netmaskWeightImpl(addr.to_bytes());
} else {
return hostCountImpl(addressMask.toV6());
return netmaskWeightImpl(addressMask.toV6());
}
}
AuthAllowedSubnet AuthAllowedSubnet::fromString(std::string_view addressString) {
auto pos = addressString.find('/');
if (pos == std::string_view::npos) {
fmt::print("ERROR: {} is not a valid (use Network-Prefix/hostcount syntax)\n");
fmt::print("ERROR: {} is not a valid (use Network-Prefix/netmaskWeight syntax)\n");
throw invalid_option();
}
auto address = addressString.substr(0, pos);
auto hostCount = std::stoi(std::string(addressString.substr(pos + 1)));
auto netmaskWeight = std::stoi(std::string(addressString.substr(pos + 1)));
auto addr = boost::asio::ip::make_address(address);
if (addr.is_v4()) {
auto bM = createBitMask(addr.to_v4().to_bytes(), hostCount);
auto bM = createBitMask(addr.to_v4().to_bytes(), netmaskWeight);
// we typically would expect a base address has been passed, but to be safe we still
// will make the last bits 0
auto mask = boost::asio::ip::address_v4(bM).to_uint();
@ -92,7 +92,7 @@ AuthAllowedSubnet AuthAllowedSubnet::fromString(std::string_view addressString)
printIP("Mask:", IPAddress(mask));
return AuthAllowedSubnet(IPAddress(baseAddress), IPAddress(mask));
} else {
auto mask = createBitMask(addr.to_v6().to_bytes(), hostCount);
auto mask = createBitMask(addr.to_v6().to_bytes(), netmaskWeight);
auto baseAddress = addr.to_v6().to_bytes();
for (int i = 0; i < mask.size(); ++i) {
baseAddress[i] &= mask[i];
@ -149,7 +149,8 @@ IPAddress parseAddr(std::string const& str) {
struct SubNetTest {
AuthAllowedSubnet subnet;
SubNetTest(AuthAllowedSubnet&& subnet) : subnet(subnet) {}
SubNetTest(AuthAllowedSubnet&& subnet) : subnet(std::move(subnet)) {}
SubNetTest(AuthAllowedSubnet const& subnet) : subnet(subnet) {}
template <bool V4>
static SubNetTest randomSubNetImpl() {
constexpr int width = V4 ? 4 : 16;
@ -162,7 +163,7 @@ struct SubNetTest {
}
binAddr[i] = rnd[i % 4];
}
auto hostCount = deterministicRandom()->randomInt(1, width);
auto netmaskWeight = deterministicRandom()->randomInt(1, width);
std::string address;
if constexpr (V4) {
address_v4 a(binAddr);
@ -171,7 +172,7 @@ struct SubNetTest {
address_v6 a(binAddr);
address = a.to_string();
}
return SubNetTest(AuthAllowedSubnet::fromString(fmt::format("{}/{}", address, hostCount)));
return SubNetTest(AuthAllowedSubnet::fromString(fmt::format("{}/{}", address, netmaskWeight)));
}
static SubNetTest randomSubNet() {
if (deterministicRandom()->coinflip()) {
@ -182,7 +183,7 @@ struct SubNetTest {
}
template <bool V4>
IPAddress intArrayToAddress(uint32_t* arr) {
static IPAddress intArrayToAddress(uint32_t* arr) {
if constexpr (V4) {
return IPAddress(arr[0]);
} else {
@ -197,16 +198,22 @@ struct SubNetTest {
return (val & subnetMask) ^ baseAddress;
}
template<bool V4>
static IPAddress randomAddress() {
constexpr int width = V4 ? 4 : 16;
uint32_t rnd[width / 4];
for (int i = 0; i < width / 4; ++i) {
rnd[i] = deterministicRandom()->randomUInt32();
}
return intArrayToAddress<V4>(rnd);
}
template <bool V4>
IPAddress randomAddress(bool inSubnet) {
ASSERT(V4 == subnet.baseAddress.isV4() || !inSubnet);
constexpr int width = V4 ? 4 : 16;
for (;;) {
uint32_t rnd[width / 4];
for (int i = 0; i < width / 4; ++i) {
rnd[i] = deterministicRandom()->randomUInt32();
}
auto res = intArrayToAddress<V4>(rnd);
auto res = randomAddress<V4>();
if (V4 != subnet.baseAddress.isV4()) {
return res;
}
@ -272,6 +279,42 @@ TEST_CASE("/fdbrpc/allow_list") {
::subnetAssert(allowList, parseAddr("5.2.1.1"), true);
::subnetAssert(allowList, parseAddr("128.0.1.1"), false);
::subnetAssert(allowList, parseAddr("192.168.3.1"), false);
allowList = IPAllowList();
allowList.addTrustedSubnet("0.0.0.0/0");
SubNetTest subnetTest(allowList.subnets()[0]);
for (int i = 0; i < 10; ++i) {
// All IPv4 addresses are in the allow list
::subnetAssert(allowList, subnetTest.randomAddress<true>(), true);
// No IPv6 addresses are in the allow list
::subnetAssert(allowList, subnetTest.randomAddress<false>(), false);
}
allowList = IPAllowList();
allowList.addTrustedSubnet("::/0");
subnetTest = SubNetTest(allowList.subnets()[0]);
for (int i = 0; i < 10; ++i) {
// All IPv6 addresses are in the allow list
::subnetAssert(allowList, subnetTest.randomAddress<false>(), true);
// No IPv4 addresses are ub the allow list
::subnetAssert(allowList, subnetTest.randomAddress<true>(), false);
}
allowList = IPAllowList();
IPAddress baseAddress = SubNetTest::randomAddress<true>();
allowList.addTrustedSubnet(fmt::format("{}/32", baseAddress.toString()));
for (int i = 0; i < 10; ++i) {
auto rnd = SubNetTest::randomAddress<true>();
::subnetAssert(allowList, rnd, rnd == baseAddress);
rnd = SubNetTest::randomAddress<false>();
::subnetAssert(allowList, rnd, false);
}
allowList = IPAllowList();
baseAddress = SubNetTest::randomAddress<false>();
allowList.addTrustedSubnet(fmt::format("{}/128", baseAddress.toString()));
for (int i = 0; i < 10; ++i) {
auto rnd = SubNetTest::randomAddress<false>();
::subnetAssert(allowList, rnd, rnd == baseAddress);
rnd = SubNetTest::randomAddress<true>();
::subnetAssert(allowList, rnd, false);
}
for (int i = 0; i < 100; ++i) {
SubNetTest subnetTest(SubNetTest::randomSubNet());
allowList = IPAllowList();

View File

@ -37,15 +37,15 @@ struct AuthAllowedSubnet {
static AuthAllowedSubnet fromString(std::string_view addressString);
template <std::size_t sz>
static auto createBitMask(std::array<unsigned char, sz> const& addr, int hostCount)
static auto createBitMask(std::array<unsigned char, sz> const& addr, int netmaskWeight)
-> std::array<unsigned char, sz> {
std::array<unsigned char, sz> res;
res.fill((unsigned char)0xff);
int idx = hostCount / 8;
if (hostCount % 8 > 0) {
// 2^(hostCount % 8) - 1 sets the last (hostCount % 8) number of bits to 1
int idx = netmaskWeight / 8;
if (netmaskWeight % 8 > 0) {
// 2^(netmaskWeight % 8) - 1 sets the last (netmaskWeight % 8) number of bits to 1
// everything else will be zero. For example: 2^3 - 1 == 7 == 0b111
unsigned char bitmask = (1 << (hostCount % 8)) - ((unsigned char)1);
unsigned char bitmask = (1 << (netmaskWeight % 8)) - ((unsigned char)1);
res[idx] ^= bitmask;
++idx;
}
@ -73,7 +73,7 @@ struct AuthAllowedSubnet {
IPAddress netmask() const;
int hostCount() const;
int netmaskWeight() const;
// some useful helper functions if we need to debug ip masks etc
static void printIP(std::string_view txt, IPAddress const& address);

View File

@ -1,9 +1,9 @@
/*
* workloads.actor.h
* ClientWorkload.actor.cpp
*
* This source file is part of the FoundationDB open source project
*
* Copyright 2013-2018 Apple Inc. and the FoundationDB project authors
* Copyright 2013-2022 Apple Inc. and the FoundationDB project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -26,12 +26,14 @@
#include <unordered_map>
using ContextVariableMap = std::unordered_map<std::string_view, void*>;
template <class Ar>
struct LoadContext {
Ar* ar;
std::unordered_map<std::string_view, void*> variables;
ContextVariableMap* variables;
LoadContext(Ar* ar) : ar(ar) {}
LoadContext(Ar* ar, ContextVariableMap* variables = nullptr) : ar(ar), variables(variables) {}
Arena& arena() { return ar->arena(); }
ProtocolVersion protocolVersion() const { return ar->protocolVersion(); }
@ -54,19 +56,20 @@ struct LoadContext {
template <class T>
bool variable(std::string_view name, T* val) {
auto p = variables.insert(std::make_pair(name, val));
auto p = variables->insert(std::make_pair(name, val));
return p.second;
}
template <class T>
T& variable(std::string_view name) {
auto res = variables.at(name);
auto res = variables->at(name);
return *reinterpret_cast<T*>(res);
}
template <class T>
T const& variable(std::string_view name) const {
return const_cast<LoadContext<Ar>*>(this)->variable<T>(name);
auto res = variables->at(name);
return *reinterpret_cast<T*>(res);
}
};
@ -96,6 +99,9 @@ public:
_ObjectReader() : context(static_cast<ReaderImpl*>(this)) {}
ProtocolVersion protocolVersion() const { return mProtocolVersion; }
void setProtocolVersion(ProtocolVersion v) { mProtocolVersion = v; }
void setContextVariableMap(ContextVariableMap* cvm) {
context.variables = cvm;
}
template <class... Items>
void deserialize(FileIdentifier file_identifier, Items&... items) {