forked from OSchip/llvm-project
Adding an m_NonZero constant integer matcher.
This is useful for making matching cases where a non-zero value is required more readable, such as the results of a constant comparison that are expected to be equal. PiperOrigin-RevId: 278932874
This commit is contained in:
parent
b5654d1311
commit
68bd355505
|
@ -111,16 +111,24 @@ struct constant_int_op_binder {
|
|||
}
|
||||
};
|
||||
|
||||
// The matcher that matches a given target constant scalar / vector splat /
|
||||
// tensor splat integer value.
|
||||
/// The matcher that matches a given target constant scalar / vector splat /
|
||||
/// tensor splat integer value.
|
||||
template <int64_t TargetValue> struct constant_int_value_matcher {
|
||||
bool match(Operation *op) {
|
||||
APInt value;
|
||||
|
||||
return constant_int_op_binder(&value).match(op) && TargetValue == value;
|
||||
}
|
||||
};
|
||||
|
||||
/// The matcher that matches anything except the given target constant scalar /
|
||||
/// vector splat / tensor splat integer value.
|
||||
template <int64_t TargetNotValue> struct constant_int_not_value_matcher {
|
||||
bool match(Operation *op) {
|
||||
APInt value;
|
||||
return constant_int_op_binder(&value).match(op) && TargetNotValue != value;
|
||||
}
|
||||
};
|
||||
|
||||
/// The matcher that matches a certain kind of op.
|
||||
template <typename OpClass> struct op_matcher {
|
||||
bool match(Operation *op) { return isa<OpClass>(op); }
|
||||
|
@ -172,6 +180,12 @@ inline detail::constant_int_value_matcher<0> m_Zero() {
|
|||
return detail::constant_int_value_matcher<0>();
|
||||
}
|
||||
|
||||
/// Matches a constant scalar / vector splat / tensor splat integer that is any
|
||||
/// non-zero value.
|
||||
inline detail::constant_int_not_value_matcher<0> m_NonZero() {
|
||||
return detail::constant_int_not_value_matcher<0>();
|
||||
}
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_MATCHERS_H
|
||||
|
|
|
@ -1070,27 +1070,20 @@ struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
|
|||
|
||||
PatternMatchResult matchAndRewrite(CondBranchOp condbr,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Check that the condition is a constant.
|
||||
if (!matchPattern(condbr.getCondition(), m_Op<ConstantOp>()))
|
||||
return matchFailure();
|
||||
|
||||
Block *foldedDest;
|
||||
SmallVector<Value *, 4> branchArgs;
|
||||
|
||||
// If the condition is known to evaluate to false we fold to a branch to the
|
||||
// false destination. Otherwise, we fold to a branch to the true
|
||||
// destination.
|
||||
if (matchPattern(condbr.getCondition(), m_Zero())) {
|
||||
foldedDest = condbr.getFalseDest();
|
||||
branchArgs.assign(condbr.false_operand_begin(),
|
||||
condbr.false_operand_end());
|
||||
} else {
|
||||
foldedDest = condbr.getTrueDest();
|
||||
branchArgs.assign(condbr.true_operand_begin(), condbr.true_operand_end());
|
||||
if (matchPattern(condbr.getCondition(), m_NonZero())) {
|
||||
// True branch taken.
|
||||
rewriter.replaceOpWithNewOp<BranchOp>(
|
||||
condbr, condbr.getTrueDest(),
|
||||
llvm::to_vector<4>(condbr.getTrueOperands()));
|
||||
return matchSuccess();
|
||||
} else if (matchPattern(condbr.getCondition(), m_Zero())) {
|
||||
// False branch taken.
|
||||
rewriter.replaceOpWithNewOp<BranchOp>(
|
||||
condbr, condbr.getFalseDest(),
|
||||
llvm::to_vector<4>(condbr.getFalseOperands()));
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<BranchOp>(condbr, foldedDest, branchArgs);
|
||||
return matchSuccess();
|
||||
return matchFailure();
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace.
|
||||
|
|
Loading…
Reference in New Issue