Add support to RewritePattern for specifying the potential operations that can be generated during a rewrite. This will enable analyses to start understanding the possible effects of applying a rewrite pattern.

--

PiperOrigin-RevId: 249936309
This commit is contained in:
River Riddle 2019-05-24 19:35:23 -07:00 committed by Mehdi Amini
parent 09438a412f
commit 647f8cabb9
7 changed files with 63 additions and 6 deletions

View File

@ -184,12 +184,25 @@ public:
return matchFailure();
}
/// Return a list of operations that may be generated when rewriting an
/// operation instance with this pattern.
ArrayRef<OperationName> getGeneratedOps() const { return generatedOps; }
protected:
/// Patterns must specify the root operation name they match against, and can
/// also specify the benefit of the pattern matching.
RewritePattern(StringRef rootName, PatternBenefit benefit,
MLIRContext *context)
: Pattern(rootName, benefit, context) {}
/// Patterns must specify the root operation name they match against, and can
/// also specify the benefit of the pattern matching. They can also specify
/// the names of operations that may be generated during a successful rewrite.
RewritePattern(StringRef rootName, ArrayRef<StringRef> generatedNames,
PatternBenefit benefit, MLIRContext *context);
/// A list of the potential operations that may be generated when rewriting
/// an op with this pattern.
llvm::SmallVector<OperationName, 2> generatedOps;
};
//===----------------------------------------------------------------------===//

View File

@ -172,6 +172,9 @@ public:
// Returns true if this DAG node is wrapping native code call.
bool isNativeCodeCall() const;
// Returns true if this DAG node is an operation.
bool isOperation() const;
// Returns the native code call template inside this DAG node.
// Precondition: isNativeCodeCall()
llvm::StringRef getNativeCodeTemplate() const;

View File

@ -60,6 +60,20 @@ PatternMatchResult RewritePattern::match(Operation *op) const {
llvm_unreachable("need to implement either match or matchAndRewrite!");
}
/// Patterns must specify the root operation name they match against, and can
/// also specify the benefit of the pattern matching. They can also specify the
/// names of operations that may be generated during a successful rewrite.
RewritePattern::RewritePattern(StringRef rootName,
ArrayRef<StringRef> generatedNames,
PatternBenefit benefit, MLIRContext *context)
: Pattern(rootName, benefit, context) {
generatedOps.reserve(generatedNames.size());
std::transform(generatedNames.begin(), generatedNames.end(),
std::back_inserter(generatedOps), [context](StringRef name) {
return OperationName(name, context);
});
}
PatternRewriter::~PatternRewriter() {
// Out of line to provide a vtable anchor for the class.
}

View File

@ -94,6 +94,10 @@ bool tblgen::DagNode::isNativeCodeCall() const {
return false;
}
bool tblgen::DagNode::isOperation() const {
return !(isNativeCodeCall() || isVerifyUnusedValue() || isReplaceWithValue());
}
llvm::StringRef tblgen::DagNode::getNativeCodeTemplate() const {
assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
return cast<llvm::DefInit>(node->getOperator())

View File

@ -26,9 +26,9 @@ def Z_AddOp : NS_Op<"add"> {
def bena : Pat<(X_AddOp (X_AddOp $lhs, $rhs), $rhs), (Y_AddOp $lhs, $rhs, $rhs)>;
// CHECK-LABEL: struct bena
// CHECK: RewritePattern("x.add", 2, context) {}
// CHECK: RewritePattern("x.add", {"x.add"}, 2, context) {}
def benb : Pat<(X_AddOp $lhs, $rhs), (Z_AddOp $lhs), [(IfEqual $lhs, $rhs)], (addBenefit 100)>;
// CHECK-LABEL: struct benb
// CHECK: RewritePattern("x.add", 101, context) {}
// CHECK: RewritePattern("x.add", {"x.add"}, 101, context) {}

View File

@ -34,7 +34,7 @@ def : Pat<(OpA $input, $attr), (OpB $input, $attr)>;
// CHECK: struct GeneratedConvert0 : public RewritePattern
// CHECK: GeneratedConvert0(MLIRContext *context) : RewritePattern("op_a", 1, context) {}
// CHECK: GeneratedConvert0(MLIRContext *context) : RewritePattern("op_a", {"op_b"}, 1, context) {}
// CHECK: struct MatchedState : public PatternState {
// CHECK: Value *input;

View File

@ -153,6 +153,9 @@ private:
// Emits the match() method.
void emitMatchMethod(DagNode tree);
// Collects all of the operations within the given dag tree.
void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops);
// Emits the rewrite() method.
void emitRewriteMethod();
@ -443,6 +446,18 @@ void PatternEmitter::emitMatchMethod(DagNode tree) {
os.indent(4) << "return matchSuccess(std::move(state));\n }\n";
}
void PatternEmitter::collectOps(DagNode tree,
llvm::SmallPtrSetImpl<const Operator *> &ops) {
// Check if this tree is an operation.
if (tree.isOperation())
ops.insert(&tree.getDialectOp(opMap));
// Recurse the arguments of the tree.
for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i)
if (auto child = tree.getArgAsNestedDag(i))
collectOps(child, ops);
}
void PatternEmitter::emit(StringRef rewriteName) {
// Get the DAG tree for the source pattern
DagNode tree = pattern.getSourcePattern();
@ -454,14 +469,22 @@ void PatternEmitter::emit(StringRef rewriteName) {
PrintFatalError(
loc, "replacing op with variadic results not supported right now");
// Collect the set of result operations.
llvm::SmallPtrSet<const Operator *, 4> results;
for (unsigned i = 0, e = pattern.getNumResults(); i != e; ++i)
collectOps(pattern.getResultPattern(i), results);
// Emit RewritePattern for Pattern.
auto locs = pattern.getLocation();
os << formatv("/* Generated from:\n\t{0:$[ instantiating\n\t]}\n*/\n",
make_range(locs.rbegin(), locs.rend()));
os << formatv(R"(struct {0} : public RewritePattern {
{0}(MLIRContext *context) : RewritePattern("{1}", {2}, context) {{})",
rewriteName, rootName, pattern.getBenefit())
<< "\n";
{0}(MLIRContext *context) : RewritePattern("{1}", {{)",
rewriteName, rootName);
interleaveComma(results, os, [&](const Operator *op) {
os << '"' << op->getOperationName() << '"';
});
os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n";
// Emit matched state.
os << " struct MatchedState : public PatternState {\n";