[MLIR][SCF] Canonicalize while statement whose cmp condition is recomputed in the after region

Given a while loop whose condition is given by a cmp, don't recomputed the comparison (or its inverse) in the after region, instead use a constant since  the original condition must be true if we branched to the after region.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D117047
This commit is contained in:
William S. Moses 2022-01-11 15:39:18 -05:00
parent ff11cd9550
commit 97567bde5b
4 changed files with 123 additions and 2 deletions

View File

@ -121,6 +121,8 @@ Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder,
/// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc,
Value lhs, Value rhs);
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred);
} // namespace arith
} // namespace mlir

View File

@ -40,7 +40,7 @@ static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
}
/// Invert an integer comparison predicate.
static arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred) {
arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
switch (pred) {
case arith::CmpIPredicate::eq:
return arith::CmpIPredicate::ne;

View File

@ -2443,11 +2443,76 @@ struct WhileUnusedResult : public OpRewritePattern<WhileOp> {
return success();
}
};
/// Replace operations equivalent to the condition in the do block with true,
/// since otherwise the block would not be evaluated.
///
/// scf.while (..) : (i32, ...) -> ... {
/// %z = ... : i32
/// %condition = cmpi pred %z, %a
/// scf.condition(%condition) %z : i32, ...
/// } do {
/// ^bb0(%arg0: i32, ...):
/// %condition2 = cmpi pred %arg0, %a
/// use(%condition2)
/// ...
///
/// becomes
/// scf.while (..) : (i32, ...) -> ... {
/// %z = ... : i32
/// %condition = cmpi pred %z, %a
/// scf.condition(%condition) %z : i32, ...
/// } do {
/// ^bb0(%arg0: i32, ...):
/// use(%true)
/// ...
struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> {
using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
LogicalResult matchAndRewrite(scf::WhileOp op,
PatternRewriter &rewriter) const override {
using namespace scf;
auto cond = op.getConditionOp();
auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
if (!cmp)
return failure();
bool changed = false;
for (auto tup :
llvm::zip(cond.getArgs(), op.getAfter().front().getArguments())) {
for (size_t opIdx = 0; opIdx < 2; opIdx++) {
if (std::get<0>(tup) != cmp.getOperand(opIdx))
continue;
for (OpOperand &u :
llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
if (!cmp2)
continue;
// For a binary operator 1-opIdx gets the other side.
if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
continue;
bool samePredicate;
if (cmp2.getPredicate() == cmp.getPredicate())
samePredicate = true;
else if (cmp2.getPredicate() ==
arith::invertPredicate(cmp.getPredicate()))
samePredicate = false;
else
continue;
rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(cmp2, samePredicate,
1);
changed = true;
}
}
}
return success(changed);
}
};
} // namespace
void WhileOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<WhileConditionTruth, WhileUnusedResult>(context);
results.insert<WhileConditionTruth, WhileUnusedResult, WhileCmpCond>(context);
}
//===----------------------------------------------------------------------===//

View File

@ -872,6 +872,60 @@ func @while_unused_result() -> i32 {
// CHECK-NEXT: }
// CHECK-NEXT: return %[[res]] : i32
// CHECK-LABEL: @while_cmp_lhs
func @while_cmp_lhs(%arg0 : i32) {
%0 = scf.while () : () -> i32 {
%val = "test.val"() : () -> i32
%condition = arith.cmpi ne, %val, %arg0 : i32
scf.condition(%condition) %val : i32
} do {
^bb0(%val2: i32):
%condition2 = arith.cmpi ne, %val2, %arg0 : i32
%negcondition2 = arith.cmpi eq, %val2, %arg0 : i32
"test.use"(%condition2, %negcondition2, %val2) : (i1, i1, i32) -> ()
scf.yield
}
return
}
// CHECK-DAG: %[[true:.+]] = arith.constant true
// CHECK-DAG: %[[false:.+]] = arith.constant false
// CHECK-DAG: %{{.+}} = scf.while : () -> i32 {
// CHECK-NEXT: %[[val:.+]] = "test.val"
// CHECK-NEXT: %[[cmp:.+]] = arith.cmpi ne, %[[val]], %arg0 : i32
// CHECK-NEXT: scf.condition(%[[cmp]]) %[[val]] : i32
// CHECK-NEXT: } do {
// CHECK-NEXT: ^bb0(%arg1: i32): // no predecessors
// CHECK-NEXT: "test.use"(%[[true]], %[[false]], %arg1) : (i1, i1, i32) -> ()
// CHECK-NEXT: scf.yield
// CHECK-NEXT: }
// CHECK-LABEL: @while_cmp_rhs
func @while_cmp_rhs(%arg0 : i32) {
%0 = scf.while () : () -> i32 {
%val = "test.val"() : () -> i32
%condition = arith.cmpi ne, %arg0, %val : i32
scf.condition(%condition) %val : i32
} do {
^bb0(%val2: i32):
%condition2 = arith.cmpi ne, %arg0, %val2 : i32
%negcondition2 = arith.cmpi eq, %arg0, %val2 : i32
"test.use"(%condition2, %negcondition2, %val2) : (i1, i1, i32) -> ()
scf.yield
}
return
}
// CHECK-DAG: %[[true:.+]] = arith.constant true
// CHECK-DAG: %[[false:.+]] = arith.constant false
// CHECK-DAG: %{{.+}} = scf.while : () -> i32 {
// CHECK-NEXT: %[[val:.+]] = "test.val"
// CHECK-NEXT: %[[cmp:.+]] = arith.cmpi ne, %arg0, %[[val]] : i32
// CHECK-NEXT: scf.condition(%[[cmp]]) %[[val]] : i32
// CHECK-NEXT: } do {
// CHECK-NEXT: ^bb0(%arg1: i32): // no predecessors
// CHECK-NEXT: "test.use"(%[[true]], %[[false]], %arg1) : (i1, i1, i32) -> ()
// CHECK-NEXT: scf.yield
// CHECK-NEXT: }
// -----
// CHECK-LABEL: @combineIfs