fold: sqrt(x * x * y) -> fabs(x) * sqrt(y)

If a square root call has an FP multiplication argument that can be reassociated,
then we can hoist a repeated factor out of the square root call and into a fabs().

In the simplest case, this:

   y = sqrt(x * x);

becomes this:

   y = fabs(x);

This patch relies on an earlier optimization in instcombine or reassociate to put the
multiplication tree into a canonical form, so we don't have to search over
every permutation of the multiplication tree.

Because there are no IR-level FastMathFlags for intrinsics (PR21290), we have to
use function-level attributes to do this optimization. This needs to be fixed
for both the intrinsics and in the backend.

Differential Revision: http://reviews.llvm.org/D5787

llvm-svn: 219944
This commit is contained in:
Sanjay Patel 2014-10-16 18:48:17 +00:00
parent d70f3c20b8
commit c699a6117b
3 changed files with 258 additions and 1 deletions

View File

@ -93,6 +93,7 @@ private:
Value *optimizePow(CallInst *CI, IRBuilder<> &B);
Value *optimizeExp2(CallInst *CI, IRBuilder<> &B);
Value *optimizeFabs(CallInst *CI, IRBuilder<> &B);
Value *optimizeSqrt(CallInst *CI, IRBuilder<> &B);
Value *optimizeSinCosPi(CallInst *CI, IRBuilder<> &B);
// Integer Library Call Optimizations

View File

@ -27,12 +27,14 @@
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/Support/Allocator.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Target/TargetLibraryInfo.h"
#include "llvm/Transforms/Utils/BuildLibCalls.h"
using namespace llvm;
using namespace PatternMatch;
static cl::opt<bool>
ColdErrorCalls("error-reporting-is-cold", cl::init(true), cl::Hidden,
@ -1254,6 +1256,85 @@ Value *LibCallSimplifier::optimizeFabs(CallInst *CI, IRBuilder<> &B) {
return Ret;
}
Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilder<> &B) {
Function *Callee = CI->getCalledFunction();
Value *Ret = nullptr;
if (UnsafeFPShrink && Callee->getName() == "sqrt" &&
TLI->has(LibFunc::sqrtf)) {
Ret = optimizeUnaryDoubleFP(CI, B, true);
}
// FIXME: For finer-grain optimization, we need intrinsics to have the same
// fast-math flag decorations that are applied to FP instructions. For now,
// we have to rely on the function-level unsafe-fp-math attribute to do this
// optimization because there's no other way to express that the sqrt can be
// reassociated.
Function *F = CI->getParent()->getParent();
if (F->hasFnAttribute("unsafe-fp-math")) {
// Check for unsafe-fp-math = true.
Attribute Attr = F->getFnAttribute("unsafe-fp-math");
if (Attr.getValueAsString() != "true")
return Ret;
}
Value *Op = CI->getArgOperand(0);
if (Instruction *I = dyn_cast<Instruction>(Op)) {
if (I->getOpcode() == Instruction::FMul && I->hasUnsafeAlgebra()) {
// We're looking for a repeated factor in a multiplication tree,
// so we can do this fold: sqrt(x * x) -> fabs(x);
// or this fold: sqrt(x * x * y) -> fabs(x) * sqrt(y).
Value *Op0 = I->getOperand(0);
Value *Op1 = I->getOperand(1);
Value *RepeatOp = nullptr;
Value *OtherOp = nullptr;
if (Op0 == Op1) {
// Simple match: the operands of the multiply are identical.
RepeatOp = Op0;
} else {
// Look for a more complicated pattern: one of the operands is itself
// a multiply, so search for a common factor in that multiply.
// Note: We don't bother looking any deeper than this first level or for
// variations of this pattern because instcombine's visitFMUL and/or the
// reassociation pass should give us this form.
Value *OtherMul0, *OtherMul1;
if (match(Op0, m_FMul(m_Value(OtherMul0), m_Value(OtherMul1)))) {
// Pattern: sqrt((x * y) * z)
if (OtherMul0 == OtherMul1) {
// Matched: sqrt((x * x) * z)
RepeatOp = OtherMul0;
OtherOp = Op1;
}
}
}
if (RepeatOp) {
// Fast math flags for any created instructions should match the sqrt
// and multiply.
// FIXME: We're not checking the sqrt because it doesn't have
// fast-math-flags (see earlier comment).
IRBuilder<true, ConstantFolder,
IRBuilderDefaultInserter<true> >::FastMathFlagGuard Guard(B);
B.SetFastMathFlags(I->getFastMathFlags());
// If we found a repeated factor, hoist it out of the square root and
// replace it with the fabs of that factor.
Module *M = Callee->getParent();
Type *ArgType = Op->getType();
Value *Fabs = Intrinsic::getDeclaration(M, Intrinsic::fabs, ArgType);
Value *FabsCall = B.CreateCall(Fabs, RepeatOp, "fabs");
if (OtherOp) {
// If we found a non-repeated factor, we still need to get its square
// root. We then multiply that by the value that was simplified out
// of the square root calculation.
Value *Sqrt = Intrinsic::getDeclaration(M, Intrinsic::sqrt, ArgType);
Value *SqrtCall = B.CreateCall(Sqrt, OtherOp, "sqrt");
return B.CreateFMul(FabsCall, SqrtCall);
}
return FabsCall;
}
}
}
return Ret;
}
static bool isTrigLibCall(CallInst *CI);
static void insertSinCosCall(IRBuilder<> &B, Function *OrigCallee, Value *Arg,
bool UseFloat, Value *&Sin, Value *&Cos,
@ -1919,6 +2000,8 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) {
return optimizeExp2(CI, Builder);
case Intrinsic::fabs:
return optimizeFabs(CI, Builder);
case Intrinsic::sqrt:
return optimizeSqrt(CI, Builder);
default:
return nullptr;
}
@ -1995,6 +2078,10 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) {
case LibFunc::fabs:
case LibFunc::fabsl:
return optimizeFabs(CI, Builder);
case LibFunc::sqrtf:
case LibFunc::sqrt:
case LibFunc::sqrtl:
return optimizeSqrt(CI, Builder);
case LibFunc::ffs:
case LibFunc::ffsl:
case LibFunc::ffsll:
@ -2055,7 +2142,6 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) {
case LibFunc::logb:
case LibFunc::sin:
case LibFunc::sinh:
case LibFunc::sqrt:
case LibFunc::tan:
case LibFunc::tanh:
if (UnsafeFPShrink && hasFloatVersion(FuncName))

View File

@ -530,3 +530,173 @@ define float @fact_div6(float %x) {
; CHECK: fact_div6
; CHECK: %t3 = fsub fast float %t1, %t2
}
; =========================================================================
;
; Test-cases for square root
;
; =========================================================================
; A squared factor fed into a square root intrinsic should be hoisted out
; as a fabs() value.
; We have to rely on a function-level attribute to enable this optimization
; because intrinsics don't currently have access to IR-level fast-math
; flags. If that changes, we can relax the requirement on all of these
; tests to just specify 'fast' on the sqrt.
attributes #0 = { "unsafe-fp-math" = "true" }
declare double @llvm.sqrt.f64(double)
define double @sqrt_intrinsic_arg_squared(double %x) #0 {
%mul = fmul fast double %x, %x
%sqrt = call double @llvm.sqrt.f64(double %mul)
ret double %sqrt
; CHECK-LABEL: sqrt_intrinsic_arg_squared(
; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x)
; CHECK-NEXT: ret double %fabs
}
; Check all 6 combinations of a 3-way multiplication tree where
; one factor is repeated.
define double @sqrt_intrinsic_three_args1(double %x, double %y) #0 {
%mul = fmul fast double %y, %x
%mul2 = fmul fast double %mul, %x
%sqrt = call double @llvm.sqrt.f64(double %mul2)
ret double %sqrt
; CHECK-LABEL: sqrt_intrinsic_three_args1(
; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x)
; CHECK-NEXT: %sqrt1 = call double @llvm.sqrt.f64(double %y)
; CHECK-NEXT: %1 = fmul fast double %fabs, %sqrt1
; CHECK-NEXT: ret double %1
}
define double @sqrt_intrinsic_three_args2(double %x, double %y) #0 {
%mul = fmul fast double %x, %y
%mul2 = fmul fast double %mul, %x
%sqrt = call double @llvm.sqrt.f64(double %mul2)
ret double %sqrt
; CHECK-LABEL: sqrt_intrinsic_three_args2(
; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x)
; CHECK-NEXT: %sqrt1 = call double @llvm.sqrt.f64(double %y)
; CHECK-NEXT: %1 = fmul fast double %fabs, %sqrt1
; CHECK-NEXT: ret double %1
}
define double @sqrt_intrinsic_three_args3(double %x, double %y) #0 {
%mul = fmul fast double %x, %x
%mul2 = fmul fast double %mul, %y
%sqrt = call double @llvm.sqrt.f64(double %mul2)
ret double %sqrt
; CHECK-LABEL: sqrt_intrinsic_three_args3(
; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x)
; CHECK-NEXT: %sqrt1 = call double @llvm.sqrt.f64(double %y)
; CHECK-NEXT: %1 = fmul fast double %fabs, %sqrt1
; CHECK-NEXT: ret double %1
}
define double @sqrt_intrinsic_three_args4(double %x, double %y) #0 {
%mul = fmul fast double %y, %x
%mul2 = fmul fast double %x, %mul
%sqrt = call double @llvm.sqrt.f64(double %mul2)
ret double %sqrt
; CHECK-LABEL: sqrt_intrinsic_three_args4(
; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x)
; CHECK-NEXT: %sqrt1 = call double @llvm.sqrt.f64(double %y)
; CHECK-NEXT: %1 = fmul fast double %fabs, %sqrt1
; CHECK-NEXT: ret double %1
}
define double @sqrt_intrinsic_three_args5(double %x, double %y) #0 {
%mul = fmul fast double %x, %y
%mul2 = fmul fast double %x, %mul
%sqrt = call double @llvm.sqrt.f64(double %mul2)
ret double %sqrt
; CHECK-LABEL: sqrt_intrinsic_three_args5(
; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x)
; CHECK-NEXT: %sqrt1 = call double @llvm.sqrt.f64(double %y)
; CHECK-NEXT: %1 = fmul fast double %fabs, %sqrt1
; CHECK-NEXT: ret double %1
}
define double @sqrt_intrinsic_three_args6(double %x, double %y) #0 {
%mul = fmul fast double %x, %x
%mul2 = fmul fast double %y, %mul
%sqrt = call double @llvm.sqrt.f64(double %mul2)
ret double %sqrt
; CHECK-LABEL: sqrt_intrinsic_three_args6(
; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x)
; CHECK-NEXT: %sqrt1 = call double @llvm.sqrt.f64(double %y)
; CHECK-NEXT: %1 = fmul fast double %fabs, %sqrt1
; CHECK-NEXT: ret double %1
}
define double @sqrt_intrinsic_arg_4th(double %x) #0 {
%mul = fmul fast double %x, %x
%mul2 = fmul fast double %mul, %mul
%sqrt = call double @llvm.sqrt.f64(double %mul2)
ret double %sqrt
; CHECK-LABEL: sqrt_intrinsic_arg_4th(
; CHECK-NEXT: %mul = fmul fast double %x, %x
; CHECK-NEXT: ret double %mul
}
define double @sqrt_intrinsic_arg_5th(double %x) #0 {
%mul = fmul fast double %x, %x
%mul2 = fmul fast double %mul, %x
%mul3 = fmul fast double %mul2, %mul
%sqrt = call double @llvm.sqrt.f64(double %mul3)
ret double %sqrt
; CHECK-LABEL: sqrt_intrinsic_arg_5th(
; CHECK-NEXT: %mul = fmul fast double %x, %x
; CHECK-NEXT: %sqrt1 = call double @llvm.sqrt.f64(double %x)
; CHECK-NEXT: %1 = fmul fast double %mul, %sqrt1
; CHECK-NEXT: ret double %1
}
; Check that square root calls have the same behavior.
declare float @sqrtf(float)
declare double @sqrt(double)
declare fp128 @sqrtl(fp128)
define float @sqrt_call_squared_f32(float %x) #0 {
%mul = fmul fast float %x, %x
%sqrt = call float @sqrtf(float %mul)
ret float %sqrt
; CHECK-LABEL: sqrt_call_squared_f32(
; CHECK-NEXT: %fabs = call float @llvm.fabs.f32(float %x)
; CHECK-NEXT: ret float %fabs
}
define double @sqrt_call_squared_f64(double %x) #0 {
%mul = fmul fast double %x, %x
%sqrt = call double @sqrt(double %mul)
ret double %sqrt
; CHECK-LABEL: sqrt_call_squared_f64(
; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x)
; CHECK-NEXT: ret double %fabs
}
define fp128 @sqrt_call_squared_f128(fp128 %x) #0 {
%mul = fmul fast fp128 %x, %x
%sqrt = call fp128 @sqrtl(fp128 %mul)
ret fp128 %sqrt
; CHECK-LABEL: sqrt_call_squared_f128(
; CHECK-NEXT: %fabs = call fp128 @llvm.fabs.f128(fp128 %x)
; CHECK-NEXT: ret fp128 %fabs
}