[Local] collectBitParts - bail out if we find more than one root input value.

All the uses that we have for collectBitParts revolve around us matching down to an operation with a single root value - I don't think we're intending to change that (and a lot of collectBitParts assumes it).

The binops cases (OR/FSHL/FSHR) already check if the providers are the same, but that would still mean we waste time collecting through unaryops before getting to them.
This commit is contained in:
Simon Pilgrim 2021-05-15 13:58:42 +01:00
parent 401d6685c0
commit f0660a977e
1 changed files with 34 additions and 25 deletions

View File

@ -2879,7 +2879,8 @@ struct BitPart {
/// does not invalidate internal references (std::map instead of DenseMap).
static const Optional<BitPart> &
collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals,
std::map<Value *, Optional<BitPart>> &BPS, int Depth) {
std::map<Value *, Optional<BitPart>> &BPS, int Depth,
bool &FoundRoot) {
auto I = BPS.find(V);
if (I != BPS.end())
return I->second;
@ -2904,13 +2905,13 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals,
// If this is an or instruction, it may be an inner node of the bswap.
if (match(V, m_Or(m_Value(X), m_Value(Y)))) {
// Check we have both sources and they are from the same provider.
const auto &A =
collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS, Depth + 1);
const auto &A = collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS,
Depth + 1, FoundRoot);
if (!A || !A->Provider)
return Result;
const auto &B =
collectBitParts(Y, MatchBSwaps, MatchBitReversals, BPS, Depth + 1);
const auto &B = collectBitParts(Y, MatchBSwaps, MatchBitReversals, BPS,
Depth + 1, FoundRoot);
if (!B || A->Provider != B->Provider)
return Result;
@ -2943,8 +2944,8 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals,
if (!MatchBitReversals && (BitShift.getZExtValue() % 8) != 0)
return Result;
const auto &Res =
collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS, Depth + 1);
const auto &Res = collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS,
Depth + 1, FoundRoot);
if (!Res)
return Result;
Result = Res;
@ -2973,8 +2974,8 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals,
if (!MatchBitReversals && (NumMaskedBits % 8) != 0)
return Result;
const auto &Res =
collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS, Depth + 1);
const auto &Res = collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS,
Depth + 1, FoundRoot);
if (!Res)
return Result;
Result = Res;
@ -2988,8 +2989,8 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals,
// If this is a zext instruction zero extend the result.
if (match(V, m_ZExt(m_Value(X)))) {
const auto &Res =
collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS, Depth + 1);
const auto &Res = collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS,
Depth + 1, FoundRoot);
if (!Res)
return Result;
@ -3004,8 +3005,8 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals,
// If this is a truncate instruction, extract the lower bits.
if (match(V, m_Trunc(m_Value(X)))) {
const auto &Res =
collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS, Depth + 1);
const auto &Res = collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS,
Depth + 1, FoundRoot);
if (!Res)
return Result;
@ -3018,8 +3019,8 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals,
// BITREVERSE - most likely due to us previous matching a partial
// bitreverse.
if (match(V, m_BitReverse(m_Value(X)))) {
const auto &Res =
collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS, Depth + 1);
const auto &Res = collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS,
Depth + 1, FoundRoot);
if (!Res)
return Result;
@ -3031,8 +3032,8 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals,
// BSWAP - most likely due to us previous matching a partial bswap.
if (match(V, m_BSwap(m_Value(X)))) {
const auto &Res =
collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS, Depth + 1);
const auto &Res = collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS,
Depth + 1, FoundRoot);
if (!Res)
return Result;
@ -3063,13 +3064,13 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals,
return Result;
// Check we have both sources and they are from the same provider.
const auto &LHS =
collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS, Depth + 1);
const auto &LHS = collectBitParts(X, MatchBSwaps, MatchBitReversals, BPS,
Depth + 1, FoundRoot);
if (!LHS || !LHS->Provider)
return Result;
return Result;
const auto &RHS =
collectBitParts(Y, MatchBSwaps, MatchBitReversals, BPS, Depth + 1);
const auto &RHS = collectBitParts(Y, MatchBSwaps, MatchBitReversals, BPS,
Depth + 1, FoundRoot);
if (!RHS || LHS->Provider != RHS->Provider)
return Result;
@ -3083,8 +3084,14 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals,
}
}
// Okay, we got to something that isn't a shift, 'or' or 'and'. This must be
// the input value to the bswap/bitreverse.
// If we've already found a root input value then we're never going to merge
// these back together.
if (FoundRoot)
return Result;
// Okay, we got to something that isn't a shift, 'or', 'and', etc. This must
// be the root input value to the bswap/bitreverse.
FoundRoot = true;
Result = BitPart(V, BitWidth);
for (unsigned BitIdx = 0; BitIdx < BitWidth; ++BitIdx)
Result->Provenance[BitIdx] = BitIdx;
@ -3126,8 +3133,10 @@ bool llvm::recognizeBSwapOrBitReverseIdiom(
DemandedTy = Trunc->getType();
// Try to find all the pieces corresponding to the bswap.
bool FoundRoot = false;
std::map<Value *, Optional<BitPart>> BPS;
auto Res = collectBitParts(I, MatchBSwaps, MatchBitReversals, BPS, 0);
const auto &Res =
collectBitParts(I, MatchBSwaps, MatchBitReversals, BPS, 0, FoundRoot);
if (!Res)
return false;
ArrayRef<int8_t> BitProvenance = Res->Provenance;