Add a canonicalization pattern for conditional branch to fold constant branch conditions.

PiperOrigin-RevId: 229242007
This commit is contained in:
River Riddle 2019-01-14 13:23:18 -08:00 committed by jpienaar
parent 06b0bd9651
commit ed26dd0421
3 changed files with 73 additions and 6 deletions

View File

@ -157,6 +157,9 @@ public:
void print(OpAsmPrinter *p) const;
bool verify() const;
static void getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context);
// The condition operand is the first operand in the list.
Value *getCondition() { return getOperand(0); }
const Value *getCondition() const { return getOperand(0); }
@ -186,17 +189,21 @@ public:
setOperand(getTrueDestOperandIndex() + idx, value);
}
operand_iterator true_operand_begin() { return operand_begin(); }
operand_iterator true_operand_begin() {
return operand_begin() + getTrueDestOperandIndex();
}
operand_iterator true_operand_end() {
return operand_begin() + getNumTrueOperands();
return true_operand_begin() + getNumTrueOperands();
}
llvm::iterator_range<operand_iterator> getTrueOperands() {
return {true_operand_begin(), true_operand_end()};
}
const_operand_iterator true_operand_begin() const { return operand_begin(); }
const_operand_iterator true_operand_begin() const {
return operand_begin() + getTrueDestOperandIndex();
}
const_operand_iterator true_operand_end() const {
return operand_begin() + getNumTrueOperands();
return true_operand_begin() + getNumTrueOperands();
}
llvm::iterator_range<const_operand_iterator> getTrueOperands() const {
return {true_operand_begin(), true_operand_end()};
@ -245,10 +252,10 @@ public:
private:
/// Get the index of the first true destination operand.
unsigned getTrueDestOperandIndex() { return 1; }
unsigned getTrueDestOperandIndex() const { return 1; }
/// Get the index of the first false destination operand.
unsigned getFalseDestOperandIndex() {
unsigned getFalseDestOperandIndex() const {
return getTrueDestOperandIndex() + getNumTrueOperands();
}

View File

@ -19,6 +19,7 @@
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
@ -319,6 +320,44 @@ void BranchOp::eraseOperand(unsigned index) {
// CondBranchOp
//===----------------------------------------------------------------------===//
namespace {
/// cond_br true, ^bb1, ^bb2 -> br ^bb1
/// cond_br false, ^bb1, ^bb2 -> br ^bb2
///
struct SimplifyConstCondBranchPred : public RewritePattern {
SimplifyConstCondBranchPred(MLIRContext *context)
: RewritePattern(CondBranchOp::getOperationName(), 1, context) {}
PatternMatchResult match(OperationInst *op) const override {
auto condbr = op->cast<CondBranchOp>();
if (matchPattern(condbr->getCondition(), m_Op<ConstantOp>()))
return matchSuccess();
return matchFailure();
}
void rewrite(OperationInst *op, PatternRewriter &rewriter) const override {
auto condbr = op->cast<CondBranchOp>();
Block *foldedDest;
SmallVector<Value *, 4> branchArgs;
// If the condition is known to evaluate to false we fold to a branch to the
// false destination. Otherwise, we fold to a branch to the true
// destination.
if (matchPattern(condbr->getCondition(), m_Zero())) {
foldedDest = condbr->getFalseDest();
branchArgs.assign(condbr->false_operand_begin(),
condbr->false_operand_end());
} else {
foldedDest = condbr->getTrueDest();
branchArgs.assign(condbr->true_operand_begin(),
condbr->true_operand_end());
}
rewriter.replaceOpWithNewOp<BranchOp>(op, foldedDest, branchArgs);
}
};
} // end anonymous namespace.
void CondBranchOp::build(Builder *builder, OperationState *result,
Value *condition, Block *trueDest,
ArrayRef<Value *> trueOperands, Block *falseDest,
@ -372,6 +411,11 @@ bool CondBranchOp::verify() const {
return false;
}
void CondBranchOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.push_back(std::make_unique<SimplifyConstCondBranchPred>(context));
}
Block *CondBranchOp::getTrueDest() {
return getInstruction()->getSuccessor(trueIndex);
}

View File

@ -273,3 +273,19 @@ func @simplify_affine_apply(%arg0: memref<index>, %arg1: index, %arg2: index) {
return
}
// CHECK-LABEL: func @cond_br_folding
func @cond_br_folding(%a : i32) {
%false_cond = constant 0 : i1
%true_cond = constant 1 : i1
// CHECK-NEXT: br ^bb1(%arg0 : i32)
cond_br %true_cond, ^bb1(%a : i32), ^bb2
^bb1(%x : i32):
// CHECK: br ^bb2
cond_br %false_cond, ^bb1(%x : i32), ^bb2
^bb2:
return
}