[MLIR][SCF] Remove loop invariant arguments of scf.while

-- This commit adds a canonicalization pattern on scf.while to remove
   the loop invariant arguments.
-- An argument is considered loop invariant if the iteration argument value is
   the same as the corresponding one being yielded (at the same position) in both
   the before/after block of scf.while.
-- For the arguments removed, their use within scf.while and their corresponding
   scf.while's result are replaced with their corresponding initial value.

Signed-off-by: Abhishek Varma <abhishek.varma@polymagelabs.com>

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D116923
This commit is contained in:
Abhishek Varma 2022-02-03 17:09:51 +01:00 committed by Alex Zinenko
parent 42fc05e09c
commit 59b23c4aec
2 changed files with 362 additions and 2 deletions

View File

@ -2343,6 +2343,297 @@ struct WhileConditionTruth : public OpRewritePattern<WhileOp> {
} }
}; };
/// Remove loop invariant arguments from `before` block of scf.while.
/// A before block argument is considered loop invariant if :-
/// 1. i-th yield operand is equal to the i-th while operand.
/// 2. i-th yield operand is k-th after block argument which is (k+1)-th
/// condition operand AND this (k+1)-th condition operand is equal to i-th
/// iter argument/while operand.
/// For the arguments which are removed, their uses inside scf.while
/// are replaced with their corresponding initial value.
///
/// Eg:
/// INPUT :-
/// %res = scf.while <...> iter_args(%arg0_before = %a, %arg1_before = %b,
/// ..., %argN_before = %N)
/// {
/// ...
/// scf.condition(%cond) %arg1_before, %arg0_before,
/// %arg2_before, %arg0_before, ...
/// } do {
/// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
/// ..., %argK_after):
/// ...
/// scf.yield %arg0_after_2, %b, %arg1_after, ..., %argN
/// }
///
/// OUTPUT :-
/// %res = scf.while <...> iter_args(%arg2_before = %c, ..., %argN_before =
/// %N)
/// {
/// ...
/// scf.condition(%cond) %b, %a, %arg2_before, %a, ...
/// } do {
/// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
/// ..., %argK_after):
/// ...
/// scf.yield %arg1_after, ..., %argN
/// }
///
/// EXPLANATION:
/// We iterate over each yield operand.
/// 1. 0-th yield operand %arg0_after_2 is 4-th condition operand
/// %arg0_before, which in turn is the 0-th iter argument. So we
/// remove 0-th before block argument and yield operand, and replace
/// all uses of the 0-th before block argument with its initial value
/// %a.
/// 2. 1-th yield operand %b is equal to the 1-th iter arg's initial
/// value. So we remove this operand and the corresponding before
/// block argument and replace all uses of 1-th before block argument
/// with %b.
struct RemoveLoopInvariantArgsFromBeforeBlock
: public OpRewritePattern<WhileOp> {
using OpRewritePattern<WhileOp>::OpRewritePattern;
LogicalResult matchAndRewrite(WhileOp op,
PatternRewriter &rewriter) const override {
Block &afterBlock = op.getAfter().front();
Block::BlockArgListType beforeBlockArgs = op.getBeforeArguments();
ConditionOp condOp = op.getConditionOp();
OperandRange condOpArgs = condOp.getArgs();
Operation *yieldOp = afterBlock.getTerminator();
ValueRange yieldOpArgs = yieldOp->getOperands();
bool canSimplify = false;
for (auto it : llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
auto index = static_cast<unsigned>(it.index());
Value initVal, yieldOpArg;
std::tie(initVal, yieldOpArg) = it.value();
// If i-th yield operand is equal to the i-th operand of the scf.while,
// the i-th before block argument is a loop invariant.
if (yieldOpArg == initVal) {
canSimplify = true;
break;
}
// If the i-th yield operand is k-th after block argument, then we check
// if the (k+1)-th condition op operand is equal to either the i-th before
// block argument or the initial value of i-th before block argument. If
// the comparison results `true`, i-th before block argument is a loop
// invariant.
auto yieldOpBlockArg = yieldOpArg.dyn_cast<BlockArgument>();
if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
canSimplify = true;
break;
}
}
}
if (!canSimplify)
return failure();
SmallVector<Value> newInitArgs, newYieldOpArgs;
DenseMap<unsigned, Value> beforeBlockInitValMap;
SmallVector<Location> newBeforeBlockArgLocs;
for (auto it : llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
auto index = static_cast<unsigned>(it.index());
Value initVal, yieldOpArg;
std::tie(initVal, yieldOpArg) = it.value();
// If i-th yield operand is equal to the i-th operand of the scf.while,
// the i-th before block argument is a loop invariant.
if (yieldOpArg == initVal) {
beforeBlockInitValMap.insert({index, initVal});
continue;
} else {
// If the i-th yield operand is k-th after block argument, then we check
// if the (k+1)-th condition op operand is equal to either the i-th
// before block argument or the initial value of i-th before block
// argument. If the comparison results `true`, i-th before block
// argument is a loop invariant.
auto yieldOpBlockArg = yieldOpArg.dyn_cast<BlockArgument>();
if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
beforeBlockInitValMap.insert({index, initVal});
continue;
}
}
}
newInitArgs.emplace_back(initVal);
newYieldOpArgs.emplace_back(yieldOpArg);
newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc());
}
{
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(yieldOp);
rewriter.replaceOpWithNewOp<YieldOp>(yieldOp, newYieldOpArgs);
}
auto newWhile =
rewriter.create<WhileOp>(op.getLoc(), op.getResultTypes(), newInitArgs);
Block &newBeforeBlock = *rewriter.createBlock(
&newWhile.getBefore(), /*insertPt*/ {},
ValueRange(newYieldOpArgs).getTypes(), newBeforeBlockArgLocs);
Block &beforeBlock = op.getBefore().front();
SmallVector<Value> newBeforeBlockArgs(beforeBlock.getNumArguments());
// For each i-th before block argument we find it's replacement value as :-
// 1. If i-th before block argument is a loop invariant, we fetch it's
// initial value from `beforeBlockInitValMap` by querying for key `i`.
// 2. Else we fetch j-th new before block argument as the replacement
// value of i-th before block argument.
for (unsigned i = 0, j = 0, n = beforeBlock.getNumArguments(); i < n; i++) {
// If the index 'i' argument was a loop invariant we fetch it's initial
// value from `beforeBlockInitValMap`.
if (beforeBlockInitValMap.count(i) != 0)
newBeforeBlockArgs[i] = beforeBlockInitValMap[i];
else
newBeforeBlockArgs[i] = newBeforeBlock.getArgument(j++);
}
rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs);
rewriter.inlineRegionBefore(op.getAfter(), newWhile.getAfter(),
newWhile.getAfter().begin());
rewriter.replaceOp(op, newWhile.getResults());
return success();
}
};
/// Remove loop invariant value from result (condition op) of scf.while.
/// A value is considered loop invariant if the final value yielded by
/// scf.condition is defined outside of the `before` block. We remove the
/// corresponding argument in `after` block and replace the use with the value.
/// We also replace the use of the corresponding result of scf.while with the
/// value.
///
/// Eg:
/// INPUT :-
/// %res_input:K = scf.while <...> iter_args(%arg0_before = , ...,
/// %argN_before = %N) {
/// ...
/// scf.condition(%cond) %arg0_before, %a, %b, %arg1_before, ...
/// } do {
/// ^bb0(%arg0_after, %arg1_after, %arg2_after, ..., %argK_after):
/// ...
/// some_func(%arg1_after)
/// ...
/// scf.yield %arg0_after, %arg2_after, ..., %argN_after
/// }
///
/// OUTPUT :-
/// %res_output:M = scf.while <...> iter_args(%arg0 = , ..., %argN = %N) {
/// ...
/// scf.condition(%cond) %arg0, %arg1, ..., %argM
/// } do {
/// ^bb0(%arg0, %arg3, ..., %argM):
/// ...
/// some_func(%a)
/// ...
/// scf.yield %arg0, %b, ..., %argN
/// }
///
/// EXPLANATION:
/// 1. The 1-th and 2-th operand of scf.condition are defined outside the
/// before block of scf.while, so they get removed.
/// 2. %res_input#1's uses are replaced by %a and %res_input#2's uses are
/// replaced by %b.
/// 3. The corresponding after block argument %arg1_after's uses are
/// replaced by %a and %arg2_after's uses are replaced by %b.
struct RemoveLoopInvariantValueYielded : public OpRewritePattern<WhileOp> {
using OpRewritePattern<WhileOp>::OpRewritePattern;
LogicalResult matchAndRewrite(WhileOp op,
PatternRewriter &rewriter) const override {
Block &beforeBlock = op.getBefore().front();
ConditionOp condOp = op.getConditionOp();
OperandRange condOpArgs = condOp.getArgs();
bool canSimplify = false;
for (Value condOpArg : condOpArgs) {
// Those values not defined within `before` block will be considered as
// loop invariant values. We map the corresponding `index` with their
// value.
if (condOpArg.getParentBlock() != &beforeBlock) {
canSimplify = true;
break;
}
}
if (!canSimplify)
return failure();
Block::BlockArgListType afterBlockArgs = op.getAfterArguments();
SmallVector<Value> newCondOpArgs;
SmallVector<Type> newAfterBlockType;
DenseMap<unsigned, Value> condOpInitValMap;
SmallVector<Location> newAfterBlockArgLocs;
for (auto it : llvm::enumerate(condOpArgs)) {
auto index = static_cast<unsigned>(it.index());
Value condOpArg = it.value();
// Those values not defined within `before` block will be considered as
// loop invariant values. We map the corresponding `index` with their
// value.
if (condOpArg.getParentBlock() != &beforeBlock) {
condOpInitValMap.insert({index, condOpArg});
} else {
newCondOpArgs.emplace_back(condOpArg);
newAfterBlockType.emplace_back(condOpArg.getType());
newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc());
}
}
{
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(condOp);
rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
newCondOpArgs);
}
auto newWhile = rewriter.create<WhileOp>(op.getLoc(), newAfterBlockType,
op.getOperands());
Block &newAfterBlock =
*rewriter.createBlock(&newWhile.getAfter(), /*insertPt*/ {},
newAfterBlockType, newAfterBlockArgLocs);
Block &afterBlock = op.getAfter().front();
// Since a new scf.condition op was created, we need to fetch the new
// `after` block arguments which will be used while replacing operations of
// previous scf.while's `after` blocks. We'd also be fetching new result
// values too.
SmallVector<Value> newAfterBlockArgs(afterBlock.getNumArguments());
SmallVector<Value> newWhileResults(afterBlock.getNumArguments());
for (unsigned i = 0, j = 0, n = afterBlock.getNumArguments(); i < n; i++) {
Value afterBlockArg, result;
// If index 'i' argument was loop invariant we fetch it's value from the
// `condOpInitMap` map.
if (condOpInitValMap.count(i) != 0) {
afterBlockArg = condOpInitValMap[i];
result = afterBlockArg;
} else {
afterBlockArg = newAfterBlock.getArgument(j);
result = newWhile.getResult(j);
j++;
}
newAfterBlockArgs[i] = afterBlockArg;
newWhileResults[i] = result;
}
rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
newWhile.getBefore().begin());
rewriter.replaceOp(op, newWhileResults);
return success();
}
};
/// Remove WhileOp results that are also unused in 'after' block. /// Remove WhileOp results that are also unused in 'after' block.
/// ///
/// %0:2 = scf.while () : () -> (i32, i64) { /// %0:2 = scf.while () : () -> (i32, i64) {
@ -2552,8 +2843,9 @@ struct WhileUnusedArg : public OpRewritePattern<WhileOp> {
void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results, void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) { MLIRContext *context) {
results.insert<WhileConditionTruth, WhileUnusedResult, WhileCmpCond, results.insert<RemoveLoopInvariantArgsFromBeforeBlock,
WhileUnusedArg>(context); RemoveLoopInvariantValueYielded, WhileConditionTruth,
WhileCmpCond, WhileUnusedResult>(context);
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -870,6 +870,74 @@ func @while_unused_arg(%x : i32, %y : f64) -> i32 {
// ----- // -----
// CHECK-LABEL: @invariant_loop_args_in_same_order
// CHECK-SAME: (%[[FUNC_ARG0:.*]]: tensor<i32>)
func @invariant_loop_args_in_same_order(%f_arg0: tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) {
%cst_0 = arith.constant dense<0> : tensor<i32>
%cst_1 = arith.constant dense<1> : tensor<i32>
%cst_42 = arith.constant dense<42> : tensor<i32>
%0:5 = scf.while (%arg0 = %cst_0, %arg1 = %f_arg0, %arg2 = %cst_1, %arg3 = %cst_1, %arg4 = %cst_0) : (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) {
%1 = arith.cmpi slt, %arg0, %cst_42 : tensor<i32>
%2 = tensor.extract %1[] : tensor<i1>
scf.condition(%2) %arg0, %arg1, %arg2, %arg3, %arg4 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
} do {
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<i32>): // no predecessors
// %arg1 here will get replaced by %cst_1
%1 = arith.addi %arg0, %arg1 : tensor<i32>
%2 = arith.addi %arg2, %arg3 : tensor<i32>
scf.yield %1, %arg1, %2, %2, %arg4 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
}
return %0#0, %0#1, %0#2, %0#3, %0#4 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
}
// CHECK: %[[CST42:.*]] = arith.constant dense<42>
// CHECK: %[[ONE:.*]] = arith.constant dense<1>
// CHECK: %[[ZERO:.*]] = arith.constant dense<0>
// CHECK: %[[WHILE:.*]]:3 = scf.while (%[[ARG0:.*]] = %[[ZERO]], %[[ARG2:.*]] = %[[ONE]], %[[ARG3:.*]] = %[[ONE]])
// CHECK: arith.cmpi slt, %[[ARG0]], %{{.*}}
// CHECK: tensor.extract %{{.*}}[]
// CHECK: scf.condition(%{{.*}}) %[[ARG0]], %[[ARG2]], %[[ARG3]]
// CHECK: } do {
// CHECK: ^{{.*}}(%[[ARG0:.*]]: tensor<i32>, %[[ARG2:.*]]: tensor<i32>, %[[ARG3:.*]]: tensor<i32>):
// CHECK: %[[VAL0:.*]] = arith.addi %[[ARG0]], %[[FUNC_ARG0]]
// CHECK: %[[VAL1:.*]] = arith.addi %[[ARG2]], %[[ARG3]]
// CHECK: scf.yield %[[VAL0]], %[[VAL1]], %[[VAL1]]
// CHECK: }
// CHECK: return %[[WHILE]]#0, %[[FUNC_ARG0]], %[[WHILE]]#1, %[[WHILE]]#2, %[[ZERO]]
// CHECK-LABEL: @while_loop_invariant_argument_different_order
func @while_loop_invariant_argument_different_order() -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) {
%cst_0 = arith.constant dense<0> : tensor<i32>
%cst_1 = arith.constant dense<1> : tensor<i32>
%cst_42 = arith.constant dense<42> : tensor<i32>
%0:6 = scf.while (%arg0 = %cst_0, %arg1 = %cst_1, %arg2 = %cst_1, %arg3 = %cst_1, %arg4 = %cst_0) : (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) {
%1 = arith.cmpi slt, %arg0, %cst_42 : tensor<i32>
%2 = tensor.extract %1[] : tensor<i1>
scf.condition(%2) %arg1, %arg0, %arg2, %arg0, %arg3, %arg4 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
} do {
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<i32>, %arg5: tensor<i32>): // no predecessors
%1 = arith.addi %arg0, %cst_1 : tensor<i32>
%2 = arith.addi %arg2, %arg3 : tensor<i32>
scf.yield %arg3, %arg1, %2, %2, %arg4 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
}
return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
}
// CHECK: %[[CST42:.*]] = arith.constant dense<42>
// CHECK: %[[ONE:.*]] = arith.constant dense<1>
// CHECK: %[[ZERO:.*]] = arith.constant dense<0>
// CHECK: %[[WHILE:.*]]:2 = scf.while (%[[ARG1:.*]] = %[[ONE]], %[[ARG4:.*]] = %[[ZERO]])
// CHECK: arith.cmpi slt, %[[ZERO]], %[[CST42]]
// CHECK: tensor.extract %{{.*}}[]
// CHECK: scf.condition(%{{.*}}) %[[ARG1]], %[[ARG4]]
// CHECK: } do {
// CHECK: ^{{.*}}(%{{.*}}: tensor<i32>, %{{.*}}: tensor<i32>):
// CHECK: scf.yield %[[ZERO]], %[[ONE]]
// CHECK: }
// CHECK: return %[[WHILE]]#0, %[[ZERO]], %[[ONE]], %[[ZERO]], %[[ONE]], %[[WHILE]]#1
// -----
// CHECK-LABEL: @while_unused_result // CHECK-LABEL: @while_unused_result
func @while_unused_result() -> i32 { func @while_unused_result() -> i32 {
%0:2 = scf.while () : () -> (i32, i64) { %0:2 = scf.while () : () -> (i32, i64) {