From cbcb12fd44dfdb51bbf4489d213d96f17be3091f Mon Sep 17 00:00:00 2001 From: Uday Bondhugula Date: Sun, 5 Apr 2020 10:05:52 +0530 Subject: [PATCH] [MLIR] Handle in-place folding properly in greedy pattern rewrite driver OperatioFolder::tryToFold performs both true folding and in a few instances in-place updates through op rewrites. In the latter case, we should still be applying the supplied pattern rewrites in the same iteration; however this wasn't the case since tryToFold returned success() for both true folding and in-place updates, and the patterns for the in-place updated ops were being applied only in the next iteration of the driver's outer loop. This fix would make it converge faster. Differential Revision: https://reviews.llvm.org/D77485 --- mlir/include/mlir/Transforms/FoldUtils.h | 5 +++-- mlir/lib/Transforms/Utils/FoldUtils.cpp | 10 ++++++++-- .../Transforms/Utils/GreedyPatternRewriteDriver.cpp | 13 +++++++++---- 3 files changed, 20 insertions(+), 8 deletions(-) diff --git a/mlir/include/mlir/Transforms/FoldUtils.h b/mlir/include/mlir/Transforms/FoldUtils.h index 0bab87c5e4e3..d2ba43339ce3 100644 --- a/mlir/include/mlir/Transforms/FoldUtils.h +++ b/mlir/include/mlir/Transforms/FoldUtils.h @@ -56,11 +56,12 @@ public: /// folded results, and returns success. `preReplaceAction` is invoked on `op` /// before it is replaced. 'processGeneratedConstants' is invoked for any new /// operations generated when folding. If the op was completely folded it is - /// erased. + /// erased. If it is just updated in place, `inPlaceUpdate` is set to true. LogicalResult tryToFold(Operation *op, function_ref processGeneratedConstants = nullptr, - function_ref preReplaceAction = nullptr); + function_ref preReplaceAction = nullptr, + bool *inPlaceUpdate = nullptr); /// Notifies that the given constant `op` should be remove from this /// OperationFolder's internal bookkeeping. diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp index f2099bca75ea..9e67c2b6b348 100644 --- a/mlir/lib/Transforms/Utils/FoldUtils.cpp +++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp @@ -74,7 +74,10 @@ static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder, LogicalResult OperationFolder::tryToFold( Operation *op, function_ref processGeneratedConstants, - function_ref preReplaceAction) { + function_ref preReplaceAction, bool *inPlaceUpdate) { + if (inPlaceUpdate) + *inPlaceUpdate = false; + // If this is a unique'd constant, return failure as we know that it has // already been folded. if (referencedDialects.count(op)) @@ -87,8 +90,11 @@ LogicalResult OperationFolder::tryToFold( return failure(); // Check to see if the operation was just updated in place. - if (results.empty()) + if (results.empty()) { + if (inPlaceUpdate) + *inPlaceUpdate = true; return success(); + } // Constant folding succeeded. We will start replacing this op's uses and // erase this op. Invoke the callback provided by the caller to perform any diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index 80ad143ce0d3..53c8e9fbd1c2 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -104,7 +104,8 @@ private: // be re-added to the worklist. This function should be called when an // operation is modified or removed, as it may trigger further // simplifications. - template void addToWorklist(Operands &&operands) { + template + void addToWorklist(Operands &&operands) { for (Value operand : operands) { // If the use count of this operand is now < 2, we re-add the defining // operation to the worklist. @@ -133,7 +134,8 @@ private: }; } // end anonymous namespace -/// Perform the rewrites while folding and erasing any dead ops. +/// Performs the rewrites while folding and erasing any dead ops. Returns true +/// if the rewrite converges in `maxIterations`. bool GreedyPatternRewriteDriver::simplify(MutableArrayRef regions, int maxIterations) { // Add the given operation to the worklist. @@ -183,9 +185,12 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef regions, }; // Try to fold this op. - if (succeeded(folder.tryToFold(op, collectOps, preReplaceAction))) { + bool inPlaceUpdate; + if ((succeeded(folder.tryToFold(op, collectOps, preReplaceAction, + &inPlaceUpdate)))) { changed = true; - continue; + if (!inPlaceUpdate) + continue; } // Make sure that any new operations are inserted at this point.