[LibCallSimplifier] partly fix pow(x, 0.5) -> sqrt() transforms

As the first test shows, we could transform an llvm intrinsic which never sets errno 
into a libcall which could set errno (even though it's marked readnone?), so that's 
not ideal.

It's possible that we can also transform a libcall which could set errno to an
intrinsic given the fast-math-flags constraint, but that's deferred to determine
exactly which set of FMF are needed.

Differential Revision: https://reviews.llvm.org/D40150

llvm-svn: 318628
This commit is contained in:
Sanjay Patel 2017-11-19 16:13:14 +00:00
parent eb731b09f3
commit fbd3e66b9a
3 changed files with 54 additions and 36 deletions

View File

@ -131,6 +131,7 @@ private:
// Math Library Optimizations
Value *optimizeCos(CallInst *CI, IRBuilder<> &B);
Value *optimizePow(CallInst *CI, IRBuilder<> &B);
Value *replacePowWithSqrt(CallInst *Pow, IRBuilder<> &B);
Value *optimizeExp2(CallInst *CI, IRBuilder<> &B);
Value *optimizeFMinFMax(CallInst *CI, IRBuilder<> &B);
Value *optimizeLog(CallInst *CI, IRBuilder<> &B);

View File

@ -1074,6 +1074,52 @@ static Value *getPow(Value *InnerChain[33], unsigned Exp, IRBuilder<> &B) {
return InnerChain[Exp];
}
/// Use square root in place of pow(x, +/-0.5).
Value *LibCallSimplifier::replacePowWithSqrt(CallInst *Pow, IRBuilder<> &B) {
// TODO: There is some subset of 'fast' under which these transforms should
// be allowed.
if (!Pow->isFast())
return nullptr;
// TODO: This should use m_APFloat to allow vector splats.
ConstantFP *Op2C = dyn_cast<ConstantFP>(Pow->getArgOperand(1));
if (!Op2C)
return nullptr;
if (!Op2C->isExactlyValue(0.5) && !Op2C->isExactlyValue(-0.5))
return nullptr;
// Fast-math flags from the pow() are propagated to all replacement ops.
IRBuilder<>::FastMathFlagGuard Guard(B);
B.setFastMathFlags(Pow->getFastMathFlags());
Type *Ty = Pow->getType();
Value *Sqrt;
if (Pow->hasFnAttr(Attribute::ReadNone)) {
// We know that errno is never set, so replace with an intrinsic:
// pow(x, 0.5) --> llvm.sqrt(x)
// llvm.pow(x, 0.5) --> llvm.sqrt(x)
auto *F = Intrinsic::getDeclaration(Pow->getModule(), Intrinsic::sqrt, Ty);
Sqrt = B.CreateCall(F, Pow->getArgOperand(0));
} else if (hasUnaryFloatFn(TLI, Ty, LibFunc_sqrt, LibFunc_sqrtf,
LibFunc_sqrtl)) {
// Errno could be set, so we must use a sqrt libcall.
// TODO: We also should check that the target can in fact lower the sqrt
// libcall. We currently have no way to ask this question, so we ask
// whether the target has a sqrt libcall which is not exactly the same.
Sqrt = emitUnaryFloatFnCall(Pow->getArgOperand(0),
TLI->getName(LibFunc_sqrt), B,
Pow->getCalledFunction()->getAttributes());
} else {
// We can't replace with an intrinsic or a libcall.
return nullptr;
}
// If this is pow(x, -0.5), get the reciprocal.
if (Op2C->isExactlyValue(-0.5))
Sqrt = B.CreateFDiv(ConstantFP::get(Ty, 1.0), Sqrt);
return Sqrt;
}
Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) {
Function *Callee = CI->getCalledFunction();
Value *Ret = nullptr;
@ -1131,42 +1177,13 @@ Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) {
if (Op2C->getValueAPF().isZero()) // pow(x, 0.0) -> 1.0
return ConstantFP::get(CI->getType(), 1.0);
if (Op2C->isExactlyValue(-0.5) &&
hasUnaryFloatFn(TLI, Op2->getType(), LibFunc_sqrt, LibFunc_sqrtf,
LibFunc_sqrtl)) {
// If -ffast-math:
// pow(x, -0.5) -> 1.0 / sqrt(x)
if (CI->isFast()) {
IRBuilder<>::FastMathFlagGuard Guard(B);
B.setFastMathFlags(CI->getFastMathFlags());
// TODO: If the pow call is an intrinsic, we should lower to the sqrt
// intrinsic, so we match errno semantics. We also should check that the
// target can in fact lower the sqrt intrinsic -- we currently have no way
// to ask this question other than asking whether the target has a sqrt
// libcall, which is a sufficient but not necessary condition.
Value *Sqrt = emitUnaryFloatFnCall(Op1, TLI->getName(LibFunc_sqrt), B,
Callee->getAttributes());
return B.CreateFDiv(ConstantFP::get(CI->getType(), 1.0), Sqrt, "sqrtrecip");
}
}
if (Value *Sqrt = replacePowWithSqrt(CI, B))
return Sqrt;
// FIXME: Correct the transforms and pull this into replacePowWithSqrt().
if (Op2C->isExactlyValue(0.5) &&
hasUnaryFloatFn(TLI, Op2->getType(), LibFunc_sqrt, LibFunc_sqrtf,
LibFunc_sqrtl)) {
// In -ffast-math, pow(x, 0.5) -> sqrt(x).
if (CI->isFast()) {
IRBuilder<>::FastMathFlagGuard Guard(B);
B.setFastMathFlags(CI->getFastMathFlags());
// TODO: As above, we should lower to the sqrt intrinsic if the pow is an
// intrinsic, to match errno semantics.
return emitUnaryFloatFnCall(Op1, TLI->getName(LibFunc_sqrt), B,
Callee->getAttributes());
}
// Expand pow(x, 0.5) to (x == -infinity ? +infinity : fabs(sqrt(x))).
// This is faster than calling pow, and still handles negative zero
// and negative infinity correctly.

View File

@ -2,8 +2,8 @@
define double @pow_intrinsic_half_fast(double %x) {
; CHECK-LABEL: @pow_intrinsic_half_fast(
; CHECK-NEXT: [[SQRT:%.*]] = call fast double @sqrt(double %x) #1
; CHECK-NEXT: ret double [[SQRT]]
; CHECK-NEXT: [[TMP1:%.*]] = call fast double @llvm.sqrt.f64(double %x)
; CHECK-NEXT: ret double [[TMP1]]
;
%pow = call fast double @llvm.pow.f64(double %x, double 5.000000e-01)
ret double %pow
@ -51,8 +51,8 @@ define double @pow_intrinsic_neghalf_approx(double %x) {
define float @pow_libcall_neghalf_fast(float %x) {
; CHECK-LABEL: @pow_libcall_neghalf_fast(
; CHECK-NEXT: [[SQRTF:%.*]] = call fast float @sqrtf(float %x)
; CHECK-NEXT: [[SQRTRECIP:%.*]] = fdiv fast float 1.000000e+00, [[SQRTF]]
; CHECK-NEXT: ret float [[SQRTRECIP]]
; CHECK-NEXT: [[TMP1:%.*]] = fdiv fast float 1.000000e+00, [[SQRTF]]
; CHECK-NEXT: ret float [[TMP1]]
;
%pow = call fast float @powf(float %x, float -5.0e-01)
ret float %pow