[ValueTracking] Fix computeKnownBits() with bitwidth-changing ptrtoint

computeKnownBitsFromAssume() currently asserts if m_V matches a
ptrtoint that changes the bitwidth. Because InstCombine
canonicalizes ptrtoint instructions to use explicit zext/trunc,
we never ran into the issue in practice. I'm adding unit tests,
as I don't know if this can be triggered via IR anywhere.

Fix this by calling anyextOrTrunc(BitWidth) on the computed
KnownBits. Note that we are going from the KnownBits of the
ptrtoint result to the KnownBits of the ptrtoint operand,
so we need to truncate if the ptrtoint zexted and anyext if
the ptrtoint truncated.

Differential Revision: https://reviews.llvm.org/D79234
This commit is contained in:
Nikita Popov 2020-05-01 15:19:41 +02:00
parent 3f66bb2017
commit d86fff6ae7
2 changed files with 89 additions and 47 deletions

View File

@ -785,6 +785,7 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
if (!Cmp)
continue;
// Note that ptrtoint may change the bitwidth.
Value *A, *B;
auto m_V = m_CombineOr(m_Specific(V), m_PtrToInt(m_Specific(V)));
@ -797,18 +798,18 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
// assume(v = a)
if (match(Cmp, m_c_ICmp(Pred, m_V, m_Value(A))) &&
isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
KnownBits RHSKnown(BitWidth);
computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I));
KnownBits RHSKnown =
computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
Known.Zero |= RHSKnown.Zero;
Known.One |= RHSKnown.One;
// assume(v & b = a)
} else if (match(Cmp,
m_c_ICmp(Pred, m_c_And(m_V, m_Value(B)), m_Value(A))) &&
isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
KnownBits RHSKnown(BitWidth);
computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I));
KnownBits MaskKnown(BitWidth);
computeKnownBits(B, MaskKnown, Depth+1, Query(Q, I));
KnownBits RHSKnown =
computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
KnownBits MaskKnown =
computeKnownBits(B, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
// For those bits in the mask that are known to be one, we can propagate
// known bits from the RHS to V.
@ -818,10 +819,10 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
} else if (match(Cmp, m_c_ICmp(Pred, m_Not(m_c_And(m_V, m_Value(B))),
m_Value(A))) &&
isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
KnownBits RHSKnown(BitWidth);
computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I));
KnownBits MaskKnown(BitWidth);
computeKnownBits(B, MaskKnown, Depth+1, Query(Q, I));
KnownBits RHSKnown =
computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
KnownBits MaskKnown =
computeKnownBits(B, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
// For those bits in the mask that are known to be one, we can propagate
// inverted known bits from the RHS to V.
@ -831,10 +832,10 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
} else if (match(Cmp,
m_c_ICmp(Pred, m_c_Or(m_V, m_Value(B)), m_Value(A))) &&
isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
KnownBits RHSKnown(BitWidth);
computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I));
KnownBits BKnown(BitWidth);
computeKnownBits(B, BKnown, Depth+1, Query(Q, I));
KnownBits RHSKnown =
computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
KnownBits BKnown =
computeKnownBits(B, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
// For those bits in B that are known to be zero, we can propagate known
// bits from the RHS to V.
@ -844,10 +845,10 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
} else if (match(Cmp, m_c_ICmp(Pred, m_Not(m_c_Or(m_V, m_Value(B))),
m_Value(A))) &&
isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
KnownBits RHSKnown(BitWidth);
computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I));
KnownBits BKnown(BitWidth);
computeKnownBits(B, BKnown, Depth+1, Query(Q, I));
KnownBits RHSKnown =
computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
KnownBits BKnown =
computeKnownBits(B, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
// For those bits in B that are known to be zero, we can propagate
// inverted known bits from the RHS to V.
@ -857,10 +858,10 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
} else if (match(Cmp,
m_c_ICmp(Pred, m_c_Xor(m_V, m_Value(B)), m_Value(A))) &&
isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
KnownBits RHSKnown(BitWidth);
computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I));
KnownBits BKnown(BitWidth);
computeKnownBits(B, BKnown, Depth+1, Query(Q, I));
KnownBits RHSKnown =
computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
KnownBits BKnown =
computeKnownBits(B, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
// For those bits in B that are known to be zero, we can propagate known
// bits from the RHS to V. For those bits in B that are known to be one,
@ -873,10 +874,10 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
} else if (match(Cmp, m_c_ICmp(Pred, m_Not(m_c_Xor(m_V, m_Value(B))),
m_Value(A))) &&
isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
KnownBits RHSKnown(BitWidth);
computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I));
KnownBits BKnown(BitWidth);
computeKnownBits(B, BKnown, Depth+1, Query(Q, I));
KnownBits RHSKnown =
computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
KnownBits BKnown =
computeKnownBits(B, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
// For those bits in B that are known to be zero, we can propagate
// inverted known bits from the RHS to V. For those bits in B that are
@ -889,8 +890,9 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
} else if (match(Cmp, m_c_ICmp(Pred, m_Shl(m_V, m_ConstantInt(C)),
m_Value(A))) &&
isValidAssumeForContext(I, Q.CxtI, Q.DT) && C < BitWidth) {
KnownBits RHSKnown(BitWidth);
computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I));
KnownBits RHSKnown =
computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
// For those bits in RHS that are known, we can propagate them to known
// bits in V shifted to the right by C.
RHSKnown.Zero.lshrInPlace(C);
@ -901,8 +903,8 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
} else if (match(Cmp, m_c_ICmp(Pred, m_Not(m_Shl(m_V, m_ConstantInt(C))),
m_Value(A))) &&
isValidAssumeForContext(I, Q.CxtI, Q.DT) && C < BitWidth) {
KnownBits RHSKnown(BitWidth);
computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I));
KnownBits RHSKnown =
computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
// For those bits in RHS that are known, we can propagate them inverted
// to known bits in V shifted to the right by C.
RHSKnown.One.lshrInPlace(C);
@ -913,8 +915,8 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
} else if (match(Cmp, m_c_ICmp(Pred, m_Shr(m_V, m_ConstantInt(C)),
m_Value(A))) &&
isValidAssumeForContext(I, Q.CxtI, Q.DT) && C < BitWidth) {
KnownBits RHSKnown(BitWidth);
computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I));
KnownBits RHSKnown =
computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
// For those bits in RHS that are known, we can propagate them to known
// bits in V shifted to the right by C.
Known.Zero |= RHSKnown.Zero << C;
@ -923,8 +925,8 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
} else if (match(Cmp, m_c_ICmp(Pred, m_Not(m_Shr(m_V, m_ConstantInt(C))),
m_Value(A))) &&
isValidAssumeForContext(I, Q.CxtI, Q.DT) && C < BitWidth) {
KnownBits RHSKnown(BitWidth);
computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I));
KnownBits RHSKnown =
computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
// For those bits in RHS that are known, we can propagate them inverted
// to known bits in V shifted to the right by C.
Known.Zero |= RHSKnown.One << C;
@ -935,8 +937,8 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
// assume(v >=_s c) where c is non-negative
if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) &&
isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
KnownBits RHSKnown(BitWidth);
computeKnownBits(A, RHSKnown, Depth + 1, Query(Q, I));
KnownBits RHSKnown =
computeKnownBits(A, Depth + 1, Query(Q, I)).anyextOrTrunc(BitWidth);
if (RHSKnown.isNonNegative()) {
// We know that the sign bit is zero.
@ -948,8 +950,8 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
// assume(v >_s c) where c is at least -1.
if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) &&
isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
KnownBits RHSKnown(BitWidth);
computeKnownBits(A, RHSKnown, Depth + 1, Query(Q, I));
KnownBits RHSKnown =
computeKnownBits(A, Depth + 1, Query(Q, I)).anyextOrTrunc(BitWidth);
if (RHSKnown.isAllOnes() || RHSKnown.isNonNegative()) {
// We know that the sign bit is zero.
@ -961,8 +963,8 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
// assume(v <=_s c) where c is negative
if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) &&
isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
KnownBits RHSKnown(BitWidth);
computeKnownBits(A, RHSKnown, Depth + 1, Query(Q, I));
KnownBits RHSKnown =
computeKnownBits(A, Depth + 1, Query(Q, I)).anyextOrTrunc(BitWidth);
if (RHSKnown.isNegative()) {
// We know that the sign bit is one.
@ -974,8 +976,8 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
// assume(v <_s c) where c is non-positive
if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) &&
isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
KnownBits RHSKnown(BitWidth);
computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I));
KnownBits RHSKnown =
computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
if (RHSKnown.isZero() || RHSKnown.isNegative()) {
// We know that the sign bit is one.
@ -987,8 +989,8 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
// assume(v <=_u c)
if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) &&
isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
KnownBits RHSKnown(BitWidth);
computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I));
KnownBits RHSKnown =
computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
// Whatever high bits in c are zero are known to be zero.
Known.Zero.setHighBits(RHSKnown.countMinLeadingZeros());
@ -998,8 +1000,8 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
// assume(v <_u c)
if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) &&
isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
KnownBits RHSKnown(BitWidth);
computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I));
KnownBits RHSKnown =
computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
// If the RHS is known zero, then this assumption must be wrong (nothing
// is unsigned less than zero). Signal a conflict and get out of here.

View File

@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/InstIterator.h"
@ -40,7 +41,7 @@ protected:
M = parseModule(Assembly);
ASSERT_TRUE(M);
Function *F = M->getFunction("test");
F = M->getFunction("test");
ASSERT_TRUE(F) << "Test must have a function @test";
if (!F)
return;
@ -57,6 +58,7 @@ protected:
LLVMContext Context;
std::unique_ptr<Module> M;
Function *F = nullptr;
Instruction *A = nullptr;
};
@ -954,6 +956,44 @@ TEST_F(ComputeKnownBitsTest, ComputeKnownUSubSatZerosPreserved) {
expectKnownBits(/*zero*/ 2u, /*one*/ 0u);
}
TEST_F(ComputeKnownBitsTest, ComputeKnownBitsPtrToIntTrunc) {
// ptrtoint truncates the pointer type.
parseAssembly(
"define void @test(i8** %p) {\n"
" %A = load i8*, i8** %p\n"
" %i = ptrtoint i8* %A to i32\n"
" %m = and i32 %i, 31\n"
" %c = icmp eq i32 %m, 0\n"
" call void @llvm.assume(i1 %c)\n"
" ret void\n"
"}\n"
"declare void @llvm.assume(i1)\n");
AssumptionCache AC(*F);
KnownBits Known = computeKnownBits(
A, M->getDataLayout(), /* Depth */ 0, &AC, F->front().getTerminator());
EXPECT_EQ(Known.Zero.getZExtValue(), 31u);
EXPECT_EQ(Known.One.getZExtValue(), 0u);
}
TEST_F(ComputeKnownBitsTest, ComputeKnownBitsPtrToIntZext) {
// ptrtoint zero extends the pointer type.
parseAssembly(
"define void @test(i8** %p) {\n"
" %A = load i8*, i8** %p\n"
" %i = ptrtoint i8* %A to i128\n"
" %m = and i128 %i, 31\n"
" %c = icmp eq i128 %m, 0\n"
" call void @llvm.assume(i1 %c)\n"
" ret void\n"
"}\n"
"declare void @llvm.assume(i1)\n");
AssumptionCache AC(*F);
KnownBits Known = computeKnownBits(
A, M->getDataLayout(), /* Depth */ 0, &AC, F->front().getTerminator());
EXPECT_EQ(Known.Zero.getZExtValue(), 31u);
EXPECT_EQ(Known.One.getZExtValue(), 0u);
}
class IsBytewiseValueTest : public ValueTrackingTest,
public ::testing::WithParamInterface<
std::pair<const char *, const char *>> {