[SVE] Eliminate bad VectorType::getNumElements() calls from ConstantFold

Summary:
Assume all usages of this function are explicitly fixed-width operations
and cast to FixedVectorType

Reviewers: efriedma, sdesmalen, c-rhodes, majnemer, dblaikie

Reviewed By: sdesmalen

Subscribers: tschuett, hiraditya, rkruppe, psnobl, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D80262
This commit is contained in:
Christopher Tetreault 2020-06-17 14:12:48 -07:00
parent 4b776a98f1
commit 8819202dfd
2 changed files with 62 additions and 33 deletions

View File

@ -55,8 +55,8 @@ static Constant *BitCastConstantVector(Constant *CV, VectorType *DstTy) {
// If this cast changes element count then we can't handle it here:
// doing so requires endianness information. This should be handled by
// Analysis/ConstantFolding.cpp
unsigned NumElts = DstTy->getNumElements();
if (NumElts != cast<VectorType>(CV->getType())->getNumElements())
unsigned NumElts = cast<FixedVectorType>(DstTy)->getNumElements();
if (NumElts != cast<FixedVectorType>(CV->getType())->getNumElements())
return nullptr;
Type *DstEltTy = DstTy->getElementType();
@ -573,8 +573,8 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V,
// count may be mismatched; don't attempt to handle that here.
if ((isa<ConstantVector>(V) || isa<ConstantDataVector>(V)) &&
DestTy->isVectorTy() &&
cast<VectorType>(DestTy)->getNumElements() ==
cast<VectorType>(V->getType())->getNumElements()) {
cast<FixedVectorType>(DestTy)->getNumElements() ==
cast<FixedVectorType>(V->getType())->getNumElements()) {
VectorType *DestVecTy = cast<VectorType>(DestTy);
Type *DstEltTy = DestVecTy->getElementType();
// Fast path for splatted constants.
@ -585,7 +585,8 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V,
}
SmallVector<Constant *, 16> res;
Type *Ty = IntegerType::get(V->getContext(), 32);
for (unsigned i = 0, e = cast<VectorType>(V->getType())->getNumElements();
for (unsigned i = 0,
e = cast<FixedVectorType>(V->getType())->getNumElements();
i != e; ++i) {
Constant *C =
ConstantExpr::getExtractElement(V, ConstantInt::get(Ty, i));
@ -809,9 +810,11 @@ Constant *llvm::ConstantFoldExtractElementInstruction(Constant *Val,
if (!CIdx)
return nullptr;
// ee({w,x,y,z}, wrong_value) -> undef
if (CIdx->uge(ValVTy->getNumElements()))
return UndefValue::get(ValVTy->getElementType());
if (auto *ValFVTy = dyn_cast<FixedVectorType>(Val->getType())) {
// ee({w,x,y,z}, wrong_value) -> undef
if (CIdx->uge(ValFVTy->getNumElements()))
return UndefValue::get(ValFVTy->getElementType());
}
// ee (gep (ptr, idx0, ...), idx) -> gep (ee (ptr, idx), ee (idx0, idx), ...)
if (auto *CE = dyn_cast<ConstantExpr>(Val)) {
@ -823,7 +826,7 @@ Constant *llvm::ConstantFoldExtractElementInstruction(Constant *Val,
if (Op->getType()->isVectorTy()) {
Constant *ScalarOp = ConstantExpr::getExtractElement(Op, Idx);
if (!ScalarOp)
return nullptr;
return nullptr;
Ops.push_back(ScalarOp);
} else
Ops.push_back(Op);
@ -833,6 +836,16 @@ Constant *llvm::ConstantFoldExtractElementInstruction(Constant *Val,
}
}
// CAZ of type ScalableVectorType and n < CAZ->getMinNumElements() =>
// extractelt CAZ, n -> 0
if (auto *ValSVTy = dyn_cast<ScalableVectorType>(Val->getType())) {
if (!CIdx->uge(ValSVTy->getMinNumElements())) {
if (auto *CAZ = dyn_cast<ConstantAggregateZero>(Val))
return CAZ->getElementValue(CIdx->getZExtValue());
}
return nullptr;
}
return Val->getAggregateElement(CIdx);
}
@ -847,11 +860,12 @@ Constant *llvm::ConstantFoldInsertElementInstruction(Constant *Val,
// Do not iterate on scalable vector. The num of elements is unknown at
// compile-time.
VectorType *ValTy = cast<VectorType>(Val->getType());
if (isa<ScalableVectorType>(ValTy))
if (isa<ScalableVectorType>(Val->getType()))
return nullptr;
unsigned NumElts = cast<VectorType>(Val->getType())->getNumElements();
auto *ValTy = cast<FixedVectorType>(Val->getType());
unsigned NumElts = ValTy->getNumElements();
if (CIdx->uge(NumElts))
return UndefValue::get(Val->getType());
@ -898,7 +912,7 @@ Constant *llvm::ConstantFoldShuffleVectorInstruction(Constant *V1, Constant *V2,
if (isa<ScalableVectorType>(V1VTy))
return nullptr;
unsigned SrcNumElts = V1VTy->getNumElements();
unsigned SrcNumElts = V1VTy->getElementCount().Min;
// Loop over the shuffle mask, evaluating each element.
SmallVector<Constant*, 32> Result;
@ -998,11 +1012,8 @@ Constant *llvm::ConstantFoldUnaryInstruction(unsigned Opcode, Constant *C) {
case Instruction::FNeg:
return ConstantFP::get(C->getContext(), neg(CV));
}
} else if (VectorType *VTy = dyn_cast<VectorType>(C->getType())) {
// Do not iterate on scalable vector. The number of elements is unknown at
// compile-time.
if (IsScalableVector)
return nullptr;
} else if (auto *VTy = dyn_cast<FixedVectorType>(C->getType())) {
Type *Ty = IntegerType::get(VTy->getContext(), 32);
// Fast path for splatted constants.
if (Constant *Splat = C->getSplatValue()) {
@ -1011,7 +1022,7 @@ Constant *llvm::ConstantFoldUnaryInstruction(unsigned Opcode, Constant *C) {
}
// Fold each element and create a vector constant from those constants.
SmallVector<Constant*, 16> Result;
SmallVector<Constant *, 16> Result;
for (unsigned i = 0, e = VTy->getNumElements(); i != e; ++i) {
Constant *ExtractIdx = ConstantInt::get(Ty, i);
Constant *Elt = ConstantExpr::getExtractElement(C, ExtractIdx);
@ -1367,11 +1378,12 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode, Constant *C1,
return ConstantFP::get(C1->getContext(), C3V);
}
}
} else if (VectorType *VTy = dyn_cast<VectorType>(C1->getType())) {
} else if (IsScalableVector) {
// Do not iterate on scalable vector. The number of elements is unknown at
// compile-time.
if (IsScalableVector)
return nullptr;
// FIXME: this branch can potentially be removed
return nullptr;
} else if (auto *VTy = dyn_cast<FixedVectorType>(C1->getType())) {
// Fast path for splatted constants.
if (Constant *C2Splat = C2->getSplatValue()) {
if (Instruction::isIntDivRem(Opcode) && C2Splat->isNullValue())
@ -2014,7 +2026,7 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred,
SmallVector<Constant*, 4> ResElts;
Type *Ty = IntegerType::get(C1->getContext(), 32);
// Compare the elements, producing an i1 result or constant expr.
for (unsigned i = 0, e = C1VTy->getNumElements(); i != e; ++i) {
for (unsigned i = 0, e = C1VTy->getElementCount().Min; i != e; ++i) {
Constant *C1E =
ConstantExpr::getExtractElement(C1, ConstantInt::get(Ty, i));
Constant *C2E =
@ -2286,14 +2298,18 @@ Constant *llvm::ConstantFoldGetElementPtr(Type *PointeeTy, Constant *C,
assert(Ty && "Invalid indices for GEP!");
Type *OrigGEPTy = PointerType::get(Ty, PtrTy->getAddressSpace());
Type *GEPTy = PointerType::get(Ty, PtrTy->getAddressSpace());
if (VectorType *VT = dyn_cast<VectorType>(C->getType()))
GEPTy = FixedVectorType::get(OrigGEPTy, VT->getNumElements());
if (VectorType *VT = dyn_cast<VectorType>(C->getType())) {
// FIXME: handle scalable vectors (use getElementCount())
GEPTy = FixedVectorType::get(
OrigGEPTy, cast<FixedVectorType>(VT)->getNumElements());
}
// The GEP returns a vector of pointers when one of more of
// its arguments is a vector.
for (unsigned i = 0, e = Idxs.size(); i != e; ++i) {
if (auto *VT = dyn_cast<VectorType>(Idxs[i]->getType())) {
GEPTy = FixedVectorType::get(OrigGEPTy, VT->getNumElements());
// FIXME: handle scalable vectors
GEPTy = FixedVectorType::get(
OrigGEPTy, cast<FixedVectorType>(VT)->getNumElements());
break;
}
}
@ -2500,19 +2516,19 @@ Constant *llvm::ConstantFoldGetElementPtr(Type *PointeeTy, Constant *C,
if (!IsCurrIdxVector && IsPrevIdxVector)
CurrIdx = ConstantDataVector::getSplat(
cast<VectorType>(PrevIdx->getType())->getNumElements(), CurrIdx);
cast<FixedVectorType>(PrevIdx->getType())->getNumElements(), CurrIdx);
if (!IsPrevIdxVector && IsCurrIdxVector)
PrevIdx = ConstantDataVector::getSplat(
cast<VectorType>(CurrIdx->getType())->getNumElements(), PrevIdx);
cast<FixedVectorType>(CurrIdx->getType())->getNumElements(), PrevIdx);
Constant *Factor =
ConstantInt::get(CurrIdx->getType()->getScalarType(), NumElements);
if (UseVector)
Factor = ConstantDataVector::getSplat(
IsPrevIdxVector
? cast<VectorType>(PrevIdx->getType())->getNumElements()
: cast<VectorType>(CurrIdx->getType())->getNumElements(),
? cast<FixedVectorType>(PrevIdx->getType())->getNumElements()
: cast<FixedVectorType>(CurrIdx->getType())->getNumElements(),
Factor);
NewIdxs[i] = ConstantExpr::getSRem(CurrIdx, Factor);
@ -2531,8 +2547,8 @@ Constant *llvm::ConstantFoldGetElementPtr(Type *PointeeTy, Constant *C,
ExtendedTy = FixedVectorType::get(
ExtendedTy,
IsPrevIdxVector
? cast<VectorType>(PrevIdx->getType())->getNumElements()
: cast<VectorType>(CurrIdx->getType())->getNumElements());
? cast<FixedVectorType>(PrevIdx->getType())->getNumElements()
: cast<FixedVectorType>(CurrIdx->getType())->getNumElements());
if (!PrevIdx->getType()->isIntOrIntVectorTy(CommonExtendedWidth))
PrevIdx = ConstantExpr::getSExt(PrevIdx, ExtendedTy);

View File

@ -0,0 +1,13 @@
; RUN: opt -instcombine -S < %s | FileCheck %s
; CHECK-LABEL: definitely_in_bounds
; CHECK: ret i8 0
define i8 @definitely_in_bounds() {
ret i8 extractelement (<vscale x 16 x i8> zeroinitializer, i64 15)
}
; CHECK-LABEL: maybe_in_bounds
; CHECK: ret i8 extractelement (<vscale x 16 x i8> zeroinitializer, i64 16)
define i8 @maybe_in_bounds() {
ret i8 extractelement (<vscale x 16 x i8> zeroinitializer, i64 16)
}