From 110b1138baf4b0736cd65c5fa4b3b697c318a6da Mon Sep 17 00:00:00 2001 From: Evandro Menezes Date: Mon, 30 Sep 2019 20:52:21 +0000 Subject: [PATCH] [InstCombine] Expand the simplification of log() Expand the simplification of special cases of `log()` to include `log2()` and `log10()` as well as intrinsics and more types. Differential revision: https://reviews.llvm.org/D67199 llvm-svn: 373261 --- .../lib/Transforms/Utils/SimplifyLibCalls.cpp | 186 ++++++++++++++---- llvm/test/Transforms/InstCombine/log-pow.ll | 62 +++--- 2 files changed, 180 insertions(+), 68 deletions(-) diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp index 813816e82d25..b252167ba27d 100644 --- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -35,6 +35,7 @@ #include "llvm/IR/PatternMatch.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/KnownBits.h" +#include "llvm/Support/MathExtras.h" #include "llvm/Transforms/Utils/BuildLibCalls.h" #include "llvm/Transforms/Utils/SizeOpts.h" @@ -1457,9 +1458,7 @@ Value *LibCallSimplifier::replacePowWithExp(CallInst *Pow, IRBuilder<> &B) { StringRef ExpName; Intrinsic::ID ID; Value *ExpFn; - LibFunc LibFnFloat; - LibFunc LibFnDouble; - LibFunc LibFnLongDouble; + LibFunc LibFnFloat, LibFnDouble, LibFnLongDouble; switch (LibFn) { default: @@ -1809,48 +1808,155 @@ Value *LibCallSimplifier::optimizeFMinFMax(CallInst *CI, IRBuilder<> &B) { return B.CreateCall(F, { CI->getArgOperand(0), CI->getArgOperand(1) }); } -Value *LibCallSimplifier::optimizeLog(CallInst *CI, IRBuilder<> &B) { - Function *Callee = CI->getCalledFunction(); +Value *LibCallSimplifier::optimizeLog(CallInst *Log, IRBuilder<> &B) { + Function *LogFn = Log->getCalledFunction(); + AttributeList Attrs = LogFn->getAttributes(); + StringRef LogNm = LogFn->getName(); + Intrinsic::ID LogID = LogFn->getIntrinsicID(); + Module *Mod = Log->getModule(); + Type *Ty = Log->getType(); Value *Ret = nullptr; - StringRef Name = Callee->getName(); - if (UnsafeFPShrink && hasFloatVersion(Name)) - Ret = optimizeUnaryDoubleFP(CI, B, true); - if (!CI->isFast()) - return Ret; - Value *Op1 = CI->getArgOperand(0); - auto *OpC = dyn_cast(Op1); + if (UnsafeFPShrink && hasFloatVersion(LogNm)) + Ret = optimizeUnaryDoubleFP(Log, B, true); // The earlier call must also be 'fast' in order to do these transforms. - if (!OpC || !OpC->isFast()) + CallInst *Arg = dyn_cast(Log->getArgOperand(0)); + if (!Log->isFast() || !Arg || !Arg->isFast() || !Arg->hasOneUse()) return Ret; - // log(pow(x,y)) -> y*log(x) - // This is only applicable to log, log2, log10. - if (Name != "log" && Name != "log2" && Name != "log10") + LibFunc LogLb, ExpLb, Exp2Lb, Exp10Lb, PowLb; + + // This is only applicable to log(), log2(), log10(). + if (TLI->getLibFunc(LogNm, LogLb)) + switch (LogLb) { + case LibFunc_logf: + LogID = Intrinsic::log; + ExpLb = LibFunc_expf; + Exp2Lb = LibFunc_exp2f; + Exp10Lb = LibFunc_exp10f; + PowLb = LibFunc_powf; + break; + case LibFunc_log: + LogID = Intrinsic::log; + ExpLb = LibFunc_exp; + Exp2Lb = LibFunc_exp2; + Exp10Lb = LibFunc_exp10; + PowLb = LibFunc_pow; + break; + case LibFunc_logl: + LogID = Intrinsic::log; + ExpLb = LibFunc_expl; + Exp2Lb = LibFunc_exp2l; + Exp10Lb = LibFunc_exp10l; + PowLb = LibFunc_powl; + break; + case LibFunc_log2f: + LogID = Intrinsic::log2; + ExpLb = LibFunc_expf; + Exp2Lb = LibFunc_exp2f; + Exp10Lb = LibFunc_exp10f; + PowLb = LibFunc_powf; + break; + case LibFunc_log2: + LogID = Intrinsic::log2; + ExpLb = LibFunc_exp; + Exp2Lb = LibFunc_exp2; + Exp10Lb = LibFunc_exp10; + PowLb = LibFunc_pow; + break; + case LibFunc_log2l: + LogID = Intrinsic::log2; + ExpLb = LibFunc_expl; + Exp2Lb = LibFunc_exp2l; + Exp10Lb = LibFunc_exp10l; + PowLb = LibFunc_powl; + break; + case LibFunc_log10f: + LogID = Intrinsic::log10; + ExpLb = LibFunc_expf; + Exp2Lb = LibFunc_exp2f; + Exp10Lb = LibFunc_exp10f; + PowLb = LibFunc_powf; + break; + case LibFunc_log10: + LogID = Intrinsic::log10; + ExpLb = LibFunc_exp; + Exp2Lb = LibFunc_exp2; + Exp10Lb = LibFunc_exp10; + PowLb = LibFunc_pow; + break; + case LibFunc_log10l: + LogID = Intrinsic::log10; + ExpLb = LibFunc_expl; + Exp2Lb = LibFunc_exp2l; + Exp10Lb = LibFunc_exp10l; + PowLb = LibFunc_powl; + break; + default: + return Ret; + } + else if (LogID == Intrinsic::log || LogID == Intrinsic::log2 || + LogID == Intrinsic::log10) { + if (Ty->getScalarType()->isFloatTy()) { + ExpLb = LibFunc_expf; + Exp2Lb = LibFunc_exp2f; + Exp10Lb = LibFunc_exp10f; + PowLb = LibFunc_powf; + } else if (Ty->getScalarType()->isDoubleTy()) { + ExpLb = LibFunc_exp; + Exp2Lb = LibFunc_exp2; + Exp10Lb = LibFunc_exp10; + PowLb = LibFunc_pow; + } else + return Ret; + } else return Ret; IRBuilder<>::FastMathFlagGuard Guard(B); - FastMathFlags FMF; - FMF.setFast(); - B.setFastMathFlags(FMF); + B.setFastMathFlags(FastMathFlags::getFast()); - LibFunc Func; - Function *F = OpC->getCalledFunction(); - if (F && ((TLI->getLibFunc(F->getName(), Func) && TLI->has(Func) && - Func == LibFunc_pow) || F->getIntrinsicID() == Intrinsic::pow)) - return B.CreateFMul(OpC->getArgOperand(1), - emitUnaryFloatFnCall(OpC->getOperand(0), Callee->getName(), B, - Callee->getAttributes()), "mul"); + Function *ArgFn = Arg->getCalledFunction(); + StringRef ArgNm = ArgFn->getName(); + Intrinsic::ID ArgID = ArgFn->getIntrinsicID(); + LibFunc ArgLb = NotLibFunc; + TLI->getLibFunc(ArgNm, ArgLb); + + // log(pow(x,y)) -> y*log(x) + if (ArgLb == PowLb || ArgID == Intrinsic::pow) { + Value *LogX = + Log->doesNotAccessMemory() + ? B.CreateCall(Intrinsic::getDeclaration(Mod, LogID, Ty), + Arg->getOperand(0), "log") + : emitUnaryFloatFnCall(Arg->getOperand(0), LogNm, B, Attrs); + Value *MulY = B.CreateFMul(Arg->getArgOperand(1), LogX, "mul"); + // Since pow() may have side effects, e.g. errno, + // dead code elimination may not be trusted to remove it. + substituteInParent(Arg, MulY); + return MulY; + } + // log(exp{,2,10}(y)) -> y*log({e,2,10}) + // TODO: There is no exp10() intrinsic yet. + else if (ArgLb == ExpLb || ArgLb == Exp2Lb || ArgLb == Exp10Lb || + ArgID == Intrinsic::exp || ArgID == Intrinsic::exp2) { + Constant *Eul; + if (ArgLb == ExpLb || ArgID == Intrinsic::exp) + Eul = ConstantFP::get(Log->getType(), M_E); + else if (ArgLb == Exp2Lb || ArgID == Intrinsic::exp2) + Eul = ConstantFP::get(Log->getType(), 2.0); + else + Eul = ConstantFP::get(Log->getType(), 10.0); + Value *LogE = Log->doesNotAccessMemory() + ? B.CreateCall(Intrinsic::getDeclaration(Mod, LogID, Ty), + Eul, "log") + : emitUnaryFloatFnCall(Eul, LogNm, B, Attrs); + Value *MulY = B.CreateFMul(Arg->getArgOperand(0), LogE, "mul"); + // Since exp() may have side effects, e.g. errno, + // dead code elimination may not be trusted to remove it. + substituteInParent(Arg, MulY); + return MulY; + } - // log(exp2(y)) -> y*log(2) - if (F && Name == "log" && TLI->getLibFunc(F->getName(), Func) && - TLI->has(Func) && Func == LibFunc_exp2) - return B.CreateFMul( - OpC->getArgOperand(0), - emitUnaryFloatFnCall(ConstantFP::get(CI->getType(), 2.0), - Callee->getName(), B, Callee->getAttributes()), - "logmul"); return Ret; } @@ -2801,11 +2907,21 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI, case LibFunc_sqrt: case LibFunc_sqrtl: return optimizeSqrt(CI, Builder); + case LibFunc_logf: case LibFunc_log: + case LibFunc_logl: + case LibFunc_log10f: case LibFunc_log10: + case LibFunc_log10l: + case LibFunc_log1pf: case LibFunc_log1p: + case LibFunc_log1pl: + case LibFunc_log2f: case LibFunc_log2: + case LibFunc_log2l: + case LibFunc_logbf: case LibFunc_logb: + case LibFunc_logbl: return optimizeLog(CI, Builder); case LibFunc_tan: case LibFunc_tanf: @@ -2896,6 +3012,8 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) { case Intrinsic::exp2: return optimizeExp2(CI, Builder); case Intrinsic::log: + case Intrinsic::log2: + case Intrinsic::log10: return optimizeLog(CI, Builder); case Intrinsic::sqrt: return optimizeSqrt(CI, Builder); diff --git a/llvm/test/Transforms/InstCombine/log-pow.ll b/llvm/test/Transforms/InstCombine/log-pow.ll index 3a6b4eaa4e7b..b7085e955ef4 100644 --- a/llvm/test/Transforms/InstCombine/log-pow.ll +++ b/llvm/test/Transforms/InstCombine/log-pow.ll @@ -3,7 +3,7 @@ define double @log_pow(double %x, double %y) { ; CHECK-LABEL: @log_pow( -; CHECK-NEXT: [[LOG1:%.*]] = call fast double @log(double [[X:%.*]]) #0 +; CHECK-NEXT: [[LOG1:%.*]] = call fast double @llvm.log.f64(double [[X:%.*]]) ; CHECK-NEXT: [[MUL:%.*]] = fmul fast double [[LOG1]], [[Y:%.*]] ; CHECK-NEXT: ret double [[MUL]] ; @@ -12,24 +12,22 @@ define double @log_pow(double %x, double %y) { ret double %log } -; FIXME: log10f() should also be simplified. define float @log10f_powf(float %x, float %y) { ; CHECK-LABEL: @log10f_powf( -; CHECK-NEXT: [[POW:%.*]] = call fast float @powf(float [[X:%.*]], float [[Y:%.*]]) -; CHECK-NEXT: [[LOG:%.*]] = call fast float @log10f(float [[POW]]) -; CHECK-NEXT: ret float [[LOG]] +; CHECK-NEXT: [[LOG1:%.*]] = call fast float @llvm.log10.f32(float [[X:%.*]]) +; CHECK-NEXT: [[MUL:%.*]] = fmul fast float [[LOG1]], [[Y:%.*]] +; CHECK-NEXT: ret float [[MUL]] ; %pow = call fast float @powf(float %x, float %y) - %log = call fast float @log10f(float %pow) + %log = call fast float @llvm.log10.f32(float %pow) ret float %log } -; FIXME: Intrinsic log2() should also be simplified. define <2 x double> @log2v_powv(<2 x double> %x, <2 x double> %y) { ; CHECK-LABEL: @log2v_powv( -; CHECK-NEXT: [[POW:%.*]] = call fast <2 x double> @llvm.pow.v2f64(<2 x double> [[X:%.*]], <2 x double> [[Y:%.*]]) -; CHECK-NEXT: [[LOG:%.*]] = call fast <2 x double> @llvm.log2.v2f64(<2 x double> [[POW]]) -; CHECK-NEXT: ret <2 x double> [[LOG]] +; CHECK-NEXT: [[LOG1:%.*]] = call fast <2 x double> @llvm.log2.v2f64(<2 x double> [[X:%.*]]) +; CHECK-NEXT: [[MUL:%.*]] = fmul fast <2 x double> [[LOG1]], [[Y:%.*]] +; CHECK-NEXT: ret <2 x double> [[MUL]] ; %pow = call fast <2 x double> @llvm.pow.v2f64(<2 x double> %x, <2 x double> %y) %log = call fast <2 x double> @llvm.log2.v2f64(<2 x double> %pow) @@ -58,39 +56,33 @@ define float @function_pointer(float ()* %fptr, float %p1) { ret float %log } -; FIXME: The call to exp2() should be removed. -define double @log_exp2(double %x) { -; CHECK-LABEL: @log_exp2( -; CHECK-NEXT: [[EXP:%.*]] = call fast double @exp2(double [[X:%.*]]) -; CHECK-NEXT: [[LOGMUL:%.*]] = fmul fast double [[X]], 0x3FE62E42FEFA39EF -; CHECK-NEXT: ret double [[LOGMUL]] +define double @log10_exp(double %x) { +; CHECK-LABEL: @log10_exp( +; CHECK-NEXT: [[MUL:%.*]] = fmul fast double [[X:%.*]], 0x3FDBCB7B1526E50E +; CHECK-NEXT: ret double [[MUL]] ; - %exp = call fast double @exp2(double %x) - %log = call fast double @log(double %exp) + %exp = call fast double @exp(double %x) + %log = call fast double @log10(double %exp) ret double %log } -; FIXME: Intrinsic logf() should also be simplified. define <2 x float> @logv_exp2v(<2 x float> %x) { ; CHECK-LABEL: @logv_exp2v( -; CHECK-NEXT: [[EXP:%.*]] = call fast <2 x float> @llvm.exp2.v2f32(<2 x float> [[X:%.*]]) -; CHECK-NEXT: [[LOG:%.*]] = call fast <2 x float> @llvm.log.v2f32(<2 x float> [[EXP]]) -; CHECK-NEXT: ret <2 x float> [[LOG]] +; CHECK-NEXT: [[MUL:%.*]] = fmul fast <2 x float> [[X:%.*]], +; CHECK-NEXT: ret <2 x float> [[MUL]] ; %exp = call fast <2 x float> @llvm.exp2.v2f32(<2 x float> %x) %log = call fast <2 x float> @llvm.log.v2f32(<2 x float> %exp) ret <2 x float> %log } -; FIXME: log10f() should also be simplified. -define float @log10f_exp2f(float %x) { -; CHECK-LABEL: @log10f_exp2f( -; CHECK-NEXT: [[EXP:%.*]] = call fast float @exp2f(float [[X:%.*]]) -; CHECK-NEXT: [[LOG:%.*]] = call fast float @log10f(float [[EXP]]) -; CHECK-NEXT: ret float [[LOG]] +define float @log2f_exp10f(float %x) { +; CHECK-LABEL: @log2f_exp10f( +; CHECK-NEXT: [[MUL:%.*]] = fmul fast float [[X:%.*]], 0x400A934F00000000 +; CHECK-NEXT: ret float [[MUL]] ; - %exp = call fast float @exp2f(float %x) - %log = call fast float @log10f(float %exp) + %exp = call fast float @exp10f(float %x) + %log = call fast float @log2f(float %exp) ret float %log } @@ -108,11 +100,13 @@ define double @log_exp2_not_fast(double %x) { declare double @log(double) #0 declare float @logf(float) #0 declare <2 x float> @llvm.log.v2f32(<2 x float>) -declare double @log2(double) #0 +declare float @log2f(float) #0 declare <2 x double> @llvm.log2.v2f64(<2 x double>) -declare float @log10f(float) #0 -declare double @exp2(double) -declare float @exp2f(float) +declare double @log10(double) #0 +declare float @llvm.log10.f32(float) +declare double @exp(double %x) #0 +declare double @exp2(double) #0 +declare float @exp10f(float) #0 declare <2 x float> @llvm.exp2.v2f32(<2 x float>) declare double @pow(double, double) #0 declare float @powf(float, float) #0