forked from OSchip/llvm-project
fold: sqrt(x * x * y) -> fabs(x) * sqrt(y)
If a square root call has an FP multiplication argument that can be reassociated, then we can hoist a repeated factor out of the square root call and into a fabs(). In the simplest case, this: y = sqrt(x * x); becomes this: y = fabs(x); This patch relies on an earlier optimization in instcombine or reassociate to put the multiplication tree into a canonical form, so we don't have to search over every permutation of the multiplication tree. Because there are no IR-level FastMathFlags for intrinsics (PR21290), we have to use function-level attributes to do this optimization. This needs to be fixed for both the intrinsics and in the backend. Differential Revision: http://reviews.llvm.org/D5787 llvm-svn: 219944
This commit is contained in:
parent
d70f3c20b8
commit
c699a6117b
|
@ -93,6 +93,7 @@ private:
|
|||
Value *optimizePow(CallInst *CI, IRBuilder<> &B);
|
||||
Value *optimizeExp2(CallInst *CI, IRBuilder<> &B);
|
||||
Value *optimizeFabs(CallInst *CI, IRBuilder<> &B);
|
||||
Value *optimizeSqrt(CallInst *CI, IRBuilder<> &B);
|
||||
Value *optimizeSinCosPi(CallInst *CI, IRBuilder<> &B);
|
||||
|
||||
// Integer Library Call Optimizations
|
||||
|
|
|
@ -27,12 +27,14 @@
|
|||
#include "llvm/IR/Intrinsics.h"
|
||||
#include "llvm/IR/LLVMContext.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/PatternMatch.h"
|
||||
#include "llvm/Support/Allocator.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Target/TargetLibraryInfo.h"
|
||||
#include "llvm/Transforms/Utils/BuildLibCalls.h"
|
||||
|
||||
using namespace llvm;
|
||||
using namespace PatternMatch;
|
||||
|
||||
static cl::opt<bool>
|
||||
ColdErrorCalls("error-reporting-is-cold", cl::init(true), cl::Hidden,
|
||||
|
@ -1254,6 +1256,85 @@ Value *LibCallSimplifier::optimizeFabs(CallInst *CI, IRBuilder<> &B) {
|
|||
return Ret;
|
||||
}
|
||||
|
||||
Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilder<> &B) {
|
||||
Function *Callee = CI->getCalledFunction();
|
||||
|
||||
Value *Ret = nullptr;
|
||||
if (UnsafeFPShrink && Callee->getName() == "sqrt" &&
|
||||
TLI->has(LibFunc::sqrtf)) {
|
||||
Ret = optimizeUnaryDoubleFP(CI, B, true);
|
||||
}
|
||||
|
||||
// FIXME: For finer-grain optimization, we need intrinsics to have the same
|
||||
// fast-math flag decorations that are applied to FP instructions. For now,
|
||||
// we have to rely on the function-level unsafe-fp-math attribute to do this
|
||||
// optimization because there's no other way to express that the sqrt can be
|
||||
// reassociated.
|
||||
Function *F = CI->getParent()->getParent();
|
||||
if (F->hasFnAttribute("unsafe-fp-math")) {
|
||||
// Check for unsafe-fp-math = true.
|
||||
Attribute Attr = F->getFnAttribute("unsafe-fp-math");
|
||||
if (Attr.getValueAsString() != "true")
|
||||
return Ret;
|
||||
}
|
||||
Value *Op = CI->getArgOperand(0);
|
||||
if (Instruction *I = dyn_cast<Instruction>(Op)) {
|
||||
if (I->getOpcode() == Instruction::FMul && I->hasUnsafeAlgebra()) {
|
||||
// We're looking for a repeated factor in a multiplication tree,
|
||||
// so we can do this fold: sqrt(x * x) -> fabs(x);
|
||||
// or this fold: sqrt(x * x * y) -> fabs(x) * sqrt(y).
|
||||
Value *Op0 = I->getOperand(0);
|
||||
Value *Op1 = I->getOperand(1);
|
||||
Value *RepeatOp = nullptr;
|
||||
Value *OtherOp = nullptr;
|
||||
if (Op0 == Op1) {
|
||||
// Simple match: the operands of the multiply are identical.
|
||||
RepeatOp = Op0;
|
||||
} else {
|
||||
// Look for a more complicated pattern: one of the operands is itself
|
||||
// a multiply, so search for a common factor in that multiply.
|
||||
// Note: We don't bother looking any deeper than this first level or for
|
||||
// variations of this pattern because instcombine's visitFMUL and/or the
|
||||
// reassociation pass should give us this form.
|
||||
Value *OtherMul0, *OtherMul1;
|
||||
if (match(Op0, m_FMul(m_Value(OtherMul0), m_Value(OtherMul1)))) {
|
||||
// Pattern: sqrt((x * y) * z)
|
||||
if (OtherMul0 == OtherMul1) {
|
||||
// Matched: sqrt((x * x) * z)
|
||||
RepeatOp = OtherMul0;
|
||||
OtherOp = Op1;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (RepeatOp) {
|
||||
// Fast math flags for any created instructions should match the sqrt
|
||||
// and multiply.
|
||||
// FIXME: We're not checking the sqrt because it doesn't have
|
||||
// fast-math-flags (see earlier comment).
|
||||
IRBuilder<true, ConstantFolder,
|
||||
IRBuilderDefaultInserter<true> >::FastMathFlagGuard Guard(B);
|
||||
B.SetFastMathFlags(I->getFastMathFlags());
|
||||
// If we found a repeated factor, hoist it out of the square root and
|
||||
// replace it with the fabs of that factor.
|
||||
Module *M = Callee->getParent();
|
||||
Type *ArgType = Op->getType();
|
||||
Value *Fabs = Intrinsic::getDeclaration(M, Intrinsic::fabs, ArgType);
|
||||
Value *FabsCall = B.CreateCall(Fabs, RepeatOp, "fabs");
|
||||
if (OtherOp) {
|
||||
// If we found a non-repeated factor, we still need to get its square
|
||||
// root. We then multiply that by the value that was simplified out
|
||||
// of the square root calculation.
|
||||
Value *Sqrt = Intrinsic::getDeclaration(M, Intrinsic::sqrt, ArgType);
|
||||
Value *SqrtCall = B.CreateCall(Sqrt, OtherOp, "sqrt");
|
||||
return B.CreateFMul(FabsCall, SqrtCall);
|
||||
}
|
||||
return FabsCall;
|
||||
}
|
||||
}
|
||||
}
|
||||
return Ret;
|
||||
}
|
||||
|
||||
static bool isTrigLibCall(CallInst *CI);
|
||||
static void insertSinCosCall(IRBuilder<> &B, Function *OrigCallee, Value *Arg,
|
||||
bool UseFloat, Value *&Sin, Value *&Cos,
|
||||
|
@ -1919,6 +2000,8 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) {
|
|||
return optimizeExp2(CI, Builder);
|
||||
case Intrinsic::fabs:
|
||||
return optimizeFabs(CI, Builder);
|
||||
case Intrinsic::sqrt:
|
||||
return optimizeSqrt(CI, Builder);
|
||||
default:
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -1995,6 +2078,10 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) {
|
|||
case LibFunc::fabs:
|
||||
case LibFunc::fabsl:
|
||||
return optimizeFabs(CI, Builder);
|
||||
case LibFunc::sqrtf:
|
||||
case LibFunc::sqrt:
|
||||
case LibFunc::sqrtl:
|
||||
return optimizeSqrt(CI, Builder);
|
||||
case LibFunc::ffs:
|
||||
case LibFunc::ffsl:
|
||||
case LibFunc::ffsll:
|
||||
|
@ -2055,7 +2142,6 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) {
|
|||
case LibFunc::logb:
|
||||
case LibFunc::sin:
|
||||
case LibFunc::sinh:
|
||||
case LibFunc::sqrt:
|
||||
case LibFunc::tan:
|
||||
case LibFunc::tanh:
|
||||
if (UnsafeFPShrink && hasFloatVersion(FuncName))
|
||||
|
|
|
@ -530,3 +530,173 @@ define float @fact_div6(float %x) {
|
|||
; CHECK: fact_div6
|
||||
; CHECK: %t3 = fsub fast float %t1, %t2
|
||||
}
|
||||
|
||||
; =========================================================================
|
||||
;
|
||||
; Test-cases for square root
|
||||
;
|
||||
; =========================================================================
|
||||
|
||||
; A squared factor fed into a square root intrinsic should be hoisted out
|
||||
; as a fabs() value.
|
||||
; We have to rely on a function-level attribute to enable this optimization
|
||||
; because intrinsics don't currently have access to IR-level fast-math
|
||||
; flags. If that changes, we can relax the requirement on all of these
|
||||
; tests to just specify 'fast' on the sqrt.
|
||||
|
||||
attributes #0 = { "unsafe-fp-math" = "true" }
|
||||
|
||||
declare double @llvm.sqrt.f64(double)
|
||||
|
||||
define double @sqrt_intrinsic_arg_squared(double %x) #0 {
|
||||
%mul = fmul fast double %x, %x
|
||||
%sqrt = call double @llvm.sqrt.f64(double %mul)
|
||||
ret double %sqrt
|
||||
|
||||
; CHECK-LABEL: sqrt_intrinsic_arg_squared(
|
||||
; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x)
|
||||
; CHECK-NEXT: ret double %fabs
|
||||
}
|
||||
|
||||
; Check all 6 combinations of a 3-way multiplication tree where
|
||||
; one factor is repeated.
|
||||
|
||||
define double @sqrt_intrinsic_three_args1(double %x, double %y) #0 {
|
||||
%mul = fmul fast double %y, %x
|
||||
%mul2 = fmul fast double %mul, %x
|
||||
%sqrt = call double @llvm.sqrt.f64(double %mul2)
|
||||
ret double %sqrt
|
||||
|
||||
; CHECK-LABEL: sqrt_intrinsic_three_args1(
|
||||
; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x)
|
||||
; CHECK-NEXT: %sqrt1 = call double @llvm.sqrt.f64(double %y)
|
||||
; CHECK-NEXT: %1 = fmul fast double %fabs, %sqrt1
|
||||
; CHECK-NEXT: ret double %1
|
||||
}
|
||||
|
||||
define double @sqrt_intrinsic_three_args2(double %x, double %y) #0 {
|
||||
%mul = fmul fast double %x, %y
|
||||
%mul2 = fmul fast double %mul, %x
|
||||
%sqrt = call double @llvm.sqrt.f64(double %mul2)
|
||||
ret double %sqrt
|
||||
|
||||
; CHECK-LABEL: sqrt_intrinsic_three_args2(
|
||||
; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x)
|
||||
; CHECK-NEXT: %sqrt1 = call double @llvm.sqrt.f64(double %y)
|
||||
; CHECK-NEXT: %1 = fmul fast double %fabs, %sqrt1
|
||||
; CHECK-NEXT: ret double %1
|
||||
}
|
||||
|
||||
define double @sqrt_intrinsic_three_args3(double %x, double %y) #0 {
|
||||
%mul = fmul fast double %x, %x
|
||||
%mul2 = fmul fast double %mul, %y
|
||||
%sqrt = call double @llvm.sqrt.f64(double %mul2)
|
||||
ret double %sqrt
|
||||
|
||||
; CHECK-LABEL: sqrt_intrinsic_three_args3(
|
||||
; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x)
|
||||
; CHECK-NEXT: %sqrt1 = call double @llvm.sqrt.f64(double %y)
|
||||
; CHECK-NEXT: %1 = fmul fast double %fabs, %sqrt1
|
||||
; CHECK-NEXT: ret double %1
|
||||
}
|
||||
|
||||
define double @sqrt_intrinsic_three_args4(double %x, double %y) #0 {
|
||||
%mul = fmul fast double %y, %x
|
||||
%mul2 = fmul fast double %x, %mul
|
||||
%sqrt = call double @llvm.sqrt.f64(double %mul2)
|
||||
ret double %sqrt
|
||||
|
||||
; CHECK-LABEL: sqrt_intrinsic_three_args4(
|
||||
; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x)
|
||||
; CHECK-NEXT: %sqrt1 = call double @llvm.sqrt.f64(double %y)
|
||||
; CHECK-NEXT: %1 = fmul fast double %fabs, %sqrt1
|
||||
; CHECK-NEXT: ret double %1
|
||||
}
|
||||
|
||||
define double @sqrt_intrinsic_three_args5(double %x, double %y) #0 {
|
||||
%mul = fmul fast double %x, %y
|
||||
%mul2 = fmul fast double %x, %mul
|
||||
%sqrt = call double @llvm.sqrt.f64(double %mul2)
|
||||
ret double %sqrt
|
||||
|
||||
; CHECK-LABEL: sqrt_intrinsic_three_args5(
|
||||
; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x)
|
||||
; CHECK-NEXT: %sqrt1 = call double @llvm.sqrt.f64(double %y)
|
||||
; CHECK-NEXT: %1 = fmul fast double %fabs, %sqrt1
|
||||
; CHECK-NEXT: ret double %1
|
||||
}
|
||||
|
||||
define double @sqrt_intrinsic_three_args6(double %x, double %y) #0 {
|
||||
%mul = fmul fast double %x, %x
|
||||
%mul2 = fmul fast double %y, %mul
|
||||
%sqrt = call double @llvm.sqrt.f64(double %mul2)
|
||||
ret double %sqrt
|
||||
|
||||
; CHECK-LABEL: sqrt_intrinsic_three_args6(
|
||||
; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x)
|
||||
; CHECK-NEXT: %sqrt1 = call double @llvm.sqrt.f64(double %y)
|
||||
; CHECK-NEXT: %1 = fmul fast double %fabs, %sqrt1
|
||||
; CHECK-NEXT: ret double %1
|
||||
}
|
||||
|
||||
define double @sqrt_intrinsic_arg_4th(double %x) #0 {
|
||||
%mul = fmul fast double %x, %x
|
||||
%mul2 = fmul fast double %mul, %mul
|
||||
%sqrt = call double @llvm.sqrt.f64(double %mul2)
|
||||
ret double %sqrt
|
||||
|
||||
; CHECK-LABEL: sqrt_intrinsic_arg_4th(
|
||||
; CHECK-NEXT: %mul = fmul fast double %x, %x
|
||||
; CHECK-NEXT: ret double %mul
|
||||
}
|
||||
|
||||
define double @sqrt_intrinsic_arg_5th(double %x) #0 {
|
||||
%mul = fmul fast double %x, %x
|
||||
%mul2 = fmul fast double %mul, %x
|
||||
%mul3 = fmul fast double %mul2, %mul
|
||||
%sqrt = call double @llvm.sqrt.f64(double %mul3)
|
||||
ret double %sqrt
|
||||
|
||||
; CHECK-LABEL: sqrt_intrinsic_arg_5th(
|
||||
; CHECK-NEXT: %mul = fmul fast double %x, %x
|
||||
; CHECK-NEXT: %sqrt1 = call double @llvm.sqrt.f64(double %x)
|
||||
; CHECK-NEXT: %1 = fmul fast double %mul, %sqrt1
|
||||
; CHECK-NEXT: ret double %1
|
||||
}
|
||||
|
||||
; Check that square root calls have the same behavior.
|
||||
|
||||
declare float @sqrtf(float)
|
||||
declare double @sqrt(double)
|
||||
declare fp128 @sqrtl(fp128)
|
||||
|
||||
define float @sqrt_call_squared_f32(float %x) #0 {
|
||||
%mul = fmul fast float %x, %x
|
||||
%sqrt = call float @sqrtf(float %mul)
|
||||
ret float %sqrt
|
||||
|
||||
; CHECK-LABEL: sqrt_call_squared_f32(
|
||||
; CHECK-NEXT: %fabs = call float @llvm.fabs.f32(float %x)
|
||||
; CHECK-NEXT: ret float %fabs
|
||||
}
|
||||
|
||||
define double @sqrt_call_squared_f64(double %x) #0 {
|
||||
%mul = fmul fast double %x, %x
|
||||
%sqrt = call double @sqrt(double %mul)
|
||||
ret double %sqrt
|
||||
|
||||
; CHECK-LABEL: sqrt_call_squared_f64(
|
||||
; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x)
|
||||
; CHECK-NEXT: ret double %fabs
|
||||
}
|
||||
|
||||
define fp128 @sqrt_call_squared_f128(fp128 %x) #0 {
|
||||
%mul = fmul fast fp128 %x, %x
|
||||
%sqrt = call fp128 @sqrtl(fp128 %mul)
|
||||
ret fp128 %sqrt
|
||||
|
||||
; CHECK-LABEL: sqrt_call_squared_f128(
|
||||
; CHECK-NEXT: %fabs = call fp128 @llvm.fabs.f128(fp128 %x)
|
||||
; CHECK-NEXT: ret fp128 %fabs
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue