forked from OSchip/llvm-project
[MLIR][SCF] Remove unused arguments to whileop
Canonicalize away unused arguments to the before region of a whileOp Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D117059
This commit is contained in:
parent
f0b2a1a629
commit
d23fa4f2f1
|
@ -686,6 +686,8 @@ def WhileOp : SCF_Op<"while",
|
|||
let extraClassDeclaration = [{
|
||||
OperandRange getSuccessorEntryOperands(unsigned index);
|
||||
ConditionOp getConditionOp();
|
||||
YieldOp getYieldOp();
|
||||
Block::BlockArgListType getBeforeArguments();
|
||||
Block::BlockArgListType getAfterArguments();
|
||||
}];
|
||||
|
||||
|
|
|
@ -2171,6 +2171,14 @@ ConditionOp WhileOp::getConditionOp() {
|
|||
return cast<ConditionOp>(getBefore().front().getTerminator());
|
||||
}
|
||||
|
||||
YieldOp WhileOp::getYieldOp() {
|
||||
return cast<YieldOp>(getAfter().front().getTerminator());
|
||||
}
|
||||
|
||||
Block::BlockArgListType WhileOp::getBeforeArguments() {
|
||||
return getBefore().front().getArguments();
|
||||
}
|
||||
|
||||
Block::BlockArgListType WhileOp::getAfterArguments() {
|
||||
return getAfter().front().getArguments();
|
||||
}
|
||||
|
@ -2508,11 +2516,60 @@ struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> {
|
|||
return success(changed);
|
||||
}
|
||||
};
|
||||
|
||||
struct WhileUnusedArg : public OpRewritePattern<WhileOp> {
|
||||
using OpRewritePattern<WhileOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(WhileOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
|
||||
if (!llvm::any_of(op.getBeforeArguments(),
|
||||
[](Value arg) { return arg.use_empty(); }))
|
||||
return failure();
|
||||
|
||||
YieldOp yield = op.getYieldOp();
|
||||
|
||||
// Collect results mapping, new terminator args and new result types.
|
||||
SmallVector<Value> newYields;
|
||||
SmallVector<Value> newInits;
|
||||
SmallVector<unsigned> argsToErase;
|
||||
for (const auto &it : llvm::enumerate(llvm::zip(
|
||||
op.getBeforeArguments(), yield.getOperands(), op.getInits()))) {
|
||||
Value beforeArg = std::get<0>(it.value());
|
||||
Value yieldValue = std::get<1>(it.value());
|
||||
Value initValue = std::get<2>(it.value());
|
||||
if (beforeArg.use_empty()) {
|
||||
argsToErase.push_back(it.index());
|
||||
} else {
|
||||
newYields.emplace_back(yieldValue);
|
||||
newInits.emplace_back(initValue);
|
||||
}
|
||||
}
|
||||
|
||||
if (argsToErase.size() == 0)
|
||||
return failure();
|
||||
|
||||
rewriter.startRootUpdate(op);
|
||||
op.getBefore().front().eraseArguments(argsToErase);
|
||||
rewriter.finalizeRootUpdate(op);
|
||||
|
||||
WhileOp replacement =
|
||||
rewriter.create<WhileOp>(op.getLoc(), op.getResultTypes(), newInits);
|
||||
replacement.getBefore().takeBody(op.getBefore());
|
||||
replacement.getAfter().takeBody(op.getAfter());
|
||||
rewriter.replaceOp(op, replacement.getResults());
|
||||
|
||||
rewriter.setInsertionPoint(yield);
|
||||
rewriter.replaceOpWithNewOp<YieldOp>(yield, newYields);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void WhileOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<WhileConditionTruth, WhileUnusedResult, WhileCmpCond>(context);
|
||||
results.insert<WhileConditionTruth, WhileUnusedResult, WhileCmpCond,
|
||||
WhileUnusedArg>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -846,6 +846,30 @@ func @while_cond_true() -> i1 {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @while_unused_arg
|
||||
func @while_unused_arg(%x : i32, %y : f64) -> i32 {
|
||||
%0 = scf.while (%arg1 = %x, %arg2 = %y) : (i32, f64) -> (i32) {
|
||||
%condition = "test.condition"(%arg1) : (i32) -> i1
|
||||
scf.condition(%condition) %arg1 : i32
|
||||
} do {
|
||||
^bb0(%arg1: i32):
|
||||
%next = "test.use"(%arg1) : (i32) -> (i32)
|
||||
scf.yield %next, %y : i32, f64
|
||||
}
|
||||
return %0 : i32
|
||||
}
|
||||
// CHECK-NEXT: %[[res:.*]] = scf.while (%[[arg2:.+]] = %{{.*}}) : (i32) -> i32 {
|
||||
// CHECK-NEXT: %[[cmp:.*]] = "test.condition"(%[[arg2]]) : (i32) -> i1
|
||||
// CHECK-NEXT: scf.condition(%[[cmp]]) %[[arg2]] : i32
|
||||
// CHECK-NEXT: } do {
|
||||
// CHECK-NEXT: ^bb0(%[[post:.+]]: i32): // no predecessors
|
||||
// CHECK-NEXT: %[[next:.+]] = "test.use"(%[[post]]) : (i32) -> i32
|
||||
// CHECK-NEXT: scf.yield %[[next]] : i32
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %[[res]] : i32
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @while_unused_result
|
||||
func @while_unused_result() -> i32 {
|
||||
%0:2 = scf.while () : () -> (i32, i64) {
|
||||
|
|
Loading…
Reference in New Issue