[mlir][Standard] Add a canonicalization to simplify cond_br when the successors are identical

This revision adds support for canonicalizing the following:

```
cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N)

br ^bb1(A, ..., N)
```

 If the operands to the successor are different and the cond_br is the only predecessor, we emit selects for the branch operands.

```
cond_br %cond, ^bb1(A), ^bb1(B)

%select = select %cond, A, B
br ^bb1(%select)
```

Differential Revision: https://reviews.llvm.org/D78682
This commit is contained in:
River Riddle 2020-04-23 04:40:25 -07:00
parent 2f4b303d68
commit af331bc52d
7 changed files with 153 additions and 37 deletions

View File

@ -1123,10 +1123,13 @@ public:
}
/// Compare this range with another.
template <typename OtherT> bool operator==(const OtherT &other) {
template <typename OtherT> bool operator==(const OtherT &other) const {
return size() == std::distance(other.begin(), other.end()) &&
std::equal(begin(), end(), other.begin());
}
template <typename OtherT> bool operator!=(const OtherT &other) const {
return !(*this == other);
}
/// Return the size of this range.
size_t size() const { return count; }

View File

@ -1951,9 +1951,9 @@ def SelectOp : Std_Op<"select", [NoSideEffect, SameOperandsAndResultShape,
}];
let arguments = (ins BoolLike:$condition,
SignlessIntegerOrFloatLike:$true_value,
SignlessIntegerOrFloatLike:$false_value);
let results = (outs SignlessIntegerOrFloatLike:$result);
AnyType:$true_value,
AnyType:$false_value);
let results = (outs AnyType:$result);
let verifier = ?;
let builders = [OpBuilder<

View File

@ -248,6 +248,10 @@ public:
/// destinations) is not considered to be a single predecessor.
Block *getSinglePredecessor();
/// If this block has a unique predecessor, i.e., all incoming edges originate
/// from one block, return it. Otherwise, return null.
Block *getUniquePredecessor();
// Indexed successor access.
unsigned getNumSuccessors();
Block *getSuccessor(unsigned i);

View File

@ -684,23 +684,15 @@ void CallIndirectOp::getCanonicalizationPatterns(
//===----------------------------------------------------------------------===//
// Return the type of the same shape (scalar, vector or tensor) containing i1.
static Type getCheckedI1SameShape(Type type) {
static Type getI1SameShape(Type type) {
auto i1Type = IntegerType::get(1, type.getContext());
if (type.isSignlessIntOrIndexOrFloat())
return i1Type;
if (auto tensorType = type.dyn_cast<RankedTensorType>())
return RankedTensorType::get(tensorType.getShape(), i1Type);
if (type.isa<UnrankedTensorType>())
return UnrankedTensorType::get(i1Type);
if (auto vectorType = type.dyn_cast<VectorType>())
return VectorType::get(vectorType.getShape(), i1Type);
return Type();
}
static Type getI1SameShape(Type type) {
Type res = getCheckedI1SameShape(type);
assert(res && "expected type with valid i1 shape");
return res;
return i1Type;
}
//===----------------------------------------------------------------------===//
@ -840,8 +832,10 @@ OpFoldResult CmpFOp::fold(ArrayRef<Attribute> operands) {
//===----------------------------------------------------------------------===//
namespace {
/// cond_br true, ^bb1, ^bb2 -> br ^bb1
/// cond_br false, ^bb1, ^bb2 -> br ^bb2
/// cond_br true, ^bb1, ^bb2
/// -> br ^bb1
/// cond_br false, ^bb1, ^bb2
/// -> br ^bb2
///
struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
@ -869,7 +863,7 @@ struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
/// ^bb2
/// br ^bbK(...)
///
/// cond_br %cond, ^bbN(...), ^bbK(...)
/// -> cond_br %cond, ^bbN(...), ^bbK(...)
///
struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
@ -943,12 +937,70 @@ struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
return success();
}
};
/// cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N)
/// -> br ^bb1(A, ..., N)
///
/// cond_br %cond, ^bb1(A), ^bb1(B)
/// -> %select = select %cond, A, B
/// br ^bb1(%select)
///
struct SimplifyCondBranchIdenticalSuccessors
: public OpRewritePattern<CondBranchOp> {
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
LogicalResult matchAndRewrite(CondBranchOp condbr,
PatternRewriter &rewriter) const override {
// Check that the true and false destinations are the same and have the same
// operands.
Block *trueDest = condbr.trueDest();
if (trueDest != condbr.falseDest())
return failure();
// If all of the operands match, no selects need to be generated.
OperandRange trueOperands = condbr.getTrueOperands();
OperandRange falseOperands = condbr.getFalseOperands();
if (trueOperands == falseOperands) {
rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands);
return success();
}
// Otherwise, if the current block is the only predecessor insert selects
// for any mismatched branch operands.
if (trueDest->getUniquePredecessor() != condbr.getOperation()->getBlock())
return failure();
// TODO: ATM Tensor/Vector SelectOp requires that the condition has the same
// shape as the operands. We should relax that to allow an i1 to signify
// that everything is selected.
auto doesntSupportsScalarI1 = [](Type type) {
return type.isa<TensorType>() || type.isa<VectorType>();
};
if (llvm::any_of(trueOperands.getTypes(), doesntSupportsScalarI1))
return failure();
// Generate a select for any operands that differ between the two.
SmallVector<Value, 8> mergedOperands;
mergedOperands.reserve(trueOperands.size());
Value condition = condbr.getCondition();
for (auto it : llvm::zip(trueOperands, falseOperands)) {
if (std::get<0>(it) == std::get<1>(it))
mergedOperands.push_back(std::get<0>(it));
else
mergedOperands.push_back(rewriter.create<SelectOp>(
condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it)));
}
rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands);
return success();
}
};
} // end anonymous namespace
void CondBranchOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch>(
context);
results.insert<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
SimplifyCondBranchIdenticalSuccessors>(context);
}
Optional<OperandRange> CondBranchOp::getSuccessorOperands(unsigned index) {

View File

@ -229,6 +229,21 @@ Block *Block::getSinglePredecessor() {
return it == pred_end() ? firstPred : nullptr;
}
/// If this block has a unique predecessor, i.e., all incoming edges originate
/// from one block, return it. Otherwise, return null.
Block *Block::getUniquePredecessor() {
auto it = pred_begin(), e = pred_end();
if (it == e)
return nullptr;
// Check for any conflicting predecessors.
auto *firstPred = *it;
for (++it; it != e; ++it)
if (*it != firstPred)
return nullptr;
return firstPred;
}
//===----------------------------------------------------------------------===//
// Other
//===----------------------------------------------------------------------===//

View File

@ -1,6 +1,6 @@
// RUN: mlir-opt %s -allow-unregistered-dialect -pass-pipeline='func(canonicalize)' -split-input-file | FileCheck %s
// Test the folding of BranchOp.
/// Test the folding of BranchOp.
// CHECK-LABEL: func @br_folding(
func @br_folding() -> i32 {
@ -12,11 +12,11 @@ func @br_folding() -> i32 {
return %x : i32
}
// Test the folding of CondBranchOp with a constant condition.
/// Test the folding of CondBranchOp with a constant condition.
// CHECK-LABEL: func @cond_br_folding(
func @cond_br_folding(%cond : i1, %a : i32) {
// CHECK-NEXT: cond_br %{{.*}}, ^bb1, ^bb1
// CHECK-NEXT: return
%false_cond = constant 0 : i1
%true_cond = constant 1 : i1
@ -29,13 +29,62 @@ func @cond_br_folding(%cond : i1, %a : i32) {
cond_br %false_cond, ^bb2(%x : i32), ^bb3
^bb3:
// CHECK: ^bb1:
// CHECK-NEXT: return
return
}
// Test the compound folding of BranchOp and CondBranchOp.
/// Test the folding of CondBranchOp when the successors are identical.
// CHECK-LABEL: func @cond_br_same_successor(
func @cond_br_same_successor(%cond : i1, %a : i32) {
// CHECK-NEXT: return
cond_br %cond, ^bb1(%a : i32), ^bb1(%a : i32)
^bb1(%result : i32):
return
}
/// Test the folding of CondBranchOp when the successors are identical, but the
/// arguments are different.
// CHECK-LABEL: func @cond_br_same_successor_insert_select(
// CHECK-SAME: %[[COND:.*]]: i1, %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32
func @cond_br_same_successor_insert_select(%cond : i1, %a : i32, %b : i32) -> i32 {
// CHECK: %[[RES:.*]] = select %[[COND]], %[[ARG0]], %[[ARG1]]
// CHECK: return %[[RES]]
cond_br %cond, ^bb1(%a : i32), ^bb1(%b : i32)
^bb1(%result : i32):
return %result : i32
}
/// Check that we don't generate a select if the type requires a splat.
/// TODO: SelectOp should allow for matching a vector/tensor with i1.
// CHECK-LABEL: func @cond_br_same_successor_no_select_tensor(
func @cond_br_same_successor_no_select_tensor(%cond : i1, %a : tensor<2xi32>,
%b : tensor<2xi32>) -> tensor<2xi32>{
// CHECK: cond_br
cond_br %cond, ^bb1(%a : tensor<2xi32>), ^bb1(%b : tensor<2xi32>)
^bb1(%result : tensor<2xi32>):
return %result : tensor<2xi32>
}
// CHECK-LABEL: func @cond_br_same_successor_no_select_vector(
func @cond_br_same_successor_no_select_vector(%cond : i1, %a : vector<2xi32>,
%b : vector<2xi32>) -> vector<2xi32> {
// CHECK: cond_br
cond_br %cond, ^bb1(%a : vector<2xi32>), ^bb1(%b : vector<2xi32>)
^bb1(%result : vector<2xi32>):
return %result : vector<2xi32>
}
/// Test the compound folding of BranchOp and CondBranchOp.
// CHECK-LABEL: func @cond_br_and_br_folding(
func @cond_br_and_br_folding(%a : i32) {
@ -55,9 +104,11 @@ func @cond_br_and_br_folding(%a : i32) {
/// Test that pass-through successors of CondBranchOp get folded.
// CHECK-LABEL: func @cond_br_pass_through(
// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32
// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[COND:.*]]: i1
func @cond_br_pass_through(%arg0 : i32, %arg1 : i32, %arg2 : i32, %cond : i1) -> (i32, i32) {
// CHECK: cond_br %{{.*}}, ^bb1(%[[ARG0]], %[[ARG1]] : i32, i32), ^bb1(%[[ARG2]], %[[ARG2]] : i32, i32)
// CHECK: %[[RES:.*]] = select %[[COND]], %[[ARG0]], %[[ARG2]]
// CHECK: %[[RES2:.*]] = select %[[COND]], %[[ARG1]], %[[ARG2]]
// CHECK: return %[[RES]], %[[RES2]]
cond_br %cond, ^bb1(%arg0 : i32), ^bb2(%arg2, %arg2 : i32, i32)
@ -65,9 +116,6 @@ func @cond_br_pass_through(%arg0 : i32, %arg1 : i32, %arg2 : i32, %cond : i1) ->
br ^bb2(%arg3, %arg1 : i32, i32)
^bb2(%arg4: i32, %arg5: i32):
// CHECK: ^bb1(%[[RET0:.*]]: i32, %[[RET1:.*]]: i32):
// CHECK-NEXT: return %[[RET0]], %[[RET1]]
return %arg4, %arg5 : i32, i32
}

View File

@ -297,12 +297,6 @@ func @func_with_ops(i1, tensor<42xi32>, tensor<?xi32>) {
// -----
func @invalid_select_shape(%cond : i1, %idx : () -> ()) {
// expected-error@+1 {{'result' must be signless-integer-like or floating-point-like, but got '() -> ()'}}
%sel = select %cond, %idx, %idx : () -> ()
// -----
func @invalid_cmp_shape(%idx : () -> ()) {
// expected-error@+1 {{'lhs' must be signless-integer-like, but got '() -> ()'}}
%cmp = cmpi "eq", %idx, %idx : () -> ()