[InstCombine] Rewrite the binary op shrinking in visitFPTrunc to avoid creating overly small ConstantFPs that we'll just need to extend again.

Instead of returning the smaller FP constant we now return the minimal Type the constant can fit into. We also return the Type of the input to any fp extends. The legality checks are then done on just the size of these Types. If we find something profitable we then emit FPTruncs in front of the smaller binop and assume those FPTruncs will be constant folded or combined with any ConstantFPs or fpextends.

Differential Revision: https://reviews.llvm.org/D44038

llvm-svn: 326617
This commit is contained in:
Craig Topper 2018-03-02 21:25:18 +00:00
parent 1785e244eb
commit c7461e1aad
1 changed files with 43 additions and 47 deletions

View File

@ -1411,45 +1411,43 @@ Instruction *InstCombiner::visitSExt(SExtInst &CI) {
/// Return a Constant* for the specified floating-point constant if it fits /// Return a Constant* for the specified floating-point constant if it fits
/// in the specified FP type without changing its value. /// in the specified FP type without changing its value.
static Constant *fitsInFPType(ConstantFP *CFP, const fltSemantics &Sem) { static bool fitsInFPType(ConstantFP *CFP, const fltSemantics &Sem) {
bool losesInfo; bool losesInfo;
APFloat F = CFP->getValueAPF(); APFloat F = CFP->getValueAPF();
(void)F.convert(Sem, APFloat::rmNearestTiesToEven, &losesInfo); (void)F.convert(Sem, APFloat::rmNearestTiesToEven, &losesInfo);
if (!losesInfo) return !losesInfo;
return ConstantFP::get(CFP->getContext(), F);
return nullptr;
} }
static Constant *shrinkFPConstant(ConstantFP *CFP) { static Type *shrinkFPConstant(ConstantFP *CFP) {
if (CFP->getType() == Type::getPPC_FP128Ty(CFP->getContext())) if (CFP->getType() == Type::getPPC_FP128Ty(CFP->getContext()))
return nullptr; // No constant folding of this. return nullptr; // No constant folding of this.
// See if the value can be truncated to half and then reextended. // See if the value can be truncated to half and then reextended.
if (Constant *NewCFP = fitsInFPType(CFP, APFloat::IEEEhalf())) if (fitsInFPType(CFP, APFloat::IEEEhalf()))
return NewCFP; return Type::getHalfTy(CFP->getContext());
// See if the value can be truncated to float and then reextended. // See if the value can be truncated to float and then reextended.
if (Constant *NewCFP = fitsInFPType(CFP, APFloat::IEEEsingle())) if (fitsInFPType(CFP, APFloat::IEEEsingle()))
return NewCFP; return Type::getFloatTy(CFP->getContext());
if (CFP->getType()->isDoubleTy()) if (CFP->getType()->isDoubleTy())
return nullptr; // Won't shrink. return nullptr; // Won't shrink.
if (Constant *NewCFP = fitsInFPType(CFP, APFloat::IEEEdouble())) if (fitsInFPType(CFP, APFloat::IEEEdouble()))
return NewCFP; return Type::getDoubleTy(CFP->getContext());
// Don't try to shrink to various long double types. // Don't try to shrink to various long double types.
return nullptr; return nullptr;
} }
/// Look through floating-point extensions until we get the source value. /// Find the minimum FP type we can safely truncate to.
static Value *lookThroughFPExtensions(Value *V) { static Type *getMinimumFPType(Value *V) {
while (auto *FPExt = dyn_cast<FPExtInst>(V)) if (auto *FPExt = dyn_cast<FPExtInst>(V))
V = FPExt->getOperand(0); return FPExt->getOperand(0)->getType();
// If this value is a constant, return the constant in the smallest FP type // If this value is a constant, return the constant in the smallest FP type
// that can accurately represent it. This allows us to turn // that can accurately represent it. This allows us to turn
// (float)((double)X+2.0) into x+2.0f. // (float)((double)X+2.0) into x+2.0f.
if (auto *CFP = dyn_cast<ConstantFP>(V)) if (auto *CFP = dyn_cast<ConstantFP>(V))
if (Constant *NewCFP = shrinkFPConstant(CFP)) if (Type *T = shrinkFPConstant(CFP))
return NewCFP; return T;
return V; return V->getType();
} }
Instruction *InstCombiner::visitFPTrunc(FPTruncInst &CI) { Instruction *InstCombiner::visitFPTrunc(FPTruncInst &CI) {
@ -1464,11 +1462,11 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &CI) {
// is explained below in the various case statements. // is explained below in the various case statements.
BinaryOperator *OpI = dyn_cast<BinaryOperator>(CI.getOperand(0)); BinaryOperator *OpI = dyn_cast<BinaryOperator>(CI.getOperand(0));
if (OpI && OpI->hasOneUse()) { if (OpI && OpI->hasOneUse()) {
Value *LHSOrig = lookThroughFPExtensions(OpI->getOperand(0)); Type *LHSMinType = getMinimumFPType(OpI->getOperand(0));
Value *RHSOrig = lookThroughFPExtensions(OpI->getOperand(1)); Type *RHSMinType = getMinimumFPType(OpI->getOperand(1));
unsigned OpWidth = OpI->getType()->getFPMantissaWidth(); unsigned OpWidth = OpI->getType()->getFPMantissaWidth();
unsigned LHSWidth = LHSOrig->getType()->getFPMantissaWidth(); unsigned LHSWidth = LHSMinType->getFPMantissaWidth();
unsigned RHSWidth = RHSOrig->getType()->getFPMantissaWidth(); unsigned RHSWidth = RHSMinType->getFPMantissaWidth();
unsigned SrcWidth = std::max(LHSWidth, RHSWidth); unsigned SrcWidth = std::max(LHSWidth, RHSWidth);
unsigned DstWidth = CI.getType()->getFPMantissaWidth(); unsigned DstWidth = CI.getType()->getFPMantissaWidth();
switch (OpI->getOpcode()) { switch (OpI->getOpcode()) {
@ -1494,12 +1492,10 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &CI) {
// could be tightened for those cases, but they are rare (the main // could be tightened for those cases, but they are rare (the main
// case of interest here is (float)((double)float + float)). // case of interest here is (float)((double)float + float)).
if (OpWidth >= 2*DstWidth+1 && DstWidth >= SrcWidth) { if (OpWidth >= 2*DstWidth+1 && DstWidth >= SrcWidth) {
if (LHSOrig->getType() != CI.getType()) Value *LHS = Builder.CreateFPTrunc(OpI->getOperand(0), CI.getType());
LHSOrig = Builder.CreateFPExt(LHSOrig, CI.getType()); Value *RHS = Builder.CreateFPTrunc(OpI->getOperand(1), CI.getType());
if (RHSOrig->getType() != CI.getType())
RHSOrig = Builder.CreateFPExt(RHSOrig, CI.getType());
Instruction *RI = Instruction *RI =
BinaryOperator::Create(OpI->getOpcode(), LHSOrig, RHSOrig); BinaryOperator::Create(OpI->getOpcode(), LHS, RHS);
RI->copyFastMathFlags(OpI); RI->copyFastMathFlags(OpI);
return RI; return RI;
} }
@ -1511,12 +1507,10 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &CI) {
// rounding can possibly occur; we can safely perform the operation // rounding can possibly occur; we can safely perform the operation
// in the destination format if it can represent both sources. // in the destination format if it can represent both sources.
if (OpWidth >= LHSWidth + RHSWidth && DstWidth >= SrcWidth) { if (OpWidth >= LHSWidth + RHSWidth && DstWidth >= SrcWidth) {
if (LHSOrig->getType() != CI.getType()) Value *LHS = Builder.CreateFPTrunc(OpI->getOperand(0), CI.getType());
LHSOrig = Builder.CreateFPExt(LHSOrig, CI.getType()); Value *RHS = Builder.CreateFPTrunc(OpI->getOperand(1), CI.getType());
if (RHSOrig->getType() != CI.getType())
RHSOrig = Builder.CreateFPExt(RHSOrig, CI.getType());
Instruction *RI = Instruction *RI =
BinaryOperator::CreateFMul(LHSOrig, RHSOrig); BinaryOperator::CreateFMul(LHS, RHS);
RI->copyFastMathFlags(OpI); RI->copyFastMathFlags(OpI);
return RI; return RI;
} }
@ -1529,33 +1523,35 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &CI) {
// condition used here is a good conservative first pass. // condition used here is a good conservative first pass.
// TODO: Tighten bound via rigorous analysis of the unbalanced case. // TODO: Tighten bound via rigorous analysis of the unbalanced case.
if (OpWidth >= 2*DstWidth && DstWidth >= SrcWidth) { if (OpWidth >= 2*DstWidth && DstWidth >= SrcWidth) {
if (LHSOrig->getType() != CI.getType()) Value *LHS = Builder.CreateFPTrunc(OpI->getOperand(0), CI.getType());
LHSOrig = Builder.CreateFPExt(LHSOrig, CI.getType()); Value *RHS = Builder.CreateFPTrunc(OpI->getOperand(1), CI.getType());
if (RHSOrig->getType() != CI.getType())
RHSOrig = Builder.CreateFPExt(RHSOrig, CI.getType());
Instruction *RI = Instruction *RI =
BinaryOperator::CreateFDiv(LHSOrig, RHSOrig); BinaryOperator::CreateFDiv(LHS, RHS);
RI->copyFastMathFlags(OpI); RI->copyFastMathFlags(OpI);
return RI; return RI;
} }
break; break;
case Instruction::FRem: case Instruction::FRem: {
// Remainder is straightforward. Remainder is always exact, so the // Remainder is straightforward. Remainder is always exact, so the
// type of OpI doesn't enter into things at all. We simply evaluate // type of OpI doesn't enter into things at all. We simply evaluate
// in whichever source type is larger, then convert to the // in whichever source type is larger, then convert to the
// destination type. // destination type.
if (SrcWidth == OpWidth) if (SrcWidth == OpWidth)
break; break;
if (LHSWidth < SrcWidth) Value *LHS, *RHS;
LHSOrig = Builder.CreateFPExt(LHSOrig, RHSOrig->getType()); if (LHSWidth == SrcWidth) {
else if (RHSWidth <= SrcWidth) LHS = Builder.CreateFPTrunc(OpI->getOperand(0), LHSMinType);
RHSOrig = Builder.CreateFPExt(RHSOrig, LHSOrig->getType()); RHS = Builder.CreateFPTrunc(OpI->getOperand(1), LHSMinType);
if (LHSOrig != OpI->getOperand(0) || RHSOrig != OpI->getOperand(1)) { } else {
Value *ExactResult = Builder.CreateFRem(LHSOrig, RHSOrig); LHS = Builder.CreateFPTrunc(OpI->getOperand(0), RHSMinType);
if (Instruction *RI = dyn_cast<Instruction>(ExactResult)) RHS = Builder.CreateFPTrunc(OpI->getOperand(1), RHSMinType);
RI->copyFastMathFlags(OpI);
return CastInst::CreateFPCast(ExactResult, CI.getType());
} }
Value *ExactResult = Builder.CreateFRem(LHS, RHS);
if (Instruction *RI = dyn_cast<Instruction>(ExactResult))
RI->copyFastMathFlags(OpI);
return CastInst::CreateFPCast(ExactResult, CI.getType());
}
} }
// (fptrunc (fneg x)) -> (fneg (fptrunc x)) // (fptrunc (fneg x)) -> (fneg (fptrunc x))