forked from OSchip/llvm-project
[mlir] Move the operation equivalence out of CSE and into OperationSupport
This provides a general hash and comparison for checking if two operations are equivalent. This revision also optimizes the handling of result types to take advantage of how result types are stored on the operation. Differential Revision: https://reviews.llvm.org/D79029
This commit is contained in:
parent
108abd2f2e
commit
df00e466da
|
@ -814,6 +814,20 @@ private:
|
|||
/// Allow access to `offset_base` and `dereference_iterator`.
|
||||
friend RangeBaseT;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Operation Equivalency
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This class provides utilities for computing if two operations are
|
||||
/// equivalent.
|
||||
struct OperationEquivalence {
|
||||
/// Compute a hash for the given operation.
|
||||
static llvm::hash_code computeHash(Operation *op);
|
||||
|
||||
/// Compare two operations and return if they are equivalent.
|
||||
static bool isEquivalentTo(Operation *lhs, Operation *rhs);
|
||||
};
|
||||
} // end namespace mlir
|
||||
|
||||
namespace llvm {
|
||||
|
|
|
@ -395,3 +395,78 @@ Value ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) {
|
|||
Operation *operation = reinterpret_cast<Operation *>(owner.ptr.get<void *>());
|
||||
return operation->getResult(owner.startIndex + index);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Operation Equivalency
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
llvm::hash_code OperationEquivalence::computeHash(Operation *op) {
|
||||
// Hash operations based upon their:
|
||||
// - Operation Name
|
||||
// - Attributes
|
||||
llvm::hash_code hash = llvm::hash_combine(
|
||||
op->getName(), op->getMutableAttrDict().getDictionary());
|
||||
|
||||
// - Result Types
|
||||
ArrayRef<Type> resultTypes = op->getResultTypes();
|
||||
switch (resultTypes.size()) {
|
||||
case 0:
|
||||
// We don't need to add anything to the hash.
|
||||
break;
|
||||
case 1:
|
||||
// Add in the result type.
|
||||
hash = llvm::hash_combine(hash, resultTypes.front());
|
||||
break;
|
||||
default:
|
||||
// Use the type buffer as the hash, as we can guarantee it is the same for
|
||||
// any given range of result types. This takes advantage of the fact the
|
||||
// result types >1 are stored in a TupleType and uniqued.
|
||||
hash = llvm::hash_combine(hash, resultTypes.data());
|
||||
break;
|
||||
}
|
||||
|
||||
// - Operands
|
||||
// TODO: Allow commutative operations to have different ordering.
|
||||
return llvm::hash_combine(
|
||||
hash, llvm::hash_combine_range(op->operand_begin(), op->operand_end()));
|
||||
}
|
||||
|
||||
bool OperationEquivalence::isEquivalentTo(Operation *lhs, Operation *rhs) {
|
||||
if (lhs == rhs)
|
||||
return true;
|
||||
|
||||
// Compare the operation name.
|
||||
if (lhs->getName() != rhs->getName())
|
||||
return false;
|
||||
// Check operand counts.
|
||||
if (lhs->getNumOperands() != rhs->getNumOperands())
|
||||
return false;
|
||||
// Compare attributes.
|
||||
if (lhs->getMutableAttrDict() != rhs->getMutableAttrDict())
|
||||
return false;
|
||||
// Compare result types.
|
||||
ArrayRef<Type> lhsResultTypes = lhs->getResultTypes();
|
||||
ArrayRef<Type> rhsResultTypes = rhs->getResultTypes();
|
||||
if (lhsResultTypes.size() != rhsResultTypes.size())
|
||||
return false;
|
||||
switch (lhsResultTypes.size()) {
|
||||
case 0:
|
||||
break;
|
||||
case 1:
|
||||
// Compare the single result type.
|
||||
if (lhsResultTypes.front() != rhsResultTypes.front())
|
||||
return false;
|
||||
break;
|
||||
default:
|
||||
// Use the type buffer for the comparison, as we can guarantee it is the
|
||||
// same for any given range of result types. This takes advantage of the
|
||||
// fact the result types >1 are stored in a TupleType and uniqued.
|
||||
if (lhsResultTypes.data() != rhsResultTypes.data())
|
||||
return false;
|
||||
break;
|
||||
}
|
||||
// Compare operands.
|
||||
// TODO: Allow commutative operations to have different ordering.
|
||||
return std::equal(lhs->operand_begin(), lhs->operand_end(),
|
||||
rhs->operand_begin());
|
||||
}
|
||||
|
|
|
@ -26,19 +26,9 @@
|
|||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
// TODO(riverriddle) Handle commutative operations.
|
||||
struct SimpleOperationInfo : public llvm::DenseMapInfo<Operation *> {
|
||||
static unsigned getHashValue(const Operation *opC) {
|
||||
auto *op = const_cast<Operation *>(opC);
|
||||
// Hash the operations based upon their:
|
||||
// - Operation Name
|
||||
// - Attributes
|
||||
// - Result Types
|
||||
// - Operands
|
||||
return llvm::hash_combine(
|
||||
op->getName(), op->getMutableAttrDict().getDictionary(),
|
||||
op->getResultTypes(),
|
||||
llvm::hash_combine_range(op->operand_begin(), op->operand_end()));
|
||||
return OperationEquivalence::computeHash(const_cast<Operation *>(opC));
|
||||
}
|
||||
static bool isEqual(const Operation *lhsC, const Operation *rhsC) {
|
||||
auto *lhs = const_cast<Operation *>(lhsC);
|
||||
|
@ -48,24 +38,8 @@ struct SimpleOperationInfo : public llvm::DenseMapInfo<Operation *> {
|
|||
if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
|
||||
rhs == getTombstoneKey() || rhs == getEmptyKey())
|
||||
return false;
|
||||
|
||||
// Compare the operation name.
|
||||
if (lhs->getName() != rhs->getName())
|
||||
return false;
|
||||
// Check operand and result type counts.
|
||||
if (lhs->getNumOperands() != rhs->getNumOperands() ||
|
||||
lhs->getNumResults() != rhs->getNumResults())
|
||||
return false;
|
||||
// Compare attributes.
|
||||
if (lhs->getMutableAttrDict() != rhs->getMutableAttrDict())
|
||||
return false;
|
||||
// Compare operands.
|
||||
if (!std::equal(lhs->operand_begin(), lhs->operand_end(),
|
||||
rhs->operand_begin()))
|
||||
return false;
|
||||
// Compare result types.
|
||||
return std::equal(lhs->result_type_begin(), lhs->result_type_end(),
|
||||
rhs->result_type_begin());
|
||||
return OperationEquivalence::isEquivalentTo(const_cast<Operation *>(lhsC),
|
||||
const_cast<Operation *>(rhsC));
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
|
Loading…
Reference in New Issue