diff --git a/llvm/lib/Transforms/Scalar/InstructionCombining.cpp b/llvm/lib/Transforms/Scalar/InstructionCombining.cpp index 0d23f44799ee..bcc17f1b98a3 100644 --- a/llvm/lib/Transforms/Scalar/InstructionCombining.cpp +++ b/llvm/lib/Transforms/Scalar/InstructionCombining.cpp @@ -405,65 +405,66 @@ static ConstantInt *SubOne(ConstantInt *C) { ConstantInt::get(C->getType(), 1))); } -/// MaskedValueIsZero - Return true if 'V & Mask' is known to be zero. We use -/// this predicate to simplify operations downstream. Mask is known to be zero -/// for bits that V cannot have. -static bool MaskedValueIsZero(Value *V, uint64_t Mask, unsigned Depth = 0) { +/// ComputeMaskedNonZeroBits - Determine which of the bits specified in Mask are +/// not known to be zero and return them as a bitmask. The bits that we can +/// guarantee to be zero are returned as zero bits in the result. +static uint64_t ComputeMaskedNonZeroBits(Value *V, uint64_t Mask, + unsigned Depth = 0) { // Note, we cannot consider 'undef' to be "IsZero" here. The problem is that // we cannot optimize based on the assumption that it is zero without changing // it to be an explicit zero. If we don't change it to zero, other code could // optimized based on the contradictory assumption that it is non-zero. // Because instcombine aggressively folds operations with undef args anyway, // this won't lose us code quality. - if (Mask == 0) - return true; if (ConstantIntegral *CI = dyn_cast(V)) - return (CI->getRawValue() & Mask) == 0; - - if (Depth == 6) return false; // Limit search depth. + return CI->getRawValue() & Mask; + if (Depth == 6 || Mask == 0) + return Mask; // Limit search depth. if (Instruction *I = dyn_cast(V)) { switch (I->getOpcode()) { case Instruction::And: // (X & C1) & C2 == 0 iff C1 & C2 == 0. if (ConstantIntegral *CI = dyn_cast(I->getOperand(1))) - return MaskedValueIsZero(I->getOperand(0), CI->getRawValue() & Mask, - Depth+1); + return ComputeMaskedNonZeroBits(I->getOperand(0), + CI->getRawValue() & Mask, Depth+1); // If either the LHS or the RHS are MaskedValueIsZero, the result is zero. - return MaskedValueIsZero(I->getOperand(1), Mask, Depth+1) || - MaskedValueIsZero(I->getOperand(0), Mask, Depth+1); + Mask = ComputeMaskedNonZeroBits(I->getOperand(1), Mask, Depth+1); + Mask = ComputeMaskedNonZeroBits(I->getOperand(0), Mask, Depth+1); + return Mask; case Instruction::Or: case Instruction::Xor: - // If the LHS and the RHS are MaskedValueIsZero, the result is also zero. - return MaskedValueIsZero(I->getOperand(1), Mask, Depth+1) && - MaskedValueIsZero(I->getOperand(0), Mask, Depth+1); + // Any non-zero bits in the LHS or RHS are potentially non-zero in the + // result. + return ComputeMaskedNonZeroBits(I->getOperand(1), Mask, Depth+1) | + ComputeMaskedNonZeroBits(I->getOperand(0), Mask, Depth+1); case Instruction::Select: - // If the T and F values are MaskedValueIsZero, the result is also zero. - return MaskedValueIsZero(I->getOperand(2), Mask, Depth+1) && - MaskedValueIsZero(I->getOperand(1), Mask, Depth+1); + // Any non-zero bits in the T or F values are potentially non-zero in the + // result. + return ComputeMaskedNonZeroBits(I->getOperand(2), Mask, Depth+1) | + ComputeMaskedNonZeroBits(I->getOperand(1), Mask, Depth+1); case Instruction::Cast: { const Type *SrcTy = I->getOperand(0)->getType(); if (SrcTy == Type::BoolTy) - return (Mask & 1) == 0; - if (!SrcTy->isInteger()) return false; + return ComputeMaskedNonZeroBits(I->getOperand(0), Mask & 1, Depth+1); + if (!SrcTy->isInteger()) return Mask; // (cast X to int) & C2 == 0 iff could not have contained C2. - if (SrcTy->isUnsigned()) // Only handle zero ext. - return MaskedValueIsZero(I->getOperand(0), - Mask & SrcTy->getIntegralTypeMask(), Depth+1); - - // If this is a noop or trunc cast, recurse. - if (SrcTy->getPrimitiveSizeInBits() >= - I->getType()->getPrimitiveSizeInBits()) - return MaskedValueIsZero(I->getOperand(0), - Mask & SrcTy->getIntegralTypeMask(), Depth+1); + if (SrcTy->isUnsigned() || // Only handle zero ext/trunc/noop + SrcTy->getPrimitiveSizeInBits() >= + I->getType()->getPrimitiveSizeInBits()) { + Mask &= SrcTy->getIntegralTypeMask(); + return ComputeMaskedNonZeroBits(I->getOperand(0), Mask, Depth+1); + } + + // FIXME: handle sext casts. break; } case Instruction::Shl: // (shl X, C1) & C2 == 0 iff (X & C2 >>u C1) == 0 if (ConstantUInt *SA = dyn_cast(I->getOperand(1))) - return MaskedValueIsZero(I->getOperand(0), Mask >> SA->getValue(), - Depth+1); + return ComputeMaskedNonZeroBits(I->getOperand(0),Mask >> SA->getValue(), + Depth+1); break; case Instruction::Shr: // (ushr X, C1) & C2 == 0 iff (-1 >> C1) & C2 == 0 @@ -471,13 +472,20 @@ static bool MaskedValueIsZero(Value *V, uint64_t Mask, unsigned Depth = 0) { if (I->getType()->isUnsigned()) { Mask <<= SA->getValue(); Mask &= I->getType()->getIntegralTypeMask(); - return MaskedValueIsZero(I->getOperand(0), Mask, Depth+1); + return ComputeMaskedNonZeroBits(I->getOperand(0), Mask, Depth+1); } break; } } - return false; + return Mask; +} + +/// MaskedValueIsZero - Return true if 'V & Mask' is known to be zero. We use +/// this predicate to simplify operations downstream. Mask is known to be zero +/// for bits that V cannot have. +static bool MaskedValueIsZero(Value *V, uint64_t Mask, unsigned Depth = 0) { + return ComputeMaskedNonZeroBits(V, Mask, Depth) == 0; } /// SimplifyDemandedBits - Look at V. At this point, we know that only the Mask @@ -493,7 +501,9 @@ bool InstCombiner::SimplifyDemandedBits(Value *V, uint64_t Mask, // just set the Mask to all bits. Mask = V->getType()->getIntegralTypeMask(); } else if (Mask == 0) { // Not demanding any bits from V. - return UpdateValueUsesWith(V, UndefValue::get(V->getType())); + if (V != UndefValue::get(V->getType())) + return UpdateValueUsesWith(V, UndefValue::get(V->getType())); + return false; } else if (Depth == 6) { // Limit search depth. return false; } @@ -509,15 +519,14 @@ bool InstCombiner::SimplifyDemandedBits(Value *V, uint64_t Mask, if (SimplifyDemandedBits(I->getOperand(0), RHS->getRawValue() & Mask, Depth+1)) return true; - if (~Mask & RHS->getRawValue()) { + if (~Mask & RHS->getZExtValue()) { // If this is producing any bits that are not needed, simplify the RHS. - if (I->getType()->isSigned()) { - int64_t Val = Mask & cast(RHS)->getValue(); - I->setOperand(1, ConstantSInt::get(I->getType(), Val)); - } else { - uint64_t Val = Mask & cast(RHS)->getValue(); - I->setOperand(1, ConstantUInt::get(I->getType(), Val)); - } + uint64_t Val = Mask & RHS->getZExtValue(); + Constant *RHS = + ConstantUInt::get(I->getType()->getUnsignedVersion(), Val); + if (I->getType()->isSigned()) + RHS = ConstantExpr::getCast(RHS, I->getType()); + I->setOperand(1, RHS); return UpdateValueUsesWith(I, I); } } @@ -833,7 +842,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { // X + (signbit) --> X ^ signbit if (ConstantInt *CI = dyn_cast(RHSC)) { - uint64_t Val = CI->getRawValue() & CI->getType()->getIntegralTypeMask(); + uint64_t Val = CI->getZExtValue(); if (Val == (1ULL << (CI->getType()->getPrimitiveSizeInBits()-1))) return BinaryOperator::createXor(LHS, RHS); }