forked from OSchip/llvm-project
[mlir:MultiOpDriver] Add operands to worklist should be checked
Operand's defining op may not be valid for adding to the worklist under stict mode Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D127180
This commit is contained in:
parent
ff80dc8544
commit
ba3a9f51ff
|
@ -43,7 +43,7 @@ public:
|
|||
bool simplify(MutableArrayRef<Region> regions);
|
||||
|
||||
/// Add the given operation to the worklist.
|
||||
void addToWorklist(Operation *op);
|
||||
virtual void addToWorklist(Operation *op);
|
||||
|
||||
/// Pop the next operation from the worklist.
|
||||
Operation *popFromWorklist();
|
||||
|
@ -60,8 +60,7 @@ protected:
|
|||
// be re-added to the worklist. This function should be called when an
|
||||
// operation is modified or removed, as it may trigger further
|
||||
// simplifications.
|
||||
template <typename Operands>
|
||||
void addToWorklist(Operands &&operands);
|
||||
void addOperandsToWorklist(ValueRange operands);
|
||||
|
||||
// If an operation is about to be removed, make sure it is not in our
|
||||
// worklist anymore because we'd get dangling references to it.
|
||||
|
@ -219,7 +218,7 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
|
|||
originalOperands.assign(op->operand_begin(), op->operand_end());
|
||||
auto preReplaceAction = [&](Operation *op) {
|
||||
// Add the operands to the worklist for visitation.
|
||||
addToWorklist(originalOperands);
|
||||
addOperandsToWorklist(originalOperands);
|
||||
|
||||
// Add all the users of the result to the worklist so we make sure
|
||||
// to revisit them.
|
||||
|
@ -327,8 +326,7 @@ void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) {
|
|||
addToWorklist(op);
|
||||
}
|
||||
|
||||
template <typename Operands>
|
||||
void GreedyPatternRewriteDriver::addToWorklist(Operands &&operands) {
|
||||
void GreedyPatternRewriteDriver::addOperandsToWorklist(ValueRange operands) {
|
||||
for (Value operand : operands) {
|
||||
// If the use count of this operand is now < 2, we re-add the defining
|
||||
// operation to the worklist.
|
||||
|
@ -343,7 +341,7 @@ void GreedyPatternRewriteDriver::addToWorklist(Operands &&operands) {
|
|||
}
|
||||
|
||||
void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {
|
||||
addToWorklist(op->getOperands());
|
||||
addOperandsToWorklist(op->getOperands());
|
||||
op->walk([this](Operation *operation) {
|
||||
removeFromWorklist(operation);
|
||||
folder.notifyRemoval(operation);
|
||||
|
@ -523,22 +521,12 @@ public:
|
|||
|
||||
bool simplifyLocally(ArrayRef<Operation *> op);
|
||||
|
||||
private:
|
||||
// Look over the provided operands for any defining operations that should
|
||||
// be re-added to the worklist. This function should be called when an
|
||||
// operation is modified or removed, as it may trigger further
|
||||
// simplifications. If `strict` is set to true, only ops in
|
||||
// `strictModeFilteredOps` are considered.
|
||||
template <typename Operands>
|
||||
void addOperandsToWorklist(Operands &&operands) {
|
||||
for (Value operand : operands) {
|
||||
if (auto *defOp = operand.getDefiningOp()) {
|
||||
if (!strictMode || strictModeFilteredOps.contains(defOp))
|
||||
addToWorklist(defOp);
|
||||
}
|
||||
}
|
||||
void addToWorklist(Operation *op) override {
|
||||
if (!strictMode || strictModeFilteredOps.contains(op))
|
||||
GreedyPatternRewriteDriver::addToWorklist(op);
|
||||
}
|
||||
|
||||
private:
|
||||
void notifyOperationInserted(Operation *op) override {
|
||||
GreedyPatternRewriteDriver::notifyOperationInserted(op);
|
||||
if (strictMode)
|
||||
|
@ -551,15 +539,6 @@ private:
|
|||
strictModeFilteredOps.erase(op);
|
||||
}
|
||||
|
||||
void notifyRootReplaced(Operation *op) override {
|
||||
for (auto result : op->getResults()) {
|
||||
for (auto *user : result.getUsers()) {
|
||||
if (!strictMode || strictModeFilteredOps.contains(user))
|
||||
addToWorklist(user);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// If `strictMode` is true, any pre-existing ops outside of
|
||||
/// `strictModeFilteredOps` remain completely untouched by the rewrite driver.
|
||||
/// If `strictMode` is false, operations that use results of (or supply
|
||||
|
@ -633,22 +612,17 @@ bool MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops) {
|
|||
|
||||
// Add all the users of the result to the worklist so we make sure
|
||||
// to revisit them.
|
||||
for (Value result : op->getResults())
|
||||
for (Operation *userOp : result.getUsers()) {
|
||||
if (!strictMode || strictModeFilteredOps.contains(userOp))
|
||||
addToWorklist(userOp);
|
||||
}
|
||||
for (Value result : op->getResults()) {
|
||||
for (Operation *userOp : result.getUsers())
|
||||
addToWorklist(userOp);
|
||||
}
|
||||
|
||||
notifyOperationRemoved(op);
|
||||
};
|
||||
|
||||
// Add the given operation generated by the folder to the worklist.
|
||||
auto processGeneratedConstants = [this](Operation *op) {
|
||||
// Newly created ops are also simplified -- these are also "local".
|
||||
addToWorklist(op);
|
||||
// When strict mode is off, we don't need to maintain
|
||||
// strictModeFilteredOps.
|
||||
if (strictMode)
|
||||
strictModeFilteredOps.insert(op);
|
||||
notifyOperationInserted(op);
|
||||
};
|
||||
|
||||
// Try to fold this op.
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
// RUN: mlir-opt -allow-unregistered-dialect -test-strict-pattern-driver %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @test_erase
|
||||
func.func @test_erase() {
|
||||
%0 = "test.arg0"() : () -> (i32)
|
||||
%1 = "test.arg1"() : () -> (i32)
|
||||
%erase = "test.erase_op"(%0, %1) : (i32, i32) -> (i32)
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_insert_same_op
|
||||
func.func @test_insert_same_op() {
|
||||
%0 = "test.insert_same_op"() : () -> (i32)
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_replace_with_same_op
|
||||
func.func @test_replace_with_same_op() {
|
||||
%0 = "test.replace_with_same_op"() : () -> (i32)
|
||||
%1 = "test.dummy_user"(%0) : (i32) -> (i32)
|
||||
%2 = "test.dummy_user"(%0) : (i32) -> (i32)
|
||||
return
|
||||
}
|
|
@ -176,6 +176,91 @@ struct TestPatternDriver
|
|||
llvm::cl::desc("Seed the worklist in general top-down order"),
|
||||
llvm::cl::init(GreedyRewriteConfig().useTopDownTraversal)};
|
||||
};
|
||||
|
||||
struct TestStrictPatternDriver
|
||||
: public PassWrapper<TestStrictPatternDriver, OperationPass<func::FuncOp>> {
|
||||
public:
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestStrictPatternDriver)
|
||||
|
||||
TestStrictPatternDriver() = default;
|
||||
TestStrictPatternDriver(const TestStrictPatternDriver &other)
|
||||
: PassWrapper(other) {}
|
||||
|
||||
StringRef getArgument() const final { return "test-strict-pattern-driver"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Run strict mode of pattern driver";
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
patterns.add<InsertSameOp, ReplaceWithSameOp, EraseOp>(&getContext());
|
||||
SmallVector<Operation *> ops;
|
||||
getOperation()->walk([&](Operation *op) {
|
||||
StringRef opName = op->getName().getStringRef();
|
||||
if (opName == "test.insert_same_op" ||
|
||||
opName == "test.replace_with_same_op" || opName == "test.erase_op") {
|
||||
ops.push_back(op);
|
||||
}
|
||||
});
|
||||
|
||||
// Check if these transformations introduce visiting of operations that
|
||||
// are not in the `ops` set (The new created ops are valid). An invalid
|
||||
// operation will trigger the assertion while processing.
|
||||
(void)applyOpPatternsAndFold(makeArrayRef(ops), std::move(patterns),
|
||||
/*strict=*/true);
|
||||
}
|
||||
|
||||
private:
|
||||
// New inserted operation is valid for further transformation.
|
||||
class InsertSameOp : public RewritePattern {
|
||||
public:
|
||||
InsertSameOp(MLIRContext *context)
|
||||
: RewritePattern("test.insert_same_op", /*benefit=*/1, context) {}
|
||||
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (op->hasAttr("skip"))
|
||||
return failure();
|
||||
|
||||
Operation *newOp =
|
||||
rewriter.create(op->getLoc(), op->getName().getIdentifier(),
|
||||
op->getOperands(), op->getResultTypes());
|
||||
op->setAttr("skip", rewriter.getBoolAttr(true));
|
||||
newOp->setAttr("skip", rewriter.getBoolAttr(true));
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Replace an operation may introduce the re-visiting of its users.
|
||||
class ReplaceWithSameOp : public RewritePattern {
|
||||
public:
|
||||
ReplaceWithSameOp(MLIRContext *context)
|
||||
: RewritePattern("test.replace_with_same_op", /*benefit=*/1, context) {}
|
||||
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Operation *newOp =
|
||||
rewriter.create(op->getLoc(), op->getName().getIdentifier(),
|
||||
op->getOperands(), op->getResultTypes());
|
||||
rewriter.replaceOp(op, newOp->getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Remove an operation may introduce the re-visiting of its opreands.
|
||||
class EraseOp : public RewritePattern {
|
||||
public:
|
||||
EraseOp(MLIRContext *context)
|
||||
: RewritePattern("test.erase_op", /*benefit=*/1, context) {}
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1471,6 +1556,7 @@ void registerPatternsTestPass() {
|
|||
PassRegistration<TestDerivedAttributeDriver>();
|
||||
|
||||
PassRegistration<TestPatternDriver>();
|
||||
PassRegistration<TestStrictPatternDriver>();
|
||||
|
||||
PassRegistration<TestLegalizePatternDriver>([] {
|
||||
return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode);
|
||||
|
|
Loading…
Reference in New Issue