diff --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h index e7856e615034..cacf400c2d58 100644 --- a/mlir/include/mlir/TableGen/Pattern.h +++ b/mlir/include/mlir/TableGen/Pattern.h @@ -160,10 +160,6 @@ public: // Precondition: isNativeCodeBuilder. llvm::StringRef getNativeCodeBuilder() const; - // Collects all recursively bound arguments involved in the DAG tree rooted - // from this node. - void collectBoundArguments(Pattern *pattern) const; - // Returns true if this DAG construct means to replace with an existing SSA // value. bool isReplaceWithValue() const; @@ -235,6 +231,10 @@ public: int getBenefit() const; private: + // Recursively collects all bound arguments inside the DAG tree rooted + // at `tree`. + void collectBoundArguments(DagNode tree); + // The TableGen definition of this pattern. const llvm::Record &def; diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp index bc95c2f95927..2200346ceed4 100644 --- a/mlir/lib/TableGen/Pattern.cpp +++ b/mlir/lib/TableGen/Pattern.cpp @@ -22,6 +22,7 @@ #include "mlir/TableGen/Pattern.h" #include "llvm/ADT/Twine.h" +#include "llvm/Support/FormatVariadic.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Record.h" @@ -141,36 +142,6 @@ StringRef tblgen::DagNode::getArgName(unsigned index) const { return node->getArgNameStr(index); } -static void collectBoundArguments(const llvm::DagInit *tree, - tblgen::Pattern *pattern) { - auto &op = pattern->getDialectOp(tblgen::DagNode(tree)); - if (llvm::StringInit *si = tree->getName()) { - auto name = si->getAsUnquotedString(); - if (!name.empty()) - pattern->getSourcePatternBoundResults().insert(name); - } - - // TODO(jpienaar): Expand to multiple matches. - for (unsigned i = 0, e = tree->getNumArgs(); i != e; ++i) { - auto *arg = tree->getArg(i); - - if (auto *argTree = dyn_cast(arg)) { - collectBoundArguments(argTree, pattern); - continue; - } - - StringRef name = tree->getArgNameStr(i); - if (name.empty()) - continue; - - pattern->getSourcePatternBoundArgs().try_emplace(name, op.getArg(i)); - } -} - -void tblgen::DagNode::collectBoundArguments(tblgen::Pattern *pattern) const { - ::collectBoundArguments(node, pattern); -} - bool tblgen::DagNode::isReplaceWithValue() const { auto *dagOpDef = cast(node->getOperator())->getDef(); return dagOpDef->getName() == "replaceWithValue"; @@ -194,7 +165,7 @@ llvm::StringRef tblgen::DagNode::getNativeCodeBuilder() const { tblgen::Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper) : def(*def), recordOpMap(mapper) { - getSourcePattern().collectBoundArguments(this); + collectBoundArguments(getSourcePattern()); } tblgen::DagNode tblgen::Pattern::getSourcePattern() const { @@ -276,3 +247,34 @@ int tblgen::Pattern::getBenefit() const { } return initBenefit + dyn_cast(delta->getArg(0))->getValue(); } + +void tblgen::Pattern::collectBoundArguments(DagNode tree) { + auto &op = getDialectOp(tree); + auto numOpArgs = op.getNumArgs(); + auto numTreeArgs = tree.getNumArgs(); + + if (numOpArgs != numTreeArgs) { + PrintFatalError(def.getLoc(), + formatv("op '{0}' argument number mismatch: " + "{1} in pattern vs. {2} in definition", + op.getOperationName(), numTreeArgs, numOpArgs)); + } + + // The name attached to the DAG node's operator is for representing the + // results generated from this op. It should be remembered as bound results. + auto treeName = tree.getOpName(); + if (!treeName.empty()) + boundResults.insert(treeName); + + // TODO(jpienaar): Expand to multiple matches. + for (unsigned i = 0; i != numTreeArgs; ++i) { + if (auto treeArg = tree.getArgAsNestedDag(i)) { + // This DAG node argument is a DAG node itself. Go inside recursively. + collectBoundArguments(treeArg); + } else { + auto treeArgName = tree.getArgName(i); + if (!treeArgName.empty()) + boundArguments.try_emplace(treeArgName, op.getArg(i)); + } + } +}