[MLIR][SCF] Remove unused arguments to whileop

Canonicalize away unused arguments to the before region of a whileOp

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D117059
This commit is contained in:
William S. Moses 2022-01-11 18:45:17 -05:00
parent f0b2a1a629
commit d23fa4f2f1
3 changed files with 84 additions and 1 deletions

View File

@ -686,6 +686,8 @@ def WhileOp : SCF_Op<"while",
let extraClassDeclaration = [{
OperandRange getSuccessorEntryOperands(unsigned index);
ConditionOp getConditionOp();
YieldOp getYieldOp();
Block::BlockArgListType getBeforeArguments();
Block::BlockArgListType getAfterArguments();
}];

View File

@ -2171,6 +2171,14 @@ ConditionOp WhileOp::getConditionOp() {
return cast<ConditionOp>(getBefore().front().getTerminator());
}
YieldOp WhileOp::getYieldOp() {
return cast<YieldOp>(getAfter().front().getTerminator());
}
Block::BlockArgListType WhileOp::getBeforeArguments() {
return getBefore().front().getArguments();
}
Block::BlockArgListType WhileOp::getAfterArguments() {
return getAfter().front().getArguments();
}
@ -2508,11 +2516,60 @@ struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> {
return success(changed);
}
};
struct WhileUnusedArg : public OpRewritePattern<WhileOp> {
using OpRewritePattern<WhileOp>::OpRewritePattern;
LogicalResult matchAndRewrite(WhileOp op,
PatternRewriter &rewriter) const override {
if (!llvm::any_of(op.getBeforeArguments(),
[](Value arg) { return arg.use_empty(); }))
return failure();
YieldOp yield = op.getYieldOp();
// Collect results mapping, new terminator args and new result types.
SmallVector<Value> newYields;
SmallVector<Value> newInits;
SmallVector<unsigned> argsToErase;
for (const auto &it : llvm::enumerate(llvm::zip(
op.getBeforeArguments(), yield.getOperands(), op.getInits()))) {
Value beforeArg = std::get<0>(it.value());
Value yieldValue = std::get<1>(it.value());
Value initValue = std::get<2>(it.value());
if (beforeArg.use_empty()) {
argsToErase.push_back(it.index());
} else {
newYields.emplace_back(yieldValue);
newInits.emplace_back(initValue);
}
}
if (argsToErase.size() == 0)
return failure();
rewriter.startRootUpdate(op);
op.getBefore().front().eraseArguments(argsToErase);
rewriter.finalizeRootUpdate(op);
WhileOp replacement =
rewriter.create<WhileOp>(op.getLoc(), op.getResultTypes(), newInits);
replacement.getBefore().takeBody(op.getBefore());
replacement.getAfter().takeBody(op.getAfter());
rewriter.replaceOp(op, replacement.getResults());
rewriter.setInsertionPoint(yield);
rewriter.replaceOpWithNewOp<YieldOp>(yield, newYields);
return success();
}
};
} // namespace
void WhileOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<WhileConditionTruth, WhileUnusedResult, WhileCmpCond>(context);
results.insert<WhileConditionTruth, WhileUnusedResult, WhileCmpCond,
WhileUnusedArg>(context);
}
//===----------------------------------------------------------------------===//

View File

@ -846,6 +846,30 @@ func @while_cond_true() -> i1 {
// -----
// CHECK-LABEL: @while_unused_arg
func @while_unused_arg(%x : i32, %y : f64) -> i32 {
%0 = scf.while (%arg1 = %x, %arg2 = %y) : (i32, f64) -> (i32) {
%condition = "test.condition"(%arg1) : (i32) -> i1
scf.condition(%condition) %arg1 : i32
} do {
^bb0(%arg1: i32):
%next = "test.use"(%arg1) : (i32) -> (i32)
scf.yield %next, %y : i32, f64
}
return %0 : i32
}
// CHECK-NEXT: %[[res:.*]] = scf.while (%[[arg2:.+]] = %{{.*}}) : (i32) -> i32 {
// CHECK-NEXT: %[[cmp:.*]] = "test.condition"(%[[arg2]]) : (i32) -> i1
// CHECK-NEXT: scf.condition(%[[cmp]]) %[[arg2]] : i32
// CHECK-NEXT: } do {
// CHECK-NEXT: ^bb0(%[[post:.+]]: i32): // no predecessors
// CHECK-NEXT: %[[next:.+]] = "test.use"(%[[post]]) : (i32) -> i32
// CHECK-NEXT: scf.yield %[[next]] : i32
// CHECK-NEXT: }
// CHECK-NEXT: return %[[res]] : i32
// -----
// CHECK-LABEL: @while_unused_result
func @while_unused_result() -> i32 {
%0:2 = scf.while () : () -> (i32, i64) {