[TableGen] Support multiple result patterns

This CL added the ability to generate multiple ops using multiple result
patterns, with each of them replacing one result of the matched source op.

Specifically, the syntax is

```
def : Pattern<(SourceOp ...),
              [(ResultOp1 ...), (ResultOp2 ...), (ResultOp3 ...)]>;
```

Assuming `SourceOp` has three results.

Currently we require that each result op must generate one result, which
can be lifted later when use cases arise.

To help with cases that certain output is unused and we don't care about it,
this CL also introduces a new directive: `verifyUnusedValue`. Checks will
be emitted in the `match()` method to make sure if the corresponding output
is not unused, `match()` returns with `matchFailure()`.

PiperOrigin-RevId: 237513904
This commit is contained in:
Lei Zhang 2019-03-08 13:56:53 -08:00 committed by jpienaar
parent 87884ab4b6
commit 18fde7c9d8
6 changed files with 142 additions and 17 deletions

View File

@ -541,7 +541,18 @@ class Results<dag rets> {
// Base class for op+ -> op+ rewrite patterns. These allow declaratively
// specifying rewrite patterns.
class Pattern<dag source, list<dag> results, list<dag> preds> {
//
// A rewrite pattern contains two components: a source pattern and one or more
// result patterns. Each pattern is specified as a (recursive) DAG node (tree)
// in the form of `(node arg0, arg1, ...)`.
// The `node` are normally MLIR ops, but it can also be one of the directives
// listed later in this section.
// In the source pattern, `arg*` can be used to specify matchers (e.g., using
// type/attribute types, mAttr, etc.) and bound to a name for later use. In
// the result pattern, `arg*` can be used to refer to a previously bound name,
// with potential transformations (e.g., using tAttr, etc.). `arg*` can itself
// be nested DAG node.
class Pattern<dag source, list<dag> results, list<dag> preds = []> {
dag patternToMatch = source;
list<dag> resultOps = results;
list<dag> constraints = preds;
@ -561,9 +572,8 @@ class mAttrAnyOf<list<AttrConstraint> attrs> :
mAttr<AnyOf<!foldl([]<Pred>, attrs, prev, attr,
!listconcat(prev, [attr.predicate]))>>;
// Attribute transforms. This is the base class to specify a
// transformation of a matched attribute. Used on the output of a rewrite
// rule.
// Attribute transformation. This is the base class to specify a transformation
// of a matched attribute. Used on the output attribute of a rewrite rule.
class tAttr<code transform> {
// Code to transform the attribute.
// Format: {0} represents the attribute.
@ -598,8 +608,14 @@ class mPat<string f> {
string function = f;
}
// Marker used to indicate that no new result op are generated by applying the
// rewrite pattern, so to replace the matched DAG with an existing SSA value.
// Directive used in result pattern to indicate that no new result op are
// generated, so to replace the matched DAG with an existing SSA value.
def replaceWithValue;
// Directive used in result pattern to indicate that no replacement is generated
// for the current result. Predicates are generated to make sure the
// corresponding result in source pattern is unused.
// syntax: (verifyUnusedValue)
def verifyUnusedValue;
#endif // OP_BASE

View File

@ -162,6 +162,9 @@ public:
// value.
bool isReplaceWithValue() const;
// Returns true if this DAG node is the `verifyUnusedValue` directive.
bool isVerifyUnusedValue() const;
// Returns true if this DAG construct is meant to invoke a native code
// constructor.
bool isNativeCodeBuilder() const;

View File

@ -162,6 +162,11 @@ bool tblgen::DagNode::isReplaceWithValue() const {
return dagOpDef->getName() == "replaceWithValue";
}
bool tblgen::DagNode::isVerifyUnusedValue() const {
auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
return dagOpDef->getName() == "verifyUnusedValue";
}
bool tblgen::DagNode::isNativeCodeBuilder() const {
auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
return dagOpDef->isSubClassOf("cOp");

View File

@ -0,0 +1,29 @@
// RUN: mlir-tblgen -gen-rewriters -I %S/../../include %s | FileCheck %s
include "mlir/IR/OpBase.td"
def ThreeResultOp : Op<"three_result_op", []> {
let arguments = (ins I32:$input);
let results = (outs I32:$r1, I32:$r2, I32:$r3);
}
def OneResultOp : Op<"one_result_op", []> {
let arguments = (ins I32:$input);
let results = (outs I32:$r1);
}
def : Pattern<(ThreeResultOp $input), [
(verifyUnusedValue),
(OneResultOp $input),
(verifyUnusedValue)
]>;
// CHECK-LABEL: struct GeneratedConvert0
// CHECK: PatternMatchResult match(
// CHECK: if (!op0->getResult(0)->use_empty()) return matchFailure();
// CHECK: if (!op0->getResult(2)->use_empty()) return matchFailure();
// CHECK: void rewrite(
// CHECK: auto vOneResultOp0 = rewriter.create<OneResultOp>(
// CHECK: rewriter.replaceOp(op, {nullptr, vOneResultOp0, nullptr});

View File

@ -0,0 +1,27 @@
// RUN: mlir-tblgen -gen-rewriters -I %S/../../include %s | FileCheck %s
include "mlir/IR/OpBase.td"
def ThreeResultOp : Op<"three_result_op", []> {
let arguments = (ins I32:$input);
let results = (outs I32:$r1, I32:$r2, I32:$r3);
}
def OneResultOp : Op<"one_result_op", []> {
let arguments = (ins I32:$input);
let results = (outs I32:$r1);
}
def : Pattern<(ThreeResultOp $input), [
(OneResultOp $input),
(OneResultOp $input),
(OneResultOp $input)
]>;
// CHECK-LABEL: struct GeneratedConvert0
// CHECK: void rewrite(
// CHECK: auto vOneResultOp0 = rewriter.create<OneResultOp>(
// CHECK: auto vOneResultOp1 = rewriter.create<OneResultOp>(
// CHECK: auto vOneResultOp2 = rewriter.create<OneResultOp>(
// CHECK: rewriter.replaceOp(op, {vOneResultOp0, vOneResultOp1, vOneResultOp2});

View File

@ -95,6 +95,10 @@ private:
// replacement.
std::string handleReplaceWithValue(DagNode tree);
// Handles the `verifyUnusedValue` directive: emitting C++ statements to check
// the `index`-th result of the source op is not used.
void handleVerifyUnusedValue(DagNode tree, int index);
// Emits the C++ statement to build a new op out of the given DAG `tree` and
// returns the variable name that this op is assigned to. If `treeName` is not
// empty, the created op will be assigned to a variable of the given
@ -267,12 +271,19 @@ void PatternEmitter::emitMatchMethod(DagNode tree) {
// Emit the heading.
os << R"(
PatternMatchResult match(Instruction *op0) const override {
// TODO: This just handle 1 result
if (op0->getNumResults() != 1) return matchFailure();
auto ctx = op0->getContext(); (void)ctx;
auto state = std::make_unique<MatchedState>();)"
<< "\n";
// The rewrite pattern may specify that certain outputs should be unused in
// the source IR. Check it here.
for (int i = 0, e = pattern.getNumResults(); i < e; ++i) {
DagNode resultTree = pattern.getResultPattern(i);
if (resultTree.isVerifyUnusedValue()) {
handleVerifyUnusedValue(resultTree, i);
}
}
for (auto &res : pattern.getSourcePatternBoundResults())
os.indent(4) << formatv("mlir::Instruction* {0}; (void){0};\n",
resultName(res.first()));
@ -311,6 +322,7 @@ void PatternEmitter::emitMatchMethod(DagNode tree) {
"Pattern constraints have to be either a type or native constraint");
}
}
os.indent(4) << "return matchSuccess(std::move(state));\n }\n";
}
@ -351,10 +363,9 @@ void PatternEmitter::emit(StringRef rewriteName) {
}
void PatternEmitter::emitRewriteMethod() {
if (pattern.getNumResults() != 1)
PrintFatalError("only single result rules supported");
DagNode resultTree = pattern.getResultPattern(0);
unsigned numResults = pattern.getNumResults();
if (numResults == 0)
PrintFatalError(loc, "must provide at least one result pattern");
os << R"(
void rewrite(Instruction *op, std::unique_ptr<PatternState> state,
@ -363,10 +374,18 @@ void PatternEmitter::emitRewriteMethod() {
auto loc = op->getLoc(); (void)loc;
)";
std::string resultValue =
handleRewritePattern(resultTree, /*resultIndex=*/0, /*depth=*/0);
// Collect the replacement value for each result
llvm::SmallVector<std::string, 2> resultValues;
for (unsigned i = 0; i < numResults; ++i) {
DagNode resultTree = pattern.getResultPattern(i);
resultValues.push_back(handleRewritePattern(resultTree, i, 0));
}
os.indent(4) << "rewriter.replaceOp(op, {" << resultValue;
// Emit the final replaceOp() statement
os.indent(4) << "rewriter.replaceOp(op, {";
interleave(
resultValues, [&](const std::string &name) { os << name; },
[&]() { os << ", "; });
os << "});\n }\n";
}
@ -380,6 +399,20 @@ std::string PatternEmitter::handleRewritePattern(DagNode resultTree,
if (resultTree.isNativeCodeBuilder())
return emitReplaceWithNativeBuilder(resultTree);
if (resultTree.isVerifyUnusedValue()) {
if (depth > 0) {
// TODO: Revisit this when we have use cases of matching an intermediate
// multi-result op with no uses of its certain results.
PrintFatalError(loc, "verifyUnusedValue directive can only be used to "
"verify top-level result");
}
// The C++ statements to check that this result value is unused are already
// emitted in the match() method. So returning a nullptr here directly
// should be safe because the C++ RewritePattern harness will use it to
// replace nothing.
return "nullptr";
}
if (resultTree.isReplaceWithValue())
return handleReplaceWithValue(resultTree);
@ -400,11 +433,17 @@ std::string PatternEmitter::handleReplaceWithValue(DagNode tree) {
return boundArgNameInRewrite(name).str();
}
void PatternEmitter::handleVerifyUnusedValue(DagNode tree, int index) {
assert(tree.isVerifyUnusedValue());
os.indent(4) << "if (!op0->getResult(" << index
<< ")->use_empty()) return matchFailure();\n";
}
std::string PatternEmitter::emitOpCreate(DagNode tree, int resultIndex,
int depth, StringRef treeName) {
Operator &resultOp = tree.getDialectOp(opMap);
auto numOpArgs =
resultOp.getNumOperands() + resultOp.getNumNativeAttributes();
auto numOpArgs = resultOp.getNumArgs();
if (numOpArgs != tree.getNumArgs()) {
PrintFatalError(loc, formatv("resultant op '{0}' argument number mismatch: "
@ -413,6 +452,12 @@ std::string PatternEmitter::emitOpCreate(DagNode tree, int resultIndex,
numOpArgs));
}
if (resultOp.getNumResults() > 1) {
PrintFatalError(
loc, formatv("generating multiple-result op '{0}' is unsupported now",
resultOp.getOperationName()));
}
// A map to collect all nested DAG child nodes' names, with operand index as
// the key.
llvm::DenseMap<unsigned, std::string> childNodeNames;