diff --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h index b6381f31a085..22bc7b303fa2 100644 --- a/mlir/include/mlir/TableGen/Pattern.h +++ b/mlir/include/mlir/TableGen/Pattern.h @@ -205,23 +205,15 @@ public: // Returns the DAG tree root node of the `index`-th result pattern. DagNode getResultPattern(unsigned index) const; - // Returns true if an argument with the given `name` is bound in source - // pattern. - bool isArgBoundInSourcePattern(llvm::StringRef name) const; - - // Returns true if an argument with the given `name` is bound as result of - // op in pattern. - bool isResultBoundInSourcePattern(llvm::StringRef name) const; - - // Checks whether an argument with the given `name` is bound in source - // pattern. Prints fatal error if not; does nothing otherwise. - void ensureArgBoundInSourcePattern(llvm::StringRef name) const; + // Checks whether an argument or op with the given `name` is bound in + // source pattern. Prints fatal error if not; does nothing otherwise. + void ensureBoundInSourcePattern(llvm::StringRef name) const; // Returns a reference to all the bound arguments in the source pattern. llvm::StringMap &getSourcePatternBoundArgs(); - // Returns a reference to all the bound results in the source pattern. - llvm::StringSet<> &getSourcePatternBoundResults(); + // Returns a reference to all the bound ops in the source pattern. + llvm::StringSet<> &getSourcePatternBoundOps(); // Returns the op that the root node of the source pattern matches. const Operator &getSourceRootOp(); @@ -247,11 +239,11 @@ private: // All operators. RecordOperatorMap *recordOpMap; - // All bound arguments. + // All bound op arguments. llvm::StringMap boundArguments; - // All bound results. - llvm::StringSet<> boundResults; + // All bound ops. + llvm::StringSet<> boundOps; }; } // end namespace tblgen diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp index 794b772c601a..92267d140357 100644 --- a/mlir/lib/TableGen/Pattern.cpp +++ b/mlir/lib/TableGen/Pattern.cpp @@ -186,17 +186,9 @@ tblgen::DagNode tblgen::Pattern::getResultPattern(unsigned index) const { return tblgen::DagNode(cast(results->getElement(index))); } -bool tblgen::Pattern::isArgBoundInSourcePattern(llvm::StringRef name) const { - return boundArguments.find(name) != boundArguments.end(); -} - -bool tblgen::Pattern::isResultBoundInSourcePattern(llvm::StringRef name) const { - return boundResults.count(name); -} - -void tblgen::Pattern::ensureArgBoundInSourcePattern( - llvm::StringRef name) const { - if (!isArgBoundInSourcePattern(name)) +void tblgen::Pattern::ensureBoundInSourcePattern(llvm::StringRef name) const { + if (boundArguments.find(name) == boundArguments.end() && + boundOps.find(name) == boundOps.end()) PrintFatalError(def.getLoc(), Twine("referencing unbound variable '") + name + "'"); } @@ -206,8 +198,8 @@ tblgen::Pattern::getSourcePatternBoundArgs() { return boundArguments; } -llvm::StringSet<> &tblgen::Pattern::getSourcePatternBoundResults() { - return boundResults; +llvm::StringSet<> &tblgen::Pattern::getSourcePatternBoundOps() { + return boundOps; } const tblgen::Operator &tblgen::Pattern::getSourceRootOp() { @@ -268,7 +260,7 @@ void tblgen::Pattern::collectBoundArguments(DagNode tree) { // results generated from this op. It should be remembered as bound results. auto treeName = tree.getOpName(); if (!treeName.empty()) - boundResults.insert(treeName); + boundOps.insert(treeName); // TODO(jpienaar): Expand to multiple matches. for (unsigned i = 0; i != numTreeArgs; ++i) { diff --git a/mlir/test/mlir-tblgen/pattern-bound-symbol.td b/mlir/test/mlir-tblgen/pattern-bound-symbol.td index 7c2897e560c9..0d1a53eecc1b 100644 --- a/mlir/test/mlir-tblgen/pattern-bound-symbol.td +++ b/mlir/test/mlir-tblgen/pattern-bound-symbol.td @@ -18,7 +18,7 @@ def OpC : Op<"op_c", []> { } def OpD : Op<"op_d", []> { - let arguments = (ins I32:$input1, I32:$input2, I32Attr:$attr); + let arguments = (ins I32:$input1, I32:$input2, I32:$input3, I32Attr:$attr); let results = (outs I32:$result); } @@ -26,7 +26,7 @@ def hasOneUse: ConstrainthasOneUse()">, "has one use">; def : Pattern<(OpA:$res_a $operand, $attr), [(OpC:$res_c (OpB:$res_b $operand)), - (OpD $res_b, $res_c, $attr)], + (OpD $res_b, $res_c, $res_a, $attr)], [(hasOneUse $res_a)]>; // CHECK-LABEL: GeneratedConvert0 @@ -34,20 +34,20 @@ def : Pattern<(OpA:$res_a $operand, $attr), // Test struct for bound arguments // --- // CHECK: struct MatchedState : public PatternState -// CHECK: Value* operand; +// CHECK: Value *operand; // CHECK: IntegerAttr attr; +// CHECK: Operation *res_a; // Test bound arguments/results in source pattern // --- // CHECK: PatternMatchResult match // CHECK: auto state = llvm::make_unique(); // CHECK: auto &s = *state; -// CHECK: mlir::Operation* tblgen_res_a; (void)tblgen_res_a; -// CHECK: tblgen_res_a = op0; +// CHECK: s.res_a = op0; // CHECK: s.operand = op0->getOperand(0); // CHECK: attr = op0->getAttrOfType("attr"); // CHECK: s.attr = attr; -// CHECK: if (!(tblgen_res_a->hasOneUse())) return matchFailure(); +// CHECK: if (!(s.res_a->hasOneUse())) return matchFailure(); // Test bound results in result pattern // --- @@ -59,4 +59,5 @@ def : Pattern<(OpA:$res_a $operand, $attr), // CHECK: auto vOpD0 = rewriter.create( // CHECK: /*input1=*/res_b, // CHECK: /*input2=*/res_c, +// CHECK: /*input3=*/s.res_a, // CHECK: /*attr=*/s.attr diff --git a/mlir/test/mlir-tblgen/pattern.td b/mlir/test/mlir-tblgen/pattern.td index c3c4b201f3b3..1ca9643cff14 100644 --- a/mlir/test/mlir-tblgen/pattern.td +++ b/mlir/test/mlir-tblgen/pattern.td @@ -29,7 +29,7 @@ def : Pat<(OpA $input, $attr), (OpB $input, $attr)>; // CHECK: GeneratedConvert0(MLIRContext *context) : RewritePattern("op_a", 1, context) {} // CHECK: struct MatchedState : public PatternState { -// CHECK: Value* input; +// CHECK: Value *input; // CHECK: IntegerAttr attr; // CHECK: }; diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index b77c93b0aaf9..797d3367efa0 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -41,19 +41,12 @@ using namespace llvm; using namespace mlir; using namespace mlir::tblgen; -static const char *const tblgenNamePrefix = "tblgen_"; - -// Returns the bound value for the given op result `symbol`. -static Twine getBoundResult(const StringRef &symbol) { - return tblgenNamePrefix + symbol; -} - -// Returns the bound value for the given op argument `symbol`. +// Returns the bound symbol for the given op argument or op named `symbol`. // -// Arguments bound in the source pattern are grouped into a transient -// `PatternState` struct. This struct can be accessed in both `match()` and -// `rewrite()` via the local variable named as `s`. -static Twine getBoundArgument(const StringRef &symbol) { +// Arguments and ops bound in the source pattern are grouped into a +// transient `PatternState` struct. This struct can be accessed in both +// `match()` and `rewrite()` via the local variable named as `s`. +static Twine getBoundSymbol(const StringRef &symbol) { return Twine("s.") + symbol; } @@ -85,7 +78,7 @@ namespace { class PatternSymbolResolver { public: PatternSymbolResolver(const StringMap &srcArgs, - const StringSet<> &srcResults); + const StringSet<> &srcOperations); // Marks the given `symbol` as bound. Returns false if the `symbol` is // already bound. @@ -105,8 +98,8 @@ private: } // end anonymous namespace PatternSymbolResolver::PatternSymbolResolver(const StringMap &srcArgs, - const StringSet<> &srcResults) - : sourceArguments(srcArgs), sourceOps(srcResults) {} + const StringSet<> &srcOperations) + : sourceArguments(srcArgs), sourceOps(srcOperations) {} bool PatternSymbolResolver::add(StringRef symbol) { return resultOps.insert(symbol).second; @@ -121,12 +114,12 @@ std::string PatternSymbolResolver::query(StringRef symbol) const { { auto it = sourceArguments.find(symbol); if (it != sourceArguments.end()) - return getBoundArgument(symbol).str(); + return getBoundSymbol(symbol).str(); } { auto it = sourceOps.find(symbol); if (it != sourceOps.end()) - return getBoundResult(symbol).str(); + return getBoundSymbol(symbol).str(); } return {}; } @@ -235,7 +228,7 @@ PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os) : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), symbolResolver(pattern.getSourcePatternBoundArgs(), - pattern.getSourcePatternBoundResults()), + pattern.getSourcePatternBoundOps()), nextValueId(0), os(os) { matchCtx.withBuilder("mlir::Builder(ctx)"); rewriteCtx.withBuilder("rewriter"); @@ -274,7 +267,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) { // If the operand's name is set, set to that variable. auto name = tree.getOpName(); if (!name.empty()) - os.indent(indent) << formatv("{0} = op{1};\n", getBoundResult(name), depth); + os.indent(indent) << formatv("{0} = op{1};\n", getBoundSymbol(name), depth); for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { auto opArg = op.getArg(i); @@ -330,7 +323,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth, // Capture the value auto name = tree.getArgName(index); if (!name.empty()) { - os.indent(indent) << getBoundArgument(name) << " = op" << depth + os.indent(indent) << getBoundSymbol(name) << " = op" << depth << "->getOperand(" << index << ");\n"; } } @@ -380,7 +373,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth, // Capture the value auto name = tree.getArgName(index); if (!name.empty()) { - os.indent(indent) << getBoundArgument(name) << " = attr;\n"; + os.indent(indent) << getBoundSymbol(name) << " = attr;\n"; } indent -= 2; @@ -405,10 +398,6 @@ void PatternEmitter::emitMatchMethod(DagNode tree) { } } - for (auto &res : pattern.getSourcePatternBoundResults()) - os.indent(4) << formatv("mlir::Operation* {0}; (void){0};\n", - getBoundResult(res.first())); - emitOpMatch(tree, 0); for (auto &appliedConstraint : pattern.getConstraints()) { @@ -472,9 +461,12 @@ void PatternEmitter::emit(StringRef rewriteName) { os.indent(4) << namedAttr->attr.getStorageType() << " " << fieldName << ";\n"; } else { - os.indent(4) << "Value* " << fieldName << ";\n"; + os.indent(4) << "Value *" << fieldName << ";\n"; } } + for (const auto &result : pattern.getSourcePatternBoundOps()) { + os.indent(4) << "Operation *" << result.getKey() << ";\n"; + } os << " };\n"; emitMatchMethod(tree); @@ -564,9 +556,9 @@ std::string PatternEmitter::handleReplaceWithValue(DagNode tree) { } auto name = tree.getArgName(0); - pattern.ensureArgBoundInSourcePattern(name); + pattern.ensureBoundInSourcePattern(name); - return getBoundArgument(name).str(); + return getBoundSymbol(name).str(); } void PatternEmitter::handleVerifyUnusedValue(DagNode tree, int index) { @@ -587,8 +579,8 @@ std::string PatternEmitter::handleOpArgument(DagLeaf leaf, auto enumCase = leaf.getAsEnumAttrCase(); return handleConstantAttr(enumCase, enumCase.getSymbol()); } - pattern.ensureArgBoundInSourcePattern(argName); - std::string result = getBoundArgument(argName).str(); + pattern.ensureBoundInSourcePattern(argName); + std::string result = getBoundSymbol(argName).str(); if (leaf.isUnspecified() || leaf.isOperandMatcher()) { return result; } @@ -758,7 +750,7 @@ std::string PatternEmitter::emitReplaceWithNativeBuilder(DagNode resultTree) { bool printingAttr = false; for (int i = 0, e = resultTree.getNumArgs(); i != e; ++i) { auto name = resultTree.getArgName(i); - pattern.ensureArgBoundInSourcePattern(name); + pattern.ensureBoundInSourcePattern(name); const auto &val = boundedValues.find(name); if (val->second.dyn_cast() && !printingAttr) { os << "}, {"; @@ -767,7 +759,7 @@ std::string PatternEmitter::emitReplaceWithNativeBuilder(DagNode resultTree) { } if (!first) os << ","; - os << getBoundArgument(name); + os << getBoundSymbol(name); first = false; } if (!printingAttr)