diff --git a/llvm/lib/Transforms/Scalar/InstructionCombining.cpp b/llvm/lib/Transforms/Scalar/InstructionCombining.cpp index 2f2d32ceab63..d041083f040c 100644 --- a/llvm/lib/Transforms/Scalar/InstructionCombining.cpp +++ b/llvm/lib/Transforms/Scalar/InstructionCombining.cpp @@ -489,7 +489,50 @@ struct AddMaskingAnd { } }; +static Value *FoldOperationIntoSelectOperand(Instruction &BI, Value *SO, + InstCombiner *IC) { + // Figure out if the constant is the left or the right argument. + bool ConstIsRHS = isa(BI.getOperand(1)); + Constant *ConstOperand = cast(BI.getOperand(ConstIsRHS)); + if (Constant *SOC = dyn_cast(SO)) { + if (ConstIsRHS) + return ConstantExpr::get(BI.getOpcode(), SOC, ConstOperand); + return ConstantExpr::get(BI.getOpcode(), ConstOperand, SOC); + } + + Value *Op0 = SO, *Op1 = ConstOperand; + if (!ConstIsRHS) + std::swap(Op0, Op1); + Instruction *New; + if (BinaryOperator *BO = dyn_cast(&BI)) + New = BinaryOperator::create(BO->getOpcode(), Op0, Op1); + else if (ShiftInst *SI = dyn_cast(&BI)) + New = new ShiftInst(SI->getOpcode(), Op0, Op1); + else + assert(0 && "Unknown binary instruction type!"); + return IC->InsertNewInstBefore(New, BI); +} + +// FoldBinOpIntoSelect - Given an instruction with a select as one operand and a +// constant as the other operand, try to fold the binary operator into the +// select arguments. +static Instruction *FoldBinOpIntoSelect(Instruction &BI, SelectInst *SI, + InstCombiner *IC) { + // Don't modify shared select instructions + if (!SI->hasOneUse()) return 0; + Value *TV = SI->getOperand(1); + Value *FV = SI->getOperand(2); + + if (isa(TV) || isa(FV)) { + Value *SelectTrueVal = FoldOperationIntoSelectOperand(BI, TV, IC); + Value *SelectFalseVal = FoldOperationIntoSelectOperand(BI, FV, IC); + + return new SelectInst(SI->getCondition(), SelectTrueVal, + SelectFalseVal); + } + return 0; +} Instruction *InstCombiner::visitAdd(BinaryOperator &I) { bool Changed = SimplifyCommutative(I); @@ -547,6 +590,11 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { CRHS, ConstantInt::get(I.getType(), 1)), ILHS->getOperand(0)); break; + case Instruction::Select: + // Try to fold constant add into select arguments. + if (Instruction *R = FoldBinOpIntoSelect(I,cast(ILHS),this)) + return R; + default: break; } } @@ -632,6 +680,11 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { } } } + + // Try to fold constant sub into select arguments. + if (SelectInst *SI = dyn_cast(Op1)) + if (Instruction *R = FoldBinOpIntoSelect(I, SI, this)) + return R; } if (BinaryOperator *Op1I = dyn_cast(Op1)) @@ -740,6 +793,11 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { if (Op1F->getValue() == 1.0) return ReplaceInstUsesWith(I, Op0); // Eliminate 'mul double %X, 1.0' } + + // Try to fold constant mul into select arguments. + if (SelectInst *SI = dyn_cast(Op0)) + if (Instruction *R = FoldBinOpIntoSelect(I, SI, this)) + return R; } if (Value *Op0v = dyn_castNegVal(Op0)) // -X * -Y = X*Y @@ -1093,6 +1151,11 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { if (Instruction *Res = OptAndOp(Op0I, Op0CI, RHS, I)) return Res; } + + // Try to fold constant and into select arguments. + if (SelectInst *SI = dyn_cast(Op0)) + if (Instruction *R = FoldBinOpIntoSelect(I, SI, this)) + return R; } Value *Op0NotVal = dyn_castNotVal(Op0); @@ -1158,6 +1221,11 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { NotConstant(RHS))); } } + + // Try to fold constant and into select arguments. + if (SelectInst *SI = dyn_cast(Op0)) + if (Instruction *R = FoldBinOpIntoSelect(I, SI, this)) + return R; } // (A & C1)|(A & C2) == A & (C1|C2) @@ -1268,6 +1336,11 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { default: break; } } + + // Try to fold constant and into select arguments. + if (SelectInst *SI = dyn_cast(Op0)) + if (Instruction *R = FoldBinOpIntoSelect(I, SI, this)) + return R; } if (Value *X = dyn_castNotVal(Op0)) // ~A ^ A == -1 @@ -1668,6 +1741,12 @@ Instruction *InstCombiner::visitShiftInst(ShiftInst &I) { if (CSI->isAllOnesValue()) return ReplaceInstUsesWith(I, CSI); + // Try to fold constant and into select arguments. + if (isa(Op0)) + if (SelectInst *SI = dyn_cast(Op1)) + if (Instruction *R = FoldBinOpIntoSelect(I, SI, this)) + return R; + if (ConstantUInt *CUI = dyn_cast(Op1)) { // shl uint X, 32 = 0 and shr ubyte Y, 9 = 0, ... just don't eliminate shr // of a signed value. @@ -1689,6 +1768,10 @@ Instruction *InstCombiner::visitShiftInst(ShiftInst &I) { return BinaryOperator::create(Instruction::Mul, BO->getOperand(0), ConstantExpr::get(Instruction::Shl, BOOp, CUI)); + // Try to fold constant and into select arguments. + if (SelectInst *SI = dyn_cast(Op0)) + if (Instruction *R = FoldBinOpIntoSelect(I, SI, this)) + return R; // If the operand is an bitwise operator with a constant RHS, and the // shift is the only use, we can pull it out of the shift. @@ -2052,22 +2135,18 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { } } - // Selecting between two constants? - if (Constant *TrueValC = dyn_cast(TrueVal)) - if (Constant *FalseValC = dyn_cast(FalseVal)) { - if (SI.getType()->isInteger()) { - // select C, 1, 0 -> cast C to int - if (FalseValC->isNullValue() && isa(TrueValC) && - cast(TrueValC)->getRawValue() == 1) { - return new CastInst(CondVal, SI.getType()); - } else if (TrueValC->isNullValue() && isa(FalseValC) && - cast(FalseValC)->getRawValue() == 1) { - // select C, 0, 1 -> cast !C to int - Value *NotCond = - InsertNewInstBefore(BinaryOperator::createNot(CondVal, + // Selecting between two integer constants? + if (ConstantInt *TrueValC = dyn_cast(TrueVal)) + if (ConstantInt *FalseValC = dyn_cast(FalseVal)) { + // select C, 1, 0 -> cast C to int + if (FalseValC->isNullValue() && TrueValC->getRawValue() == 1) { + return new CastInst(CondVal, SI.getType()); + } else if (TrueValC->isNullValue() && FalseValC->getRawValue() == 1) { + // select C, 0, 1 -> cast !C to int + Value *NotCond = + InsertNewInstBefore(BinaryOperator::createNot(CondVal, "not."+CondVal->getName()), SI); - return new CastInst(NotCond, SI.getType()); - } + return new CastInst(NotCond, SI.getType()); } }