forked from OSchip/llvm-project
[MLIR][SCF] Canonicalize while statement whose cmp condition is recomputed in the after region
Given a while loop whose condition is given by a cmp, don't recomputed the comparison (or its inverse) in the after region, instead use a constant since the original condition must be true if we branched to the after region. Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D117047
This commit is contained in:
parent
ff11cd9550
commit
97567bde5b
|
@ -121,6 +121,8 @@ Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder,
|
|||
/// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
|
||||
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc,
|
||||
Value lhs, Value rhs);
|
||||
|
||||
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred);
|
||||
} // namespace arith
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -40,7 +40,7 @@ static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
|
|||
}
|
||||
|
||||
/// Invert an integer comparison predicate.
|
||||
static arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred) {
|
||||
arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
|
||||
switch (pred) {
|
||||
case arith::CmpIPredicate::eq:
|
||||
return arith::CmpIPredicate::ne;
|
||||
|
|
|
@ -2443,11 +2443,76 @@ struct WhileUnusedResult : public OpRewritePattern<WhileOp> {
|
|||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Replace operations equivalent to the condition in the do block with true,
|
||||
/// since otherwise the block would not be evaluated.
|
||||
///
|
||||
/// scf.while (..) : (i32, ...) -> ... {
|
||||
/// %z = ... : i32
|
||||
/// %condition = cmpi pred %z, %a
|
||||
/// scf.condition(%condition) %z : i32, ...
|
||||
/// } do {
|
||||
/// ^bb0(%arg0: i32, ...):
|
||||
/// %condition2 = cmpi pred %arg0, %a
|
||||
/// use(%condition2)
|
||||
/// ...
|
||||
///
|
||||
/// becomes
|
||||
/// scf.while (..) : (i32, ...) -> ... {
|
||||
/// %z = ... : i32
|
||||
/// %condition = cmpi pred %z, %a
|
||||
/// scf.condition(%condition) %z : i32, ...
|
||||
/// } do {
|
||||
/// ^bb0(%arg0: i32, ...):
|
||||
/// use(%true)
|
||||
/// ...
|
||||
struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> {
|
||||
using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(scf::WhileOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
using namespace scf;
|
||||
auto cond = op.getConditionOp();
|
||||
auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
|
||||
if (!cmp)
|
||||
return failure();
|
||||
bool changed = false;
|
||||
for (auto tup :
|
||||
llvm::zip(cond.getArgs(), op.getAfter().front().getArguments())) {
|
||||
for (size_t opIdx = 0; opIdx < 2; opIdx++) {
|
||||
if (std::get<0>(tup) != cmp.getOperand(opIdx))
|
||||
continue;
|
||||
for (OpOperand &u :
|
||||
llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
|
||||
auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
|
||||
if (!cmp2)
|
||||
continue;
|
||||
// For a binary operator 1-opIdx gets the other side.
|
||||
if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
|
||||
continue;
|
||||
bool samePredicate;
|
||||
if (cmp2.getPredicate() == cmp.getPredicate())
|
||||
samePredicate = true;
|
||||
else if (cmp2.getPredicate() ==
|
||||
arith::invertPredicate(cmp.getPredicate()))
|
||||
samePredicate = false;
|
||||
else
|
||||
continue;
|
||||
|
||||
rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(cmp2, samePredicate,
|
||||
1);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return success(changed);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void WhileOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<WhileConditionTruth, WhileUnusedResult>(context);
|
||||
results.insert<WhileConditionTruth, WhileUnusedResult, WhileCmpCond>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -872,6 +872,60 @@ func @while_unused_result() -> i32 {
|
|||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %[[res]] : i32
|
||||
|
||||
// CHECK-LABEL: @while_cmp_lhs
|
||||
func @while_cmp_lhs(%arg0 : i32) {
|
||||
%0 = scf.while () : () -> i32 {
|
||||
%val = "test.val"() : () -> i32
|
||||
%condition = arith.cmpi ne, %val, %arg0 : i32
|
||||
scf.condition(%condition) %val : i32
|
||||
} do {
|
||||
^bb0(%val2: i32):
|
||||
%condition2 = arith.cmpi ne, %val2, %arg0 : i32
|
||||
%negcondition2 = arith.cmpi eq, %val2, %arg0 : i32
|
||||
"test.use"(%condition2, %negcondition2, %val2) : (i1, i1, i32) -> ()
|
||||
scf.yield
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK-DAG: %[[true:.+]] = arith.constant true
|
||||
// CHECK-DAG: %[[false:.+]] = arith.constant false
|
||||
// CHECK-DAG: %{{.+}} = scf.while : () -> i32 {
|
||||
// CHECK-NEXT: %[[val:.+]] = "test.val"
|
||||
// CHECK-NEXT: %[[cmp:.+]] = arith.cmpi ne, %[[val]], %arg0 : i32
|
||||
// CHECK-NEXT: scf.condition(%[[cmp]]) %[[val]] : i32
|
||||
// CHECK-NEXT: } do {
|
||||
// CHECK-NEXT: ^bb0(%arg1: i32): // no predecessors
|
||||
// CHECK-NEXT: "test.use"(%[[true]], %[[false]], %arg1) : (i1, i1, i32) -> ()
|
||||
// CHECK-NEXT: scf.yield
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// CHECK-LABEL: @while_cmp_rhs
|
||||
func @while_cmp_rhs(%arg0 : i32) {
|
||||
%0 = scf.while () : () -> i32 {
|
||||
%val = "test.val"() : () -> i32
|
||||
%condition = arith.cmpi ne, %arg0, %val : i32
|
||||
scf.condition(%condition) %val : i32
|
||||
} do {
|
||||
^bb0(%val2: i32):
|
||||
%condition2 = arith.cmpi ne, %arg0, %val2 : i32
|
||||
%negcondition2 = arith.cmpi eq, %arg0, %val2 : i32
|
||||
"test.use"(%condition2, %negcondition2, %val2) : (i1, i1, i32) -> ()
|
||||
scf.yield
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK-DAG: %[[true:.+]] = arith.constant true
|
||||
// CHECK-DAG: %[[false:.+]] = arith.constant false
|
||||
// CHECK-DAG: %{{.+}} = scf.while : () -> i32 {
|
||||
// CHECK-NEXT: %[[val:.+]] = "test.val"
|
||||
// CHECK-NEXT: %[[cmp:.+]] = arith.cmpi ne, %arg0, %[[val]] : i32
|
||||
// CHECK-NEXT: scf.condition(%[[cmp]]) %[[val]] : i32
|
||||
// CHECK-NEXT: } do {
|
||||
// CHECK-NEXT: ^bb0(%arg1: i32): // no predecessors
|
||||
// CHECK-NEXT: "test.use"(%[[true]], %[[false]], %arg1) : (i1, i1, i32) -> ()
|
||||
// CHECK-NEXT: scf.yield
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @combineIfs
|
||||
|
|
Loading…
Reference in New Issue