[LibCallSimplifier] allow splat vectors for pow(x, 0.5) -> sqrt() transforms

llvm-svn: 318629
This commit is contained in:
Sanjay Patel 2017-11-19 16:42:27 +00:00
parent fbd3e66b9a
commit 9771a96f6e
2 changed files with 10 additions and 10 deletions

View File

@ -1081,11 +1081,10 @@ Value *LibCallSimplifier::replacePowWithSqrt(CallInst *Pow, IRBuilder<> &B) {
if (!Pow->isFast())
return nullptr;
// TODO: This should use m_APFloat to allow vector splats.
ConstantFP *Op2C = dyn_cast<ConstantFP>(Pow->getArgOperand(1));
if (!Op2C)
const APFloat *Arg1C;
if (!match(Pow->getArgOperand(1), m_APFloat(Arg1C)))
return nullptr;
if (!Op2C->isExactlyValue(0.5) && !Op2C->isExactlyValue(-0.5))
if (!Arg1C->isExactlyValue(0.5) && !Arg1C->isExactlyValue(-0.5))
return nullptr;
// Fast-math flags from the pow() are propagated to all replacement ops.
@ -1114,7 +1113,7 @@ Value *LibCallSimplifier::replacePowWithSqrt(CallInst *Pow, IRBuilder<> &B) {
}
// If this is pow(x, -0.5), get the reciprocal.
if (Op2C->isExactlyValue(-0.5))
if (Arg1C->isExactlyValue(-0.5))
Sqrt = B.CreateFDiv(ConstantFP::get(Ty, 1.0), Sqrt);
return Sqrt;
@ -1170,6 +1169,9 @@ Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) {
}
}
if (Value *Sqrt = replacePowWithSqrt(CI, B))
return Sqrt;
ConstantFP *Op2C = dyn_cast<ConstantFP>(Op2);
if (!Op2C)
return Ret;
@ -1177,9 +1179,6 @@ Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) {
if (Op2C->getValueAPF().isZero()) // pow(x, 0.0) -> 1.0
return ConstantFP::get(CI->getType(), 1.0);
if (Value *Sqrt = replacePowWithSqrt(CI, B))
return Sqrt;
// FIXME: Correct the transforms and pull this into replacePowWithSqrt().
if (Op2C->isExactlyValue(0.5) &&
hasUnaryFloatFn(TLI, Op2->getType(), LibFunc_sqrt, LibFunc_sqrtf,

View File

@ -32,8 +32,9 @@ define double @pow_libcall_half_approx(double %x) {
define <2 x double> @pow_intrinsic_neghalf_fast(<2 x double> %x) {
; CHECK-LABEL: @pow_intrinsic_neghalf_fast(
; CHECK-NEXT: [[POW:%.*]] = call fast <2 x double> @llvm.pow.v2f64(<2 x double> %x, <2 x double> <double -5.000000e-01, double -5.000000e-01>)
; CHECK-NEXT: ret <2 x double> [[POW]]
; CHECK-NEXT: [[TMP1:%.*]] = call fast <2 x double> @llvm.sqrt.v2f64(<2 x double> %x)
; CHECK-NEXT: [[TMP2:%.*]] = fdiv fast <2 x double> <double 1.000000e+00, double 1.000000e+00>, [[TMP1]]
; CHECK-NEXT: ret <2 x double> [[TMP2]]
;
%pow = call fast <2 x double> @llvm.pow.v2f64(<2 x double> %x, <2 x double> <double -5.0e-01, double -5.0e-01>)
ret <2 x double> %pow