[SVE] Fix invalid uses of VectorType::getNumElements() in ValueTracking

Summary:
Any function in this module that make use of DemandedElts laregely does
not work with scalable vectors. DemandedElts is used to define which
elements of the vector to look at. At best, for scalable vectors, we can
express the first N elements of the vector. However, in practice, most
code that uses these functions expect to be able to talk about the
entire vector. In principle, this module should be able to be extended
to work with scalable vectors. However, before we can do that, we should
ensure that it does not cause code with scalable vectors to miscompile.
All functions that use a DemandedElts will bail out if the vector is
scalable. Usages of getNumElements() are updated to go through
FixedVectorType pointers.

Reviewers: rengolin, efriedma, sdesmalen, c-rhodes, spatel

Reviewed By: efriedma

Subscribers: david-arm, tschuett, kristof.beyls, hiraditya, rkruppe, psnobl, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D79053
This commit is contained in:
Christopher Tetreault 2020-05-06 09:53:57 -07:00
parent 6d6d48add8
commit 782231ac79
1 changed files with 139 additions and 94 deletions

View File

@ -206,11 +206,16 @@ static void computeKnownBits(const Value *V, const APInt &DemandedElts,
static void computeKnownBits(const Value *V, KnownBits &Known, unsigned Depth,
const Query &Q) {
Type *Ty = V->getType();
// FIXME: We currently have no way to represent the DemandedElts of a scalable
// vector
if (isa<ScalableVectorType>(V->getType())) {
Known.resetAll();
return;
}
auto *FVTy = dyn_cast<FixedVectorType>(V->getType());
APInt DemandedElts =
Ty->isVectorTy()
? APInt::getAllOnesValue(cast<VectorType>(Ty)->getNumElements())
: APInt(1, 1);
FVTy ? APInt::getAllOnesValue(FVTy->getNumElements()) : APInt(1, 1);
computeKnownBits(V, DemandedElts, Known, Depth, Q);
}
@ -374,11 +379,14 @@ static unsigned ComputeNumSignBits(const Value *V, const APInt &DemandedElts,
static unsigned ComputeNumSignBits(const Value *V, unsigned Depth,
const Query &Q) {
Type *Ty = V->getType();
// FIXME: We currently have no way to represent the DemandedElts of a scalable
// vector
if (isa<ScalableVectorType>(V->getType()))
return 1;
auto *FVTy = dyn_cast<FixedVectorType>(V->getType());
APInt DemandedElts =
Ty->isVectorTy()
? APInt::getAllOnesValue(cast<VectorType>(Ty)->getNumElements())
: APInt(1, 1);
FVTy ? APInt::getAllOnesValue(FVTy->getNumElements()) : APInt(1, 1);
return ComputeNumSignBits(V, DemandedElts, Depth, Q);
}
@ -1808,7 +1816,12 @@ static void computeKnownBitsFromOperator(const Operator *I,
const Value *Vec = I->getOperand(0);
const Value *Idx = I->getOperand(1);
auto *CIdx = dyn_cast<ConstantInt>(Idx);
unsigned NumElts = cast<VectorType>(Vec->getType())->getNumElements();
if (isa<ScalableVectorType>(Vec->getType())) {
// FIXME: there's probably *something* we can do with scalable vectors
Known.resetAll();
break;
}
unsigned NumElts = cast<FixedVectorType>(Vec->getType())->getNumElements();
APInt DemandedVecElts = APInt::getAllOnesValue(NumElts);
if (CIdx && CIdx->getValue().ult(NumElts))
DemandedVecElts = APInt::getOneBitSet(NumElts, CIdx->getZExtValue());
@ -1880,31 +1893,42 @@ KnownBits computeKnownBits(const Value *V, unsigned Depth, const Query &Q) {
/// for all of the demanded elements in the vector specified by DemandedElts.
void computeKnownBits(const Value *V, const APInt &DemandedElts,
KnownBits &Known, unsigned Depth, const Query &Q) {
assert(V && "No Value?");
assert(Depth <= MaxDepth && "Limit Search Depth");
unsigned BitWidth = Known.getBitWidth();
Type *Ty = V->getType();
assert((Ty->isIntOrIntVectorTy(BitWidth) || Ty->isPtrOrPtrVectorTy()) &&
"Not integer or pointer type!");
assert(((Ty->isVectorTy() && cast<VectorType>(Ty)->getNumElements() ==
DemandedElts.getBitWidth()) ||
(!Ty->isVectorTy() && DemandedElts == APInt(1, 1))) &&
"Unexpected vector size");
Type *ScalarTy = Ty->getScalarType();
unsigned ExpectedWidth = ScalarTy->isPointerTy() ?
Q.DL.getPointerTypeSizeInBits(ScalarTy) : Q.DL.getTypeSizeInBits(ScalarTy);
assert(ExpectedWidth == BitWidth && "V and Known should have same BitWidth");
(void)BitWidth;
(void)ExpectedWidth;
if (!DemandedElts) {
// No demanded elts, better to assume we don't know anything.
if (!DemandedElts || isa<ScalableVectorType>(V->getType())) {
// No demanded elts or V is a scalable vector, better to assume we don't
// know anything.
Known.resetAll();
return;
}
assert(V && "No Value?");
assert(Depth <= MaxDepth && "Limit Search Depth");
#ifndef NDEBUG
Type *Ty = V->getType();
unsigned BitWidth = Known.getBitWidth();
assert((Ty->isIntOrIntVectorTy(BitWidth) || Ty->isPtrOrPtrVectorTy()) &&
"Not integer or pointer type!");
if (auto *FVTy = dyn_cast<FixedVectorType>(Ty)) {
assert(
FVTy->getNumElements() == DemandedElts.getBitWidth() &&
"DemandedElt width should equal the fixed vector number of elements");
} else {
assert(DemandedElts == APInt(1, 1) &&
"DemandedElt width should be 1 for scalars");
}
Type *ScalarTy = Ty->getScalarType();
if (ScalarTy->isPointerTy()) {
assert(BitWidth == Q.DL.getPointerTypeSizeInBits(ScalarTy) &&
"V and Known should have same BitWidth");
} else {
assert(BitWidth == Q.DL.getTypeSizeInBits(ScalarTy) &&
"V and Known should have same BitWidth");
}
#endif
const APInt *C;
if (match(V, m_APInt(C))) {
// We know all of the bits for a scalar constant or a splat vector constant!
@ -1919,17 +1943,14 @@ void computeKnownBits(const Value *V, const APInt &DemandedElts,
}
// Handle a constant vector by taking the intersection of the known bits of
// each element.
if (const ConstantDataSequential *CDS = dyn_cast<ConstantDataSequential>(V)) {
assert((!Ty->isVectorTy() ||
CDS->getNumElements() == DemandedElts.getBitWidth()) &&
"Unexpected vector size");
// We know that CDS must be a vector of integers. Take the intersection of
if (const ConstantDataVector *CDV = dyn_cast<ConstantDataVector>(V)) {
// We know that CDV must be a vector of integers. Take the intersection of
// each element.
Known.Zero.setAllBits(); Known.One.setAllBits();
for (unsigned i = 0, e = CDS->getNumElements(); i != e; ++i) {
if (Ty->isVectorTy() && !DemandedElts[i])
for (unsigned i = 0, e = CDV->getNumElements(); i != e; ++i) {
if (!DemandedElts[i])
continue;
APInt Elt = CDS->getElementAsAPInt(i);
APInt Elt = CDV->getElementAsAPInt(i);
Known.Zero &= ~Elt;
Known.One &= Elt;
}
@ -1937,8 +1958,6 @@ void computeKnownBits(const Value *V, const APInt &DemandedElts,
}
if (const auto *CV = dyn_cast<ConstantVector>(V)) {
assert(CV->getNumOperands() == DemandedElts.getBitWidth() &&
"Unexpected vector size");
// We know that CV must be a vector of integers. Take the intersection of
// each element.
Known.Zero.setAllBits(); Known.One.setAllBits();
@ -1986,7 +2005,7 @@ void computeKnownBits(const Value *V, const APInt &DemandedElts,
computeKnownBitsFromOperator(I, DemandedElts, Known, Depth, Q);
// Aligned pointers have trailing zeros - refine Known.Zero set
if (Ty->isPointerTy()) {
if (isa<PointerType>(V->getType())) {
const MaybeAlign Align = V->getPointerAlignment(Q.DL);
if (Align)
Known.Zero.setLowBits(countTrailingZeros(Align->value()));
@ -2274,6 +2293,11 @@ static bool rangeMetadataExcludesValue(const MDNode* Ranges, const APInt& Value)
/// Supports values with integer or pointer type and vectors of integers.
bool isKnownNonZero(const Value *V, const APInt &DemandedElts, unsigned Depth,
const Query &Q) {
// FIXME: We currently have no way to represent the DemandedElts of a scalable
// vector
if (isa<ScalableVectorType>(V->getType()))
return false;
if (auto *C = dyn_cast<Constant>(V)) {
if (C->isNullValue())
return false;
@ -2292,7 +2316,7 @@ bool isKnownNonZero(const Value *V, const APInt &DemandedElts, unsigned Depth,
// For constant vectors, check that all elements are undefined or known
// non-zero to determine that the whole vector is known non-zero.
if (auto *VecTy = dyn_cast<VectorType>(C->getType())) {
if (auto *VecTy = dyn_cast<FixedVectorType>(C->getType())) {
for (unsigned i = 0, e = VecTy->getNumElements(); i != e; ++i) {
if (!DemandedElts[i])
continue;
@ -2527,7 +2551,7 @@ bool isKnownNonZero(const Value *V, const APInt &DemandedElts, unsigned Depth,
const Value *Vec = EEI->getVectorOperand();
const Value *Idx = EEI->getIndexOperand();
auto *CIdx = dyn_cast<ConstantInt>(Idx);
unsigned NumElts = cast<VectorType>(Vec->getType())->getNumElements();
unsigned NumElts = cast<FixedVectorType>(Vec->getType())->getNumElements();
APInt DemandedVecElts = APInt::getAllOnesValue(NumElts);
if (CIdx && CIdx->getValue().ult(NumElts))
DemandedVecElts = APInt::getOneBitSet(NumElts, CIdx->getZExtValue());
@ -2540,11 +2564,14 @@ bool isKnownNonZero(const Value *V, const APInt &DemandedElts, unsigned Depth,
}
bool isKnownNonZero(const Value* V, unsigned Depth, const Query& Q) {
Type *Ty = V->getType();
// FIXME: We currently have no way to represent the DemandedElts of a scalable
// vector
if (isa<ScalableVectorType>(V->getType()))
return false;
auto *FVTy = dyn_cast<FixedVectorType>(V->getType());
APInt DemandedElts =
Ty->isVectorTy()
? APInt::getAllOnesValue(cast<VectorType>(Ty)->getNumElements())
: APInt(1, 1);
FVTy ? APInt::getAllOnesValue(FVTy->getNumElements()) : APInt(1, 1);
return isKnownNonZero(V, DemandedElts, Depth, Q);
}
@ -2641,11 +2668,11 @@ static unsigned computeNumSignBitsVectorConstant(const Value *V,
const APInt &DemandedElts,
unsigned TyBits) {
const auto *CV = dyn_cast<Constant>(V);
if (!CV || !CV->getType()->isVectorTy())
if (!CV || !isa<FixedVectorType>(CV->getType()))
return 0;
unsigned MinSignBits = TyBits;
unsigned NumElts = cast<VectorType>(CV->getType())->getNumElements();
unsigned NumElts = cast<FixedVectorType>(CV->getType())->getNumElements();
for (unsigned i = 0; i != NumElts; ++i) {
if (!DemandedElts[i])
continue;
@ -2681,18 +2708,30 @@ static unsigned ComputeNumSignBits(const Value *V, const APInt &DemandedElts,
static unsigned ComputeNumSignBitsImpl(const Value *V,
const APInt &DemandedElts,
unsigned Depth, const Query &Q) {
Type *Ty = V->getType();
// FIXME: We currently have no way to represent the DemandedElts of a scalable
// vector
if (isa<ScalableVectorType>(Ty))
return 1;
#ifndef NDEBUG
assert(Depth <= MaxDepth && "Limit Search Depth");
if (auto *FVTy = dyn_cast<FixedVectorType>(Ty)) {
assert(
FVTy->getNumElements() == DemandedElts.getBitWidth() &&
"DemandedElt width should equal the fixed vector number of elements");
} else {
assert(DemandedElts == APInt(1, 1) &&
"DemandedElt width should be 1 for scalars");
}
#endif
// We return the minimum number of sign bits that are guaranteed to be present
// in V, so for undef we have to conservatively return 1. We don't have the
// same behavior for poison though -- that's a FIXME today.
Type *Ty = V->getType();
assert(((Ty->isVectorTy() && cast<VectorType>(Ty)->getNumElements() ==
DemandedElts.getBitWidth()) ||
(!Ty->isVectorTy() && DemandedElts == APInt(1, 1))) &&
"Unexpected vector size");
Type *ScalarTy = Ty->getScalarType();
unsigned TyBits = ScalarTy->isPointerTy() ?
Q.DL.getPointerTypeSizeInBits(ScalarTy) :
@ -3266,8 +3305,8 @@ static bool cannotBeOrderedLessThanZeroImpl(const Value *V,
// Handle vector of constants.
if (auto *CV = dyn_cast<Constant>(V)) {
if (auto *CVVTy = dyn_cast<VectorType>(CV->getType())) {
unsigned NumElts = CVVTy->getNumElements();
if (auto *CVFVTy = dyn_cast<FixedVectorType>(CV->getType())) {
unsigned NumElts = CVFVTy->getNumElements();
for (unsigned i = 0; i != NumElts; ++i) {
auto *CFP = dyn_cast_or_null<ConstantFP>(CV->getAggregateElement(i));
if (!CFP)
@ -3438,24 +3477,26 @@ bool llvm::isKnownNeverInfinity(const Value *V, const TargetLibraryInfo *TLI,
}
}
// Bail out for constant expressions, but try to handle vector constants.
if (!V->getType()->isVectorTy() || !isa<Constant>(V))
return false;
// For vectors, verify that each element is not infinity.
unsigned NumElts = cast<VectorType>(V->getType())->getNumElements();
for (unsigned i = 0; i != NumElts; ++i) {
Constant *Elt = cast<Constant>(V)->getAggregateElement(i);
if (!Elt)
return false;
if (isa<UndefValue>(Elt))
continue;
auto *CElt = dyn_cast<ConstantFP>(Elt);
if (!CElt || CElt->isInfinity())
return false;
// try to handle fixed width vector constants
if (isa<FixedVectorType>(V->getType()) && isa<Constant>(V)) {
// For vectors, verify that each element is not infinity.
unsigned NumElts = cast<VectorType>(V->getType())->getNumElements();
for (unsigned i = 0; i != NumElts; ++i) {
Constant *Elt = cast<Constant>(V)->getAggregateElement(i);
if (!Elt)
return false;
if (isa<UndefValue>(Elt))
continue;
auto *CElt = dyn_cast<ConstantFP>(Elt);
if (!CElt || CElt->isInfinity())
return false;
}
// All elements were confirmed non-infinity or undefined.
return true;
}
// All elements were confirmed non-infinity or undefined.
return true;
// was not able to prove that V never contains infinity
return false;
}
bool llvm::isKnownNeverNaN(const Value *V, const TargetLibraryInfo *TLI,
@ -3539,24 +3580,26 @@ bool llvm::isKnownNeverNaN(const Value *V, const TargetLibraryInfo *TLI,
}
}
// Bail out for constant expressions, but try to handle vector constants.
if (!V->getType()->isVectorTy() || !isa<Constant>(V))
return false;
// For vectors, verify that each element is not NaN.
unsigned NumElts = cast<VectorType>(V->getType())->getNumElements();
for (unsigned i = 0; i != NumElts; ++i) {
Constant *Elt = cast<Constant>(V)->getAggregateElement(i);
if (!Elt)
return false;
if (isa<UndefValue>(Elt))
continue;
auto *CElt = dyn_cast<ConstantFP>(Elt);
if (!CElt || CElt->isNaN())
return false;
// Try to handle fixed width vector constants
if (isa<FixedVectorType>(V->getType()) && isa<Constant>(V)) {
// For vectors, verify that each element is not NaN.
unsigned NumElts = cast<VectorType>(V->getType())->getNumElements();
for (unsigned i = 0; i != NumElts; ++i) {
Constant *Elt = cast<Constant>(V)->getAggregateElement(i);
if (!Elt)
return false;
if (isa<UndefValue>(Elt))
continue;
auto *CElt = dyn_cast<ConstantFP>(Elt);
if (!CElt || CElt->isNaN())
return false;
}
// All elements were confirmed not-NaN or undefined.
return true;
}
// All elements were confirmed not-NaN or undefined.
return true;
// Was not able to prove that V never contains NaN
return false;
}
Value *llvm::isBytewiseValue(Value *V, const DataLayout &DL) {
@ -4634,11 +4677,13 @@ bool llvm::canCreatePoison(const Instruction *I) {
// Shifts return poison if shiftwidth is larger than the bitwidth.
if (auto *C = dyn_cast<Constant>(I->getOperand(1))) {
SmallVector<Constant *, 4> ShiftAmounts;
if (C->getType()->isVectorTy()) {
unsigned NumElts = cast<VectorType>(C->getType())->getNumElements();
if (auto *FVTy = dyn_cast<FixedVectorType>(C->getType())) {
unsigned NumElts = FVTy->getNumElements();
for (unsigned i = 0; i < NumElts; ++i)
ShiftAmounts.push_back(C->getAggregateElement(i));
} else
} else if (isa<ScalableVectorType>(C->getType()))
return true; // Can't tell, just return true to be safe
else
ShiftAmounts.push_back(C);
bool Safe = llvm::all_of(ShiftAmounts, [](Constant *C) {