forked from OSchip/llvm-project
Match multiple pattern nodes as input to rewrite.
* Allow multi input node patterns in the rewrite; * Use number of nodes matched as benefit; * Rewrite relu(add(...)) matching using the new pattern; To allow for undefined ops, do string compare - will address soon! PiperOrigin-RevId: 227225425
This commit is contained in:
parent
5b9c3f7cdb
commit
554848d617
|
@ -53,7 +53,7 @@ public:
|
|||
struct Attribute {
|
||||
llvm::StringInit *name;
|
||||
llvm::Record *record;
|
||||
const bool isDerived;
|
||||
bool isDerived;
|
||||
};
|
||||
|
||||
using attribute_iterator = Attribute *;
|
||||
|
|
|
@ -35,23 +35,56 @@
|
|||
using namespace llvm;
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
|
||||
// Wrapper around dag argument.
|
||||
struct DagArg {
|
||||
DagArg(Init *init) : init(init){};
|
||||
bool isAttr();
|
||||
|
||||
Init *init;
|
||||
};
|
||||
|
||||
} // end namespace
|
||||
|
||||
bool DagArg::isAttr() {
|
||||
if (auto defInit = dyn_cast<DefInit>(init))
|
||||
return defInit->getDef()->isSubClassOf("Attr");
|
||||
return false;
|
||||
}
|
||||
|
||||
namespace {
|
||||
class Pattern {
|
||||
public:
|
||||
Pattern(Record *pattern) : pattern(pattern){};
|
||||
|
||||
// Emit the rewrite pattern named `rewriteName` to `os`.
|
||||
void emit(StringRef rewriteName, raw_ostream &os);
|
||||
|
||||
// Emits the value of constant attribute to `os`.
|
||||
void emitAttributeValue(RecordVal *value, raw_ostream &os);
|
||||
static void emit(StringRef rewriteName, Record *p, raw_ostream &os);
|
||||
|
||||
private:
|
||||
Pattern(Record *pattern, raw_ostream &os) : pattern(pattern), os(os){};
|
||||
|
||||
// Emit the rewrite pattern named `rewriteName`.
|
||||
void emit(StringRef rewriteName);
|
||||
|
||||
// Emit the matcher.
|
||||
void emitMatcher(DagInit *tree);
|
||||
|
||||
// Emits the value of constant attribute to `os`.
|
||||
void emitAttributeValue(RecordVal *value);
|
||||
|
||||
// Collect bound arguments.
|
||||
void collectBoundArguments(DagInit *tree);
|
||||
|
||||
// Map from bound argument name to DagArg.
|
||||
StringMap<DagArg> boundArguments;
|
||||
|
||||
// Number of the operations in the input pattern.
|
||||
int numberOfOpsMatched = 0;
|
||||
|
||||
Record *pattern;
|
||||
raw_ostream &os;
|
||||
};
|
||||
} // end namespace
|
||||
|
||||
void Pattern::emitAttributeValue(RecordVal *value, raw_ostream &os) {
|
||||
void Pattern::emitAttributeValue(RecordVal *value) {
|
||||
switch (value->getType()->getRecTyKind()) {
|
||||
case RecTy::IntRecTyKind:
|
||||
// TODO(jpienaar): This is using 64-bits for all the bitwidth of the
|
||||
|
@ -72,46 +105,99 @@ void Pattern::emitAttributeValue(RecordVal *value, raw_ostream &os) {
|
|||
}
|
||||
}
|
||||
|
||||
void Pattern::emit(StringRef rewriteName, raw_ostream &os) {
|
||||
DagInit *tree = pattern->getValueAsDag("PatternToMatch");
|
||||
|
||||
void Pattern::collectBoundArguments(DagInit *tree) {
|
||||
++numberOfOpsMatched;
|
||||
// TODO(jpienaar): Expand to multiple matches.
|
||||
for (auto arg : tree->getArgs()) {
|
||||
if (isa<DagInit>(arg))
|
||||
PrintFatalError(pattern->getLoc(),
|
||||
"only single pattern inputs supported");
|
||||
for (int i = 0, e = tree->getNumArgs(); i != e; ++i) {
|
||||
auto arg = tree->getArg(i);
|
||||
if (auto argTree = dyn_cast<DagInit>(arg)) {
|
||||
collectBoundArguments(argTree);
|
||||
continue;
|
||||
}
|
||||
auto name = tree->getArgNameStr(i);
|
||||
if (name.empty())
|
||||
continue;
|
||||
boundArguments.try_emplace(name, arg);
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to match patterns.
|
||||
static void matchOp(DagInit *tree, int depth, raw_ostream &os) {
|
||||
Operator op(cast<DefInit>(tree->getOperator())->getDef());
|
||||
int indent = 4 + 2 * depth;
|
||||
// Skip the operand matching at depth 0 as the pattern rewriter already does.
|
||||
if (depth != 0) {
|
||||
// Skip if there is no defining instruction (e.g., arguments to function).
|
||||
os.indent(indent) << formatv("if (!op{0}) return matchFailure();\n", depth);
|
||||
// TODO(jpienaar): This is bad, we should not be checking strings here, we
|
||||
// should be matching using mOp (and helpers). Currently doing this to allow
|
||||
// for TF ops that aren't registed. Fix it.
|
||||
os.indent(indent) << formatv(
|
||||
"if (op{0}->getName().getStringRef() != \"{1}\")",
|
||||
depth, op.getOperationName())
|
||||
<< "\n";
|
||||
os.indent(indent + 2) << "return matchFailure();\n";
|
||||
}
|
||||
for (int i = 0, e = tree->getNumArgs(); i != e; ++i) {
|
||||
auto arg = tree->getArg(i);
|
||||
if (auto argTree = dyn_cast<DagInit>(arg)) {
|
||||
os.indent(indent) << "{\n";
|
||||
os.indent(indent + 2) << formatv(
|
||||
"auto op{0} = op{1}->getOperand({2})->getDefiningInst();\n",
|
||||
depth + 1, depth, i);
|
||||
matchOp(argTree, depth + 1, os);
|
||||
os.indent(indent) << "}\n";
|
||||
continue;
|
||||
}
|
||||
auto name = tree->getArgNameStr(i);
|
||||
if (name.empty())
|
||||
continue;
|
||||
os.indent(indent) << "state->" << name << " = op" << depth
|
||||
<< "->getOperand(" << i << ");\n";
|
||||
}
|
||||
}
|
||||
|
||||
void Pattern::emitMatcher(DagInit *tree) {
|
||||
// Emit the heading.
|
||||
os << R"(
|
||||
PatternMatchResult match(OperationInst *op0) const override {
|
||||
// TODO: This just handle 1 result
|
||||
if (op0->getNumResults() != 1) return matchFailure();
|
||||
auto state = std::make_unique<MatchedState>();)"
|
||||
<< "\n";
|
||||
matchOp(tree, 0, os);
|
||||
os.indent(4) << "return matchSuccess(std::move(state));\n }\n";
|
||||
}
|
||||
|
||||
void Pattern::emit(StringRef rewriteName) {
|
||||
DagInit *tree = pattern->getValueAsDag("PatternToMatch");
|
||||
// Collect bound arguments and compute number of ops matched.
|
||||
// TODO(jpienaar): the benefit metric is simply number of ops matched at the
|
||||
// moment, revise.
|
||||
collectBoundArguments(tree);
|
||||
|
||||
// Emit RewritePattern for Pattern.
|
||||
DefInit *root = cast<DefInit>(tree->getOperator());
|
||||
auto *rootName = cast<StringInit>(root->getDef()->getValueInit("opName"));
|
||||
os << "struct " << rewriteName << " : public RewritePattern {\n"
|
||||
<< " " << rewriteName << "(MLIRContext *context) : RewritePattern("
|
||||
<< rootName->getAsString() << ", 1, context) {}\n";
|
||||
os << formatv(R"(struct {0} : public RewritePattern {
|
||||
{0}(MLIRContext *context) : RewritePattern({1}, {2}, context) {{})",
|
||||
rewriteName, rootName->getAsString(), numberOfOpsMatched)
|
||||
<< "\n";
|
||||
|
||||
// Emit matched state.
|
||||
os << " struct MatchedState : public PatternState {\n";
|
||||
for (int i = 0, e = tree->getNumArgs(); i != e; ++i) {
|
||||
auto arg = tree->getArgNameStr(i);
|
||||
if (!arg.empty())
|
||||
os.indent(6) << "Value* " << arg << ";\n";
|
||||
for (auto &arg : boundArguments) {
|
||||
if (arg.second.isAttr()) {
|
||||
DefInit *defInit = cast<DefInit>(arg.second.init);
|
||||
os.indent(4) << defInit->getDef()->getValueAsString("storageType").trim()
|
||||
<< " " << arg.first() << ";\n";
|
||||
} else {
|
||||
os.indent(4) << "Value* " << arg.first() << ";\n";
|
||||
}
|
||||
}
|
||||
os << " };\n";
|
||||
|
||||
StringSet<> boundArguments;
|
||||
os << R"(
|
||||
PatternMatchResult match(OperationInst *op) const override {
|
||||
// TODO: This just handle 1 result
|
||||
if (op->getNumResults() != 1) return matchFailure();
|
||||
auto state = std::make_unique<MatchedState>();)"
|
||||
<< "\n";
|
||||
for (int i = 0, e = tree->getNumArgs(); i != e; ++i) {
|
||||
auto arg = tree->getArgNameStr(i);
|
||||
if (!arg.empty())
|
||||
os.indent(4) << "state->" << arg << " = op->getOperand(" << i << ");\n";
|
||||
boundArguments.insert(arg);
|
||||
}
|
||||
os.indent(4) << "return matchSuccess(std::move(state));\n }\n";
|
||||
|
||||
emitMatcher(tree);
|
||||
ListInit *resultOps = pattern->getValueAsListInit("ResultOps");
|
||||
if (resultOps->size() != 1)
|
||||
PrintFatalError("only single result rules supported");
|
||||
|
@ -183,12 +269,17 @@ void Pattern::emit(StringRef rewriteName, raw_ostream &os) {
|
|||
|
||||
if (!name.empty())
|
||||
os << "/*" << name << "=*/";
|
||||
emitAttributeValue(value, os);
|
||||
emitAttributeValue(value);
|
||||
// TODO(jpienaar): verify types
|
||||
}
|
||||
os << "\n );\n }\n};\n";
|
||||
}
|
||||
|
||||
void Pattern::emit(StringRef rewriteName, Record *p, raw_ostream &os) {
|
||||
Pattern pattern(p, os);
|
||||
pattern.emit(rewriteName);
|
||||
}
|
||||
|
||||
static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
|
||||
emitSourceFileHeader("Rewriters", os);
|
||||
const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
|
||||
|
@ -197,8 +288,7 @@ static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
|
|||
std::string baseRewriteName = "GeneratedConvert";
|
||||
int rewritePatternCount = 0;
|
||||
for (Record *p : patterns) {
|
||||
Pattern pattern(p);
|
||||
pattern.emit(baseRewriteName + llvm::utostr(rewritePatternCount++), os);
|
||||
Pattern::emit(baseRewriteName + llvm::utostr(rewritePatternCount++), p, os);
|
||||
}
|
||||
|
||||
// Emit function to add the generated matchers to the pattern list.
|
||||
|
|
Loading…
Reference in New Issue