foundationdb/flow/IndexedSet.h

1115 lines
38 KiB
C++

/*
* IndexedSet.h
*
* This source file is part of the FoundationDB open source project
*
* Copyright 2013-2018 Apple Inc. and the FoundationDB project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FLOW_INDEXEDSET_H
#define FLOW_INDEXEDSET_H
#pragma once
#include "Platform.h"
#include "FastAlloc.h"
#include "Trace.h"
#include "Error.h"
#include <deque>
#include <vector>
// IndexedSet<T, Metric> is similar to a std::set<T>, with the following additional features:
// - Each element in the set is associated with a value of type Metric
// - sumTo() and sumRange() can report the sum of the metric values associated with a
// contiguous range of elements in O(lg N) time
// - index() can be used to find an element having a given sumTo() in O(lg N) time
// - Search functions (find(), lower_bound(), etc) can accept a type comparable to T instead of T
// (e.g. StringRef when T is std::string or Standalone<StringRef>). This can save a lot of needless
// copying at query time for read-mostly sets with string keys.
// - iterators are not const; the responsibility of not changing the order lies with the caller
// - the size() function is missing; if the metric being used is a count sumTo(end()) will do instead
// A number of STL compatibility features are missing and should be added as needed.
// T must define operator <, which must define a total order. Unlike std::set,
// a user-defined predicate is not currently supported as a template parameter.
// Metric is required to have operators + and - and <, and behavior is undefined if
// the sum of metrics for all elements of a set overflows the Metric type.
// Map<Key,Value> is similar to a std::map<Key,Value>, except that it inherits the search key type
// flexibility of IndexedSet<>, uses MapPair<Key,Value> by default instead of pair<Key,Value>
// (use iterator->key instead of iterator->first), and uses FastAllocator for nodes.
template <class T>
class Future;
class Void;
template <class T, class Metric>
struct IndexedSet{
typedef T value_type;
typedef T key_type;
private: // Forward-declare IndexedSet::Node because Clang is much stricter about this ordering.
struct Node : FastAllocated<Node> {
// Here, and throughout all code that indirectly instantiates a Node, we rely on forwarding
// references so that we don't need to maintain the set of 2^arity lvalue and rvalue reference
// combinations, but still take advantage of move constructors when available (or required).
template <class T_, class Metric_>
Node(T_&& data, Metric_&& m, Node* parent=0) : data(std::forward<T_>(data)), total(std::forward<Metric_>(m)), parent(parent), balance(0) {
child[0] = child[1] = NULL;
}
~Node(){
delete child[0];
delete child[1];
}
T data;
signed char balance; // right height - left height
Metric total; // this + child[0] + child[1]
Node *child[2]; // left, right
Node *parent;
};
public:
struct iterator{
typename IndexedSet::Node *i;
iterator() : i(0) {};
iterator(typename IndexedSet::Node *n) : i(n) {};
T& operator*() { return i->data; };
T* operator->() { return &i->data; }
void operator++();
void decrementNonEnd();
bool operator == ( const iterator& r ) const { return i == r.i; }
bool operator != ( const iterator& r ) const { return i != r.i; }
};
IndexedSet() : root(NULL) {};
~IndexedSet() { delete root; }
IndexedSet(IndexedSet&& r) noexcept(true) : root(r.root) { r.root = NULL; }
IndexedSet& operator=(IndexedSet&& r) noexcept(true) { delete root; root = r.root; r.root = 0; return *this; }
iterator begin() const;
iterator end() const { return iterator(); }
iterator previous(iterator i) const;
iterator lastItem() const;
bool empty() const { return !root; }
void clear() { delete root; root = NULL; }
void swap( IndexedSet& r ) { std::swap( root, r.root ); }
// Place data in the set with the given metric. If an item equal to data is already in the set and,
// replaceExisting == true, it will be overwritten (and its metric will be replaced)
template <class T_, class Metric_>
iterator insert(T_ &&data, Metric_ &&metric, bool replaceExisting = true);
// Insert all items from data into set. All items will use metric. If an item equal to data is already in the set and,
// replaceExisting == true, it will be overwritten (and its metric will be replaced). returns the number of items inserted.
int insert(const std::vector<std::pair<T,Metric>>& data, bool replaceExisting = true);
// Increase the metric for the given item by the given amount. Inserts data into the set if it
// doesn't exist. Returns the new sum.
template <class T_, class Metric_>
Metric addMetric( T_ && data, Metric_ && metric );
// Remove the data item, if any, which is equal to key
template <class Key>
void erase(const Key &key) { erase( find(key) ); }
// Erase the indicated item. No effect if item == end().
// SOMEDAY: Return ++item
void erase(iterator item);
// Erase all data items x for which begin<=x<end
template <class Key>
void erase(const Key& begin, const Key& end) { erase( lower_bound(begin), lower_bound(end) ); }
// Erase data items with a deferred (async) free process. The data structure has the items removed
// synchronously with the invocation of this method so any subsequent call will see this new state.
template <class Key>
Future<Void> eraseAsync(const Key& begin, const Key& end);
// Erase the items in the indicated range.
void erase(iterator begin, iterator end);
// Erase data items with a deferred (async) free process. The data structure has the items removed
// synchronously with the invocation of this method so any subsequent call will see this new state.
Future<Void> eraseAsync(iterator begin, iterator end);
// Returns the number of items equal to key (either 0 or 1)
template <class Key>
int count(const Key &key) const { return find(key) != end(); }
// Returns x such that key==*x, or end()
template <class Key>
iterator find(const Key &key) const;
// Returns the smallest x such that *x>=key, or end()
template <class Key>
iterator lower_bound(const Key &key) const;
// Returns the smallest x such that *x>key, or end()
template <class Key>
iterator upper_bound(const Key &key) const;
// Returns the largest x such that *x<=key, or end()
template <class Key>
iterator lastLessOrEqual( const Key &key ) const;
// Returns smallest x such that sumTo(x+1) > metric, or end()
template <class M>
iterator index( M const& metric ) const;
// Return the metric inserted with item x
Metric getMetric(iterator x) const;
// Return the sum of getMetric(x) for begin()<=x<to
Metric sumTo(iterator to) const;
// Return the sum of getMetric(x) for begin<=x<end
Metric sumRange(iterator begin, iterator end) const { return sumTo(end) - sumTo(begin); }
// Return the sum of getMetric(x) for all x s.t. begin <= *x && *x < end
template <class Key>
Metric sumRange(const Key& begin, const Key& end) const { return sumRange(lower_bound(begin), lower_bound(end)); }
// Return the amount of memory used by an entry in the IndexedSet
static int getElementBytes() { return sizeof(Node); }
private:
// Copy operations unimplemented. SOMEDAY: Implement and make public.
IndexedSet( const IndexedSet& );
IndexedSet& operator=( const IndexedSet& );
Node *root;
Metric eraseHalf( Node* start, Node* end, int eraseDir, int& heightDelta, std::vector<Node*>& toFree );
void erase( iterator begin, iterator end, std::vector<Node*>& toFree );
void replacePointer( Node* oldNode, Node* newNode ) {
if (oldNode->parent)
oldNode->parent->child[ oldNode->parent->child[1] == oldNode ] = newNode;
else
root = newNode;
if (newNode)
newNode->parent = oldNode->parent;
}
// direction 0 = left, 1 = right
template <int direction>
static void moveIterator(Node* &i){
if (i->child[0^direction]) {
i = i->child[0^direction];
while (i->child[1^direction])
i = i->child[1^direction];
} else {
while (i->parent && i->parent->child[0^direction] == i)
i = i->parent;
i = i->parent;
}
}
public: // but testonly
std::pair<int, int> testonly_assertBalanced(Node*n=0, int d=0, bool a=true);
};
class NoMetric {
public:
NoMetric() {}
NoMetric(int) {} // NoMetric(1)
NoMetric operator+(NoMetric const&) const { return NoMetric(); }
NoMetric operator-(NoMetric const&) const { return NoMetric(); }
bool operator<(NoMetric const&) const { return false; }
};
template <class Key, class Value>
class MapPair {
public:
Key key;
Value value;
template <class Key_, class Value_>
MapPair( Key_&& key, Value_&& value ) : key(std::forward<Key_>(key)), value(std::forward<Value_>(value)) {}
void operator= ( MapPair const& rhs ) { key = rhs.key; value = rhs.value; }
MapPair( MapPair const& rhs ) : key(rhs.key), value(rhs.value) {}
MapPair(MapPair&& r) noexcept(true) : key(std::move(r.key)), value(std::move(r.value)) {}
void operator=(MapPair&& r) noexcept(true) { key = std::move(r.key); value = std::move(r.value); }
bool operator<(MapPair<Key,Value> const& r) const { return key < r.key; }
bool operator==(MapPair<Key,Value> const& r) const { return key == r.key; }
bool operator!=(MapPair<Key,Value> const& r) const { return key != r.key; }
//private: MapPair( const MapPair& );
};
template <class Key, class Value>
inline MapPair<typename std::decay<Key>::type, typename std::decay<Value>::type> mapPair(Key&& key, Value&& value) { return MapPair<typename std::decay<Key>::type, typename std::decay<Value>::type>(std::forward<Key>(key), std::forward<Value>(value)); }
template <class Key, class Value, class CompatibleWithKey>
bool operator<(MapPair<Key, Value> const& l, CompatibleWithKey const& r) { return l.key < r; }
template <class Key, class Value, class CompatibleWithKey>
bool operator<(CompatibleWithKey const& l, MapPair<Key, Value> const& r) { return l < r.key; }
template <class Key, class Value, class Pair = MapPair<Key,Value>, class Metric=NoMetric >
class Map {
public:
typedef typename IndexedSet<Pair,Metric>::iterator iterator;
Map() {}
iterator begin() const { return set.begin(); }
iterator end() const { return set.end(); }
iterator lastItem() const { return set.lastItem(); }
iterator previous(iterator i) const { return set.previous(i); }
bool empty() const { return set.empty(); }
Value& operator[]( const Key& key ) {
iterator i = set.insert( Pair(key, Value()), Metric(1), false );
return i->value;
}
Value& get( const Key& key, Metric m = Metric(1) ) {
iterator i = set.insert( Pair(key, Value()), m, false );
return i->value;
}
iterator insert( const Pair& p, bool replaceExisting = true, Metric m = Metric(1) ) { return set.insert(p, m, replaceExisting); }
iterator insert( Pair && p, bool replaceExisting = true, Metric m = Metric(1) ) { return set.insert(std::move(p), m, replaceExisting); }
int insert( const std::vector<std::pair<MapPair<Key,Value>, Metric>>& pairs, bool replaceExisting = true) { return set.insert(pairs, replaceExisting); }
template <class KeyCompatible>
void erase( KeyCompatible const& k ) { set.erase(k); }
void erase( iterator b, iterator e ) { set.erase(b,e); }
void erase( iterator x ) { set.erase(x); }
void clear() { set.clear(); }
Metric size() const {
static_assert(!std::is_same<Metric, NoMetric>::value, "size() on Map with NoMetric is not valid!");
return sumTo(end());
}
template <class KeyCompatible>
iterator find( KeyCompatible const& k ) const { return set.find(k); }
template <class KeyCompatible>
iterator lower_bound( KeyCompatible const& k ) const { return set.lower_bound(k); }
template <class KeyCompatible>
iterator upper_bound( KeyCompatible const& k ) const { return set.upper_bound(k); }
template <class KeyCompatible>
iterator lastLessOrEqual( KeyCompatible const& k ) const { return set.lastLessOrEqual(k); }
template <class M>
iterator index( M const& metric ) const { return set.index(metric); }
Metric getMetric(iterator x) const { return set.getMetric(x); }
Metric sumTo(iterator to) const { return set.sumTo(to); }
Metric sumRange(iterator begin, iterator end) const { return set.sumRange(begin,end); }
template <class KeyCompatible>
Metric sumRange(const KeyCompatible& begin, const KeyCompatible& end) const { return set.sumRange(begin,end); }
static int getElementBytes() { return IndexedSet< Pair, Metric >::getElementBytes(); }
Map(Map&& r) noexcept(true) : set(std::move(r.set)) {}
void operator=(Map&& r) noexcept(true) { set = std::move(r.set); }
private:
Map( Map<Key,Value,Pair> const& ); // unimplemented
void operator=( Map<Key,Value,Pair> const& ); // unimplemented
IndexedSet< Pair, Metric > set;
};
/////////////////////// implementation //////////////////////////
template <class T, class Metric>
void IndexedSet<T,Metric>::iterator::operator++(){
moveIterator<1>(i);
}
template <class T, class Metric>
void IndexedSet<T,Metric>::iterator::decrementNonEnd(){
moveIterator<0>(i);
}
template <class Node>
void ISRotate(Node*& oldRootRef, int d) {
Node *oldRoot = oldRootRef;
Node *newRoot = oldRoot->child[1-d];
// metrics
auto orTotal = oldRoot->total - newRoot->total;
if (newRoot->child[d])
orTotal = orTotal + newRoot->child[d]->total;
newRoot->total = oldRoot->total;
oldRoot->total = orTotal;
//pointers
oldRoot->child[1-d] = newRoot->child[d];
if (oldRoot->child[1-d]) oldRoot->child[1-d]->parent = oldRoot;
newRoot->child[d] = oldRoot;
newRoot->parent = oldRoot->parent;
oldRoot->parent = newRoot;
oldRootRef = newRoot;
}
template <class Node>
void ISAdjustBalance(Node* root, int d, int bal) {
Node *n = root->child[d];
Node *nn = n->child[1-d];
if ( !nn->balance )
root->balance = n->balance = 0;
else if ( nn->balance == bal ) {
root->balance = -bal;
n->balance = 0;
} else {
root->balance = 0;
n->balance = bal;
}
nn->balance = 0;
}
template <class Node>
int ISRebalance( Node*& root ) {
// Pre: root is a tree having the BST, metric, and balance invariants but not (necessarily) the AVL invariant. root->child[0] and root->child[1] are AVL.
// Post: root is an AVL tree with the same nodes
// Returns: the change in height of root
// rebalance is O(1) if abs(root->balance)<=2, and probably O(log N) otherwise. (The rare "still unbalanced" recursion is hard to analyze)
//
// The documentation of this function will be referencing the following tree (where
// nodes A, C, E, and G represent subtrees of unspecified height). Thus for each node X,
// we know the value of balance(X), but not height(X).
//
// We will assume that balance(F) < 0 (so we will be rotating right).
// Trees that rotate to the left will perform analagous operations.
//
// F
// / \
// B G
// / \
// A D
// / \
// C E
if (!root || (root->balance >= -1 && root->balance <= +1))
return 0;
int rebalanceDir = root->balance<0; // 1 if rotating right, 0 if rotating left
auto* n = root->child[ 1-rebalanceDir ]; // Node B
int bal = rebalanceDir ? +1 : -1; // 1 if rotating right, -1 if rotating left
int rootBal = root->balance;
// Depending on the balance at B, we will be required to do one or two rotations.
// If balance(B) <= 0, then we do only one rotation (the second of the two).
//
// In a tree where balance(B) == +1, we are required to do both rotations.
// The result of the first rotation will be:
//
// F
// / \
// D G
// / \
// B E
// / \
// A C
//
bool doubleRotation = n->balance == bal;
if (doubleRotation) {
int x = n->child[rebalanceDir]->balance; // balance of Node D
ISRotate( root->child[1-rebalanceDir], 1-rebalanceDir); // Rotate at Node B
// Change node pointed to by 'n' to prepare for the second rotation
// After this first rotation, Node D will be the left child of the root
n = root->child[1-rebalanceDir];
// Compute the balance at the new root node D' of our rotation
// We know that height(A) == max(height(C), height(E)) because B had balance of +1
// If height(E) >= height(C), then height(E) == height(A) and balance(D') = -1
// Otherwise height(C) == height(E) + 1, and therefore balance(D') = -2
n->balance = ((x==-bal) ? -2 : -1)*bal;
// Compute the balance at the old root node B' of our rotation
// As stated above, height(A) == max(height(C), height(E))
// If height(C) >= height(E), then height(A) == height(C) and balance(B') = 0
// Otherwise height(A) == height(E) == height(C) + 1, and therefore balance(B') = -1
n->child[1-rebalanceDir]->balance = ((x==bal) ? -1 : 0)*bal;
}
// At this point, we perform the "second" rotation (which may actually be the first
// if the "first" rotation was not performed). The rotation that is performed is the
// same for both trees, but the result will be different depending on which tree we
// started with:
//
// If unrotated: If once rotated:
//
// B D
// / \ / \
// A F B F
// / \ / \ / \
// D G A C E G
// / \
// C E
//
// The documentation for this second rotation will be based on the unrotated original tree.
// Compute the balance at the new root node B'.
// balance(B') = 1 + max(height(D), height(G)) - height(A) = 1 + max(height(D) - height(A), height(G) - height(A))
// balance(B') = 1 + max(balance(B), height(G) - height(A))
//
// Now, we must find height(G) - height(A):
// If height(A) >= height(D) (i.e. balance(B) <= 0), then
// height(G) - height(A) = height(G) - height(B) + 1 = balance(F) + 1
//
// Otherwise, height(A) = height(D) - balance(B) = height(B) - 1 - balance(B), so
// height(G) - height(A) = height(G) - height(B) + 1 + balance(B) = balance(F) + 1 + balance(B)
//
// balance(B') = 1 + max(balance(B), balance(F) + 1 + max(balance(B), 0))
//
int nBal = n->balance * bal; // Direction corrected balance at Node B
int newRootBalance = bal * (1 + std::max(nBal, bal * root->balance + 1 + std::max(nBal, 0)));
// Compute the balance at the old root node F' (which becomes a child of the new root).
// balance(F') = height(G) - height(D)
//
// If height(D) >= height(A) (i.e. balance(B) >= 0), then height(D) = height(B) - 1, so
// balance(F') = height(G) - height(B) + 1 = balance(F) + 1
//
// Otherwise, height(D) = height(A) + balance(B) = height(B) - 1 + balance(B), so
// balance(F') = height(G) - height(B) + 1 - balance(B) = balance(F) + 1 - balance(B)
//
// balance(F') = balance(F) + 1 - min(balance(B), 0)
//
int newChildBalance = root->balance + bal * (1 - std::min(nBal, 0));
ISRotate( root, rebalanceDir );
root->balance = newRootBalance;
root->child[rebalanceDir]->balance = newChildBalance;
// If the original tree is very unbalanced, the unbalance may have been "pushed" down into this subtree, so recursively rebalance that if necessary.
int childHeightChange = ISRebalance(root->child[rebalanceDir]);
root->balance += childHeightChange * bal;
newRootBalance *= bal;
// Compute the change in height at the root
// We will look at the single and double rotation cases separately
//
// If we did a single rotation, then height(A) >= height(D).
// As a result, height(A) >= height(G) + 1; otherwise the tree would be balanced and we wouldn't do any rotations.
//
// Then the original height of the tree is height(A) + 2,
// and the new height is max(height(D) + 2 + childHeightChange, height(A) + 1), so
//
// heightChange_single = max(height(D) + 2 + childHeightChange, height(A) + 1) - (height(A) + 2)
// heightChange_single = max(height(D) - height(A) + childHeightChange, -1)
// heightChange_single = max(balance(B) + childHeightChange, -1)
//
// If we did a double rotation, then height(D) = height(A) + 1 in the original tree.
// As a result, height(D) >= height(G) + 1; otherwise the tree would be balanced and we wouldn't do any rotations.
//
// Then the original height of the tree is height(D) + 2,
// and the new height is max(height(A), height(C), height(E), height(G)) + 2
//
// balance(B) == 1, so height(A) == max(height(C), height(E)).
// Also, height(A) = height(D) - 1 >= height(G)
// Therefore the new height is height(A) + 2
//
// heightChange_double = height(A) + 2 - (height(D) + 2)
// heightChange_double = height(A) - height(D)
// heightChange_double = -1
//
int heightChange = doubleRotation ? -1 : std::max(nBal + childHeightChange, -1);
// If the root is still unbalanced, then it should at least be more balanced than before. Recursively rebalance the root until we get a balanced tree.
if (root->balance <-1 || root->balance > +1) {
ASSERT(abs(root->balance) < abs(rootBal));
heightChange += ISRebalance(root);
}
return heightChange;
}
template <class Node>
Node* ISCommonSubtreeRoot(Node* first, Node* last) {
// Finds the smallest common subtree of first and last and returns its root node
//Find the depth of first and last
int firstDepth=0, lastDepth=0;
for(auto f = first; f; f=f->parent) firstDepth++;
for(auto f = last; f; f=f->parent) lastDepth++;
//Traverse up the tree from the deeper of first and last until f and l are at the same depth
auto f = first, l = last;
for(int i=firstDepth; i>lastDepth; i--) f = f->parent;
for(int i=lastDepth; i>firstDepth; i--) l = l->parent;
//Traverse up from f and l simultaneously until we reach a common node
while (f != l) {
f = f->parent;
l = l->parent;
}
return f;
}
template <class T, class Metric>
typename IndexedSet<T,Metric>::iterator IndexedSet<T,Metric>::begin() const {
Node *x = root;
while (x && x->child[0])
x = x->child[0];
return x;
}
template <class T, class Metric>
typename IndexedSet<T,Metric>::iterator IndexedSet<T,Metric>::previous(typename IndexedSet<T,Metric>::iterator i) const {
if (i==end())
return lastItem();
moveIterator<0>(i.i);
return i;
}
template <class T, class Metric>
typename IndexedSet<T,Metric>::iterator IndexedSet<T,Metric>::lastItem() const {
Node *x = root;
while (x && x->child[1])
x = x->child[1];
return x;
}
template <class T, class Metric> template<class T_, class Metric_>
Metric IndexedSet<T,Metric>::addMetric(T_&& data, Metric_&& metric){
auto i = find( data );
if (i == end()) {
insert( std::forward<T_>(data), std::forward<Metric_>(metric) );
return metric;
} else {
Metric m = metric + getMetric(i);
insert( std::forward<T_>(data), m );
return m;
}
}
template <class T, class Metric> template<class T_, class Metric_>
typename IndexedSet<T,Metric>::iterator IndexedSet<T,Metric>::insert(T_&& data, Metric_&& metric, bool replaceExisting){
if (root == NULL){
root = new Node(std::forward<T_>(data), std::forward<Metric_>(metric));
return root;
}
Node *t = root;
int d; // direction
// traverse to find insert point
while (true){
d = t->data < data;
if (!d && !(data < t->data)) { // t->data == data
Node *returnNode = t;
if(replaceExisting) {
t->data = std::forward<T_>(data);
Metric delta = t->total;
t->total = std::forward<Metric_>(metric);
if (t->child[0]) t->total = t->total + t->child[0]->total;
if (t->child[1]) t->total = t->total + t->child[1]->total;
delta = t->total - delta;
while (true) {
t = t->parent;
if (!t) break;
t->total = t->total + delta;
}
}
return returnNode;
}
Node *nextT = t->child[d];
if (!nextT) break;
t = nextT;
}
Node *newNode = new Node(std::forward<T_>(data), std::forward<Metric_>(metric), t);
t->child[d] = newNode;
while (true){
t->balance += d ? 1 : -1;
t->total = t->total + metric;
if (t->balance == 0)
break;
if (t->balance != 1 && t->balance != -1){
Node** parent = t->parent ? &t->parent->child[t->parent->child[1]==t] : &root;
//assert( *parent == t );
Node *n = t->child[d];
int bal = d ? 1 : -1;
if (n->balance == bal){
t->balance = n->balance = 0;
} else {
ISAdjustBalance(t, d, bal);
ISRotate(t->child[d], d);
}
ISRotate(*parent, 1-d);
t = *parent;
break;
}
if (!t->parent) break;
d = t->parent->child[1] == t;
t = t->parent;
}
while (true) {
t = t->parent;
if (!t) break;
t->total = t->total + metric;
}
return newNode;
}
template <class T, class Metric>
int IndexedSet<T,Metric>::insert(const std::vector<std::pair<T,Metric>>& dataVector, bool replaceExisting) {
int num_inserted = 0;
Node *blockStart = NULL;
Node *blockEnd = NULL;
for(int i = 0; i < dataVector.size(); ++i) {
Metric metric = dataVector[i].second;
T data = std::move(dataVector[i].first);
int d = 1; // direction
if(blockStart == NULL || (blockEnd != NULL && data >= blockEnd->data)) {
blockEnd = NULL;
if (root == NULL){
root = new Node(std::move(data), metric);
num_inserted++;
blockStart = root;
continue;
}
Node *t = root;
// traverse to find insert point
bool foundNode = false;
while (true){
d = t->data < data;
if (!d)
blockEnd = t;
if (!d && !(data < t->data)) { // t->data == data
Node *returnNode = t;
if(replaceExisting) {
num_inserted++;
t->data = std::move(data);
Metric delta = t->total;
t->total = metric;
if (t->child[0]) t->total = t->total + t->child[0]->total;
if (t->child[1]) t->total = t->total + t->child[1]->total;
delta = t->total - delta;
while (true) {
t = t->parent;
if (!t) break;
t->total = t->total + delta;
}
}
blockStart = returnNode;
foundNode = true;
break;
}
Node *nextT = t->child[d];
if (!nextT) {
blockStart = t;
break;
}
t = nextT;
}
if(foundNode)
continue;
}
Node *t = blockStart;
while(t->child[d]) {
t = t->child[d];
d = 0;
}
Node *newNode = new Node(std::move(data), metric, t);
num_inserted++;
t->child[d] = newNode;
blockStart = newNode;
while (true){
t->balance += d ? 1 : -1;
t->total = t->total + metric;
if (t->balance == 0)
break;
if (t->balance != 1 && t->balance != -1){
Node** parent = t->parent ? &t->parent->child[t->parent->child[1]==t] : &root;
//assert( *parent == t );
Node *n = t->child[d];
int bal = d ? 1 : -1;
if (n->balance == bal){
t->balance = n->balance = 0;
} else {
ISAdjustBalance(t, d, bal);
ISRotate(t->child[d], d);
}
ISRotate(*parent, 1-d);
t = *parent;
break;
}
if (!t->parent) break;
d = t->parent->child[1] == t;
t = t->parent;
}
while (true) {
t = t->parent;
if (!t) break;
t->total = t->total + metric;
}
}
return num_inserted;
}
template <class T, class Metric>
Metric IndexedSet<T,Metric>::eraseHalf( Node* start, Node* end, int eraseDir, int& heightDelta, std::vector<Node*>& toFree ) {
// Removes all nodes between start (inclusive) and end (exclusive) from the set, where start is equal to end or one of its descendants
// eraseDir 1 means erase the right half (nodes > at) of the left subtree of end. eraseDir 0 means the left half of the right subtree
// toFree is extended with the roots of completely removed subtrees
// heightDelta will be set to the change in height of the end node
// Returns the amount that should be subtracted from end node's metric value (and, by extension, the metric values of all ancestors of the end node).
//
// The end node may be left unbalanced (AVL invariant broken)
// The end node may be left with the incorrect metric total (the correct value is end->total = end->total + metricDelta)
// scare quotes in comments mean the values when eraseDir==1 (when eraseDir==0, "left" means right etc)
// metricDelta measures how much should be subtracted from the current node's metrics
Metric metricDelta = 0;
heightDelta = 0;
int fromDir = 1 - eraseDir;
// Begin removing nodes at start continuing up until we get to end
while(start != end) {
start->total = start->total - metricDelta;
IndexedSet<T,Metric>::Node *parent = start->parent;
// Obtain the child pointer to start, which rebalance will update with the new root of the subtree currently rooted at start
IndexedSet<T,Metric>::Node *& node = parent->child[ parent->child[1] == start ];
int nextDir = parent->child[1] == start;
if (fromDir==eraseDir) {
// The "right" subtree has been half-erased, and the "left" subtree doesn't need to be (nor does node).
// But this node might be unbalanced by the shrinking "right" subtree. Rebalance and continue up.
heightDelta += ISRebalance( node );
} else {
// The "left" subtree has been half-erased. `start' and its "right" subtree will be completely erased,
// leaving only the "left" subtree in its place (which is already AVL balanced).
heightDelta += -1 - std::max<int>(0, node->balance * (eraseDir ? +1 : -1));
metricDelta = metricDelta + start->total;
// If there is a surviving subtree of start, then connect it to start->parent
IndexedSet<T,Metric>::Node *n = node->child[fromDir];
node = n; // This updates the appropriate child pointer of start->parent
if (n) {
metricDelta = metricDelta - n->total;
n->parent = start->parent;
}
start->child[fromDir] = NULL;
toFree.push_back( start );
}
int dir = (nextDir ? +1 : -1);
int oldBalance = parent->balance;
// The change in height from removing nodes should never increase our height
ASSERT(heightDelta <= 0);
parent->balance += heightDelta * dir;
// Compute the change in height of start's parent based on its change in balance.
// Because we can only be (possibly) shrinking one subtree of parent:
// If we were originally heavier on the shrunken size (oldBalance * dir > 0), then the change in height is at most abs(oldBalance) == oldBalance * dir.
// If we were lighter on the shrunken side, then height cannot change.
int maxHeightChange = std::max(oldBalance * dir, 0);
int balanceChange = (oldBalance - parent->balance) * dir;
heightDelta = -std::min(maxHeightChange, balanceChange);
start = parent;
fromDir = nextDir;
}
return metricDelta;
}
template <class T, class Metric>
void IndexedSet<T,Metric>::erase( typename IndexedSet<T,Metric>::iterator begin, typename IndexedSet<T,Metric>::iterator end, std::vector<Node*>& toFree ) {
// Removes all nodes in the set between first and last, inclusive.
// toFree is extended with the roots of completely removed subtrees.
ASSERT(!end.i || (begin.i && *begin <= *end));
if(begin == end)
return;
IndexedSet<T,Metric>::Node* first = begin.i;
IndexedSet<T,Metric>::Node* last = previous(end).i;
IndexedSet<T,Metric>::Node* subRoot = ISCommonSubtreeRoot(first, last);
Metric metricDelta = 0;
int leftHeightDelta = 0;
int rightHeightDelta = 0;
// Erase all matching nodes that descend from subRoot, by first erasing descendants of subRoot->child[0] and then erasing the descendants of subRoot->child[1]
// subRoot is not removed from the tree at this time
metricDelta = metricDelta + eraseHalf( first, subRoot, 1, leftHeightDelta, toFree );
metricDelta = metricDelta + eraseHalf( last, subRoot, 0, rightHeightDelta, toFree );
// Change in the height of subRoot due to past activity, before subRoot is rebalanced. subRoot->balance already reflects changes in height to its children.
int heightDelta = leftHeightDelta + rightHeightDelta;
// Rebalance and update metrics for all nodes from subRoot up to the root
for(auto p = subRoot; p != NULL; p = p->parent) {
p->total = p->total - metricDelta;
auto& pc = p->parent ? p->parent->child[p->parent->child[1]==p] : root;
heightDelta += ISRebalance(pc);
p = pc;
// Update the balance and compute heightDelta for p->parent
if (p->parent) {
int oldb = p->parent->balance;
int dir = (p->parent->child[1]==p ? +1 : -1);
p->parent->balance += heightDelta * dir;
heightDelta = (std::max(p->parent->balance*dir, 0) - std::max(oldb*dir, 0));
}
}
// Erase the subRoot using the single node erase implementation
erase( IndexedSet<T,Metric>::iterator(subRoot) );
}
template <class T, class Metric>
void IndexedSet<T,Metric>::erase(iterator toErase) {
Node* rebalanceNode;
int rebalanceDir;
{
// Find the node to erase
Node* t = toErase.i;
if (!t) return;
if (!t->child[0] || !t->child[1]) {
Metric tMetric = t->total;
if (t->child[0]) tMetric = tMetric - t->child[0]->total;
if (t->child[1]) tMetric = tMetric - t->child[1]->total;
for( Node* p = t->parent; p; p = p->parent )
p->total = p->total - tMetric;
rebalanceNode = t->parent;
if (rebalanceNode) rebalanceDir = rebalanceNode->child[1] == t;
int d = !t->child[0]; // Only one child, on this side (or no children!)
replacePointer(t, t->child[d]);
t->child[d] = 0;
delete t;
} else { // Remove node with two children
Node* predecessor = t->child[0];
while ( predecessor->child[1] )
predecessor = predecessor->child[1];
rebalanceNode = predecessor->parent;
if (rebalanceNode == t) rebalanceNode = predecessor;
if (rebalanceNode) rebalanceDir = rebalanceNode->child[1] == predecessor;
Metric tMetric = t->total - t->child[0]->total - t->child[1]->total;
if (predecessor->child[0]) predecessor->total = predecessor->total - predecessor->child[0]->total;
for( Node* p = predecessor->parent; p != t; p = p->parent )
p->total = p->total - predecessor->total;
for( Node* p = t->parent; p; p = p->parent )
p->total = p->total - tMetric;
// Replace t with predecessor
replacePointer( predecessor, predecessor->child[0] );
replacePointer( t, predecessor );
predecessor->balance = t->balance;
for(int i=0; i<2; i++) {
Node* c = predecessor->child[i] = t->child[i];
if (c) {
c->parent = predecessor;
predecessor->total = predecessor->total + c->total;
t->child[i] = 0;
}
}
delete t;
}
}
if (!rebalanceNode) return;
while (true) {
rebalanceNode->balance += rebalanceDir ? -1 : +1;
if ( rebalanceNode->balance < -1 || rebalanceNode->balance > +1 ) {
Node** parent = rebalanceNode->parent ? &rebalanceNode->parent->child[rebalanceNode->parent->child[1]==rebalanceNode] : &root;
Node* n = rebalanceNode->child[ 1-rebalanceDir ];
int bal = rebalanceDir ? +1 : -1;
if (n->balance == -bal) {
rebalanceNode->balance = n->balance = 0;
ISRotate( *parent, rebalanceDir );
} else if (n->balance == bal) {
ISAdjustBalance( rebalanceNode, 1-rebalanceDir, -bal );
ISRotate( rebalanceNode->child[1-rebalanceDir], 1-rebalanceDir);
ISRotate( *parent, rebalanceDir );
} else { // n->balance == 0
rebalanceNode->balance = -bal;
n->balance = bal;
ISRotate( *parent, rebalanceDir );
break;
}
rebalanceNode = *parent;
} else if ( rebalanceNode->balance ) // +/- 1, we are done
break;
if (!rebalanceNode->parent) break;
rebalanceDir = rebalanceNode->parent->child[1] == rebalanceNode;
rebalanceNode = rebalanceNode->parent;
}
}
// Returns x such that key==*x, or end()
template <class T, class Metric>
template <class Key>
typename IndexedSet<T,Metric>::iterator IndexedSet<T,Metric>::find(const Key &key) const {
Node* t = root;
while (t){
int d = t->data < key;
if (!d && !(key < t->data)) // t->data == key
return iterator(t);
t = t->child[d];
}
return end();
}
// Returns the smallest x such that *x>=key, or end()
template <class T, class Metric>
template <class Key>
typename IndexedSet<T,Metric>::iterator IndexedSet<T,Metric>::lower_bound(const Key &key) const {
Node* t = root;
if (!t) return iterator();
while (true) {
Node *n = t->child[ t->data < key ];
if (!n) break;
t = n;
}
if (t->data < key)
moveIterator<1>(t);
return iterator(t);
}
// Returns the smallest x such that *x>key, or end()
template <class T, class Metric>
template <class Key>
typename IndexedSet<T,Metric>::iterator IndexedSet<T,Metric>::upper_bound(const Key &key) const {
Node* t = root;
if (!t) return iterator();
while (true) {
Node *n = t->child[ !(key < t->data) ];
if (!n) break;
t = n;
}
if (!(key < t->data))
moveIterator<1>(t);
return iterator(t);
}
template <class T, class Metric>
template <class Key>
typename IndexedSet<T,Metric>::iterator IndexedSet<T,Metric>::lastLessOrEqual(const Key &key) const {
iterator i = upper_bound(key);
if (i == begin()) return end();
return previous(i);
}
// Returns first x such that metric < sum(begin(), x+1), or end()
template <class T, class Metric>
template <class M>
typename IndexedSet<T,Metric>::iterator IndexedSet<T,Metric>::index( M const& metric ) const
{
M m = metric;
Node* t = root;
while (t) {
if (t->child[0] && m < t->child[0]->total)
t = t->child[0];
else {
m = m - t->total;
if (t->child[1])
m = m + t->child[1]->total;
if (m < M())
return iterator(t);
t = t->child[1];
}
}
return end();
}
template <class T, class Metric>
Metric IndexedSet<T,Metric>::getMetric(typename IndexedSet<T,Metric>::iterator x) const {
Metric m = x.i->total;
for(int i=0; i<2; i++)
if (x.i->child[i])
m = m - x.i->child[i]->total;
return m;
}
template <class T, class Metric>
Metric IndexedSet<T,Metric>::sumTo(typename IndexedSet<T,Metric>::iterator end) const {
if (!end.i)
return root ? root->total : Metric();
Metric m = end.i->child[0] ? end.i->child[0]->total : Metric();
for(Node* p = end.i; p->parent; p=p->parent) {
if (p->parent->child[1] == p) {
m = m - p->total;
m = m + p->parent->total;
}
}
return m;
}
#include "flow.h"
#include "IndexedSet.actor.h"
template <class T, class Metric>
void IndexedSet<T,Metric>::erase(typename IndexedSet<T,Metric>::iterator begin, typename IndexedSet<T,Metric>::iterator end) {
std::vector<IndexedSet<T,Metric>::Node*> toFree;
erase(begin, end, toFree);
ISFreeNodes(toFree, true);
}
template <class T, class Metric>
template <class Key>
Future<Void> IndexedSet<T, Metric>::eraseAsync(const Key &begin, const Key &end) {
return eraseAsync(lower_bound(begin), lower_bound(end) );
}
template <class T, class Metric>
Future<Void> IndexedSet<T, Metric>::eraseAsync(typename IndexedSet<T,Metric>::iterator begin, typename IndexedSet<T,Metric>::iterator end) {
std::vector<IndexedSet<T,Metric>::Node*> toFree;
erase(begin, end, toFree);
return uncancellable(ISFreeNodes(toFree, false));
}
#endif