forked from OSchip/llvm-project
[MLIR][SCF] Remove loop invariant arguments of scf.while
-- This commit adds a canonicalization pattern on scf.while to remove the loop invariant arguments. -- An argument is considered loop invariant if the iteration argument value is the same as the corresponding one being yielded (at the same position) in both the before/after block of scf.while. -- For the arguments removed, their use within scf.while and their corresponding scf.while's result are replaced with their corresponding initial value. Signed-off-by: Abhishek Varma <abhishek.varma@polymagelabs.com> Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D116923
This commit is contained in:
parent
42fc05e09c
commit
59b23c4aec
|
@ -2343,6 +2343,297 @@ struct WhileConditionTruth : public OpRewritePattern<WhileOp> {
|
|||
}
|
||||
};
|
||||
|
||||
/// Remove loop invariant arguments from `before` block of scf.while.
|
||||
/// A before block argument is considered loop invariant if :-
|
||||
/// 1. i-th yield operand is equal to the i-th while operand.
|
||||
/// 2. i-th yield operand is k-th after block argument which is (k+1)-th
|
||||
/// condition operand AND this (k+1)-th condition operand is equal to i-th
|
||||
/// iter argument/while operand.
|
||||
/// For the arguments which are removed, their uses inside scf.while
|
||||
/// are replaced with their corresponding initial value.
|
||||
///
|
||||
/// Eg:
|
||||
/// INPUT :-
|
||||
/// %res = scf.while <...> iter_args(%arg0_before = %a, %arg1_before = %b,
|
||||
/// ..., %argN_before = %N)
|
||||
/// {
|
||||
/// ...
|
||||
/// scf.condition(%cond) %arg1_before, %arg0_before,
|
||||
/// %arg2_before, %arg0_before, ...
|
||||
/// } do {
|
||||
/// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
|
||||
/// ..., %argK_after):
|
||||
/// ...
|
||||
/// scf.yield %arg0_after_2, %b, %arg1_after, ..., %argN
|
||||
/// }
|
||||
///
|
||||
/// OUTPUT :-
|
||||
/// %res = scf.while <...> iter_args(%arg2_before = %c, ..., %argN_before =
|
||||
/// %N)
|
||||
/// {
|
||||
/// ...
|
||||
/// scf.condition(%cond) %b, %a, %arg2_before, %a, ...
|
||||
/// } do {
|
||||
/// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
|
||||
/// ..., %argK_after):
|
||||
/// ...
|
||||
/// scf.yield %arg1_after, ..., %argN
|
||||
/// }
|
||||
///
|
||||
/// EXPLANATION:
|
||||
/// We iterate over each yield operand.
|
||||
/// 1. 0-th yield operand %arg0_after_2 is 4-th condition operand
|
||||
/// %arg0_before, which in turn is the 0-th iter argument. So we
|
||||
/// remove 0-th before block argument and yield operand, and replace
|
||||
/// all uses of the 0-th before block argument with its initial value
|
||||
/// %a.
|
||||
/// 2. 1-th yield operand %b is equal to the 1-th iter arg's initial
|
||||
/// value. So we remove this operand and the corresponding before
|
||||
/// block argument and replace all uses of 1-th before block argument
|
||||
/// with %b.
|
||||
struct RemoveLoopInvariantArgsFromBeforeBlock
|
||||
: public OpRewritePattern<WhileOp> {
|
||||
using OpRewritePattern<WhileOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(WhileOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Block &afterBlock = op.getAfter().front();
|
||||
Block::BlockArgListType beforeBlockArgs = op.getBeforeArguments();
|
||||
ConditionOp condOp = op.getConditionOp();
|
||||
OperandRange condOpArgs = condOp.getArgs();
|
||||
Operation *yieldOp = afterBlock.getTerminator();
|
||||
ValueRange yieldOpArgs = yieldOp->getOperands();
|
||||
|
||||
bool canSimplify = false;
|
||||
for (auto it : llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
|
||||
auto index = static_cast<unsigned>(it.index());
|
||||
Value initVal, yieldOpArg;
|
||||
std::tie(initVal, yieldOpArg) = it.value();
|
||||
// If i-th yield operand is equal to the i-th operand of the scf.while,
|
||||
// the i-th before block argument is a loop invariant.
|
||||
if (yieldOpArg == initVal) {
|
||||
canSimplify = true;
|
||||
break;
|
||||
}
|
||||
// If the i-th yield operand is k-th after block argument, then we check
|
||||
// if the (k+1)-th condition op operand is equal to either the i-th before
|
||||
// block argument or the initial value of i-th before block argument. If
|
||||
// the comparison results `true`, i-th before block argument is a loop
|
||||
// invariant.
|
||||
auto yieldOpBlockArg = yieldOpArg.dyn_cast<BlockArgument>();
|
||||
if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
|
||||
Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
|
||||
if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
|
||||
canSimplify = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!canSimplify)
|
||||
return failure();
|
||||
|
||||
SmallVector<Value> newInitArgs, newYieldOpArgs;
|
||||
DenseMap<unsigned, Value> beforeBlockInitValMap;
|
||||
SmallVector<Location> newBeforeBlockArgLocs;
|
||||
for (auto it : llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
|
||||
auto index = static_cast<unsigned>(it.index());
|
||||
Value initVal, yieldOpArg;
|
||||
std::tie(initVal, yieldOpArg) = it.value();
|
||||
|
||||
// If i-th yield operand is equal to the i-th operand of the scf.while,
|
||||
// the i-th before block argument is a loop invariant.
|
||||
if (yieldOpArg == initVal) {
|
||||
beforeBlockInitValMap.insert({index, initVal});
|
||||
continue;
|
||||
} else {
|
||||
// If the i-th yield operand is k-th after block argument, then we check
|
||||
// if the (k+1)-th condition op operand is equal to either the i-th
|
||||
// before block argument or the initial value of i-th before block
|
||||
// argument. If the comparison results `true`, i-th before block
|
||||
// argument is a loop invariant.
|
||||
auto yieldOpBlockArg = yieldOpArg.dyn_cast<BlockArgument>();
|
||||
if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
|
||||
Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
|
||||
if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
|
||||
beforeBlockInitValMap.insert({index, initVal});
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
newInitArgs.emplace_back(initVal);
|
||||
newYieldOpArgs.emplace_back(yieldOpArg);
|
||||
newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc());
|
||||
}
|
||||
|
||||
{
|
||||
OpBuilder::InsertionGuard g(rewriter);
|
||||
rewriter.setInsertionPoint(yieldOp);
|
||||
rewriter.replaceOpWithNewOp<YieldOp>(yieldOp, newYieldOpArgs);
|
||||
}
|
||||
|
||||
auto newWhile =
|
||||
rewriter.create<WhileOp>(op.getLoc(), op.getResultTypes(), newInitArgs);
|
||||
|
||||
Block &newBeforeBlock = *rewriter.createBlock(
|
||||
&newWhile.getBefore(), /*insertPt*/ {},
|
||||
ValueRange(newYieldOpArgs).getTypes(), newBeforeBlockArgLocs);
|
||||
|
||||
Block &beforeBlock = op.getBefore().front();
|
||||
SmallVector<Value> newBeforeBlockArgs(beforeBlock.getNumArguments());
|
||||
// For each i-th before block argument we find it's replacement value as :-
|
||||
// 1. If i-th before block argument is a loop invariant, we fetch it's
|
||||
// initial value from `beforeBlockInitValMap` by querying for key `i`.
|
||||
// 2. Else we fetch j-th new before block argument as the replacement
|
||||
// value of i-th before block argument.
|
||||
for (unsigned i = 0, j = 0, n = beforeBlock.getNumArguments(); i < n; i++) {
|
||||
// If the index 'i' argument was a loop invariant we fetch it's initial
|
||||
// value from `beforeBlockInitValMap`.
|
||||
if (beforeBlockInitValMap.count(i) != 0)
|
||||
newBeforeBlockArgs[i] = beforeBlockInitValMap[i];
|
||||
else
|
||||
newBeforeBlockArgs[i] = newBeforeBlock.getArgument(j++);
|
||||
}
|
||||
|
||||
rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs);
|
||||
rewriter.inlineRegionBefore(op.getAfter(), newWhile.getAfter(),
|
||||
newWhile.getAfter().begin());
|
||||
|
||||
rewriter.replaceOp(op, newWhile.getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Remove loop invariant value from result (condition op) of scf.while.
|
||||
/// A value is considered loop invariant if the final value yielded by
|
||||
/// scf.condition is defined outside of the `before` block. We remove the
|
||||
/// corresponding argument in `after` block and replace the use with the value.
|
||||
/// We also replace the use of the corresponding result of scf.while with the
|
||||
/// value.
|
||||
///
|
||||
/// Eg:
|
||||
/// INPUT :-
|
||||
/// %res_input:K = scf.while <...> iter_args(%arg0_before = , ...,
|
||||
/// %argN_before = %N) {
|
||||
/// ...
|
||||
/// scf.condition(%cond) %arg0_before, %a, %b, %arg1_before, ...
|
||||
/// } do {
|
||||
/// ^bb0(%arg0_after, %arg1_after, %arg2_after, ..., %argK_after):
|
||||
/// ...
|
||||
/// some_func(%arg1_after)
|
||||
/// ...
|
||||
/// scf.yield %arg0_after, %arg2_after, ..., %argN_after
|
||||
/// }
|
||||
///
|
||||
/// OUTPUT :-
|
||||
/// %res_output:M = scf.while <...> iter_args(%arg0 = , ..., %argN = %N) {
|
||||
/// ...
|
||||
/// scf.condition(%cond) %arg0, %arg1, ..., %argM
|
||||
/// } do {
|
||||
/// ^bb0(%arg0, %arg3, ..., %argM):
|
||||
/// ...
|
||||
/// some_func(%a)
|
||||
/// ...
|
||||
/// scf.yield %arg0, %b, ..., %argN
|
||||
/// }
|
||||
///
|
||||
/// EXPLANATION:
|
||||
/// 1. The 1-th and 2-th operand of scf.condition are defined outside the
|
||||
/// before block of scf.while, so they get removed.
|
||||
/// 2. %res_input#1's uses are replaced by %a and %res_input#2's uses are
|
||||
/// replaced by %b.
|
||||
/// 3. The corresponding after block argument %arg1_after's uses are
|
||||
/// replaced by %a and %arg2_after's uses are replaced by %b.
|
||||
struct RemoveLoopInvariantValueYielded : public OpRewritePattern<WhileOp> {
|
||||
using OpRewritePattern<WhileOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(WhileOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Block &beforeBlock = op.getBefore().front();
|
||||
ConditionOp condOp = op.getConditionOp();
|
||||
OperandRange condOpArgs = condOp.getArgs();
|
||||
|
||||
bool canSimplify = false;
|
||||
for (Value condOpArg : condOpArgs) {
|
||||
// Those values not defined within `before` block will be considered as
|
||||
// loop invariant values. We map the corresponding `index` with their
|
||||
// value.
|
||||
if (condOpArg.getParentBlock() != &beforeBlock) {
|
||||
canSimplify = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!canSimplify)
|
||||
return failure();
|
||||
|
||||
Block::BlockArgListType afterBlockArgs = op.getAfterArguments();
|
||||
|
||||
SmallVector<Value> newCondOpArgs;
|
||||
SmallVector<Type> newAfterBlockType;
|
||||
DenseMap<unsigned, Value> condOpInitValMap;
|
||||
SmallVector<Location> newAfterBlockArgLocs;
|
||||
for (auto it : llvm::enumerate(condOpArgs)) {
|
||||
auto index = static_cast<unsigned>(it.index());
|
||||
Value condOpArg = it.value();
|
||||
// Those values not defined within `before` block will be considered as
|
||||
// loop invariant values. We map the corresponding `index` with their
|
||||
// value.
|
||||
if (condOpArg.getParentBlock() != &beforeBlock) {
|
||||
condOpInitValMap.insert({index, condOpArg});
|
||||
} else {
|
||||
newCondOpArgs.emplace_back(condOpArg);
|
||||
newAfterBlockType.emplace_back(condOpArg.getType());
|
||||
newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc());
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
OpBuilder::InsertionGuard g(rewriter);
|
||||
rewriter.setInsertionPoint(condOp);
|
||||
rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
|
||||
newCondOpArgs);
|
||||
}
|
||||
|
||||
auto newWhile = rewriter.create<WhileOp>(op.getLoc(), newAfterBlockType,
|
||||
op.getOperands());
|
||||
|
||||
Block &newAfterBlock =
|
||||
*rewriter.createBlock(&newWhile.getAfter(), /*insertPt*/ {},
|
||||
newAfterBlockType, newAfterBlockArgLocs);
|
||||
|
||||
Block &afterBlock = op.getAfter().front();
|
||||
// Since a new scf.condition op was created, we need to fetch the new
|
||||
// `after` block arguments which will be used while replacing operations of
|
||||
// previous scf.while's `after` blocks. We'd also be fetching new result
|
||||
// values too.
|
||||
SmallVector<Value> newAfterBlockArgs(afterBlock.getNumArguments());
|
||||
SmallVector<Value> newWhileResults(afterBlock.getNumArguments());
|
||||
for (unsigned i = 0, j = 0, n = afterBlock.getNumArguments(); i < n; i++) {
|
||||
Value afterBlockArg, result;
|
||||
// If index 'i' argument was loop invariant we fetch it's value from the
|
||||
// `condOpInitMap` map.
|
||||
if (condOpInitValMap.count(i) != 0) {
|
||||
afterBlockArg = condOpInitValMap[i];
|
||||
result = afterBlockArg;
|
||||
} else {
|
||||
afterBlockArg = newAfterBlock.getArgument(j);
|
||||
result = newWhile.getResult(j);
|
||||
j++;
|
||||
}
|
||||
newAfterBlockArgs[i] = afterBlockArg;
|
||||
newWhileResults[i] = result;
|
||||
}
|
||||
|
||||
rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
|
||||
rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
|
||||
newWhile.getBefore().begin());
|
||||
|
||||
rewriter.replaceOp(op, newWhileResults);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Remove WhileOp results that are also unused in 'after' block.
|
||||
///
|
||||
/// %0:2 = scf.while () : () -> (i32, i64) {
|
||||
|
@ -2552,8 +2843,9 @@ struct WhileUnusedArg : public OpRewritePattern<WhileOp> {
|
|||
|
||||
void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<WhileConditionTruth, WhileUnusedResult, WhileCmpCond,
|
||||
WhileUnusedArg>(context);
|
||||
results.insert<RemoveLoopInvariantArgsFromBeforeBlock,
|
||||
RemoveLoopInvariantValueYielded, WhileConditionTruth,
|
||||
WhileCmpCond, WhileUnusedResult>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -870,6 +870,74 @@ func @while_unused_arg(%x : i32, %y : f64) -> i32 {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @invariant_loop_args_in_same_order
|
||||
// CHECK-SAME: (%[[FUNC_ARG0:.*]]: tensor<i32>)
|
||||
func @invariant_loop_args_in_same_order(%f_arg0: tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) {
|
||||
%cst_0 = arith.constant dense<0> : tensor<i32>
|
||||
%cst_1 = arith.constant dense<1> : tensor<i32>
|
||||
%cst_42 = arith.constant dense<42> : tensor<i32>
|
||||
|
||||
%0:5 = scf.while (%arg0 = %cst_0, %arg1 = %f_arg0, %arg2 = %cst_1, %arg3 = %cst_1, %arg4 = %cst_0) : (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) {
|
||||
%1 = arith.cmpi slt, %arg0, %cst_42 : tensor<i32>
|
||||
%2 = tensor.extract %1[] : tensor<i1>
|
||||
scf.condition(%2) %arg0, %arg1, %arg2, %arg3, %arg4 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
|
||||
} do {
|
||||
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<i32>): // no predecessors
|
||||
// %arg1 here will get replaced by %cst_1
|
||||
%1 = arith.addi %arg0, %arg1 : tensor<i32>
|
||||
%2 = arith.addi %arg2, %arg3 : tensor<i32>
|
||||
scf.yield %1, %arg1, %2, %2, %arg4 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
|
||||
}
|
||||
return %0#0, %0#1, %0#2, %0#3, %0#4 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
|
||||
}
|
||||
// CHECK: %[[CST42:.*]] = arith.constant dense<42>
|
||||
// CHECK: %[[ONE:.*]] = arith.constant dense<1>
|
||||
// CHECK: %[[ZERO:.*]] = arith.constant dense<0>
|
||||
// CHECK: %[[WHILE:.*]]:3 = scf.while (%[[ARG0:.*]] = %[[ZERO]], %[[ARG2:.*]] = %[[ONE]], %[[ARG3:.*]] = %[[ONE]])
|
||||
// CHECK: arith.cmpi slt, %[[ARG0]], %{{.*}}
|
||||
// CHECK: tensor.extract %{{.*}}[]
|
||||
// CHECK: scf.condition(%{{.*}}) %[[ARG0]], %[[ARG2]], %[[ARG3]]
|
||||
// CHECK: } do {
|
||||
// CHECK: ^{{.*}}(%[[ARG0:.*]]: tensor<i32>, %[[ARG2:.*]]: tensor<i32>, %[[ARG3:.*]]: tensor<i32>):
|
||||
// CHECK: %[[VAL0:.*]] = arith.addi %[[ARG0]], %[[FUNC_ARG0]]
|
||||
// CHECK: %[[VAL1:.*]] = arith.addi %[[ARG2]], %[[ARG3]]
|
||||
// CHECK: scf.yield %[[VAL0]], %[[VAL1]], %[[VAL1]]
|
||||
// CHECK: }
|
||||
// CHECK: return %[[WHILE]]#0, %[[FUNC_ARG0]], %[[WHILE]]#1, %[[WHILE]]#2, %[[ZERO]]
|
||||
|
||||
// CHECK-LABEL: @while_loop_invariant_argument_different_order
|
||||
func @while_loop_invariant_argument_different_order() -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) {
|
||||
%cst_0 = arith.constant dense<0> : tensor<i32>
|
||||
%cst_1 = arith.constant dense<1> : tensor<i32>
|
||||
%cst_42 = arith.constant dense<42> : tensor<i32>
|
||||
|
||||
%0:6 = scf.while (%arg0 = %cst_0, %arg1 = %cst_1, %arg2 = %cst_1, %arg3 = %cst_1, %arg4 = %cst_0) : (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) {
|
||||
%1 = arith.cmpi slt, %arg0, %cst_42 : tensor<i32>
|
||||
%2 = tensor.extract %1[] : tensor<i1>
|
||||
scf.condition(%2) %arg1, %arg0, %arg2, %arg0, %arg3, %arg4 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
|
||||
} do {
|
||||
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<i32>, %arg5: tensor<i32>): // no predecessors
|
||||
%1 = arith.addi %arg0, %cst_1 : tensor<i32>
|
||||
%2 = arith.addi %arg2, %arg3 : tensor<i32>
|
||||
scf.yield %arg3, %arg1, %2, %2, %arg4 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
|
||||
}
|
||||
return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
|
||||
}
|
||||
// CHECK: %[[CST42:.*]] = arith.constant dense<42>
|
||||
// CHECK: %[[ONE:.*]] = arith.constant dense<1>
|
||||
// CHECK: %[[ZERO:.*]] = arith.constant dense<0>
|
||||
// CHECK: %[[WHILE:.*]]:2 = scf.while (%[[ARG1:.*]] = %[[ONE]], %[[ARG4:.*]] = %[[ZERO]])
|
||||
// CHECK: arith.cmpi slt, %[[ZERO]], %[[CST42]]
|
||||
// CHECK: tensor.extract %{{.*}}[]
|
||||
// CHECK: scf.condition(%{{.*}}) %[[ARG1]], %[[ARG4]]
|
||||
// CHECK: } do {
|
||||
// CHECK: ^{{.*}}(%{{.*}}: tensor<i32>, %{{.*}}: tensor<i32>):
|
||||
// CHECK: scf.yield %[[ZERO]], %[[ONE]]
|
||||
// CHECK: }
|
||||
// CHECK: return %[[WHILE]]#0, %[[ZERO]], %[[ONE]], %[[ZERO]], %[[ONE]], %[[WHILE]]#1
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @while_unused_result
|
||||
func @while_unused_result() -> i32 {
|
||||
%0:2 = scf.while () : () -> (i32, i64) {
|
||||
|
|
Loading…
Reference in New Issue