[TableGen] Make sure op in pattern has the same number of arguments as definition

When an op in the source pattern specifies more arguments than its definition, we
    will have out-of-bound query for op arguments from the definition. That will cause
    crashes. This change fixes it.

--

PiperOrigin-RevId: 242548415
This commit is contained in:
Lei Zhang 2019-04-08 15:14:59 -07:00 committed by Mehdi Amini
parent 1ee07e7fde
commit 04b6d2f3c1
2 changed files with 37 additions and 35 deletions

View File

@ -160,10 +160,6 @@ public:
// Precondition: isNativeCodeBuilder. // Precondition: isNativeCodeBuilder.
llvm::StringRef getNativeCodeBuilder() const; llvm::StringRef getNativeCodeBuilder() const;
// Collects all recursively bound arguments involved in the DAG tree rooted
// from this node.
void collectBoundArguments(Pattern *pattern) const;
// Returns true if this DAG construct means to replace with an existing SSA // Returns true if this DAG construct means to replace with an existing SSA
// value. // value.
bool isReplaceWithValue() const; bool isReplaceWithValue() const;
@ -235,6 +231,10 @@ public:
int getBenefit() const; int getBenefit() const;
private: private:
// Recursively collects all bound arguments inside the DAG tree rooted
// at `tree`.
void collectBoundArguments(DagNode tree);
// The TableGen definition of this pattern. // The TableGen definition of this pattern.
const llvm::Record &def; const llvm::Record &def;

View File

@ -22,6 +22,7 @@
#include "mlir/TableGen/Pattern.h" #include "mlir/TableGen/Pattern.h"
#include "llvm/ADT/Twine.h" #include "llvm/ADT/Twine.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/TableGen/Error.h" #include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h" #include "llvm/TableGen/Record.h"
@ -141,36 +142,6 @@ StringRef tblgen::DagNode::getArgName(unsigned index) const {
return node->getArgNameStr(index); return node->getArgNameStr(index);
} }
static void collectBoundArguments(const llvm::DagInit *tree,
tblgen::Pattern *pattern) {
auto &op = pattern->getDialectOp(tblgen::DagNode(tree));
if (llvm::StringInit *si = tree->getName()) {
auto name = si->getAsUnquotedString();
if (!name.empty())
pattern->getSourcePatternBoundResults().insert(name);
}
// TODO(jpienaar): Expand to multiple matches.
for (unsigned i = 0, e = tree->getNumArgs(); i != e; ++i) {
auto *arg = tree->getArg(i);
if (auto *argTree = dyn_cast<llvm::DagInit>(arg)) {
collectBoundArguments(argTree, pattern);
continue;
}
StringRef name = tree->getArgNameStr(i);
if (name.empty())
continue;
pattern->getSourcePatternBoundArgs().try_emplace(name, op.getArg(i));
}
}
void tblgen::DagNode::collectBoundArguments(tblgen::Pattern *pattern) const {
::collectBoundArguments(node, pattern);
}
bool tblgen::DagNode::isReplaceWithValue() const { bool tblgen::DagNode::isReplaceWithValue() const {
auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef(); auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
return dagOpDef->getName() == "replaceWithValue"; return dagOpDef->getName() == "replaceWithValue";
@ -194,7 +165,7 @@ llvm::StringRef tblgen::DagNode::getNativeCodeBuilder() const {
tblgen::Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper) tblgen::Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
: def(*def), recordOpMap(mapper) { : def(*def), recordOpMap(mapper) {
getSourcePattern().collectBoundArguments(this); collectBoundArguments(getSourcePattern());
} }
tblgen::DagNode tblgen::Pattern::getSourcePattern() const { tblgen::DagNode tblgen::Pattern::getSourcePattern() const {
@ -276,3 +247,34 @@ int tblgen::Pattern::getBenefit() const {
} }
return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue(); return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue();
} }
void tblgen::Pattern::collectBoundArguments(DagNode tree) {
auto &op = getDialectOp(tree);
auto numOpArgs = op.getNumArgs();
auto numTreeArgs = tree.getNumArgs();
if (numOpArgs != numTreeArgs) {
PrintFatalError(def.getLoc(),
formatv("op '{0}' argument number mismatch: "
"{1} in pattern vs. {2} in definition",
op.getOperationName(), numTreeArgs, numOpArgs));
}
// The name attached to the DAG node's operator is for representing the
// results generated from this op. It should be remembered as bound results.
auto treeName = tree.getOpName();
if (!treeName.empty())
boundResults.insert(treeName);
// TODO(jpienaar): Expand to multiple matches.
for (unsigned i = 0; i != numTreeArgs; ++i) {
if (auto treeArg = tree.getArgAsNestedDag(i)) {
// This DAG node argument is a DAG node itself. Go inside recursively.
collectBoundArguments(treeArg);
} else {
auto treeArgName = tree.getArgName(i);
if (!treeArgName.empty())
boundArguments.try_emplace(treeArgName, op.getArg(i));
}
}
}