forked from OSchip/llvm-project
[mlir] m_Constant()
Summary: Introduce m_Constant() which allows matching a constant operation without forcing the user also to capture the attribute value. Differential Revision: https://reviews.llvm.org/D72397
This commit is contained in:
parent
202ab273e6
commit
81e7922e83
|
@ -56,6 +56,8 @@ template <typename AttrT> struct constant_op_binder {
|
||||||
/// Creates a matcher instance that binds the constant attribute value to
|
/// Creates a matcher instance that binds the constant attribute value to
|
||||||
/// bind_value if match succeeds.
|
/// bind_value if match succeeds.
|
||||||
constant_op_binder(AttrT *bind_value) : bind_value(bind_value) {}
|
constant_op_binder(AttrT *bind_value) : bind_value(bind_value) {}
|
||||||
|
/// Creates a matcher instance that doesn't bind if match succeeds.
|
||||||
|
constant_op_binder() : bind_value(nullptr) {}
|
||||||
|
|
||||||
bool match(Operation *op) {
|
bool match(Operation *op) {
|
||||||
if (op->getNumOperands() > 0 || op->getNumResults() != 1)
|
if (op->getNumOperands() > 0 || op->getNumResults() != 1)
|
||||||
|
@ -66,10 +68,13 @@ template <typename AttrT> struct constant_op_binder {
|
||||||
SmallVector<OpFoldResult, 1> foldedOp;
|
SmallVector<OpFoldResult, 1> foldedOp;
|
||||||
if (succeeded(op->fold(/*operands=*/llvm::None, foldedOp))) {
|
if (succeeded(op->fold(/*operands=*/llvm::None, foldedOp))) {
|
||||||
if (auto attr = foldedOp.front().dyn_cast<Attribute>()) {
|
if (auto attr = foldedOp.front().dyn_cast<Attribute>()) {
|
||||||
if ((*bind_value = attr.dyn_cast<AttrT>()))
|
if (auto attrT = attr.dyn_cast<AttrT>()) {
|
||||||
|
if (bind_value)
|
||||||
|
*bind_value = attrT;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -196,6 +201,11 @@ struct RecursivePatternMatcher {
|
||||||
|
|
||||||
} // end namespace detail
|
} // end namespace detail
|
||||||
|
|
||||||
|
/// Matches a constant foldable operation.
|
||||||
|
inline detail::constant_op_binder<Attribute> m_Constant() {
|
||||||
|
return detail::constant_op_binder<Attribute>();
|
||||||
|
}
|
||||||
|
|
||||||
/// Matches a value from a constant foldable operation and writes the value to
|
/// Matches a value from a constant foldable operation and writes the value to
|
||||||
/// bind_value.
|
/// bind_value.
|
||||||
template <typename AttrT>
|
template <typename AttrT>
|
||||||
|
|
|
@ -342,8 +342,7 @@ LogicalResult OpBuilder::tryFold(Operation *op,
|
||||||
};
|
};
|
||||||
|
|
||||||
// If this operation is already a constant, there is nothing to do.
|
// If this operation is already a constant, there is nothing to do.
|
||||||
Attribute unused;
|
if (matchPattern(op, m_Constant()))
|
||||||
if (matchPattern(op, m_Constant(&unused)))
|
|
||||||
return cleanupFailure();
|
return cleanupFailure();
|
||||||
|
|
||||||
// Check to see if any operands to the operation is constant and whether
|
// Check to see if any operands to the operation is constant and whether
|
||||||
|
|
|
@ -40,3 +40,4 @@ func @test2(%a: f32) -> f32 {
|
||||||
|
|
||||||
// CHECK-LABEL: test2
|
// CHECK-LABEL: test2
|
||||||
// CHECK: Pattern add(add(a, constant), a) matched and bound constant to: 1.000000e+00
|
// CHECK: Pattern add(add(a, constant), a) matched and bound constant to: 1.000000e+00
|
||||||
|
// CHECK: Pattern add(add(a, constant), a) matched
|
||||||
|
|
|
@ -126,12 +126,15 @@ void test2(FuncOp f) {
|
||||||
auto a = m_Val(f.getArgument(0));
|
auto a = m_Val(f.getArgument(0));
|
||||||
FloatAttr floatAttr;
|
FloatAttr floatAttr;
|
||||||
auto p = m_Op<MulFOp>(a, m_Op<AddFOp>(a, m_Constant(&floatAttr)));
|
auto p = m_Op<MulFOp>(a, m_Op<AddFOp>(a, m_Constant(&floatAttr)));
|
||||||
|
auto p1 = m_Op<MulFOp>(a, m_Op<AddFOp>(a, m_Constant()));
|
||||||
// Last operation that is not the terminator.
|
// Last operation that is not the terminator.
|
||||||
Operation *lastOp = f.getBody().front().back().getPrevNode();
|
Operation *lastOp = f.getBody().front().back().getPrevNode();
|
||||||
if (p.match(lastOp))
|
if (p.match(lastOp))
|
||||||
llvm::outs()
|
llvm::outs()
|
||||||
<< "Pattern add(add(a, constant), a) matched and bound constant to: "
|
<< "Pattern add(add(a, constant), a) matched and bound constant to: "
|
||||||
<< floatAttr.getValueAsDouble() << "\n";
|
<< floatAttr.getValueAsDouble() << "\n";
|
||||||
|
if (p1.match(lastOp))
|
||||||
|
llvm::outs() << "Pattern add(add(a, constant), a) matched\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestMatchers::runOnFunction() {
|
void TestMatchers::runOnFunction() {
|
||||||
|
|
Loading…
Reference in New Issue