[MLIR] Handle in-place folding properly in greedy pattern rewrite driver

OperatioFolder::tryToFold performs both true folding and in a few
instances in-place updates through op rewrites. In the latter case, we
should still be applying the supplied pattern rewrites in the same
iteration; however this wasn't the case since tryToFold returned
success() for both true folding and in-place updates, and the patterns
for the in-place updated ops were being applied only in the next
iteration of the driver's outer loop. This fix would make it converge
faster.

Differential Revision: https://reviews.llvm.org/D77485
This commit is contained in:
Uday Bondhugula 2020-04-05 10:05:52 +05:30
parent 1318ddbc14
commit cbcb12fd44
3 changed files with 20 additions and 8 deletions

View File

@ -56,11 +56,12 @@ public:
/// folded results, and returns success. `preReplaceAction` is invoked on `op`
/// before it is replaced. 'processGeneratedConstants' is invoked for any new
/// operations generated when folding. If the op was completely folded it is
/// erased.
/// erased. If it is just updated in place, `inPlaceUpdate` is set to true.
LogicalResult
tryToFold(Operation *op,
function_ref<void(Operation *)> processGeneratedConstants = nullptr,
function_ref<void(Operation *)> preReplaceAction = nullptr);
function_ref<void(Operation *)> preReplaceAction = nullptr,
bool *inPlaceUpdate = nullptr);
/// Notifies that the given constant `op` should be remove from this
/// OperationFolder's internal bookkeeping.

View File

@ -74,7 +74,10 @@ static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder,
LogicalResult OperationFolder::tryToFold(
Operation *op, function_ref<void(Operation *)> processGeneratedConstants,
function_ref<void(Operation *)> preReplaceAction) {
function_ref<void(Operation *)> preReplaceAction, bool *inPlaceUpdate) {
if (inPlaceUpdate)
*inPlaceUpdate = false;
// If this is a unique'd constant, return failure as we know that it has
// already been folded.
if (referencedDialects.count(op))
@ -87,8 +90,11 @@ LogicalResult OperationFolder::tryToFold(
return failure();
// Check to see if the operation was just updated in place.
if (results.empty())
if (results.empty()) {
if (inPlaceUpdate)
*inPlaceUpdate = true;
return success();
}
// Constant folding succeeded. We will start replacing this op's uses and
// erase this op. Invoke the callback provided by the caller to perform any

View File

@ -104,7 +104,8 @@ private:
// be re-added to the worklist. This function should be called when an
// operation is modified or removed, as it may trigger further
// simplifications.
template <typename Operands> void addToWorklist(Operands &&operands) {
template <typename Operands>
void addToWorklist(Operands &&operands) {
for (Value operand : operands) {
// If the use count of this operand is now < 2, we re-add the defining
// operation to the worklist.
@ -133,7 +134,8 @@ private:
};
} // end anonymous namespace
/// Perform the rewrites while folding and erasing any dead ops.
/// Performs the rewrites while folding and erasing any dead ops. Returns true
/// if the rewrite converges in `maxIterations`.
bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions,
int maxIterations) {
// Add the given operation to the worklist.
@ -183,9 +185,12 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions,
};
// Try to fold this op.
if (succeeded(folder.tryToFold(op, collectOps, preReplaceAction))) {
bool inPlaceUpdate;
if ((succeeded(folder.tryToFold(op, collectOps, preReplaceAction,
&inPlaceUpdate)))) {
changed = true;
continue;
if (!inPlaceUpdate)
continue;
}
// Make sure that any new operations are inserted at this point.