Perform factorization as a last resort of unsafe fadd/fsub simplification.

Rules include:
  1)1 x*y +/- x*z => x*(y +/- z) 
    (the order of operands dosen't matter)

  2) y/x +/- z/x => (y +/- z)/x 

 The transformation is disabled if the new add/sub expr "y +/- z" is a 
denormal/naz/inifinity.

rdar://12911472

llvm-svn: 177088
This commit is contained in:
Shuxin Yang 2013-03-14 18:08:26 +00:00
parent ed6d955416
commit 2eca602f8b
2 changed files with 196 additions and 5 deletions

View File

@ -150,7 +150,9 @@ namespace {
typedef SmallVector<const FAddend*, 4> AddendVect;
Value *simplifyFAdd(AddendVect& V, unsigned InstrQuota);
Value *performFactorization(Instruction *I);
/// Convert given addend to a Value
Value *createAddendVal(const FAddend &A, bool& NeedNeg);
@ -159,6 +161,7 @@ namespace {
Value *createFSub(Value *Opnd0, Value *Opnd1);
Value *createFAdd(Value *Opnd0, Value *Opnd1);
Value *createFMul(Value *Opnd0, Value *Opnd1);
Value *createFDiv(Value *Opnd0, Value *Opnd1);
Value *createFNeg(Value *V);
Value *createNaryFAdd(const AddendVect& Opnds, unsigned InstrQuota);
void createInstPostProc(Instruction *NewInst);
@ -388,6 +391,78 @@ unsigned FAddend::drillAddendDownOneStep
return BreakNum;
}
// Try to perform following optimization on the input instruction I. Return the
// simplified expression if was successful; otherwise, return 0.
//
// Instruction "I" is Simplified into
// -------------------------------------------------------
// (x * y) +/- (x * z) x * (y +/- z)
// (y / x) +/- (z / x) (y +/- z) / x
//
Value *FAddCombine::performFactorization(Instruction *I) {
assert((I->getOpcode() == Instruction::FAdd ||
I->getOpcode() == Instruction::FSub) && "Expect add/sub");
Instruction *I0 = dyn_cast<Instruction>(I->getOperand(0));
Instruction *I1 = dyn_cast<Instruction>(I->getOperand(1));
if (!I0 || !I1 || I0->getOpcode() != I1->getOpcode())
return 0;
bool isMpy = false;
if (I0->getOpcode() == Instruction::FMul)
isMpy = true;
else if (I0->getOpcode() != Instruction::FDiv)
return 0;
Value *Opnd0_0 = I0->getOperand(0);
Value *Opnd0_1 = I0->getOperand(1);
Value *Opnd1_0 = I1->getOperand(0);
Value *Opnd1_1 = I1->getOperand(1);
// Input Instr I Factor AddSub0 AddSub1
// ----------------------------------------------
// (x*y) +/- (x*z) x y z
// (y/x) +/- (z/x) x y z
//
Value *Factor = 0;
Value *AddSub0 = 0, *AddSub1 = 0;
if (isMpy) {
if (Opnd0_0 == Opnd1_0 || Opnd0_0 == Opnd1_1)
Factor = Opnd0_0;
else if (Opnd0_1 == Opnd1_0 || Opnd0_1 == Opnd1_1)
Factor = Opnd0_1;
if (Factor) {
AddSub0 = (Factor == Opnd0_0) ? Opnd0_1 : Opnd0_0;
AddSub1 = (Factor == Opnd1_0) ? Opnd1_1 : Opnd1_0;
}
} else if (Opnd0_1 == Opnd1_1) {
Factor = Opnd0_1;
AddSub0 = Opnd0_0;
AddSub1 = Opnd1_0;
}
if (!Factor)
return 0;
// Create expression "NewAddSub = AddSub0 +/- AddsSub1"
Value *NewAddSub = (I->getOpcode() == Instruction::FAdd) ?
createFAdd(AddSub0, AddSub1) :
createFSub(AddSub0, AddSub1);
if (ConstantFP *CFP = dyn_cast<ConstantFP>(NewAddSub)) {
const APFloat &F = CFP->getValueAPF();
if (!F.isNormal() || F.isDenormal())
return 0;
}
if (isMpy)
return createFMul(Factor, NewAddSub);
return createFDiv(NewAddSub, Factor);
}
Value *FAddCombine::simplify(Instruction *I) {
assert(I->hasUnsafeAlgebra() && "Should be in unsafe mode");
@ -471,7 +546,8 @@ Value *FAddCombine::simplify(Instruction *I) {
return R;
}
return 0;
// step 6: Try factorization as the last resort,
return performFactorization(I);
}
Value *FAddCombine::simplifyFAdd(AddendVect& Addends, unsigned InstrQuota) {
@ -627,7 +703,8 @@ Value *FAddCombine::createNaryFAdd
Value *FAddCombine::createFSub
(Value *Opnd0, Value *Opnd1) {
Value *V = Builder->CreateFSub(Opnd0, Opnd1);
createInstPostProc(cast<Instruction>(V));
if (Instruction *I = dyn_cast<Instruction>(V))
createInstPostProc(I);
return V;
}
@ -639,13 +716,22 @@ Value *FAddCombine::createFNeg(Value *V) {
Value *FAddCombine::createFAdd
(Value *Opnd0, Value *Opnd1) {
Value *V = Builder->CreateFAdd(Opnd0, Opnd1);
createInstPostProc(cast<Instruction>(V));
if (Instruction *I = dyn_cast<Instruction>(V))
createInstPostProc(I);
return V;
}
Value *FAddCombine::createFMul(Value *Opnd0, Value *Opnd1) {
Value *V = Builder->CreateFMul(Opnd0, Opnd1);
createInstPostProc(cast<Instruction>(V));
if (Instruction *I = dyn_cast<Instruction>(V))
createInstPostProc(I);
return V;
}
Value *FAddCombine::createFDiv(Value *Opnd0, Value *Opnd1) {
Value *V = Builder->CreateFDiv(Opnd0, Opnd1);
if (Instruction *I = dyn_cast<Instruction>(V))
createInstPostProc(I);
return V;
}

View File

@ -350,3 +350,108 @@ define float @fdiv9(float %x) {
; CHECK: @fdiv9
; CHECK: fmul fast float %x, 5.000000e+00
}
; =========================================================================
;
; Testing-cases about factorization
;
; =========================================================================
; x*z + y*z => (x+y) * z
define float @fact_mul1(float %x, float %y, float %z) {
%t1 = fmul fast float %x, %z
%t2 = fmul fast float %y, %z
%t3 = fadd fast float %t1, %t2
ret float %t3
; CHECK: @fact_mul1
; CHECK: fmul fast float %1, %z
}
; z*x + y*z => (x+y) * z
define float @fact_mul2(float %x, float %y, float %z) {
%t1 = fmul fast float %z, %x
%t2 = fmul fast float %y, %z
%t3 = fsub fast float %t1, %t2
ret float %t3
; CHECK: @fact_mul2
; CHECK: fmul fast float %1, %z
}
; z*x - z*y => (x-y) * z
define float @fact_mul3(float %x, float %y, float %z) {
%t2 = fmul fast float %z, %y
%t1 = fmul fast float %z, %x
%t3 = fsub fast float %t1, %t2
ret float %t3
; CHECK: @fact_mul3
; CHECK: fmul fast float %1, %z
}
; x*z - z*y => (x-y) * z
define float @fact_mul4(float %x, float %y, float %z) {
%t1 = fmul fast float %x, %z
%t2 = fmul fast float %z, %y
%t3 = fsub fast float %t1, %t2
ret float %t3
; CHECK: @fact_mul4
; CHECK: fmul fast float %1, %z
}
; x/y + x/z, no xform
define float @fact_div1(float %x, float %y, float %z) {
%t1 = fdiv fast float %x, %y
%t2 = fdiv fast float %x, %z
%t3 = fadd fast float %t1, %t2
ret float %t3
; CHECK: fact_div1
; CHECK: fadd fast float %t1, %t2
}
; x/y + z/x; no xform
define float @fact_div2(float %x, float %y, float %z) {
%t1 = fdiv fast float %x, %y
%t2 = fdiv fast float %z, %x
%t3 = fadd fast float %t1, %t2
ret float %t3
; CHECK: fact_div2
; CHECK: fadd fast float %t1, %t2
}
; y/x + z/x => (y+z)/x
define float @fact_div3(float %x, float %y, float %z) {
%t1 = fdiv fast float %y, %x
%t2 = fdiv fast float %z, %x
%t3 = fadd fast float %t1, %t2
ret float %t3
; CHECK: fact_div3
; CHECK: fdiv fast float %1, %x
}
; y/x - z/x => (y-z)/x
define float @fact_div4(float %x, float %y, float %z) {
%t1 = fdiv fast float %y, %x
%t2 = fdiv fast float %z, %x
%t3 = fsub fast float %t1, %t2
ret float %t3
; CHECK: fact_div4
; CHECK: fdiv fast float %1, %x
}
; y/x - z/x => (y-z)/x is disabled if y-z is denormal.
define float @fact_div5(float %x) {
%t1 = fdiv fast float 0x3810000000000000, %x
%t2 = fdiv fast float 0x3800000000000000, %x
%t3 = fadd fast float %t1, %t2
ret float %t3
; CHECK: fact_div5
; CHECK: fdiv fast float 0x3818000000000000, %x
}
; y/x - z/x => (y-z)/x is disabled if y-z is denormal.
define float @fact_div6(float %x) {
%t1 = fdiv fast float 0x3810000000000000, %x
%t2 = fdiv fast float 0x3800000000000000, %x
%t3 = fsub fast float %t1, %t2
ret float %t3
; CHECK: fact_div6
; CHECK: %t3 = fsub fast float %t1, %t2
}