Support op removal patterns in TableGen

This CL adds a new marker, replaceWithValue, to indicate that no new result
op is generated by applying a pattern. Instead, the matched DAG is replaced
by an existing SSA value.

Converted the tf.Identity converter to use the pattern.

PiperOrigin-RevId: 230922323
This commit is contained in:
Lei Zhang 2019-01-25 10:09:15 -08:00 committed by jpienaar
parent 95f19d558c
commit 2de5e9fd19
2 changed files with 64 additions and 15 deletions

View File

@ -431,6 +431,7 @@ class Traits<list<string> Traits> {
//===----------------------------------------------------------------------===//
// Patterns
//===----------------------------------------------------------------------===//
// Base class for op+ -> op+ rewrite patterns. These allow declaratively
// specifying rewrite patterns.
// TODO(jpienaar): Add the constraint list along with the Pattern.
@ -459,4 +460,8 @@ class tAttr<code transform> {
code attrTransform = transform;
}
// Marker used to indicate that no new result op are generated by applying the
// rewrite pattern, so to replace the matched DAG with an existing SSA value.
def replaceWithValue;
#endif // OP_BASE

View File

@ -66,18 +66,30 @@ public:
private:
Pattern(Record *pattern, raw_ostream &os) : pattern(pattern), os(os) {}
// Emit the rewrite pattern named `rewriteName`.
// Emits the rewrite pattern named `rewriteName`.
void emit(StringRef rewriteName);
// Emit the matcher.
// Emits the matcher.
void emitMatcher(DagInit *tree);
// Emits the rewrite() method.
void emitRewriteMethod();
// Emits the C++ statement to replace the matched DAG with an existing value.
void emitReplaceWithExistingValue(DagInit *resultTree);
// Emits the C++ statement to replace the matched DAG with a new op.
void emitReplaceOpWithNewOp(DagInit *resultTree);
// Emits the value of constant attribute to `os`.
void emitAttributeValue(Record *constAttr);
// Collect bound arguments.
// Collects bound arguments.
void collectBoundArguments(DagInit *tree);
// Checks whether an argument with the given `name` is bound in source
// pattern. Prints fatal error if not; does nothing otherwise.
void checkArgumentBound(StringRef name) const;
// Helper function to match patterns.
void matchOp(DagInit *tree, int depth);
@ -134,6 +146,12 @@ void Pattern::collectBoundArguments(DagInit *tree) {
}
}
void Pattern::checkArgumentBound(StringRef name) const {
if (boundArguments.find(name) == boundArguments.end())
PrintFatalError(pattern->getLoc(),
Twine("referencing unbound variable '") + name + "'");
}
// Helper function to match patterns.
void Pattern::matchOp(DagInit *tree, int depth) {
Operator &op = getOperator(cast<DefInit>(tree->getOperator())->getDef());
@ -259,6 +277,12 @@ void Pattern::emit(StringRef rewriteName) {
os << " };\n";
emitMatcher(tree);
emitRewriteMethod();
os << "};\n";
}
void Pattern::emitRewriteMethod() {
ListInit *resultOps = pattern->getValueAsListInit("ResultOps");
if (resultOps->size() != 1)
PrintFatalError("only single result rules supported");
@ -270,14 +294,38 @@ void Pattern::emit(StringRef rewriteName) {
PrintFatalError(pattern->getLoc(), "only single op result supported");
}
DefInit *resultRoot = cast<DefInit>(resultTree->getOperator());
Operator &resultOp = getOperator(resultRoot->getDef());
auto resultOperands = resultRoot->getDef()->getValueAsDag("arguments");
os << formatv(R"(
os << R"(
void rewrite(OperationInst *op, std::unique_ptr<PatternState> state,
PatternRewriter &rewriter) const override {
auto& s = *static_cast<MatchedState *>(state.get());
)";
auto *dagOpDef = cast<DefInit>(resultTree->getOperator())->getDef();
if (dagOpDef->getName() == "replaceWithValue")
emitReplaceWithExistingValue(resultTree);
else
emitReplaceOpWithNewOp(resultTree);
os << " }\n";
}
void Pattern::emitReplaceWithExistingValue(DagInit *resultTree) {
if (resultTree->getNumArgs() != 1) {
PrintFatalError(pattern->getLoc(),
"exactly one argument needed in the result pattern");
}
auto name = resultTree->getArgNameStr(0);
checkArgumentBound(name);
os.indent(4) << "rewriter.replaceOp(op, {s." << name << "});\n";
}
void Pattern::emitReplaceOpWithNewOp(DagInit *resultTree) {
DefInit *dagOperator = cast<DefInit>(resultTree->getOperator());
Operator &resultOp = getOperator(dagOperator->getDef());
auto resultOperands = dagOperator->getDef()->getValueAsDag("arguments");
os << formatv(R"(
rewriter.replaceOpWithNewOp<{0}>(op, op->getResult(0)->getType())",
resultOp.cppClassName());
if (resultOperands->getNumArgs() != resultTree->getNumArgs()) {
@ -296,9 +344,7 @@ void Pattern::emit(StringRef rewriteName) {
(os << ",\n").indent(6);
auto name = resultTree->getArgNameStr(i);
if (boundArguments.find(name) == boundArguments.end())
PrintFatalError(pattern->getLoc(),
Twine("referencing unbound variable '") + name + "'");
checkArgumentBound(name);
if (operand.name)
os << "/*" << operand.name->getAsUnquotedString() << "=*/";
os << "s." << name;
@ -317,9 +363,7 @@ void Pattern::emit(StringRef rewriteName) {
auto defInit = dyn_cast<DefInit>(resultTree->getArg(i));
auto *value = defInit ? defInit->getDef()->getValue("value") : nullptr;
if (!value) {
if (boundArguments.find(name) == boundArguments.end())
PrintFatalError(pattern->getLoc(),
Twine("referencing unbound variable '") + name + "'");
checkArgumentBound(name);
auto result = "s." + name;
os << "/*" << opName << "=*/";
if (defInit) {
@ -347,7 +391,7 @@ void Pattern::emit(StringRef rewriteName) {
emitAttributeValue(defInit->getDef());
// TODO(jpienaar): verify types
}
os << "\n );\n }\n};\n";
os << "\n );\n";
}
void Pattern::emit(StringRef rewriteName, Record *p, raw_ostream &os) {