diff --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h index 80a38329a33b..6329b7386b28 100644 --- a/mlir/include/mlir/TableGen/Pattern.h +++ b/mlir/include/mlir/TableGen/Pattern.h @@ -183,6 +183,10 @@ public: // Returns the DAG tree root node of the `index`-th result pattern. DagNode getResultPattern(unsigned index) const; + // Returns true if an argument with the given `name` is bound in source + // pattern. + bool isArgBoundInSourcePattern(llvm::StringRef name) const; + // Checks whether an argument with the given `name` is bound in source // pattern. Prints fatal error if not; does nothing otherwise. void ensureArgBoundInSourcePattern(llvm::StringRef name) const; diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp index 5262141e753b..1662fcc867bb 100644 --- a/mlir/lib/TableGen/Pattern.cpp +++ b/mlir/lib/TableGen/Pattern.cpp @@ -183,9 +183,13 @@ tblgen::DagNode tblgen::Pattern::getResultPattern(unsigned index) const { return tblgen::DagNode(cast(results->getElement(index))); } +bool tblgen::Pattern::isArgBoundInSourcePattern(llvm::StringRef name) const { + return boundArguments.find(name) != boundArguments.end(); +} + void tblgen::Pattern::ensureArgBoundInSourcePattern( llvm::StringRef name) const { - if (boundArguments.find(name) == boundArguments.end()) + if (!isArgBoundInSourcePattern(name)) PrintFatalError(def.getLoc(), Twine("referencing unbound variable '") + name + "'"); } diff --git a/mlir/test/mlir-tblgen/one-op-one-result.td b/mlir/test/mlir-tblgen/one-op-one-result.td index 25023c891684..a43226c41de1 100644 --- a/mlir/test/mlir-tblgen/one-op-one-result.td +++ b/mlir/test/mlir-tblgen/one-op-one-result.td @@ -24,6 +24,6 @@ def : Pat<(X_AddOp $lhs, $rhs), (Y_AddOp $lhs, U:$rhs, T_Const_Attr:$x)>; // CHECK: PatternMatchResult match(Instruction * // CHECK: void rewrite(Instruction *op, std::unique_ptr // CHECK: PatternRewriter &rewriter) -// CHECK: rewriter.replaceOpWithNewOp(op, op->getResult(0)->getType() +// CHECK: rewriter.create(loc, op->getResult(0)->getType() // CHECK: void populateWithGenerated // CHECK: patterns->push_back(std::make_unique(context)) diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index 1e6eecb3ed7b..9d23b9f24644 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -52,8 +52,7 @@ public: raw_ostream &os); private: - PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os) - : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), os(os) {} + PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os); // Emits the mlir::RewritePattern struct named `rewriteName`. void emit(StringRef rewriteName); @@ -64,14 +63,6 @@ private: // Emits the rewrite() method. void emitRewriteMethod(); - // Emits the C++ statement to replace the matched DAG with a new op. - void emitReplaceOpWithNewOp(DagNode resultTree); - // Emits the C++ statement to replace the matched DAG with an existing value. - void emitReplaceWithExistingValue(DagNode resultTree); - // Emits the C++ statement to replace the matched DAG with a native C++ built - // value. - void emitReplaceWithNativeBuilder(DagNode resultTree); - // Emits the value of constant attribute to `os`. void emitConstantAttr(tblgen::ConstantAttr constAttr); @@ -86,6 +77,30 @@ private: // `tree` as an attribute. void emitAttributeMatch(DagNode tree, int index, int depth, int indent); + // Returns a unique name for an value of the given `op`. + std::string getUniqueValueName(const Operator *op); + + // Entry point for handling a rewrite pattern rooted at `resultTree` and + // dispatches to concrete handlers. The given tree is the `resultIndex`-th + // argument of the enclosing DAG. + std::string handleRewritePattern(DagNode resultTree, int resultIndex, + int depth, llvm::StringRef treeName = ""); + + // Emits the C++ statement to replace the matched DAG with a native C++ built + // value. + std::string emitReplaceWithNativeBuilder(DagNode resultTree); + + // Returns the C++ expression referencing the old value serving as the + // replacement. + std::string handleReplaceWithValue(DagNode tree); + + // Emits the C++ statement to build a new op out of the given DAG `tree` and + // returns the variable name that this op is assigned to. If `treeName` is not + // empty, the created op will be assigned to a variable of the given + // `treeName`. Otherwise, a unique name will be used as the result value name. + std::string emitOpCreate(DagNode tree, int resultIndex, int depth, + llvm::StringRef treeName = ""); + private: // Pattern instantiation location followed by the location of multiclass // prototypes used. This is intended to be used as a whole to @@ -95,10 +110,17 @@ private: RecordOperatorMap *opMap; // Handy wrapper for pattern being emitted tblgen::Pattern pattern; + // The next unused ID for newly created values + unsigned nextValueId; raw_ostream &os; }; } // end namespace +PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper, + raw_ostream &os) + : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), nextValueId(0), + os(os) {} + void PatternEmitter::emitConstantAttr(tblgen::ConstantAttr constAttr) { auto attr = constAttr.getAttribute(); @@ -283,31 +305,77 @@ void PatternEmitter::emitRewriteMethod() { void rewrite(Instruction *op, std::unique_ptr state, PatternRewriter &rewriter) const override { auto& s = *static_cast(state.get()); + auto loc = op->getLoc(); (void)loc; )"; - if (resultTree.isNativeCodeBuilder()) - emitReplaceWithNativeBuilder(resultTree); - else if (resultTree.isReplaceWithValue()) - emitReplaceWithExistingValue(resultTree); - else - emitReplaceOpWithNewOp(resultTree); + std::string resultValue = + handleRewritePattern(resultTree, /*resultIndex=*/0, /*depth=*/0); - os << " }\n"; + os.indent(4) << "rewriter.replaceOp(op, {" << resultValue; + os << "});\n }\n"; } -void PatternEmitter::emitReplaceOpWithNewOp(DagNode resultTree) { - Operator &resultOp = resultTree.getDialectOp(opMap); +std::string PatternEmitter::getUniqueValueName(const Operator *op) { + return formatv("v{0}{1}", op->getCppClassName(), nextValueId++); +} + +std::string PatternEmitter::handleRewritePattern(DagNode resultTree, + int resultIndex, int depth, + llvm::StringRef treeName) { + if (resultTree.isNativeCodeBuilder()) + return emitReplaceWithNativeBuilder(resultTree); + + if (resultTree.isReplaceWithValue()) + return handleReplaceWithValue(resultTree); + + return emitOpCreate(resultTree, resultIndex, depth, treeName); +} + +std::string PatternEmitter::handleReplaceWithValue(DagNode tree) { + assert(tree.isReplaceWithValue()); + + if (tree.getNumArgs() != 1) { + PrintFatalError( + loc, "replaceWithValue directive must take exactly one argument"); + } + + auto name = tree.getArgName(0); + pattern.ensureArgBoundInSourcePattern(name); + + // We are referencing some bound value in the source pattern. Those values are + // grouped into a transient struct named as `s`. + return std::string("s.") + name.str(); +} + +std::string PatternEmitter::emitOpCreate(DagNode tree, int resultIndex, + int depth, StringRef treeName) { + Operator &resultOp = tree.getDialectOp(opMap); auto numOpArgs = resultOp.getNumOperands() + resultOp.getNumNativeAttributes(); - os << formatv(R"( - rewriter.replaceOpWithNewOp<{0}>(op, op->getResult(0)->getType())", - resultOp.getCppClassName()); - if (numOpArgs != resultTree.getNumArgs()) { + if (numOpArgs != tree.getNumArgs()) { PrintFatalError(loc, formatv("resultant op '{0}' argument number mismatch: " "{1} in pattern vs. {2} in definition", - resultOp.getOperationName(), - resultTree.getNumArgs(), numOpArgs)); + resultOp.getOperationName(), tree.getNumArgs(), + numOpArgs)); + } + + std::string resultValue = + treeName.empty() ? getUniqueValueName(&resultOp) : treeName.str(); + + // TODO: this is a hack to support various constant ops. We are assuming + // all of them have no operands and one attribute here. Figure out a better + // way to do this. + if (resultOp.getNumOperands() == 0 && + resultOp.getNumNativeAttributes() == 1) { + os.indent(4) << formatv("auto {0} = rewriter.create<{1}>(loc", resultValue, + resultOp.getQualCppClassName()); + } else { + std::string resultType = formatv("op->getResult({0})", resultIndex).str(); + + os.indent(4) << formatv( + "auto {0} = rewriter.create<{1}>(loc, {2}->getType()", resultValue, + resultOp.getQualCppClassName(), resultType); } // Create the builder call for the result. @@ -317,7 +385,7 @@ void PatternEmitter::emitReplaceOpWithNewOp(DagNode resultTree) { // Start each operand on its own line. (os << ",\n").indent(6); - auto name = resultTree.getArgName(i); + auto name = tree.getArgName(i); pattern.ensureArgBoundInSourcePattern(name); if (!operand.name.empty()) os << "/*" << operand.name << "=*/"; @@ -327,13 +395,13 @@ void PatternEmitter::emitReplaceOpWithNewOp(DagNode resultTree) { } // Add attributes. - for (int e = resultTree.getNumArgs(); i != e; ++i) { + for (int e = tree.getNumArgs(); i != e; ++i) { // Start each attribute on its own line. (os << ",\n").indent(6); - auto leaf = resultTree.getArgAsLeaf(i); + auto leaf = tree.getArgAsLeaf(i); // The argument in the result DAG pattern. - auto patArgName = resultTree.getArgName(i); + auto patArgName = tree.getArgName(i); // The argument in the op definition. auto opArgName = resultOp.getArgName(i); @@ -360,20 +428,16 @@ void PatternEmitter::emitReplaceOpWithNewOp(DagNode resultTree) { } } os << "\n );\n"; + + return resultValue; } -void PatternEmitter::emitReplaceWithExistingValue(DagNode resultTree) { - if (resultTree.getNumArgs() != 1) { - PrintFatalError(loc, "exactly one argument needed in the result pattern"); - } +std::string PatternEmitter::emitReplaceWithNativeBuilder(DagNode resultTree) { + // The variable's name for holding the result of this native builder call + std::string value = formatv("v{0}", nextValueId++).str(); - auto name = resultTree.getArgName(0); - pattern.ensureArgBoundInSourcePattern(name); - os.indent(4) << "rewriter.replaceOp(op, {s." << name << "});\n"; -} - -void PatternEmitter::emitReplaceWithNativeBuilder(DagNode resultTree) { - os.indent(4) << resultTree.getNativeCodeBuilder() << "(op, {"; + os.indent(4) << "auto " << value << " = " << resultTree.getNativeCodeBuilder() + << "(op, {"; const auto &boundedValues = pattern.getSourcePatternBoundArgs(); bool first = true; bool printingAttr = false; @@ -394,6 +458,8 @@ void PatternEmitter::emitReplaceWithNativeBuilder(DagNode resultTree) { if (!printingAttr) os << "},{"; os << "}, rewriter);\n"; + + return value; } void PatternEmitter::emit(StringRef rewriteName, Record *p,