diff --git a/flow/ThreadHelper.actor.h b/flow/ThreadHelper.actor.h index 6627b9e25e..5aa9baebc7 100644 --- a/flow/ThreadHelper.actor.h +++ b/flow/ThreadHelper.actor.h @@ -22,6 +22,8 @@ // When actually compiled (NO_INTELLISENSE), include the generated // version of this file. In intellisense use the source version. +#include "flow/Error.h" +#include #if defined(NO_INTELLISENSE) && !defined(FLOW_THREADHELPER_ACTOR_G_H) #define FLOW_THREADHELPER_ACTOR_G_H #include "flow/ThreadHelper.actor.g.h" @@ -69,20 +71,78 @@ void onMainThreadVoid(F f, Error* err = nullptr, TaskPriority taskID = TaskPrior g_network->onMainThread(std::move(signal), taskID); } +class ThreadMultiCallback; + struct ThreadCallback { virtual bool canFire(int notMadeActive) const = 0; virtual void fire(const Void& unused, int& userParam) = 0; virtual void error(const Error&, int& userParam) = 0; virtual ThreadCallback* addCallback(ThreadCallback* cb); - virtual bool contains(ThreadCallback* cb) const { return false; } - virtual void clearCallback(ThreadCallback* cb) { // If this is the only registered callback this will be called with (possibly) arbitrary pointers } virtual void destroy() { UNSTOPPABLE_ASSERT(false); } virtual bool isMultiCallback() const { return false; } + + // MultiCallbackHolder is a helper object for ThreadMultiCallback which allows it to store its index + // within the callback vector inside the ThreadCallback rather than having a map of pointers or + // some other scheme to store the indices by callback. + // MultiCallbackHolder objects can form a doubly linked list. + struct MultiCallbackHolder : public FastAllocated { + MultiCallbackHolder(ThreadMultiCallback* holder = nullptr, + MultiCallbackHolder* prev = nullptr, + MultiCallbackHolder* next = nullptr) + : holder(holder), previous(prev), next(next) {} + + ThreadMultiCallback* holder; + int index; + MultiCallbackHolder* previous; + MultiCallbackHolder* next; + }; + + // firstHolder is both the inline first record of a MultiCallbackHolder and the head of the + // doubly linked list of MultiCallbackHolder entries. + MultiCallbackHolder firstHolder; + + // Return a MultiCallbackHolder for the given holder, using the firstHolder if free or allocating + // a new one. No check for an existing record for holder is done. + MultiCallbackHolder* addHolder(ThreadMultiCallback* holder) { + if (firstHolder.holder == nullptr) { + firstHolder.holder = holder; + return &firstHolder; + } + firstHolder.next = new MultiCallbackHolder(holder, &firstHolder, firstHolder.next); + return firstHolder.next; + } + + // Get the MultiCallbackHolder for holder if it exists, or nullptr. + MultiCallbackHolder* getHolder(ThreadMultiCallback* holder) { + MultiCallbackHolder* h = &firstHolder; + while (h != nullptr && h->holder != holder) { + h = h->next; + } + return h; + } + + // Destroy the given MultiCallbackHolder, freeing it if it is not firstHolder. + void destroyHolder(MultiCallbackHolder* h) { + UNSTOPPABLE_ASSERT(h != nullptr); + + // If h is the firstHolder just clear its holder pointer to indicate unusedness + if (h == &firstHolder) { + h->holder = nullptr; + } else { + // Otherwise unlink h from the doubly linked list and free it + // h->previous is definitely valid + h->previous->next = h->next; + if (h->next) { + h->next->previous = h->previous; + } + delete h; + } + } }; class ThreadMultiCallback final : public ThreadCallback, public FastAllocated { @@ -90,29 +150,31 @@ public: ThreadMultiCallback() {} ThreadCallback* addCallback(ThreadCallback* callback) override { - UNSTOPPABLE_ASSERT(callbackMap.count(callback) == - 0); // May be triggered by a waitForAll on a vector with the same future in it more than once - callbackMap[callback] = callbacks.size(); + UNSTOPPABLE_ASSERT( + callback->getHolder(this) == + nullptr); // May be triggered by a waitForAll on a vector with the same future in it more than once + callback->addHolder(this)->index = callbacks.size(); callbacks.push_back(callback); return (ThreadCallback*)this; } - bool contains(ThreadCallback* cb) const override { return callbackMap.count(cb) != 0; } - void clearCallback(ThreadCallback* callback) override { - auto it = callbackMap.find(callback); - if (it == callbackMap.end()) + MultiCallbackHolder* h = callback->getHolder(this); + if (h == nullptr) { return; + } - UNSTOPPABLE_ASSERT(it->second < callbacks.size() && it->second >= 0); + UNSTOPPABLE_ASSERT(h->index < callbacks.size() && h->index >= 0); - if (it->second != callbacks.size() - 1) { - callbacks[it->second] = callbacks.back(); - callbackMap[callbacks[it->second]] = it->second; + // Swap callback with last callback if it isn't the last + if (h->index != callbacks.size() - 1) { + callbacks[h->index] = callbacks.back(); + // Update the index of the Holder entry for the moved callback + callbacks[h->index]->getHolder(this)->index = h->index; } callbacks.pop_back(); - callbackMap.erase(it); + callback->destroyHolder(h); } bool canFire(int notMadeActive) const override { return true; } @@ -126,7 +188,7 @@ public: while (callbacks.size()) { auto cb = callbacks.back(); callbacks.pop_back(); - callbackMap.erase(cb); + cb->destroyHolder(cb->getHolder(this)); if (cb->canFire(0)) { int ld = 0; cb->fire(value, ld); @@ -143,7 +205,7 @@ public: while (callbacks.size()) { auto cb = callbacks.back(); callbacks.pop_back(); - callbackMap.erase(cb); + cb->destroyHolder(cb->getHolder(this)); if (cb->canFire(0)) { int ld = 0; cb->error(err, ld); @@ -160,7 +222,6 @@ public: private: std::vector callbacks; - std::unordered_map callbackMap; }; struct SetCallbackResult {