forked from OSchip/llvm-project
[mlir][scf] Canonicalize scf.while with unused results
Differential Revision: https://reviews.llvm.org/D114291
This commit is contained in:
parent
ba4411e7c6
commit
7f5d9bf13a
|
@ -2255,11 +2255,102 @@ struct WhileConditionTruth : public OpRewritePattern<WhileOp> {
|
|||
return success(replaced);
|
||||
}
|
||||
};
|
||||
|
||||
/// Remove WhileOp results that are also unused in 'after' block.
|
||||
///
|
||||
/// %0:2 = scf.while () : () -> (i32, i64) {
|
||||
/// %condition = "test.condition"() : () -> i1
|
||||
/// %v1 = "test.get_some_value"() : () -> i32
|
||||
/// %v2 = "test.get_some_value"() : () -> i64
|
||||
/// scf.condition(%condition) %v1, %v2 : i32, i64
|
||||
/// } do {
|
||||
/// ^bb0(%arg0: i32, %arg1: i64):
|
||||
/// "test.use"(%arg0) : (i32) -> ()
|
||||
/// scf.yield
|
||||
/// }
|
||||
/// return %0#0 : i32
|
||||
///
|
||||
/// becomes
|
||||
/// %0 = scf.while () : () -> (i32) {
|
||||
/// %condition = "test.condition"() : () -> i1
|
||||
/// %v1 = "test.get_some_value"() : () -> i32
|
||||
/// %v2 = "test.get_some_value"() : () -> i64
|
||||
/// scf.condition(%condition) %v1 : i32
|
||||
/// } do {
|
||||
/// ^bb0(%arg0: i32):
|
||||
/// "test.use"(%arg0) : (i32) -> ()
|
||||
/// scf.yield
|
||||
/// }
|
||||
/// return %0 : i32
|
||||
struct WhileUnusedResult : public OpRewritePattern<WhileOp> {
|
||||
using OpRewritePattern<WhileOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(WhileOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto term = op.getConditionOp();
|
||||
auto afterArgs = op.getAfterArguments();
|
||||
auto termArgs = term.args();
|
||||
|
||||
// Collect results mapping, new terminator args and new result types.
|
||||
SmallVector<unsigned> newResultsIndices;
|
||||
SmallVector<Type> newResultTypes;
|
||||
SmallVector<Value> newTermArgs;
|
||||
bool needUpdate = false;
|
||||
for (auto it :
|
||||
llvm::enumerate(llvm::zip(op.getResults(), afterArgs, termArgs))) {
|
||||
auto i = static_cast<unsigned>(it.index());
|
||||
Value result = std::get<0>(it.value());
|
||||
Value afterArg = std::get<1>(it.value());
|
||||
Value termArg = std::get<2>(it.value());
|
||||
if (result.use_empty() && afterArg.use_empty()) {
|
||||
needUpdate = true;
|
||||
} else {
|
||||
newResultsIndices.emplace_back(i);
|
||||
newTermArgs.emplace_back(termArg);
|
||||
newResultTypes.emplace_back(result.getType());
|
||||
}
|
||||
}
|
||||
|
||||
if (!needUpdate)
|
||||
return failure();
|
||||
|
||||
{
|
||||
OpBuilder::InsertionGuard g(rewriter);
|
||||
rewriter.setInsertionPoint(term);
|
||||
rewriter.replaceOpWithNewOp<ConditionOp>(term, term.condition(),
|
||||
newTermArgs);
|
||||
}
|
||||
|
||||
auto newWhile =
|
||||
rewriter.create<WhileOp>(op.getLoc(), newResultTypes, op.inits());
|
||||
|
||||
Block &newAfterBlock = *rewriter.createBlock(
|
||||
&newWhile.after(), /*insertPt*/ {}, newResultTypes);
|
||||
|
||||
// Build new results list and new after block args (unused entries will be
|
||||
// null).
|
||||
SmallVector<Value> newResults(op.getNumResults());
|
||||
SmallVector<Value> newAfterBlockArgs(op.getNumResults());
|
||||
for (auto it : llvm::enumerate(newResultsIndices)) {
|
||||
newResults[it.value()] = newWhile.getResult(it.index());
|
||||
newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(it.index());
|
||||
}
|
||||
|
||||
rewriter.inlineRegionBefore(op.before(), newWhile.before(),
|
||||
newWhile.before().begin());
|
||||
|
||||
Block &afterBlock = op.after().front();
|
||||
rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
|
||||
|
||||
rewriter.replaceOp(op, newResults);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void WhileOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<WhileConditionTruth>(context);
|
||||
results.insert<WhileConditionTruth, WhileUnusedResult>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -782,7 +782,7 @@ func @replace_if_with_cond3(%arg0 : i1, %arg2: i64) -> (i32, i64) {
|
|||
// -----
|
||||
|
||||
// CHECK-LABEL: @while_cond_true
|
||||
func @while_cond_true() {
|
||||
func @while_cond_true() -> i1 {
|
||||
%0 = scf.while () : () -> i1 {
|
||||
%condition = "test.condition"() : () -> i1
|
||||
scf.condition(%condition) %condition : i1
|
||||
|
@ -791,7 +791,7 @@ func @while_cond_true() {
|
|||
"test.use"(%arg0) : (i1) -> ()
|
||||
scf.yield
|
||||
}
|
||||
return
|
||||
return %0 : i1
|
||||
}
|
||||
// CHECK-NEXT: %[[true:.+]] = arith.constant true
|
||||
// CHECK-NEXT: %{{.+}} = scf.while : () -> i1 {
|
||||
|
@ -805,6 +805,34 @@ func @while_cond_true() {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @while_unused_result
|
||||
func @while_unused_result() -> i32 {
|
||||
%0:2 = scf.while () : () -> (i32, i64) {
|
||||
%condition = "test.condition"() : () -> i1
|
||||
%v1 = "test.get_some_value"() : () -> i32
|
||||
%v2 = "test.get_some_value"() : () -> i64
|
||||
scf.condition(%condition) %v1, %v2 : i32, i64
|
||||
} do {
|
||||
^bb0(%arg0: i32, %arg1: i64):
|
||||
"test.use"(%arg0) : (i32) -> ()
|
||||
scf.yield
|
||||
}
|
||||
return %0#0 : i32
|
||||
}
|
||||
// CHECK-NEXT: %[[res:.*]] = scf.while : () -> i32 {
|
||||
// CHECK-NEXT: %[[cmp:.*]] = "test.condition"() : () -> i1
|
||||
// CHECK-NEXT: %[[val:.*]] = "test.get_some_value"() : () -> i32
|
||||
// CHECK-NEXT: %{{.*}} = "test.get_some_value"() : () -> i64
|
||||
// CHECK-NEXT: scf.condition(%[[cmp]]) %[[val]] : i32
|
||||
// CHECK-NEXT: } do {
|
||||
// CHECK-NEXT: ^bb0(%[[arg:.*]]: i32): // no predecessors
|
||||
// CHECK-NEXT: "test.use"(%[[arg]]) : (i32) -> ()
|
||||
// CHECK-NEXT: scf.yield
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %[[res]] : i32
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @combineIfs
|
||||
func @combineIfs(%arg0 : i1, %arg2: i64) -> (i32, i32) {
|
||||
%res = scf.if %arg0 -> i32 {
|
||||
|
|
Loading…
Reference in New Issue