forked from OSchip/llvm-project
Add support to RewritePattern for specifying the potential operations that can be generated during a rewrite. This will enable analyses to start understanding the possible effects of applying a rewrite pattern.
-- PiperOrigin-RevId: 249936309
This commit is contained in:
parent
09438a412f
commit
647f8cabb9
|
@ -184,12 +184,25 @@ public:
|
|||
return matchFailure();
|
||||
}
|
||||
|
||||
/// Return a list of operations that may be generated when rewriting an
|
||||
/// operation instance with this pattern.
|
||||
ArrayRef<OperationName> getGeneratedOps() const { return generatedOps; }
|
||||
|
||||
protected:
|
||||
/// Patterns must specify the root operation name they match against, and can
|
||||
/// also specify the benefit of the pattern matching.
|
||||
RewritePattern(StringRef rootName, PatternBenefit benefit,
|
||||
MLIRContext *context)
|
||||
: Pattern(rootName, benefit, context) {}
|
||||
/// Patterns must specify the root operation name they match against, and can
|
||||
/// also specify the benefit of the pattern matching. They can also specify
|
||||
/// the names of operations that may be generated during a successful rewrite.
|
||||
RewritePattern(StringRef rootName, ArrayRef<StringRef> generatedNames,
|
||||
PatternBenefit benefit, MLIRContext *context);
|
||||
|
||||
/// A list of the potential operations that may be generated when rewriting
|
||||
/// an op with this pattern.
|
||||
llvm::SmallVector<OperationName, 2> generatedOps;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -172,6 +172,9 @@ public:
|
|||
// Returns true if this DAG node is wrapping native code call.
|
||||
bool isNativeCodeCall() const;
|
||||
|
||||
// Returns true if this DAG node is an operation.
|
||||
bool isOperation() const;
|
||||
|
||||
// Returns the native code call template inside this DAG node.
|
||||
// Precondition: isNativeCodeCall()
|
||||
llvm::StringRef getNativeCodeTemplate() const;
|
||||
|
|
|
@ -60,6 +60,20 @@ PatternMatchResult RewritePattern::match(Operation *op) const {
|
|||
llvm_unreachable("need to implement either match or matchAndRewrite!");
|
||||
}
|
||||
|
||||
/// Patterns must specify the root operation name they match against, and can
|
||||
/// also specify the benefit of the pattern matching. They can also specify the
|
||||
/// names of operations that may be generated during a successful rewrite.
|
||||
RewritePattern::RewritePattern(StringRef rootName,
|
||||
ArrayRef<StringRef> generatedNames,
|
||||
PatternBenefit benefit, MLIRContext *context)
|
||||
: Pattern(rootName, benefit, context) {
|
||||
generatedOps.reserve(generatedNames.size());
|
||||
std::transform(generatedNames.begin(), generatedNames.end(),
|
||||
std::back_inserter(generatedOps), [context](StringRef name) {
|
||||
return OperationName(name, context);
|
||||
});
|
||||
}
|
||||
|
||||
PatternRewriter::~PatternRewriter() {
|
||||
// Out of line to provide a vtable anchor for the class.
|
||||
}
|
||||
|
|
|
@ -94,6 +94,10 @@ bool tblgen::DagNode::isNativeCodeCall() const {
|
|||
return false;
|
||||
}
|
||||
|
||||
bool tblgen::DagNode::isOperation() const {
|
||||
return !(isNativeCodeCall() || isVerifyUnusedValue() || isReplaceWithValue());
|
||||
}
|
||||
|
||||
llvm::StringRef tblgen::DagNode::getNativeCodeTemplate() const {
|
||||
assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
|
||||
return cast<llvm::DefInit>(node->getOperator())
|
||||
|
|
|
@ -26,9 +26,9 @@ def Z_AddOp : NS_Op<"add"> {
|
|||
def bena : Pat<(X_AddOp (X_AddOp $lhs, $rhs), $rhs), (Y_AddOp $lhs, $rhs, $rhs)>;
|
||||
|
||||
// CHECK-LABEL: struct bena
|
||||
// CHECK: RewritePattern("x.add", 2, context) {}
|
||||
// CHECK: RewritePattern("x.add", {"x.add"}, 2, context) {}
|
||||
|
||||
def benb : Pat<(X_AddOp $lhs, $rhs), (Z_AddOp $lhs), [(IfEqual $lhs, $rhs)], (addBenefit 100)>;
|
||||
|
||||
// CHECK-LABEL: struct benb
|
||||
// CHECK: RewritePattern("x.add", 101, context) {}
|
||||
// CHECK: RewritePattern("x.add", {"x.add"}, 101, context) {}
|
||||
|
|
|
@ -34,7 +34,7 @@ def : Pat<(OpA $input, $attr), (OpB $input, $attr)>;
|
|||
|
||||
// CHECK: struct GeneratedConvert0 : public RewritePattern
|
||||
|
||||
// CHECK: GeneratedConvert0(MLIRContext *context) : RewritePattern("op_a", 1, context) {}
|
||||
// CHECK: GeneratedConvert0(MLIRContext *context) : RewritePattern("op_a", {"op_b"}, 1, context) {}
|
||||
|
||||
// CHECK: struct MatchedState : public PatternState {
|
||||
// CHECK: Value *input;
|
||||
|
|
|
@ -153,6 +153,9 @@ private:
|
|||
// Emits the match() method.
|
||||
void emitMatchMethod(DagNode tree);
|
||||
|
||||
// Collects all of the operations within the given dag tree.
|
||||
void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops);
|
||||
|
||||
// Emits the rewrite() method.
|
||||
void emitRewriteMethod();
|
||||
|
||||
|
@ -443,6 +446,18 @@ void PatternEmitter::emitMatchMethod(DagNode tree) {
|
|||
os.indent(4) << "return matchSuccess(std::move(state));\n }\n";
|
||||
}
|
||||
|
||||
void PatternEmitter::collectOps(DagNode tree,
|
||||
llvm::SmallPtrSetImpl<const Operator *> &ops) {
|
||||
// Check if this tree is an operation.
|
||||
if (tree.isOperation())
|
||||
ops.insert(&tree.getDialectOp(opMap));
|
||||
|
||||
// Recurse the arguments of the tree.
|
||||
for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i)
|
||||
if (auto child = tree.getArgAsNestedDag(i))
|
||||
collectOps(child, ops);
|
||||
}
|
||||
|
||||
void PatternEmitter::emit(StringRef rewriteName) {
|
||||
// Get the DAG tree for the source pattern
|
||||
DagNode tree = pattern.getSourcePattern();
|
||||
|
@ -454,14 +469,22 @@ void PatternEmitter::emit(StringRef rewriteName) {
|
|||
PrintFatalError(
|
||||
loc, "replacing op with variadic results not supported right now");
|
||||
|
||||
// Collect the set of result operations.
|
||||
llvm::SmallPtrSet<const Operator *, 4> results;
|
||||
for (unsigned i = 0, e = pattern.getNumResults(); i != e; ++i)
|
||||
collectOps(pattern.getResultPattern(i), results);
|
||||
|
||||
// Emit RewritePattern for Pattern.
|
||||
auto locs = pattern.getLocation();
|
||||
os << formatv("/* Generated from:\n\t{0:$[ instantiating\n\t]}\n*/\n",
|
||||
make_range(locs.rbegin(), locs.rend()));
|
||||
os << formatv(R"(struct {0} : public RewritePattern {
|
||||
{0}(MLIRContext *context) : RewritePattern("{1}", {2}, context) {{})",
|
||||
rewriteName, rootName, pattern.getBenefit())
|
||||
<< "\n";
|
||||
{0}(MLIRContext *context) : RewritePattern("{1}", {{)",
|
||||
rewriteName, rootName);
|
||||
interleaveComma(results, os, [&](const Operator *op) {
|
||||
os << '"' << op->getOperationName() << '"';
|
||||
});
|
||||
os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n";
|
||||
|
||||
// Emit matched state.
|
||||
os << " struct MatchedState : public PatternState {\n";
|
||||
|
|
Loading…
Reference in New Issue