[MLIR] Handle in-place folding properly in greedy pattern rewrite driver

OperatioFolder::tryToFold performs both true folding and in a few
instances in-place updates through op rewrites. In the latter case, we
should still be applying the supplied pattern rewrites in the same
iteration; however this wasn't the case since tryToFold returned
success() for both true folding and in-place updates, and the patterns
for the in-place updated ops were being applied only in the next
iteration of the driver's outer loop. This fix would make it converge
faster.

Differential Revision: https://reviews.llvm.org/D77485
This commit is contained in:
Uday Bondhugula 2020-04-05 10:05:52 +05:30
parent 1318ddbc14
commit cbcb12fd44
3 changed files with 20 additions and 8 deletions

View File

@ -56,11 +56,12 @@ public:
/// folded results, and returns success. `preReplaceAction` is invoked on `op` /// folded results, and returns success. `preReplaceAction` is invoked on `op`
/// before it is replaced. 'processGeneratedConstants' is invoked for any new /// before it is replaced. 'processGeneratedConstants' is invoked for any new
/// operations generated when folding. If the op was completely folded it is /// operations generated when folding. If the op was completely folded it is
/// erased. /// erased. If it is just updated in place, `inPlaceUpdate` is set to true.
LogicalResult LogicalResult
tryToFold(Operation *op, tryToFold(Operation *op,
function_ref<void(Operation *)> processGeneratedConstants = nullptr, function_ref<void(Operation *)> processGeneratedConstants = nullptr,
function_ref<void(Operation *)> preReplaceAction = nullptr); function_ref<void(Operation *)> preReplaceAction = nullptr,
bool *inPlaceUpdate = nullptr);
/// Notifies that the given constant `op` should be remove from this /// Notifies that the given constant `op` should be remove from this
/// OperationFolder's internal bookkeeping. /// OperationFolder's internal bookkeeping.

View File

@ -74,7 +74,10 @@ static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder,
LogicalResult OperationFolder::tryToFold( LogicalResult OperationFolder::tryToFold(
Operation *op, function_ref<void(Operation *)> processGeneratedConstants, Operation *op, function_ref<void(Operation *)> processGeneratedConstants,
function_ref<void(Operation *)> preReplaceAction) { function_ref<void(Operation *)> preReplaceAction, bool *inPlaceUpdate) {
if (inPlaceUpdate)
*inPlaceUpdate = false;
// If this is a unique'd constant, return failure as we know that it has // If this is a unique'd constant, return failure as we know that it has
// already been folded. // already been folded.
if (referencedDialects.count(op)) if (referencedDialects.count(op))
@ -87,8 +90,11 @@ LogicalResult OperationFolder::tryToFold(
return failure(); return failure();
// Check to see if the operation was just updated in place. // Check to see if the operation was just updated in place.
if (results.empty()) if (results.empty()) {
if (inPlaceUpdate)
*inPlaceUpdate = true;
return success(); return success();
}
// Constant folding succeeded. We will start replacing this op's uses and // Constant folding succeeded. We will start replacing this op's uses and
// erase this op. Invoke the callback provided by the caller to perform any // erase this op. Invoke the callback provided by the caller to perform any

View File

@ -104,7 +104,8 @@ private:
// be re-added to the worklist. This function should be called when an // be re-added to the worklist. This function should be called when an
// operation is modified or removed, as it may trigger further // operation is modified or removed, as it may trigger further
// simplifications. // simplifications.
template <typename Operands> void addToWorklist(Operands &&operands) { template <typename Operands>
void addToWorklist(Operands &&operands) {
for (Value operand : operands) { for (Value operand : operands) {
// If the use count of this operand is now < 2, we re-add the defining // If the use count of this operand is now < 2, we re-add the defining
// operation to the worklist. // operation to the worklist.
@ -133,7 +134,8 @@ private:
}; };
} // end anonymous namespace } // end anonymous namespace
/// Perform the rewrites while folding and erasing any dead ops. /// Performs the rewrites while folding and erasing any dead ops. Returns true
/// if the rewrite converges in `maxIterations`.
bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions, bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions,
int maxIterations) { int maxIterations) {
// Add the given operation to the worklist. // Add the given operation to the worklist.
@ -183,9 +185,12 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions,
}; };
// Try to fold this op. // Try to fold this op.
if (succeeded(folder.tryToFold(op, collectOps, preReplaceAction))) { bool inPlaceUpdate;
if ((succeeded(folder.tryToFold(op, collectOps, preReplaceAction,
&inPlaceUpdate)))) {
changed = true; changed = true;
continue; if (!inPlaceUpdate)
continue;
} }
// Make sure that any new operations are inserted at this point. // Make sure that any new operations are inserted at this point.