Make transport work

This commit is contained in:
Markus Pilman 2022-04-22 17:06:05 -06:00
parent 1da1f8cc0f
commit ccf97eb187
6 changed files with 46 additions and 18 deletions

View File

@ -251,7 +251,9 @@ struct TenantAuthorizer final : NetworkMessageReceiver {
for (const auto& t : req.tokens) {
auto key = transport.getPublicKeyByName(t.keyName);
if (key.present() && verifyToken(t, key.get())) {
auto token = ObjectReader::fromStringRef<AuthTokenRef>(t.token, Unversioned());
ObjectReader r(t.token.begin(), AssumeVersion(reader.protocolVersion()));
AuthTokenRef token;
r.deserialize(token);
Reference<AuthorizedTenants>& auth =
std::any_cast<Reference<AuthorizedTenants>&>(reader.variable("AuthorizedTenants"));
auth->add(token.expiresAt, token.tenants);
@ -287,10 +289,14 @@ struct UnauthorizedEndpointReceiver final : NetworkMessageReceiver {
bool isPublic() const override { return true; }
};
template <class T, class Container = std::vector<T>, class Cmp = std::less<T>>
class IterablePriorityQueue {
// A priority queue with two additional properties:
// - All values within the queue are unique
// - One can iterate over the queue
template <class T, class Container = std::vector<T>, class Cmp = std::less<T>, class ValEq = std::equal_to<T>>
class IterableUniquePriorityQueue {
Container queue;
Cmp cmp;
ValEq valCmp;
using const_iterator = typename Container::const_iterator;
public:
@ -302,10 +308,18 @@ public:
queue.push_back(std::move(val));
std::push_heap(queue.begin(), queue.end(), cmp);
}
// runs in O(n) -- so this class should only be used if we expect the max size of the queue to be small
template <class... Args>
void emplace(Args&&... args) {
queue.emplace_back(std::forward<Args>(args)...);
bool emplace(Args&&... args) {
T el(std::forward<Args>(args)...);
for (const auto& element : *this) {
if (valCmp(element, el)) {
return false;
}
}
queue.emplace_back(std::move(el));
std::push_heap(queue.begin(), queue.end(), cmp);
return true;
}
const T& front() const { return queue.begin(); }
void pop() {
@ -319,9 +333,14 @@ public:
using SignedAuthTokenTTL = std::pair<double, SignedAuthToken>;
struct SignedAuthTokenTTLCmp {
bool operator()(const SignedAuthTokenTTL& lhs, const SignedAuthTokenTTL& rhs) { return lhs.first > rhs.first; }
constexpr bool operator()(const SignedAuthTokenTTL& lhs, const SignedAuthTokenTTL& rhs) const { return lhs.first > rhs.first; }
};
using TokenQueue = IterablePriorityQueue<SignedAuthTokenTTL, std::vector<SignedAuthTokenTTL>, SignedAuthTokenTTLCmp>;
struct SignedAuthTokenCmp {
constexpr bool operator()(const SignedAuthTokenTTL& lhs, const SignedAuthTokenTTL& rhs) const { return lhs.second.signature == rhs.second.signature; }
};
using TokenQueue = IterableUniquePriorityQueue<SignedAuthTokenTTL, std::vector<SignedAuthTokenTTL>, SignedAuthTokenTTLCmp, SignedAuthTokenCmp>;
class TransportData {
public:
@ -954,7 +973,6 @@ void Peer::prependConnectPacket() {
for (auto t : transport->tokens) {
req.tokens.push_back(req.arena, t.second);
}
SerializeSource<AuthorizationRequest> what(req);
++transport->countPacketsGenerated;
SplitBuffer packetInfoBuffer;
uint32_t len;
@ -969,7 +987,10 @@ void Peer::prependConnectPacket() {
}
wr.writeAhead(packetInfoSize, &packetInfoBuffer);
wr << Endpoint::wellKnownToken(WLTOKEN_AUTH_TENANT);
what.serializePacketWriter(wr);
ObjectWriter writer([&wr](size_t size) { return wr.writeBytes(size); },
AssumeVersion(g_network->protocolVersion()));
writer.serialize(req);
// what.serializePacketWriter(wr);
pb_end = wr.finish();
len = wr.size() - packetInfoSize - pkt.totalPacketSize();
if (checksumEnabled) {
@ -2046,10 +2067,12 @@ HealthMonitor* FlowTransport::healthMonitor() {
}
void FlowTransport::authorizationTokenAdd(StringRef signedToken) {
auto tokenRef = ObjectReader::fromStringRef<SignedAuthTokenRef>(signedToken, Unversioned());
ObjectReader reader(signedToken.begin(), AssumeVersion(g_network->protocolVersion()));
SignedAuthTokenRef tokenRef;
reader.deserialize(tokenRef);
SignedAuthToken token(tokenRef);
// we need the TTL to invalidate tokens on the client side
auto authToken = ObjectReader::fromStringRef<AuthTokenRef>(token.token, Unversioned());
auto authToken = ObjectReader::fromStringRef<AuthTokenRef>(token.token, AssumeVersion(g_network->protocolVersion()));
if (authToken.expiresAt < now()) {
TraceEvent(SevWarnAlways, "AddedExpiredToken").detail("Expired", authToken.expiresAt);
return;

View File

@ -109,10 +109,10 @@ Standalone<KeyPairRef> generateEcdsaKeyPair() {
Standalone<SignedAuthTokenRef> signToken(AuthTokenRef token, StringRef keyName, StringRef privateKeyDer) {
auto ret = Standalone<SignedAuthTokenRef>{};
auto arena = ret.arena();
auto& arena = ret.arena();
auto writer = ObjectWriter([&arena](size_t len) { return new (arena) uint8_t[len]; }, Unversioned());
writer.serialize(token);
auto tokenStr = writer.toStringRef();
auto tokenStr = StringRef(arena, writer.toStringRef());
auto rawPrivKeyDer = privateKeyDer.begin();
auto key = ::d2i_AutoPrivateKey(nullptr, &rawPrivKeyDer, privateKeyDer.size());

View File

@ -207,6 +207,7 @@ set(FDBSERVER_SRCS
workloads/ConflictRange.actor.cpp
workloads/ConsistencyCheck.actor.cpp
workloads/CpuProfiler.actor.cpp
workloads/CreateTenant.actor.cpp
workloads/Cycle.actor.cpp
workloads/DataDistributionMetrics.actor.cpp
workloads/DataLossRecovery.actor.cpp

View File

@ -588,6 +588,9 @@ ACTOR Future<ISimulator::KillType> simulatedFDBDRebooter(Reference<IClusterConne
1,
WLTOKEN_RESERVED_COUNT,
&allowList);
for (const auto& p : g_simulator.authKeys) {
FlowTransport::transport().addPublicKey(p.first, p.second.publicKey);
}
Sim2FileSystem::newFileSystem();
std::vector<Future<Void>> futures;

View File

@ -370,7 +370,9 @@ ACTOR Future<Reference<TestWorkload>> getWorkloadIface(WorkloadRequest work,
wcx.sharedRandomNumber = work.sharedRandomNumber;
workload = IWorkloadFactory::create(testName.toString(), wcx);
wait(workload->initialized());
if (workload) {
wait(workload->initialized());
}
auto unconsumedOptions = checkAllOptionsConsumed(workload ? workload->options : VectorRef<KeyValueRef>());
if (!workload || unconsumedOptions.size()) {

View File

@ -851,10 +851,6 @@ struct PacketWriter {
}
ProtocolVersion protocolVersion() const { return m_protocolVersion; }
void setProtocolVersion(ProtocolVersion pv) { m_protocolVersion = pv; }
private:
void serializeBytesAcrossBoundary(const void* data, int bytes);
void nextBuffer(size_t size = 0 /* downstream it will default to at least 4k minus some padding */);
uint8_t* writeBytes(size_t size) {
if (size > buffer->bytes_unwritten()) {
nextBuffer(size);
@ -865,6 +861,9 @@ private:
return result;
}
private:
void serializeBytesAcrossBoundary(const void* data, int bytes);
void nextBuffer(size_t size = 0 /* downstream it will default to at least 4k minus some padding */);
template <class, class>
friend class MakeSerializeSource;