Merge pull request #6401 from sfc-gh-mpilman/features/private-request-streams

Features/private request streams
This commit is contained in:
Vaidas Gasiunas 2022-04-11 18:29:06 +02:00 committed by GitHub
commit ca563466a6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
42 changed files with 1453 additions and 203 deletions

View File

@ -172,6 +172,7 @@ endif()
include(CompileBoost)
include(GetMsgpack)
add_subdirectory(contrib)
add_subdirectory(flow)
add_subdirectory(fdbrpc)
add_subdirectory(fdbclient)
@ -183,7 +184,6 @@ else()
add_subdirectory(fdbservice)
endif()
add_subdirectory(fdbbackup)
add_subdirectory(contrib)
add_subdirectory(tests)
add_subdirectory(flowbench EXCLUDE_FROM_ALL)
if(WITH_PYTHON AND WITH_C_BINDING)

View File

@ -568,7 +568,7 @@ ACTOR Future<Void> applyMutations(Database cx,
Key removePrefix,
Version beginVersion,
Version* endVersion,
RequestStream<CommitTransactionRequest> commit,
PublicRequestStream<CommitTransactionRequest> commit,
NotifiedVersion* committedVersion,
Reference<KeyRangeMap<Version>> keyVersion);
ACTOR Future<Void> cleanupBackup(Database cx, DeleteData deleteData);

View File

@ -598,7 +598,7 @@ ACTOR Future<int> dumpData(Database cx,
Key uid,
Key addPrefix,
Key removePrefix,
RequestStream<CommitTransactionRequest> commit,
PublicRequestStream<CommitTransactionRequest> commit,
NotifiedVersion* committedVersion,
Optional<Version> endVersion,
Key rangeBegin,
@ -675,7 +675,7 @@ ACTOR Future<int> dumpData(Database cx,
ACTOR Future<Void> coalesceKeyVersionCache(Key uid,
Version endVersion,
Reference<KeyRangeMap<Version>> keyVersion,
RequestStream<CommitTransactionRequest> commit,
PublicRequestStream<CommitTransactionRequest> commit,
NotifiedVersion* committedVersion,
PromiseStream<Future<Void>> addActor,
FlowLock* commitLock) {
@ -725,7 +725,7 @@ ACTOR Future<Void> applyMutations(Database cx,
Key removePrefix,
Version beginVersion,
Version* endVersion,
RequestStream<CommitTransactionRequest> commit,
PublicRequestStream<CommitTransactionRequest> commit,
NotifiedVersion* committedVersion,
Reference<KeyRangeMap<Version>> keyVersion) {
state FlowLock commitLock(CLIENT_KNOBS->BACKUP_LOCK_BYTES);

View File

@ -43,13 +43,13 @@ struct CommitProxyInterface {
Optional<Key> processId;
bool provisional;
RequestStream<struct CommitTransactionRequest> commit;
RequestStream<struct GetReadVersionRequest>
PublicRequestStream<struct CommitTransactionRequest> commit;
PublicRequestStream<struct GetReadVersionRequest>
getConsistentReadVersion; // Returns a version which (1) is committed, and (2) is >= the latest version reported
// committed (by a commit response) when this request was sent
// (at some point between when this request is sent and when its response is
// received, the latest version reported committed)
RequestStream<struct GetKeyServerLocationsRequest> getKeyServersLocations;
PublicRequestStream<struct GetKeyServerLocationsRequest> getKeyServersLocations;
RequestStream<struct GetStorageServerRejoinInfoRequest> getStorageServerRejoinInfo;
RequestStream<ReplyPromise<Void>> waitFailure;
@ -72,9 +72,9 @@ struct CommitProxyInterface {
serializer(ar, processId, provisional, commit);
if (Archive::isDeserializing) {
getConsistentReadVersion =
RequestStream<struct GetReadVersionRequest>(commit.getEndpoint().getAdjustedEndpoint(1));
PublicRequestStream<struct GetReadVersionRequest>(commit.getEndpoint().getAdjustedEndpoint(1));
getKeyServersLocations =
RequestStream<struct GetKeyServerLocationsRequest>(commit.getEndpoint().getAdjustedEndpoint(2));
PublicRequestStream<struct GetKeyServerLocationsRequest>(commit.getEndpoint().getAdjustedEndpoint(2));
getStorageServerRejoinInfo =
RequestStream<struct GetStorageServerRejoinInfoRequest>(commit.getEndpoint().getAdjustedEndpoint(3));
waitFailure = RequestStream<ReplyPromise<Void>>(commit.getEndpoint().getAdjustedEndpoint(4));

View File

@ -33,8 +33,8 @@
const int MAX_CLUSTER_FILE_BYTES = 60000;
struct ClientLeaderRegInterface {
RequestStream<struct GetLeaderRequest> getLeader;
RequestStream<struct OpenDatabaseCoordRequest> openDatabase;
PublicRequestStream<struct GetLeaderRequest> getLeader;
PublicRequestStream<struct OpenDatabaseCoordRequest> openDatabase;
RequestStream<struct CheckDescriptorMutableRequest> checkDescriptorMutable;
Optional<Hostname> hostname;

View File

@ -36,7 +36,7 @@ struct GrvProxyInterface {
Optional<Key> processId;
bool provisional;
RequestStream<struct GetReadVersionRequest>
PublicRequestStream<struct GetReadVersionRequest>
getConsistentReadVersion; // Returns a version which (1) is committed, and (2) is >= the latest version reported
// committed (by a commit response) when this request was sent
// (at some point between when this request is sent and when its response is received, the latest version reported

View File

@ -105,11 +105,11 @@ namespace {
TransactionLineageCollector transactionLineageCollector;
NameLineageCollector nameLineageCollector;
template <class Interface, class Request>
template <class Interface, class Request, bool P>
Future<REPLY_TYPE(Request)> loadBalance(
DatabaseContext* ctx,
const Reference<LocationInfo> alternatives,
RequestStream<Request> Interface::*channel,
RequestStream<Request, P> Interface::*channel,
const Request& request = Request(),
TaskPriority taskID = TaskPriority::DefaultPromiseEndpoint,
AtMostOnce atMostOnce =
@ -2541,7 +2541,7 @@ void stopNetwork() {
if (!g_network)
throw network_not_setup();
TraceEvent("ClientStopNetwork");
TraceEvent("ClientStopNetwork").log();
g_network->stop();
closeTraceFile();
}
@ -3771,7 +3771,7 @@ void transformRangeLimits(GetRangeLimits limits, Reverse reverse, GetKeyValuesFa
}
template <class GetKeyValuesFamilyRequest>
RequestStream<GetKeyValuesFamilyRequest> StorageServerInterface::*getRangeRequestStream() {
PublicRequestStream<GetKeyValuesFamilyRequest> StorageServerInterface::*getRangeRequestStream() {
if constexpr (std::is_same<GetKeyValuesFamilyRequest, GetKeyValuesRequest>::value) {
return &StorageServerInterface::getKeyValues;
} else if (std::is_same<GetKeyValuesFamilyRequest, GetMappedKeyValuesRequest>::value) {
@ -4597,9 +4597,9 @@ static Future<Void> tssStreamComparison(Request request,
// Currently only used for GetKeyValuesStream but could easily be plugged for other stream types
// User of the stream has to forward the SS's responses to the returned promise stream, if it is set
template <class Request>
template <class Request, bool P>
Optional<TSSDuplicateStreamData<REPLYSTREAM_TYPE(Request)>>
maybeDuplicateTSSStreamFragment(Request& req, QueueModel* model, RequestStream<Request> const* ssStream) {
maybeDuplicateTSSStreamFragment(Request& req, QueueModel* model, RequestStream<Request, P> const* ssStream) {
if (model) {
Optional<TSSEndpointData> tssData = model->getTssData(ssStream->getEndpoint().token.first());

View File

@ -63,13 +63,13 @@ struct StorageServerInterface {
UID uniqueID;
Optional<UID> tssPairID;
RequestStream<struct GetValueRequest> getValue;
RequestStream<struct GetKeyRequest> getKey;
PublicRequestStream<struct GetValueRequest> getValue;
PublicRequestStream<struct GetKeyRequest> getKey;
// Throws a wrong_shard_server if the keys in the request or result depend on data outside this server OR if a large
// selector offset prevents all data from being read in one range read
RequestStream<struct GetKeyValuesRequest> getKeyValues;
RequestStream<struct GetMappedKeyValuesRequest> getMappedKeyValues;
PublicRequestStream<struct GetKeyValuesRequest> getKeyValues;
PublicRequestStream<struct GetMappedKeyValuesRequest> getMappedKeyValues;
RequestStream<struct GetShardStateRequest> getShardState;
RequestStream<struct WaitMetricsRequest> waitMetrics;
@ -79,17 +79,17 @@ struct StorageServerInterface {
RequestStream<struct StorageQueuingMetricsRequest> getQueuingMetrics;
RequestStream<ReplyPromise<KeyValueStoreType>> getKeyValueStoreType;
RequestStream<struct WatchValueRequest> watchValue;
PublicRequestStream<struct WatchValueRequest> watchValue;
RequestStream<struct ReadHotSubRangeRequest> getReadHotRanges;
RequestStream<struct SplitRangeRequest> getRangeSplitPoints;
RequestStream<struct GetKeyValuesStreamRequest> getKeyValuesStream;
RequestStream<struct ChangeFeedStreamRequest> changeFeedStream;
RequestStream<struct OverlappingChangeFeedsRequest> overlappingChangeFeeds;
RequestStream<struct ChangeFeedPopRequest> changeFeedPop;
RequestStream<struct ChangeFeedVersionUpdateRequest> changeFeedVersionUpdate;
RequestStream<struct GetCheckpointRequest> checkpoint;
RequestStream<struct FetchCheckpointRequest> fetchCheckpoint;
RequestStream<struct FetchCheckpointKeyValuesRequest> fetchCheckpointKeyValues;
PublicRequestStream<struct GetKeyValuesStreamRequest> getKeyValuesStream;
PublicRequestStream<struct ChangeFeedStreamRequest> changeFeedStream;
PublicRequestStream<struct OverlappingChangeFeedsRequest> overlappingChangeFeeds;
PublicRequestStream<struct ChangeFeedPopRequest> changeFeedPop;
PublicRequestStream<struct ChangeFeedVersionUpdateRequest> changeFeedVersionUpdate;
PublicRequestStream<struct GetCheckpointRequest> checkpoint;
PublicRequestStream<struct FetchCheckpointRequest> fetchCheckpoint;
PublicRequestStream<struct FetchCheckpointKeyValuesRequest> fetchCheckpointKeyValues;
private:
bool acceptingRequests;
@ -123,8 +123,9 @@ public:
serializer(ar, uniqueID, locality, getValue);
}
if (Ar::isDeserializing) {
getKey = RequestStream<struct GetKeyRequest>(getValue.getEndpoint().getAdjustedEndpoint(1));
getKeyValues = RequestStream<struct GetKeyValuesRequest>(getValue.getEndpoint().getAdjustedEndpoint(2));
getKey = PublicRequestStream<struct GetKeyRequest>(getValue.getEndpoint().getAdjustedEndpoint(1));
getKeyValues =
PublicRequestStream<struct GetKeyValuesRequest>(getValue.getEndpoint().getAdjustedEndpoint(2));
getShardState =
RequestStream<struct GetShardStateRequest>(getValue.getEndpoint().getAdjustedEndpoint(3));
waitMetrics = RequestStream<struct WaitMetricsRequest>(getValue.getEndpoint().getAdjustedEndpoint(4));
@ -136,27 +137,29 @@ public:
RequestStream<struct StorageQueuingMetricsRequest>(getValue.getEndpoint().getAdjustedEndpoint(8));
getKeyValueStoreType =
RequestStream<ReplyPromise<KeyValueStoreType>>(getValue.getEndpoint().getAdjustedEndpoint(9));
watchValue = RequestStream<struct WatchValueRequest>(getValue.getEndpoint().getAdjustedEndpoint(10));
watchValue =
PublicRequestStream<struct WatchValueRequest>(getValue.getEndpoint().getAdjustedEndpoint(10));
getReadHotRanges =
RequestStream<struct ReadHotSubRangeRequest>(getValue.getEndpoint().getAdjustedEndpoint(11));
getRangeSplitPoints =
RequestStream<struct SplitRangeRequest>(getValue.getEndpoint().getAdjustedEndpoint(12));
getKeyValuesStream =
RequestStream<struct GetKeyValuesStreamRequest>(getValue.getEndpoint().getAdjustedEndpoint(13));
getMappedKeyValues =
RequestStream<struct GetMappedKeyValuesRequest>(getValue.getEndpoint().getAdjustedEndpoint(14));
getKeyValuesStream = PublicRequestStream<struct GetKeyValuesStreamRequest>(
getValue.getEndpoint().getAdjustedEndpoint(13));
getMappedKeyValues = PublicRequestStream<struct GetMappedKeyValuesRequest>(
getValue.getEndpoint().getAdjustedEndpoint(14));
changeFeedStream =
RequestStream<struct ChangeFeedStreamRequest>(getValue.getEndpoint().getAdjustedEndpoint(15));
overlappingChangeFeeds =
RequestStream<struct OverlappingChangeFeedsRequest>(getValue.getEndpoint().getAdjustedEndpoint(16));
PublicRequestStream<struct ChangeFeedStreamRequest>(getValue.getEndpoint().getAdjustedEndpoint(15));
overlappingChangeFeeds = PublicRequestStream<struct OverlappingChangeFeedsRequest>(
getValue.getEndpoint().getAdjustedEndpoint(16));
changeFeedPop =
RequestStream<struct ChangeFeedPopRequest>(getValue.getEndpoint().getAdjustedEndpoint(17));
changeFeedVersionUpdate = RequestStream<struct ChangeFeedVersionUpdateRequest>(
PublicRequestStream<struct ChangeFeedPopRequest>(getValue.getEndpoint().getAdjustedEndpoint(17));
changeFeedVersionUpdate = PublicRequestStream<struct ChangeFeedVersionUpdateRequest>(
getValue.getEndpoint().getAdjustedEndpoint(18));
checkpoint = RequestStream<struct GetCheckpointRequest>(getValue.getEndpoint().getAdjustedEndpoint(19));
checkpoint =
PublicRequestStream<struct GetCheckpointRequest>(getValue.getEndpoint().getAdjustedEndpoint(19));
fetchCheckpoint =
RequestStream<struct FetchCheckpointRequest>(getValue.getEndpoint().getAdjustedEndpoint(20));
fetchCheckpointKeyValues = RequestStream<struct FetchCheckpointKeyValuesRequest>(
PublicRequestStream<struct FetchCheckpointRequest>(getValue.getEndpoint().getAdjustedEndpoint(20));
fetchCheckpointKeyValues = PublicRequestStream<struct FetchCheckpointKeyValuesRequest>(
getValue.getEndpoint().getAdjustedEndpoint(21));
}
} else {

View File

@ -28,28 +28,28 @@
* All well-known endpoints of FDB must be listed here to guarantee their uniqueness
*/
enum WellKnownEndpoints {
WLTOKEN_CLIENTLEADERREG_GETLEADER = WLTOKEN_FIRST_AVAILABLE, // 2
WLTOKEN_CLIENTLEADERREG_OPENDATABASE, // 3
WLTOKEN_LEADERELECTIONREG_CANDIDACY, // 4
WLTOKEN_LEADERELECTIONREG_ELECTIONRESULT, // 5
WLTOKEN_LEADERELECTIONREG_LEADERHEARTBEAT, // 6
WLTOKEN_LEADERELECTIONREG_FORWARD, // 7
WLTOKEN_GENERATIONREG_READ, // 8
WLTOKEN_GENERATIONREG_WRITE, // 9
WLTOKEN_CLIENTLEADERREG_GETLEADER = WLTOKEN_FIRST_AVAILABLE, // 4
WLTOKEN_CLIENTLEADERREG_OPENDATABASE, // 5
WLTOKEN_LEADERELECTIONREG_CANDIDACY, // 6
WLTOKEN_LEADERELECTIONREG_ELECTIONRESULT, // 7
WLTOKEN_LEADERELECTIONREG_LEADERHEARTBEAT, // 8
WLTOKEN_LEADERELECTIONREG_FORWARD, // 9
WLTOKEN_PROTOCOL_INFO, // 10 : the value of this endpoint should be stable and not change.
WLTOKEN_CLIENTLEADERREG_DESCRIPTOR_MUTABLE, // 11
WLTOKEN_CONFIGTXN_GETGENERATION, // 12
WLTOKEN_CONFIGTXN_GET, // 13
WLTOKEN_CONFIGTXN_GETCLASSES, // 14
WLTOKEN_CONFIGTXN_GETKNOBS, // 15
WLTOKEN_CONFIGTXN_COMMIT, // 16
WLTOKEN_CONFIGFOLLOWER_GETSNAPSHOTANDCHANGES, // 17
WLTOKEN_CONFIGFOLLOWER_GETCHANGES, // 18
WLTOKEN_CONFIGFOLLOWER_COMPACT, // 19
WLTOKEN_CONFIGFOLLOWER_ROLLFORWARD, // 20
WLTOKEN_CONFIGFOLLOWER_GETCOMMITTEDVERSION, // 21
WLTOKEN_PROCESS, // 22
WLTOKEN_RESERVED_COUNT // 23
WLTOKEN_GENERATIONREG_READ, // 11
WLTOKEN_GENERATIONREG_WRITE, // 12
WLTOKEN_CLIENTLEADERREG_DESCRIPTOR_MUTABLE, // 13
WLTOKEN_CONFIGTXN_GETGENERATION, // 14
WLTOKEN_CONFIGTXN_GET, // 15
WLTOKEN_CONFIGTXN_GETCLASSES, // 16
WLTOKEN_CONFIGTXN_GETKNOBS, // 17
WLTOKEN_CONFIGTXN_COMMIT, // 18
WLTOKEN_CONFIGFOLLOWER_GETSNAPSHOTANDCHANGES, // 19
WLTOKEN_CONFIGFOLLOWER_GETCHANGES, // 20
WLTOKEN_CONFIGFOLLOWER_COMPACT, // 21
WLTOKEN_CONFIGFOLLOWER_ROLLFORWARD, // 22
WLTOKEN_CONFIGFOLLOWER_GETCOMMITTEDVERSION, // 23
WLTOKEN_PROCESS, // 24
WLTOKEN_RESERVED_COUNT // 25
};
static_assert(WLTOKEN_PROTOCOL_INFO ==

View File

@ -16,6 +16,7 @@ set(FDBRPC_SRCS
genericactors.actor.cpp
HealthMonitor.actor.cpp
IAsyncFile.actor.cpp
IPAllowList.cpp
LoadBalance.actor.cpp
LoadBalance.actor.h
Locality.cpp

View File

@ -131,11 +131,20 @@ void SimpleFailureMonitor::endpointNotFound(Endpoint const& endpoint) {
TraceEvent(SevWarnAlways, "TooManyFailedEndpoints").suppressFor(1.0);
failedEndpoints.clear();
}
failedEndpoints.insert(endpoint);
failedEndpoints.emplace(endpoint, FailedReason::NOT_FOUND);
}
endpointKnownFailed.trigger(endpoint);
}
void SimpleFailureMonitor::unauthorizedEndpoint(Endpoint const& endpoint) {
TraceEvent(g_network->isSimulated() ? SevWarnAlways : SevError, "TriedAccessPrivateEndpoint")
.suppressFor(1.0)
.detail("Address", endpoint.getPrimaryAddress())
.detail("Token", endpoint.token);
failedEndpoints.emplace(endpoint, FailedReason::UNAUTHORIZED);
endpointKnownFailed.trigger(endpoint);
}
void SimpleFailureMonitor::notifyDisconnect(NetworkAddress const& address) {
//TraceEvent("NotifyDisconnect").detail("Address", address);
endpointKnownFailed.triggerRange(Endpoint({ address }, UID()), Endpoint({ address }, UID(-1, -1)));
@ -208,8 +217,13 @@ bool SimpleFailureMonitor::permanentlyFailed(Endpoint const& endpoint) const {
return failedEndpoints.count(endpoint);
}
bool SimpleFailureMonitor::knownUnauthorized(Endpoint const& endpoint) const {
auto iter = failedEndpoints.find(endpoint);
return iter != failedEndpoints.end() && iter->second == FailedReason::UNAUTHORIZED;
}
void SimpleFailureMonitor::reset() {
addressStatus = std::unordered_map<NetworkAddress, FailureStatus>();
failedEndpoints = std::unordered_set<Endpoint>();
failedEndpoints = std::unordered_map<Endpoint, FailedReason>();
endpointKnownFailed.resetNoWaiting();
}

View File

@ -93,6 +93,9 @@ public:
// Only use this function when the endpoint is known to be failed
virtual void endpointNotFound(Endpoint const&) = 0;
// Inform client that it was trying to send a message to a private endpoint
virtual void unauthorizedEndpoint(Endpoint const&) = 0;
// The next time the known status for the endpoint changes, returns the new status.
virtual Future<Void> onStateChanged(Endpoint const& endpoint) = 0;
@ -108,6 +111,9 @@ public:
// Returns true if the endpoint will never become available.
virtual bool permanentlyFailed(Endpoint const& endpoint) const = 0;
// Returns true if we known we're not allowed to send messages to the remote endpoint
virtual bool knownUnauthorized(Endpoint const&) const = 0;
// Called by FlowTransport when a connection closes and a prior request or reply might be lost
virtual void notifyDisconnect(NetworkAddress const&) = 0;
@ -139,9 +145,11 @@ public:
class SimpleFailureMonitor : public IFailureMonitor {
public:
enum class FailedReason { NOT_FOUND, UNAUTHORIZED };
SimpleFailureMonitor();
void setStatus(NetworkAddress const& address, FailureStatus const& status) override;
void endpointNotFound(Endpoint const&) override;
void unauthorizedEndpoint(Endpoint const&) override;
void notifyDisconnect(NetworkAddress const&) override;
Future<Void> onStateChanged(Endpoint const& endpoint) override;
@ -151,6 +159,7 @@ public:
Future<Void> onDisconnect(NetworkAddress const& address) override;
bool onlyEndpointFailed(Endpoint const& endpoint) const override;
bool permanentlyFailed(Endpoint const& endpoint) const override;
bool knownUnauthorized(Endpoint const&) const override;
void reset();
@ -158,7 +167,7 @@ private:
std::unordered_map<NetworkAddress, FailureStatus> addressStatus;
YieldedAsyncMap<Endpoint, bool> endpointKnownFailed;
AsyncMap<NetworkAddress, bool> disconnectTriggers;
std::unordered_set<Endpoint> failedEndpoints;
std::unordered_map<Endpoint, FailedReason> failedEndpoints;
friend class OnStateChangedActorActor;
};

View File

@ -27,10 +27,12 @@
#include <memcheck.h>
#endif
#include "fdbrpc/TenantInfo.h"
#include "fdbrpc/fdbrpc.h"
#include "fdbrpc/FailureMonitor.h"
#include "fdbrpc/HealthMonitor.h"
#include "fdbrpc/genericactors.actor.h"
#include "fdbrpc/IPAllowList.h"
#include "fdbrpc/simulator.h"
#include "flow/ActorCollection.h"
#include "flow/Error.h"
@ -50,6 +52,9 @@ static Future<Void> g_currentDeliveryPeerDisconnect;
constexpr int PACKET_LEN_WIDTH = sizeof(uint32_t);
const uint64_t TOKEN_STREAM_FLAG = 1;
FDB_BOOLEAN_PARAM(InReadSocket);
FDB_BOOLEAN_PARAM(IsStableConnection);
class EndpointMap : NonCopyable {
public:
// Reserve space for this many wellKnownEndpoints
@ -205,6 +210,7 @@ struct EndpointNotFoundReceiver final : NetworkMessageReceiver {
Endpoint e = FlowTransport::transport().loadedEndpoint(token);
IFailureMonitor::failureMonitor().endpointNotFound(e);
}
bool isPublic() const override { return true; }
};
struct PingRequest {
@ -229,11 +235,52 @@ struct PingReceiver final : NetworkMessageReceiver {
PeerCompatibilityPolicy peerCompatibilityPolicy() const override {
return PeerCompatibilityPolicy{ RequirePeer::AtLeast, ProtocolVersion::withStableInterfaces() };
}
bool isPublic() const override { return true; }
};
struct TenantAuthorizer final : NetworkMessageReceiver {
TenantAuthorizer(EndpointMap& endpoints) {
endpoints.insertWellKnown(this, Endpoint::wellKnownToken(WLTOKEN_AUTH_TENANT), TaskPriority::ReadSocket);
}
void receive(ArenaObjectReader& reader) override {
AuthorizationRequest req;
try {
reader.deserialize(req);
// TODO: verify that token is valid
AuthorizedTenants& auth = reader.variable<AuthorizedTenants>("AuthorizedTenants");
for (auto const& t : req.tenants) {
auth.authorizedTenants.insert(TenantInfoRef(auth.arena, t));
}
req.reply.send(Void());
} catch (Error& e) {
if (e.code() == error_code_permission_denied) {
req.reply.sendError(e);
} else {
throw;
}
}
}
bool isPublic() const override { return true; }
};
struct UnauthorizedEndpointReceiver final : NetworkMessageReceiver {
UnauthorizedEndpointReceiver(EndpointMap& endpoints) {
endpoints.insertWellKnown(
this, Endpoint::wellKnownToken(WLTOKEN_UNAUTHORIZED_ENDPOINT), TaskPriority::ReadSocket);
}
void receive(ArenaObjectReader& reader) override {
UID token;
reader.deserialize(token);
Endpoint e = FlowTransport::transport().loadedEndpoint(token);
IFailureMonitor::failureMonitor().unauthorizedEndpoint(e);
}
bool isPublic() const override { return true; }
};
class TransportData {
public:
TransportData(uint64_t transportId, int maxWellKnownEndpoints);
TransportData(uint64_t transportId, int maxWellKnownEndpoints, IPAllowList const* allowList);
~TransportData();
@ -264,6 +311,8 @@ public:
EndpointMap endpoints;
EndpointNotFoundReceiver endpointNotFoundReceiver{ endpoints };
PingReceiver pingReceiver{ endpoints };
TenantAuthorizer tenantReceiver{ endpoints };
UnauthorizedEndpointReceiver unauthorizedEndpointReceiver{ endpoints };
Int64MetricHandle bytesSent;
Int64MetricHandle countPacketsReceived;
@ -278,6 +327,8 @@ public:
std::map<uint64_t, double> multiVersionConnections;
double lastIncompatibleMessage;
uint64_t transportId;
IPAllowList allowList;
std::shared_ptr<ContextVariableMap> localCVM = std::make_shared<ContextVariableMap>(); // for local delivery
Future<Void> multiVersionCleanup;
Future<Void> pingLogger;
@ -340,9 +391,10 @@ ACTOR Future<Void> pingLatencyLogger(TransportData* self) {
}
}
TransportData::TransportData(uint64_t transportId, int maxWellKnownEndpoints)
TransportData::TransportData(uint64_t transportId, int maxWellKnownEndpoints, IPAllowList const* allowList)
: warnAlwaysForLargePacket(true), endpoints(maxWellKnownEndpoints), endpointNotFoundReceiver(endpoints),
pingReceiver(endpoints), numIncompatibleConnections(0), lastIncompatibleMessage(0), transportId(transportId) {
pingReceiver(endpoints), numIncompatibleConnections(0), lastIncompatibleMessage(0), transportId(transportId),
allowList(allowList == nullptr ? IPAllowList() : *allowList) {
degraded = makeReference<AsyncVar<bool>>(false);
pingLogger = pingLatencyLogger(this);
}
@ -880,7 +932,8 @@ void Peer::onIncomingConnection(Reference<Peer> self, Reference<IConnection> con
.suppressFor(1.0)
.detail("FromAddr", conn->getPeerAddress())
.detail("CanonicalAddr", destination)
.detail("IsPublic", destination.isPublic());
.detail("IsPublic", destination.isPublic())
.detail("Trusted", self->transport->allowList(conn->getPeerAddress().ip));
connect.cancel();
prependConnectPacket();
@ -927,7 +980,10 @@ ACTOR static void deliver(TransportData* self,
Endpoint destination,
TaskPriority priority,
ArenaReader reader,
bool inReadSocket,
NetworkAddress peerAddress,
Reference<AuthorizedTenants> authorizedTenants,
std::shared_ptr<ContextVariableMap> cvm,
InReadSocket inReadSocket,
Future<Void> disconnect) {
// We want to run the task at the right priority. If the priority is higher than the current priority (which is
// ReadSocket) we can just upgrade. Otherwise we'll context switch so that we don't block other tasks that might run
@ -941,7 +997,7 @@ ACTOR static void deliver(TransportData* self,
}
auto receiver = self->endpoints.get(destination.token);
if (receiver) {
if (receiver && (authorizedTenants->trusted || receiver->isPublic())) {
if (!checkCompatible(receiver->peerCompatibilityPolicy(), reader.protocolVersion())) {
return;
}
@ -951,6 +1007,7 @@ ACTOR static void deliver(TransportData* self,
StringRef data = reader.arenaReadAll();
ASSERT(data.size() > 8);
ArenaObjectReader objReader(reader.arena(), reader.arenaReadAll(), AssumeVersion(reader.protocolVersion()));
objReader.setContextVariableMap(cvm);
receiver->receive(objReader);
g_currentDeliveryPeerAddress = { NetworkAddress() };
g_currentDeliveryPeerDisconnect = Future<Void>();
@ -968,18 +1025,31 @@ ACTOR static void deliver(TransportData* self,
}
} else if (destination.token.first() & TOKEN_STREAM_FLAG) {
// We don't have the (stream) endpoint 'token', notify the remote machine
if (destination.token.first() != -1) {
if (self->isLocalAddress(destination.getPrimaryAddress())) {
sendLocal(self,
SerializeSource<UID>(destination.token),
Endpoint::wellKnown(destination.addresses, WLTOKEN_ENDPOINT_NOT_FOUND));
} else {
Reference<Peer> peer = self->getOrOpenPeer(destination.getPrimaryAddress());
sendPacket(self,
peer,
SerializeSource<UID>(destination.token),
Endpoint::wellKnown(destination.addresses, WLTOKEN_ENDPOINT_NOT_FOUND),
false);
if (receiver) {
TraceEvent(SevWarnAlways, "AttemptedRPCToPrivatePrevented")
.detail("From", peerAddress)
.detail("Token", destination.token);
ASSERT(!self->isLocalAddress(destination.getPrimaryAddress()));
Reference<Peer> peer = self->getOrOpenPeer(destination.getPrimaryAddress());
sendPacket(self,
peer,
SerializeSource<UID>(destination.token),
Endpoint::wellKnown(destination.addresses, WLTOKEN_UNAUTHORIZED_ENDPOINT),
false);
} else {
if (destination.token.first() != -1) {
if (self->isLocalAddress(destination.getPrimaryAddress())) {
sendLocal(self,
SerializeSource<UID>(destination.token),
Endpoint::wellKnown(destination.addresses, WLTOKEN_ENDPOINT_NOT_FOUND));
} else {
Reference<Peer> peer = self->getOrOpenPeer(destination.getPrimaryAddress());
sendPacket(self,
peer,
SerializeSource<UID>(destination.token),
Endpoint::wellKnown(destination.addresses, WLTOKEN_ENDPOINT_NOT_FOUND),
false);
}
}
}
}
@ -990,9 +1060,11 @@ static void scanPackets(TransportData* transport,
const uint8_t* e,
Arena& arena,
NetworkAddress const& peerAddress,
Reference<AuthorizedTenants> const& authorizedTenants,
std::shared_ptr<ContextVariableMap> cvm,
ProtocolVersion peerProtocolVersion,
Future<Void> disconnect,
bool isStableConnection) {
IsStableConnection isStableConnection) {
// Find each complete packet in the given byte range and queue a ready task to deliver it.
// Remove the complete packets from the range by increasing unprocessed_begin.
// There won't be more than 64K of data plus one packet, so this shouldn't take a long time.
@ -1106,7 +1178,15 @@ static void scanPackets(TransportData* transport,
// we have many messages to UnknownEndpoint we want to optimize earlier. As deliver is an actor it
// will allocate some state on the heap and this prevents it from doing that.
if (priority != TaskPriority::UnknownEndpoint || (token.first() & TOKEN_STREAM_FLAG) != 0) {
deliver(transport, Endpoint({ peerAddress }, token), priority, std::move(reader), true, disconnect);
deliver(transport,
Endpoint({ peerAddress }, token),
priority,
std::move(reader),
peerAddress,
authorizedTenants,
cvm,
InReadSocket::True,
disconnect);
}
unprocessed_begin = p = p + packetLen;
@ -1152,8 +1232,14 @@ ACTOR static Future<Void> connectionReader(TransportData* transport,
state bool incompatibleProtocolVersionNewer = false;
state NetworkAddress peerAddress;
state ProtocolVersion peerProtocolVersion;
state Reference<AuthorizedTenants> authorizedTenants = makeReference<AuthorizedTenants>();
state std::shared_ptr<ContextVariableMap> cvm = std::make_shared<ContextVariableMap>();
peerAddress = conn->getPeerAddress();
authorizedTenants->trusted = transport->allowList(conn->getPeerAddress().ip);
(*cvm)["AuthorizedTenants"] = &authorizedTenants;
(*cvm)["PeerAddress"] = &peerAddress;
authorizedTenants->trusted = transport->allowList(peerAddress.ip);
if (!peer) {
ASSERT(!peerAddress.isPublic());
}
@ -1306,9 +1392,11 @@ ACTOR static Future<Void> connectionReader(TransportData* transport,
unprocessed_end,
arena,
peerAddress,
authorizedTenants,
cvm,
peerProtocolVersion,
peer->disconnect.getFuture(),
g_network->isSimulated() && conn->isStableConnection());
IsStableConnection(g_network->isSimulated() && conn->isStableConnection()));
} else {
unprocessed_begin = unprocessed_end;
peer->resetPing.trigger();
@ -1452,8 +1540,8 @@ ACTOR static Future<Void> multiVersionCleanupWorker(TransportData* self) {
}
}
FlowTransport::FlowTransport(uint64_t transportId, int maxWellKnownEndpoints)
: self(new TransportData(transportId, maxWellKnownEndpoints)) {
FlowTransport::FlowTransport(uint64_t transportId, int maxWellKnownEndpoints, IPAllowList const* allowList)
: self(new TransportData(transportId, maxWellKnownEndpoints, allowList)) {
self->multiVersionCleanup = multiVersionCleanupWorker(self);
}
@ -1473,6 +1561,10 @@ NetworkAddress FlowTransport::getLocalAddress() const {
return self->localAddresses.address;
}
void FlowTransport::setLocalAddress(NetworkAddress const& address) {
self->localAddresses.address = address;
}
const std::unordered_map<NetworkAddress, Reference<Peer>>& FlowTransport::getAllPeers() const {
return self->peers;
}
@ -1587,11 +1679,16 @@ static void sendLocal(TransportData* self, ISerializeSource const& what, const E
ASSERT(copy.size() > 0);
TaskPriority priority = self->endpoints.getPriority(destination.token);
if (priority != TaskPriority::UnknownEndpoint || (destination.token.first() & TOKEN_STREAM_FLAG) != 0) {
Reference<AuthorizedTenants> authorizedTenants = makeReference<AuthorizedTenants>();
authorizedTenants->trusted = true;
deliver(self,
destination,
priority,
ArenaReader(copy.arena(), copy, AssumeVersion(currentProtocolVersion)),
false,
NetworkAddress(),
authorizedTenants,
self->localCVM,
InReadSocket::False,
Never());
}
}
@ -1792,9 +1889,12 @@ bool FlowTransport::incompatibleOutgoingConnectionsPresent() {
return self->numIncompatibleConnections > 0;
}
void FlowTransport::createInstance(bool isClient, uint64_t transportId, int maxWellKnownEndpoints) {
void FlowTransport::createInstance(bool isClient,
uint64_t transportId,
int maxWellKnownEndpoints,
IPAllowList const* allowList) {
g_network->setGlobal(INetwork::enFlowTransport,
(flowGlobalType) new FlowTransport(transportId, maxWellKnownEndpoints));
(flowGlobalType) new FlowTransport(transportId, maxWellKnownEndpoints, allowList));
g_network->setGlobal(INetwork::enNetworkAddressFunc, (flowGlobalType)&FlowTransport::getGlobalLocalAddress);
g_network->setGlobal(INetwork::enNetworkAddressesFunc, (flowGlobalType)&FlowTransport::getGlobalLocalAddresses);
g_network->setGlobal(INetwork::enFailureMonitor, (flowGlobalType) new SimpleFailureMonitor());

View File

@ -31,7 +31,13 @@
#include "flow/Net2Packet.h"
#include "fdbrpc/ContinuousSample.h"
enum { WLTOKEN_ENDPOINT_NOT_FOUND = 0, WLTOKEN_PING_PACKET, WLTOKEN_FIRST_AVAILABLE };
enum {
WLTOKEN_ENDPOINT_NOT_FOUND = 0,
WLTOKEN_PING_PACKET,
WLTOKEN_AUTH_TENANT,
WLTOKEN_UNAUTHORIZED_ENDPOINT,
WLTOKEN_FIRST_AVAILABLE
};
#pragma pack(push, 4)
class Endpoint {
@ -129,6 +135,7 @@ class NetworkMessageReceiver {
public:
virtual void receive(ArenaObjectReader&) = 0;
virtual bool isStream() const { return false; }
virtual bool isPublic() const = 0;
virtual PeerCompatibilityPolicy peerCompatibilityPolicy() const {
return { RequirePeer::Exactly, g_network->protocolVersion() };
}
@ -182,14 +189,19 @@ struct Peer : public ReferenceCounted<Peer> {
void onIncomingConnection(Reference<Peer> self, Reference<IConnection> conn, Future<Void> reader);
};
class IPAllowList;
class FlowTransport {
public:
FlowTransport(uint64_t transportId, int maxWellKnownEndpoints);
FlowTransport(uint64_t transportId, int maxWellKnownEndpoints, IPAllowList const* allowList);
~FlowTransport();
// Creates a new FlowTransport and makes FlowTransport::transport() return it. This uses g_network->global()
// variables, so it will be private to a simulation.
static void createInstance(bool isClient, uint64_t transportId, int maxWellKnownEndpoints);
static void createInstance(bool isClient,
uint64_t transportId,
int maxWellKnownEndpoints,
IPAllowList const* allowList = nullptr);
static bool isClient() { return g_network->global(INetwork::enClientFailureMonitor) != nullptr; }
@ -203,6 +215,9 @@ public:
// Returns first local NetworkAddress.
NetworkAddress getLocalAddress() const;
// Returns first local NetworkAddress.
void setLocalAddress(NetworkAddress const&);
// Returns all local NetworkAddress.
NetworkAddressList getLocalAddresses() const;

386
fdbrpc/IPAllowList.cpp Normal file
View File

@ -0,0 +1,386 @@
/*
* IPAllowList.cpp
*
* This source file is part of the FoundationDB open source project
*
* Copyright 2013-2022 Apple Inc. and the FoundationDB project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "flow/UnitTest.h"
#include "flow/Error.h"
#include "fdbrpc/IPAllowList.h"
#include <fmt/printf.h>
#include <fmt/format.h>
#include <bitset>
namespace {
template <std::size_t C>
std::string binRep(std::array<unsigned char, C> const& addr) {
return fmt::format("{:02x}", fmt::join(addr, ":"));
}
template <std::size_t C>
void printIP(std::array<unsigned char, C> const& addr) {
fmt::print(" {}", binRep(addr));
}
template <size_t Sz>
int netmaskWeightImpl(std::array<unsigned char, Sz> const& addr) {
int count = 0;
for (int i = 0; i < addr.size() && addr[i] != 0xff; ++i) {
std::bitset<8> b(addr[i]);
count += 8 - b.count();
}
return count;
}
} // namespace
AuthAllowedSubnet::AuthAllowedSubnet(IPAddress const& baseAddress, IPAddress const& addressMask)
: baseAddress(baseAddress), addressMask(addressMask) {
ASSERT(baseAddress.isV4() == addressMask.isV4());
}
IPAddress AuthAllowedSubnet::netmask() const {
if (addressMask.isV4()) {
uint32_t res = 0xffffffff ^ addressMask.toV4();
return IPAddress(res);
} else {
std::array<unsigned char, 16> res;
res.fill(0xff);
auto mask = addressMask.toV6();
for (int i = 0; i < mask.size(); ++i) {
res[i] ^= mask[i];
}
return IPAddress(res);
}
}
int AuthAllowedSubnet::netmaskWeight() const {
if (addressMask.isV4()) {
boost::asio::ip::address_v4 addr(netmask().toV4());
return netmaskWeightImpl(addr.to_bytes());
} else {
return netmaskWeightImpl(netmask().toV6());
}
}
AuthAllowedSubnet AuthAllowedSubnet::fromString(std::string_view addressString) {
auto pos = addressString.find('/');
if (pos == std::string_view::npos) {
fmt::print("ERROR: {} is not a valid (use Network-Prefix/netmaskWeight syntax)\n");
throw invalid_option();
}
auto address = addressString.substr(0, pos);
auto netmaskWeight = std::stoi(std::string(addressString.substr(pos + 1)));
auto addr = boost::asio::ip::make_address(address);
if (addr.is_v4()) {
auto bM = createBitMask(addr.to_v4().to_bytes(), netmaskWeight);
// we typically would expect a base address has been passed, but to be safe we still
// will make the last bits 0
auto mask = boost::asio::ip::address_v4(bM).to_uint();
auto baseAddress = addr.to_v4().to_uint() & mask;
return AuthAllowedSubnet(IPAddress(baseAddress), IPAddress(mask));
} else {
auto mask = createBitMask(addr.to_v6().to_bytes(), netmaskWeight);
auto baseAddress = addr.to_v6().to_bytes();
for (int i = 0; i < mask.size(); ++i) {
baseAddress[i] &= mask[i];
}
return AuthAllowedSubnet(IPAddress(baseAddress), IPAddress(mask));
}
}
void AuthAllowedSubnet::printIP(std::string_view txt, IPAddress const& address) {
fmt::print("{}:", txt);
if (address.isV4()) {
::printIP(boost::asio::ip::address_v4(address.toV4()).to_bytes());
} else {
::printIP(address.toV6());
}
fmt::print("\n");
}
template <std::size_t sz>
std::array<unsigned char, sz> AuthAllowedSubnet::createBitMask(std::array<unsigned char, sz> const& addr,
int netmaskWeight) {
std::array<unsigned char, sz> res;
res.fill((unsigned char)0xff);
int idx = netmaskWeight / 8;
if (netmaskWeight % 8 > 0) {
// 2^(netmaskWeight % 8) - 1 sets the last (netmaskWeight % 8) number of bits to 1
// everything else will be zero. For example: 2^3 - 1 == 7 == 0b111
unsigned char bitmask = (1 << (8 - (netmaskWeight % 8))) - ((unsigned char)1);
res[idx] ^= bitmask;
++idx;
}
for (; idx < res.size(); ++idx) {
res[idx] = (unsigned char)0;
}
return res;
}
template std::array<unsigned char, 4> AuthAllowedSubnet::createBitMask<4>(const std::array<unsigned char, 4>& addr,
int netmaskWeight);
template std::array<unsigned char, 16> AuthAllowedSubnet::createBitMask<16>(const std::array<unsigned char, 16>& addr,
int netmaskWeight);
// helpers for testing
namespace {
using boost::asio::ip::address_v4;
using boost::asio::ip::address_v6;
void traceAddress(TraceEvent& evt, const char* name, IPAddress address) {
evt.detail(name, address);
std::string bin;
if (address.isV4()) {
boost::asio::ip::address_v4 a(address.toV4());
bin = binRep(a.to_bytes());
} else {
bin = binRep(address.toV6());
}
evt.detail(fmt::format("{}Binary", name).c_str(), bin);
}
void subnetAssert(IPAllowList const& allowList, IPAddress addr, bool expectAllowed) {
if (allowList(addr) == expectAllowed) {
return;
}
TraceEvent evt(SevError, expectAllowed ? "ExpectedAddressToBeTrusted" : "ExpectedAddressToBeUntrusted");
traceAddress(evt, "Address", addr);
auto const& subnets = allowList.subnets();
for (int i = 0; i < subnets.size(); ++i) {
traceAddress(evt, fmt::format("SubnetBase{}", i).c_str(), subnets[i].baseAddress);
traceAddress(evt, fmt::format("SubnetMask{}", i).c_str(), subnets[i].addressMask);
}
}
IPAddress parseAddr(std::string const& str) {
auto res = IPAddress::parse(str);
ASSERT(res.present());
return res.get();
}
struct SubNetTest {
AuthAllowedSubnet subnet;
SubNetTest(AuthAllowedSubnet&& subnet) : subnet(std::move(subnet)) {}
SubNetTest(AuthAllowedSubnet const& subnet) : subnet(subnet) {}
template <bool V4>
static SubNetTest randomSubNetImpl() {
constexpr int width = V4 ? 4 : 16;
std::array<unsigned char, width> binAddr;
unsigned char rnd[4];
for (int i = 0; i < binAddr.size(); ++i) {
if (i % 4 == 0) {
auto tmp = deterministicRandom()->randomUInt32();
::memcpy(rnd, &tmp, 4);
}
binAddr[i] = rnd[i % 4];
}
auto netmaskWeight = deterministicRandom()->randomInt(1, width);
std::string address;
if constexpr (V4) {
address_v4 a(binAddr);
address = a.to_string();
} else {
address_v6 a(binAddr);
address = a.to_string();
}
return SubNetTest(AuthAllowedSubnet::fromString(fmt::format("{}/{}", address, netmaskWeight)));
}
static SubNetTest randomSubNet() {
if (deterministicRandom()->coinflip()) {
return randomSubNetImpl<true>();
} else {
return randomSubNetImpl<false>();
}
}
template <bool V4>
static IPAddress intArrayToAddress(uint32_t* arr) {
if constexpr (V4) {
return IPAddress(arr[0]);
} else {
std::array<unsigned char, 16> res;
memcpy(res.data(), arr, 4);
return IPAddress(res);
}
}
template <class I>
I transformIntToSubnet(I val, I subnetMask, I baseAddress) {
return (val & subnetMask) ^ baseAddress;
}
template <bool V4>
static IPAddress randomAddress() {
constexpr int width = V4 ? 4 : 16;
uint32_t rnd[width / 4];
for (int i = 0; i < width / 4; ++i) {
rnd[i] = deterministicRandom()->randomUInt32();
}
return intArrayToAddress<V4>(rnd);
}
template <bool V4>
IPAddress randomAddress(bool inSubnet) {
ASSERT(V4 == subnet.baseAddress.isV4() || !inSubnet);
for (;;) {
auto res = randomAddress<V4>();
if (V4 != subnet.baseAddress.isV4()) {
return res;
}
if (!inSubnet) {
if (!subnet(res)) {
return res;
} else {
continue;
}
}
// first we make sure the address is in the subnet
if constexpr (V4) {
auto a = res.toV4();
auto base = subnet.baseAddress.toV4();
auto netmask = subnet.netmask().toV4();
auto validAddress = transformIntToSubnet(a, netmask, base);
res = IPAddress(validAddress);
} else {
auto a = res.toV6();
auto base = subnet.baseAddress.toV6();
auto netmask = subnet.netmask().toV6();
for (int i = 0; i < a.size(); ++i) {
a[i] = transformIntToSubnet(a[i], netmask[i], base[i]);
}
res = IPAddress(a);
}
return res;
}
}
IPAddress randomAddress(bool inSubnet) {
if (!inSubnet && deterministicRandom()->random01() < 0.1) {
// return an address of a different type
if (subnet.baseAddress.isV4()) {
return randomAddress<false>(false);
} else {
return randomAddress<true>(false);
}
}
if (subnet.addressMask.isV4()) {
return randomAddress<true>(inSubnet);
} else {
return randomAddress<false>(inSubnet);
}
}
};
} // namespace
TEST_CASE("/fdbrpc/allow_list") {
// test correct weight calculation
// IPv4
for (int i = 0; i < 33; ++i) {
auto str = fmt::format("0.0.0.0/{}", i);
auto subnet = AuthAllowedSubnet::fromString(str);
if (i != subnet.netmaskWeight()) {
fmt::print("Wrong calculated weight {} for {}\n", subnet.netmaskWeight(), str);
fmt::print("\tBase address: {}\n", subnet.baseAddress.toString());
fmt::print("\tAddress Mask: {}\n", subnet.addressMask.toString());
fmt::print("\tNetmask: {}\n", subnet.netmask().toString());
ASSERT_EQ(i, subnet.netmaskWeight());
}
}
// IPv6
for (int i = 0; i < 129; ++i) {
auto subnet = AuthAllowedSubnet::fromString(fmt::format("0::/{}", i));
ASSERT_EQ(i, subnet.netmaskWeight());
}
IPAllowList allowList;
// Simulated v4 addresses
allowList.addTrustedSubnet("1.0.0.0/8");
allowList.addTrustedSubnet("2.0.0.0/4");
::subnetAssert(allowList, parseAddr("1.0.1.1"), true);
::subnetAssert(allowList, parseAddr("1.1.2.2"), true);
::subnetAssert(allowList, parseAddr("2.2.1.1"), true);
::subnetAssert(allowList, parseAddr("128.0.1.1"), false);
allowList = IPAllowList();
allowList.addTrustedSubnet("0.0.0.0/2");
allowList.addTrustedSubnet("abcd::/16");
::subnetAssert(allowList, parseAddr("1.0.1.1"), true);
::subnetAssert(allowList, parseAddr("1.1.2.2"), true);
::subnetAssert(allowList, parseAddr("2.2.1.1"), true);
::subnetAssert(allowList, parseAddr("4.0.1.2"), true);
::subnetAssert(allowList, parseAddr("5.2.1.1"), true);
::subnetAssert(allowList, parseAddr("128.0.1.1"), false);
::subnetAssert(allowList, parseAddr("192.168.3.1"), false);
// Simulated v6 addresses
::subnetAssert(allowList, parseAddr("abcd::1:2:3:4"), true);
::subnetAssert(allowList, parseAddr("abcd::2:3:3:4"), true);
::subnetAssert(allowList, parseAddr("abcd:ab:ab:fdb:2:3:3:4"), true);
::subnetAssert(allowList, parseAddr("2001:fdb1:fdb2:fdb3:fdb4:fdb5:fdb6:12"), false);
::subnetAssert(allowList, parseAddr("2001:fdb1:fdb2:fdb3:fdb4:fdb5:fdb6:1"), false);
::subnetAssert(allowList, parseAddr("2001:fdb1:fdb2:fdb3:fdb4:fdb5:fdb6:fdb"), false);
// Corner Cases
allowList = IPAllowList();
allowList.addTrustedSubnet("0.0.0.0/0");
// Random address tests
SubNetTest subnetTest(allowList.subnets()[0]);
for (int i = 0; i < 10; ++i) {
// All IPv4 addresses are in the allow list
::subnetAssert(allowList, subnetTest.randomAddress<true>(), true);
// No IPv6 addresses are in the allow list
::subnetAssert(allowList, subnetTest.randomAddress<false>(), false);
}
allowList = IPAllowList();
allowList.addTrustedSubnet("::/0");
subnetTest = SubNetTest(allowList.subnets()[0]);
for (int i = 0; i < 10; ++i) {
// All IPv6 addresses are in the allow list
::subnetAssert(allowList, subnetTest.randomAddress<false>(), true);
// No IPv4 addresses are ub the allow list
::subnetAssert(allowList, subnetTest.randomAddress<true>(), false);
}
allowList = IPAllowList();
IPAddress baseAddress = SubNetTest::randomAddress<true>();
allowList.addTrustedSubnet(fmt::format("{}/32", baseAddress.toString()));
for (int i = 0; i < 10; ++i) {
auto rnd = SubNetTest::randomAddress<true>();
::subnetAssert(allowList, rnd, rnd == baseAddress);
rnd = SubNetTest::randomAddress<false>();
::subnetAssert(allowList, rnd, false);
}
allowList = IPAllowList();
baseAddress = SubNetTest::randomAddress<false>();
allowList.addTrustedSubnet(fmt::format("{}/128", baseAddress.toString()));
for (int i = 0; i < 10; ++i) {
auto rnd = SubNetTest::randomAddress<false>();
::subnetAssert(allowList, rnd, rnd == baseAddress);
rnd = SubNetTest::randomAddress<true>();
::subnetAssert(allowList, rnd, false);
}
for (int i = 0; i < 100; ++i) {
SubNetTest subnetTest(SubNetTest::randomSubNet());
allowList = IPAllowList();
allowList.addTrustedSubnet(subnetTest.subnet);
for (int j = 0; j < 10; ++j) {
bool inSubnet = deterministicRandom()->random01() < 0.7;
auto addr = subnetTest.randomAddress(inSubnet);
::subnetAssert(allowList, addr, inSubnet);
}
}
return Void();
}

86
fdbrpc/IPAllowList.h Normal file
View File

@ -0,0 +1,86 @@
/*
* IPAllowList.h
*
* This source file is part of the FoundationDB open source project
*
* Copyright 2013-2022 Apple Inc. and the FoundationDB project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#ifndef FDBRPC_IP_ALLOW_LIST_H
#define FDBRPC_IP_ALLOW_LIST_H
#include "flow/network.h"
#include "flow/Arena.h"
struct AuthAllowedSubnet {
IPAddress baseAddress;
IPAddress addressMask;
AuthAllowedSubnet(IPAddress const& baseAddress, IPAddress const& addressMask);
static AuthAllowedSubnet fromString(std::string_view addressString);
template <std::size_t sz>
static std::array<unsigned char, sz> createBitMask(std::array<unsigned char, sz> const& addr, int netmaskWeight);
bool operator()(IPAddress const& address) const {
if (addressMask.isV4() != address.isV4()) {
return false;
}
if (addressMask.isV4()) {
return (addressMask.toV4() & address.toV4()) == baseAddress.toV4();
} else {
auto res = address.toV6();
auto const& mask = addressMask.toV6();
for (int i = 0; i < res.size(); ++i) {
res[i] &= mask[i];
}
return res == baseAddress.toV6();
}
}
IPAddress netmask() const;
int netmaskWeight() const;
// some useful helper functions if we need to debug ip masks etc
static void printIP(std::string_view txt, IPAddress const& address);
};
class IPAllowList {
std::vector<AuthAllowedSubnet> subnetList;
public:
void addTrustedSubnet(std::string_view str) { subnetList.push_back(AuthAllowedSubnet::fromString(str)); }
void addTrustedSubnet(AuthAllowedSubnet const& subnet) { subnetList.push_back(subnet); }
std::vector<AuthAllowedSubnet> const& subnets() const { return subnetList; }
bool operator()(IPAddress address) const {
if (subnetList.empty()) {
return true;
}
for (auto const& subnet : subnetList) {
if (subnet(address)) {
return true;
}
}
return false;
}
};
#endif // FDBRPC_IP_ALLOW_LIST_H

View File

@ -78,14 +78,14 @@ struct LoadBalancedReply {
Optional<LoadBalancedReply> getLoadBalancedReply(const LoadBalancedReply* reply);
Optional<LoadBalancedReply> getLoadBalancedReply(const void*);
ACTOR template <class Req, class Resp, class Interface, class Multi>
ACTOR template <class Req, class Resp, class Interface, class Multi, bool P>
Future<Void> tssComparison(Req req,
Future<ErrorOr<Resp>> fSource,
Future<ErrorOr<Resp>> fTss,
TSSEndpointData tssData,
uint64_t srcEndpointId,
Reference<MultiInterface<Multi>> ssTeam,
RequestStream<Req> Interface::*channel) {
RequestStream<Req, P> Interface::*channel) {
state double startTime = now();
state Future<Optional<ErrorOr<Resp>>> fTssWithTimeout = timeout(fTss, FLOW_KNOBS->LOAD_BALANCE_TSS_TIMEOUT);
state int finished = 0;
@ -157,7 +157,7 @@ Future<Void> tssComparison(Req req,
state std::vector<Future<ErrorOr<Resp>>> restOfTeamFutures;
restOfTeamFutures.reserve(ssTeam->size() - 1);
for (int i = 0; i < ssTeam->size(); i++) {
RequestStream<Req> const* si = &ssTeam->get(i, channel);
RequestStream<Req, P> const* si = &ssTeam->get(i, channel);
if (si->getEndpoint().token.first() !=
srcEndpointId) { // don't re-request to SS we already have a response from
resetReply(req);
@ -242,7 +242,7 @@ FDB_DECLARE_BOOLEAN_PARAM(AtMostOnce);
FDB_DECLARE_BOOLEAN_PARAM(TriedAllOptions);
// Stores state for a request made by the load balancer
template <class Request, class Interface, class Multi>
template <class Request, class Interface, class Multi, bool P>
struct RequestData : NonCopyable {
typedef ErrorOr<REPLY_TYPE(Request)> Reply;
@ -257,12 +257,12 @@ struct RequestData : NonCopyable {
// This is true once setupRequest is called, even though at that point the response is Never().
bool isValid() { return response.isValid(); }
static void maybeDuplicateTSSRequest(RequestStream<Request> const* stream,
static void maybeDuplicateTSSRequest(RequestStream<Request, P> const* stream,
Request& request,
QueueModel* model,
Future<Reply> ssResponse,
Reference<MultiInterface<Multi>> alternatives,
RequestStream<Request> Interface::*channel) {
RequestStream<Request, P> Interface::*channel) {
if (model) {
// Send parallel request to TSS pair, if it exists
Optional<TSSEndpointData> tssData = model->getTssData(stream->getEndpoint().token.first());
@ -271,7 +271,7 @@ struct RequestData : NonCopyable {
TEST(true); // duplicating request to TSS
resetReply(request);
// FIXME: optimize to avoid creating new netNotifiedQueue for each message
RequestStream<Request> tssRequestStream(tssData.get().endpoint);
RequestStream<Request, P> tssRequestStream(tssData.get().endpoint);
Future<ErrorOr<REPLY_TYPE(Request)>> fTssResult = tssRequestStream.tryGetReply(request);
model->addActor.send(tssComparison(request,
ssResponse,
@ -288,11 +288,11 @@ struct RequestData : NonCopyable {
void startRequest(
double backoff,
TriedAllOptions triedAllOptions,
RequestStream<Request> const* stream,
RequestStream<Request, P> const* stream,
Request& request,
QueueModel* model,
Reference<MultiInterface<Multi>> alternatives, // alternatives and channel passed through for TSS check
RequestStream<Request> Interface::*channel) {
RequestStream<Request, P> Interface::*channel) {
modelHolder = Reference<ModelHolder>();
requestStarted = false;
@ -438,18 +438,18 @@ struct RequestData : NonCopyable {
// list of servers.
// When model is set, load balance among alternatives in the same DC aims to balance request queue length on these
// interfaces. If too many interfaces in the same DC are bad, try remote interfaces.
ACTOR template <class Interface, class Request, class Multi>
ACTOR template <class Interface, class Request, class Multi, bool P>
Future<REPLY_TYPE(Request)> loadBalance(
Reference<MultiInterface<Multi>> alternatives,
RequestStream<Request> Interface::*channel,
RequestStream<Request, P> Interface::*channel,
Request request = Request(),
TaskPriority taskID = TaskPriority::DefaultPromiseEndpoint,
AtMostOnce atMostOnce =
AtMostOnce::False, // if true, throws request_maybe_delivered() instead of retrying automatically
QueueModel* model = nullptr) {
state RequestData<Request, Interface, Multi> firstRequestData;
state RequestData<Request, Interface, Multi> secondRequestData;
state RequestData<Request, Interface, Multi, P> firstRequestData;
state RequestData<Request, Interface, Multi, P> secondRequestData;
state Optional<uint64_t> firstRequestEndpoint;
state Future<Void> secondDelay = Never();
@ -488,7 +488,7 @@ Future<REPLY_TYPE(Request)> loadBalance(
break;
}
RequestStream<Request> const* thisStream = &alternatives->get(i, channel);
RequestStream<Request, P> const* thisStream = &alternatives->get(i, channel);
if (!IFailureMonitor::failureMonitor().getState(thisStream->getEndpoint()).failed) {
auto const& qd = model->getMeasurement(thisStream->getEndpoint().token.first());
if (now() > qd.failedUntil) {
@ -527,7 +527,7 @@ Future<REPLY_TYPE(Request)> loadBalance(
// go through all the remote servers again, since we may have
// skipped it.
for (int i = alternatives->countBest(); i < alternatives->size(); i++) {
RequestStream<Request> const* thisStream = &alternatives->get(i, channel);
RequestStream<Request, P> const* thisStream = &alternatives->get(i, channel);
if (!IFailureMonitor::failureMonitor().getState(thisStream->getEndpoint()).failed) {
auto const& qd = model->getMeasurement(thisStream->getEndpoint().token.first());
if (now() > qd.failedUntil) {
@ -574,7 +574,7 @@ Future<REPLY_TYPE(Request)> loadBalance(
if (ev.isEnabled()) {
ev.log();
for (int alternativeNum = 0; alternativeNum < alternatives->size(); alternativeNum++) {
RequestStream<Request> const* thisStream = &alternatives->get(alternativeNum, channel);
RequestStream<Request, P> const* thisStream = &alternatives->get(alternativeNum, channel);
TraceEvent(SevWarn, "LoadBalanceTooLongEndpoint")
.detail("Addr", thisStream->getEndpoint().getPrimaryAddress())
.detail("Token", thisStream->getEndpoint().token)
@ -586,7 +586,7 @@ Future<REPLY_TYPE(Request)> loadBalance(
// Find an alternative, if any, that is not failed, starting with
// nextAlt. This logic matters only if model == nullptr. Otherwise, the
// bestAlt and nextAlt have been decided.
state RequestStream<Request> const* stream = nullptr;
state RequestStream<Request, P> const* stream = nullptr;
for (int alternativeNum = 0; alternativeNum < alternatives->size(); alternativeNum++) {
int useAlt = nextAlt;
if (nextAlt == startAlt)
@ -724,9 +724,9 @@ Optional<BasicLoadBalancedReply> getBasicLoadBalancedReply(const BasicLoadBalanc
Optional<BasicLoadBalancedReply> getBasicLoadBalancedReply(const void*);
// A simpler version of LoadBalance that does not send second requests where the list of servers are always fresh
ACTOR template <class Interface, class Request, class Multi>
ACTOR template <class Interface, class Request, class Multi, bool P>
Future<REPLY_TYPE(Request)> basicLoadBalance(Reference<ModelInterface<Multi>> alternatives,
RequestStream<Request> Interface::*channel,
RequestStream<Request, P> Interface::*channel,
Request request = Request(),
TaskPriority taskID = TaskPriority::DefaultPromiseEndpoint,
AtMostOnce atMostOnce = AtMostOnce::False) {
@ -749,7 +749,7 @@ Future<REPLY_TYPE(Request)> basicLoadBalance(Reference<ModelInterface<Multi>> al
state int useAlt;
loop {
// Find an alternative, if any, that is not failed, starting with nextAlt
state RequestStream<Request> const* stream = nullptr;
state RequestStream<Request, P> const* stream = nullptr;
for (int alternativeNum = 0; alternativeNum < alternatives->size(); alternativeNum++) {
useAlt = nextAlt;
if (nextAlt == startAlt)

View File

@ -43,7 +43,7 @@ struct PerfMetric {
std::string format_code() const { return m_format_code; }
bool averaged() const { return m_averaged; }
PerfMetric withPrefix(const std::string& pre) {
PerfMetric withPrefix(const std::string& pre) const {
return PerfMetric(pre + name(), value(), Averaged{ averaged() }, format_code());
}

73
fdbrpc/TenantInfo.h Normal file
View File

@ -0,0 +1,73 @@
/*
* TenantInfo.h
*
* This source file is part of the FoundationDB open source project
*
* Copyright 2013-2022 Apple Inc. and the FoundationDB project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#ifndef FDBRPC_TENANTINFO_H_
#define FDBRPC_TENANTINFO_H_
#include "flow/Arena.h"
#include "fdbrpc/fdbrpc.h"
#include <set>
struct TenantInfoRef {
TenantInfoRef() {}
TenantInfoRef(Arena& p, StringRef toCopy) : tenantName(StringRef(p, toCopy)) {}
TenantInfoRef(Arena& p, TenantInfoRef toCopy)
: tenantName(toCopy.tenantName.present() ? Optional<StringRef>(StringRef(p, toCopy.tenantName.get()))
: Optional<StringRef>()) {}
// Empty tenant name means that the peer is trusted
Optional<StringRef> tenantName;
bool operator<(TenantInfoRef const& other) const {
if (!other.tenantName.present()) {
return false;
}
if (!tenantName.present()) {
return true;
}
return tenantName.get() < other.tenantName.get();
}
template <class Ar>
void serialize(Ar& ar) {
serializer(ar, tenantName);
}
};
struct AuthorizedTenants : ReferenceCounted<AuthorizedTenants> {
Arena arena;
std::set<TenantInfoRef> authorizedTenants;
bool trusted = false;
};
// TODO: receive and validate token instead
struct AuthorizationRequest {
constexpr static FileIdentifier file_identifier = 11499331;
Arena arena;
VectorRef<TenantInfoRef> tenants;
ReplyPromise<Void> reply;
template <class Ar>
void serialize(Ar& ar) {
serializer(ar, tenants, reply, arena);
}
};
#endif // FDBRPC_TENANTINFO_H_

View File

@ -110,6 +110,8 @@ struct NetSAV final : SAV<T>, FlowReceiver, FastAllocated<NetSAV<T>> {
SAV<T>::sendAndDelPromiseRef(message.get().asUnderlyingType());
}
}
bool isPublic() const override { return true; }
};
template <class T>
@ -290,6 +292,8 @@ struct AcknowledgementReceiver final : FlowReceiver, FastAllocated<Acknowledgeme
AcknowledgementReceiver() : ready(nullptr) {}
AcknowledgementReceiver(const Endpoint& remoteEndpoint) : FlowReceiver(remoteEndpoint, false), ready(nullptr) {}
bool isPublic() const override { return true; }
void receive(ArenaObjectReader& reader) override {
ErrorOr<AcknowledgementReply> message;
reader.deserialize(message);
@ -337,6 +341,8 @@ struct NetNotifiedQueueWithAcknowledgements final : NotifiedQueue<T>,
acknowledgements.failures = tagError<Void>(FlowTransport::transport().loadedDisconnect(), operation_obsolete());
}
bool isPublic() const override { return true; }
void destroy() override { delete this; }
void receive(ArenaObjectReader& reader) override {
this->addPromiseRef();
@ -642,10 +648,10 @@ struct serializable_traits<ReplyPromiseStream<T>> : std::true_type {
}
};
template <class T>
struct NetNotifiedQueue final : NotifiedQueue<T>, FlowReceiver, FastAllocated<NetNotifiedQueue<T>> {
using FastAllocated<NetNotifiedQueue<T>>::operator new;
using FastAllocated<NetNotifiedQueue<T>>::operator delete;
template <class T, bool IsPublic>
struct NetNotifiedQueue final : NotifiedQueue<T>, FlowReceiver, FastAllocated<NetNotifiedQueue<T, IsPublic>> {
using FastAllocated<NetNotifiedQueue<T, IsPublic>>::operator new;
using FastAllocated<NetNotifiedQueue<T, IsPublic>>::operator delete;
NetNotifiedQueue(int futures, int promises) : NotifiedQueue<T>(futures, promises) {}
NetNotifiedQueue(int futures, int promises, const Endpoint& remoteEndpoint)
@ -660,9 +666,10 @@ struct NetNotifiedQueue final : NotifiedQueue<T>, FlowReceiver, FastAllocated<Ne
this->delPromiseRef();
}
bool isStream() const override { return true; }
bool isPublic() const override { return IsPublic; }
};
template <class T>
template <class T, bool IsPublic = false>
class RequestStream {
public:
// stream.send( request )
@ -726,6 +733,9 @@ public:
Future<Void> disc =
makeDependent<T>(IFailureMonitor::failureMonitor()).onDisconnectOrFailure(getEndpoint(taskID));
if (disc.isReady()) {
if (IFailureMonitor::failureMonitor().knownUnauthorized(getEndpoint(taskID))) {
return ErrorOr<REPLY_TYPE(X)>(unauthorized_attempt());
}
return ErrorOr<REPLY_TYPE(X)>(request_maybe_delivered());
}
Reference<Peer> peer =
@ -744,6 +754,9 @@ public:
Future<Void> disc =
makeDependent<T>(IFailureMonitor::failureMonitor()).onDisconnectOrFailure(getEndpoint());
if (disc.isReady()) {
if (IFailureMonitor::failureMonitor().knownUnauthorized(getEndpoint())) {
return ErrorOr<REPLY_TYPE(X)>(unauthorized_attempt());
}
return ErrorOr<REPLY_TYPE(X)>(request_maybe_delivered());
}
Reference<Peer> peer =
@ -821,13 +834,13 @@ public:
return getReplyUnlessFailedFor(ReplyPromise<X>(), sustainedFailureDuration, sustainedFailureSlope);
}
explicit RequestStream(const Endpoint& endpoint) : queue(new NetNotifiedQueue<T>(0, 1, endpoint)) {}
explicit RequestStream(const Endpoint& endpoint) : queue(new NetNotifiedQueue<T, IsPublic>(0, 1, endpoint)) {}
FutureStream<T> getFuture() const {
queue->addFutureRef();
return FutureStream<T>(queue);
}
RequestStream() : queue(new NetNotifiedQueue<T>(0, 1)) {}
RequestStream() : queue(new NetNotifiedQueue<T, IsPublic>(0, 1)) {}
explicit RequestStream(PeerCompatibilityPolicy policy) : RequestStream() {
queue->setPeerCompatibilityPolicy(policy);
}
@ -861,8 +874,8 @@ public:
queue->makeWellKnownEndpoint(Endpoint::Token(-1, wlTokenID), taskID);
}
bool operator==(const RequestStream<T>& rhs) const { return queue == rhs.queue; }
bool operator!=(const RequestStream<T>& rhs) const { return !(*this == rhs); }
bool operator==(const RequestStream<T, IsPublic>& rhs) const { return queue == rhs.queue; }
bool operator!=(const RequestStream<T, IsPublic>& rhs) const { return !(*this == rhs); }
bool isEmpty() const { return !queue->isReady(); }
uint32_t size() const { return queue->size(); }
@ -871,32 +884,37 @@ public:
}
private:
NetNotifiedQueue<T>* queue;
NetNotifiedQueue<T, IsPublic>* queue;
};
template <class Ar, class T>
void save(Ar& ar, const RequestStream<T>& value) {
template <class T>
using PrivateRequestStream = RequestStream<T, false>;
template <class T>
using PublicRequestStream = RequestStream<T, true>;
template <class Ar, class T, bool P>
void save(Ar& ar, const RequestStream<T, P>& value) {
auto const& ep = value.getEndpoint();
ar << ep;
UNSTOPPABLE_ASSERT(
ep.getPrimaryAddress().isValid()); // No serializing PromiseStreams on a client with no public address
}
template <class Ar, class T>
void load(Ar& ar, RequestStream<T>& value) {
template <class Ar, class T, bool P>
void load(Ar& ar, RequestStream<T, P>& value) {
Endpoint endpoint;
ar >> endpoint;
value = RequestStream<T>(endpoint);
value = RequestStream<T, P>(endpoint);
}
template <class T>
struct serializable_traits<RequestStream<T>> : std::true_type {
template <class T, bool P>
struct serializable_traits<RequestStream<T, P>> : std::true_type {
template <class Archiver>
static void serialize(Archiver& ar, RequestStream<T>& stream) {
static void serialize(Archiver& ar, RequestStream<T, P>& stream) {
if constexpr (Archiver::isDeserializing) {
Endpoint endpoint;
serializer(ar, endpoint);
stream = RequestStream<T>(endpoint);
stream = RequestStream<T, P>(endpoint);
} else {
const auto& ep = stream.getEndpoint();
serializer(ar, ep);

View File

@ -32,8 +32,8 @@
#include "flow/Hostname.h"
#include "flow/actorcompiler.h" // This must be the last #include.
ACTOR template <class Req>
Future<REPLY_TYPE(Req)> retryBrokenPromise(RequestStream<Req> to, Req request) {
ACTOR template <class Req, bool P>
Future<REPLY_TYPE(Req)> retryBrokenPromise(RequestStream<Req, P> to, Req request) {
// Like to.getReply(request), except that a broken_promise exception results in retrying request immediately.
// Suitable for use with well known endpoints, which are likely to return to existence after the other process
// restarts. Not normally useful for ordinary endpoints, which conventionally are permanently destroyed after
@ -52,8 +52,8 @@ Future<REPLY_TYPE(Req)> retryBrokenPromise(RequestStream<Req> to, Req request) {
}
}
ACTOR template <class Req>
Future<REPLY_TYPE(Req)> retryBrokenPromise(RequestStream<Req> to, Req request, TaskPriority taskID) {
ACTOR template <class Req, bool P>
Future<REPLY_TYPE(Req)> retryBrokenPromise(RequestStream<Req, P> to, Req request, TaskPriority taskID) {
// Like to.getReply(request), except that a broken_promise exception results in retrying request immediately.
// Suitable for use with well known endpoints, which are likely to return to existence after the other process
// restarts. Not normally useful for ordinary endpoints, which conventionally are permanently destroyed after
@ -350,7 +350,11 @@ Future<ErrorOr<X>> waitValueOrSignal(Future<X> value,
try {
choose {
when(X x = wait(value)) { return x; }
when(wait(signal)) { return ErrorOr<X>(request_maybe_delivered()); }
when(wait(signal)) {
return ErrorOr<X>(IFailureMonitor::failureMonitor().knownUnauthorized(endpoint)
? unauthorized_attempt()
: request_maybe_delivered());
}
}
} catch (Error& e) {
if (signal.isError()) {
@ -373,12 +377,31 @@ Future<ErrorOr<X>> waitValueOrSignal(Future<X> value,
ACTOR template <class T>
Future<T> sendCanceler(ReplyPromise<T> reply, ReliablePacket* send, Endpoint endpoint) {
state bool didCancelReliable = false;
try {
T t = wait(reply.getFuture());
FlowTransport::transport().cancelReliable(send);
return t;
loop {
if (IFailureMonitor::failureMonitor().permanentlyFailed(endpoint)) {
FlowTransport::transport().cancelReliable(send);
didCancelReliable = true;
if (IFailureMonitor::failureMonitor().knownUnauthorized(endpoint)) {
throw unauthorized_attempt();
} else {
wait(Never());
}
}
choose {
when(T t = wait(reply.getFuture())) {
FlowTransport::transport().cancelReliable(send);
didCancelReliable = true;
return t;
}
when(wait(IFailureMonitor::failureMonitor().onStateChanged(endpoint))) {}
}
}
} catch (Error& e) {
FlowTransport::transport().cancelReliable(send);
if (!didCancelReliable) {
FlowTransport::transport().cancelReliable(send);
}
if (e.code() == error_code_broken_promise) {
IFailureMonitor::failureMonitor().endpointNotFound(endpoint);
}

View File

@ -445,6 +445,7 @@ private:
TraceEvent(SevError, "LeakedConnection", self->dbgid)
.error(connection_leaked())
.detail("MyAddr", self->process->address)
.detail("IsPublic", self->process->address.isPublic())
.detail("PeerAddr", self->peerEndpoint)
.detail("PeerId", self->peerId)
.detail("Opened", self->opened);

View File

@ -87,6 +87,7 @@ public:
UID uid;
ProtocolVersion protocolVersion;
bool excludeFromRestarts = false;
std::vector<ProcessInfo*> childs;

View File

@ -118,7 +118,7 @@ private:
KeyRangeMap<ServerCacheInfo>* keyInfo = nullptr;
KeyRangeMap<bool>* cacheInfo = nullptr;
std::map<Key, ApplyMutationsData>* uid_applyMutationsData = nullptr;
RequestStream<CommitTransactionRequest> commit = RequestStream<CommitTransactionRequest>();
PublicRequestStream<CommitTransactionRequest> commit = PublicRequestStream<CommitTransactionRequest>();
Database cx = Database();
NotifiedVersion* committedVersion = nullptr;
std::map<UID, Reference<StorageInfo>>* storageCache = nullptr;

View File

@ -198,6 +198,7 @@ set(FDBSERVER_SRCS
workloads/ChangeFeeds.actor.cpp
workloads/ClearSingleRange.actor.cpp
workloads/ClientTransactionProfileCorrectness.actor.cpp
workloads/ClientWorkload.actor.cpp
workloads/ClogSingleConnection.actor.cpp
workloads/CommitBugCheck.actor.cpp
workloads/ConfigIncrement.actor.cpp
@ -251,6 +252,7 @@ set(FDBSERVER_SRCS
workloads/PhysicalShardMove.actor.cpp
workloads/Ping.actor.cpp
workloads/PopulateTPCC.actor.cpp
workloads/PrivateEndpoints.actor.cpp
workloads/ProtocolVersion.actor.cpp
workloads/PubSubMultiples.actor.cpp
workloads/QueuePush.actor.cpp

View File

@ -2981,7 +2981,8 @@ TEST_CASE("/fdbserver/clustercontroller/shouldTriggerRecoveryDueToDegradedServer
testDbInfo.logSystemConfig.tLogs.push_back(remoteTLogSet);
GrvProxyInterface proxyInterf;
proxyInterf.getConsistentReadVersion = RequestStream<struct GetReadVersionRequest>(Endpoint({ proxy }, testUID));
proxyInterf.getConsistentReadVersion =
PublicRequestStream<struct GetReadVersionRequest>(Endpoint({ proxy }, testUID));
testDbInfo.client.grvProxies.push_back(proxyInterf);
ResolverInterface resolverInterf;
@ -3090,11 +3091,12 @@ TEST_CASE("/fdbserver/clustercontroller/shouldTriggerFailoverDueToDegradedServer
testDbInfo.logSystemConfig.tLogs.push_back(remoteTLogSet);
GrvProxyInterface grvProxyInterf;
grvProxyInterf.getConsistentReadVersion = RequestStream<struct GetReadVersionRequest>(Endpoint({ proxy }, testUID));
grvProxyInterf.getConsistentReadVersion =
PublicRequestStream<struct GetReadVersionRequest>(Endpoint({ proxy }, testUID));
testDbInfo.client.grvProxies.push_back(grvProxyInterf);
CommitProxyInterface commitProxyInterf;
commitProxyInterf.commit = RequestStream<struct CommitTransactionRequest>(Endpoint({ proxy2 }, testUID));
commitProxyInterf.commit = PublicRequestStream<struct CommitTransactionRequest>(Endpoint({ proxy2 }, testUID));
testDbInfo.client.commitProxies.push_back(commitProxyInterf);
ResolverInterface resolverInterf;

View File

@ -239,7 +239,7 @@ struct GrvProxyData {
GrvProxyStats stats;
MasterInterface master;
RequestStream<GetReadVersionRequest> getConsistentReadVersion;
PublicRequestStream<GetReadVersionRequest> getConsistentReadVersion;
Reference<ILogSystem> logSystem;
Database cx;
@ -275,7 +275,7 @@ struct GrvProxyData {
GrvProxyData(UID dbgid,
MasterInterface master,
RequestStream<GetReadVersionRequest> getConsistentReadVersion,
PublicRequestStream<GetReadVersionRequest> getConsistentReadVersion,
Reference<AsyncVar<ServerDBInfo> const> db)
: dbgid(dbgid), stats(dbgid), master(master), getConsistentReadVersion(getConsistentReadVersion),
cx(openDBOnServer(db, TaskPriority::DefaultEndpoint, LockAware::True)), db(db), lastStartCommit(0),

View File

@ -29,6 +29,9 @@
#include "fdbclient/Tenant.h"
#include "fdbrpc/Stats.h"
#include "fdbserver/Knobs.h"
#include "fdbserver/LogSystem.h"
#include "fdbserver/MasterInterface.h"
#include "fdbserver/ResolverInterface.h"
#include "fdbserver/LogSystemDiskQueueAdapter.h"
#include "flow/IRandom.h"
@ -193,8 +196,8 @@ struct ProxyCommitData {
NotifiedVersion latestLocalCommitBatchResolving;
NotifiedVersion latestLocalCommitBatchLogging;
RequestStream<GetReadVersionRequest> getConsistentReadVersion;
RequestStream<CommitTransactionRequest> commit;
PublicRequestStream<GetReadVersionRequest> getConsistentReadVersion;
PublicRequestStream<CommitTransactionRequest> commit;
Database cx;
Reference<AsyncVar<ServerDBInfo> const> db;
EventMetricHandle<SingleKeyMutation> singleKeyMutationEvent;
@ -273,9 +276,9 @@ struct ProxyCommitData {
ProxyCommitData(UID dbgid,
MasterInterface master,
RequestStream<GetReadVersionRequest> getConsistentReadVersion,
PublicRequestStream<GetReadVersionRequest> getConsistentReadVersion,
Version recoveryTransactionVersion,
RequestStream<CommitTransactionRequest> commit,
PublicRequestStream<CommitTransactionRequest> commit,
Reference<AsyncVar<ServerDBInfo> const> db,
bool firstProxy)
: dbgid(dbgid), commitBatchesMemBytesCount(0),

View File

@ -26,6 +26,7 @@
#include <toml.hpp>
#include "fdbrpc/Locality.h"
#include "fdbrpc/simulator.h"
#include "fdbrpc/IPAllowList.h"
#include "fdbclient/ClusterConnectionFile.h"
#include "fdbclient/ClusterConnectionMemoryRecord.h"
#include "fdbclient/DatabaseContext.h"
@ -520,6 +521,10 @@ ACTOR Future<ISimulator::KillType> simulatedFDBDRebooter(Reference<IClusterConne
state ISimulator::ProcessInfo* simProcess = g_simulator.getCurrentProcess();
state UID randomId = nondeterministicRandom()->randomUniqueID();
state int cycles = 0;
state IPAllowList allowList;
allowList.addTrustedSubnet("0.0.0.0/2"sv);
allowList.addTrustedSubnet("abcd::/16"sv);
loop {
auto waitTime =
@ -579,7 +584,8 @@ ACTOR Future<ISimulator::KillType> simulatedFDBDRebooter(Reference<IClusterConne
// making progress
FlowTransport::createInstance(processClass == ProcessClass::TesterClass || runBackupAgents == AgentOnly,
1,
WLTOKEN_RESERVED_COUNT);
WLTOKEN_RESERVED_COUNT,
&allowList);
Sim2FileSystem::newFileSystem();
std::vector<Future<Void>> futures;
@ -2334,10 +2340,14 @@ ACTOR void setupAndRun(std::string dataFolder,
state Standalone<StringRef> startingConfiguration;
state int testerCount = 1;
state TestConfig testConfig;
state IPAllowList allowList;
testConfig.readFromConfig(testFile);
g_simulator.hasDiffProtocolProcess = testConfig.startIncompatibleProcess;
g_simulator.setDiffProtocol = false;
// Build simulator allow list
allowList.addTrustedSubnet("0.0.0.0/2"sv);
allowList.addTrustedSubnet("abcd::/16"sv);
state bool allowDefaultTenant = testConfig.allowDefaultTenant;
state bool allowDisablingTenants = testConfig.allowDisablingTenants;
@ -2382,7 +2392,7 @@ ACTOR void setupAndRun(std::string dataFolder,
}
// TODO (IPv6) Use IPv6?
wait(g_simulator.onProcess(
auto testSystem =
g_simulator.newProcess("TestSystem",
IPAddress(0x01010101),
1,
@ -2395,10 +2405,11 @@ ACTOR void setupAndRun(std::string dataFolder,
ProcessClass(ProcessClass::TesterClass, ProcessClass::CommandLineSource),
"",
"",
currentProtocolVersion),
TaskPriority::DefaultYield));
currentProtocolVersion);
testSystem->excludeFromRestarts = true;
wait(g_simulator.onProcess(testSystem, TaskPriority::DefaultYield));
Sim2FileSystem::newFileSystem();
FlowTransport::createInstance(true, 1, WLTOKEN_RESERVED_COUNT);
FlowTransport::createInstance(true, 1, WLTOKEN_RESERVED_COUNT, &allowList);
TEST(true); // Simulation start
state Optional<TenantName> defaultTenant;

View File

@ -35,6 +35,8 @@
#include <boost/algorithm/string.hpp>
#include <boost/interprocess/managed_shared_memory.hpp>
#include <fmt/printf.h>
#include "fdbclient/ActorLineageProfiler.h"
#include "fdbclient/ClusterConnectionFile.h"
#include "fdbclient/IKnobCollection.h"
@ -45,6 +47,7 @@
#include "fdbclient/WellKnownEndpoints.h"
#include "fdbclient/SimpleIni.h"
#include "fdbrpc/AsyncFileCached.actor.h"
#include "fdbrpc/IPAllowList.h"
#include "fdbrpc/FlowProcess.actor.h"
#include "fdbrpc/Net2FileSystem.h"
#include "fdbrpc/PerfMetric.h"
@ -107,7 +110,8 @@ enum {
OPT_DCID, OPT_MACHINE_CLASS, OPT_BUGGIFY, OPT_VERSION, OPT_BUILD_FLAGS, OPT_CRASHONERROR, OPT_HELP, OPT_NETWORKIMPL, OPT_NOBUFSTDOUT, OPT_BUFSTDOUTERR,
OPT_TRACECLOCK, OPT_NUMTESTERS, OPT_DEVHELP, OPT_ROLLSIZE, OPT_MAXLOGS, OPT_MAXLOGSSIZE, OPT_KNOB, OPT_UNITTESTPARAM, OPT_TESTSERVERS, OPT_TEST_ON_SERVERS, OPT_METRICSCONNFILE,
OPT_METRICSPREFIX, OPT_LOGGROUP, OPT_LOCALITY, OPT_IO_TRUST_SECONDS, OPT_IO_TRUST_WARN_ONLY, OPT_FILESYSTEM, OPT_PROFILER_RSS_SIZE, OPT_KVFILE,
OPT_TRACE_FORMAT, OPT_WHITELIST_BINPATH, OPT_BLOB_CREDENTIAL_FILE, OPT_CONFIG_PATH, OPT_USE_TEST_CONFIG_DB, OPT_FAULT_INJECTION, OPT_PROFILER, OPT_PRINT_SIMTIME, OPT_FLOW_PROCESS_NAME, OPT_FLOW_PROCESS_ENDPOINT
OPT_TRACE_FORMAT, OPT_WHITELIST_BINPATH, OPT_BLOB_CREDENTIAL_FILE, OPT_CONFIG_PATH, OPT_USE_TEST_CONFIG_DB, OPT_FAULT_INJECTION, OPT_PROFILER, OPT_PRINT_SIMTIME,
OPT_FLOW_PROCESS_NAME, OPT_FLOW_PROCESS_ENDPOINT, OPT_IP_TRUSTED_MASK
};
CSimpleOpt::SOption g_rgOptions[] = {
@ -199,6 +203,7 @@ CSimpleOpt::SOption g_rgOptions[] = {
{ OPT_PRINT_SIMTIME, "--print-sim-time", SO_NONE },
{ OPT_FLOW_PROCESS_NAME, "--process-name", SO_REQ_SEP },
{ OPT_FLOW_PROCESS_ENDPOINT, "--process-endpoint", SO_REQ_SEP },
{ OPT_IP_TRUSTED_MASK, "--trusted-subnet-", SO_REQ_SEP },
#ifndef TLS_DISABLED
TLS_OPTION_FLAGS
@ -311,7 +316,7 @@ UID getSharedMemoryMachineId() {
std::string sharedMemoryIdentifier = "fdbserver_shared_memory_id";
loop {
try {
// "0" is the default parameter "addr"
// "0" is the default netPrefix "addr"
boost::interprocess::managed_shared_memory segment(
boost::interprocess::open_or_create, sharedMemoryIdentifier.c_str(), 1000, 0, p.permission);
machineId = segment.find_or_construct<UID>("machineId")(newUID);
@ -1041,6 +1046,7 @@ struct CLIOptions {
std::string flowProcessName;
Endpoint flowProcessEndpoint;
bool printSimTime = false;
IPAllowList allowList;
static CLIOptions parseArgs(int argc, char* argv[]) {
CLIOptions opts;
@ -1167,6 +1173,15 @@ private:
localities.set(key, Standalone<StringRef>(std::string(args.OptionArg())));
break;
}
case OPT_IP_TRUSTED_MASK: {
Optional<std::string> subnetKey = extractPrefixedArgument("--trusted-subnet", args.OptionSyntax());
if (!subnetKey.present()) {
fprintf(stderr, "ERROR: unable to parse locality key '%s'\n", args.OptionSyntax());
flushAndExit(FDB_EXIT_ERROR);
}
allowList.addTrustedSubnet(args.OptionArg());
break;
}
case OPT_VERSION:
printVersion();
flushAndExit(FDB_EXIT_SUCCESS);
@ -1853,7 +1868,7 @@ int main(int argc, char* argv[]) {
} else {
g_network = newNet2(opts.tlsConfig, opts.useThreadPool, true);
g_network->addStopCallback(Net2FileSystem::stop);
FlowTransport::createInstance(false, 1, WLTOKEN_RESERVED_COUNT);
FlowTransport::createInstance(false, 1, WLTOKEN_RESERVED_COUNT, &opts.allowList);
opts.buildNetwork(argv[0]);
const bool expectsPublicAddress = (role == ServerRole::FDBD || role == ServerRole::NetworkTestServer ||

View File

@ -48,7 +48,7 @@ WorkloadContext::WorkloadContext() {}
WorkloadContext::WorkloadContext(const WorkloadContext& r)
: options(r.options), clientId(r.clientId), clientCount(r.clientCount), sharedRandomNumber(r.sharedRandomNumber),
dbInfo(r.dbInfo) {}
dbInfo(r.dbInfo), ccr(r.ccr) {}
WorkloadContext::~WorkloadContext() {}
@ -326,34 +326,51 @@ struct CompoundWorkload : TestWorkload {
}
return allTrue(all);
}
void getMetrics(std::vector<PerfMetric>& m) override {
for (int w = 0; w < workloads.size(); w++) {
ACTOR static Future<std::vector<PerfMetric>> getMetrics(CompoundWorkload* self) {
state std::vector<Future<std::vector<PerfMetric>>> results;
for (int w = 0; w < self->workloads.size(); w++) {
std::vector<PerfMetric> p;
workloads[w]->getMetrics(p);
for (int i = 0; i < p.size(); i++)
m.push_back(p[i].withPrefix(workloads[w]->description() + "."));
results.push_back(self->workloads[w]->getMetrics());
}
wait(waitForAll(results));
std::vector<PerfMetric> res;
for (int i = 0; i < results.size(); ++i) {
auto const& p = results[i].get();
for (auto const& m : p) {
res.push_back(m.withPrefix(self->workloads[i]->description() + "."));
}
}
return res;
}
Future<std::vector<PerfMetric>> getMetrics() override { return getMetrics(this); }
double getCheckTimeout() const override {
double m = 0;
for (int w = 0; w < workloads.size(); w++)
m = std::max(workloads[w]->getCheckTimeout(), m);
return m;
}
void getMetrics(std::vector<PerfMetric>&) override { ASSERT(false); }
};
Reference<TestWorkload> getWorkloadIface(WorkloadRequest work,
VectorRef<KeyValueRef> options,
Reference<AsyncVar<ServerDBInfo> const> dbInfo) {
Value testName = getOption(options, LiteralStringRef("testName"), LiteralStringRef("no-test-specified"));
ACTOR Future<Reference<TestWorkload>> getWorkloadIface(WorkloadRequest work,
Reference<IClusterConnectionRecord> ccr,
VectorRef<KeyValueRef> options,
Reference<AsyncVar<ServerDBInfo> const> dbInfo) {
state Reference<TestWorkload> workload;
state Value testName = getOption(options, LiteralStringRef("testName"), LiteralStringRef("no-test-specified"));
WorkloadContext wcx;
wcx.clientId = work.clientId;
wcx.clientCount = work.clientCount;
wcx.ccr = ccr;
wcx.dbInfo = dbInfo;
wcx.options = options;
wcx.sharedRandomNumber = work.sharedRandomNumber;
auto workload = IWorkloadFactory::create(testName.toString(), wcx);
workload = IWorkloadFactory::create(testName.toString(), wcx);
wait(workload->initialized());
auto unconsumedOptions = checkAllOptionsConsumed(workload ? workload->options : VectorRef<KeyValueRef>());
if (!workload || unconsumedOptions.size()) {
@ -378,24 +395,33 @@ Reference<TestWorkload> getWorkloadIface(WorkloadRequest work,
return workload;
}
Reference<TestWorkload> getWorkloadIface(WorkloadRequest work, Reference<AsyncVar<ServerDBInfo> const> dbInfo) {
ACTOR Future<Reference<TestWorkload>> getWorkloadIface(WorkloadRequest work,
Reference<IClusterConnectionRecord> ccr,
Reference<AsyncVar<ServerDBInfo> const> dbInfo) {
state WorkloadContext wcx;
state std::vector<Future<Reference<TestWorkload>>> ifaces;
if (work.options.size() < 1) {
TraceEvent(SevError, "TestCreationError").detail("Reason", "No options provided");
fprintf(stderr, "ERROR: No options were provided for workload.\n");
throw test_specification_invalid();
}
if (work.options.size() == 1)
return getWorkloadIface(work, work.options[0], dbInfo);
if (work.options.size() == 1) {
Reference<TestWorkload> res = wait(getWorkloadIface(work, ccr, work.options[0], dbInfo));
return res;
}
WorkloadContext wcx;
wcx.clientId = work.clientId;
wcx.clientCount = work.clientCount;
wcx.sharedRandomNumber = work.sharedRandomNumber;
// FIXME: Other stuff not filled in; why isn't this constructed here and passed down to the other
// getWorkloadIface()?
for (int i = 0; i < work.options.size(); i++) {
ifaces.push_back(getWorkloadIface(work, ccr, work.options[i], dbInfo));
}
wait(waitForAll(ifaces));
auto compound = makeReference<CompoundWorkload>(wcx);
for (int i = 0; i < work.options.size(); i++) {
compound->add(getWorkloadIface(work, work.options[i], dbInfo));
compound->add(ifaces[i].getValue());
}
return compound;
}
@ -494,7 +520,7 @@ ACTOR Future<Void> testDatabaseLiveness(Database cx,
try {
state double start = now();
auto traceMsg = "PingingDatabaseLiveness_" + context;
TraceEvent(traceMsg.c_str());
TraceEvent(traceMsg.c_str()).log();
wait(timeoutError(pingDatabase(cx), databasePingDelay));
double pingTime = now() - start;
ASSERT(pingTime > 0);
@ -607,10 +633,9 @@ ACTOR Future<Void> runWorkloadAsync(Database cx,
when(ReplyPromise<std::vector<PerfMetric>> req = waitNext(workIface.metrics.getFuture())) {
state ReplyPromise<std::vector<PerfMetric>> s_req = req;
try {
std::vector<PerfMetric> m;
workload->getMetrics(m);
std::vector<PerfMetric> m = wait(workload->getMetrics());
TraceEvent("WorkloadSendMetrics", workIface.id()).detail("Count", m.size());
req.send(m);
s_req.send(m);
} catch (Error& e) {
if (e.code() == error_code_please_reboot || e.code() == error_code_please_reboot_delete)
throw;
@ -649,7 +674,7 @@ ACTOR Future<Void> testerServerWorkload(WorkloadRequest work,
// add test for "done" ?
TraceEvent("WorkloadReceived", workIface.id()).detail("Title", work.title);
auto workload = getWorkloadIface(work, dbInfo);
Reference<TestWorkload> workload = wait(getWorkloadIface(work, ccr, dbInfo));
if (!workload) {
TraceEvent("TestCreationError").detail("Reason", "Workload could not be created");
fprintf(stderr, "ERROR: The workload could not be created.\n");
@ -704,6 +729,9 @@ ACTOR Future<Void> testerServerCore(TesterInterface interf,
ACTOR Future<Void> clearData(Database cx) {
state Transaction tr(cx);
state UID debugID = debugRandom()->randomUniqueID();
TraceEvent("TesterClearingDatabaseStart", debugID).log();
tr.debugTransaction(debugID);
loop {
try {
// This transaction needs to be self-conflicting, but not conflict consistently with
@ -712,10 +740,10 @@ ACTOR Future<Void> clearData(Database cx) {
tr.makeSelfConflicting();
wait(success(tr.getReadVersion())); // required since we use addReadConflictRange but not get
wait(tr.commit());
TraceEvent("TesterClearingDatabase").detail("AtVersion", tr.getCommittedVersion());
TraceEvent("TesterClearingDatabase", debugID).detail("AtVersion", tr.getCommittedVersion());
break;
} catch (Error& e) {
TraceEvent(SevWarn, "TesterClearingDatabaseError").error(e);
TraceEvent(SevWarn, "TesterClearingDatabaseError", debugID).error(e);
wait(tr.onError(e));
}
}

View File

@ -762,14 +762,14 @@ TEST_CASE("/fdbserver/worker/addressInDbAndPrimaryDc") {
NetworkAddress grvProxyAddress(IPAddress(0x26262626), 1);
GrvProxyInterface grvProxyInterf;
grvProxyInterf.getConsistentReadVersion =
RequestStream<struct GetReadVersionRequest>(Endpoint({ grvProxyAddress }, UID(1, 2)));
PublicRequestStream<struct GetReadVersionRequest>(Endpoint({ grvProxyAddress }, UID(1, 2)));
testDbInfo.client.grvProxies.push_back(grvProxyInterf);
ASSERT(addressInDbAndPrimaryDc(grvProxyAddress, makeReference<AsyncVar<ServerDBInfo>>(testDbInfo)));
NetworkAddress commitProxyAddress(IPAddress(0x37373737), 1);
CommitProxyInterface commitProxyInterf;
commitProxyInterf.commit =
RequestStream<struct CommitTransactionRequest>(Endpoint({ commitProxyAddress }, UID(1, 2)));
PublicRequestStream<struct CommitTransactionRequest>(Endpoint({ commitProxyAddress }, UID(1, 2)));
testDbInfo.client.commitProxies.push_back(commitProxyInterf);
ASSERT(addressInDbAndPrimaryDc(commitProxyAddress, makeReference<AsyncVar<ServerDBInfo>>(testDbInfo)));

View File

@ -0,0 +1,241 @@
/*
* ClientWorkload.actor.cpp
*
* This source file is part of the FoundationDB open source project
*
* Copyright 2013-2022 Apple Inc. and the FoundationDB project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "fdbserver/ServerDBInfo.actor.h"
#include "fdbserver/workloads/workloads.actor.h"
#include "fdbrpc/simulator.h"
#include <fmt/format.h>
#include "flow/actorcompiler.h" // has to be last include
class WorkloadProcessState {
IPAddress childAddress;
std::string processName;
Future<Void> processActor;
Promise<Void> init;
WorkloadProcessState(int clientId) : clientId(clientId) { processActor = processStart(this); }
~WorkloadProcessState() {
TraceEvent("ShutdownClientForWorkload", id).log();
g_simulator.destroyProcess(childProcess);
}
ACTOR static Future<Void> initializationDone(WorkloadProcessState* self, ISimulator::ProcessInfo* parent) {
wait(g_simulator.onProcess(parent, TaskPriority::DefaultYield));
self->init.send(Void());
wait(Never());
ASSERT(false); // does not happen
return Void();
}
ACTOR static Future<Void> processStart(WorkloadProcessState* self) {
state ISimulator::ProcessInfo* parent = g_simulator.getCurrentProcess();
state std::vector<Future<Void>> futures;
if (parent->address.isV6()) {
self->childAddress =
IPAddress::parse(fmt::format("2001:fdb1:fdb2:fdb3:fdb4:fdb5:fdb6:{:04x}", self->clientId + 2)).get();
} else {
self->childAddress = IPAddress::parse(fmt::format("192.168.0.{}", self->clientId + 2)).get();
}
self->processName = fmt::format("TestClient{}", self->clientId);
Standalone<StringRef> newZoneId(deterministicRandom()->randomUniqueID().toString());
auto locality = LocalityData(Optional<Standalone<StringRef>>(), newZoneId, newZoneId, parent->locality.dcId());
auto dataFolder = joinPath(popPath(parent->dataFolder), deterministicRandom()->randomUniqueID().toString());
platform::createDirectory(dataFolder);
TraceEvent("StartingClientWorkloadProcess", self->id)
.detail("Name", self->processName)
.detail("Address", self->childAddress);
self->childProcess = g_simulator.newProcess(self->processName.c_str(),
self->childAddress,
1,
parent->address.isTLS(),
1,
locality,
ProcessClass(ProcessClass::TesterClass, ProcessClass::AutoSource),
dataFolder.c_str(),
parent->coordinationFolder,
parent->protocolVersion);
self->childProcess->excludeFromRestarts = true;
wait(g_simulator.onProcess(self->childProcess, TaskPriority::DefaultYield));
try {
FlowTransport::createInstance(true, 1, WLTOKEN_RESERVED_COUNT);
Sim2FileSystem::newFileSystem();
auto addr = g_simulator.getCurrentProcess()->address;
futures.push_back(FlowTransport::transport().bind(addr, addr));
futures.push_back(success((self->childProcess->onShutdown())));
TraceEvent("ClientWorkloadProcessInitialized", self->id).log();
futures.push_back(initializationDone(self, parent));
wait(waitForAny(futures));
} catch (Error& e) {
if (e.code() == error_code_actor_cancelled) {
return Void();
}
ASSERT(false);
}
ASSERT(false);
return Void();
}
static std::vector<WorkloadProcessState*>& states() {
static std::vector<WorkloadProcessState*> res;
return res;
}
public:
static WorkloadProcessState* instance(int clientId) {
states().resize(std::max(states().size(), size_t(clientId + 1)), nullptr);
auto& res = states()[clientId];
if (res == nullptr) {
res = new WorkloadProcessState(clientId);
}
return res;
}
Future<Void> initialized() const { return init.getFuture(); }
UID id = deterministicRandom()->randomUniqueID();
int clientId;
ISimulator::ProcessInfo* childProcess;
};
struct WorkloadProcess {
WorkloadProcessState* processState;
UID id = deterministicRandom()->randomUniqueID();
Database cx;
Future<Void> databaseOpened;
Reference<TestWorkload> child;
std::string desc;
void createDatabase(ClientWorkload::CreateWorkload const& childCreator, WorkloadContext const& wcx) {
try {
child = childCreator(wcx);
TraceEvent("ClientWorkloadOpenDatabase", id).detail("ClusterFileLocation", child->ccr->getLocation());
cx = Database::createDatabase(child->ccr, -1);
desc = child->description();
} catch (Error&) {
throw;
} catch (...) {
ASSERT(false);
}
}
ACTOR static Future<Void> openDatabase(WorkloadProcess* self,
ClientWorkload::CreateWorkload childCreator,
WorkloadContext wcx) {
state ISimulator::ProcessInfo* parent = g_simulator.getCurrentProcess();
state Optional<Error> err;
wcx.dbInfo = Reference<AsyncVar<struct ServerDBInfo> const>();
wait(self->processState->initialized());
wait(g_simulator.onProcess(self->childProcess(), TaskPriority::DefaultYield));
try {
self->createDatabase(childCreator, wcx);
} catch (Error& e) {
ASSERT(e.code() != error_code_actor_cancelled);
err = e;
}
wait(g_simulator.onProcess(parent, TaskPriority::DefaultYield));
if (err.present()) {
throw err.get();
}
return Void();
}
ISimulator::ProcessInfo* childProcess() { return processState->childProcess; }
int clientId() const { return processState->clientId; }
WorkloadProcess(ClientWorkload::CreateWorkload const& childCreator, WorkloadContext const& wcx)
: processState(WorkloadProcessState::instance(wcx.clientId)) {
TraceEvent("StartingClinetWorkload", id).detail("OnClientProcess", processState->id);
databaseOpened = openDatabase(this, childCreator, wcx);
}
ACTOR static void destroy(WorkloadProcess* self) {
state ISimulator::ProcessInfo* parent = g_simulator.getCurrentProcess();
wait(g_simulator.onProcess(self->childProcess(), TaskPriority::DefaultYield));
delete self;
wait(g_simulator.onProcess(parent, TaskPriority::DefaultYield));
}
std::string description() { return desc; }
ACTOR template <class Ret, class Fun>
Future<Ret> runActor(WorkloadProcess* self, Optional<TenantName> defaultTenant, Fun f) {
state Optional<Error> err;
state Ret res;
state ISimulator::ProcessInfo* parent = g_simulator.getCurrentProcess();
wait(self->databaseOpened);
wait(g_simulator.onProcess(self->childProcess(), TaskPriority::DefaultYield));
self->cx->defaultTenant = defaultTenant;
try {
Ret r = wait(f(self->cx));
res = r;
} catch (Error& e) {
if (e.code() == error_code_actor_cancelled) {
ASSERT(g_simulator.getCurrentProcess() == parent);
throw;
}
err = e;
}
wait(g_simulator.onProcess(parent, TaskPriority::DefaultYield));
if (err.present()) {
throw err.get();
}
return res;
}
};
ClientWorkload::ClientWorkload(CreateWorkload const& childCreator, WorkloadContext const& wcx)
: TestWorkload(wcx), impl(new WorkloadProcess(childCreator, wcx)) {}
ClientWorkload::~ClientWorkload() {
WorkloadProcess::destroy(impl);
}
std::string ClientWorkload::description() const {
return impl->description();
}
Future<Void> ClientWorkload::initialized() {
return impl->databaseOpened;
}
Future<Void> ClientWorkload::setup(Database const& cx) {
return impl->runActor<Void>(impl, cx->defaultTenant, [this](Database const& db) { return impl->child->setup(db); });
}
Future<Void> ClientWorkload::start(Database const& cx) {
return impl->runActor<Void>(impl, cx->defaultTenant, [this](Database const& db) { return impl->child->start(db); });
}
Future<bool> ClientWorkload::check(Database const& cx) {
return impl->runActor<bool>(impl, cx->defaultTenant, [this](Database const& db) { return impl->child->check(db); });
}
Future<std::vector<PerfMetric>> ClientWorkload::getMetrics() {
return impl->runActor<std::vector<PerfMetric>>(
impl, Optional<TenantName>(), [this](Database const& db) { return impl->child->getMetrics(); });
}
void ClientWorkload::getMetrics(std::vector<PerfMetric>& m) {
ASSERT(false);
}
double ClientWorkload::getCheckTimeout() const {
return impl->child->getCheckTimeout();
}

View File

@ -268,4 +268,4 @@ struct CycleWorkload : TestWorkload {
}
};
WorkloadFactory<CycleWorkload> CycleWorkloadFactory("Cycle");
WorkloadFactory<CycleWorkload> CycleWorkloadFactory("Cycle", true);

View File

@ -0,0 +1,142 @@
/*
* PrivateEndpoints.actor.cpp
*
* This source file is part of the FoundationDB open source project
*
* Copyright 2013-2022 Apple Inc. and the FoundationDB project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "fdbserver/workloads/workloads.actor.h"
#include "flow/actorcompiler.h" // has to be last include
namespace {
struct PrivateEndpoints : TestWorkload {
static constexpr const char* WorkloadName = "PrivateEndpoints";
bool success = true;
int numSuccesses = 0;
double startAfter;
double runFor;
std::vector<std::function<Future<Void>(Reference<AsyncVar<ClientDBInfo>> const&)>> testFunctions;
template <class T>
static Optional<T> getRandom(std::vector<T> const& v) {
if (v.empty()) {
return Optional<T>();
} else {
return deterministicRandom()->randomChoice(v);
}
}
template <class T>
static Optional<T> getInterface(Reference<AsyncVar<ClientDBInfo>> const& clientDBInfo) {
if constexpr (std::is_same_v<T, GrvProxyInterface>) {
return getRandom(clientDBInfo->get().grvProxies);
} else if constexpr (std::is_same_v<T, CommitProxyInterface>) {
return getRandom(clientDBInfo->get().commitProxies);
} else {
ASSERT(false); // don't know how to handle this type
}
}
ACTOR template <class T>
static Future<Void> assumeFailure(Future<T> f) {
try {
T t = wait(f);
(void)t;
ASSERT(false);
} catch (Error& e) {
if (e.code() == error_code_actor_cancelled) {
throw;
} else if (e.code() == error_code_unauthorized_attempt) {
TraceEvent("SuccessPrivateEndpoint").log();
} else if (e.code() == error_code_request_maybe_delivered) {
// this is also fine, because even when calling private endpoints
// we might see connection failures
TraceEvent("SuccessRequestMaybeDelivered").log();
} else {
TraceEvent(SevError, "WrongErrorCode").error(e);
}
}
return Void();
}
template <class I, class RT>
void addTestFor(RequestStream<RT, false> I::*channel) {
testFunctions.push_back([channel](Reference<AsyncVar<ClientDBInfo>> const& clientDBInfo) {
auto optintf = getInterface<I>(clientDBInfo);
if (!optintf.present()) {
return clientDBInfo->onChange();
}
RequestStream<RT> s = optintf.get().*channel;
RT req;
return assumeFailure(deterministicRandom()->coinflip() ? throwErrorOr(s.tryGetReply(req))
: s.getReply(req));
});
}
explicit PrivateEndpoints(WorkloadContext const& wcx) : TestWorkload(wcx) {
startAfter = getOption(options, "startAfter"_sr, 10.0);
runFor = getOption(options, "runFor"_sr, 10.0);
addTestFor(&GrvProxyInterface::waitFailure);
addTestFor(&GrvProxyInterface::getHealthMetrics);
addTestFor(&CommitProxyInterface::getStorageServerRejoinInfo);
addTestFor(&CommitProxyInterface::waitFailure);
addTestFor(&CommitProxyInterface::txnState);
addTestFor(&CommitProxyInterface::getHealthMetrics);
addTestFor(&CommitProxyInterface::proxySnapReq);
addTestFor(&CommitProxyInterface::exclusionSafetyCheckReq);
addTestFor(&CommitProxyInterface::getDDMetrics);
}
std::string description() const override { return WorkloadName; }
Future<Void> start(Database const& cx) override { return _start(this, cx); }
Future<bool> check(Database const& cx) override { return success; }
void getMetrics(std::vector<PerfMetric>& m) override {
m.emplace_back("Successes", double(numSuccesses), Averaged::True);
}
ACTOR static Future<Void> _start(PrivateEndpoints* self, Database cx) {
state Reference<AsyncVar<ClientDBInfo>> clientInfo = cx->clientInfo;
state Future<Void> end;
TraceEvent("PrivateEndpointTestStartWait").detail("WaitTime", self->startAfter).log();
wait(delay(self->startAfter));
TraceEvent("PrivateEndpointTestStart").detail("RunFor", self->runFor).log();
end = delay(self->runFor);
try {
loop {
auto testFuture = deterministicRandom()->randomChoice(self->testFunctions)(cx->clientInfo);
choose {
when(wait(end)) {
TraceEvent("PrivateEndpointTestDone").log();
return Void();
}
when(wait(testFuture)) { ++self->numSuccesses; }
}
wait(delay(0.2));
}
} catch (Error& e) {
TraceEvent(SevError, "PrivateEndpointTestError").errorUnsuppressed(e);
ASSERT(false);
}
UNREACHABLE();
return Void();
}
};
} // namespace
WorkloadFactory<PrivateEndpoints> PrivateEndpointsFactory(PrivateEndpoints::WorkloadName, true);

View File

@ -89,7 +89,7 @@ struct SaveAndKillWorkload : TestWorkload {
for (const auto& [_, process] : allProcessesMap) {
std::string machineId = printable(process->locality.machineId());
const char* machineIdString = machineId.c_str();
if (strcmp(process->name, "TestSystem") != 0) {
if (!process->excludeFromRestarts) {
if (machines.find(machineId) == machines.end()) {
machines.insert(std::pair<std::string, int>(machineId, 1));
ini.SetValue("META", format("%d", j).c_str(), machineIdString);

View File

@ -30,7 +30,10 @@
#include "fdbserver/KnobProtectiveGroups.h"
#include "fdbserver/TesterInterface.actor.h"
#include "fdbrpc/simulator.h"
#include "flow/actorcompiler.h"
#include <functional>
#include "flow/actorcompiler.h" // has to be last import
/*
* Gets an Value from a list of key/value pairs, using a default value if the key is not present.
@ -51,6 +54,7 @@ struct WorkloadContext {
int clientId, clientCount;
int64_t sharedRandomNumber;
Reference<AsyncVar<struct ServerDBInfo> const> dbInfo;
Reference<IClusterConnectionRecord> ccr;
WorkloadContext();
WorkloadContext(const WorkloadContext&);
@ -69,15 +73,40 @@ struct TestWorkload : NonCopyable, WorkloadContext, ReferenceCounted<TestWorkloa
phases |= TestWorkload::SETUP;
}
virtual ~TestWorkload(){};
virtual Future<Void> initialized() { return Void(); }
virtual std::string description() const = 0;
virtual Future<Void> setup(Database const& cx) { return Void(); }
virtual Future<Void> start(Database const& cx) = 0;
virtual Future<bool> check(Database const& cx) = 0;
virtual void getMetrics(std::vector<PerfMetric>& m) = 0;
virtual Future<std::vector<PerfMetric>> getMetrics() {
std::vector<PerfMetric> result;
getMetrics(result);
return result;
}
virtual double getCheckTimeout() const { return 3000; }
enum WorkloadPhase { SETUP = 1, EXECUTION = 2, CHECK = 4, METRICS = 8 };
private:
virtual void getMetrics(std::vector<PerfMetric>& m) = 0;
};
struct WorkloadProcess;
struct ClientWorkload : TestWorkload {
WorkloadProcess* impl;
using CreateWorkload = std::function<Reference<TestWorkload>(WorkloadContext const&)>;
ClientWorkload(CreateWorkload const& childCreator, WorkloadContext const& wcx);
~ClientWorkload();
Future<Void> initialized() override;
std::string description() const override;
Future<Void> setup(Database const& cx) override;
Future<Void> start(Database const& cx) override;
Future<bool> check(Database const& cx) override;
void getMetrics(std::vector<PerfMetric>& m) override;
Future<std::vector<PerfMetric>> getMetrics() override;
double getCheckTimeout() const override;
};
struct KVWorkload : TestWorkload {
@ -122,8 +151,17 @@ struct IWorkloadFactory : ReferenceCounted<IWorkloadFactory> {
template <class WorkloadType>
struct WorkloadFactory : IWorkloadFactory {
WorkloadFactory(const char* name) { factories()[name] = Reference<IWorkloadFactory>::addRef(this); }
Reference<TestWorkload> create(WorkloadContext const& wcx) override { return makeReference<WorkloadType>(wcx); }
bool asClient;
WorkloadFactory(const char* name, bool asClient = false) : asClient(asClient) {
factories()[name] = Reference<IWorkloadFactory>::addRef(this);
}
Reference<TestWorkload> create(WorkloadContext const& wcx) override {
if (g_network->isSimulated() && asClient) {
return makeReference<ClientWorkload>(
[](WorkloadContext const& wcx) { return makeReference<WorkloadType>(wcx); }, wcx);
}
return makeReference<WorkloadType>(wcx);
}
};
#define REGISTER_WORKLOAD(classname) WorkloadFactory<classname> classname##WorkloadFactory(#classname)

View File

@ -138,6 +138,7 @@ target_link_libraries(flow PUBLIC fmt::fmt)
add_flow_target(STATIC_LIBRARY NAME flow_sampling SRCS ${FLOW_SRCS})
target_link_libraries(flow_sampling PRIVATE stacktrace)
target_link_libraries(flow_sampling PUBLIC fmt::fmt)
target_compile_definitions(flow_sampling PRIVATE -DENABLE_SAMPLING)
if(WIN32)
add_dependencies(flow_sampling_actors flow_actors)

View File

@ -24,10 +24,16 @@
#include "flow/flat_buffers.h"
#include "flow/ProtocolVersion.h"
#include <unordered_map>
using ContextVariableMap = std::unordered_map<std::string_view, void*>;
template <class Ar>
struct LoadContext {
Ar* ar;
LoadContext(Ar* ar) : ar(ar) {}
Arena& arena() { return ar->arena(); }
ProtocolVersion protocolVersion() const { return ar->protocolVersion(); }
@ -68,20 +74,23 @@ struct SaveContext {
template <class ReaderImpl>
class _ObjectReader {
protected:
ProtocolVersion mProtocolVersion;
Optional<ProtocolVersion> mProtocolVersion;
std::shared_ptr<ContextVariableMap> variables;
public:
ProtocolVersion protocolVersion() const { return mProtocolVersion; }
ProtocolVersion protocolVersion() const { return mProtocolVersion.get(); }
void setProtocolVersion(ProtocolVersion v) { mProtocolVersion = v; }
void setContextVariableMap(std::shared_ptr<ContextVariableMap> const& cvm) { variables = cvm; }
template <class... Items>
void deserialize(FileIdentifier file_identifier, Items&... items) {
const uint8_t* data = static_cast<ReaderImpl*>(this)->data();
LoadContext<ReaderImpl> context(static_cast<ReaderImpl*>(this));
const uint8_t* data = static_cast<ReaderImpl*>(this)->data();
if (read_file_identifier(data) != file_identifier) {
// Some file identifiers are changed in 7.0, so file identifier mismatches
// are expected during a downgrade from 7.0 to 6.3
bool expectMismatch = mProtocolVersion >= ProtocolVersion(0x0FDB00B070000000LL);
bool expectMismatch = mProtocolVersion.get() >= ProtocolVersion(0x0FDB00B070000000LL) &&
currentProtocolVersion < ProtocolVersion(0x0FDB00B070000000LL);
{
TraceEvent te(expectMismatch ? SevInfo : SevError, "MismatchedFileIdentifier");
if (expectMismatch) {
@ -100,6 +109,24 @@ public:
void deserialize(Item& item) {
deserialize(FileIdentifierFor<Item>::value, item);
}
template <class T>
bool variable(std::string_view name, T* val) {
auto p = variables->insert(std::make_pair(name, val));
return p.second;
}
template <class T>
T& variable(std::string_view name) {
auto res = variables->at(name);
return *reinterpret_cast<T*>(res);
}
template <class T>
T const& variable(std::string_view name) const {
auto res = variables->at(name);
return *reinterpret_cast<T*>(res);
}
};
class ObjectReader : public _ObjectReader<ObjectReader> {

View File

@ -311,6 +311,10 @@ ERROR( encrypt_invalid_id, 2706, "Invalid encryption domainId or encryption ciph
ERROR( unknown_error, 4000, "An unknown error occurred" ) // C++ exception not of type Error
ERROR( internal_error, 4100, "An internal error occurred" )
ERROR( not_implemented, 4200, "Not implemented yet" )
// 6xxx Authorization and authentication error codes
ERROR( permission_denied, 6000, "Client tried to access unauthorized data" )
ERROR( unauthorized_attempt, 6001, "A untrusted client tried to send a message to a private endpoint" )
// clang-format on
#undef ERROR

View File

@ -161,6 +161,7 @@ if(WITH_PYTHON)
add_fdb_test(TEST_FILES fast/MoveKeysCycle.toml)
add_fdb_test(TEST_FILES fast/MutationLogReaderCorrectness.toml)
add_fdb_test(TEST_FILES fast/GetMappedRange.toml)
add_fdb_test(TEST_FILES fast/PrivateEndpoints.toml)
add_fdb_test(TEST_FILES fast/ProtocolVersion.toml)
add_fdb_test(TEST_FILES fast/RandomSelector.toml)
add_fdb_test(TEST_FILES fast/RandomUnitTests.toml)

View File

@ -0,0 +1,5 @@
[[test]]
testTitle = 'PrivateEndpoints'
[[test.workload]]
testName = 'PrivateEndpoints'