[InstSimplify] add constant folding for fdiv/frem

Also, add a helper function so we don't have to repeat this code for each binop.

llvm-svn: 299309
This commit is contained in:
Sanjay Patel 2017-04-01 19:05:11 +00:00
parent ee0f5cc41f
commit 8b5ad3f00e
2 changed files with 51 additions and 74 deletions

View File

@ -528,17 +528,26 @@ static Value *ThreadCmpOverPHI(CmpInst::Predicate Pred, Value *LHS, Value *RHS,
return CommonValue;
}
static Constant *foldOrCommuteConstant(Instruction::BinaryOps Opcode,
Value *&Op0, Value *&Op1,
const Query &Q) {
if (auto *CLHS = dyn_cast<Constant>(Op0)) {
if (auto *CRHS = dyn_cast<Constant>(Op1))
return ConstantFoldBinaryOpOperands(Opcode, CLHS, CRHS, Q.DL);
// Canonicalize the constant to the RHS if this is a commutative operation.
if (Instruction::isCommutative(Opcode))
std::swap(Op0, Op1);
}
return nullptr;
}
/// Given operands for an Add, see if we can fold the result.
/// If not, this returns null.
static Value *SimplifyAddInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW,
const Query &Q, unsigned MaxRecurse) {
if (Constant *CLHS = dyn_cast<Constant>(Op0)) {
if (Constant *CRHS = dyn_cast<Constant>(Op1))
return ConstantFoldBinaryOpOperands(Instruction::Add, CLHS, CRHS, Q.DL);
// Canonicalize the constant to the RHS.
std::swap(Op0, Op1);
}
if (Constant *C = foldOrCommuteConstant(Instruction::Add, Op0, Op1, Q))
return C;
// X + undef -> undef
if (match(Op1, m_Undef()))
@ -674,9 +683,8 @@ static Constant *computePointerDifference(const DataLayout &DL, Value *LHS,
/// If not, this returns null.
static Value *SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW,
const Query &Q, unsigned MaxRecurse) {
if (Constant *CLHS = dyn_cast<Constant>(Op0))
if (Constant *CRHS = dyn_cast<Constant>(Op1))
return ConstantFoldBinaryOpOperands(Instruction::Sub, CLHS, CRHS, Q.DL);
if (Constant *C = foldOrCommuteConstant(Instruction::Sub, Op0, Op1, Q))
return C;
// X - undef -> undef
// undef - X -> undef
@ -816,13 +824,8 @@ Value *llvm::SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW,
/// returns null.
static Value *SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF,
const Query &Q, unsigned MaxRecurse) {
if (Constant *CLHS = dyn_cast<Constant>(Op0)) {
if (Constant *CRHS = dyn_cast<Constant>(Op1))
return ConstantFoldBinaryOpOperands(Instruction::FAdd, CLHS, CRHS, Q.DL);
// Canonicalize the constant to the RHS.
std::swap(Op0, Op1);
}
if (Constant *C = foldOrCommuteConstant(Instruction::FAdd, Op0, Op1, Q))
return C;
// fadd X, -0 ==> X
if (match(Op1, m_NegZero()))
@ -855,10 +858,8 @@ static Value *SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF,
/// returns null.
static Value *SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF,
const Query &Q, unsigned MaxRecurse) {
if (Constant *CLHS = dyn_cast<Constant>(Op0)) {
if (Constant *CRHS = dyn_cast<Constant>(Op1))
return ConstantFoldBinaryOpOperands(Instruction::FSub, CLHS, CRHS, Q.DL);
}
if (Constant *C = foldOrCommuteConstant(Instruction::FSub, Op0, Op1, Q))
return C;
// fsub X, 0 ==> X
if (match(Op1, m_Zero()))
@ -889,13 +890,8 @@ static Value *SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF,
/// Given the operands for an FMul, see if we can fold the result
static Value *SimplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF,
const Query &Q, unsigned MaxRecurse) {
if (Constant *CLHS = dyn_cast<Constant>(Op0)) {
if (Constant *CRHS = dyn_cast<Constant>(Op1))
return ConstantFoldBinaryOpOperands(Instruction::FMul, CLHS, CRHS, Q.DL);
// Canonicalize the constant to the RHS.
std::swap(Op0, Op1);
}
if (Constant *C = foldOrCommuteConstant(Instruction::FMul, Op0, Op1, Q))
return C;
// fmul X, 1.0 ==> X
if (match(Op1, m_FPOne()))
@ -912,13 +908,8 @@ static Value *SimplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF,
/// If not, this returns null.
static Value *SimplifyMulInst(Value *Op0, Value *Op1, const Query &Q,
unsigned MaxRecurse) {
if (Constant *CLHS = dyn_cast<Constant>(Op0)) {
if (Constant *CRHS = dyn_cast<Constant>(Op1))
return ConstantFoldBinaryOpOperands(Instruction::Mul, CLHS, CRHS, Q.DL);
// Canonicalize the constant to the RHS.
std::swap(Op0, Op1);
}
if (Constant *C = foldOrCommuteConstant(Instruction::Mul, Op0, Op1, Q))
return C;
// X * undef -> 0
if (match(Op1, m_Undef()))
@ -1060,9 +1051,8 @@ static Value *simplifyDivRem(Value *Op0, Value *Op1, bool IsDiv) {
/// If not, this returns null.
static Value *SimplifyDiv(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1,
const Query &Q, unsigned MaxRecurse) {
if (Constant *C0 = dyn_cast<Constant>(Op0))
if (Constant *C1 = dyn_cast<Constant>(Op1))
return ConstantFoldBinaryOpOperands(Opcode, C0, C1, Q.DL);
if (Constant *C = foldOrCommuteConstant(Opcode, Op0, Op1, Q))
return C;
if (Value *V = simplifyDivRem(Op0, Op1, true))
return V;
@ -1162,6 +1152,9 @@ Value *llvm::SimplifyUDivInst(Value *Op0, Value *Op1, const DataLayout &DL,
static Value *SimplifyFDivInst(Value *Op0, Value *Op1, FastMathFlags FMF,
const Query &Q, unsigned) {
if (Constant *C = foldOrCommuteConstant(Instruction::FDiv, Op0, Op1, Q))
return C;
// undef / X -> undef (the undef could be a snan).
if (match(Op0, m_Undef()))
return Op0;
@ -1211,9 +1204,8 @@ Value *llvm::SimplifyFDivInst(Value *Op0, Value *Op1, FastMathFlags FMF,
/// If not, this returns null.
static Value *SimplifyRem(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1,
const Query &Q, unsigned MaxRecurse) {
if (Constant *C0 = dyn_cast<Constant>(Op0))
if (Constant *C1 = dyn_cast<Constant>(Op1))
return ConstantFoldBinaryOpOperands(Opcode, C0, C1, Q.DL);
if (Constant *C = foldOrCommuteConstant(Opcode, Op0, Op1, Q))
return C;
if (Value *V = simplifyDivRem(Op0, Op1, false))
return V;
@ -1287,7 +1279,10 @@ Value *llvm::SimplifyURemInst(Value *Op0, Value *Op1, const DataLayout &DL,
}
static Value *SimplifyFRemInst(Value *Op0, Value *Op1, FastMathFlags FMF,
const Query &, unsigned) {
const Query &Q, unsigned) {
if (Constant *C = foldOrCommuteConstant(Instruction::FRem, Op0, Op1, Q))
return C;
// undef % X -> undef (the undef could be a snan).
if (match(Op0, m_Undef()))
return Op0;
@ -1343,11 +1338,10 @@ static bool isUndefShift(Value *Amount) {
/// Given operands for an Shl, LShr or AShr, see if we can fold the result.
/// If not, this returns null.
static Value *SimplifyShift(unsigned Opcode, Value *Op0, Value *Op1,
const Query &Q, unsigned MaxRecurse) {
if (Constant *C0 = dyn_cast<Constant>(Op0))
if (Constant *C1 = dyn_cast<Constant>(Op1))
return ConstantFoldBinaryOpOperands(Opcode, C0, C1, Q.DL);
static Value *SimplifyShift(Instruction::BinaryOps Opcode, Value *Op0,
Value *Op1, const Query &Q, unsigned MaxRecurse) {
if (Constant *C = foldOrCommuteConstant(Opcode, Op0, Op1, Q))
return C;
// 0 shift by X -> 0
if (match(Op0, m_Zero()))
@ -1394,8 +1388,8 @@ static Value *SimplifyShift(unsigned Opcode, Value *Op0, Value *Op1,
/// \brief Given operands for an Shl, LShr or AShr, see if we can
/// fold the result. If not, this returns null.
static Value *SimplifyRightShift(unsigned Opcode, Value *Op0, Value *Op1,
bool isExact, const Query &Q,
static Value *SimplifyRightShift(Instruction::BinaryOps Opcode, Value *Op0,
Value *Op1, bool isExact, const Query &Q,
unsigned MaxRecurse) {
if (Value *V = SimplifyShift(Opcode, Op0, Op1, Q, MaxRecurse))
return V;
@ -1644,13 +1638,8 @@ static Value *SimplifyAndOfICmps(ICmpInst *Op0, ICmpInst *Op1) {
/// If not, this returns null.
static Value *SimplifyAndInst(Value *Op0, Value *Op1, const Query &Q,
unsigned MaxRecurse) {
if (Constant *CLHS = dyn_cast<Constant>(Op0)) {
if (Constant *CRHS = dyn_cast<Constant>(Op1))
return ConstantFoldBinaryOpOperands(Instruction::And, CLHS, CRHS, Q.DL);
// Canonicalize the constant to the RHS.
std::swap(Op0, Op1);
}
if (Constant *C = foldOrCommuteConstant(Instruction::And, Op0, Op1, Q))
return C;
// X & undef -> 0
if (match(Op1, m_Undef()))
@ -1846,13 +1835,8 @@ static Value *SimplifyOrOfICmps(ICmpInst *Op0, ICmpInst *Op1) {
/// If not, this returns null.
static Value *SimplifyOrInst(Value *Op0, Value *Op1, const Query &Q,
unsigned MaxRecurse) {
if (Constant *CLHS = dyn_cast<Constant>(Op0)) {
if (Constant *CRHS = dyn_cast<Constant>(Op1))
return ConstantFoldBinaryOpOperands(Instruction::Or, CLHS, CRHS, Q.DL);
// Canonicalize the constant to the RHS.
std::swap(Op0, Op1);
}
if (Constant *C = foldOrCommuteConstant(Instruction::Or, Op0, Op1, Q))
return C;
// X | undef -> -1
if (match(Op1, m_Undef()))
@ -1979,13 +1963,8 @@ Value *llvm::SimplifyOrInst(Value *Op0, Value *Op1, const DataLayout &DL,
/// If not, this returns null.
static Value *SimplifyXorInst(Value *Op0, Value *Op1, const Query &Q,
unsigned MaxRecurse) {
if (Constant *CLHS = dyn_cast<Constant>(Op0)) {
if (Constant *CRHS = dyn_cast<Constant>(Op1))
return ConstantFoldBinaryOpOperands(Instruction::Xor, CLHS, CRHS, Q.DL);
// Canonicalize the constant to the RHS.
std::swap(Op0, Op1);
}
if (Constant *C = foldOrCommuteConstant(Instruction::Xor, Op0, Op1, Q))
return C;
// A ^ undef -> undef
if (match(Op1, m_Undef()))

View File

@ -3,8 +3,7 @@
define float @fdiv_constant_fold() {
; CHECK-LABEL: @fdiv_constant_fold(
; CHECK-NEXT: [[F:%.*]] = fdiv float 3.000000e+00, 2.000000e+00
; CHECK-NEXT: ret float [[F]]
; CHECK-NEXT: ret float 1.500000e+00
;
%f = fdiv float 3.0, 2.0
ret float %f
@ -12,8 +11,7 @@ define float @fdiv_constant_fold() {
define float @frem_constant_fold() {
; CHECK-LABEL: @frem_constant_fold(
; CHECK-NEXT: [[F:%.*]] = frem float 3.000000e+00, 2.000000e+00
; CHECK-NEXT: ret float [[F]]
; CHECK-NEXT: ret float 1.000000e+00
;
%f = frem float 3.0, 2.0
ret float %f