From ddd0cb5ecf224cc9d5b1b502d7ca42d8d735c0c0 Mon Sep 17 00:00:00 2001 From: Karthik Bhat Date: Sat, 3 May 2014 09:59:54 +0000 Subject: [PATCH] Vectorize intrinsic math function calls in SLPVectorizer. This patch adds support to recognize and vectorize intrinsic math functions in SLPVectorizer. Review: http://reviews.llvm.org/D3560 and http://reviews.llvm.org/D3559 llvm-svn: 207901 --- .../llvm/Transforms/Utils/VectorUtils.h | 124 +++++++++++++++++ .../Transforms/Vectorize/LoopVectorize.cpp | 123 ----------------- .../Transforms/Vectorize/SLPVectorizer.cpp | 42 +++--- .../test/Transforms/SLPVectorizer/X86/call.ll | 128 ++++++++++++++++++ 4 files changed, 274 insertions(+), 143 deletions(-) create mode 100644 llvm/test/Transforms/SLPVectorizer/X86/call.ll diff --git a/llvm/include/llvm/Transforms/Utils/VectorUtils.h b/llvm/include/llvm/Transforms/Utils/VectorUtils.h index 2b1a9050ac75..d8738c80249c 100644 --- a/llvm/include/llvm/Transforms/Utils/VectorUtils.h +++ b/llvm/include/llvm/Transforms/Utils/VectorUtils.h @@ -15,6 +15,7 @@ #define LLVM_TRANSFORMS_UTILS_VECTORUTILS_H #include "llvm/IR/Intrinsics.h" +#include "llvm/Target/TargetLibraryInfo.h" namespace llvm { @@ -51,6 +52,129 @@ static inline bool isTriviallyVectorizable(Intrinsic::ID ID) { } } + +static Intrinsic::ID checkUnaryFloatSignature(const CallInst &I, + Intrinsic::ID ValidIntrinsicID) { + if (I.getNumArgOperands() != 1 || + !I.getArgOperand(0)->getType()->isFloatingPointTy() || + I.getType() != I.getArgOperand(0)->getType() || + !I.onlyReadsMemory()) + return Intrinsic::not_intrinsic; + + return ValidIntrinsicID; +} + +static Intrinsic::ID checkBinaryFloatSignature(const CallInst &I, + Intrinsic::ID ValidIntrinsicID) { + if (I.getNumArgOperands() != 2 || + !I.getArgOperand(0)->getType()->isFloatingPointTy() || + !I.getArgOperand(1)->getType()->isFloatingPointTy() || + I.getType() != I.getArgOperand(0)->getType() || + I.getType() != I.getArgOperand(1)->getType() || + !I.onlyReadsMemory()) + return Intrinsic::not_intrinsic; + + return ValidIntrinsicID; +} + +static Intrinsic::ID +getIntrinsicIDForCall(CallInst *CI, const TargetLibraryInfo *TLI) { + // If we have an intrinsic call, check if it is trivially vectorizable. + if (IntrinsicInst *II = dyn_cast(CI)) { + Intrinsic::ID ID = II->getIntrinsicID(); + if (isTriviallyVectorizable(ID) || ID == Intrinsic::lifetime_start || + ID == Intrinsic::lifetime_end) + return ID; + else + return Intrinsic::not_intrinsic; + } + + if (!TLI) + return Intrinsic::not_intrinsic; + + LibFunc::Func Func; + Function *F = CI->getCalledFunction(); + // We're going to make assumptions on the semantics of the functions, check + // that the target knows that it's available in this environment and it does + // not have local linkage. + if (!F || F->hasLocalLinkage() || !TLI->getLibFunc(F->getName(), Func)) + return Intrinsic::not_intrinsic; + + // Otherwise check if we have a call to a function that can be turned into a + // vector intrinsic. + switch (Func) { + default: + break; + case LibFunc::sin: + case LibFunc::sinf: + case LibFunc::sinl: + return checkUnaryFloatSignature(*CI, Intrinsic::sin); + case LibFunc::cos: + case LibFunc::cosf: + case LibFunc::cosl: + return checkUnaryFloatSignature(*CI, Intrinsic::cos); + case LibFunc::exp: + case LibFunc::expf: + case LibFunc::expl: + return checkUnaryFloatSignature(*CI, Intrinsic::exp); + case LibFunc::exp2: + case LibFunc::exp2f: + case LibFunc::exp2l: + return checkUnaryFloatSignature(*CI, Intrinsic::exp2); + case LibFunc::log: + case LibFunc::logf: + case LibFunc::logl: + return checkUnaryFloatSignature(*CI, Intrinsic::log); + case LibFunc::log10: + case LibFunc::log10f: + case LibFunc::log10l: + return checkUnaryFloatSignature(*CI, Intrinsic::log10); + case LibFunc::log2: + case LibFunc::log2f: + case LibFunc::log2l: + return checkUnaryFloatSignature(*CI, Intrinsic::log2); + case LibFunc::fabs: + case LibFunc::fabsf: + case LibFunc::fabsl: + return checkUnaryFloatSignature(*CI, Intrinsic::fabs); + case LibFunc::copysign: + case LibFunc::copysignf: + case LibFunc::copysignl: + return checkBinaryFloatSignature(*CI, Intrinsic::copysign); + case LibFunc::floor: + case LibFunc::floorf: + case LibFunc::floorl: + return checkUnaryFloatSignature(*CI, Intrinsic::floor); + case LibFunc::ceil: + case LibFunc::ceilf: + case LibFunc::ceill: + return checkUnaryFloatSignature(*CI, Intrinsic::ceil); + case LibFunc::trunc: + case LibFunc::truncf: + case LibFunc::truncl: + return checkUnaryFloatSignature(*CI, Intrinsic::trunc); + case LibFunc::rint: + case LibFunc::rintf: + case LibFunc::rintl: + return checkUnaryFloatSignature(*CI, Intrinsic::rint); + case LibFunc::nearbyint: + case LibFunc::nearbyintf: + case LibFunc::nearbyintl: + return checkUnaryFloatSignature(*CI, Intrinsic::nearbyint); + case LibFunc::round: + case LibFunc::roundf: + case LibFunc::roundl: + return checkUnaryFloatSignature(*CI, Intrinsic::round); + case LibFunc::pow: + case LibFunc::powf: + case LibFunc::powl: + return checkBinaryFloatSignature(*CI, Intrinsic::pow); + } + + return Intrinsic::not_intrinsic; +} + + } // llvm namespace #endif diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index ab0cf9a5142b..e89237051b62 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -85,7 +85,6 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/Format.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Target/TargetLibraryInfo.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" @@ -2300,128 +2299,6 @@ LoopVectorizationLegality::getReductionIdentity(ReductionKind K, Type *Tp) { } } -static Intrinsic::ID checkUnaryFloatSignature(const CallInst &I, - Intrinsic::ID ValidIntrinsicID) { - if (I.getNumArgOperands() != 1 || - !I.getArgOperand(0)->getType()->isFloatingPointTy() || - I.getType() != I.getArgOperand(0)->getType() || - !I.onlyReadsMemory()) - return Intrinsic::not_intrinsic; - - return ValidIntrinsicID; -} - -static Intrinsic::ID checkBinaryFloatSignature(const CallInst &I, - Intrinsic::ID ValidIntrinsicID) { - if (I.getNumArgOperands() != 2 || - !I.getArgOperand(0)->getType()->isFloatingPointTy() || - !I.getArgOperand(1)->getType()->isFloatingPointTy() || - I.getType() != I.getArgOperand(0)->getType() || - I.getType() != I.getArgOperand(1)->getType() || - !I.onlyReadsMemory()) - return Intrinsic::not_intrinsic; - - return ValidIntrinsicID; -} - - -static Intrinsic::ID -getIntrinsicIDForCall(CallInst *CI, const TargetLibraryInfo *TLI) { - // If we have an intrinsic call, check if it is trivially vectorizable. - if (IntrinsicInst *II = dyn_cast(CI)) { - Intrinsic::ID ID = II->getIntrinsicID(); - if (isTriviallyVectorizable(ID) || ID == Intrinsic::lifetime_start || - ID == Intrinsic::lifetime_end) - return ID; - else - return Intrinsic::not_intrinsic; - } - - if (!TLI) - return Intrinsic::not_intrinsic; - - LibFunc::Func Func; - Function *F = CI->getCalledFunction(); - // We're going to make assumptions on the semantics of the functions, check - // that the target knows that it's available in this environment and it does - // not have local linkage. - if (!F || F->hasLocalLinkage() || !TLI->getLibFunc(F->getName(), Func)) - return Intrinsic::not_intrinsic; - - // Otherwise check if we have a call to a function that can be turned into a - // vector intrinsic. - switch (Func) { - default: - break; - case LibFunc::sin: - case LibFunc::sinf: - case LibFunc::sinl: - return checkUnaryFloatSignature(*CI, Intrinsic::sin); - case LibFunc::cos: - case LibFunc::cosf: - case LibFunc::cosl: - return checkUnaryFloatSignature(*CI, Intrinsic::cos); - case LibFunc::exp: - case LibFunc::expf: - case LibFunc::expl: - return checkUnaryFloatSignature(*CI, Intrinsic::exp); - case LibFunc::exp2: - case LibFunc::exp2f: - case LibFunc::exp2l: - return checkUnaryFloatSignature(*CI, Intrinsic::exp2); - case LibFunc::log: - case LibFunc::logf: - case LibFunc::logl: - return checkUnaryFloatSignature(*CI, Intrinsic::log); - case LibFunc::log10: - case LibFunc::log10f: - case LibFunc::log10l: - return checkUnaryFloatSignature(*CI, Intrinsic::log10); - case LibFunc::log2: - case LibFunc::log2f: - case LibFunc::log2l: - return checkUnaryFloatSignature(*CI, Intrinsic::log2); - case LibFunc::fabs: - case LibFunc::fabsf: - case LibFunc::fabsl: - return checkUnaryFloatSignature(*CI, Intrinsic::fabs); - case LibFunc::copysign: - case LibFunc::copysignf: - case LibFunc::copysignl: - return checkBinaryFloatSignature(*CI, Intrinsic::copysign); - case LibFunc::floor: - case LibFunc::floorf: - case LibFunc::floorl: - return checkUnaryFloatSignature(*CI, Intrinsic::floor); - case LibFunc::ceil: - case LibFunc::ceilf: - case LibFunc::ceill: - return checkUnaryFloatSignature(*CI, Intrinsic::ceil); - case LibFunc::trunc: - case LibFunc::truncf: - case LibFunc::truncl: - return checkUnaryFloatSignature(*CI, Intrinsic::trunc); - case LibFunc::rint: - case LibFunc::rintf: - case LibFunc::rintl: - return checkUnaryFloatSignature(*CI, Intrinsic::rint); - case LibFunc::nearbyint: - case LibFunc::nearbyintf: - case LibFunc::nearbyintl: - return checkUnaryFloatSignature(*CI, Intrinsic::nearbyint); - case LibFunc::round: - case LibFunc::roundf: - case LibFunc::roundl: - return checkUnaryFloatSignature(*CI, Intrinsic::round); - case LibFunc::pow: - case LibFunc::powf: - case LibFunc::powl: - return checkBinaryFloatSignature(*CI, Intrinsic::pow); - } - - return Intrinsic::not_intrinsic; -} - /// This function translates the reduction kind to an LLVM binary operator. static unsigned getReductionBinOp(LoopVectorizationLegality::ReductionKind Kind) { diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index b49b1b0ff5ca..15fb6d72b84d 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -346,9 +346,9 @@ public: typedef SmallVector StoreList; BoUpSLP(Function *Func, ScalarEvolution *Se, const DataLayout *Dl, - TargetTransformInfo *Tti, AliasAnalysis *Aa, LoopInfo *Li, + TargetTransformInfo *Tti, TargetLibraryInfo *TLi, AliasAnalysis *Aa, LoopInfo *Li, DominatorTree *Dt) : - F(Func), SE(Se), DL(Dl), TTI(Tti), AA(Aa), LI(Li), DT(Dt), + F(Func), SE(Se), DL(Dl), TTI(Tti), TLI(TLi), AA(Aa), LI(Li), DT(Dt), Builder(Se->getContext()) { // Setup the block numbering utility for all of the blocks in the // function. @@ -536,6 +536,7 @@ private: ScalarEvolution *SE; const DataLayout *DL; TargetTransformInfo *TTI; + TargetLibraryInfo *TLI; AliasAnalysis *AA; LoopInfo *LI; DominatorTree *DT; @@ -949,34 +950,36 @@ void BoUpSLP::buildTree_rec(ArrayRef VL, unsigned Depth) { } case Instruction::Call: { // Check if the calls are all to the same vectorizable intrinsic. - IntrinsicInst *II = dyn_cast(VL[0]); - Intrinsic::ID ID = II ? II->getIntrinsicID() : Intrinsic::not_intrinsic; - + CallInst *CI = cast(VL[0]); + // Check if this is an Intrinsic call or something that can be + // represented by an intrinsic call + Intrinsic::ID ID = getIntrinsicIDForCall(CI, TLI); if (!isTriviallyVectorizable(ID)) { newTreeEntry(VL, false); DEBUG(dbgs() << "SLP: Non-vectorizable call.\n"); return; } - Function *Int = II->getCalledFunction(); + Function *Int = CI->getCalledFunction(); for (unsigned i = 1, e = VL.size(); i != e; ++i) { - IntrinsicInst *II2 = dyn_cast(VL[i]); - if (!II2 || II2->getCalledFunction() != Int) { + CallInst *CI2 = dyn_cast(VL[i]); + if (!CI2 || CI2->getCalledFunction() != Int || + getIntrinsicIDForCall(CI2, TLI) != ID) { newTreeEntry(VL, false); - DEBUG(dbgs() << "SLP: mismatched calls:" << *II << "!=" << *VL[i] + DEBUG(dbgs() << "SLP: mismatched calls:" << *CI << "!=" << *VL[i] << "\n"); return; } } newTreeEntry(VL, true); - for (unsigned i = 0, e = II->getNumArgOperands(); i != e; ++i) { + for (unsigned i = 0, e = CI->getNumArgOperands(); i != e; ++i) { ValueList Operands; // Prepare the operand vector. for (unsigned j = 0; j < VL.size(); ++j) { - IntrinsicInst *II2 = dyn_cast(VL[j]); - Operands.push_back(II2->getArgOperand(i)); + CallInst *CI2 = dyn_cast(VL[j]); + Operands.push_back(CI2->getArgOperand(i)); } buildTree_rec(Operands, Depth + 1); } @@ -1132,12 +1135,11 @@ int BoUpSLP::getEntryCost(TreeEntry *E) { } case Instruction::Call: { CallInst *CI = cast(VL0); - IntrinsicInst *II = cast(CI); - Intrinsic::ID ID = II->getIntrinsicID(); + Intrinsic::ID ID = getIntrinsicIDForCall(CI, TLI); // Calculate the cost of the scalar and vector calls. SmallVector ScalarTys, VecTys; - for (unsigned op = 0, opc = II->getNumArgOperands(); op!= opc; ++op) { + for (unsigned op = 0, opc = CI->getNumArgOperands(); op!= opc; ++op) { ScalarTys.push_back(CI->getArgOperand(op)->getType()); VecTys.push_back(VectorType::get(CI->getArgOperand(op)->getType(), VecTy->getNumElements())); @@ -1150,7 +1152,7 @@ int BoUpSLP::getEntryCost(TreeEntry *E) { DEBUG(dbgs() << "SLP: Call cost "<< VecCallCost - ScalarCallCost << " (" << VecCallCost << "-" << ScalarCallCost << ")" - << " for " << *II << "\n"); + << " for " << *CI << "\n"); return VecCallCost - ScalarCallCost; } @@ -1643,7 +1645,6 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { } case Instruction::Call: { CallInst *CI = cast(VL0); - setInsertPointAfterBundle(E->Scalars); std::vector OpVecs; for (int j = 0, e = CI->getNumArgOperands(); j < e; ++j) { @@ -1659,8 +1660,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { } Module *M = F->getParent(); - IntrinsicInst *II = cast(CI); - Intrinsic::ID ID = II->getIntrinsicID(); + Intrinsic::ID ID = getIntrinsicIDForCall(CI, TLI); Type *Tys[] = { VectorType::get(CI->getType(), E->Scalars.size()) }; Function *CF = Intrinsic::getDeclaration(M, ID, Tys); Value *V = Builder.CreateCall(CF, OpVecs); @@ -1867,6 +1867,7 @@ struct SLPVectorizer : public FunctionPass { ScalarEvolution *SE; const DataLayout *DL; TargetTransformInfo *TTI; + TargetLibraryInfo *TLI; AliasAnalysis *AA; LoopInfo *LI; DominatorTree *DT; @@ -1879,6 +1880,7 @@ struct SLPVectorizer : public FunctionPass { DataLayoutPass *DLP = getAnalysisIfAvailable(); DL = DLP ? &DLP->getDataLayout() : nullptr; TTI = &getAnalysis(); + TLI = getAnalysisIfAvailable(); AA = &getAnalysis(); LI = &getAnalysis(); DT = &getAnalysis().getDomTree(); @@ -1904,7 +1906,7 @@ struct SLPVectorizer : public FunctionPass { // Use the bottom up slp vectorizer to construct chains that start with // he store instructions. - BoUpSLP R(&F, SE, DL, TTI, AA, LI, DT); + BoUpSLP R(&F, SE, DL, TTI, TLI, AA, LI, DT); // Scan the blocks in the function in post order. for (po_iterator it = po_begin(&F.getEntryBlock()), diff --git a/llvm/test/Transforms/SLPVectorizer/X86/call.ll b/llvm/test/Transforms/SLPVectorizer/X86/call.ll new file mode 100644 index 000000000000..83d45c0a9d72 --- /dev/null +++ b/llvm/test/Transforms/SLPVectorizer/X86/call.ll @@ -0,0 +1,128 @@ +; RUN: opt < %s -basicaa -slp-vectorizer -slp-threshold=-999 -dce -S -mtriple=x86_64-apple-macosx10.8.0 -mcpu=corei7-avx | FileCheck %s + +target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v64:64:64-v128:128:128-a0:0:64-s0:64:64-f80:128:128-n8:16:32:64-S128" +target triple = "x86_64-apple-macosx10.8.0" + +declare double @sin(double) +declare double @cos(double) +declare double @pow(double, double) +declare double @exp2(double) +declare i64 @round(i64) + + +; CHECK: sin_libm +; CHECK: call <2 x double> @llvm.sin.v2f64 +; CHECK: ret void +define void @sin_libm(double* %a, double* %b, double* %c) { +entry: + %i0 = load double* %a, align 8 + %i1 = load double* %b, align 8 + %mul = fmul double %i0, %i1 + %call = tail call double @sin(double %mul) nounwind readnone + %arrayidx3 = getelementptr inbounds double* %a, i64 1 + %i3 = load double* %arrayidx3, align 8 + %arrayidx4 = getelementptr inbounds double* %b, i64 1 + %i4 = load double* %arrayidx4, align 8 + %mul5 = fmul double %i3, %i4 + %call5 = tail call double @sin(double %mul5) nounwind readnone + store double %call, double* %c, align 8 + %arrayidx5 = getelementptr inbounds double* %c, i64 1 + store double %call5, double* %arrayidx5, align 8 + ret void +} + +; CHECK: cos_libm +; CHECK: call <2 x double> @llvm.cos.v2f64 +; CHECK: ret void +define void @cos_libm(double* %a, double* %b, double* %c) { +entry: + %i0 = load double* %a, align 8 + %i1 = load double* %b, align 8 + %mul = fmul double %i0, %i1 + %call = tail call double @cos(double %mul) nounwind readnone + %arrayidx3 = getelementptr inbounds double* %a, i64 1 + %i3 = load double* %arrayidx3, align 8 + %arrayidx4 = getelementptr inbounds double* %b, i64 1 + %i4 = load double* %arrayidx4, align 8 + %mul5 = fmul double %i3, %i4 + %call5 = tail call double @cos(double %mul5) nounwind readnone + store double %call, double* %c, align 8 + %arrayidx5 = getelementptr inbounds double* %c, i64 1 + store double %call5, double* %arrayidx5, align 8 + ret void +} + +; CHECK: pow_libm +; CHECK: call <2 x double> @llvm.pow.v2f64 +; CHECK: ret void +define void @pow_libm(double* %a, double* %b, double* %c) { +entry: + %i0 = load double* %a, align 8 + %i1 = load double* %b, align 8 + %mul = fmul double %i0, %i1 + %call = tail call double @pow(double %mul,double %mul) nounwind readnone + %arrayidx3 = getelementptr inbounds double* %a, i64 1 + %i3 = load double* %arrayidx3, align 8 + %arrayidx4 = getelementptr inbounds double* %b, i64 1 + %i4 = load double* %arrayidx4, align 8 + %mul5 = fmul double %i3, %i4 + %call5 = tail call double @pow(double %mul5,double %mul5) nounwind readnone + store double %call, double* %c, align 8 + %arrayidx5 = getelementptr inbounds double* %c, i64 1 + store double %call5, double* %arrayidx5, align 8 + ret void +} + + +; CHECK: exp2_libm +; CHECK: call <2 x double> @llvm.exp2.v2f64 +; CHECK: ret void +define void @exp2_libm(double* %a, double* %b, double* %c) { +entry: + %i0 = load double* %a, align 8 + %i1 = load double* %b, align 8 + %mul = fmul double %i0, %i1 + %call = tail call double @exp2(double %mul) nounwind readnone + %arrayidx3 = getelementptr inbounds double* %a, i64 1 + %i3 = load double* %arrayidx3, align 8 + %arrayidx4 = getelementptr inbounds double* %b, i64 1 + %i4 = load double* %arrayidx4, align 8 + %mul5 = fmul double %i3, %i4 + %call5 = tail call double @exp2(double %mul5) nounwind readnone + store double %call, double* %c, align 8 + %arrayidx5 = getelementptr inbounds double* %c, i64 1 + store double %call5, double* %arrayidx5, align 8 + ret void +} + + +; Negative test case +; CHECK: round_custom +; CHECK-NOT: load <4 x i64> +; CHECK: ret void +define void @round_custom(i64* %a, i64* %b, i64* %c) { +entry: + %i0 = load i64* %a, align 8 + %i1 = load i64* %b, align 8 + %mul = mul i64 %i0, %i1 + %call = tail call i64 @round(i64 %mul) nounwind readnone + %arrayidx3 = getelementptr inbounds i64* %a, i64 1 + %i3 = load i64* %arrayidx3, align 8 + %arrayidx4 = getelementptr inbounds i64* %b, i64 1 + %i4 = load i64* %arrayidx4, align 8 + %mul5 = mul i64 %i3, %i4 + %call5 = tail call i64 @round(i64 %mul5) nounwind readnone + store i64 %call, i64* %c, align 8 + %arrayidx5 = getelementptr inbounds i64* %c, i64 1 + store i64 %call5, i64* %arrayidx5, align 8 + ret void +} + + +; CHECK: declare <2 x double> @llvm.sin.v2f64(<2 x double>) #0 +; CHECK: declare <2 x double> @llvm.cos.v2f64(<2 x double>) #0 +; CHECK: declare <2 x double> @llvm.pow.v2f64(<2 x double>, <2 x double>) #0 +; CHECK: declare <2 x double> @llvm.exp2.v2f64(<2 x double>) #0 + +; CHECK: attributes #0 = { nounwind readnone } +