[mlir:MultiOpDriver] Add operands to worklist should be checked

Operand's defining op may not be valid for adding to the worklist under
stict mode

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D127180
This commit is contained in:
Chia-hung Duan 2022-06-11 15:56:21 +00:00
parent ff80dc8544
commit ba3a9f51ff
3 changed files with 124 additions and 41 deletions

View File

@ -43,7 +43,7 @@ public:
bool simplify(MutableArrayRef<Region> regions);
/// Add the given operation to the worklist.
void addToWorklist(Operation *op);
virtual void addToWorklist(Operation *op);
/// Pop the next operation from the worklist.
Operation *popFromWorklist();
@ -60,8 +60,7 @@ protected:
// be re-added to the worklist. This function should be called when an
// operation is modified or removed, as it may trigger further
// simplifications.
template <typename Operands>
void addToWorklist(Operands &&operands);
void addOperandsToWorklist(ValueRange operands);
// If an operation is about to be removed, make sure it is not in our
// worklist anymore because we'd get dangling references to it.
@ -219,7 +218,7 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
originalOperands.assign(op->operand_begin(), op->operand_end());
auto preReplaceAction = [&](Operation *op) {
// Add the operands to the worklist for visitation.
addToWorklist(originalOperands);
addOperandsToWorklist(originalOperands);
// Add all the users of the result to the worklist so we make sure
// to revisit them.
@ -327,8 +326,7 @@ void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) {
addToWorklist(op);
}
template <typename Operands>
void GreedyPatternRewriteDriver::addToWorklist(Operands &&operands) {
void GreedyPatternRewriteDriver::addOperandsToWorklist(ValueRange operands) {
for (Value operand : operands) {
// If the use count of this operand is now < 2, we re-add the defining
// operation to the worklist.
@ -343,7 +341,7 @@ void GreedyPatternRewriteDriver::addToWorklist(Operands &&operands) {
}
void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {
addToWorklist(op->getOperands());
addOperandsToWorklist(op->getOperands());
op->walk([this](Operation *operation) {
removeFromWorklist(operation);
folder.notifyRemoval(operation);
@ -523,22 +521,12 @@ public:
bool simplifyLocally(ArrayRef<Operation *> op);
private:
// Look over the provided operands for any defining operations that should
// be re-added to the worklist. This function should be called when an
// operation is modified or removed, as it may trigger further
// simplifications. If `strict` is set to true, only ops in
// `strictModeFilteredOps` are considered.
template <typename Operands>
void addOperandsToWorklist(Operands &&operands) {
for (Value operand : operands) {
if (auto *defOp = operand.getDefiningOp()) {
if (!strictMode || strictModeFilteredOps.contains(defOp))
addToWorklist(defOp);
}
}
void addToWorklist(Operation *op) override {
if (!strictMode || strictModeFilteredOps.contains(op))
GreedyPatternRewriteDriver::addToWorklist(op);
}
private:
void notifyOperationInserted(Operation *op) override {
GreedyPatternRewriteDriver::notifyOperationInserted(op);
if (strictMode)
@ -551,15 +539,6 @@ private:
strictModeFilteredOps.erase(op);
}
void notifyRootReplaced(Operation *op) override {
for (auto result : op->getResults()) {
for (auto *user : result.getUsers()) {
if (!strictMode || strictModeFilteredOps.contains(user))
addToWorklist(user);
}
}
}
/// If `strictMode` is true, any pre-existing ops outside of
/// `strictModeFilteredOps` remain completely untouched by the rewrite driver.
/// If `strictMode` is false, operations that use results of (or supply
@ -633,22 +612,17 @@ bool MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops) {
// Add all the users of the result to the worklist so we make sure
// to revisit them.
for (Value result : op->getResults())
for (Operation *userOp : result.getUsers()) {
if (!strictMode || strictModeFilteredOps.contains(userOp))
addToWorklist(userOp);
}
for (Value result : op->getResults()) {
for (Operation *userOp : result.getUsers())
addToWorklist(userOp);
}
notifyOperationRemoved(op);
};
// Add the given operation generated by the folder to the worklist.
auto processGeneratedConstants = [this](Operation *op) {
// Newly created ops are also simplified -- these are also "local".
addToWorklist(op);
// When strict mode is off, we don't need to maintain
// strictModeFilteredOps.
if (strictMode)
strictModeFilteredOps.insert(op);
notifyOperationInserted(op);
};
// Try to fold this op.

View File

@ -0,0 +1,23 @@
// RUN: mlir-opt -allow-unregistered-dialect -test-strict-pattern-driver %s | FileCheck %s
// CHECK-LABEL: @test_erase
func.func @test_erase() {
%0 = "test.arg0"() : () -> (i32)
%1 = "test.arg1"() : () -> (i32)
%erase = "test.erase_op"(%0, %1) : (i32, i32) -> (i32)
return
}
// CHECK-LABEL: @test_insert_same_op
func.func @test_insert_same_op() {
%0 = "test.insert_same_op"() : () -> (i32)
return
}
// CHECK-LABEL: @test_replace_with_same_op
func.func @test_replace_with_same_op() {
%0 = "test.replace_with_same_op"() : () -> (i32)
%1 = "test.dummy_user"(%0) : (i32) -> (i32)
%2 = "test.dummy_user"(%0) : (i32) -> (i32)
return
}

View File

@ -176,6 +176,91 @@ struct TestPatternDriver
llvm::cl::desc("Seed the worklist in general top-down order"),
llvm::cl::init(GreedyRewriteConfig().useTopDownTraversal)};
};
struct TestStrictPatternDriver
: public PassWrapper<TestStrictPatternDriver, OperationPass<func::FuncOp>> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestStrictPatternDriver)
TestStrictPatternDriver() = default;
TestStrictPatternDriver(const TestStrictPatternDriver &other)
: PassWrapper(other) {}
StringRef getArgument() const final { return "test-strict-pattern-driver"; }
StringRef getDescription() const final {
return "Run strict mode of pattern driver";
}
void runOnOperation() override {
mlir::RewritePatternSet patterns(&getContext());
patterns.add<InsertSameOp, ReplaceWithSameOp, EraseOp>(&getContext());
SmallVector<Operation *> ops;
getOperation()->walk([&](Operation *op) {
StringRef opName = op->getName().getStringRef();
if (opName == "test.insert_same_op" ||
opName == "test.replace_with_same_op" || opName == "test.erase_op") {
ops.push_back(op);
}
});
// Check if these transformations introduce visiting of operations that
// are not in the `ops` set (The new created ops are valid). An invalid
// operation will trigger the assertion while processing.
(void)applyOpPatternsAndFold(makeArrayRef(ops), std::move(patterns),
/*strict=*/true);
}
private:
// New inserted operation is valid for further transformation.
class InsertSameOp : public RewritePattern {
public:
InsertSameOp(MLIRContext *context)
: RewritePattern("test.insert_same_op", /*benefit=*/1, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
if (op->hasAttr("skip"))
return failure();
Operation *newOp =
rewriter.create(op->getLoc(), op->getName().getIdentifier(),
op->getOperands(), op->getResultTypes());
op->setAttr("skip", rewriter.getBoolAttr(true));
newOp->setAttr("skip", rewriter.getBoolAttr(true));
return success();
}
};
// Replace an operation may introduce the re-visiting of its users.
class ReplaceWithSameOp : public RewritePattern {
public:
ReplaceWithSameOp(MLIRContext *context)
: RewritePattern("test.replace_with_same_op", /*benefit=*/1, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
Operation *newOp =
rewriter.create(op->getLoc(), op->getName().getIdentifier(),
op->getOperands(), op->getResultTypes());
rewriter.replaceOp(op, newOp->getResults());
return success();
}
};
// Remove an operation may introduce the re-visiting of its opreands.
class EraseOp : public RewritePattern {
public:
EraseOp(MLIRContext *context)
: RewritePattern("test.erase_op", /*benefit=*/1, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
rewriter.eraseOp(op);
return success();
}
};
};
} // namespace
//===----------------------------------------------------------------------===//
@ -1471,6 +1556,7 @@ void registerPatternsTestPass() {
PassRegistration<TestDerivedAttributeDriver>();
PassRegistration<TestPatternDriver>();
PassRegistration<TestStrictPatternDriver>();
PassRegistration<TestLegalizePatternDriver>([] {
return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode);