forked from OSchip/llvm-project
[mlir][Standard] Add a canonicalization to simplify cond_br when the successors are identical
This revision adds support for canonicalizing the following: ``` cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N) br ^bb1(A, ..., N) ``` If the operands to the successor are different and the cond_br is the only predecessor, we emit selects for the branch operands. ``` cond_br %cond, ^bb1(A), ^bb1(B) %select = select %cond, A, B br ^bb1(%select) ``` Differential Revision: https://reviews.llvm.org/D78682
This commit is contained in:
parent
2f4b303d68
commit
af331bc52d
|
@ -1123,10 +1123,13 @@ public:
|
|||
}
|
||||
|
||||
/// Compare this range with another.
|
||||
template <typename OtherT> bool operator==(const OtherT &other) {
|
||||
template <typename OtherT> bool operator==(const OtherT &other) const {
|
||||
return size() == std::distance(other.begin(), other.end()) &&
|
||||
std::equal(begin(), end(), other.begin());
|
||||
}
|
||||
template <typename OtherT> bool operator!=(const OtherT &other) const {
|
||||
return !(*this == other);
|
||||
}
|
||||
|
||||
/// Return the size of this range.
|
||||
size_t size() const { return count; }
|
||||
|
|
|
@ -1951,9 +1951,9 @@ def SelectOp : Std_Op<"select", [NoSideEffect, SameOperandsAndResultShape,
|
|||
}];
|
||||
|
||||
let arguments = (ins BoolLike:$condition,
|
||||
SignlessIntegerOrFloatLike:$true_value,
|
||||
SignlessIntegerOrFloatLike:$false_value);
|
||||
let results = (outs SignlessIntegerOrFloatLike:$result);
|
||||
AnyType:$true_value,
|
||||
AnyType:$false_value);
|
||||
let results = (outs AnyType:$result);
|
||||
let verifier = ?;
|
||||
|
||||
let builders = [OpBuilder<
|
||||
|
|
|
@ -248,6 +248,10 @@ public:
|
|||
/// destinations) is not considered to be a single predecessor.
|
||||
Block *getSinglePredecessor();
|
||||
|
||||
/// If this block has a unique predecessor, i.e., all incoming edges originate
|
||||
/// from one block, return it. Otherwise, return null.
|
||||
Block *getUniquePredecessor();
|
||||
|
||||
// Indexed successor access.
|
||||
unsigned getNumSuccessors();
|
||||
Block *getSuccessor(unsigned i);
|
||||
|
|
|
@ -684,23 +684,15 @@ void CallIndirectOp::getCanonicalizationPatterns(
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Return the type of the same shape (scalar, vector or tensor) containing i1.
|
||||
static Type getCheckedI1SameShape(Type type) {
|
||||
static Type getI1SameShape(Type type) {
|
||||
auto i1Type = IntegerType::get(1, type.getContext());
|
||||
if (type.isSignlessIntOrIndexOrFloat())
|
||||
return i1Type;
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>())
|
||||
return RankedTensorType::get(tensorType.getShape(), i1Type);
|
||||
if (type.isa<UnrankedTensorType>())
|
||||
return UnrankedTensorType::get(i1Type);
|
||||
if (auto vectorType = type.dyn_cast<VectorType>())
|
||||
return VectorType::get(vectorType.getShape(), i1Type);
|
||||
return Type();
|
||||
}
|
||||
|
||||
static Type getI1SameShape(Type type) {
|
||||
Type res = getCheckedI1SameShape(type);
|
||||
assert(res && "expected type with valid i1 shape");
|
||||
return res;
|
||||
return i1Type;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -840,8 +832,10 @@ OpFoldResult CmpFOp::fold(ArrayRef<Attribute> operands) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
/// cond_br true, ^bb1, ^bb2 -> br ^bb1
|
||||
/// cond_br false, ^bb1, ^bb2 -> br ^bb2
|
||||
/// cond_br true, ^bb1, ^bb2
|
||||
/// -> br ^bb1
|
||||
/// cond_br false, ^bb1, ^bb2
|
||||
/// -> br ^bb2
|
||||
///
|
||||
struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
|
||||
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
|
||||
|
@ -869,7 +863,7 @@ struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
|
|||
/// ^bb2
|
||||
/// br ^bbK(...)
|
||||
///
|
||||
/// cond_br %cond, ^bbN(...), ^bbK(...)
|
||||
/// -> cond_br %cond, ^bbN(...), ^bbK(...)
|
||||
///
|
||||
struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
|
||||
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
|
||||
|
@ -943,12 +937,70 @@ struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
|
|||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N)
|
||||
/// -> br ^bb1(A, ..., N)
|
||||
///
|
||||
/// cond_br %cond, ^bb1(A), ^bb1(B)
|
||||
/// -> %select = select %cond, A, B
|
||||
/// br ^bb1(%select)
|
||||
///
|
||||
struct SimplifyCondBranchIdenticalSuccessors
|
||||
: public OpRewritePattern<CondBranchOp> {
|
||||
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(CondBranchOp condbr,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Check that the true and false destinations are the same and have the same
|
||||
// operands.
|
||||
Block *trueDest = condbr.trueDest();
|
||||
if (trueDest != condbr.falseDest())
|
||||
return failure();
|
||||
|
||||
// If all of the operands match, no selects need to be generated.
|
||||
OperandRange trueOperands = condbr.getTrueOperands();
|
||||
OperandRange falseOperands = condbr.getFalseOperands();
|
||||
if (trueOperands == falseOperands) {
|
||||
rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Otherwise, if the current block is the only predecessor insert selects
|
||||
// for any mismatched branch operands.
|
||||
if (trueDest->getUniquePredecessor() != condbr.getOperation()->getBlock())
|
||||
return failure();
|
||||
|
||||
// TODO: ATM Tensor/Vector SelectOp requires that the condition has the same
|
||||
// shape as the operands. We should relax that to allow an i1 to signify
|
||||
// that everything is selected.
|
||||
auto doesntSupportsScalarI1 = [](Type type) {
|
||||
return type.isa<TensorType>() || type.isa<VectorType>();
|
||||
};
|
||||
if (llvm::any_of(trueOperands.getTypes(), doesntSupportsScalarI1))
|
||||
return failure();
|
||||
|
||||
// Generate a select for any operands that differ between the two.
|
||||
SmallVector<Value, 8> mergedOperands;
|
||||
mergedOperands.reserve(trueOperands.size());
|
||||
Value condition = condbr.getCondition();
|
||||
for (auto it : llvm::zip(trueOperands, falseOperands)) {
|
||||
if (std::get<0>(it) == std::get<1>(it))
|
||||
mergedOperands.push_back(std::get<0>(it));
|
||||
else
|
||||
mergedOperands.push_back(rewriter.create<SelectOp>(
|
||||
condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it)));
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
void CondBranchOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
results.insert<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch>(
|
||||
context);
|
||||
results.insert<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
|
||||
SimplifyCondBranchIdenticalSuccessors>(context);
|
||||
}
|
||||
|
||||
Optional<OperandRange> CondBranchOp::getSuccessorOperands(unsigned index) {
|
||||
|
|
|
@ -229,6 +229,21 @@ Block *Block::getSinglePredecessor() {
|
|||
return it == pred_end() ? firstPred : nullptr;
|
||||
}
|
||||
|
||||
/// If this block has a unique predecessor, i.e., all incoming edges originate
|
||||
/// from one block, return it. Otherwise, return null.
|
||||
Block *Block::getUniquePredecessor() {
|
||||
auto it = pred_begin(), e = pred_end();
|
||||
if (it == e)
|
||||
return nullptr;
|
||||
|
||||
// Check for any conflicting predecessors.
|
||||
auto *firstPred = *it;
|
||||
for (++it; it != e; ++it)
|
||||
if (*it != firstPred)
|
||||
return nullptr;
|
||||
return firstPred;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Other
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
// RUN: mlir-opt %s -allow-unregistered-dialect -pass-pipeline='func(canonicalize)' -split-input-file | FileCheck %s
|
||||
|
||||
// Test the folding of BranchOp.
|
||||
/// Test the folding of BranchOp.
|
||||
|
||||
// CHECK-LABEL: func @br_folding(
|
||||
func @br_folding() -> i32 {
|
||||
|
@ -12,11 +12,11 @@ func @br_folding() -> i32 {
|
|||
return %x : i32
|
||||
}
|
||||
|
||||
// Test the folding of CondBranchOp with a constant condition.
|
||||
/// Test the folding of CondBranchOp with a constant condition.
|
||||
|
||||
// CHECK-LABEL: func @cond_br_folding(
|
||||
func @cond_br_folding(%cond : i1, %a : i32) {
|
||||
// CHECK-NEXT: cond_br %{{.*}}, ^bb1, ^bb1
|
||||
// CHECK-NEXT: return
|
||||
|
||||
%false_cond = constant 0 : i1
|
||||
%true_cond = constant 1 : i1
|
||||
|
@ -29,13 +29,62 @@ func @cond_br_folding(%cond : i1, %a : i32) {
|
|||
cond_br %false_cond, ^bb2(%x : i32), ^bb3
|
||||
|
||||
^bb3:
|
||||
// CHECK: ^bb1:
|
||||
// CHECK-NEXT: return
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Test the compound folding of BranchOp and CondBranchOp.
|
||||
/// Test the folding of CondBranchOp when the successors are identical.
|
||||
|
||||
// CHECK-LABEL: func @cond_br_same_successor(
|
||||
func @cond_br_same_successor(%cond : i1, %a : i32) {
|
||||
// CHECK-NEXT: return
|
||||
|
||||
cond_br %cond, ^bb1(%a : i32), ^bb1(%a : i32)
|
||||
|
||||
^bb1(%result : i32):
|
||||
return
|
||||
}
|
||||
|
||||
/// Test the folding of CondBranchOp when the successors are identical, but the
|
||||
/// arguments are different.
|
||||
|
||||
// CHECK-LABEL: func @cond_br_same_successor_insert_select(
|
||||
// CHECK-SAME: %[[COND:.*]]: i1, %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32
|
||||
func @cond_br_same_successor_insert_select(%cond : i1, %a : i32, %b : i32) -> i32 {
|
||||
// CHECK: %[[RES:.*]] = select %[[COND]], %[[ARG0]], %[[ARG1]]
|
||||
// CHECK: return %[[RES]]
|
||||
|
||||
cond_br %cond, ^bb1(%a : i32), ^bb1(%b : i32)
|
||||
|
||||
^bb1(%result : i32):
|
||||
return %result : i32
|
||||
}
|
||||
|
||||
/// Check that we don't generate a select if the type requires a splat.
|
||||
/// TODO: SelectOp should allow for matching a vector/tensor with i1.
|
||||
|
||||
// CHECK-LABEL: func @cond_br_same_successor_no_select_tensor(
|
||||
func @cond_br_same_successor_no_select_tensor(%cond : i1, %a : tensor<2xi32>,
|
||||
%b : tensor<2xi32>) -> tensor<2xi32>{
|
||||
// CHECK: cond_br
|
||||
|
||||
cond_br %cond, ^bb1(%a : tensor<2xi32>), ^bb1(%b : tensor<2xi32>)
|
||||
|
||||
^bb1(%result : tensor<2xi32>):
|
||||
return %result : tensor<2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @cond_br_same_successor_no_select_vector(
|
||||
func @cond_br_same_successor_no_select_vector(%cond : i1, %a : vector<2xi32>,
|
||||
%b : vector<2xi32>) -> vector<2xi32> {
|
||||
// CHECK: cond_br
|
||||
|
||||
cond_br %cond, ^bb1(%a : vector<2xi32>), ^bb1(%b : vector<2xi32>)
|
||||
|
||||
^bb1(%result : vector<2xi32>):
|
||||
return %result : vector<2xi32>
|
||||
}
|
||||
|
||||
/// Test the compound folding of BranchOp and CondBranchOp.
|
||||
|
||||
// CHECK-LABEL: func @cond_br_and_br_folding(
|
||||
func @cond_br_and_br_folding(%a : i32) {
|
||||
|
@ -55,9 +104,11 @@ func @cond_br_and_br_folding(%a : i32) {
|
|||
/// Test that pass-through successors of CondBranchOp get folded.
|
||||
|
||||
// CHECK-LABEL: func @cond_br_pass_through(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32
|
||||
// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[COND:.*]]: i1
|
||||
func @cond_br_pass_through(%arg0 : i32, %arg1 : i32, %arg2 : i32, %cond : i1) -> (i32, i32) {
|
||||
// CHECK: cond_br %{{.*}}, ^bb1(%[[ARG0]], %[[ARG1]] : i32, i32), ^bb1(%[[ARG2]], %[[ARG2]] : i32, i32)
|
||||
// CHECK: %[[RES:.*]] = select %[[COND]], %[[ARG0]], %[[ARG2]]
|
||||
// CHECK: %[[RES2:.*]] = select %[[COND]], %[[ARG1]], %[[ARG2]]
|
||||
// CHECK: return %[[RES]], %[[RES2]]
|
||||
|
||||
cond_br %cond, ^bb1(%arg0 : i32), ^bb2(%arg2, %arg2 : i32, i32)
|
||||
|
||||
|
@ -65,9 +116,6 @@ func @cond_br_pass_through(%arg0 : i32, %arg1 : i32, %arg2 : i32, %cond : i1) ->
|
|||
br ^bb2(%arg3, %arg1 : i32, i32)
|
||||
|
||||
^bb2(%arg4: i32, %arg5: i32):
|
||||
// CHECK: ^bb1(%[[RET0:.*]]: i32, %[[RET1:.*]]: i32):
|
||||
// CHECK-NEXT: return %[[RET0]], %[[RET1]]
|
||||
|
||||
return %arg4, %arg5 : i32, i32
|
||||
}
|
||||
|
||||
|
|
|
@ -297,12 +297,6 @@ func @func_with_ops(i1, tensor<42xi32>, tensor<?xi32>) {
|
|||
|
||||
// -----
|
||||
|
||||
func @invalid_select_shape(%cond : i1, %idx : () -> ()) {
|
||||
// expected-error@+1 {{'result' must be signless-integer-like or floating-point-like, but got '() -> ()'}}
|
||||
%sel = select %cond, %idx, %idx : () -> ()
|
||||
|
||||
// -----
|
||||
|
||||
func @invalid_cmp_shape(%idx : () -> ()) {
|
||||
// expected-error@+1 {{'lhs' must be signless-integer-like, but got '() -> ()'}}
|
||||
%cmp = cmpi "eq", %idx, %idx : () -> ()
|
||||
|
|
Loading…
Reference in New Issue