[mlir][SCF] Remove empty else blocks of `scf.if` operations.

Differential Revision: https://reviews.llvm.org/D104273
This commit is contained in:
MaheshRavishankar 2021-06-15 14:15:40 -07:00
parent fad8d4230f
commit 621d93d263
2 changed files with 47 additions and 7 deletions

View File

@ -1371,18 +1371,44 @@ struct CombineIfs : public OpRewritePattern<IfOp> {
}
};
/// Pattern to remove an empty else branch.
struct RemoveEmptyElseBranch : public OpRewritePattern<IfOp> {
using OpRewritePattern<IfOp>::OpRewritePattern;
LogicalResult matchAndRewrite(IfOp ifOp,
PatternRewriter &rewriter) const override {
// Cannot remove else region when there are operation results.
if (ifOp.getNumResults())
return failure();
Block *elseBlock = ifOp.elseBlock();
if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
return failure();
auto newIfOp = rewriter.cloneWithoutRegions(ifOp);
rewriter.inlineRegionBefore(ifOp.thenRegion(), newIfOp.thenRegion(),
newIfOp.thenRegion().begin());
rewriter.eraseOp(ifOp);
return success();
}
};
} // namespace
void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<RemoveUnusedResults, RemoveStaticCondition,
ConvertTrivialIfToSelect, ConditionPropagation,
ReplaceIfYieldWithConditionOrValue, CombineIfs>(context);
results
.add<RemoveUnusedResults, RemoveStaticCondition, ConvertTrivialIfToSelect,
ConditionPropagation, ReplaceIfYieldWithConditionOrValue, CombineIfs,
RemoveEmptyElseBranch>(context);
}
Block *IfOp::thenBlock() { return &thenRegion().back(); }
YieldOp IfOp::thenYield() { return cast<YieldOp>(&thenBlock()->back()); }
Block *IfOp::elseBlock() { return &elseRegion().back(); }
Block *IfOp::elseBlock() {
Region &r = elseRegion();
if (r.empty())
return nullptr;
return &r.back();
}
YieldOp IfOp::elseYield() { return cast<YieldOp>(&elseBlock()->back()); }
//===----------------------------------------------------------------------===//

View File

@ -250,6 +250,20 @@ func @empty_if2(%cond: i1) {
// CHECK-NOT: scf.if
// CHECK: return
// ----
func @empty_else(%cond: i1, %v : memref<i1>) {
scf.if %cond {
memref.store %cond, %v[] : memref<i1>
} else {
}
return
}
// CHECK-LABEL: func @empty_else
// CHECK: scf.if
// CHECK-NOT: else
// -----
func @to_select1(%cond: i1) -> index {
@ -475,9 +489,9 @@ func @replace_single_iteration_loop_1() {
// CHECK-LABEL: @replace_single_iteration_loop_2
func @replace_single_iteration_loop_2() {
// CHECK: %[[LB:.*]] = constant 5
%c5 = constant 5 : index
%c6 = constant 6 : index
%c11 = constant 11 : index
%c5 = constant 5 : index
%c6 = constant 6 : index
%c11 = constant 11 : index
// CHECK: %[[INIT:.*]] = "test.init"
%init = "test.init"() : () -> i32
// CHECK-NOT: scf.for