[mlir][scf] Canonicalize nested scf.if's to scf.if + arith.and

Differential Revision: https://reviews.llvm.org/D115930
This commit is contained in:
Butygin 2021-10-28 19:04:35 +03:00
parent de90490060
commit c7f96d5ab1
2 changed files with 68 additions and 4 deletions

View File

@ -1596,14 +1596,60 @@ struct RemoveEmptyElseBranch : public OpRewritePattern<IfOp> {
}
};
/// Convert nested `if`s into `arith.andi` + single `if`.
///
/// scf.if %arg0 {
/// scf.if %arg1 {
/// ...
/// scf.yield
/// }
/// scf.yield
/// }
/// becomes
///
/// %0 = arith.andi %arg0, %arg1
/// scf.if %0 {
/// ...
/// scf.yield
/// }
struct CombineNestedIfs : public OpRewritePattern<IfOp> {
using OpRewritePattern<IfOp>::OpRewritePattern;
LogicalResult matchAndRewrite(IfOp op,
PatternRewriter &rewriter) const override {
// Both `if` ops must not yield results and have only `then` block.
if (op->getNumResults() != 0 || op.elseBlock())
return failure();
auto nestedOps = op.thenBlock()->without_terminator();
// Nested `if` must be the only op in block.
if (!llvm::hasSingleElement(nestedOps))
return failure();
auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
if (!nestedIf || nestedIf->getNumResults() != 0 || nestedIf.elseBlock())
return failure();
Location loc = op.getLoc();
Value newCondition = rewriter.create<arith::AndIOp>(loc, op.condition(),
nestedIf.condition());
auto newIf = rewriter.create<IfOp>(loc, newCondition);
Block *newIfBlock = newIf.thenBlock();
rewriter.eraseOp(newIfBlock->getTerminator());
rewriter.mergeBlocks(nestedIf.thenBlock(), newIfBlock);
rewriter.eraseOp(op);
return success();
}
};
} // namespace
void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results
.add<RemoveUnusedResults, RemoveStaticCondition, ConvertTrivialIfToSelect,
ConditionPropagation, ReplaceIfYieldWithConditionOrValue, CombineIfs,
RemoveEmptyElseBranch>(context);
results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
RemoveStaticCondition, RemoveUnusedResults,
ReplaceIfYieldWithConditionOrValue>(context);
}
Block *IfOp::thenBlock() { return &getThenRegion().back(); }

View File

@ -429,6 +429,24 @@ func @replace_false_if_with_values() {
// -----
// CHECK-LABEL: @merge_nested_if
// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i1)
func @merge_nested_if(%arg0: i1, %arg1: i1) {
// CHECK: %[[COND:.*]] = arith.andi %[[ARG0]], %[[ARG1]]
// CHECK: scf.if %[[COND]] {
// CHECK-NEXT: "test.op"()
scf.if %arg0 {
scf.if %arg1 {
"test.op"() : () -> ()
scf.yield
}
scf.yield
}
return
}
// -----
// CHECK-LABEL: @remove_zero_iteration_loop
func @remove_zero_iteration_loop() {
%c42 = arith.constant 42 : index