forked from OSchip/llvm-project
[InstCombine] Limit FMul constant folding for fma simplifications.
As @reames pointed out post-commit, rL371518 adds additional rounding in some cases, when doing constant folding of the multiplication. This breaks a guarantee llvm.fma makes and must be avoided. This patch reapplies rL371518, but splits off the simplifications not requiring rounding from SimplifFMulInst as SimplifyFMAFMul. Reviewers: spatel, lebedev.ri, reames, scanon Reviewed By: reames Differential Revision: https://reviews.llvm.org/D67434 llvm-svn: 372899
This commit is contained in:
parent
24337db616
commit
f3ab99dcf8
|
@ -142,6 +142,13 @@ Value *SimplifyFSubInst(Value *LHS, Value *RHS, FastMathFlags FMF,
|
|||
Value *SimplifyFMulInst(Value *LHS, Value *RHS, FastMathFlags FMF,
|
||||
const SimplifyQuery &Q);
|
||||
|
||||
/// Given operands for the multiplication of a FMA, fold the result or return
|
||||
/// null. In contrast to SimplifyFMulInst, this function will not perform
|
||||
/// simplifications whose unrounded results differ when rounded to the argument
|
||||
/// type.
|
||||
Value *SimplifyFMAFMul(Value *LHS, Value *RHS, FastMathFlags FMF,
|
||||
const SimplifyQuery &Q);
|
||||
|
||||
/// Given operands for a Mul, fold the result or return null.
|
||||
Value *SimplifyMulInst(Value *LHS, Value *RHS, const SimplifyQuery &Q);
|
||||
|
||||
|
|
|
@ -4576,15 +4576,8 @@ static Value *SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF,
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
/// Given the operands for an FMul, see if we can fold the result
|
||||
static Value *SimplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF,
|
||||
const SimplifyQuery &Q, unsigned MaxRecurse) {
|
||||
if (Constant *C = foldOrCommuteConstant(Instruction::FMul, Op0, Op1, Q))
|
||||
return C;
|
||||
|
||||
if (Constant *C = simplifyFPBinop(Op0, Op1))
|
||||
return C;
|
||||
|
||||
static Value *SimplifyFMAFMul(Value *Op0, Value *Op1, FastMathFlags FMF,
|
||||
const SimplifyQuery &Q, unsigned MaxRecurse) {
|
||||
// fmul X, 1.0 ==> X
|
||||
if (match(Op1, m_FPOne()))
|
||||
return Op0;
|
||||
|
@ -4605,6 +4598,19 @@ static Value *SimplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF,
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
/// Given the operands for an FMul, see if we can fold the result
|
||||
static Value *SimplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF,
|
||||
const SimplifyQuery &Q, unsigned MaxRecurse) {
|
||||
if (Constant *C = foldOrCommuteConstant(Instruction::FMul, Op0, Op1, Q))
|
||||
return C;
|
||||
|
||||
if (Constant *C = simplifyFPBinop(Op0, Op1))
|
||||
return C;
|
||||
|
||||
// Now apply simplifications that do not require rounding.
|
||||
return SimplifyFMAFMul(Op0, Op1, FMF, Q, MaxRecurse);
|
||||
}
|
||||
|
||||
Value *llvm::SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF,
|
||||
const SimplifyQuery &Q) {
|
||||
return ::SimplifyFAddInst(Op0, Op1, FMF, Q, RecursionLimit);
|
||||
|
@ -4621,6 +4627,11 @@ Value *llvm::SimplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF,
|
|||
return ::SimplifyFMulInst(Op0, Op1, FMF, Q, RecursionLimit);
|
||||
}
|
||||
|
||||
Value *llvm::SimplifyFMAFMul(Value *Op0, Value *Op1, FastMathFlags FMF,
|
||||
const SimplifyQuery &Q) {
|
||||
return ::SimplifyFMAFMul(Op0, Op1, FMF, Q, RecursionLimit);
|
||||
}
|
||||
|
||||
static Value *SimplifyFDivInst(Value *Op0, Value *Op1, FastMathFlags FMF,
|
||||
const SimplifyQuery &Q, unsigned) {
|
||||
if (Constant *C = foldOrCommuteConstant(Instruction::FDiv, Op0, Op1, Q))
|
||||
|
|
|
@ -2234,6 +2234,15 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
|
|||
return replaceInstUsesWith(*II, Add);
|
||||
}
|
||||
|
||||
// Try to simplify the underlying FMul.
|
||||
if (Value *V = SimplifyFMulInst(II->getArgOperand(0), II->getArgOperand(1),
|
||||
II->getFastMathFlags(),
|
||||
SQ.getWithInstruction(II))) {
|
||||
auto *FAdd = BinaryOperator::CreateFAdd(V, II->getArgOperand(2));
|
||||
FAdd->copyFastMathFlags(II);
|
||||
return FAdd;
|
||||
}
|
||||
|
||||
LLVM_FALLTHROUGH;
|
||||
}
|
||||
case Intrinsic::fma: {
|
||||
|
@ -2258,9 +2267,12 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
|
|||
return II;
|
||||
}
|
||||
|
||||
// fma x, 1, z -> fadd x, z
|
||||
if (match(Src1, m_FPOne())) {
|
||||
auto *FAdd = BinaryOperator::CreateFAdd(Src0, II->getArgOperand(2));
|
||||
// Try to simplify the underlying FMul. We can only apply simplifications
|
||||
// that do not require rounding.
|
||||
if (Value *V = SimplifyFMAFMul(II->getArgOperand(0), II->getArgOperand(1),
|
||||
II->getFastMathFlags(),
|
||||
SQ.getWithInstruction(II))) {
|
||||
auto *FAdd = BinaryOperator::CreateFAdd(V, II->getArgOperand(2));
|
||||
FAdd->copyFastMathFlags(II);
|
||||
return FAdd;
|
||||
}
|
||||
|
|
|
@ -372,8 +372,7 @@ define float @fmuladd_x_1_z_fast(float %x, float %z) {
|
|||
define <2 x double> @fmuladd_a_0_b(<2 x double> %a, <2 x double> %b) {
|
||||
; CHECK-LABEL: @fmuladd_a_0_b(
|
||||
; CHECK-NEXT: entry:
|
||||
; CHECK-NEXT: [[RES:%.*]] = call nnan nsz <2 x double> @llvm.fmuladd.v2f64(<2 x double> [[A:%.*]], <2 x double> zeroinitializer, <2 x double> [[B:%.*]])
|
||||
; CHECK-NEXT: ret <2 x double> [[RES]]
|
||||
; CHECK-NEXT: ret <2 x double> [[B:%.*]]
|
||||
;
|
||||
entry:
|
||||
%res = call nnan nsz <2 x double> @llvm.fmuladd.v2f64(<2 x double> %a, <2 x double> zeroinitializer, <2 x double> %b)
|
||||
|
@ -383,8 +382,7 @@ entry:
|
|||
define <2 x double> @fmuladd_0_a_b(<2 x double> %a, <2 x double> %b) {
|
||||
; CHECK-LABEL: @fmuladd_0_a_b(
|
||||
; CHECK-NEXT: entry:
|
||||
; CHECK-NEXT: [[RES:%.*]] = call nnan nsz <2 x double> @llvm.fmuladd.v2f64(<2 x double> [[A:%.*]], <2 x double> zeroinitializer, <2 x double> [[B:%.*]])
|
||||
; CHECK-NEXT: ret <2 x double> [[RES]]
|
||||
; CHECK-NEXT: ret <2 x double> [[B:%.*]]
|
||||
;
|
||||
entry:
|
||||
%res = call nnan nsz <2 x double> @llvm.fmuladd.v2f64(<2 x double> zeroinitializer, <2 x double> %a, <2 x double> %b)
|
||||
|
@ -407,8 +405,7 @@ declare <2 x double> @llvm.fmuladd.v2f64(<2 x double>, <2 x double>, <2 x double
|
|||
define <2 x double> @fma_a_0_b(<2 x double> %a, <2 x double> %b) {
|
||||
; CHECK-LABEL: @fma_a_0_b(
|
||||
; CHECK-NEXT: entry:
|
||||
; CHECK-NEXT: [[RES:%.*]] = call nnan nsz <2 x double> @llvm.fma.v2f64(<2 x double> [[A:%.*]], <2 x double> zeroinitializer, <2 x double> [[B:%.*]])
|
||||
; CHECK-NEXT: ret <2 x double> [[RES]]
|
||||
; CHECK-NEXT: ret <2 x double> [[B:%.*]]
|
||||
;
|
||||
entry:
|
||||
%res = call nnan nsz <2 x double> @llvm.fma.v2f64(<2 x double> %a, <2 x double> zeroinitializer, <2 x double> %b)
|
||||
|
@ -418,8 +415,7 @@ entry:
|
|||
define <2 x double> @fma_0_a_b(<2 x double> %a, <2 x double> %b) {
|
||||
; CHECK-LABEL: @fma_0_a_b(
|
||||
; CHECK-NEXT: entry:
|
||||
; CHECK-NEXT: [[RES:%.*]] = call nnan nsz <2 x double> @llvm.fma.v2f64(<2 x double> [[A:%.*]], <2 x double> zeroinitializer, <2 x double> [[B:%.*]])
|
||||
; CHECK-NEXT: ret <2 x double> [[RES]]
|
||||
; CHECK-NEXT: ret <2 x double> [[B:%.*]]
|
||||
;
|
||||
entry:
|
||||
%res = call nnan nsz <2 x double> @llvm.fma.v2f64(<2 x double> zeroinitializer, <2 x double> %a, <2 x double> %b)
|
||||
|
@ -440,8 +436,7 @@ entry:
|
|||
define <2 x double> @fma_sqrt(<2 x double> %a, <2 x double> %b) {
|
||||
; CHECK-LABEL: @fma_sqrt(
|
||||
; CHECK-NEXT: entry:
|
||||
; CHECK-NEXT: [[SQRT:%.*]] = call fast <2 x double> @llvm.sqrt.v2f64(<2 x double> [[A:%.*]])
|
||||
; CHECK-NEXT: [[RES:%.*]] = call fast <2 x double> @llvm.fma.v2f64(<2 x double> [[SQRT]], <2 x double> [[SQRT]], <2 x double> [[B:%.*]])
|
||||
; CHECK-NEXT: [[RES:%.*]] = fadd fast <2 x double> [[A:%.*]], [[B:%.*]]
|
||||
; CHECK-NEXT: ret <2 x double> [[RES]]
|
||||
;
|
||||
entry:
|
||||
|
@ -450,6 +445,71 @@ entry:
|
|||
ret <2 x double> %res
|
||||
}
|
||||
|
||||
; We do not fold constant multiplies in FMAs, as they could require rounding, unless either constant is 0.0 or 1.0.
|
||||
define <2 x double> @fma_const_fmul(<2 x double> %b) {
|
||||
; CHECK-LABEL: @fma_const_fmul(
|
||||
; CHECK-NEXT: entry:
|
||||
; CHECK-NEXT: [[RES:%.*]] = call nnan nsz <2 x double> @llvm.fma.v2f64(<2 x double> <double 0x4131233302898702, double 0x40C387800000D6C0>, <2 x double> <double 1.291820e-08, double 9.123000e-06>, <2 x double> [[B:%.*]])
|
||||
; CHECK-NEXT: ret <2 x double> [[RES]]
|
||||
;
|
||||
entry:
|
||||
%res = call nnan nsz <2 x double> @llvm.fma.v2f64(<2 x double> <double 1123123.0099110012314, double 9999.0000001>, <2 x double> <double 0.0000000129182, double 0.000009123>, <2 x double> %b)
|
||||
ret <2 x double> %res
|
||||
}
|
||||
|
||||
define <2 x double> @fma_const_fmul_zero(<2 x double> %b) {
|
||||
; CHECK-LABEL: @fma_const_fmul_zero(
|
||||
; CHECK-NEXT: entry:
|
||||
; CHECK-NEXT: [[RES:%.*]] = call nnan nsz <2 x double> @llvm.fma.v2f64(<2 x double> zeroinitializer, <2 x double> <double 0x4131233302898702, double 0x40C387800000D6C0>, <2 x double> [[B:%.*]])
|
||||
; CHECK-NEXT: ret <2 x double> [[RES]]
|
||||
;
|
||||
entry:
|
||||
%res = call nnan nsz <2 x double> @llvm.fma.v2f64(<2 x double> <double 0.0, double 0.0>, <2 x double> <double 1123123.0099110012314, double 9999.0000001>, <2 x double> %b)
|
||||
ret <2 x double> %res
|
||||
}
|
||||
|
||||
define <2 x double> @fma_const_fmul_zero2(<2 x double> %b) {
|
||||
; CHECK-LABEL: @fma_const_fmul_zero2(
|
||||
; CHECK-NEXT: entry:
|
||||
; CHECK-NEXT: ret <2 x double> [[B:%.*]]
|
||||
;
|
||||
entry:
|
||||
%res = call nnan nsz <2 x double> @llvm.fma.v2f64(<2 x double> <double 1123123.0099110012314, double 9999.0000001>, <2 x double> <double 0.0, double 0.0>, <2 x double> %b)
|
||||
ret <2 x double> %res
|
||||
}
|
||||
|
||||
define <2 x double> @fma_const_fmul_one(<2 x double> %b) {
|
||||
; CHECK-LABEL: @fma_const_fmul_one(
|
||||
; CHECK-NEXT: entry:
|
||||
; CHECK-NEXT: [[RES:%.*]] = call nnan nsz <2 x double> @llvm.fma.v2f64(<2 x double> <double 1.000000e+00, double 1.000000e+00>, <2 x double> <double 0x4131233302898702, double 0x40C387800000D6C0>, <2 x double> [[B:%.*]])
|
||||
; CHECK-NEXT: ret <2 x double> [[RES]]
|
||||
;
|
||||
entry:
|
||||
%res = call nnan nsz <2 x double> @llvm.fma.v2f64(<2 x double> <double 1.0, double 1.0>, <2 x double> <double 1123123.0099110012314, double 9999.0000001>, <2 x double> %b)
|
||||
ret <2 x double> %res
|
||||
}
|
||||
|
||||
define <2 x double> @fma_const_fmul_one2(<2 x double> %b) {
|
||||
; CHECK-LABEL: @fma_const_fmul_one2(
|
||||
; CHECK-NEXT: entry:
|
||||
; CHECK-NEXT: [[RES:%.*]] = fadd nnan nsz <2 x double> [[B:%.*]], <double 0x4131233302898702, double 0x40C387800000D6C0>
|
||||
; CHECK-NEXT: ret <2 x double> [[RES]]
|
||||
;
|
||||
entry:
|
||||
%res = call nnan nsz <2 x double> @llvm.fma.v2f64(<2 x double> <double 1123123.0099110012314, double 9999.0000001>, <2 x double> <double 1.0, double 1.0>, <2 x double> %b)
|
||||
ret <2 x double> %res
|
||||
}
|
||||
|
||||
define <2 x double> @fmuladd_const_fmul(<2 x double> %b) {
|
||||
; CHECK-LABEL: @fmuladd_const_fmul(
|
||||
; CHECK-NEXT: entry:
|
||||
; CHECK-NEXT: [[RES:%.*]] = fadd nnan nsz <2 x double> [[B:%.*]], <double 0x3F8DB6C076AD949B, double 0x3FB75A405B6E6D69>
|
||||
; CHECK-NEXT: ret <2 x double> [[RES]]
|
||||
;
|
||||
entry:
|
||||
%res = call nnan nsz <2 x double> @llvm.fmuladd.v2f64(<2 x double> <double 1123123.0099110012314, double 9999.0000001>, <2 x double> <double 0.0000000129182, double 0.000009123>, <2 x double> %b)
|
||||
ret <2 x double> %res
|
||||
}
|
||||
|
||||
declare <2 x double> @llvm.fma.v2f64(<2 x double>, <2 x double>, <2 x double>)
|
||||
declare <2 x double> @llvm.sqrt.v2f64(<2 x double>)
|
||||
|
|
Loading…
Reference in New Issue