[mlir] Move the operation equivalence out of CSE and into OperationSupport

This provides a general hash and comparison for checking if two operations are equivalent. This revision also optimizes the handling of result types to take advantage of how result types are stored on the operation.

Differential Revision: https://reviews.llvm.org/D79029
This commit is contained in:
River Riddle 2020-04-29 16:09:20 -07:00
parent 108abd2f2e
commit df00e466da
3 changed files with 92 additions and 29 deletions

View File

@ -814,6 +814,20 @@ private:
/// Allow access to `offset_base` and `dereference_iterator`.
friend RangeBaseT;
};
//===----------------------------------------------------------------------===//
// Operation Equivalency
//===----------------------------------------------------------------------===//
/// This class provides utilities for computing if two operations are
/// equivalent.
struct OperationEquivalence {
/// Compute a hash for the given operation.
static llvm::hash_code computeHash(Operation *op);
/// Compare two operations and return if they are equivalent.
static bool isEquivalentTo(Operation *lhs, Operation *rhs);
};
} // end namespace mlir
namespace llvm {

View File

@ -395,3 +395,78 @@ Value ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) {
Operation *operation = reinterpret_cast<Operation *>(owner.ptr.get<void *>());
return operation->getResult(owner.startIndex + index);
}
//===----------------------------------------------------------------------===//
// Operation Equivalency
//===----------------------------------------------------------------------===//
llvm::hash_code OperationEquivalence::computeHash(Operation *op) {
// Hash operations based upon their:
// - Operation Name
// - Attributes
llvm::hash_code hash = llvm::hash_combine(
op->getName(), op->getMutableAttrDict().getDictionary());
// - Result Types
ArrayRef<Type> resultTypes = op->getResultTypes();
switch (resultTypes.size()) {
case 0:
// We don't need to add anything to the hash.
break;
case 1:
// Add in the result type.
hash = llvm::hash_combine(hash, resultTypes.front());
break;
default:
// Use the type buffer as the hash, as we can guarantee it is the same for
// any given range of result types. This takes advantage of the fact the
// result types >1 are stored in a TupleType and uniqued.
hash = llvm::hash_combine(hash, resultTypes.data());
break;
}
// - Operands
// TODO: Allow commutative operations to have different ordering.
return llvm::hash_combine(
hash, llvm::hash_combine_range(op->operand_begin(), op->operand_end()));
}
bool OperationEquivalence::isEquivalentTo(Operation *lhs, Operation *rhs) {
if (lhs == rhs)
return true;
// Compare the operation name.
if (lhs->getName() != rhs->getName())
return false;
// Check operand counts.
if (lhs->getNumOperands() != rhs->getNumOperands())
return false;
// Compare attributes.
if (lhs->getMutableAttrDict() != rhs->getMutableAttrDict())
return false;
// Compare result types.
ArrayRef<Type> lhsResultTypes = lhs->getResultTypes();
ArrayRef<Type> rhsResultTypes = rhs->getResultTypes();
if (lhsResultTypes.size() != rhsResultTypes.size())
return false;
switch (lhsResultTypes.size()) {
case 0:
break;
case 1:
// Compare the single result type.
if (lhsResultTypes.front() != rhsResultTypes.front())
return false;
break;
default:
// Use the type buffer for the comparison, as we can guarantee it is the
// same for any given range of result types. This takes advantage of the
// fact the result types >1 are stored in a TupleType and uniqued.
if (lhsResultTypes.data() != rhsResultTypes.data())
return false;
break;
}
// Compare operands.
// TODO: Allow commutative operations to have different ordering.
return std::equal(lhs->operand_begin(), lhs->operand_end(),
rhs->operand_begin());
}

View File

@ -26,19 +26,9 @@
using namespace mlir;
namespace {
// TODO(riverriddle) Handle commutative operations.
struct SimpleOperationInfo : public llvm::DenseMapInfo<Operation *> {
static unsigned getHashValue(const Operation *opC) {
auto *op = const_cast<Operation *>(opC);
// Hash the operations based upon their:
// - Operation Name
// - Attributes
// - Result Types
// - Operands
return llvm::hash_combine(
op->getName(), op->getMutableAttrDict().getDictionary(),
op->getResultTypes(),
llvm::hash_combine_range(op->operand_begin(), op->operand_end()));
return OperationEquivalence::computeHash(const_cast<Operation *>(opC));
}
static bool isEqual(const Operation *lhsC, const Operation *rhsC) {
auto *lhs = const_cast<Operation *>(lhsC);
@ -48,24 +38,8 @@ struct SimpleOperationInfo : public llvm::DenseMapInfo<Operation *> {
if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
rhs == getTombstoneKey() || rhs == getEmptyKey())
return false;
// Compare the operation name.
if (lhs->getName() != rhs->getName())
return false;
// Check operand and result type counts.
if (lhs->getNumOperands() != rhs->getNumOperands() ||
lhs->getNumResults() != rhs->getNumResults())
return false;
// Compare attributes.
if (lhs->getMutableAttrDict() != rhs->getMutableAttrDict())
return false;
// Compare operands.
if (!std::equal(lhs->operand_begin(), lhs->operand_end(),
rhs->operand_begin()))
return false;
// Compare result types.
return std::equal(lhs->result_type_begin(), lhs->result_type_end(),
rhs->result_type_begin());
return OperationEquivalence::isEquivalentTo(const_cast<Operation *>(lhsC),
const_cast<Operation *>(rhsC));
}
};
} // end anonymous namespace