[SLC] Expand simplification of pow() for vector types

Also consider vector constants when simplifying `pow()`.

Differential revision: https://reviews.llvm.org/D50035

llvm-svn: 339578
This commit is contained in:
Evandro Menezes 2018-08-13 16:12:37 +00:00
parent 2c6cbc8bb2
commit 5ecd6c1a46
4 changed files with 96 additions and 98 deletions

View File

@ -1211,6 +1211,10 @@ Value *LibCallSimplifier::optimizePow(CallInst *Pow, IRBuilder<> &B) {
Value *Shrunk = nullptr;
bool Ignored;
// Bail out if simplifying libcalls to pow() is disabled.
if (!hasUnaryFloatFn(TLI, Ty, LibFunc_pow, LibFunc_powf, LibFunc_powl))
return nullptr;
// Propagate the math semantics from the call to any created instructions.
IRBuilder<>::FastMathFlagGuard Guard(B);
B.setFastMathFlags(Pow->getFastMathFlags());
@ -1252,9 +1256,6 @@ Value *LibCallSimplifier::optimizePow(CallInst *Pow, IRBuilder<> &B) {
Function *CalleeFn = BaseFn->getCalledFunction();
if (CalleeFn && TLI->getLibFunc(CalleeFn->getName(), LibFn) &&
(LibFn == LibFunc_exp || LibFn == LibFunc_exp2) && TLI->has(LibFn)) {
IRBuilder<>::FastMathFlagGuard Guard(B);
B.setFastMathFlags(Pow->getFastMathFlags());
Value *FMul = B.CreateFMul(BaseFn->getArgOperand(0), Expo, "mul");
return emitUnaryFloatFnCall(FMul, CalleeFn->getName(), B,
CalleeFn->getAttributes());
@ -1263,31 +1264,28 @@ Value *LibCallSimplifier::optimizePow(CallInst *Pow, IRBuilder<> &B) {
// Evaluate special cases related to the exponent.
if (Value *Sqrt = replacePowWithSqrt(Pow, B))
return Sqrt;
ConstantFP *ExpoC = dyn_cast<ConstantFP>(Expo);
if (!ExpoC)
return Shrunk;
// pow(x, -1.0) -> 1.0 / x
if (ExpoC->isExactlyValue(-1.0))
if (match(Expo, m_SpecificFP(-1.0)))
return B.CreateFDiv(ConstantFP::get(Ty, 1.0), Base, "reciprocal");
// pow(x, 0.0) -> 1.0
if (ExpoC->getValueAPF().isZero())
return ConstantFP::get(Ty, 1.0);
if (match(Expo, m_SpecificFP(0.0)))
return ConstantFP::get(Ty, 1.0);
// pow(x, 1.0) -> x
if (ExpoC->isExactlyValue(1.0))
if (match(Expo, m_FPOne()))
return Base;
// pow(x, 2.0) -> x * x
if (ExpoC->isExactlyValue(2.0))
if (match(Expo, m_SpecificFP(2.0)))
return B.CreateFMul(Base, Base, "square");
if (Value *Sqrt = replacePowWithSqrt(Pow, B))
return Sqrt;
// FIXME: Correct the transforms and pull this into replacePowWithSqrt().
if (ExpoC->isExactlyValue(0.5) &&
ConstantFP *ExpoC = dyn_cast<ConstantFP>(Expo);
if (ExpoC && ExpoC->isExactlyValue(0.5) &&
hasUnaryFloatFn(TLI, Ty, LibFunc_sqrt, LibFunc_sqrtf, LibFunc_sqrtl)) {
// Expand pow(x, 0.5) to (x == -infinity ? +infinity : fabs(sqrt(x))).
// This is faster than calling pow(), and still handles -0.0 and
@ -1307,30 +1305,29 @@ Value *LibCallSimplifier::optimizePow(CallInst *Pow, IRBuilder<> &B) {
return Sqrt;
}
// pow(x, n) -> x * x * x * ....
if (Pow->isFast()) {
APFloat ExpoA = abs(ExpoC->getValueAPF());
// We limit to a max of 7 fmul(s). Thus the maximum exponent is 32.
// This transformation applies to integer exponents only.
if (!ExpoA.isInteger() ||
ExpoA.compare
(APFloat(ExpoA.getSemantics(), 32.0)) == APFloat::cmpGreaterThan)
return Shrunk;
// pow(x, n) -> x * x * x * ...
const APFloat *ExpoF;
if (Pow->isFast() && match(Expo, m_APFloat(ExpoF))) {
// We limit to a max of 7 multiplications, thus the maximum exponent is 32.
APFloat LimF(ExpoF->getSemantics(), 33.0),
ExpoA(abs(*ExpoF));
if (ExpoA.isInteger() && ExpoA.compare(LimF) == APFloat::cmpLessThan) {
// We will memoize intermediate products of the Addition Chain.
Value *InnerChain[33] = {nullptr};
InnerChain[1] = Base;
InnerChain[2] = B.CreateFMul(Base, Base, "square");
// We will memoize intermediate products of the Addition Chain.
Value *InnerChain[33] = {nullptr};
InnerChain[1] = Base;
InnerChain[2] = B.CreateFMul(Base, Base, "square");
// We cannot readily convert a non-double type (like float) to a double.
// So we first convert it to something which could be converted to double.
ExpoA.convert(APFloat::IEEEdouble(), APFloat::rmTowardZero, &Ignored);
Value *FMul = getPow(InnerChain, ExpoA.convertToDouble(), B);
// We cannot readily convert a non-double type (like float) to a double.
// So we first convert it to something which could be converted to double.
ExpoA.convert(APFloat::IEEEdouble(), APFloat::rmTowardZero, &Ignored);
Value *FMul = getPow(InnerChain, ExpoA.convertToDouble(), B);
// If the exponent is negative, then get the reciprocal.
if (ExpoF->isNegative())
FMul = B.CreateFDiv(ConstantFP::get(Ty, 1.0), FMul, "reciprocal");
// If the exponent is negative, then get the reciprocal.
if (ExpoC->isNegative())
FMul = B.CreateFDiv(ConstantFP::get(Ty, 1.0), FMul, "reciprocal");
return FMul;
return FMul;
}
}
return Shrunk;

View File

@ -95,7 +95,7 @@ define <2 x float> @test_simplify5v(<2 x float> %x) {
; CHECK-LABEL: @test_simplify5v(
%retval = call <2 x float> @llvm.pow.v2f32(<2 x float> %x, <2 x float> <float 0.0, float 0.0>)
ret <2 x float> %retval
; CHECK-NEXT: %retval = call <2 x float> @llvm.pow.v2f32(<2 x float> %x, <2 x float> zeroinitializer)
; CHECK-NEXT: ret <2 x float> <float 1.000000e+00, float 1.000000e+00>
}
define double @test_simplify6(double %x) {
@ -109,7 +109,7 @@ define <2 x double> @test_simplify6v(<2 x double> %x) {
; CHECK-LABEL: @test_simplify6v(
%retval = call <2 x double> @llvm.pow.v2f64(<2 x double> %x, <2 x double> <double 0.0, double 0.0>)
ret <2 x double> %retval
; CHECK-NEXT: %retval = call <2 x double> @llvm.pow.v2f64(<2 x double> %x, <2 x double> zeroinitializer)
; CHECK-NEXT: ret <2 x double> <double 1.000000e+00, double 1.000000e+00>
}
; Check pow(x, 0.5) -> fabs(sqrt(x)), where x != -infinity.
@ -165,7 +165,7 @@ define <2 x float> @test_simplify11v(<2 x float> %x) {
; CHECK-LABEL: @test_simplify11v(
%retval = call <2 x float> @llvm.pow.v2f32(<2 x float> %x, <2 x float> <float 1.0, float 1.0>)
ret <2 x float> %retval
; CHECK-NEXT: %retval = call <2 x float> @llvm.pow.v2f32(<2 x float> %x, <2 x float> <float 1.000000e+00, float 1.000000e+00>)
; CHECK-NEXT: ret <2 x float> %x
}
define double @test_simplify12(double %x) {
@ -179,7 +179,7 @@ define <2 x double> @test_simplify12v(<2 x double> %x) {
; CHECK-LABEL: @test_simplify12v(
%retval = call <2 x double> @llvm.pow.v2f64(<2 x double> %x, <2 x double> <double 1.0, double 1.0>)
ret <2 x double> %retval
; CHECK-NEXT: %retval = call <2 x double> @llvm.pow.v2f64(<2 x double> %x, <2 x double> <double 1.000000e+00, double 1.000000e+00>)
; CHECK-NEXT: ret <2 x double> %x
}
; Check pow(x, 2.0) -> x*x.
@ -195,7 +195,7 @@ define float @pow2_strict(float %x) {
define <2 x float> @pow2_strictv(<2 x float> %x) {
; CHECK-LABEL: @pow2_strictv(
; CHECK-NEXT: [[POW2:%.*]] = call <2 x float> @llvm.pow.v2f32(<2 x float> %x, <2 x float> <float 2.000000e+00, float 2.000000e+00>)
; CHECK-NEXT: [[POW2:%.*]] = fmul <2 x float> %x, %x
; CHECK-NEXT: ret <2 x float> [[POW2]]
;
%r = call <2 x float> @llvm.pow.v2f32(<2 x float> %x, <2 x float> <float 2.0, float 2.0>)
@ -212,7 +212,7 @@ define double @pow2_double_strict(double %x) {
}
define <2 x double> @pow2_double_strictv(<2 x double> %x) {
; CHECK-LABEL: @pow2_double_strictv(
; CHECK-NEXT: [[POW2:%.*]] = call <2 x double> @llvm.pow.v2f64(<2 x double> %x, <2 x double> <double 2.000000e+00, double 2.000000e+00>)
; CHECK-NEXT: [[POW2:%.*]] = fmul <2 x double> %x, %x
; CHECK-NEXT: ret <2 x double> [[POW2]]
;
%r = call <2 x double> @llvm.pow.v2f64(<2 x double> %x, <2 x double> <double 2.0, double 2.0>)
@ -243,7 +243,7 @@ define float @pow_neg1_strict(float %x) {
define <2 x float> @pow_neg1_strictv(<2 x float> %x) {
; CHECK-LABEL: @pow_neg1_strictv(
; CHECK-NEXT: [[POWRECIP:%.*]] = call <2 x float> @llvm.pow.v2f32(<2 x float> %x, <2 x float> <float -1.000000e+00, float -1.000000e+00>)
; CHECK-NEXT: [[POWRECIP:%.*]] = fdiv <2 x float> <float 1.000000e+00, float 1.000000e+00>, %x
; CHECK-NEXT: ret <2 x float> [[POWRECIP]]
;
%r = call <2 x float> @llvm.pow.v2f32(<2 x float> %x, <2 x float> <float -1.0, float -1.0>)
@ -261,7 +261,7 @@ define double @pow_neg1_double_fast(double %x) {
define <2 x double> @pow_neg1_double_fastv(<2 x double> %x) {
; CHECK-LABEL: @pow_neg1_double_fastv(
; CHECK-NEXT: [[POWRECIP:%.*]] = call fast <2 x double> @llvm.pow.v2f64(<2 x double> %x, <2 x double> <double -1.000000e+00, double -1.000000e+00>)
; CHECK-NEXT: [[POWRECIP:%.*]] = fdiv fast <2 x double> <double 1.000000e+00, double 1.000000e+00>, %x
; CHECK-NEXT: ret <2 x double> [[POWRECIP]]
;
%r = call fast <2 x double> @llvm.pow.v2f64(<2 x double> %x, <2 x double> <double -1.0, double -1.0>)

View File

@ -48,4 +48,3 @@ define float @test_simplify_unavailable3(float %f, float %g) {
%fr = fptrunc double %call to float
ret float %fr
}

View File

@ -3,17 +3,8 @@
declare double @llvm.pow.f64(double, double)
declare float @llvm.pow.f32(float, float)
; pow(x, 4.0f)
define float @test_simplify_4f(float %x) {
; CHECK-LABEL: @test_simplify_4f(
; CHECK-NEXT: [[TMP1:%.*]] = fmul fast float [[X:%.*]], [[X]]
; CHECK-NEXT: [[TMP2:%.*]] = fmul fast float [[TMP1]], [[TMP1]]
; CHECK-NEXT: ret float [[TMP2]]
;
%1 = call fast float @llvm.pow.f32(float %x, float 4.000000e+00)
ret float %1
}
declare <2 x double> @llvm.pow.v2f64(<2 x double>, <2 x double>)
declare <2 x float> @llvm.pow.v2f32(<2 x float>, <2 x float>)
; pow(x, 3.0)
define double @test_simplify_3(double %x) {
@ -26,6 +17,17 @@ define double @test_simplify_3(double %x) {
ret double %1
}
; powf(x, 4.0)
define float @test_simplify_4f(float %x) {
; CHECK-LABEL: @test_simplify_4f(
; CHECK-NEXT: [[TMP1:%.*]] = fmul fast float [[X:%.*]], [[X]]
; CHECK-NEXT: [[TMP2:%.*]] = fmul fast float [[TMP1]], [[TMP1]]
; CHECK-NEXT: ret float [[TMP2]]
;
%1 = call fast float @llvm.pow.f32(float %x, float 4.000000e+00)
ret float %1
}
; pow(x, 4.0)
define double @test_simplify_4(double %x) {
; CHECK-LABEL: @test_simplify_4(
@ -37,48 +39,48 @@ define double @test_simplify_4(double %x) {
ret double %1
}
; pow(x, 15.0)
define double @test_simplify_15(double %x) {
; powf(x, <15.0, 15.0>)
define <2 x float> @test_simplify_15(<2 x float> %x) {
; CHECK-LABEL: @test_simplify_15(
; CHECK-NEXT: [[TMP1:%.*]] = fmul fast double [[X:%.*]], [[X]]
; CHECK-NEXT: [[TMP2:%.*]] = fmul fast double [[TMP1]], [[X]]
; CHECK-NEXT: [[TMP3:%.*]] = fmul fast double [[TMP2]], [[TMP2]]
; CHECK-NEXT: [[TMP4:%.*]] = fmul fast double [[TMP3]], [[TMP3]]
; CHECK-NEXT: [[TMP5:%.*]] = fmul fast double [[TMP2]], [[TMP4]]
; CHECK-NEXT: ret double [[TMP5]]
; CHECK-NEXT: [[TMP1:%.*]] = fmul fast <2 x float> [[X:%.*]], [[X]]
; CHECK-NEXT: [[TMP2:%.*]] = fmul fast <2 x float> [[TMP1]], [[X]]
; CHECK-NEXT: [[TMP3:%.*]] = fmul fast <2 x float> [[TMP2]], [[TMP2]]
; CHECK-NEXT: [[TMP4:%.*]] = fmul fast <2 x float> [[TMP3]], [[TMP3]]
; CHECK-NEXT: [[TMP5:%.*]] = fmul fast <2 x float> [[TMP2]], [[TMP4]]
; CHECK-NEXT: ret <2 x float> [[TMP5]]
;
%1 = call fast double @llvm.pow.f64(double %x, double 1.500000e+01)
ret double %1
%1 = call fast <2 x float> @llvm.pow.v2f32(<2 x float> %x, <2 x float> <float 1.500000e+01, float 1.500000e+01>)
ret <2 x float> %1
}
; pow(x, -7.0)
define double @test_simplify_neg_7(double %x) {
define <2 x double> @test_simplify_neg_7(<2 x double> %x) {
; CHECK-LABEL: @test_simplify_neg_7(
; CHECK-NEXT: [[TMP1:%.*]] = fmul fast double [[X:%.*]], [[X]]
; CHECK-NEXT: [[TMP2:%.*]] = fmul fast double [[TMP1]], [[TMP1]]
; CHECK-NEXT: [[TMP3:%.*]] = fmul fast double [[TMP2]], [[X]]
; CHECK-NEXT: [[TMP4:%.*]] = fmul fast double [[TMP1]], [[TMP3]]
; CHECK-NEXT: [[TMP5:%.*]] = fdiv fast double 1.000000e+00, [[TMP4]]
; CHECK-NEXT: ret double [[TMP5]]
; CHECK-NEXT: [[TMP1:%.*]] = fmul fast <2 x double> [[X:%.*]], [[X]]
; CHECK-NEXT: [[TMP2:%.*]] = fmul fast <2 x double> [[TMP1]], [[TMP1]]
; CHECK-NEXT: [[TMP3:%.*]] = fmul fast <2 x double> [[TMP2]], [[X]]
; CHECK-NEXT: [[TMP4:%.*]] = fmul fast <2 x double> [[TMP1]], [[TMP3]]
; CHECK-NEXT: [[TMP5:%.*]] = fdiv fast <2 x double> <double 1.000000e+00, double 1.000000e+00>, [[TMP4]]
; CHECK-NEXT: ret <2 x double> [[TMP5]]
;
%1 = call fast double @llvm.pow.f64(double %x, double -7.000000e+00)
ret double %1
%1 = call fast <2 x double> @llvm.pow.v2f64(<2 x double> %x, <2 x double> <double -7.000000e+00, double -7.000000e+00>)
ret <2 x double> %1
}
; pow(x, -19.0)
define double @test_simplify_neg_19(double %x) {
; powf(x, -19.0)
define float @test_simplify_neg_19(float %x) {
; CHECK-LABEL: @test_simplify_neg_19(
; CHECK-NEXT: [[TMP1:%.*]] = fmul fast double [[X:%.*]], [[X]]
; CHECK-NEXT: [[TMP2:%.*]] = fmul fast double [[TMP1]], [[TMP1]]
; CHECK-NEXT: [[TMP3:%.*]] = fmul fast double [[TMP2]], [[TMP2]]
; CHECK-NEXT: [[TMP4:%.*]] = fmul fast double [[TMP3]], [[TMP3]]
; CHECK-NEXT: [[TMP5:%.*]] = fmul fast double [[TMP1]], [[TMP4]]
; CHECK-NEXT: [[TMP6:%.*]] = fmul fast double [[TMP5]], [[X]]
; CHECK-NEXT: [[TMP7:%.*]] = fdiv fast double 1.000000e+00, [[TMP6]]
; CHECK-NEXT: ret double [[TMP7]]
; CHECK-NEXT: [[TMP1:%.*]] = fmul fast float [[X:%.*]], [[X]]
; CHECK-NEXT: [[TMP2:%.*]] = fmul fast float [[TMP1]], [[TMP1]]
; CHECK-NEXT: [[TMP3:%.*]] = fmul fast float [[TMP2]], [[TMP2]]
; CHECK-NEXT: [[TMP4:%.*]] = fmul fast float [[TMP3]], [[TMP3]]
; CHECK-NEXT: [[TMP5:%.*]] = fmul fast float [[TMP1]], [[TMP4]]
; CHECK-NEXT: [[TMP6:%.*]] = fmul fast float [[TMP5]], [[X]]
; CHECK-NEXT: [[TMP7:%.*]] = fdiv fast float 1.000000e+00, [[TMP6]]
; CHECK-NEXT: ret float [[TMP7]]
;
%1 = call fast double @llvm.pow.f64(double %x, double -1.900000e+01)
ret double %1
%1 = call fast float @llvm.pow.f32(float %x, float -1.900000e+01)
ret float %1
}
; pow(x, 11.23)
@ -91,18 +93,18 @@ define double @test_simplify_11_23(double %x) {
ret double %1
}
; pow(x, 32.0)
define double @test_simplify_32(double %x) {
; powf(x, 32.0)
define float @test_simplify_32(float %x) {
; CHECK-LABEL: @test_simplify_32(
; CHECK-NEXT: [[TMP1:%.*]] = fmul fast double [[X:%.*]], [[X]]
; CHECK-NEXT: [[TMP2:%.*]] = fmul fast double [[TMP1]], [[TMP1]]
; CHECK-NEXT: [[TMP3:%.*]] = fmul fast double [[TMP2]], [[TMP2]]
; CHECK-NEXT: [[TMP4:%.*]] = fmul fast double [[TMP3]], [[TMP3]]
; CHECK-NEXT: [[TMP5:%.*]] = fmul fast double [[TMP4]], [[TMP4]]
; CHECK-NEXT: ret double [[TMP5]]
; CHECK-NEXT: [[TMP1:%.*]] = fmul fast float [[X:%.*]], [[X]]
; CHECK-NEXT: [[TMP2:%.*]] = fmul fast float [[TMP1]], [[TMP1]]
; CHECK-NEXT: [[TMP3:%.*]] = fmul fast float [[TMP2]], [[TMP2]]
; CHECK-NEXT: [[TMP4:%.*]] = fmul fast float [[TMP3]], [[TMP3]]
; CHECK-NEXT: [[TMP5:%.*]] = fmul fast float [[TMP4]], [[TMP4]]
; CHECK-NEXT: ret float [[TMP5]]
;
%1 = call fast double @llvm.pow.f64(double %x, double 3.200000e+01)
ret double %1
%1 = call fast float @llvm.pow.f32(float %x, float 3.200000e+01)
ret float %1
}
; pow(x, 33.0)