diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index a301d9eef60e..0b9128a9f5a1 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -124,6 +124,50 @@ static Constant *getLogBase2(Type *Ty, Constant *C) { return ConstantVector::get(Elts); } +// TODO: This is a specific form of a much more general pattern. +// We could detect a select with any binop identity constant, or we +// could use SimplifyBinOp to see if either arm of the select reduces. +// But that needs to be done carefully and/or while removing potential +// reverse canonicalizations as in InstCombiner::foldSelectIntoOp(). +static Value *foldMulSelectToNegate(BinaryOperator &I, + InstCombiner::BuilderTy &Builder) { + Value *Cond, *OtherOp; + + // mul (select Cond, 1, -1), OtherOp --> select Cond, OtherOp, -OtherOp + // mul OtherOp, (select Cond, 1, -1) --> select Cond, OtherOp, -OtherOp + if (match(&I, m_c_Mul(m_OneUse(m_Select(m_Value(Cond), m_One(), m_AllOnes())), + m_Value(OtherOp)))) + return Builder.CreateSelect(Cond, OtherOp, Builder.CreateNeg(OtherOp)); + + // mul (select Cond, -1, 1), OtherOp --> select Cond, -OtherOp, OtherOp + // mul OtherOp, (select Cond, -1, 1) --> select Cond, -OtherOp, OtherOp + if (match(&I, m_c_Mul(m_OneUse(m_Select(m_Value(Cond), m_AllOnes(), m_One())), + m_Value(OtherOp)))) + return Builder.CreateSelect(Cond, Builder.CreateNeg(OtherOp), OtherOp); + + // fmul (select Cond, 1.0, -1.0), OtherOp --> select Cond, OtherOp, -OtherOp + // fmul OtherOp, (select Cond, 1.0, -1.0) --> select Cond, OtherOp, -OtherOp + if (match(&I, m_c_FMul(m_OneUse(m_Select(m_Value(Cond), m_SpecificFP(1.0), + m_SpecificFP(-1.0))), + m_Value(OtherOp)))) { + IRBuilder<>::FastMathFlagGuard FMFGuard(Builder); + Builder.setFastMathFlags(I.getFastMathFlags()); + return Builder.CreateSelect(Cond, OtherOp, Builder.CreateFNeg(OtherOp)); + } + + // fmul (select Cond, -1.0, 1.0), OtherOp --> select Cond, -OtherOp, OtherOp + // fmul OtherOp, (select Cond, -1.0, 1.0) --> select Cond, -OtherOp, OtherOp + if (match(&I, m_c_FMul(m_OneUse(m_Select(m_Value(Cond), m_SpecificFP(-1.0), + m_SpecificFP(1.0))), + m_Value(OtherOp)))) { + IRBuilder<>::FastMathFlagGuard FMFGuard(Builder); + Builder.setFastMathFlags(I.getFastMathFlags()); + return Builder.CreateSelect(Cond, Builder.CreateFNeg(OtherOp), OtherOp); + } + + return nullptr; +} + Instruction *InstCombiner::visitMul(BinaryOperator &I) { if (Value *V = SimplifyMulInst(I.getOperand(0), I.getOperand(1), SQ.getWithInstruction(&I))) @@ -213,24 +257,8 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { if (Instruction *FoldedMul = foldBinOpIntoSelectOrPhi(I)) return FoldedMul; - // TODO: This is a specific form of a much more general pattern. - // We could detect a select with any binop identity constant, or we - // could use SimplifyBinOp to see if either arm of the select reduces. - // But that needs to be done carefully and/or while removing potential - // reverse canonicalizations as in InstCombiner::foldSelectIntoOp(). - // mul (select Cond, 1, -1), Op1 --> select Cond, Op1, -Op1 - // mul (select Cond, -1, 1), Op1 --> select Cond, -Op1, Op1 - // mul Op0, (select Cond, 1, -1) --> select Cond, Op0, -Op0 - // mul Op0, (select Cond, -1, 1) --> select Cond, -Op0, Op0 - Value *Cond; - if (match(Op0, m_OneUse(m_Select(m_Value(Cond), m_One(), m_AllOnes())))) - return SelectInst::Create(Cond, Op1, Builder.CreateNeg(Op1)); - if (match(Op0, m_OneUse(m_Select(m_Value(Cond), m_AllOnes(), m_One())))) - return SelectInst::Create(Cond, Builder.CreateNeg(Op1), Op1); - if (match(Op1, m_OneUse(m_Select(m_Value(Cond), m_One(), m_AllOnes())))) - return SelectInst::Create(Cond, Op0, Builder.CreateNeg(Op0)); - if (match(Op1, m_OneUse(m_Select(m_Value(Cond), m_AllOnes(), m_One())))) - return SelectInst::Create(Cond, Builder.CreateNeg(Op0), Op0); + if (Value *FoldedMul = foldMulSelectToNegate(I, Builder)) + return replaceInstUsesWith(I, FoldedMul); // Simplify mul instructions with a constant RHS. if (isa(Op1)) { @@ -377,6 +405,9 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { if (Instruction *FoldedMul = foldBinOpIntoSelectOrPhi(I)) return FoldedMul; + if (Value *FoldedMul = foldMulSelectToNegate(I, Builder)) + return replaceInstUsesWith(I, FoldedMul); + // X * -1.0 --> -X Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (match(Op1, m_SpecificFP(-1.0))) diff --git a/llvm/test/Transforms/InstCombine/fmul.ll b/llvm/test/Transforms/InstCombine/fmul.ll index 1bcca95d0453..89c957b9d083 100644 --- a/llvm/test/Transforms/InstCombine/fmul.ll +++ b/llvm/test/Transforms/InstCombine/fmul.ll @@ -994,9 +994,9 @@ define double @fmul_negated_constant_expression(double %x) { define float @negate_if_true(float %x, i1 %cond) { ; CHECK-LABEL: @negate_if_true( -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[COND:%.*]], float -1.000000e+00, float 1.000000e+00 -; CHECK-NEXT: [[R:%.*]] = fmul float [[SEL]], [[X:%.*]] -; CHECK-NEXT: ret float [[R]] +; CHECK-NEXT: [[TMP1:%.*]] = fsub float -0.000000e+00, [[X:%.*]] +; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[COND:%.*]], float [[TMP1]], float [[X]] +; CHECK-NEXT: ret float [[TMP2]] ; %sel = select i1 %cond, float -1.0, float 1.0 %r = fmul float %sel, %x @@ -1005,9 +1005,9 @@ define float @negate_if_true(float %x, i1 %cond) { define float @negate_if_false(float %x, i1 %cond) { ; CHECK-LABEL: @negate_if_false( -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[COND:%.*]], float 1.000000e+00, float -1.000000e+00 -; CHECK-NEXT: [[R:%.*]] = fmul arcp float [[SEL]], [[X:%.*]] -; CHECK-NEXT: ret float [[R]] +; CHECK-NEXT: [[TMP1:%.*]] = fsub arcp float -0.000000e+00, [[X:%.*]] +; CHECK-NEXT: [[TMP2:%.*]] = select arcp i1 [[COND:%.*]], float [[X]], float [[TMP1]] +; CHECK-NEXT: ret float [[TMP2]] ; %sel = select i1 %cond, float 1.0, float -1.0 %r = fmul arcp float %sel, %x @@ -1017,9 +1017,9 @@ define float @negate_if_false(float %x, i1 %cond) { define <2 x double> @negate_if_true_commute(<2 x double> %px, i1 %cond) { ; CHECK-LABEL: @negate_if_true_commute( ; CHECK-NEXT: [[X:%.*]] = fdiv <2 x double> , [[PX:%.*]] -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[COND:%.*]], <2 x double> , <2 x double> -; CHECK-NEXT: [[R:%.*]] = fmul ninf <2 x double> [[X]], [[SEL]] -; CHECK-NEXT: ret <2 x double> [[R]] +; CHECK-NEXT: [[TMP1:%.*]] = fsub ninf <2 x double> , [[X]] +; CHECK-NEXT: [[TMP2:%.*]] = select ninf i1 [[COND:%.*]], <2 x double> [[TMP1]], <2 x double> [[X]] +; CHECK-NEXT: ret <2 x double> [[TMP2]] ; %x = fdiv <2 x double> , %px ; thwart complexity-based canonicalization %sel = select i1 %cond, <2 x double> , <2 x double> @@ -1030,9 +1030,9 @@ define <2 x double> @negate_if_true_commute(<2 x double> %px, i1 %cond) { define <2 x double> @negate_if_false_commute(<2 x double> %px, <2 x i1> %cond) { ; CHECK-LABEL: @negate_if_false_commute( ; CHECK-NEXT: [[X:%.*]] = fdiv <2 x double> , [[PX:%.*]] -; CHECK-NEXT: [[SEL:%.*]] = select <2 x i1> [[COND:%.*]], <2 x double> , <2 x double> -; CHECK-NEXT: [[R:%.*]] = fmul <2 x double> [[X]], [[SEL]] -; CHECK-NEXT: ret <2 x double> [[R]] +; CHECK-NEXT: [[TMP1:%.*]] = fsub <2 x double> , [[X]] +; CHECK-NEXT: [[TMP2:%.*]] = select <2 x i1> [[COND:%.*]], <2 x double> [[X]], <2 x double> [[TMP1]] +; CHECK-NEXT: ret <2 x double> [[TMP2]] ; %x = fdiv <2 x double> , %px ; thwart complexity-based canonicalization %sel = select <2 x i1> %cond, <2 x double> , <2 x double> @@ -1040,6 +1040,8 @@ define <2 x double> @negate_if_false_commute(<2 x double> %px, <2 x i1> %cond) { ret <2 x double> %r } +; Negative test + define float @negate_if_true_extra_use(float %x, i1 %cond) { ; CHECK-LABEL: @negate_if_true_extra_use( ; CHECK-NEXT: [[SEL:%.*]] = select i1 [[COND:%.*]], float -1.000000e+00, float 1.000000e+00 @@ -1053,6 +1055,8 @@ define float @negate_if_true_extra_use(float %x, i1 %cond) { ret float %r } +; Negative test + define <2 x double> @negate_if_true_wrong_constant(<2 x double> %px, i1 %cond) { ; CHECK-LABEL: @negate_if_true_wrong_constant( ; CHECK-NEXT: [[X:%.*]] = fdiv <2 x double> , [[PX:%.*]]