diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp index 6987faf6bca1..6294d195cea4 100644 --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -528,17 +528,26 @@ static Value *ThreadCmpOverPHI(CmpInst::Predicate Pred, Value *LHS, Value *RHS, return CommonValue; } +static Constant *foldOrCommuteConstant(Instruction::BinaryOps Opcode, + Value *&Op0, Value *&Op1, + const Query &Q) { + if (auto *CLHS = dyn_cast(Op0)) { + if (auto *CRHS = dyn_cast(Op1)) + return ConstantFoldBinaryOpOperands(Opcode, CLHS, CRHS, Q.DL); + + // Canonicalize the constant to the RHS if this is a commutative operation. + if (Instruction::isCommutative(Opcode)) + std::swap(Op0, Op1); + } + return nullptr; +} + /// Given operands for an Add, see if we can fold the result. /// If not, this returns null. static Value *SimplifyAddInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, const Query &Q, unsigned MaxRecurse) { - if (Constant *CLHS = dyn_cast(Op0)) { - if (Constant *CRHS = dyn_cast(Op1)) - return ConstantFoldBinaryOpOperands(Instruction::Add, CLHS, CRHS, Q.DL); - - // Canonicalize the constant to the RHS. - std::swap(Op0, Op1); - } + if (Constant *C = foldOrCommuteConstant(Instruction::Add, Op0, Op1, Q)) + return C; // X + undef -> undef if (match(Op1, m_Undef())) @@ -674,9 +683,8 @@ static Constant *computePointerDifference(const DataLayout &DL, Value *LHS, /// If not, this returns null. static Value *SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, const Query &Q, unsigned MaxRecurse) { - if (Constant *CLHS = dyn_cast(Op0)) - if (Constant *CRHS = dyn_cast(Op1)) - return ConstantFoldBinaryOpOperands(Instruction::Sub, CLHS, CRHS, Q.DL); + if (Constant *C = foldOrCommuteConstant(Instruction::Sub, Op0, Op1, Q)) + return C; // X - undef -> undef // undef - X -> undef @@ -816,13 +824,8 @@ Value *llvm::SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, /// returns null. static Value *SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF, const Query &Q, unsigned MaxRecurse) { - if (Constant *CLHS = dyn_cast(Op0)) { - if (Constant *CRHS = dyn_cast(Op1)) - return ConstantFoldBinaryOpOperands(Instruction::FAdd, CLHS, CRHS, Q.DL); - - // Canonicalize the constant to the RHS. - std::swap(Op0, Op1); - } + if (Constant *C = foldOrCommuteConstant(Instruction::FAdd, Op0, Op1, Q)) + return C; // fadd X, -0 ==> X if (match(Op1, m_NegZero())) @@ -855,10 +858,8 @@ static Value *SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF, /// returns null. static Value *SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, const Query &Q, unsigned MaxRecurse) { - if (Constant *CLHS = dyn_cast(Op0)) { - if (Constant *CRHS = dyn_cast(Op1)) - return ConstantFoldBinaryOpOperands(Instruction::FSub, CLHS, CRHS, Q.DL); - } + if (Constant *C = foldOrCommuteConstant(Instruction::FSub, Op0, Op1, Q)) + return C; // fsub X, 0 ==> X if (match(Op1, m_Zero())) @@ -889,13 +890,8 @@ static Value *SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, /// Given the operands for an FMul, see if we can fold the result static Value *SimplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF, const Query &Q, unsigned MaxRecurse) { - if (Constant *CLHS = dyn_cast(Op0)) { - if (Constant *CRHS = dyn_cast(Op1)) - return ConstantFoldBinaryOpOperands(Instruction::FMul, CLHS, CRHS, Q.DL); - - // Canonicalize the constant to the RHS. - std::swap(Op0, Op1); - } + if (Constant *C = foldOrCommuteConstant(Instruction::FMul, Op0, Op1, Q)) + return C; // fmul X, 1.0 ==> X if (match(Op1, m_FPOne())) @@ -912,13 +908,8 @@ static Value *SimplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF, /// If not, this returns null. static Value *SimplifyMulInst(Value *Op0, Value *Op1, const Query &Q, unsigned MaxRecurse) { - if (Constant *CLHS = dyn_cast(Op0)) { - if (Constant *CRHS = dyn_cast(Op1)) - return ConstantFoldBinaryOpOperands(Instruction::Mul, CLHS, CRHS, Q.DL); - - // Canonicalize the constant to the RHS. - std::swap(Op0, Op1); - } + if (Constant *C = foldOrCommuteConstant(Instruction::Mul, Op0, Op1, Q)) + return C; // X * undef -> 0 if (match(Op1, m_Undef())) @@ -1060,9 +1051,8 @@ static Value *simplifyDivRem(Value *Op0, Value *Op1, bool IsDiv) { /// If not, this returns null. static Value *SimplifyDiv(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, const Query &Q, unsigned MaxRecurse) { - if (Constant *C0 = dyn_cast(Op0)) - if (Constant *C1 = dyn_cast(Op1)) - return ConstantFoldBinaryOpOperands(Opcode, C0, C1, Q.DL); + if (Constant *C = foldOrCommuteConstant(Opcode, Op0, Op1, Q)) + return C; if (Value *V = simplifyDivRem(Op0, Op1, true)) return V; @@ -1162,6 +1152,9 @@ Value *llvm::SimplifyUDivInst(Value *Op0, Value *Op1, const DataLayout &DL, static Value *SimplifyFDivInst(Value *Op0, Value *Op1, FastMathFlags FMF, const Query &Q, unsigned) { + if (Constant *C = foldOrCommuteConstant(Instruction::FDiv, Op0, Op1, Q)) + return C; + // undef / X -> undef (the undef could be a snan). if (match(Op0, m_Undef())) return Op0; @@ -1211,9 +1204,8 @@ Value *llvm::SimplifyFDivInst(Value *Op0, Value *Op1, FastMathFlags FMF, /// If not, this returns null. static Value *SimplifyRem(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, const Query &Q, unsigned MaxRecurse) { - if (Constant *C0 = dyn_cast(Op0)) - if (Constant *C1 = dyn_cast(Op1)) - return ConstantFoldBinaryOpOperands(Opcode, C0, C1, Q.DL); + if (Constant *C = foldOrCommuteConstant(Opcode, Op0, Op1, Q)) + return C; if (Value *V = simplifyDivRem(Op0, Op1, false)) return V; @@ -1287,7 +1279,10 @@ Value *llvm::SimplifyURemInst(Value *Op0, Value *Op1, const DataLayout &DL, } static Value *SimplifyFRemInst(Value *Op0, Value *Op1, FastMathFlags FMF, - const Query &, unsigned) { + const Query &Q, unsigned) { + if (Constant *C = foldOrCommuteConstant(Instruction::FRem, Op0, Op1, Q)) + return C; + // undef % X -> undef (the undef could be a snan). if (match(Op0, m_Undef())) return Op0; @@ -1343,11 +1338,10 @@ static bool isUndefShift(Value *Amount) { /// Given operands for an Shl, LShr or AShr, see if we can fold the result. /// If not, this returns null. -static Value *SimplifyShift(unsigned Opcode, Value *Op0, Value *Op1, - const Query &Q, unsigned MaxRecurse) { - if (Constant *C0 = dyn_cast(Op0)) - if (Constant *C1 = dyn_cast(Op1)) - return ConstantFoldBinaryOpOperands(Opcode, C0, C1, Q.DL); +static Value *SimplifyShift(Instruction::BinaryOps Opcode, Value *Op0, + Value *Op1, const Query &Q, unsigned MaxRecurse) { + if (Constant *C = foldOrCommuteConstant(Opcode, Op0, Op1, Q)) + return C; // 0 shift by X -> 0 if (match(Op0, m_Zero())) @@ -1394,8 +1388,8 @@ static Value *SimplifyShift(unsigned Opcode, Value *Op0, Value *Op1, /// \brief Given operands for an Shl, LShr or AShr, see if we can /// fold the result. If not, this returns null. -static Value *SimplifyRightShift(unsigned Opcode, Value *Op0, Value *Op1, - bool isExact, const Query &Q, +static Value *SimplifyRightShift(Instruction::BinaryOps Opcode, Value *Op0, + Value *Op1, bool isExact, const Query &Q, unsigned MaxRecurse) { if (Value *V = SimplifyShift(Opcode, Op0, Op1, Q, MaxRecurse)) return V; @@ -1644,13 +1638,8 @@ static Value *SimplifyAndOfICmps(ICmpInst *Op0, ICmpInst *Op1) { /// If not, this returns null. static Value *SimplifyAndInst(Value *Op0, Value *Op1, const Query &Q, unsigned MaxRecurse) { - if (Constant *CLHS = dyn_cast(Op0)) { - if (Constant *CRHS = dyn_cast(Op1)) - return ConstantFoldBinaryOpOperands(Instruction::And, CLHS, CRHS, Q.DL); - - // Canonicalize the constant to the RHS. - std::swap(Op0, Op1); - } + if (Constant *C = foldOrCommuteConstant(Instruction::And, Op0, Op1, Q)) + return C; // X & undef -> 0 if (match(Op1, m_Undef())) @@ -1846,13 +1835,8 @@ static Value *SimplifyOrOfICmps(ICmpInst *Op0, ICmpInst *Op1) { /// If not, this returns null. static Value *SimplifyOrInst(Value *Op0, Value *Op1, const Query &Q, unsigned MaxRecurse) { - if (Constant *CLHS = dyn_cast(Op0)) { - if (Constant *CRHS = dyn_cast(Op1)) - return ConstantFoldBinaryOpOperands(Instruction::Or, CLHS, CRHS, Q.DL); - - // Canonicalize the constant to the RHS. - std::swap(Op0, Op1); - } + if (Constant *C = foldOrCommuteConstant(Instruction::Or, Op0, Op1, Q)) + return C; // X | undef -> -1 if (match(Op1, m_Undef())) @@ -1979,13 +1963,8 @@ Value *llvm::SimplifyOrInst(Value *Op0, Value *Op1, const DataLayout &DL, /// If not, this returns null. static Value *SimplifyXorInst(Value *Op0, Value *Op1, const Query &Q, unsigned MaxRecurse) { - if (Constant *CLHS = dyn_cast(Op0)) { - if (Constant *CRHS = dyn_cast(Op1)) - return ConstantFoldBinaryOpOperands(Instruction::Xor, CLHS, CRHS, Q.DL); - - // Canonicalize the constant to the RHS. - std::swap(Op0, Op1); - } + if (Constant *C = foldOrCommuteConstant(Instruction::Xor, Op0, Op1, Q)) + return C; // A ^ undef -> undef if (match(Op1, m_Undef())) diff --git a/llvm/test/Transforms/InstSimplify/fdiv.ll b/llvm/test/Transforms/InstSimplify/fdiv.ll index 3499ae530935..6643afd81471 100644 --- a/llvm/test/Transforms/InstSimplify/fdiv.ll +++ b/llvm/test/Transforms/InstSimplify/fdiv.ll @@ -3,8 +3,7 @@ define float @fdiv_constant_fold() { ; CHECK-LABEL: @fdiv_constant_fold( -; CHECK-NEXT: [[F:%.*]] = fdiv float 3.000000e+00, 2.000000e+00 -; CHECK-NEXT: ret float [[F]] +; CHECK-NEXT: ret float 1.500000e+00 ; %f = fdiv float 3.0, 2.0 ret float %f @@ -12,8 +11,7 @@ define float @fdiv_constant_fold() { define float @frem_constant_fold() { ; CHECK-LABEL: @frem_constant_fold( -; CHECK-NEXT: [[F:%.*]] = frem float 3.000000e+00, 2.000000e+00 -; CHECK-NEXT: ret float [[F]] +; CHECK-NEXT: ret float 1.000000e+00 ; %f = frem float 3.0, 2.0 ret float %f