Match multiple pattern nodes as input to rewrite.

* Allow multi input node patterns in the rewrite;
* Use number of nodes matched as benefit;
* Rewrite relu(add(...)) matching using the new pattern;

To allow for undefined ops, do string compare - will address soon!

PiperOrigin-RevId: 227225425
This commit is contained in:
Jacques Pienaar 2018-12-29 07:55:08 -08:00 committed by jpienaar
parent 5b9c3f7cdb
commit 554848d617
2 changed files with 131 additions and 41 deletions

View File

@ -53,7 +53,7 @@ public:
struct Attribute {
llvm::StringInit *name;
llvm::Record *record;
const bool isDerived;
bool isDerived;
};
using attribute_iterator = Attribute *;

View File

@ -35,23 +35,56 @@
using namespace llvm;
using namespace mlir;
namespace {
// Wrapper around dag argument.
struct DagArg {
DagArg(Init *init) : init(init){};
bool isAttr();
Init *init;
};
} // end namespace
bool DagArg::isAttr() {
if (auto defInit = dyn_cast<DefInit>(init))
return defInit->getDef()->isSubClassOf("Attr");
return false;
}
namespace {
class Pattern {
public:
Pattern(Record *pattern) : pattern(pattern){};
// Emit the rewrite pattern named `rewriteName` to `os`.
void emit(StringRef rewriteName, raw_ostream &os);
// Emits the value of constant attribute to `os`.
void emitAttributeValue(RecordVal *value, raw_ostream &os);
static void emit(StringRef rewriteName, Record *p, raw_ostream &os);
private:
Pattern(Record *pattern, raw_ostream &os) : pattern(pattern), os(os){};
// Emit the rewrite pattern named `rewriteName`.
void emit(StringRef rewriteName);
// Emit the matcher.
void emitMatcher(DagInit *tree);
// Emits the value of constant attribute to `os`.
void emitAttributeValue(RecordVal *value);
// Collect bound arguments.
void collectBoundArguments(DagInit *tree);
// Map from bound argument name to DagArg.
StringMap<DagArg> boundArguments;
// Number of the operations in the input pattern.
int numberOfOpsMatched = 0;
Record *pattern;
raw_ostream &os;
};
} // end namespace
void Pattern::emitAttributeValue(RecordVal *value, raw_ostream &os) {
void Pattern::emitAttributeValue(RecordVal *value) {
switch (value->getType()->getRecTyKind()) {
case RecTy::IntRecTyKind:
// TODO(jpienaar): This is using 64-bits for all the bitwidth of the
@ -72,46 +105,99 @@ void Pattern::emitAttributeValue(RecordVal *value, raw_ostream &os) {
}
}
void Pattern::emit(StringRef rewriteName, raw_ostream &os) {
DagInit *tree = pattern->getValueAsDag("PatternToMatch");
void Pattern::collectBoundArguments(DagInit *tree) {
++numberOfOpsMatched;
// TODO(jpienaar): Expand to multiple matches.
for (auto arg : tree->getArgs()) {
if (isa<DagInit>(arg))
PrintFatalError(pattern->getLoc(),
"only single pattern inputs supported");
for (int i = 0, e = tree->getNumArgs(); i != e; ++i) {
auto arg = tree->getArg(i);
if (auto argTree = dyn_cast<DagInit>(arg)) {
collectBoundArguments(argTree);
continue;
}
auto name = tree->getArgNameStr(i);
if (name.empty())
continue;
boundArguments.try_emplace(name, arg);
}
}
// Helper function to match patterns.
static void matchOp(DagInit *tree, int depth, raw_ostream &os) {
Operator op(cast<DefInit>(tree->getOperator())->getDef());
int indent = 4 + 2 * depth;
// Skip the operand matching at depth 0 as the pattern rewriter already does.
if (depth != 0) {
// Skip if there is no defining instruction (e.g., arguments to function).
os.indent(indent) << formatv("if (!op{0}) return matchFailure();\n", depth);
// TODO(jpienaar): This is bad, we should not be checking strings here, we
// should be matching using mOp (and helpers). Currently doing this to allow
// for TF ops that aren't registed. Fix it.
os.indent(indent) << formatv(
"if (op{0}->getName().getStringRef() != \"{1}\")",
depth, op.getOperationName())
<< "\n";
os.indent(indent + 2) << "return matchFailure();\n";
}
for (int i = 0, e = tree->getNumArgs(); i != e; ++i) {
auto arg = tree->getArg(i);
if (auto argTree = dyn_cast<DagInit>(arg)) {
os.indent(indent) << "{\n";
os.indent(indent + 2) << formatv(
"auto op{0} = op{1}->getOperand({2})->getDefiningInst();\n",
depth + 1, depth, i);
matchOp(argTree, depth + 1, os);
os.indent(indent) << "}\n";
continue;
}
auto name = tree->getArgNameStr(i);
if (name.empty())
continue;
os.indent(indent) << "state->" << name << " = op" << depth
<< "->getOperand(" << i << ");\n";
}
}
void Pattern::emitMatcher(DagInit *tree) {
// Emit the heading.
os << R"(
PatternMatchResult match(OperationInst *op0) const override {
// TODO: This just handle 1 result
if (op0->getNumResults() != 1) return matchFailure();
auto state = std::make_unique<MatchedState>();)"
<< "\n";
matchOp(tree, 0, os);
os.indent(4) << "return matchSuccess(std::move(state));\n }\n";
}
void Pattern::emit(StringRef rewriteName) {
DagInit *tree = pattern->getValueAsDag("PatternToMatch");
// Collect bound arguments and compute number of ops matched.
// TODO(jpienaar): the benefit metric is simply number of ops matched at the
// moment, revise.
collectBoundArguments(tree);
// Emit RewritePattern for Pattern.
DefInit *root = cast<DefInit>(tree->getOperator());
auto *rootName = cast<StringInit>(root->getDef()->getValueInit("opName"));
os << "struct " << rewriteName << " : public RewritePattern {\n"
<< " " << rewriteName << "(MLIRContext *context) : RewritePattern("
<< rootName->getAsString() << ", 1, context) {}\n";
os << formatv(R"(struct {0} : public RewritePattern {
{0}(MLIRContext *context) : RewritePattern({1}, {2}, context) {{})",
rewriteName, rootName->getAsString(), numberOfOpsMatched)
<< "\n";
// Emit matched state.
os << " struct MatchedState : public PatternState {\n";
for (int i = 0, e = tree->getNumArgs(); i != e; ++i) {
auto arg = tree->getArgNameStr(i);
if (!arg.empty())
os.indent(6) << "Value* " << arg << ";\n";
for (auto &arg : boundArguments) {
if (arg.second.isAttr()) {
DefInit *defInit = cast<DefInit>(arg.second.init);
os.indent(4) << defInit->getDef()->getValueAsString("storageType").trim()
<< " " << arg.first() << ";\n";
} else {
os.indent(4) << "Value* " << arg.first() << ";\n";
}
}
os << " };\n";
StringSet<> boundArguments;
os << R"(
PatternMatchResult match(OperationInst *op) const override {
// TODO: This just handle 1 result
if (op->getNumResults() != 1) return matchFailure();
auto state = std::make_unique<MatchedState>();)"
<< "\n";
for (int i = 0, e = tree->getNumArgs(); i != e; ++i) {
auto arg = tree->getArgNameStr(i);
if (!arg.empty())
os.indent(4) << "state->" << arg << " = op->getOperand(" << i << ");\n";
boundArguments.insert(arg);
}
os.indent(4) << "return matchSuccess(std::move(state));\n }\n";
emitMatcher(tree);
ListInit *resultOps = pattern->getValueAsListInit("ResultOps");
if (resultOps->size() != 1)
PrintFatalError("only single result rules supported");
@ -183,12 +269,17 @@ void Pattern::emit(StringRef rewriteName, raw_ostream &os) {
if (!name.empty())
os << "/*" << name << "=*/";
emitAttributeValue(value, os);
emitAttributeValue(value);
// TODO(jpienaar): verify types
}
os << "\n );\n }\n};\n";
}
void Pattern::emit(StringRef rewriteName, Record *p, raw_ostream &os) {
Pattern pattern(p, os);
pattern.emit(rewriteName);
}
static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
emitSourceFileHeader("Rewriters", os);
const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
@ -197,8 +288,7 @@ static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
std::string baseRewriteName = "GeneratedConvert";
int rewritePatternCount = 0;
for (Record *p : patterns) {
Pattern pattern(p);
pattern.emit(baseRewriteName + llvm::utostr(rewritePatternCount++), os);
Pattern::emit(baseRewriteName + llvm::utostr(rewritePatternCount++), p, os);
}
// Emit function to add the generated matchers to the pattern list.