[mlir][scf] Canonicalize scf.while with unused results

Differential Revision: https://reviews.llvm.org/D114291
This commit is contained in:
Butygin 2021-11-20 01:56:23 +03:00
parent ba4411e7c6
commit 7f5d9bf13a
2 changed files with 122 additions and 3 deletions

View File

@ -2255,11 +2255,102 @@ struct WhileConditionTruth : public OpRewritePattern<WhileOp> {
return success(replaced);
}
};
/// Remove WhileOp results that are also unused in 'after' block.
///
/// %0:2 = scf.while () : () -> (i32, i64) {
/// %condition = "test.condition"() : () -> i1
/// %v1 = "test.get_some_value"() : () -> i32
/// %v2 = "test.get_some_value"() : () -> i64
/// scf.condition(%condition) %v1, %v2 : i32, i64
/// } do {
/// ^bb0(%arg0: i32, %arg1: i64):
/// "test.use"(%arg0) : (i32) -> ()
/// scf.yield
/// }
/// return %0#0 : i32
///
/// becomes
/// %0 = scf.while () : () -> (i32) {
/// %condition = "test.condition"() : () -> i1
/// %v1 = "test.get_some_value"() : () -> i32
/// %v2 = "test.get_some_value"() : () -> i64
/// scf.condition(%condition) %v1 : i32
/// } do {
/// ^bb0(%arg0: i32):
/// "test.use"(%arg0) : (i32) -> ()
/// scf.yield
/// }
/// return %0 : i32
struct WhileUnusedResult : public OpRewritePattern<WhileOp> {
using OpRewritePattern<WhileOp>::OpRewritePattern;
LogicalResult matchAndRewrite(WhileOp op,
PatternRewriter &rewriter) const override {
auto term = op.getConditionOp();
auto afterArgs = op.getAfterArguments();
auto termArgs = term.args();
// Collect results mapping, new terminator args and new result types.
SmallVector<unsigned> newResultsIndices;
SmallVector<Type> newResultTypes;
SmallVector<Value> newTermArgs;
bool needUpdate = false;
for (auto it :
llvm::enumerate(llvm::zip(op.getResults(), afterArgs, termArgs))) {
auto i = static_cast<unsigned>(it.index());
Value result = std::get<0>(it.value());
Value afterArg = std::get<1>(it.value());
Value termArg = std::get<2>(it.value());
if (result.use_empty() && afterArg.use_empty()) {
needUpdate = true;
} else {
newResultsIndices.emplace_back(i);
newTermArgs.emplace_back(termArg);
newResultTypes.emplace_back(result.getType());
}
}
if (!needUpdate)
return failure();
{
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(term);
rewriter.replaceOpWithNewOp<ConditionOp>(term, term.condition(),
newTermArgs);
}
auto newWhile =
rewriter.create<WhileOp>(op.getLoc(), newResultTypes, op.inits());
Block &newAfterBlock = *rewriter.createBlock(
&newWhile.after(), /*insertPt*/ {}, newResultTypes);
// Build new results list and new after block args (unused entries will be
// null).
SmallVector<Value> newResults(op.getNumResults());
SmallVector<Value> newAfterBlockArgs(op.getNumResults());
for (auto it : llvm::enumerate(newResultsIndices)) {
newResults[it.value()] = newWhile.getResult(it.index());
newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(it.index());
}
rewriter.inlineRegionBefore(op.before(), newWhile.before(),
newWhile.before().begin());
Block &afterBlock = op.after().front();
rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
rewriter.replaceOp(op, newResults);
return success();
}
};
} // namespace
void WhileOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<WhileConditionTruth>(context);
results.insert<WhileConditionTruth, WhileUnusedResult>(context);
}
//===----------------------------------------------------------------------===//

View File

@ -782,7 +782,7 @@ func @replace_if_with_cond3(%arg0 : i1, %arg2: i64) -> (i32, i64) {
// -----
// CHECK-LABEL: @while_cond_true
func @while_cond_true() {
func @while_cond_true() -> i1 {
%0 = scf.while () : () -> i1 {
%condition = "test.condition"() : () -> i1
scf.condition(%condition) %condition : i1
@ -791,7 +791,7 @@ func @while_cond_true() {
"test.use"(%arg0) : (i1) -> ()
scf.yield
}
return
return %0 : i1
}
// CHECK-NEXT: %[[true:.+]] = arith.constant true
// CHECK-NEXT: %{{.+}} = scf.while : () -> i1 {
@ -805,6 +805,34 @@ func @while_cond_true() {
// -----
// CHECK-LABEL: @while_unused_result
func @while_unused_result() -> i32 {
%0:2 = scf.while () : () -> (i32, i64) {
%condition = "test.condition"() : () -> i1
%v1 = "test.get_some_value"() : () -> i32
%v2 = "test.get_some_value"() : () -> i64
scf.condition(%condition) %v1, %v2 : i32, i64
} do {
^bb0(%arg0: i32, %arg1: i64):
"test.use"(%arg0) : (i32) -> ()
scf.yield
}
return %0#0 : i32
}
// CHECK-NEXT: %[[res:.*]] = scf.while : () -> i32 {
// CHECK-NEXT: %[[cmp:.*]] = "test.condition"() : () -> i1
// CHECK-NEXT: %[[val:.*]] = "test.get_some_value"() : () -> i32
// CHECK-NEXT: %{{.*}} = "test.get_some_value"() : () -> i64
// CHECK-NEXT: scf.condition(%[[cmp]]) %[[val]] : i32
// CHECK-NEXT: } do {
// CHECK-NEXT: ^bb0(%[[arg:.*]]: i32): // no predecessors
// CHECK-NEXT: "test.use"(%[[arg]]) : (i32) -> ()
// CHECK-NEXT: scf.yield
// CHECK-NEXT: }
// CHECK-NEXT: return %[[res]] : i32
// -----
// CHECK-LABEL: @combineIfs
func @combineIfs(%arg0 : i1, %arg2: i64) -> (i32, i32) {
%res = scf.if %arg0 -> i32 {