diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h index d8d3308c7f08..170984b8550e 100644 --- a/mlir/include/mlir/IR/Matchers.h +++ b/mlir/include/mlir/IR/Matchers.h @@ -56,6 +56,8 @@ template struct constant_op_binder { /// Creates a matcher instance that binds the constant attribute value to /// bind_value if match succeeds. constant_op_binder(AttrT *bind_value) : bind_value(bind_value) {} + /// Creates a matcher instance that doesn't bind if match succeeds. + constant_op_binder() : bind_value(nullptr) {} bool match(Operation *op) { if (op->getNumOperands() > 0 || op->getNumResults() != 1) @@ -66,8 +68,11 @@ template struct constant_op_binder { SmallVector foldedOp; if (succeeded(op->fold(/*operands=*/llvm::None, foldedOp))) { if (auto attr = foldedOp.front().dyn_cast()) { - if ((*bind_value = attr.dyn_cast())) + if (auto attrT = attr.dyn_cast()) { + if (bind_value) + *bind_value = attrT; return true; + } } } return false; @@ -196,6 +201,11 @@ struct RecursivePatternMatcher { } // end namespace detail +/// Matches a constant foldable operation. +inline detail::constant_op_binder m_Constant() { + return detail::constant_op_binder(); +} + /// Matches a value from a constant foldable operation and writes the value to /// bind_value. template diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 0c72abf9a5e3..506636682758 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -342,8 +342,7 @@ LogicalResult OpBuilder::tryFold(Operation *op, }; // If this operation is already a constant, there is nothing to do. - Attribute unused; - if (matchPattern(op, m_Constant(&unused))) + if (matchPattern(op, m_Constant())) return cleanupFailure(); // Check to see if any operands to the operation is constant and whether diff --git a/mlir/test/IR/test-matchers.mlir b/mlir/test/IR/test-matchers.mlir index 7808f25a2f8e..60d5bcf7d81b 100644 --- a/mlir/test/IR/test-matchers.mlir +++ b/mlir/test/IR/test-matchers.mlir @@ -40,3 +40,4 @@ func @test2(%a: f32) -> f32 { // CHECK-LABEL: test2 // CHECK: Pattern add(add(a, constant), a) matched and bound constant to: 1.000000e+00 +// CHECK: Pattern add(add(a, constant), a) matched diff --git a/mlir/test/lib/IR/TestMatchers.cpp b/mlir/test/lib/IR/TestMatchers.cpp index b62daa8437c6..6061b251d724 100644 --- a/mlir/test/lib/IR/TestMatchers.cpp +++ b/mlir/test/lib/IR/TestMatchers.cpp @@ -126,12 +126,15 @@ void test2(FuncOp f) { auto a = m_Val(f.getArgument(0)); FloatAttr floatAttr; auto p = m_Op(a, m_Op(a, m_Constant(&floatAttr))); + auto p1 = m_Op(a, m_Op(a, m_Constant())); // Last operation that is not the terminator. Operation *lastOp = f.getBody().front().back().getPrevNode(); if (p.match(lastOp)) llvm::outs() << "Pattern add(add(a, constant), a) matched and bound constant to: " << floatAttr.getValueAsDouble() << "\n"; + if (p1.match(lastOp)) + llvm::outs() << "Pattern add(add(a, constant), a) matched\n"; } void TestMatchers::runOnFunction() {