foundationdb/flow/IndexedSet.h

1417 lines
46 KiB
C++

/*
* IndexedSet.h
*
* This source file is part of the FoundationDB open source project
*
* Copyright 2013-2022 Apple Inc. and the FoundationDB project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FLOW_INDEXEDSET_H
#define FLOW_INDEXEDSET_H
#pragma once
#include "flow/Arena.h"
#include "flow/Platform.h"
#include "flow/FastAlloc.h"
#include "flow/Trace.h"
#include "flow/Error.h"
#include <deque>
#include <type_traits>
#include <vector>
// IndexedSet<T, Metric> is similar to a std::set<T>, with the following additional features:
// - Each element in the set is associated with a value of type Metric
// - sumTo() and sumRange() can report the sum of the metric values associated with a
// contiguous range of elements in O(lg N) time
// - index() can be used to find an element having a given sumTo() in O(lg N) time
// - Search functions (find(), lower_bound(), etc) can accept a type comparable to T instead of T
// (e.g. StringRef when T is std::string or Standalone<StringRef>). This can save a lot of needless
// copying at query time for read-mostly sets with string keys.
// - the size() function is missing; if the metric being used is a count sumTo(end()) will do instead
// A number of STL compatibility features are missing and should be added as needed.
// T must define operator <, which must define a total order. Unlike std::set,
// a user-defined predicate is not currently supported as a template parameter.
// Metric is required to have operators + and - and <, and behavior is undefined if
// the sum of metrics for all elements of a set overflows the Metric type.
// Map<Key,Value> is similar to a std::map<Key,Value>, except that it inherits the search key type
// flexibility of IndexedSet<>, uses MapPair<Key,Value> by default instead of pair<Key,Value>
// (use iterator->key instead of iterator->first), and uses FastAllocator for nodes.
template <class T>
class Future;
class Void;
class StringRef;
template <class T, class Metric>
struct IndexedSet {
typedef T value_type;
typedef T key_type;
private: // Forward-declare IndexedSet::Node because Clang is much stricter about this ordering.
struct Node : FastAllocated<Node> {
// Here, and throughout all code that indirectly instantiates a Node, we rely on forwarding
// references so that we don't need to maintain the set of 2^arity lvalue and rvalue reference
// combinations, but still take advantage of move constructors when available (or required).
template <class T_, class Metric_>
Node(T_&& data, Metric_&& m, Node* parent = 0)
: data(std::forward<T_>(data)), balance(0), total(std::forward<Metric_>(m)), parent(parent) {
child[0] = child[1] = nullptr;
}
Node(Node const&) = delete;
Node& operator=(Node const&) = delete;
~Node() {
delete child[0];
delete child[1];
}
T data;
signed char balance; // right height - left height
Metric total; // this + child[0] + child[1]
Node* child[2]; // left, right
Node* parent;
};
template <bool isConst>
struct IteratorImpl {
typename std::conditional_t<isConst, const IndexedSet::Node, IndexedSet::Node>* node;
explicit IteratorImpl<isConst>(const IteratorImpl<!isConst>& nonConstIter) : node(nonConstIter.node) {
static_assert(isConst);
}
explicit IteratorImpl(decltype(node) n = nullptr) : node(n){};
typename std::conditional_t<isConst, const T, T>& operator*() const { return node->data; }
typename std::conditional_t<isConst, const T, T>* operator->() const { return &node->data; }
void operator++();
void decrementNonEnd();
bool operator==(const IteratorImpl<isConst>& r) const { return node == r.node; }
bool operator!=(const IteratorImpl<isConst>& r) const { return node != r.node; }
// following two methods are for memory storage engine(KeyValueStoreMemory class) use only
// in order to have same interface as radixtree
typename std::conditional_t<isConst, const StringRef, StringRef>& getKey(uint8_t* dummyContent) const {
return node->data.key;
}
typename std::conditional_t<isConst, const StringRef, StringRef>& getValue() const { return node->data.value; }
};
template <bool isConst>
struct Impl {
using NodeT = std::conditional_t<isConst, const Node, Node>;
using IteratorT = IteratorImpl<isConst>;
using SetT = std::conditional_t<isConst, const IndexedSet<T, Metric>, IndexedSet<T, Metric>>;
static IteratorT begin(SetT&);
template <bool constIterator>
static IteratorImpl<isConst || constIterator> previous(SetT&, IteratorImpl<constIterator>);
template <class M>
static IteratorT index(SetT&, const M&);
template <class Key>
static IteratorT find(SetT&, const Key&);
template <class Key>
static IteratorT upper_bound(SetT&, const Key&);
template <class Key>
static IteratorT lower_bound(SetT&, const Key&);
template <class Key>
static IteratorT lastLessOrEqual(SetT&, const Key&);
static IteratorT lastItem(SetT&);
};
using ConstImpl = Impl<true>;
using NonConstImpl = Impl<false>;
public:
using iterator = IteratorImpl<false>;
using const_iterator = IteratorImpl<true>;
IndexedSet() : root(nullptr){};
~IndexedSet() { delete root; }
IndexedSet(IndexedSet&& r) noexcept : root(r.root) { r.root = nullptr; }
IndexedSet& operator=(IndexedSet&& r) noexcept {
delete root;
root = r.root;
r.root = 0;
return *this;
}
const_iterator begin() const { return ConstImpl::begin(*this); };
iterator begin() { return NonConstImpl::begin(*this); };
const_iterator cbegin() const { return begin(); }
const_iterator end() const { return const_iterator{}; }
iterator end() { return iterator{}; }
const_iterator cend() const { return end(); }
const_iterator previous(const_iterator i) const { return ConstImpl::previous(*this, i); }
const_iterator previous(iterator i) const { return ConstImpl::previous(*this, const_iterator{ i }); }
iterator previous(iterator i) { return NonConstImpl::previous(*this, i); }
const_iterator lastItem() const { return ConstImpl::lastItem(*this); }
iterator lastItem() { return NonConstImpl::lastItem(*this); }
bool empty() const { return !root; }
void clear() {
delete root;
root = nullptr;
}
void swap(IndexedSet& r) { std::swap(root, r.root); }
// Place data in the set with the given metric. If an item equal to data is already in the set and,
// replaceExisting == true, it will be overwritten (and its metric will be replaced)
template <class T_, class Metric_>
iterator insert(T_&& data, Metric_&& metric, bool replaceExisting = true);
// Insert all items from data into set. All items will use metric. If an item equal to data is already in the set
// and,
// replaceExisting == true, it will be overwritten (and its metric will be replaced). returns the number of items
// inserted.
int insert(const std::vector<std::pair<T, Metric>>& data, bool replaceExisting = true);
// Increase the metric for the given item by the given amount. Inserts data into the set if it
// doesn't exist. Returns the new sum.
template <class T_, class Metric_>
Metric addMetric(T_&& data, Metric_&& metric);
// Remove the data item, if any, which is equal to key
template <class Key>
void erase(const Key& key) {
erase(find(key));
}
// Erase the indicated item. No effect if item == end().
// SOMEDAY: Return ++item
void erase(iterator item);
// Erase all data items x for which begin<=x<end
template <class Key>
void erase(const Key& begin, const Key& end) {
erase(lower_bound(begin), lower_bound(end));
}
// Erase data items with a deferred (async) free process. The data structure has the items removed
// synchronously with the invocation of this method so any subsequent call will see this new state.
template <class Key>
Future<Void> eraseAsync(const Key& begin, const Key& end);
// Erase the items in the indicated range.
void erase(iterator begin, iterator end);
// Erase data items with a deferred (async) free process. The data structure has the items removed
// synchronously with the invocation of this method so any subsequent call will see this new state.
Future<Void> eraseAsync(iterator begin, iterator end);
// Returns the number of items equal to key (either 0 or 1)
template <class Key>
int count(const Key& key) const {
return find(key) != end();
}
// Returns x such that key==*x, or end()
template <class Key>
const_iterator find(const Key& key) const {
return ConstImpl::find(*this, key);
}
template <class Key>
iterator find(const Key& key) {
return NonConstImpl::find(*this, key);
}
// Returns the smallest x such that *x>=key, or end()
template <class Key>
const_iterator lower_bound(const Key& key) const {
return ConstImpl::lower_bound(*this, key);
}
template <class Key>
iterator lower_bound(const Key& key) {
return NonConstImpl::lower_bound(*this, key);
};
// Returns the smallest x such that *x>key, or end()
template <class Key>
const_iterator upper_bound(const Key& key) const {
return ConstImpl::upper_bound(*this, key);
}
template <class Key>
iterator upper_bound(const Key& key) {
return NonConstImpl::upper_bound(*this, key);
};
// Returns the largest x such that *x<=key, or end()
template <class Key>
const_iterator lastLessOrEqual(const Key& key) const {
return ConstImpl::lastLessOrEqual(*this, key);
};
template <class Key>
iterator lastLessOrEqual(const Key& key) {
return NonConstImpl::lastLessOrEqual(*this, key);
}
// Returns smallest x such that sumTo(x+1) > metric, or end()
template <class M>
const_iterator index(M const& metric) const {
return ConstImpl::index(*this, metric);
};
template <class M>
iterator index(M const& metric) {
return NonConstImpl::index(*this, metric);
}
// Return the metric inserted with item x
Metric getMetric(const_iterator x) const;
Metric getMetric(iterator x) const { return getMetric(const_iterator{ x }); }
// Return the sum of getMetric(x) for begin()<=x<to
Metric sumTo(const_iterator to) const;
Metric sumTo(iterator to) const { return sumTo(const_iterator{ to }); }
// Return the sum of getMetric(x) for begin<=x<end
Metric sumRange(const_iterator begin, const_iterator end) const { return sumTo(end) - sumTo(begin); }
Metric sumRange(iterator begin, iterator end) const {
return sumTo(const_iterator{ end }) - sumTo(const_iterator{ begin });
}
// Return the sum of getMetric(x) for all x s.t. begin <= *x && *x < end
template <class Key>
Metric sumRange(const Key& begin, const Key& end) const {
return sumRange(lower_bound(begin), lower_bound(end));
}
// Return the amount of memory used by an entry in the IndexedSet
constexpr static int getElementBytes() { return sizeof(Node); }
private:
// Copy operations unimplemented. SOMEDAY: Implement and make public.
IndexedSet(const IndexedSet&);
IndexedSet& operator=(const IndexedSet&);
Node* root;
Metric eraseHalf(Node* start, Node* end, int eraseDir, int& heightDelta, std::vector<Node*>& toFree);
void erase(iterator begin, iterator end, std::vector<Node*>& toFree);
void replacePointer(Node* oldNode, Node* newNode) {
if (oldNode->parent)
oldNode->parent->child[oldNode->parent->child[1] == oldNode] = newNode;
else
root = newNode;
if (newNode)
newNode->parent = oldNode->parent;
}
template <int direction, bool isConst>
static void moveIteratorImpl(std::conditional_t<isConst, const Node, Node>*& node) {
if (node->child[0 ^ direction]) {
node = node->child[0 ^ direction];
while (node->child[1 ^ direction])
node = node->child[1 ^ direction];
} else {
while (node->parent && node->parent->child[0 ^ direction] == node)
node = node->parent;
node = node->parent;
}
}
// direction 0 = left, 1 = right
template <int direction>
static void moveIterator(Node const*& node) {
moveIteratorImpl<direction, true>(node);
}
template <int direction>
static void moveIterator(Node*& node) {
moveIteratorImpl<direction, false>(node);
}
public: // but testonly
std::pair<int, int> testonly_assertBalanced(Node* n = 0, int d = 0, bool a = true);
};
class NoMetric {
public:
NoMetric() {}
NoMetric(int) {} // NoMetric(1)
NoMetric operator+(NoMetric const&) const { return NoMetric(); }
NoMetric operator-(NoMetric const&) const { return NoMetric(); }
bool operator<(NoMetric const&) const { return false; }
};
template <class Key, class Value>
class MapPair {
public:
Key key;
Value value;
template <class Key_, class Value_>
MapPair(Key_&& key, Value_&& value) : key(std::forward<Key_>(key)), value(std::forward<Value_>(value)) {}
void operator=(MapPair const& rhs) {
key = rhs.key;
value = rhs.value;
}
MapPair(MapPair const& rhs) : key(rhs.key), value(rhs.value) {}
MapPair(MapPair&& r) noexcept : key(std::move(r.key)), value(std::move(r.value)) {}
void operator=(MapPair&& r) noexcept {
key = std::move(r.key);
value = std::move(r.value);
}
int compare(MapPair<Key, Value> const& r) const { return ::compare(key, r.key); }
template <class CompatibleWithKey>
int compare(CompatibleWithKey const& r) const {
return ::compare(key, r);
}
bool operator<(MapPair<Key, Value> const& r) const { return key < r.key; }
bool operator>(MapPair<Key, Value> const& r) const { return key > r.key; }
bool operator<=(MapPair<Key, Value> const& r) const { return key <= r.key; }
bool operator>=(MapPair<Key, Value> const& r) const { return key >= r.key; }
bool operator==(MapPair<Key, Value> const& r) const { return key == r.key; }
bool operator!=(MapPair<Key, Value> const& r) const { return key != r.key; }
// private: MapPair( const MapPair& );
};
template <class Key, class Value, class CompatibleWithKey>
inline int compare(CompatibleWithKey const& l, MapPair<Key, Value> const& r) {
return compare(l, r.key);
}
template <class Key, class Value>
inline MapPair<typename std::decay<Key>::type, typename std::decay<Value>::type> mapPair(Key&& key, Value&& value) {
return MapPair<typename std::decay<Key>::type, typename std::decay<Value>::type>(std::forward<Key>(key),
std::forward<Value>(value));
}
template <class Key, class Value, class CompatibleWithKey>
bool operator<(MapPair<Key, Value> const& l, CompatibleWithKey const& r) {
return l.key < r;
}
template <class Key, class Value, class CompatibleWithKey>
bool operator<(CompatibleWithKey const& l, MapPair<Key, Value> const& r) {
return l < r.key;
}
template <class Key, class Value, class Pair = MapPair<Key, Value>, class Metric = NoMetric>
class Map {
public:
typedef typename IndexedSet<Pair, Metric>::iterator iterator;
typedef typename IndexedSet<Pair, Metric>::const_iterator const_iterator;
Map() {}
const_iterator begin() const { return set.begin(); }
iterator begin() { return set.begin(); }
const_iterator cbegin() const { return begin(); }
const_iterator end() const { return set.end(); }
iterator end() { return set.end(); }
const_iterator cend() const { return end(); }
const_iterator lastItem() const { return set.lastItem(); }
iterator lastItem() { return set.lastItem(); }
const_iterator previous(const_iterator i) const { return set.previous(i); }
iterator previous(iterator i) { return set.previous(i); }
bool empty() const { return set.empty(); }
Value& operator[](const Key& key) {
iterator i = set.insert(Pair(key, Value()), Metric(1), false);
return i->value;
}
Value& get(const Key& key, Metric m = Metric(1)) {
iterator i = set.insert(Pair(key, Value()), m, false);
return i->value;
}
iterator insert(const Pair& p, bool replaceExisting = true, Metric m = Metric(1)) {
return set.insert(p, m, replaceExisting);
}
iterator insert(Pair&& p, bool replaceExisting = true, Metric m = Metric(1)) {
return set.insert(std::move(p), m, replaceExisting);
}
int insert(const std::vector<std::pair<MapPair<Key, Value>, Metric>>& pairs, bool replaceExisting = true) {
return set.insert(pairs, replaceExisting);
}
template <class KeyCompatible>
void erase(KeyCompatible const& k) {
set.erase(k);
}
void erase(iterator b, iterator e) { set.erase(b, e); }
void erase(iterator x) { set.erase(x); }
void clear() { set.clear(); }
Metric size() const {
static_assert(!std::is_same<Metric, NoMetric>::value, "size() on Map with NoMetric is not valid!");
return sumTo(end());
}
template <class KeyCompatible>
const_iterator find(KeyCompatible const& k) const {
return set.find(k);
}
template <class KeyCompatible>
iterator find(KeyCompatible const& k) {
return set.find(k);
}
template <class KeyCompatible>
const_iterator lower_bound(KeyCompatible const& k) const {
return set.lower_bound(k);
}
template <class KeyCompatible>
iterator lower_bound(KeyCompatible const& k) {
return set.lower_bound(k);
}
template <class KeyCompatible>
const_iterator upper_bound(KeyCompatible const& k) const {
return set.upper_bound(k);
}
template <class KeyCompatible>
iterator upper_bound(KeyCompatible const& k) {
return set.upper_bound(k);
}
template <class KeyCompatible>
const_iterator lastLessOrEqual(KeyCompatible const& k) const {
return set.lastLessOrEqual(k);
}
template <class KeyCompatible>
iterator lastLessOrEqual(KeyCompatible const& k) {
return set.lastLessOrEqual(k);
}
template <class M>
const_iterator index(M const& metric) const {
return set.index(metric);
}
template <class M>
iterator index(M const& metric) {
return set.index(metric);
}
Metric getMetric(const_iterator x) const { return set.getMetric(x); }
Metric getMetric(iterator x) const { return getMetric(const_iterator{ x }); }
Metric sumTo(const_iterator to) const { return set.sumTo(to); }
Metric sumTo(iterator to) const { return sumTo(const_iterator{ to }); }
Metric sumRange(const_iterator begin, const_iterator end) const { return set.sumRange(begin, end); }
Metric sumRange(iterator begin, iterator end) const { return set.sumRange(begin, end); }
template <class KeyCompatible>
Metric sumRange(const KeyCompatible& begin, const KeyCompatible& end) const {
return set.sumRange(begin, end);
}
static int getElementBytes() { return IndexedSet<Pair, Metric>::getElementBytes(); }
Map(Map&& r) noexcept : set(std::move(r.set)) {}
void operator=(Map&& r) noexcept { set = std::move(r.set); }
Future<Void> clearAsync();
private:
Map(Map<Key, Value, Pair> const&); // unimplemented
void operator=(Map<Key, Value, Pair> const&); // unimplemented
IndexedSet<Pair, Metric> set;
};
/////////////////////// implementation //////////////////////////
template <class T, class Metric>
template <bool isConst>
void IndexedSet<T, Metric>::IteratorImpl<isConst>::operator++() {
moveIterator<1>(node);
}
template <class T, class Metric>
template <bool isConst>
void IndexedSet<T, Metric>::IteratorImpl<isConst>::decrementNonEnd() {
moveIterator<0>(node);
}
template <class Node>
void ISRotate(Node*& oldRootRef, int d) {
Node* oldRoot = oldRootRef;
Node* newRoot = oldRoot->child[1 - d];
// metrics
auto orTotal = oldRoot->total - newRoot->total;
if (newRoot->child[d])
orTotal = orTotal + newRoot->child[d]->total;
newRoot->total = oldRoot->total;
oldRoot->total = orTotal;
// pointers
oldRoot->child[1 - d] = newRoot->child[d];
if (oldRoot->child[1 - d])
oldRoot->child[1 - d]->parent = oldRoot;
newRoot->child[d] = oldRoot;
newRoot->parent = oldRoot->parent;
oldRoot->parent = newRoot;
oldRootRef = newRoot;
}
template <class Node>
void ISAdjustBalance(Node* root, int d, int bal) {
Node* n = root->child[d];
Node* nn = n->child[1 - d];
if (!nn->balance)
root->balance = n->balance = 0;
else if (nn->balance == bal) {
root->balance = -bal;
n->balance = 0;
} else {
root->balance = 0;
n->balance = bal;
}
nn->balance = 0;
}
template <class Node>
int ISRebalance(Node*& root) {
// Pre: root is a tree having the BST, metric, and balance invariants but not (necessarily) the AVL invariant.
// root->child[0] and root->child[1] are AVL. Post: root is an AVL tree with the same nodes Returns: the change in
// height of root rebalance is O(1) if abs(root->balance)<=2, and probably O(log N) otherwise. (The rare "still
// unbalanced" recursion is hard to analyze)
//
// The documentation of this function will be referencing the following tree (where
// nodes A, C, E, and G represent subtrees of unspecified height). Thus for each node X,
// we know the value of balance(X), but not height(X).
//
// We will assume that balance(F) < 0 (so we will be rotating right).
// Trees that rotate to the left will perform analagous operations.
//
// F
// / \
// B G
// / \
// A D
// / \
// C E
if (!root || (root->balance >= -1 && root->balance <= +1))
return 0;
int rebalanceDir = root->balance < 0; // 1 if rotating right, 0 if rotating left
auto* n = root->child[1 - rebalanceDir]; // Node B
int bal = rebalanceDir ? +1 : -1; // 1 if rotating right, -1 if rotating left
int rootBal = root->balance;
// Depending on the balance at B, we will be required to do one or two rotations.
// If balance(B) <= 0, then we do only one rotation (the second of the two).
//
// In a tree where balance(B) == +1, we are required to do both rotations.
// The result of the first rotation will be:
//
// F
// / \
// D G
// / \
// B E
// / \
// A C
//
bool doubleRotation = n->balance == bal;
if (doubleRotation) {
int x = n->child[rebalanceDir]->balance; // balance of Node D
ISRotate(root->child[1 - rebalanceDir], 1 - rebalanceDir); // Rotate at Node B
// Change node pointed to by 'n' to prepare for the second rotation
// After this first rotation, Node D will be the left child of the root
n = root->child[1 - rebalanceDir];
// Compute the balance at the new root node D' of our rotation
// We know that height(A) == max(height(C), height(E)) because B had balance of +1
// If height(E) >= height(C), then height(E) == height(A) and balance(D') = -1
// Otherwise height(C) == height(E) + 1, and therefore balance(D') = -2
n->balance = ((x == -bal) ? -2 : -1) * bal;
// Compute the balance at the old root node B' of our rotation
// As stated above, height(A) == max(height(C), height(E))
// If height(C) >= height(E), then height(A) == height(C) and balance(B') = 0
// Otherwise height(A) == height(E) == height(C) + 1, and therefore balance(B') = -1
n->child[1 - rebalanceDir]->balance = ((x == bal) ? -1 : 0) * bal;
}
// At this point, we perform the "second" rotation (which may actually be the first
// if the "first" rotation was not performed). The rotation that is performed is the
// same for both trees, but the result will be different depending on which tree we
// started with:
//
// If unrotated: If once rotated:
//
// B D
// / \ / \
// A F B F
// / \ / \ / \
// D G A C E G
// / \
// C E
//
// The documentation for this second rotation will be based on the unrotated original tree.
// Compute the balance at the new root node B'.
// balance(B') = 1 + max(height(D), height(G)) - height(A) = 1 + max(height(D) - height(A), height(G) - height(A))
// balance(B') = 1 + max(balance(B), height(G) - height(A))
//
// Now, we must find height(G) - height(A):
// If height(A) >= height(D) (i.e. balance(B) <= 0), then
// height(G) - height(A) = height(G) - height(B) + 1 = balance(F) + 1
//
// Otherwise, height(A) = height(D) - balance(B) = height(B) - 1 - balance(B), so
// height(G) - height(A) = height(G) - height(B) + 1 + balance(B) = balance(F) + 1 + balance(B)
//
// balance(B') = 1 + max(balance(B), balance(F) + 1 + max(balance(B), 0))
//
int nBal = n->balance * bal; // Direction corrected balance at Node B
int newRootBalance = bal * (1 + std::max(nBal, bal * root->balance + 1 + std::max(nBal, 0)));
// Compute the balance at the old root node F' (which becomes a child of the new root).
// balance(F') = height(G) - height(D)
//
// If height(D) >= height(A) (i.e. balance(B) >= 0), then height(D) = height(B) - 1, so
// balance(F') = height(G) - height(B) + 1 = balance(F) + 1
//
// Otherwise, height(D) = height(A) + balance(B) = height(B) - 1 + balance(B), so
// balance(F') = height(G) - height(B) + 1 - balance(B) = balance(F) + 1 - balance(B)
//
// balance(F') = balance(F) + 1 - min(balance(B), 0)
//
int newChildBalance = root->balance + bal * (1 - std::min(nBal, 0));
ISRotate(root, rebalanceDir);
root->balance = newRootBalance;
root->child[rebalanceDir]->balance = newChildBalance;
// If the original tree is very unbalanced, the unbalance may have been "pushed" down into this subtree, so
// recursively rebalance that if necessary.
int childHeightChange = ISRebalance(root->child[rebalanceDir]);
root->balance += childHeightChange * bal;
newRootBalance *= bal;
// Compute the change in height at the root
// We will look at the single and double rotation cases separately
//
// If we did a single rotation, then height(A) >= height(D).
// As a result, height(A) >= height(G) + 1; otherwise the tree would be balanced and we wouldn't do any rotations.
//
// Then the original height of the tree is height(A) + 2,
// and the new height is max(height(D) + 2 + childHeightChange, height(A) + 1), so
//
// heightChange_single = max(height(D) + 2 + childHeightChange, height(A) + 1) - (height(A) + 2)
// heightChange_single = max(height(D) - height(A) + childHeightChange, -1)
// heightChange_single = max(balance(B) + childHeightChange, -1)
//
// If we did a double rotation, then height(D) = height(A) + 1 in the original tree.
// As a result, height(D) >= height(G) + 1; otherwise the tree would be balanced and we wouldn't do any rotations.
//
// Then the original height of the tree is height(D) + 2,
// and the new height is max(height(A), height(C), height(E), height(G)) + 2
//
// balance(B) == 1, so height(A) == max(height(C), height(E)).
// Also, height(A) = height(D) - 1 >= height(G)
// Therefore the new height is height(A) + 2
//
// heightChange_double = height(A) + 2 - (height(D) + 2)
// heightChange_double = height(A) - height(D)
// heightChange_double = -1
//
int heightChange = doubleRotation ? -1 : std::max(nBal + childHeightChange, -1);
// If the root is still unbalanced, then it should at least be more balanced than before. Recursively rebalance the
// root until we get a balanced tree.
if (root->balance < -1 || root->balance > +1) {
ASSERT(abs(root->balance) < abs(rootBal));
heightChange += ISRebalance(root);
}
return heightChange;
}
template <class Node>
Node* ISCommonSubtreeRoot(Node* first, Node* last) {
// Finds the smallest common subtree of first and last and returns its root node
// Find the depth of first and last
int firstDepth = 0, lastDepth = 0;
for (auto f = first; f; f = f->parent)
firstDepth++;
for (auto f = last; f; f = f->parent)
lastDepth++;
// Traverse up the tree from the deeper of first and last until f and l are at the same depth
auto f = first, l = last;
for (int i = firstDepth; i > lastDepth; i--)
f = f->parent;
for (int i = lastDepth; i > firstDepth; i--)
l = l->parent;
// Traverse up from f and l simultaneously until we reach a common node
while (f != l) {
f = f->parent;
l = l->parent;
}
return f;
}
template <class T, class Metric>
template <bool isConst>
typename IndexedSet<T, Metric>::template Impl<isConst>::IteratorT IndexedSet<T, Metric>::Impl<isConst>::begin(
IndexedSet<T, Metric>::Impl<isConst>::SetT& self) {
NodeT* x = self.root;
while (x && x->child[0])
x = x->child[0];
return IteratorT{ x };
}
template <class T, class Metric>
template <bool isConst>
template <bool constIterator>
typename IndexedSet<T, Metric>::template IteratorImpl<isConst || constIterator>
IndexedSet<T, Metric>::Impl<isConst>::previous(IndexedSet<T, Metric>::Impl<isConst>::SetT& self,
IndexedSet<T, Metric>::IteratorImpl<constIterator> iter) {
if (iter == self.end())
return self.lastItem();
moveIterator<0>(iter.node);
return iter;
}
template <class T, class Metric>
template <bool isConst>
typename IndexedSet<T, Metric>::template Impl<isConst>::IteratorT IndexedSet<T, Metric>::Impl<isConst>::lastItem(
IndexedSet<T, Metric>::Impl<isConst>::SetT& self) {
NodeT* x = self.root;
while (x && x->child[1])
x = x->child[1];
return IteratorT{ x };
}
template <class T, class Metric>
template <class T_, class Metric_>
Metric IndexedSet<T, Metric>::addMetric(T_&& data, Metric_&& metric) {
auto i = find(data);
if (i == end()) {
insert(std::forward<T_>(data), std::forward<Metric_>(metric));
return metric;
} else {
Metric m = metric + getMetric(i);
insert(std::forward<T_>(data), m);
return m;
}
}
template <class T, class Metric>
template <class T_, class Metric_>
typename IndexedSet<T, Metric>::iterator IndexedSet<T, Metric>::insert(T_&& data,
Metric_&& metric,
bool replaceExisting) {
if (root == nullptr) {
root = new Node(std::forward<T_>(data), std::forward<Metric_>(metric));
return iterator{ root };
}
Node* t = root;
int d; // direction
// traverse to find insert point
while (true) {
int cmp = compare(data, t->data);
if (cmp == 0) {
Node* returnNode = t;
if (replaceExisting) {
t->data = std::forward<T_>(data);
Metric delta = t->total;
t->total = std::forward<Metric_>(metric);
if (t->child[0])
t->total = t->total + t->child[0]->total;
if (t->child[1])
t->total = t->total + t->child[1]->total;
delta = t->total - delta;
while (true) {
t = t->parent;
if (!t)
break;
t->total = t->total + delta;
}
}
return iterator{ returnNode };
}
d = cmp > 0;
Node* nextT = t->child[d];
if (!nextT)
break;
t = nextT;
}
Node* newNode = new Node(std::forward<T_>(data), std::forward<Metric_>(metric), t);
t->child[d] = newNode;
while (true) {
t->balance += d ? 1 : -1;
t->total = t->total + metric;
if (t->balance == 0)
break;
if (t->balance != 1 && t->balance != -1) {
Node** parent = t->parent ? &t->parent->child[t->parent->child[1] == t] : &root;
// assert( *parent == t );
Node* n = t->child[d];
int bal = d ? 1 : -1;
if (n->balance == bal) {
t->balance = n->balance = 0;
} else {
ISAdjustBalance(t, d, bal);
ISRotate(t->child[d], d);
}
ISRotate(*parent, 1 - d);
t = *parent;
break;
}
if (!t->parent)
break;
d = t->parent->child[1] == t;
t = t->parent;
}
while (true) {
t = t->parent;
if (!t)
break;
t->total = t->total + metric;
}
return iterator{ newNode };
}
template <class T, class Metric>
int IndexedSet<T, Metric>::insert(const std::vector<std::pair<T, Metric>>& dataVector, bool replaceExisting) {
int num_inserted = 0;
Node* blockStart = nullptr;
Node* blockEnd = nullptr;
for (int i = 0; i < dataVector.size(); ++i) {
Metric metric = dataVector[i].second;
T data = std::move(dataVector[i].first);
int d = 1; // direction
if (blockStart == nullptr || (blockEnd != nullptr && data >= blockEnd->data)) {
blockEnd = nullptr;
if (root == nullptr) {
root = new Node(std::move(data), metric);
num_inserted++;
blockStart = root;
continue;
}
Node* t = root;
// traverse to find insert point
bool foundNode = false;
while (true) {
int cmp = compare(data, t->data);
d = cmp > 0;
if (d == 0)
blockEnd = t;
if (cmp == 0) {
Node* returnNode = t;
if (replaceExisting) {
num_inserted++;
t->data = std::move(data);
Metric delta = t->total;
t->total = metric;
if (t->child[0])
t->total = t->total + t->child[0]->total;
if (t->child[1])
t->total = t->total + t->child[1]->total;
delta = t->total - delta;
while (true) {
t = t->parent;
if (!t)
break;
t->total = t->total + delta;
}
}
blockStart = returnNode;
foundNode = true;
break;
}
Node* nextT = t->child[d];
if (!nextT) {
blockStart = t;
break;
}
t = nextT;
}
if (foundNode)
continue;
}
Node* t = blockStart;
while (t->child[d]) {
t = t->child[d];
d = 0;
}
Node* newNode = new Node(std::move(data), metric, t);
num_inserted++;
t->child[d] = newNode;
blockStart = newNode;
while (true) {
t->balance += d ? 1 : -1;
t->total = t->total + metric;
if (t->balance == 0)
break;
if (t->balance != 1 && t->balance != -1) {
Node** parent = t->parent ? &t->parent->child[t->parent->child[1] == t] : &root;
// assert( *parent == t );
Node* n = t->child[d];
int bal = d ? 1 : -1;
if (n->balance == bal) {
t->balance = n->balance = 0;
} else {
ISAdjustBalance(t, d, bal);
ISRotate(t->child[d], d);
}
ISRotate(*parent, 1 - d);
t = *parent;
break;
}
if (!t->parent)
break;
d = t->parent->child[1] == t;
t = t->parent;
}
while (true) {
t = t->parent;
if (!t)
break;
t->total = t->total + metric;
}
}
return num_inserted;
}
template <class T, class Metric>
Metric IndexedSet<T, Metric>::eraseHalf(Node* start,
Node* end,
int eraseDir,
int& heightDelta,
std::vector<Node*>& toFree) {
// Removes all nodes between start (inclusive) and end (exclusive) from the set, where start is equal to end or one
// of its descendants eraseDir 1 means erase the right half (nodes > at) of the left subtree of end. eraseDir 0
// means the left half of the right subtree toFree is extended with the roots of completely removed subtrees
// heightDelta will be set to the change in height of the end node
// Returns the amount that should be subtracted from end node's metric value (and, by extension, the metric values
// of all ancestors of the end node).
//
// The end node may be left unbalanced (AVL invariant broken)
// The end node may be left with the incorrect metric total (the correct value is end->total = end->total +
// metricDelta) scare quotes in comments mean the values when eraseDir==1 (when eraseDir==0, "left" means right etc)
// metricDelta measures how much should be subtracted from the current node's metrics
Metric metricDelta = 0;
heightDelta = 0;
int fromDir = 1 - eraseDir;
// Begin removing nodes at start continuing up until we get to end
while (start != end) {
start->total = start->total - metricDelta;
IndexedSet<T, Metric>::Node* parent = start->parent;
// Obtain the child pointer to start, which rebalance will update with the new root of the subtree currently
// rooted at start
IndexedSet<T, Metric>::Node*& node = parent->child[parent->child[1] == start];
int nextDir = parent->child[1] == start;
if (fromDir == eraseDir) {
// The "right" subtree has been half-erased, and the "left" subtree doesn't need to be (nor does node).
// But this node might be unbalanced by the shrinking "right" subtree. Rebalance and continue up.
heightDelta += ISRebalance(node);
} else {
// The "left" subtree has been half-erased. `start' and its "right" subtree will be completely erased,
// leaving only the "left" subtree in its place (which is already AVL balanced).
heightDelta += -1 - std::max<int>(0, node->balance * (eraseDir ? +1 : -1));
metricDelta = metricDelta + start->total;
// If there is a surviving subtree of start, then connect it to start->parent
IndexedSet<T, Metric>::Node* n = node->child[fromDir];
node = n; // This updates the appropriate child pointer of start->parent
if (n) {
metricDelta = metricDelta - n->total;
n->parent = start->parent;
}
start->child[fromDir] = nullptr;
toFree.push_back(start);
}
int dir = (nextDir ? +1 : -1);
int oldBalance = parent->balance;
// The change in height from removing nodes should never increase our height
ASSERT(heightDelta <= 0);
parent->balance += heightDelta * dir;
// Compute the change in height of start's parent based on its change in balance.
// Because we can only be (possibly) shrinking one subtree of parent:
// If we were originally heavier on the shrunken size (oldBalance * dir > 0), then the change in height is at
// most abs(oldBalance) == oldBalance * dir. If we were lighter on the shrunken side, then height cannot
// change.
int maxHeightChange = std::max(oldBalance * dir, 0);
int balanceChange = (oldBalance - parent->balance) * dir;
heightDelta = -std::min(maxHeightChange, balanceChange);
start = parent;
fromDir = nextDir;
}
return metricDelta;
}
template <class T, class Metric>
void IndexedSet<T, Metric>::erase(typename IndexedSet<T, Metric>::iterator begin,
typename IndexedSet<T, Metric>::iterator end,
std::vector<Node*>& toFree) {
// Removes all nodes in the set between first and last, inclusive.
// toFree is extended with the roots of completely removed subtrees.
ASSERT(!end.node || (begin.node && (::compare(*begin, *end) <= 0)));
if (begin == end)
return;
IndexedSet<T, Metric>::Node* first = begin.node;
IndexedSet<T, Metric>::Node* last = previous(end).node;
IndexedSet<T, Metric>::Node* subRoot = ISCommonSubtreeRoot(first, last);
Metric metricDelta = 0;
int leftHeightDelta = 0;
int rightHeightDelta = 0;
// Erase all matching nodes that descend from subRoot, by first erasing descendants of subRoot->child[0] and then
// erasing the descendants of subRoot->child[1] subRoot is not removed from the tree at this time
metricDelta = metricDelta + eraseHalf(first, subRoot, 1, leftHeightDelta, toFree);
metricDelta = metricDelta + eraseHalf(last, subRoot, 0, rightHeightDelta, toFree);
// Change in the height of subRoot due to past activity, before subRoot is rebalanced. subRoot->balance already
// reflects changes in height to its children.
int heightDelta = leftHeightDelta + rightHeightDelta;
// Rebalance and update metrics for all nodes from subRoot up to the root
for (auto p = subRoot; p != nullptr; p = p->parent) {
p->total = p->total - metricDelta;
auto& pc = p->parent ? p->parent->child[p->parent->child[1] == p] : root;
heightDelta += ISRebalance(pc);
p = pc;
// Update the balance and compute heightDelta for p->parent
if (p->parent) {
int oldb = p->parent->balance;
int dir = (p->parent->child[1] == p ? +1 : -1);
p->parent->balance += heightDelta * dir;
heightDelta = (std::max(p->parent->balance * dir, 0) - std::max(oldb * dir, 0));
}
}
// Erase the subRoot using the single node erase implementation
erase(IndexedSet<T, Metric>::iterator(subRoot));
}
template <class T, class Metric>
void IndexedSet<T, Metric>::erase(iterator toErase) {
Node* rebalanceNode;
int rebalanceDir;
{
// Find the node to erase
Node* t = toErase.node;
if (!t)
return;
if (!t->child[0] || !t->child[1]) {
Metric tMetric = t->total;
if (t->child[0])
tMetric = tMetric - t->child[0]->total;
if (t->child[1])
tMetric = tMetric - t->child[1]->total;
for (Node* p = t->parent; p; p = p->parent)
p->total = p->total - tMetric;
rebalanceNode = t->parent;
if (rebalanceNode)
rebalanceDir = rebalanceNode->child[1] == t;
int d = !t->child[0]; // Only one child, on this side (or no children!)
replacePointer(t, t->child[d]);
t->child[d] = 0;
delete t;
} else { // Remove node with two children
Node* predecessor = t->child[0];
while (predecessor->child[1])
predecessor = predecessor->child[1];
rebalanceNode = predecessor->parent;
if (rebalanceNode == t)
rebalanceNode = predecessor;
if (rebalanceNode)
rebalanceDir = rebalanceNode->child[1] == predecessor;
Metric tMetric = t->total - t->child[0]->total - t->child[1]->total;
if (predecessor->child[0])
predecessor->total = predecessor->total - predecessor->child[0]->total;
for (Node* p = predecessor->parent; p != t; p = p->parent)
p->total = p->total - predecessor->total;
for (Node* p = t->parent; p; p = p->parent)
p->total = p->total - tMetric;
// Replace t with predecessor
replacePointer(predecessor, predecessor->child[0]);
replacePointer(t, predecessor);
predecessor->balance = t->balance;
for (int i = 0; i < 2; i++) {
Node* c = predecessor->child[i] = t->child[i];
if (c) {
c->parent = predecessor;
predecessor->total = predecessor->total + c->total;
t->child[i] = 0;
}
}
delete t;
}
}
if (!rebalanceNode)
return;
while (true) {
rebalanceNode->balance += rebalanceDir ? -1 : +1;
if (rebalanceNode->balance < -1 || rebalanceNode->balance > +1) {
Node** parent = rebalanceNode->parent
? &rebalanceNode->parent->child[rebalanceNode->parent->child[1] == rebalanceNode]
: &root;
Node* n = rebalanceNode->child[1 - rebalanceDir];
int bal = rebalanceDir ? +1 : -1;
if (n->balance == -bal) {
rebalanceNode->balance = n->balance = 0;
ISRotate(*parent, rebalanceDir);
} else if (n->balance == bal) {
ISAdjustBalance(rebalanceNode, 1 - rebalanceDir, -bal);
ISRotate(rebalanceNode->child[1 - rebalanceDir], 1 - rebalanceDir);
ISRotate(*parent, rebalanceDir);
} else { // n->balance == 0
rebalanceNode->balance = -bal;
n->balance = bal;
ISRotate(*parent, rebalanceDir);
break;
}
rebalanceNode = *parent;
} else if (rebalanceNode->balance) // +/- 1, we are done
break;
if (!rebalanceNode->parent)
break;
rebalanceDir = rebalanceNode->parent->child[1] == rebalanceNode;
rebalanceNode = rebalanceNode->parent;
}
}
// Returns x such that key==*x, or end()
template <class T, class Metric>
template <bool isConst>
template <class Key>
typename IndexedSet<T, Metric>::template Impl<isConst>::IteratorT IndexedSet<T, Metric>::Impl<isConst>::find(
IndexedSet<T, Metric>::Impl<isConst>::SetT& self,
const Key& key) {
NodeT* t = self.root;
while (t) {
int cmp = compare(key, t->data);
if (cmp == 0)
return IteratorT{ t };
t = t->child[cmp > 0];
}
return self.end();
}
// Returns the smallest x such that *x>=key, or end()
template <class T, class Metric>
template <bool isConst>
template <class Key>
typename IndexedSet<T, Metric>::template Impl<isConst>::IteratorT IndexedSet<T, Metric>::Impl<isConst>::lower_bound(
IndexedSet<T, Metric>::Impl<isConst>::SetT& self,
const Key& key) {
NodeT* t = self.root;
if (!t)
return self.end();
bool less;
while (true) {
less = t->data < key;
NodeT* n = t->child[less];
if (!n)
break;
t = n;
}
if (less)
moveIterator<1>(t);
return IteratorT{ t };
}
// Returns the smallest x such that *x>key, or end()
template <class T, class Metric>
template <bool isConst>
template <class Key>
typename IndexedSet<T, Metric>::template Impl<isConst>::IteratorT IndexedSet<T, Metric>::Impl<isConst>::upper_bound(
IndexedSet<T, Metric>::Impl<isConst>::SetT& self,
const Key& key) {
NodeT* t = self.root;
if (!t)
return self.end();
bool not_less;
while (true) {
not_less = !(key < t->data);
NodeT* n = t->child[not_less];
if (!n)
break;
t = n;
}
if (not_less)
moveIterator<1>(t);
return IteratorT{ t };
}
template <class T, class Metric>
template <bool isConst>
template <class Key>
typename IndexedSet<T, Metric>::template Impl<isConst>::IteratorT IndexedSet<T, Metric>::Impl<isConst>::lastLessOrEqual(
IndexedSet<T, Metric>::Impl<isConst>::SetT& self,
const Key& key) {
auto i = self.upper_bound(key);
if (i == self.begin())
return self.end();
return self.previous(i);
}
// Returns first x such that metric < sum(begin(), x+1), or end()
template <class T, class Metric>
template <bool isConst>
template <class M>
typename IndexedSet<T, Metric>::template Impl<isConst>::IteratorT IndexedSet<T, Metric>::Impl<isConst>::index(
IndexedSet<T, Metric>::Impl<isConst>::SetT& self,
const M& metric) {
M m = metric;
NodeT* t = self.root;
while (t) {
if (t->child[0] && m < t->child[0]->total)
t = t->child[0];
else {
m = m - t->total;
if (t->child[1])
m = m + t->child[1]->total;
if (m < M())
return IteratorT{ t };
t = t->child[1];
}
}
return self.end();
}
template <class T, class Metric>
Metric IndexedSet<T, Metric>::getMetric(typename IndexedSet<T, Metric>::const_iterator x) const {
Metric m = x.node->total;
for (int i = 0; i < 2; i++)
if (x.node->child[i])
m = m - x.node->child[i]->total;
return m;
}
template <class T, class Metric>
Metric IndexedSet<T, Metric>::sumTo(typename IndexedSet<T, Metric>::const_iterator end) const {
if (!end.node)
return root ? root->total : Metric();
Metric m = end.node->child[0] ? end.node->child[0]->total : Metric();
for (const Node* p = end.node; p->parent; p = p->parent) {
if (p->parent->child[1] == p) {
m = m - p->total;
m = m + p->parent->total;
}
}
return m;
}
#include "flow/flow.h"
#include "flow/IndexedSet.actor.h"
template <class T, class Metric>
void IndexedSet<T, Metric>::erase(typename IndexedSet<T, Metric>::iterator begin,
typename IndexedSet<T, Metric>::iterator end) {
std::vector<IndexedSet<T, Metric>::Node*> toFree;
erase(begin, end, toFree);
ISFreeNodes(toFree, true);
}
template <class T, class Metric>
template <class Key>
Future<Void> IndexedSet<T, Metric>::eraseAsync(const Key& begin, const Key& end) {
return eraseAsync(lower_bound(begin), lower_bound(end));
}
template <class T, class Metric>
Future<Void> IndexedSet<T, Metric>::eraseAsync(typename IndexedSet<T, Metric>::iterator begin,
typename IndexedSet<T, Metric>::iterator end) {
std::vector<IndexedSet<T, Metric>::Node*> toFree;
erase(begin, end, toFree);
return uncancellable(ISFreeNodes(toFree, false));
}
template <class Key, class Value, class Pair, class Metric>
Future<Void> Map<Key, Value, Pair, Metric>::clearAsync() {
return set.eraseAsync(set.begin(), set.end());
}
#endif