forked from OSchip/llvm-project
[mlir][scf] Canonicalize nested scf.if's to scf.if + arith.and
Differential Revision: https://reviews.llvm.org/D115930
This commit is contained in:
parent
de90490060
commit
c7f96d5ab1
|
@ -1596,14 +1596,60 @@ struct RemoveEmptyElseBranch : public OpRewritePattern<IfOp> {
|
|||
}
|
||||
};
|
||||
|
||||
/// Convert nested `if`s into `arith.andi` + single `if`.
|
||||
///
|
||||
/// scf.if %arg0 {
|
||||
/// scf.if %arg1 {
|
||||
/// ...
|
||||
/// scf.yield
|
||||
/// }
|
||||
/// scf.yield
|
||||
/// }
|
||||
/// becomes
|
||||
///
|
||||
/// %0 = arith.andi %arg0, %arg1
|
||||
/// scf.if %0 {
|
||||
/// ...
|
||||
/// scf.yield
|
||||
/// }
|
||||
struct CombineNestedIfs : public OpRewritePattern<IfOp> {
|
||||
using OpRewritePattern<IfOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(IfOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Both `if` ops must not yield results and have only `then` block.
|
||||
if (op->getNumResults() != 0 || op.elseBlock())
|
||||
return failure();
|
||||
|
||||
auto nestedOps = op.thenBlock()->without_terminator();
|
||||
// Nested `if` must be the only op in block.
|
||||
if (!llvm::hasSingleElement(nestedOps))
|
||||
return failure();
|
||||
|
||||
auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
|
||||
if (!nestedIf || nestedIf->getNumResults() != 0 || nestedIf.elseBlock())
|
||||
return failure();
|
||||
|
||||
Location loc = op.getLoc();
|
||||
Value newCondition = rewriter.create<arith::AndIOp>(loc, op.condition(),
|
||||
nestedIf.condition());
|
||||
auto newIf = rewriter.create<IfOp>(loc, newCondition);
|
||||
Block *newIfBlock = newIf.thenBlock();
|
||||
rewriter.eraseOp(newIfBlock->getTerminator());
|
||||
rewriter.mergeBlocks(nestedIf.thenBlock(), newIfBlock);
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results
|
||||
.add<RemoveUnusedResults, RemoveStaticCondition, ConvertTrivialIfToSelect,
|
||||
ConditionPropagation, ReplaceIfYieldWithConditionOrValue, CombineIfs,
|
||||
RemoveEmptyElseBranch>(context);
|
||||
results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
|
||||
ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
|
||||
RemoveStaticCondition, RemoveUnusedResults,
|
||||
ReplaceIfYieldWithConditionOrValue>(context);
|
||||
}
|
||||
|
||||
Block *IfOp::thenBlock() { return &getThenRegion().back(); }
|
||||
|
|
|
@ -429,6 +429,24 @@ func @replace_false_if_with_values() {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @merge_nested_if
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i1)
|
||||
func @merge_nested_if(%arg0: i1, %arg1: i1) {
|
||||
// CHECK: %[[COND:.*]] = arith.andi %[[ARG0]], %[[ARG1]]
|
||||
// CHECK: scf.if %[[COND]] {
|
||||
// CHECK-NEXT: "test.op"()
|
||||
scf.if %arg0 {
|
||||
scf.if %arg1 {
|
||||
"test.op"() : () -> ()
|
||||
scf.yield
|
||||
}
|
||||
scf.yield
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @remove_zero_iteration_loop
|
||||
func @remove_zero_iteration_loop() {
|
||||
%c42 = arith.constant 42 : index
|
||||
|
|
Loading…
Reference in New Issue