Address review coomments
This commit is contained in:
parent
cf31e14904
commit
20bf3e1599
|
@ -258,7 +258,7 @@ struct TenantAuthorizer final : NetworkMessageReceiver {
|
|||
|
||||
class TransportData {
|
||||
public:
|
||||
TransportData(uint64_t transportId, int maxWellKnownEndpoints);
|
||||
TransportData(uint64_t transportId, int maxWellKnownEndpoints, IPAllowList const* allowList);
|
||||
|
||||
~TransportData();
|
||||
|
||||
|
@ -303,6 +303,7 @@ public:
|
|||
std::map<uint64_t, double> multiVersionConnections;
|
||||
double lastIncompatibleMessage;
|
||||
uint64_t transportId;
|
||||
IPAllowList allowList;
|
||||
|
||||
Future<Void> multiVersionCleanup;
|
||||
Future<Void> pingLogger;
|
||||
|
@ -365,9 +366,10 @@ ACTOR Future<Void> pingLatencyLogger(TransportData* self) {
|
|||
}
|
||||
}
|
||||
|
||||
TransportData::TransportData(uint64_t transportId, int maxWellKnownEndpoints)
|
||||
TransportData::TransportData(uint64_t transportId, int maxWellKnownEndpoints, IPAllowList const* allowList)
|
||||
: warnAlwaysForLargePacket(true), endpoints(maxWellKnownEndpoints), endpointNotFoundReceiver(endpoints),
|
||||
pingReceiver(endpoints), numIncompatibleConnections(0), lastIncompatibleMessage(0), transportId(transportId) {
|
||||
pingReceiver(endpoints), numIncompatibleConnections(0), lastIncompatibleMessage(0), transportId(transportId),
|
||||
allowList(allowList == nullptr ? IPAllowList() : *allowList) {
|
||||
degraded = makeReference<AsyncVar<bool>>(false);
|
||||
pingLogger = pingLatencyLogger(this);
|
||||
}
|
||||
|
@ -946,6 +948,7 @@ ACTOR static void deliver(TransportData* self,
|
|||
ArenaReader reader,
|
||||
NetworkAddress peerAddress,
|
||||
AuthorizedTenants authorizedTenants,
|
||||
ContextVariableMap* cvm,
|
||||
bool inReadSocket) {
|
||||
// We want to run the task at the right priority. If the priority is higher than the current priority (which is
|
||||
// ReadSocket) we can just upgrade. Otherwise we'll context switch so that we don't block other tasks that might run
|
||||
|
@ -972,9 +975,7 @@ ACTOR static void deliver(TransportData* self,
|
|||
StringRef data = reader.arenaReadAll();
|
||||
ASSERT(data.size() > 8);
|
||||
ArenaObjectReader objReader(reader.arena(), reader.arenaReadAll(), AssumeVersion(reader.protocolVersion()));
|
||||
bool didInsert = objReader.variable<AuthorizedTenants>("AuthorizedTenants", &authorizedTenants);
|
||||
didInsert = didInsert && objReader.variable<NetworkAddress>("PeerAddress", &peerAddress);
|
||||
ASSERT(didInsert); // check that we could set both context variables
|
||||
objReader.setContextVariableMap(cvm);
|
||||
receiver->receive(objReader);
|
||||
g_currentDeliveryPeerAddress = { NetworkAddress() };
|
||||
} catch (Error& e) {
|
||||
|
@ -1012,7 +1013,8 @@ static void scanPackets(TransportData* transport,
|
|||
const uint8_t* e,
|
||||
Arena& arena,
|
||||
NetworkAddress const& peerAddress,
|
||||
AuthorizedTenants authorizedTenants,
|
||||
AuthorizedTenants const& authorizedTenants,
|
||||
ContextVariableMap* cvm,
|
||||
ProtocolVersion peerProtocolVersion) {
|
||||
// Find each complete packet in the given byte range and queue a ready task to deliver it.
|
||||
// Remove the complete packets from the range by increasing unprocessed_begin.
|
||||
|
@ -1132,6 +1134,7 @@ static void scanPackets(TransportData* transport,
|
|||
std::move(reader),
|
||||
peerAddress,
|
||||
authorizedTenants,
|
||||
cvm,
|
||||
true);
|
||||
}
|
||||
|
||||
|
@ -1179,6 +1182,10 @@ ACTOR static Future<Void> connectionReader(TransportData* transport,
|
|||
state NetworkAddress peerAddress;
|
||||
state ProtocolVersion peerProtocolVersion;
|
||||
state AuthorizedTenants authorizedTenants;
|
||||
authorizedTenants.trusted = transport->allowList(conn->getPeerAddress().ip);
|
||||
ContextVariableMap cvm;
|
||||
cvm["AuthorizedTenants"] = &authorizedTenants;
|
||||
cvm["PeerAddress"] = &peerAddress;
|
||||
|
||||
peerAddress = conn->getPeerAddress();
|
||||
// TODO: check whether peers ip is in trusted range
|
||||
|
@ -1336,6 +1343,7 @@ ACTOR static Future<Void> connectionReader(TransportData* transport,
|
|||
arena,
|
||||
peerAddress,
|
||||
authorizedTenants,
|
||||
&cvm,
|
||||
peerProtocolVersion);
|
||||
} else {
|
||||
unprocessed_begin = unprocessed_end;
|
||||
|
@ -1473,8 +1481,8 @@ ACTOR static Future<Void> multiVersionCleanupWorker(TransportData* self) {
|
|||
}
|
||||
}
|
||||
|
||||
FlowTransport::FlowTransport(uint64_t transportId, int maxWellKnownEndpoints)
|
||||
: self(new TransportData(transportId, maxWellKnownEndpoints)) {
|
||||
FlowTransport::FlowTransport(uint64_t transportId, int maxWellKnownEndpoints, IPAllowList const* allowList)
|
||||
: self(new TransportData(transportId, maxWellKnownEndpoints, allowList)) {
|
||||
self->multiVersionCleanup = multiVersionCleanupWorker(self);
|
||||
}
|
||||
|
||||
|
@ -1594,6 +1602,7 @@ static void sendLocal(TransportData* self, ISerializeSource const& what, const E
|
|||
// SOMEDAY: Would it be better to avoid (de)serialization by doing this check in flow?
|
||||
|
||||
Standalone<StringRef> copy;
|
||||
ContextVariableMap cvm;
|
||||
ObjectWriter wr(AssumeVersion(g_network->protocolVersion()));
|
||||
what.serializeObjectWriter(wr);
|
||||
copy = wr.toStringRef();
|
||||
|
@ -1612,6 +1621,7 @@ static void sendLocal(TransportData* self, ISerializeSource const& what, const E
|
|||
ArenaReader(copy.arena(), copy, AssumeVersion(currentProtocolVersion)),
|
||||
NetworkAddress(),
|
||||
authorizedTenants,
|
||||
&cvm,
|
||||
false);
|
||||
}
|
||||
}
|
||||
|
@ -1817,7 +1827,7 @@ void FlowTransport::createInstance(bool isClient,
|
|||
int maxWellKnownEndpoints,
|
||||
IPAllowList const* allowList) {
|
||||
g_network->setGlobal(INetwork::enFlowTransport,
|
||||
(flowGlobalType) new FlowTransport(transportId, maxWellKnownEndpoints));
|
||||
(flowGlobalType) new FlowTransport(transportId, maxWellKnownEndpoints, allowList));
|
||||
g_network->setGlobal(INetwork::enNetworkAddressFunc, (flowGlobalType)&FlowTransport::getGlobalLocalAddress);
|
||||
g_network->setGlobal(INetwork::enNetworkAddressesFunc, (flowGlobalType)&FlowTransport::getGlobalLocalAddresses);
|
||||
g_network->setGlobal(INetwork::enFailureMonitor, (flowGlobalType) new SimpleFailureMonitor());
|
||||
|
|
|
@ -186,7 +186,7 @@ class IPAllowList;
|
|||
|
||||
class FlowTransport {
|
||||
public:
|
||||
FlowTransport(uint64_t transportId, int maxWellKnownEndpoints);
|
||||
FlowTransport(uint64_t transportId, int maxWellKnownEndpoints, IPAllowList const* allowList);
|
||||
~FlowTransport();
|
||||
|
||||
// Creates a new FlowTransport and makes FlowTransport::transport() return it. This uses g_network->global()
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* IPAllowList.h
|
||||
* IPAllowList.cpp
|
||||
*
|
||||
* This source file is part of the FoundationDB open source project
|
||||
*
|
||||
|
@ -37,7 +37,7 @@ void printIP(std::array<unsigned char, C> const& addr) {
|
|||
}
|
||||
|
||||
template <size_t Sz>
|
||||
int hostCountImpl(std::array<unsigned char, Sz> const& addr) {
|
||||
int netmaskWeightImpl(std::array<unsigned char, Sz> const& addr) {
|
||||
int count = 0;
|
||||
for (int i = 0; i < addr.size() && addr[i] != 0xff; ++i) {
|
||||
std::bitset<8> b(addr[i]);
|
||||
|
@ -63,26 +63,26 @@ IPAddress AuthAllowedSubnet::netmask() const {
|
|||
}
|
||||
}
|
||||
|
||||
int AuthAllowedSubnet::hostCount() const {
|
||||
int AuthAllowedSubnet::netmaskWeight() const {
|
||||
if (addressMask.isV4()) {
|
||||
boost::asio::ip::address_v4 addr(addressMask.toV4());
|
||||
return hostCountImpl(addr.to_bytes());
|
||||
return netmaskWeightImpl(addr.to_bytes());
|
||||
} else {
|
||||
return hostCountImpl(addressMask.toV6());
|
||||
return netmaskWeightImpl(addressMask.toV6());
|
||||
}
|
||||
}
|
||||
|
||||
AuthAllowedSubnet AuthAllowedSubnet::fromString(std::string_view addressString) {
|
||||
auto pos = addressString.find('/');
|
||||
if (pos == std::string_view::npos) {
|
||||
fmt::print("ERROR: {} is not a valid (use Network-Prefix/hostcount syntax)\n");
|
||||
fmt::print("ERROR: {} is not a valid (use Network-Prefix/netmaskWeight syntax)\n");
|
||||
throw invalid_option();
|
||||
}
|
||||
auto address = addressString.substr(0, pos);
|
||||
auto hostCount = std::stoi(std::string(addressString.substr(pos + 1)));
|
||||
auto netmaskWeight = std::stoi(std::string(addressString.substr(pos + 1)));
|
||||
auto addr = boost::asio::ip::make_address(address);
|
||||
if (addr.is_v4()) {
|
||||
auto bM = createBitMask(addr.to_v4().to_bytes(), hostCount);
|
||||
auto bM = createBitMask(addr.to_v4().to_bytes(), netmaskWeight);
|
||||
// we typically would expect a base address has been passed, but to be safe we still
|
||||
// will make the last bits 0
|
||||
auto mask = boost::asio::ip::address_v4(bM).to_uint();
|
||||
|
@ -92,7 +92,7 @@ AuthAllowedSubnet AuthAllowedSubnet::fromString(std::string_view addressString)
|
|||
printIP("Mask:", IPAddress(mask));
|
||||
return AuthAllowedSubnet(IPAddress(baseAddress), IPAddress(mask));
|
||||
} else {
|
||||
auto mask = createBitMask(addr.to_v6().to_bytes(), hostCount);
|
||||
auto mask = createBitMask(addr.to_v6().to_bytes(), netmaskWeight);
|
||||
auto baseAddress = addr.to_v6().to_bytes();
|
||||
for (int i = 0; i < mask.size(); ++i) {
|
||||
baseAddress[i] &= mask[i];
|
||||
|
@ -149,7 +149,8 @@ IPAddress parseAddr(std::string const& str) {
|
|||
|
||||
struct SubNetTest {
|
||||
AuthAllowedSubnet subnet;
|
||||
SubNetTest(AuthAllowedSubnet&& subnet) : subnet(subnet) {}
|
||||
SubNetTest(AuthAllowedSubnet&& subnet) : subnet(std::move(subnet)) {}
|
||||
SubNetTest(AuthAllowedSubnet const& subnet) : subnet(subnet) {}
|
||||
template <bool V4>
|
||||
static SubNetTest randomSubNetImpl() {
|
||||
constexpr int width = V4 ? 4 : 16;
|
||||
|
@ -162,7 +163,7 @@ struct SubNetTest {
|
|||
}
|
||||
binAddr[i] = rnd[i % 4];
|
||||
}
|
||||
auto hostCount = deterministicRandom()->randomInt(1, width);
|
||||
auto netmaskWeight = deterministicRandom()->randomInt(1, width);
|
||||
std::string address;
|
||||
if constexpr (V4) {
|
||||
address_v4 a(binAddr);
|
||||
|
@ -171,7 +172,7 @@ struct SubNetTest {
|
|||
address_v6 a(binAddr);
|
||||
address = a.to_string();
|
||||
}
|
||||
return SubNetTest(AuthAllowedSubnet::fromString(fmt::format("{}/{}", address, hostCount)));
|
||||
return SubNetTest(AuthAllowedSubnet::fromString(fmt::format("{}/{}", address, netmaskWeight)));
|
||||
}
|
||||
static SubNetTest randomSubNet() {
|
||||
if (deterministicRandom()->coinflip()) {
|
||||
|
@ -182,7 +183,7 @@ struct SubNetTest {
|
|||
}
|
||||
|
||||
template <bool V4>
|
||||
IPAddress intArrayToAddress(uint32_t* arr) {
|
||||
static IPAddress intArrayToAddress(uint32_t* arr) {
|
||||
if constexpr (V4) {
|
||||
return IPAddress(arr[0]);
|
||||
} else {
|
||||
|
@ -197,16 +198,22 @@ struct SubNetTest {
|
|||
return (val & subnetMask) ^ baseAddress;
|
||||
}
|
||||
|
||||
template<bool V4>
|
||||
static IPAddress randomAddress() {
|
||||
constexpr int width = V4 ? 4 : 16;
|
||||
uint32_t rnd[width / 4];
|
||||
for (int i = 0; i < width / 4; ++i) {
|
||||
rnd[i] = deterministicRandom()->randomUInt32();
|
||||
}
|
||||
return intArrayToAddress<V4>(rnd);
|
||||
}
|
||||
|
||||
template <bool V4>
|
||||
IPAddress randomAddress(bool inSubnet) {
|
||||
ASSERT(V4 == subnet.baseAddress.isV4() || !inSubnet);
|
||||
constexpr int width = V4 ? 4 : 16;
|
||||
for (;;) {
|
||||
uint32_t rnd[width / 4];
|
||||
for (int i = 0; i < width / 4; ++i) {
|
||||
rnd[i] = deterministicRandom()->randomUInt32();
|
||||
}
|
||||
auto res = intArrayToAddress<V4>(rnd);
|
||||
auto res = randomAddress<V4>();
|
||||
if (V4 != subnet.baseAddress.isV4()) {
|
||||
return res;
|
||||
}
|
||||
|
@ -272,6 +279,42 @@ TEST_CASE("/fdbrpc/allow_list") {
|
|||
::subnetAssert(allowList, parseAddr("5.2.1.1"), true);
|
||||
::subnetAssert(allowList, parseAddr("128.0.1.1"), false);
|
||||
::subnetAssert(allowList, parseAddr("192.168.3.1"), false);
|
||||
allowList = IPAllowList();
|
||||
allowList.addTrustedSubnet("0.0.0.0/0");
|
||||
SubNetTest subnetTest(allowList.subnets()[0]);
|
||||
for (int i = 0; i < 10; ++i) {
|
||||
// All IPv4 addresses are in the allow list
|
||||
::subnetAssert(allowList, subnetTest.randomAddress<true>(), true);
|
||||
// No IPv6 addresses are in the allow list
|
||||
::subnetAssert(allowList, subnetTest.randomAddress<false>(), false);
|
||||
}
|
||||
allowList = IPAllowList();
|
||||
allowList.addTrustedSubnet("::/0");
|
||||
subnetTest = SubNetTest(allowList.subnets()[0]);
|
||||
for (int i = 0; i < 10; ++i) {
|
||||
// All IPv6 addresses are in the allow list
|
||||
::subnetAssert(allowList, subnetTest.randomAddress<false>(), true);
|
||||
// No IPv4 addresses are ub the allow list
|
||||
::subnetAssert(allowList, subnetTest.randomAddress<true>(), false);
|
||||
}
|
||||
allowList = IPAllowList();
|
||||
IPAddress baseAddress = SubNetTest::randomAddress<true>();
|
||||
allowList.addTrustedSubnet(fmt::format("{}/32", baseAddress.toString()));
|
||||
for (int i = 0; i < 10; ++i) {
|
||||
auto rnd = SubNetTest::randomAddress<true>();
|
||||
::subnetAssert(allowList, rnd, rnd == baseAddress);
|
||||
rnd = SubNetTest::randomAddress<false>();
|
||||
::subnetAssert(allowList, rnd, false);
|
||||
}
|
||||
allowList = IPAllowList();
|
||||
baseAddress = SubNetTest::randomAddress<false>();
|
||||
allowList.addTrustedSubnet(fmt::format("{}/128", baseAddress.toString()));
|
||||
for (int i = 0; i < 10; ++i) {
|
||||
auto rnd = SubNetTest::randomAddress<false>();
|
||||
::subnetAssert(allowList, rnd, rnd == baseAddress);
|
||||
rnd = SubNetTest::randomAddress<true>();
|
||||
::subnetAssert(allowList, rnd, false);
|
||||
}
|
||||
for (int i = 0; i < 100; ++i) {
|
||||
SubNetTest subnetTest(SubNetTest::randomSubNet());
|
||||
allowList = IPAllowList();
|
||||
|
|
|
@ -37,15 +37,15 @@ struct AuthAllowedSubnet {
|
|||
static AuthAllowedSubnet fromString(std::string_view addressString);
|
||||
|
||||
template <std::size_t sz>
|
||||
static auto createBitMask(std::array<unsigned char, sz> const& addr, int hostCount)
|
||||
static auto createBitMask(std::array<unsigned char, sz> const& addr, int netmaskWeight)
|
||||
-> std::array<unsigned char, sz> {
|
||||
std::array<unsigned char, sz> res;
|
||||
res.fill((unsigned char)0xff);
|
||||
int idx = hostCount / 8;
|
||||
if (hostCount % 8 > 0) {
|
||||
// 2^(hostCount % 8) - 1 sets the last (hostCount % 8) number of bits to 1
|
||||
int idx = netmaskWeight / 8;
|
||||
if (netmaskWeight % 8 > 0) {
|
||||
// 2^(netmaskWeight % 8) - 1 sets the last (netmaskWeight % 8) number of bits to 1
|
||||
// everything else will be zero. For example: 2^3 - 1 == 7 == 0b111
|
||||
unsigned char bitmask = (1 << (hostCount % 8)) - ((unsigned char)1);
|
||||
unsigned char bitmask = (1 << (netmaskWeight % 8)) - ((unsigned char)1);
|
||||
res[idx] ^= bitmask;
|
||||
++idx;
|
||||
}
|
||||
|
@ -73,7 +73,7 @@ struct AuthAllowedSubnet {
|
|||
|
||||
IPAddress netmask() const;
|
||||
|
||||
int hostCount() const;
|
||||
int netmaskWeight() const;
|
||||
|
||||
// some useful helper functions if we need to debug ip masks etc
|
||||
static void printIP(std::string_view txt, IPAddress const& address);
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
/*
|
||||
* workloads.actor.h
|
||||
* ClientWorkload.actor.cpp
|
||||
*
|
||||
* This source file is part of the FoundationDB open source project
|
||||
*
|
||||
* Copyright 2013-2018 Apple Inc. and the FoundationDB project authors
|
||||
* Copyright 2013-2022 Apple Inc. and the FoundationDB project authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -26,12 +26,14 @@
|
|||
|
||||
#include <unordered_map>
|
||||
|
||||
using ContextVariableMap = std::unordered_map<std::string_view, void*>;
|
||||
|
||||
template <class Ar>
|
||||
struct LoadContext {
|
||||
Ar* ar;
|
||||
std::unordered_map<std::string_view, void*> variables;
|
||||
ContextVariableMap* variables;
|
||||
|
||||
LoadContext(Ar* ar) : ar(ar) {}
|
||||
LoadContext(Ar* ar, ContextVariableMap* variables = nullptr) : ar(ar), variables(variables) {}
|
||||
Arena& arena() { return ar->arena(); }
|
||||
|
||||
ProtocolVersion protocolVersion() const { return ar->protocolVersion(); }
|
||||
|
@ -54,19 +56,20 @@ struct LoadContext {
|
|||
|
||||
template <class T>
|
||||
bool variable(std::string_view name, T* val) {
|
||||
auto p = variables.insert(std::make_pair(name, val));
|
||||
auto p = variables->insert(std::make_pair(name, val));
|
||||
return p.second;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
T& variable(std::string_view name) {
|
||||
auto res = variables.at(name);
|
||||
auto res = variables->at(name);
|
||||
return *reinterpret_cast<T*>(res);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
T const& variable(std::string_view name) const {
|
||||
return const_cast<LoadContext<Ar>*>(this)->variable<T>(name);
|
||||
auto res = variables->at(name);
|
||||
return *reinterpret_cast<T*>(res);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -96,6 +99,9 @@ public:
|
|||
_ObjectReader() : context(static_cast<ReaderImpl*>(this)) {}
|
||||
ProtocolVersion protocolVersion() const { return mProtocolVersion; }
|
||||
void setProtocolVersion(ProtocolVersion v) { mProtocolVersion = v; }
|
||||
void setContextVariableMap(ContextVariableMap* cvm) {
|
||||
context.variables = cvm;
|
||||
}
|
||||
|
||||
template <class... Items>
|
||||
void deserialize(FileIdentifier file_identifier, Items&... items) {
|
||||
|
|
Loading…
Reference in New Issue