forked from OSchip/llvm-project
Add a canonicalization pattern for conditional branch to fold constant branch conditions.
PiperOrigin-RevId: 229242007
This commit is contained in:
parent
06b0bd9651
commit
ed26dd0421
|
@ -157,6 +157,9 @@ public:
|
|||
void print(OpAsmPrinter *p) const;
|
||||
bool verify() const;
|
||||
|
||||
static void getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context);
|
||||
|
||||
// The condition operand is the first operand in the list.
|
||||
Value *getCondition() { return getOperand(0); }
|
||||
const Value *getCondition() const { return getOperand(0); }
|
||||
|
@ -186,17 +189,21 @@ public:
|
|||
setOperand(getTrueDestOperandIndex() + idx, value);
|
||||
}
|
||||
|
||||
operand_iterator true_operand_begin() { return operand_begin(); }
|
||||
operand_iterator true_operand_begin() {
|
||||
return operand_begin() + getTrueDestOperandIndex();
|
||||
}
|
||||
operand_iterator true_operand_end() {
|
||||
return operand_begin() + getNumTrueOperands();
|
||||
return true_operand_begin() + getNumTrueOperands();
|
||||
}
|
||||
llvm::iterator_range<operand_iterator> getTrueOperands() {
|
||||
return {true_operand_begin(), true_operand_end()};
|
||||
}
|
||||
|
||||
const_operand_iterator true_operand_begin() const { return operand_begin(); }
|
||||
const_operand_iterator true_operand_begin() const {
|
||||
return operand_begin() + getTrueDestOperandIndex();
|
||||
}
|
||||
const_operand_iterator true_operand_end() const {
|
||||
return operand_begin() + getNumTrueOperands();
|
||||
return true_operand_begin() + getNumTrueOperands();
|
||||
}
|
||||
llvm::iterator_range<const_operand_iterator> getTrueOperands() const {
|
||||
return {true_operand_begin(), true_operand_end()};
|
||||
|
@ -245,10 +252,10 @@ public:
|
|||
|
||||
private:
|
||||
/// Get the index of the first true destination operand.
|
||||
unsigned getTrueDestOperandIndex() { return 1; }
|
||||
unsigned getTrueDestOperandIndex() const { return 1; }
|
||||
|
||||
/// Get the index of the first false destination operand.
|
||||
unsigned getFalseDestOperandIndex() {
|
||||
unsigned getFalseDestOperandIndex() const {
|
||||
return getTrueDestOperandIndex() + getNumTrueOperands();
|
||||
}
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
@ -319,6 +320,44 @@ void BranchOp::eraseOperand(unsigned index) {
|
|||
// CondBranchOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
/// cond_br true, ^bb1, ^bb2 -> br ^bb1
|
||||
/// cond_br false, ^bb1, ^bb2 -> br ^bb2
|
||||
///
|
||||
struct SimplifyConstCondBranchPred : public RewritePattern {
|
||||
SimplifyConstCondBranchPred(MLIRContext *context)
|
||||
: RewritePattern(CondBranchOp::getOperationName(), 1, context) {}
|
||||
|
||||
PatternMatchResult match(OperationInst *op) const override {
|
||||
auto condbr = op->cast<CondBranchOp>();
|
||||
if (matchPattern(condbr->getCondition(), m_Op<ConstantOp>()))
|
||||
return matchSuccess();
|
||||
|
||||
return matchFailure();
|
||||
}
|
||||
void rewrite(OperationInst *op, PatternRewriter &rewriter) const override {
|
||||
auto condbr = op->cast<CondBranchOp>();
|
||||
Block *foldedDest;
|
||||
SmallVector<Value *, 4> branchArgs;
|
||||
|
||||
// If the condition is known to evaluate to false we fold to a branch to the
|
||||
// false destination. Otherwise, we fold to a branch to the true
|
||||
// destination.
|
||||
if (matchPattern(condbr->getCondition(), m_Zero())) {
|
||||
foldedDest = condbr->getFalseDest();
|
||||
branchArgs.assign(condbr->false_operand_begin(),
|
||||
condbr->false_operand_end());
|
||||
} else {
|
||||
foldedDest = condbr->getTrueDest();
|
||||
branchArgs.assign(condbr->true_operand_begin(),
|
||||
condbr->true_operand_end());
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<BranchOp>(op, foldedDest, branchArgs);
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace.
|
||||
|
||||
void CondBranchOp::build(Builder *builder, OperationState *result,
|
||||
Value *condition, Block *trueDest,
|
||||
ArrayRef<Value *> trueOperands, Block *falseDest,
|
||||
|
@ -372,6 +411,11 @@ bool CondBranchOp::verify() const {
|
|||
return false;
|
||||
}
|
||||
|
||||
void CondBranchOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
results.push_back(std::make_unique<SimplifyConstCondBranchPred>(context));
|
||||
}
|
||||
|
||||
Block *CondBranchOp::getTrueDest() {
|
||||
return getInstruction()->getSuccessor(trueIndex);
|
||||
}
|
||||
|
|
|
@ -273,3 +273,19 @@ func @simplify_affine_apply(%arg0: memref<index>, %arg1: index, %arg2: index) {
|
|||
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @cond_br_folding
|
||||
func @cond_br_folding(%a : i32) {
|
||||
%false_cond = constant 0 : i1
|
||||
%true_cond = constant 1 : i1
|
||||
|
||||
// CHECK-NEXT: br ^bb1(%arg0 : i32)
|
||||
cond_br %true_cond, ^bb1(%a : i32), ^bb2
|
||||
|
||||
^bb1(%x : i32):
|
||||
// CHECK: br ^bb2
|
||||
cond_br %false_cond, ^bb1(%x : i32), ^bb2
|
||||
|
||||
^bb2:
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue