diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 340e84457f68..60c8255db6f3 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -184,12 +184,25 @@ public: return matchFailure(); } + /// Return a list of operations that may be generated when rewriting an + /// operation instance with this pattern. + ArrayRef getGeneratedOps() const { return generatedOps; } + protected: /// Patterns must specify the root operation name they match against, and can /// also specify the benefit of the pattern matching. RewritePattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context) : Pattern(rootName, benefit, context) {} + /// Patterns must specify the root operation name they match against, and can + /// also specify the benefit of the pattern matching. They can also specify + /// the names of operations that may be generated during a successful rewrite. + RewritePattern(StringRef rootName, ArrayRef generatedNames, + PatternBenefit benefit, MLIRContext *context); + + /// A list of the potential operations that may be generated when rewriting + /// an op with this pattern. + llvm::SmallVector generatedOps; }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h index f5eb9a37ef04..79d7e9871eb7 100644 --- a/mlir/include/mlir/TableGen/Pattern.h +++ b/mlir/include/mlir/TableGen/Pattern.h @@ -172,6 +172,9 @@ public: // Returns true if this DAG node is wrapping native code call. bool isNativeCodeCall() const; + // Returns true if this DAG node is an operation. + bool isOperation() const; + // Returns the native code call template inside this DAG node. // Precondition: isNativeCodeCall() llvm::StringRef getNativeCodeTemplate() const; diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp index 194307359972..ac2385135dec 100644 --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -60,6 +60,20 @@ PatternMatchResult RewritePattern::match(Operation *op) const { llvm_unreachable("need to implement either match or matchAndRewrite!"); } +/// Patterns must specify the root operation name they match against, and can +/// also specify the benefit of the pattern matching. They can also specify the +/// names of operations that may be generated during a successful rewrite. +RewritePattern::RewritePattern(StringRef rootName, + ArrayRef generatedNames, + PatternBenefit benefit, MLIRContext *context) + : Pattern(rootName, benefit, context) { + generatedOps.reserve(generatedNames.size()); + std::transform(generatedNames.begin(), generatedNames.end(), + std::back_inserter(generatedOps), [context](StringRef name) { + return OperationName(name, context); + }); +} + PatternRewriter::~PatternRewriter() { // Out of line to provide a vtable anchor for the class. } diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp index 31bab8172e3b..e2ddcbae076e 100644 --- a/mlir/lib/TableGen/Pattern.cpp +++ b/mlir/lib/TableGen/Pattern.cpp @@ -94,6 +94,10 @@ bool tblgen::DagNode::isNativeCodeCall() const { return false; } +bool tblgen::DagNode::isOperation() const { + return !(isNativeCodeCall() || isVerifyUnusedValue() || isReplaceWithValue()); +} + llvm::StringRef tblgen::DagNode::getNativeCodeTemplate() const { assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); return cast(node->getOperator()) diff --git a/mlir/test/mlir-tblgen/pattern-benefit.td b/mlir/test/mlir-tblgen/pattern-benefit.td index 61db84b75999..36bc2c7bd9e0 100644 --- a/mlir/test/mlir-tblgen/pattern-benefit.td +++ b/mlir/test/mlir-tblgen/pattern-benefit.td @@ -26,9 +26,9 @@ def Z_AddOp : NS_Op<"add"> { def bena : Pat<(X_AddOp (X_AddOp $lhs, $rhs), $rhs), (Y_AddOp $lhs, $rhs, $rhs)>; // CHECK-LABEL: struct bena -// CHECK: RewritePattern("x.add", 2, context) {} +// CHECK: RewritePattern("x.add", {"x.add"}, 2, context) {} def benb : Pat<(X_AddOp $lhs, $rhs), (Z_AddOp $lhs), [(IfEqual $lhs, $rhs)], (addBenefit 100)>; // CHECK-LABEL: struct benb -// CHECK: RewritePattern("x.add", 101, context) {} +// CHECK: RewritePattern("x.add", {"x.add"}, 101, context) {} diff --git a/mlir/test/mlir-tblgen/pattern.td b/mlir/test/mlir-tblgen/pattern.td index b5a6c60731c9..66ff381e0b29 100644 --- a/mlir/test/mlir-tblgen/pattern.td +++ b/mlir/test/mlir-tblgen/pattern.td @@ -34,7 +34,7 @@ def : Pat<(OpA $input, $attr), (OpB $input, $attr)>; // CHECK: struct GeneratedConvert0 : public RewritePattern -// CHECK: GeneratedConvert0(MLIRContext *context) : RewritePattern("op_a", 1, context) {} +// CHECK: GeneratedConvert0(MLIRContext *context) : RewritePattern("op_a", {"op_b"}, 1, context) {} // CHECK: struct MatchedState : public PatternState { // CHECK: Value *input; diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index c29cf54cc95c..9103cb0c1f36 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -153,6 +153,9 @@ private: // Emits the match() method. void emitMatchMethod(DagNode tree); + // Collects all of the operations within the given dag tree. + void collectOps(DagNode tree, llvm::SmallPtrSetImpl &ops); + // Emits the rewrite() method. void emitRewriteMethod(); @@ -443,6 +446,18 @@ void PatternEmitter::emitMatchMethod(DagNode tree) { os.indent(4) << "return matchSuccess(std::move(state));\n }\n"; } +void PatternEmitter::collectOps(DagNode tree, + llvm::SmallPtrSetImpl &ops) { + // Check if this tree is an operation. + if (tree.isOperation()) + ops.insert(&tree.getDialectOp(opMap)); + + // Recurse the arguments of the tree. + for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i) + if (auto child = tree.getArgAsNestedDag(i)) + collectOps(child, ops); +} + void PatternEmitter::emit(StringRef rewriteName) { // Get the DAG tree for the source pattern DagNode tree = pattern.getSourcePattern(); @@ -454,14 +469,22 @@ void PatternEmitter::emit(StringRef rewriteName) { PrintFatalError( loc, "replacing op with variadic results not supported right now"); + // Collect the set of result operations. + llvm::SmallPtrSet results; + for (unsigned i = 0, e = pattern.getNumResults(); i != e; ++i) + collectOps(pattern.getResultPattern(i), results); + // Emit RewritePattern for Pattern. auto locs = pattern.getLocation(); os << formatv("/* Generated from:\n\t{0:$[ instantiating\n\t]}\n*/\n", make_range(locs.rbegin(), locs.rend())); os << formatv(R"(struct {0} : public RewritePattern { - {0}(MLIRContext *context) : RewritePattern("{1}", {2}, context) {{})", - rewriteName, rootName, pattern.getBenefit()) - << "\n"; + {0}(MLIRContext *context) : RewritePattern("{1}", {{)", + rewriteName, rootName); + interleaveComma(results, os, [&](const Operator *op) { + os << '"' << op->getOperationName() << '"'; + }); + os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n"; // Emit matched state. os << " struct MatchedState : public PatternState {\n";