Adding an m_NonZero constant integer matcher.

This is useful for making matching cases where a non-zero value is required more readable, such as the results of a constant comparison that are expected to be equal.

PiperOrigin-RevId: 278932874
This commit is contained in:
Ben Vanik 2019-11-06 13:51:19 -08:00 committed by A. Unique TensorFlower
parent b5654d1311
commit 68bd355505
2 changed files with 30 additions and 23 deletions

View File

@ -111,16 +111,24 @@ struct constant_int_op_binder {
}
};
// The matcher that matches a given target constant scalar / vector splat /
// tensor splat integer value.
/// The matcher that matches a given target constant scalar / vector splat /
/// tensor splat integer value.
template <int64_t TargetValue> struct constant_int_value_matcher {
bool match(Operation *op) {
APInt value;
return constant_int_op_binder(&value).match(op) && TargetValue == value;
}
};
/// The matcher that matches anything except the given target constant scalar /
/// vector splat / tensor splat integer value.
template <int64_t TargetNotValue> struct constant_int_not_value_matcher {
bool match(Operation *op) {
APInt value;
return constant_int_op_binder(&value).match(op) && TargetNotValue != value;
}
};
/// The matcher that matches a certain kind of op.
template <typename OpClass> struct op_matcher {
bool match(Operation *op) { return isa<OpClass>(op); }
@ -172,6 +180,12 @@ inline detail::constant_int_value_matcher<0> m_Zero() {
return detail::constant_int_value_matcher<0>();
}
/// Matches a constant scalar / vector splat / tensor splat integer that is any
/// non-zero value.
inline detail::constant_int_not_value_matcher<0> m_NonZero() {
return detail::constant_int_not_value_matcher<0>();
}
} // end namespace mlir
#endif // MLIR_MATCHERS_H

View File

@ -1070,27 +1070,20 @@ struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
PatternMatchResult matchAndRewrite(CondBranchOp condbr,
PatternRewriter &rewriter) const override {
// Check that the condition is a constant.
if (!matchPattern(condbr.getCondition(), m_Op<ConstantOp>()))
return matchFailure();
Block *foldedDest;
SmallVector<Value *, 4> branchArgs;
// If the condition is known to evaluate to false we fold to a branch to the
// false destination. Otherwise, we fold to a branch to the true
// destination.
if (matchPattern(condbr.getCondition(), m_Zero())) {
foldedDest = condbr.getFalseDest();
branchArgs.assign(condbr.false_operand_begin(),
condbr.false_operand_end());
} else {
foldedDest = condbr.getTrueDest();
branchArgs.assign(condbr.true_operand_begin(), condbr.true_operand_end());
if (matchPattern(condbr.getCondition(), m_NonZero())) {
// True branch taken.
rewriter.replaceOpWithNewOp<BranchOp>(
condbr, condbr.getTrueDest(),
llvm::to_vector<4>(condbr.getTrueOperands()));
return matchSuccess();
} else if (matchPattern(condbr.getCondition(), m_Zero())) {
// False branch taken.
rewriter.replaceOpWithNewOp<BranchOp>(
condbr, condbr.getFalseDest(),
llvm::to_vector<4>(condbr.getFalseOperands()));
return matchSuccess();
}
rewriter.replaceOpWithNewOp<BranchOp>(condbr, foldedDest, branchArgs);
return matchSuccess();
return matchFailure();
}
};
} // end anonymous namespace.