[MLIR] Simplify affine.if ops with trivial conditions

The commit simplifies affine.if ops :
The affine if operation gets removed if the condition is universally true or false and then/else block is merged with the parent block.

Signed-off-by: Shashij Gupta shashij.gupta@polymagelabs.com

Reviewed By: bondhugula, pr4tgpt

Differential Revision: https://reviews.llvm.org/D104015
This commit is contained in:
Shashij gupta 2021-06-12 19:28:40 +05:30 committed by Uday Bondhugula
parent b4583a5ad7
commit 466e5aba64
3 changed files with 137 additions and 25 deletions

View File

@ -1896,6 +1896,47 @@ struct SimplifyDeadElse : public OpRewritePattern<AffineIfOp> {
return success();
}
};
/// Removes Affine.If cond if the condition is always true or false in certain
/// trivial cases. Promotes the then/else block in the parent operation block.
struct AlwaysTrueOrFalseIf : public OpRewritePattern<AffineIfOp> {
using OpRewritePattern<AffineIfOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AffineIfOp op,
PatternRewriter &rewriter) const override {
// If affine.if is returning results then don't remove it.
// TODO: Similar simplication can be done when affine.if return results.
if (op.getNumResults() > 0)
return failure();
IntegerSet conditionSet = op.getIntegerSet();
Block *blockToMove;
if (conditionSet.isEmptyIntegerSet()) {
// If the else region is not there, simply remove the Affine.if
// operation.
if (!op.hasElse()) {
rewriter.eraseOp(op);
return success();
}
blockToMove = op.getElseBlock();
} else if (conditionSet.getNumEqualities() == 1 &&
conditionSet.getNumInequalities() == 0 &&
conditionSet.getConstraint(0) == 0) {
// Condition to check for trivially true condition (0==0).
blockToMove = op.getThenBlock();
} else {
return failure();
}
// Remove the terminator from the block as it already exists in parent
// block.
Operation *blockTerminator = blockToMove->getTerminator();
rewriter.eraseOp(blockTerminator);
rewriter.mergeBlockBefore(blockToMove, op);
rewriter.eraseOp(op);
return success();
}
};
} // end anonymous namespace.
static LogicalResult verify(AffineIfOp op) {
@ -2059,7 +2100,7 @@ LogicalResult AffineIfOp::fold(ArrayRef<Attribute>,
void AffineIfOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<SimplifyDeadElse>(context);
results.add<SimplifyDeadElse, AlwaysTrueOrFalseIf>(context);
}
//===----------------------------------------------------------------------===//

View File

@ -245,9 +245,7 @@ func @multiple_if(%N : index) {
}
return
}
// CHECK: affine.if
// CHECK-NEXT: call
// CHECK-NEXT: }
// CHECK: call
// CHECK-NEXT: affine.if
// CHECK-NEXT: affine.for
// CHECK-NEXT: call

View File

@ -1,6 +1,5 @@
// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -simplify-affine-structures | FileCheck %s
// CHECK-DAG: #[[$SET_EMPTY:.*]] = affine_set<() : (1 == 0)>
// CHECK-DAG: #[[$SET_2D:.*]] = affine_set<(d0, d1) : (d0 - 100 == 0, d1 - 10 == 0, -d0 + 100 >= 0, d1 >= 0)>
// CHECK-DAG: #[[$SET_7_11:.*]] = affine_set<(d0, d1) : (d0 * 7 + d1 * 5 + 88 == 0, d0 * 5 - d1 * 11 + 60 == 0, d0 * 11 + d1 * 7 - 24 == 0, d0 * 7 + d1 * 5 + 88 == 0)>
@ -11,7 +10,7 @@ func private @external() -> ()
func @test_gaussian_elimination_empty_set0() {
affine.for %arg0 = 1 to 10 {
affine.for %arg1 = 1 to 100 {
// CHECK: affine.if #[[$SET_EMPTY]]()
// CHECK-NOT: affine.if
affine.if affine_set<(d0, d1) : (2 == 0)>(%arg0, %arg1) {
call @external() : () -> ()
}
@ -24,7 +23,7 @@ func @test_gaussian_elimination_empty_set0() {
func @test_gaussian_elimination_empty_set1() {
affine.for %arg0 = 1 to 10 {
affine.for %arg1 = 1 to 100 {
// CHECK: affine.if #[[$SET_EMPTY]]()
// CHECK-NOT: affine.if
affine.if affine_set<(d0, d1) : (1 >= 0, -1 >= 0)> (%arg0, %arg1) {
call @external() : () -> ()
}
@ -52,7 +51,7 @@ func @test_gaussian_elimination_empty_set3() {
%c11 = constant 11 : index
affine.for %arg0 = 1 to 10 {
affine.for %arg1 = 1 to 100 {
// CHECK: #[[$SET_EMPTY]]()
// CHECK-NOT: affine.if
affine.if affine_set<(d0, d1)[s0, s1] : (d0 - s0 == 0, d0 + s0 == 0, s0 - 1 == 0)>(%arg0, %arg1)[%c7, %c11] {
call @external() : () -> ()
}
@ -95,7 +94,7 @@ func @test_gaussian_elimination_empty_set5() {
%c11 = constant 11 : index
affine.for %arg0 = 1 to 10 {
affine.for %arg1 = 1 to 100 {
// CHECK: #[[$SET_EMPTY]]()
// CHECK-NOT: affine.if
affine.if #set_2d_empty(%arg0, %arg1)[%c7, %c11] {
call @external() : () -> ()
}
@ -162,33 +161,33 @@ func @test_fuzz_explosion(%arg0 : index, %arg1 : index, %arg2 : index, %arg3 : i
func @test_empty_set(%N : index) {
affine.for %i = 0 to 10 {
affine.for %j = 0 to 10 {
// CHECK: affine.if #[[$SET_EMPTY]]()
// CHECK-NOT: affine.if
affine.if affine_set<(d0, d1) : (d0 - d1 >= 0, d1 - d0 - 1 >= 0)>(%i, %j) {
"foo"() : () -> ()
}
// CHECK: affine.if #[[$SET_EMPTY]]()
// CHECK-NOT: affine.if
affine.if affine_set<(d0) : (d0 >= 0, -d0 - 1 >= 0)>(%i) {
"bar"() : () -> ()
}
// CHECK: affine.if #[[$SET_EMPTY]]()
// CHECK-NOT: affine.if
affine.if affine_set<(d0) : (d0 >= 0, -d0 - 1 >= 0)>(%i) {
"foo"() : () -> ()
}
// CHECK: affine.if #[[$SET_EMPTY]]()
// CHECK-NOT: affine.if
affine.if affine_set<(d0)[s0, s1] : (d0 >= 0, -d0 + s0 - 1 >= 0, -s0 >= 0)>(%i)[%N, %N] {
"bar"() : () -> ()
}
// CHECK: affine.if #[[$SET_EMPTY]]()
// CHECK-NOT: affine.if
// The set below implies d0 = d1; so d1 >= d0, but d0 >= d1 + 1.
affine.if affine_set<(d0, d1, d2) : (d0 - d1 == 0, d2 - d0 >= 0, d0 - d1 - 1 >= 0)>(%i, %j, %N) {
"foo"() : () -> ()
}
// CHECK: affine.if #[[$SET_EMPTY]]()
// CHECK-NOT: affine.if
// The set below has rational solutions but no integer solutions; GCD test catches it.
affine.if affine_set<(d0, d1) : (d0*2 -d1*2 - 1 == 0, d0 >= 0, -d0 + 100 >= 0, d1 >= 0, -d1 + 100 >= 0)>(%i, %j) {
"foo"() : () -> ()
}
// CHECK: affine.if #[[$SET_EMPTY]]()
// CHECK-NOT: affine.if
affine.if affine_set<(d0, d1) : (d1 == 0, d0 - 1 >= 0, - d0 - 1 >= 0)>(%i, %j) {
"foo"() : () -> ()
}
@ -198,12 +197,12 @@ func @test_empty_set(%N : index) {
affine.for %k = 0 to 10 {
affine.for %l = 0 to 10 {
// Empty because no multiple of 8 lies between 4 and 7.
// CHECK: affine.if #[[$SET_EMPTY]]()
// CHECK-NOT: affine.if
affine.if affine_set<(d0) : (8*d0 - 4 >= 0, -8*d0 + 7 >= 0)>(%k) {
"foo"() : () -> ()
}
// Same as above but with equalities and inequalities.
// CHECK: affine.if #[[$SET_EMPTY]]()
// CHECK-NOT: affine.if
affine.if affine_set<(d0, d1) : (d0 - 4*d1 == 0, 4*d1 - 5 >= 0, -4*d1 + 7 >= 0)>(%k, %l) {
"foo"() : () -> ()
}
@ -211,12 +210,12 @@ func @test_empty_set(%N : index) {
// 8*d1 here is a multiple of 4, and so can't lie between 9 and 11. GCD
// tightening will tighten constraints to 4*d0 + 8*d1 >= 12 and 4*d0 +
// 8*d1 <= 8; hence infeasible.
// CHECK: affine.if #[[$SET_EMPTY]]()
// CHECK-NOT: affine.if
affine.if affine_set<(d0, d1) : (4*d0 + 8*d1 - 9 >= 0, -4*d0 - 8*d1 + 11 >= 0)>(%k, %l) {
"foo"() : () -> ()
}
// Same as above but with equalities added into the mix.
// CHECK: affine.if #[[$SET_EMPTY]]()
// CHECK-NOT: affine.if
affine.if affine_set<(d0, d1, d2) : (d0 - 4*d2 == 0, d0 + 8*d1 - 9 >= 0, -d0 - 8*d1 + 11 >= 0)>(%k, %k, %l) {
"foo"() : () -> ()
}
@ -224,7 +223,7 @@ func @test_empty_set(%N : index) {
}
affine.for %m = 0 to 10 {
// CHECK: affine.if #[[$SET_EMPTY]]()
// CHECK-NOT: affine.if
affine.if affine_set<(d0) : (d0 mod 2 - 3 == 0)> (%m) {
"foo"() : () -> ()
}
@ -239,8 +238,6 @@ func @test_empty_set(%N : index) {
func private @external() -> ()
// CHECK-DAG: #[[$SET:.*]] = affine_set<()[s0] : (s0 >= 0, -s0 + 50 >= 0)
// CHECK-DAG: #[[$EMPTY_SET:.*]] = affine_set<() : (1 == 0)
// CHECK-DAG: #[[$UNIV_SET:.*]] = affine_set<() : (0 == 0)
// CHECK-LABEL: func @simplify_set
func @simplify_set(%a : index, %b : index) {
@ -248,11 +245,11 @@ func @simplify_set(%a : index, %b : index) {
affine.if affine_set<(d0, d1) : (d0 - d1 + d1 + d0 >= 0, 2 >= 0, d0 >= 0, -d0 + 50 >= 0, -d0 + 100 >= 0)>(%a, %b) {
call @external() : () -> ()
}
// CHECK: affine.if #[[$EMPTY_SET]]
// CHECK-NOT: affine.if
affine.if affine_set<(d0, d1) : (d0 mod 2 - 1 == 0, d0 - 2 * (d0 floordiv 2) == 0)>(%a, %b) {
call @external() : () -> ()
}
// CHECK: affine.if #[[$UNIV_SET]]
// CHECK-NOT: affine.if
affine.if affine_set<(d0, d1) : (1 >= 0, 3 >= 0)>(%a, %b) {
call @external() : () -> ()
}
@ -325,3 +322,79 @@ func @semiaffine_unsimplified_symbol(%arg0: index, %arg1: index) -> index {
// CHECK: %[[CST:.*]] = constant 0
return %a : index
}
// -----
// Two external functions that we will use in bodies to avoid DCE.
func private @external() -> ()
func private @external1() -> ()
// CHECK-LABEL: func @test_always_true_if_elimination() {
func @test_always_true_if_elimination() {
affine.for %arg0 = 1 to 10 {
affine.for %arg1 = 1 to 100 {
affine.if affine_set<(d0, d1) : (1 >= 0)> (%arg0, %arg1) {
call @external() : () -> ()
} else {
call @external1() : () -> ()
}
}
}
return
}
// CHECK: affine.for
// CHECK-NEXT: affine.for
// CHECK-NEXT: call @external()
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-LABEL: func @test_always_false_if_elimination() {
func @test_always_false_if_elimination() {
// CHECK: affine.for
affine.for %arg0 = 1 to 10 {
// CHECK: affine.for
affine.for %arg1 = 1 to 100 {
// CHECK: call @external1()
// CHECK-NOT: affine.if
affine.if affine_set<(d0, d1) : (-1 >= 0)> (%arg0, %arg1) {
call @external() : () -> ()
} else {
call @external1() : () -> ()
}
}
}
return
}
// Testing: Affine.If is not trivially true or false, nothing happens.
// CHECK-LABEL: func @test_dimensional_if_elimination() {
func @test_dimensional_if_elimination() {
affine.for %arg0 = 1 to 10 {
affine.for %arg1 = 1 to 100 {
// CHECK: affine.if
// CHECK: } else {
affine.if affine_set<(d0, d1) : (d0-1 == 0)> (%arg0, %arg1) {
call @external() : () -> ()
} else {
call @external() : () -> ()
}
}
}
return
}
// Testing: Affine.If don't get removed if it is returning results.
// CHECK-LABEL: func @test_num_results_if_elimination
func @test_num_results_if_elimination() -> f32 {
%zero = constant 0.0 : f32
// CHECK: affine.if
%0 = affine.if affine_set<() : ()> () -> f32 {
affine.yield %zero : f32
// CHECK: else {
} else {
affine.yield %zero : f32
}
return %0 : f32
}