forked from OSchip/llvm-project
Support op removal patterns in TableGen
This CL adds a new marker, replaceWithValue, to indicate that no new result op is generated by applying a pattern. Instead, the matched DAG is replaced by an existing SSA value. Converted the tf.Identity converter to use the pattern. PiperOrigin-RevId: 230922323
This commit is contained in:
parent
95f19d558c
commit
2de5e9fd19
|
@ -431,6 +431,7 @@ class Traits<list<string> Traits> {
|
|||
//===----------------------------------------------------------------------===//
|
||||
// Patterns
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Base class for op+ -> op+ rewrite patterns. These allow declaratively
|
||||
// specifying rewrite patterns.
|
||||
// TODO(jpienaar): Add the constraint list along with the Pattern.
|
||||
|
@ -459,4 +460,8 @@ class tAttr<code transform> {
|
|||
code attrTransform = transform;
|
||||
}
|
||||
|
||||
// Marker used to indicate that no new result op are generated by applying the
|
||||
// rewrite pattern, so to replace the matched DAG with an existing SSA value.
|
||||
def replaceWithValue;
|
||||
|
||||
#endif // OP_BASE
|
||||
|
|
|
@ -66,18 +66,30 @@ public:
|
|||
private:
|
||||
Pattern(Record *pattern, raw_ostream &os) : pattern(pattern), os(os) {}
|
||||
|
||||
// Emit the rewrite pattern named `rewriteName`.
|
||||
// Emits the rewrite pattern named `rewriteName`.
|
||||
void emit(StringRef rewriteName);
|
||||
|
||||
// Emit the matcher.
|
||||
// Emits the matcher.
|
||||
void emitMatcher(DagInit *tree);
|
||||
|
||||
// Emits the rewrite() method.
|
||||
void emitRewriteMethod();
|
||||
|
||||
// Emits the C++ statement to replace the matched DAG with an existing value.
|
||||
void emitReplaceWithExistingValue(DagInit *resultTree);
|
||||
// Emits the C++ statement to replace the matched DAG with a new op.
|
||||
void emitReplaceOpWithNewOp(DagInit *resultTree);
|
||||
|
||||
// Emits the value of constant attribute to `os`.
|
||||
void emitAttributeValue(Record *constAttr);
|
||||
|
||||
// Collect bound arguments.
|
||||
// Collects bound arguments.
|
||||
void collectBoundArguments(DagInit *tree);
|
||||
|
||||
// Checks whether an argument with the given `name` is bound in source
|
||||
// pattern. Prints fatal error if not; does nothing otherwise.
|
||||
void checkArgumentBound(StringRef name) const;
|
||||
|
||||
// Helper function to match patterns.
|
||||
void matchOp(DagInit *tree, int depth);
|
||||
|
||||
|
@ -134,6 +146,12 @@ void Pattern::collectBoundArguments(DagInit *tree) {
|
|||
}
|
||||
}
|
||||
|
||||
void Pattern::checkArgumentBound(StringRef name) const {
|
||||
if (boundArguments.find(name) == boundArguments.end())
|
||||
PrintFatalError(pattern->getLoc(),
|
||||
Twine("referencing unbound variable '") + name + "'");
|
||||
}
|
||||
|
||||
// Helper function to match patterns.
|
||||
void Pattern::matchOp(DagInit *tree, int depth) {
|
||||
Operator &op = getOperator(cast<DefInit>(tree->getOperator())->getDef());
|
||||
|
@ -259,6 +277,12 @@ void Pattern::emit(StringRef rewriteName) {
|
|||
os << " };\n";
|
||||
|
||||
emitMatcher(tree);
|
||||
emitRewriteMethod();
|
||||
|
||||
os << "};\n";
|
||||
}
|
||||
|
||||
void Pattern::emitRewriteMethod() {
|
||||
ListInit *resultOps = pattern->getValueAsListInit("ResultOps");
|
||||
if (resultOps->size() != 1)
|
||||
PrintFatalError("only single result rules supported");
|
||||
|
@ -270,14 +294,38 @@ void Pattern::emit(StringRef rewriteName) {
|
|||
PrintFatalError(pattern->getLoc(), "only single op result supported");
|
||||
}
|
||||
|
||||
DefInit *resultRoot = cast<DefInit>(resultTree->getOperator());
|
||||
Operator &resultOp = getOperator(resultRoot->getDef());
|
||||
auto resultOperands = resultRoot->getDef()->getValueAsDag("arguments");
|
||||
|
||||
os << formatv(R"(
|
||||
os << R"(
|
||||
void rewrite(OperationInst *op, std::unique_ptr<PatternState> state,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto& s = *static_cast<MatchedState *>(state.get());
|
||||
)";
|
||||
|
||||
auto *dagOpDef = cast<DefInit>(resultTree->getOperator())->getDef();
|
||||
if (dagOpDef->getName() == "replaceWithValue")
|
||||
emitReplaceWithExistingValue(resultTree);
|
||||
else
|
||||
emitReplaceOpWithNewOp(resultTree);
|
||||
|
||||
os << " }\n";
|
||||
}
|
||||
|
||||
void Pattern::emitReplaceWithExistingValue(DagInit *resultTree) {
|
||||
if (resultTree->getNumArgs() != 1) {
|
||||
PrintFatalError(pattern->getLoc(),
|
||||
"exactly one argument needed in the result pattern");
|
||||
}
|
||||
|
||||
auto name = resultTree->getArgNameStr(0);
|
||||
checkArgumentBound(name);
|
||||
os.indent(4) << "rewriter.replaceOp(op, {s." << name << "});\n";
|
||||
}
|
||||
|
||||
void Pattern::emitReplaceOpWithNewOp(DagInit *resultTree) {
|
||||
DefInit *dagOperator = cast<DefInit>(resultTree->getOperator());
|
||||
Operator &resultOp = getOperator(dagOperator->getDef());
|
||||
auto resultOperands = dagOperator->getDef()->getValueAsDag("arguments");
|
||||
|
||||
os << formatv(R"(
|
||||
rewriter.replaceOpWithNewOp<{0}>(op, op->getResult(0)->getType())",
|
||||
resultOp.cppClassName());
|
||||
if (resultOperands->getNumArgs() != resultTree->getNumArgs()) {
|
||||
|
@ -296,9 +344,7 @@ void Pattern::emit(StringRef rewriteName) {
|
|||
(os << ",\n").indent(6);
|
||||
|
||||
auto name = resultTree->getArgNameStr(i);
|
||||
if (boundArguments.find(name) == boundArguments.end())
|
||||
PrintFatalError(pattern->getLoc(),
|
||||
Twine("referencing unbound variable '") + name + "'");
|
||||
checkArgumentBound(name);
|
||||
if (operand.name)
|
||||
os << "/*" << operand.name->getAsUnquotedString() << "=*/";
|
||||
os << "s." << name;
|
||||
|
@ -317,9 +363,7 @@ void Pattern::emit(StringRef rewriteName) {
|
|||
auto defInit = dyn_cast<DefInit>(resultTree->getArg(i));
|
||||
auto *value = defInit ? defInit->getDef()->getValue("value") : nullptr;
|
||||
if (!value) {
|
||||
if (boundArguments.find(name) == boundArguments.end())
|
||||
PrintFatalError(pattern->getLoc(),
|
||||
Twine("referencing unbound variable '") + name + "'");
|
||||
checkArgumentBound(name);
|
||||
auto result = "s." + name;
|
||||
os << "/*" << opName << "=*/";
|
||||
if (defInit) {
|
||||
|
@ -347,7 +391,7 @@ void Pattern::emit(StringRef rewriteName) {
|
|||
emitAttributeValue(defInit->getDef());
|
||||
// TODO(jpienaar): verify types
|
||||
}
|
||||
os << "\n );\n }\n};\n";
|
||||
os << "\n );\n";
|
||||
}
|
||||
|
||||
void Pattern::emit(StringRef rewriteName, Record *p, raw_ostream &os) {
|
||||
|
|
Loading…
Reference in New Issue