diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index d0675b77444b..ad2bd1841f1f 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -1498,39 +1498,47 @@ Instruction *InstCombiner::foldICmpAndShift(ICmpInst &Cmp, BinaryOperator *And, Instruction *InstCombiner::foldICmpAndConstConst(ICmpInst &Cmp, BinaryOperator *And, const APInt *C1) { - // FIXME: This check restricts all folds under here to scalar types. - ConstantInt *RHS = dyn_cast(Cmp.getOperand(1)); - if (!RHS) - return nullptr; - - // FIXME: Use m_APInt. - auto *C2 = dyn_cast(And->getOperand(1)); - if (!C2) + const APInt *C2; + if (!match(And->getOperand(1), m_APInt(C2))) return nullptr; if (!And->hasOneUse() || !And->getOperand(0)->hasOneUse()) return nullptr; - // If the LHS is an AND of a truncating cast, we can widen the and/compare to - // be the input width without changing the value produced, eliminating a cast. - if (TruncInst *Cast = dyn_cast(And->getOperand(0))) { - // We can do this transformation if either the AND constant does not have - // its sign bit set or if it is an equality comparison. Extending a - // relational comparison when we're checking the sign bit would not work. - if (Cmp.isEquality() || (!C2->isNegative() && C1->isNonNegative())) { - Value *NewAnd = Builder->CreateAnd( - Cast->getOperand(0), ConstantExpr::getZExt(C2, Cast->getSrcTy())); - NewAnd->takeName(And); - return new ICmpInst(Cmp.getPredicate(), NewAnd, - ConstantExpr::getZExt(RHS, Cast->getSrcTy())); + // If the LHS is an 'and' of a truncate and we can widen the and/compare to + // the input width without changing the value produced, eliminate the cast: + // + // icmp (and (trunc W), C2), C1 -> icmp (and W, C2'), C1' + // + // We can do this transformation if the constants do not have their sign bits + // set or if it is an equality comparison. Extending a relational comparison + // when we're checking the sign bit would not work. + Value *W; + if (match(And->getOperand(0), m_Trunc(m_Value(W))) && + (Cmp.isEquality() || (!C1->isNegative() && !C2->isNegative()))) { + // TODO: Is this a good transform for vectors? Wider types may reduce + // throughput. Should this transform be limited (even for scalars) by using + // ShouldChangeType()? + if (!Cmp.getType()->isVectorTy()) { + Type *WideType = W->getType(); + unsigned WideScalarBits = WideType->getScalarSizeInBits(); + Constant *ZextC1 = ConstantInt::get(WideType, C1->zext(WideScalarBits)); + Constant *ZextC2 = ConstantInt::get(WideType, C2->zext(WideScalarBits)); + Value *NewAnd = Builder->CreateAnd(W, ZextC2, And->getName()); + return new ICmpInst(Cmp.getPredicate(), NewAnd, ZextC1); } } if (Instruction *I = foldICmpAndShift(Cmp, And, C1)) return I; + // FIXME: This check restricts all folds under here to scalar types. + ConstantInt *RHS = dyn_cast(Cmp.getOperand(1)); + if (!RHS) + return nullptr; + // (icmp pred (and (or (lshr A, B), A), 1), 0) --> - // (icmp pred (and A, (or (shl 1, B), 1), 0)) + // (icmp pred (and A, (or (shl 1, B), 1), 0)) // // iff pred isn't signed { @@ -1573,7 +1581,7 @@ Instruction *InstCombiner::foldICmpAndConstConst(ICmpInst &Cmp, // Replace ((X & C2) > C1) with ((X & C2) != 0), if any bit set in (X & C2) // will produce a result greater than C1. if (Cmp.getPredicate() == ICmpInst::ICMP_UGT) { - unsigned NTZ = C2->getValue().countTrailingZeros(); + unsigned NTZ = C2->countTrailingZeros(); if ((NTZ < C2->getBitWidth()) && APInt::getOneBitSet(C2->getBitWidth(), NTZ).ugt(*C1)) return new ICmpInst(ICmpInst::ICMP_NE, And,