forked from OSchip/llvm-project
[MLIR] Handle in-place folding properly in greedy pattern rewrite driver
OperatioFolder::tryToFold performs both true folding and in a few instances in-place updates through op rewrites. In the latter case, we should still be applying the supplied pattern rewrites in the same iteration; however this wasn't the case since tryToFold returned success() for both true folding and in-place updates, and the patterns for the in-place updated ops were being applied only in the next iteration of the driver's outer loop. This fix would make it converge faster. Differential Revision: https://reviews.llvm.org/D77485
This commit is contained in:
parent
1318ddbc14
commit
cbcb12fd44
|
@ -56,11 +56,12 @@ public:
|
||||||
/// folded results, and returns success. `preReplaceAction` is invoked on `op`
|
/// folded results, and returns success. `preReplaceAction` is invoked on `op`
|
||||||
/// before it is replaced. 'processGeneratedConstants' is invoked for any new
|
/// before it is replaced. 'processGeneratedConstants' is invoked for any new
|
||||||
/// operations generated when folding. If the op was completely folded it is
|
/// operations generated when folding. If the op was completely folded it is
|
||||||
/// erased.
|
/// erased. If it is just updated in place, `inPlaceUpdate` is set to true.
|
||||||
LogicalResult
|
LogicalResult
|
||||||
tryToFold(Operation *op,
|
tryToFold(Operation *op,
|
||||||
function_ref<void(Operation *)> processGeneratedConstants = nullptr,
|
function_ref<void(Operation *)> processGeneratedConstants = nullptr,
|
||||||
function_ref<void(Operation *)> preReplaceAction = nullptr);
|
function_ref<void(Operation *)> preReplaceAction = nullptr,
|
||||||
|
bool *inPlaceUpdate = nullptr);
|
||||||
|
|
||||||
/// Notifies that the given constant `op` should be remove from this
|
/// Notifies that the given constant `op` should be remove from this
|
||||||
/// OperationFolder's internal bookkeeping.
|
/// OperationFolder's internal bookkeeping.
|
||||||
|
|
|
@ -74,7 +74,10 @@ static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder,
|
||||||
|
|
||||||
LogicalResult OperationFolder::tryToFold(
|
LogicalResult OperationFolder::tryToFold(
|
||||||
Operation *op, function_ref<void(Operation *)> processGeneratedConstants,
|
Operation *op, function_ref<void(Operation *)> processGeneratedConstants,
|
||||||
function_ref<void(Operation *)> preReplaceAction) {
|
function_ref<void(Operation *)> preReplaceAction, bool *inPlaceUpdate) {
|
||||||
|
if (inPlaceUpdate)
|
||||||
|
*inPlaceUpdate = false;
|
||||||
|
|
||||||
// If this is a unique'd constant, return failure as we know that it has
|
// If this is a unique'd constant, return failure as we know that it has
|
||||||
// already been folded.
|
// already been folded.
|
||||||
if (referencedDialects.count(op))
|
if (referencedDialects.count(op))
|
||||||
|
@ -87,8 +90,11 @@ LogicalResult OperationFolder::tryToFold(
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
// Check to see if the operation was just updated in place.
|
// Check to see if the operation was just updated in place.
|
||||||
if (results.empty())
|
if (results.empty()) {
|
||||||
|
if (inPlaceUpdate)
|
||||||
|
*inPlaceUpdate = true;
|
||||||
return success();
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
// Constant folding succeeded. We will start replacing this op's uses and
|
// Constant folding succeeded. We will start replacing this op's uses and
|
||||||
// erase this op. Invoke the callback provided by the caller to perform any
|
// erase this op. Invoke the callback provided by the caller to perform any
|
||||||
|
|
|
@ -104,7 +104,8 @@ private:
|
||||||
// be re-added to the worklist. This function should be called when an
|
// be re-added to the worklist. This function should be called when an
|
||||||
// operation is modified or removed, as it may trigger further
|
// operation is modified or removed, as it may trigger further
|
||||||
// simplifications.
|
// simplifications.
|
||||||
template <typename Operands> void addToWorklist(Operands &&operands) {
|
template <typename Operands>
|
||||||
|
void addToWorklist(Operands &&operands) {
|
||||||
for (Value operand : operands) {
|
for (Value operand : operands) {
|
||||||
// If the use count of this operand is now < 2, we re-add the defining
|
// If the use count of this operand is now < 2, we re-add the defining
|
||||||
// operation to the worklist.
|
// operation to the worklist.
|
||||||
|
@ -133,7 +134,8 @@ private:
|
||||||
};
|
};
|
||||||
} // end anonymous namespace
|
} // end anonymous namespace
|
||||||
|
|
||||||
/// Perform the rewrites while folding and erasing any dead ops.
|
/// Performs the rewrites while folding and erasing any dead ops. Returns true
|
||||||
|
/// if the rewrite converges in `maxIterations`.
|
||||||
bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions,
|
bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions,
|
||||||
int maxIterations) {
|
int maxIterations) {
|
||||||
// Add the given operation to the worklist.
|
// Add the given operation to the worklist.
|
||||||
|
@ -183,9 +185,12 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Try to fold this op.
|
// Try to fold this op.
|
||||||
if (succeeded(folder.tryToFold(op, collectOps, preReplaceAction))) {
|
bool inPlaceUpdate;
|
||||||
|
if ((succeeded(folder.tryToFold(op, collectOps, preReplaceAction,
|
||||||
|
&inPlaceUpdate)))) {
|
||||||
changed = true;
|
changed = true;
|
||||||
continue;
|
if (!inPlaceUpdate)
|
||||||
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make sure that any new operations are inserted at this point.
|
// Make sure that any new operations are inserted at this point.
|
||||||
|
|
Loading…
Reference in New Issue