[ADT] Add Compare template param to EquivalenceClasses

This makes the class usable with types that do not provide their own operator<.

Update MLIR Linalg ComprehensiveBufferize to take advantage of the new template param.

Differential Revision: https://reviews.llvm.org/D112052
This commit is contained in:
Matthias Springer 2021-11-01 17:15:12 +09:00
parent 1f9fa54984
commit 0118a8044f
3 changed files with 38 additions and 25 deletions

View File

@ -30,7 +30,8 @@ namespace llvm {
///
/// This implementation is an efficient implementation that only stores one copy
/// of the element being indexed per entry in the set, and allows any arbitrary
/// type to be indexed (as long as it can be ordered with operator<).
/// type to be indexed (as long as it can be ordered with operator< or a
/// comparator is provided).
///
/// Here is a simple example using integers:
///
@ -54,7 +55,7 @@ namespace llvm {
/// 4
/// 5 1 2
///
template <class ElemTy>
template <class ElemTy, class Compare = std::less<ElemTy>>
class EquivalenceClasses {
/// ECValue - The EquivalenceClasses data structure is just a set of these.
/// Each of these represents a relation for a value. First it stores the
@ -101,22 +102,40 @@ class EquivalenceClasses {
assert(RHS.isLeader() && RHS.getNext() == nullptr && "Not a singleton!");
}
bool operator<(const ECValue &UFN) const { return Data < UFN.Data; }
bool isLeader() const { return (intptr_t)Next & 1; }
const ElemTy &getData() const { return Data; }
const ECValue *getNext() const {
return (ECValue*)((intptr_t)Next & ~(intptr_t)1);
}
};
template<typename T>
bool operator<(const T &Val) const { return Data < Val; }
/// A wrapper of the comparator, to be passed to the set.
struct ECValueComparator {
using is_transparent = void;
ECValueComparator() : compare(Compare()) {}
bool operator()(const ECValue &lhs, const ECValue &rhs) const {
return compare(lhs.Data, rhs.Data);
}
template <typename T>
bool operator()(const T &lhs, const ECValue &rhs) const {
return compare(lhs, rhs.Data);
}
template <typename T>
bool operator()(const ECValue &lhs, const T &rhs) const {
return compare(lhs.Data, rhs);
}
const Compare compare;
};
/// TheMapping - This implicitly provides a mapping from ElemTy values to the
/// ECValues, it just keeps the key as part of the value.
std::set<ECValue> TheMapping;
std::set<ECValue, ECValueComparator> TheMapping;
public:
EquivalenceClasses() = default;

View File

@ -122,23 +122,17 @@ public:
void dumpEquivalences() const;
private:
/// llvm::EquivalenceClasses wants comparable elements because it uses
/// std::set as the underlying impl.
/// ValueWrapper wraps Value and uses pointer comparison on the defining op.
/// This is a poor man's comparison but it's not like UnionFind needs ordering
/// anyway ..
struct ValueWrapper {
ValueWrapper(Value val) : v(val) {}
operator Value() const { return v; }
bool operator<(const ValueWrapper &wrap) const {
return v.getImpl() < wrap.v.getImpl();
/// llvm::EquivalenceClasses wants comparable elements. This comparator uses
/// uses pointer comparison on the defining op. This is a poor man's
/// comparison but it's not like UnionFind needs ordering anyway.
struct ValueComparator {
bool operator()(const Value &lhs, const Value &rhs) const {
return lhs.getImpl() < rhs.getImpl();
}
bool operator==(const ValueWrapper &wrap) const { return v == wrap.v; }
Value v;
};
using EquivalenceClassRangeType = llvm::iterator_range<
llvm::EquivalenceClasses<ValueWrapper>::member_iterator>;
llvm::EquivalenceClasses<Value, ValueComparator>::member_iterator>;
/// Check that aliasInfo for `v` exists and return a reference to it.
EquivalenceClassRangeType getAliases(Value v) const;
@ -164,10 +158,10 @@ private:
/// Auxiliary structure to store all the values a given value aliases with.
/// These are the conservative cases that can further decompose into
/// "equivalent" buffer relationships.
llvm::EquivalenceClasses<ValueWrapper> aliasInfo;
llvm::EquivalenceClasses<Value, ValueComparator> aliasInfo;
/// Auxiliary structure to store all the equivalent buffer classes.
llvm::EquivalenceClasses<ValueWrapper> equivalentInfo;
llvm::EquivalenceClasses<Value, ValueComparator> equivalentInfo;
};
/// Analyze the `ops` to determine which OpResults are inplaceable.

View File

@ -1213,11 +1213,11 @@ bool BufferizationAliasInfo::isSourceEquivalentToAMatchingInplaceExtractSliceOp(
for (auto mit = leaderIt, meit = equivalentInfo.member_end(); mit != meit;
++mit) {
auto extractSliceOp =
dyn_cast_or_null<ExtractSliceOp>(mit->v.getDefiningOp());
dyn_cast_or_null<ExtractSliceOp>(mit->getDefiningOp());
if (extractSliceOp &&
areEquivalentExtractSliceOps(extractSliceOp, insertSliceOp) &&
getInPlace(extractSliceOp.result()) == InPlaceSpec::True) {
LDBG("\tfound: " << *mit->v.getDefiningOp() << '\n');
LDBG("\tfound: " << *mit->getDefiningOp() << '\n');
return true;
}
}
@ -1231,7 +1231,7 @@ void BufferizationAliasInfo::applyOnEquivalenceClass(
auto leaderIt = equivalentInfo.findLeader(v);
for (auto mit = leaderIt, meit = equivalentInfo.member_end(); mit != meit;
++mit) {
fun(mit->v);
fun(*mit);
}
}