[ValueTracking] don't recursively compute known bits using multiple llvm.assumes

This is an alternative to D99759 to avoid the compile-time explosion seen in:
https://llvm.org/PR49785

Another potential solution would make the exclusion logic stronger to avoid
blowing up, but note that we reduced the complexity of the exclusion mechanism
in D16204 because it was too costly.

So I'm questioning the need for recursion/exclusion entirely - what is the
optimization value vs. cost of recursively computing known bits based on
assumptions?
This was built into the implementation from the start with 60db058,
and we have kept adding code/cost to deal with that capability.

By clearing the query's AssumptionCache inside computeKnownBitsFromAssume(),
this patch retains all existing assume functionality except refining known
bits based on even more assumptions.

We have 1 regression test that shows a difference in optimization power.

Differential Revision: https://reviews.llvm.org/D100573
This commit is contained in:
Sanjay Patel 2021-04-16 08:39:22 -04:00
parent b06c55a698
commit bb907b26e2
2 changed files with 42 additions and 59 deletions

View File

@ -107,40 +107,13 @@ struct Query {
// provide it currently.
OptimizationRemarkEmitter *ORE;
/// Set of assumptions that should be excluded from further queries.
/// This is because of the potential for mutual recursion to cause
/// computeKnownBits to repeatedly visit the same assume intrinsic. The
/// classic case of this is assume(x = y), which will attempt to determine
/// bits in x from bits in y, which will attempt to determine bits in y from
/// bits in x, etc. Regarding the mutual recursion, computeKnownBits can call
/// isKnownNonZero, which calls computeKnownBits and isKnownToBeAPowerOfTwo
/// (all of which can call computeKnownBits), and so on.
std::array<const Value *, MaxAnalysisRecursionDepth> Excluded;
/// If true, it is safe to use metadata during simplification.
InstrInfoQuery IIQ;
unsigned NumExcluded = 0;
Query(const DataLayout &DL, AssumptionCache *AC, const Instruction *CxtI,
const DominatorTree *DT, bool UseInstrInfo,
OptimizationRemarkEmitter *ORE = nullptr)
: DL(DL), AC(AC), CxtI(CxtI), DT(DT), ORE(ORE), IIQ(UseInstrInfo) {}
Query(const Query &Q, const Value *NewExcl)
: DL(Q.DL), AC(Q.AC), CxtI(Q.CxtI), DT(Q.DT), ORE(Q.ORE), IIQ(Q.IIQ),
NumExcluded(Q.NumExcluded) {
Excluded = Q.Excluded;
Excluded[NumExcluded++] = NewExcl;
assert(NumExcluded <= Excluded.size());
}
bool isExcluded(const Value *Value) const {
if (NumExcluded == 0)
return false;
auto End = Excluded.begin() + NumExcluded;
return std::find(Excluded.begin(), End, Value) != End;
}
};
} // end anonymous namespace
@ -632,8 +605,6 @@ static bool isKnownNonZeroFromAssume(const Value *V, const Query &Q) {
CallInst *I = cast<CallInst>(AssumeVH);
assert(I->getFunction() == Q.CxtI->getFunction() &&
"Got assumption for the wrong function!");
if (Q.isExcluded(I))
continue;
// Warning: This loop can end up being somewhat performance sensitive.
// We're running this loop for once for each value queried resulting in a
@ -681,8 +652,6 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
CallInst *I = cast<CallInst>(AssumeVH);
assert(I->getParent()->getParent() == Q.CxtI->getParent()->getParent() &&
"Got assumption for the wrong function!");
if (Q.isExcluded(I))
continue;
// Warning: This loop can end up being somewhat performance sensitive.
// We're running this loop for once for each value queried resulting in a
@ -713,6 +682,15 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
if (!Cmp)
continue;
// We are attempting to compute known bits for the operands of an assume.
// Do not try to use other assumptions for those recursive calls because
// that can lead to mutual recursion and a compile-time explosion.
// An example of the mutual recursion: computeKnownBits can call
// isKnownNonZero which calls computeKnownBitsFromAssume (this function)
// and so on.
Query QueryNoAC = Q;
QueryNoAC.AC = nullptr;
// Note that ptrtoint may change the bitwidth.
Value *A, *B;
auto m_V = m_CombineOr(m_Specific(V), m_PtrToInt(m_Specific(V)));
@ -727,7 +705,7 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
if (match(Cmp, m_c_ICmp(Pred, m_V, m_Value(A))) &&
isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
KnownBits RHSKnown =
computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
Known.Zero |= RHSKnown.Zero;
Known.One |= RHSKnown.One;
// assume(v & b = a)
@ -735,9 +713,9 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
m_c_ICmp(Pred, m_c_And(m_V, m_Value(B)), m_Value(A))) &&
isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
KnownBits RHSKnown =
computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
KnownBits MaskKnown =
computeKnownBits(B, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
computeKnownBits(B, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
// For those bits in the mask that are known to be one, we can propagate
// known bits from the RHS to V.
@ -748,9 +726,9 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
m_Value(A))) &&
isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
KnownBits RHSKnown =
computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
KnownBits MaskKnown =
computeKnownBits(B, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
computeKnownBits(B, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
// For those bits in the mask that are known to be one, we can propagate
// inverted known bits from the RHS to V.
@ -761,9 +739,9 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
m_c_ICmp(Pred, m_c_Or(m_V, m_Value(B)), m_Value(A))) &&
isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
KnownBits RHSKnown =
computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
KnownBits BKnown =
computeKnownBits(B, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
computeKnownBits(B, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
// For those bits in B that are known to be zero, we can propagate known
// bits from the RHS to V.
@ -774,9 +752,9 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
m_Value(A))) &&
isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
KnownBits RHSKnown =
computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
KnownBits BKnown =
computeKnownBits(B, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
computeKnownBits(B, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
// For those bits in B that are known to be zero, we can propagate
// inverted known bits from the RHS to V.
@ -787,9 +765,9 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
m_c_ICmp(Pred, m_c_Xor(m_V, m_Value(B)), m_Value(A))) &&
isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
KnownBits RHSKnown =
computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
KnownBits BKnown =
computeKnownBits(B, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
computeKnownBits(B, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
// For those bits in B that are known to be zero, we can propagate known
// bits from the RHS to V. For those bits in B that are known to be one,
@ -803,9 +781,9 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
m_Value(A))) &&
isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
KnownBits RHSKnown =
computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
KnownBits BKnown =
computeKnownBits(B, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
computeKnownBits(B, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
// For those bits in B that are known to be zero, we can propagate
// inverted known bits from the RHS to V. For those bits in B that are
@ -819,7 +797,7 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
m_Value(A))) &&
isValidAssumeForContext(I, Q.CxtI, Q.DT) && C < BitWidth) {
KnownBits RHSKnown =
computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
// For those bits in RHS that are known, we can propagate them to known
// bits in V shifted to the right by C.
@ -832,7 +810,7 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
m_Value(A))) &&
isValidAssumeForContext(I, Q.CxtI, Q.DT) && C < BitWidth) {
KnownBits RHSKnown =
computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
// For those bits in RHS that are known, we can propagate them inverted
// to known bits in V shifted to the right by C.
RHSKnown.One.lshrInPlace(C);
@ -844,7 +822,7 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
m_Value(A))) &&
isValidAssumeForContext(I, Q.CxtI, Q.DT) && C < BitWidth) {
KnownBits RHSKnown =
computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
// For those bits in RHS that are known, we can propagate them to known
// bits in V shifted to the right by C.
Known.Zero |= RHSKnown.Zero << C;
@ -854,7 +832,7 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
m_Value(A))) &&
isValidAssumeForContext(I, Q.CxtI, Q.DT) && C < BitWidth) {
KnownBits RHSKnown =
computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
// For those bits in RHS that are known, we can propagate them inverted
// to known bits in V shifted to the right by C.
Known.Zero |= RHSKnown.One << C;
@ -866,7 +844,7 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) &&
isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
KnownBits RHSKnown =
computeKnownBits(A, Depth + 1, Query(Q, I)).anyextOrTrunc(BitWidth);
computeKnownBits(A, Depth + 1, QueryNoAC).anyextOrTrunc(BitWidth);
if (RHSKnown.isNonNegative()) {
// We know that the sign bit is zero.
@ -879,7 +857,7 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) &&
isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
KnownBits RHSKnown =
computeKnownBits(A, Depth + 1, Query(Q, I)).anyextOrTrunc(BitWidth);
computeKnownBits(A, Depth + 1, QueryNoAC).anyextOrTrunc(BitWidth);
if (RHSKnown.isAllOnes() || RHSKnown.isNonNegative()) {
// We know that the sign bit is zero.
@ -892,7 +870,7 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) &&
isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
KnownBits RHSKnown =
computeKnownBits(A, Depth + 1, Query(Q, I)).anyextOrTrunc(BitWidth);
computeKnownBits(A, Depth + 1, QueryNoAC).anyextOrTrunc(BitWidth);
if (RHSKnown.isNegative()) {
// We know that the sign bit is one.
@ -905,7 +883,7 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) &&
isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
KnownBits RHSKnown =
computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
if (RHSKnown.isZero() || RHSKnown.isNegative()) {
// We know that the sign bit is one.
@ -918,7 +896,7 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) &&
isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
KnownBits RHSKnown =
computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
// Whatever high bits in c are zero are known to be zero.
Known.Zero.setHighBits(RHSKnown.countMinLeadingZeros());
@ -929,7 +907,7 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) &&
isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
KnownBits RHSKnown =
computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
// If the RHS is known zero, then this assumption must be wrong (nothing
// is unsigned less than zero). Signal a conflict and get out of here.
@ -941,7 +919,7 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
// Whatever high bits in c are zero are known to be zero (if c is a power
// of 2, then one more).
if (isKnownToBeAPowerOfTwo(A, false, Depth + 1, Query(Q, I)))
if (isKnownToBeAPowerOfTwo(A, false, Depth + 1, QueryNoAC))
Known.Zero.setHighBits(RHSKnown.countMinLeadingZeros() + 1);
else
Known.Zero.setHighBits(RHSKnown.countMinLeadingZeros());

View File

@ -175,15 +175,20 @@ entry:
ret i32 %and1
}
define i32 @bar4(i32 %a, i32 %b) {
; CHECK-LABEL: @bar4(
; If we allow recursive known bits queries based on
; assumptions, we could do better here:
; a == b and a & 7 == 1, so b & 7 == 1, so b & 3 == 1, so return 1.
define i32 @known_bits_recursion_via_assumes(i32 %a, i32 %b) {
; CHECK-LABEL: @known_bits_recursion_via_assumes(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[AND1:%.*]] = and i32 [[B:%.*]], 3
; CHECK-NEXT: [[AND:%.*]] = and i32 [[A:%.*]], 7
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[AND]], 1
; CHECK-NEXT: tail call void @llvm.assume(i1 [[CMP]])
; CHECK-NEXT: [[CMP2:%.*]] = icmp eq i32 [[A]], [[B:%.*]]
; CHECK-NEXT: [[CMP2:%.*]] = icmp eq i32 [[A]], [[B]]
; CHECK-NEXT: tail call void @llvm.assume(i1 [[CMP2]])
; CHECK-NEXT: ret i32 1
; CHECK-NEXT: ret i32 [[AND1]]
;
entry:
%and1 = and i32 %b, 3