From 8819202dfd2c39a7ed4dd69f0d7e0e0bcf409e2a Mon Sep 17 00:00:00 2001 From: Christopher Tetreault Date: Wed, 17 Jun 2020 14:12:48 -0700 Subject: [PATCH] [SVE] Eliminate bad VectorType::getNumElements() calls from ConstantFold Summary: Assume all usages of this function are explicitly fixed-width operations and cast to FixedVectorType Reviewers: efriedma, sdesmalen, c-rhodes, majnemer, dblaikie Reviewed By: sdesmalen Subscribers: tschuett, hiraditya, rkruppe, psnobl, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D80262 --- llvm/lib/IR/ConstantFold.cpp | 82 +++++++++++-------- .../ConstantFolding/extractelement-vscale.ll | 13 +++ 2 files changed, 62 insertions(+), 33 deletions(-) create mode 100644 llvm/test/Analysis/ConstantFolding/extractelement-vscale.ll diff --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp index 3fb49e94870f..ef584afc68bc 100644 --- a/llvm/lib/IR/ConstantFold.cpp +++ b/llvm/lib/IR/ConstantFold.cpp @@ -55,8 +55,8 @@ static Constant *BitCastConstantVector(Constant *CV, VectorType *DstTy) { // If this cast changes element count then we can't handle it here: // doing so requires endianness information. This should be handled by // Analysis/ConstantFolding.cpp - unsigned NumElts = DstTy->getNumElements(); - if (NumElts != cast(CV->getType())->getNumElements()) + unsigned NumElts = cast(DstTy)->getNumElements(); + if (NumElts != cast(CV->getType())->getNumElements()) return nullptr; Type *DstEltTy = DstTy->getElementType(); @@ -573,8 +573,8 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V, // count may be mismatched; don't attempt to handle that here. if ((isa(V) || isa(V)) && DestTy->isVectorTy() && - cast(DestTy)->getNumElements() == - cast(V->getType())->getNumElements()) { + cast(DestTy)->getNumElements() == + cast(V->getType())->getNumElements()) { VectorType *DestVecTy = cast(DestTy); Type *DstEltTy = DestVecTy->getElementType(); // Fast path for splatted constants. @@ -585,7 +585,8 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V, } SmallVector res; Type *Ty = IntegerType::get(V->getContext(), 32); - for (unsigned i = 0, e = cast(V->getType())->getNumElements(); + for (unsigned i = 0, + e = cast(V->getType())->getNumElements(); i != e; ++i) { Constant *C = ConstantExpr::getExtractElement(V, ConstantInt::get(Ty, i)); @@ -809,9 +810,11 @@ Constant *llvm::ConstantFoldExtractElementInstruction(Constant *Val, if (!CIdx) return nullptr; - // ee({w,x,y,z}, wrong_value) -> undef - if (CIdx->uge(ValVTy->getNumElements())) - return UndefValue::get(ValVTy->getElementType()); + if (auto *ValFVTy = dyn_cast(Val->getType())) { + // ee({w,x,y,z}, wrong_value) -> undef + if (CIdx->uge(ValFVTy->getNumElements())) + return UndefValue::get(ValFVTy->getElementType()); + } // ee (gep (ptr, idx0, ...), idx) -> gep (ee (ptr, idx), ee (idx0, idx), ...) if (auto *CE = dyn_cast(Val)) { @@ -823,7 +826,7 @@ Constant *llvm::ConstantFoldExtractElementInstruction(Constant *Val, if (Op->getType()->isVectorTy()) { Constant *ScalarOp = ConstantExpr::getExtractElement(Op, Idx); if (!ScalarOp) - return nullptr; + return nullptr; Ops.push_back(ScalarOp); } else Ops.push_back(Op); @@ -833,6 +836,16 @@ Constant *llvm::ConstantFoldExtractElementInstruction(Constant *Val, } } + // CAZ of type ScalableVectorType and n < CAZ->getMinNumElements() => + // extractelt CAZ, n -> 0 + if (auto *ValSVTy = dyn_cast(Val->getType())) { + if (!CIdx->uge(ValSVTy->getMinNumElements())) { + if (auto *CAZ = dyn_cast(Val)) + return CAZ->getElementValue(CIdx->getZExtValue()); + } + return nullptr; + } + return Val->getAggregateElement(CIdx); } @@ -847,11 +860,12 @@ Constant *llvm::ConstantFoldInsertElementInstruction(Constant *Val, // Do not iterate on scalable vector. The num of elements is unknown at // compile-time. - VectorType *ValTy = cast(Val->getType()); - if (isa(ValTy)) + if (isa(Val->getType())) return nullptr; - unsigned NumElts = cast(Val->getType())->getNumElements(); + auto *ValTy = cast(Val->getType()); + + unsigned NumElts = ValTy->getNumElements(); if (CIdx->uge(NumElts)) return UndefValue::get(Val->getType()); @@ -898,7 +912,7 @@ Constant *llvm::ConstantFoldShuffleVectorInstruction(Constant *V1, Constant *V2, if (isa(V1VTy)) return nullptr; - unsigned SrcNumElts = V1VTy->getNumElements(); + unsigned SrcNumElts = V1VTy->getElementCount().Min; // Loop over the shuffle mask, evaluating each element. SmallVector Result; @@ -998,11 +1012,8 @@ Constant *llvm::ConstantFoldUnaryInstruction(unsigned Opcode, Constant *C) { case Instruction::FNeg: return ConstantFP::get(C->getContext(), neg(CV)); } - } else if (VectorType *VTy = dyn_cast(C->getType())) { - // Do not iterate on scalable vector. The number of elements is unknown at - // compile-time. - if (IsScalableVector) - return nullptr; + } else if (auto *VTy = dyn_cast(C->getType())) { + Type *Ty = IntegerType::get(VTy->getContext(), 32); // Fast path for splatted constants. if (Constant *Splat = C->getSplatValue()) { @@ -1011,7 +1022,7 @@ Constant *llvm::ConstantFoldUnaryInstruction(unsigned Opcode, Constant *C) { } // Fold each element and create a vector constant from those constants. - SmallVector Result; + SmallVector Result; for (unsigned i = 0, e = VTy->getNumElements(); i != e; ++i) { Constant *ExtractIdx = ConstantInt::get(Ty, i); Constant *Elt = ConstantExpr::getExtractElement(C, ExtractIdx); @@ -1367,11 +1378,12 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode, Constant *C1, return ConstantFP::get(C1->getContext(), C3V); } } - } else if (VectorType *VTy = dyn_cast(C1->getType())) { + } else if (IsScalableVector) { // Do not iterate on scalable vector. The number of elements is unknown at // compile-time. - if (IsScalableVector) - return nullptr; + // FIXME: this branch can potentially be removed + return nullptr; + } else if (auto *VTy = dyn_cast(C1->getType())) { // Fast path for splatted constants. if (Constant *C2Splat = C2->getSplatValue()) { if (Instruction::isIntDivRem(Opcode) && C2Splat->isNullValue()) @@ -2014,7 +2026,7 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred, SmallVector ResElts; Type *Ty = IntegerType::get(C1->getContext(), 32); // Compare the elements, producing an i1 result or constant expr. - for (unsigned i = 0, e = C1VTy->getNumElements(); i != e; ++i) { + for (unsigned i = 0, e = C1VTy->getElementCount().Min; i != e; ++i) { Constant *C1E = ConstantExpr::getExtractElement(C1, ConstantInt::get(Ty, i)); Constant *C2E = @@ -2286,14 +2298,18 @@ Constant *llvm::ConstantFoldGetElementPtr(Type *PointeeTy, Constant *C, assert(Ty && "Invalid indices for GEP!"); Type *OrigGEPTy = PointerType::get(Ty, PtrTy->getAddressSpace()); Type *GEPTy = PointerType::get(Ty, PtrTy->getAddressSpace()); - if (VectorType *VT = dyn_cast(C->getType())) - GEPTy = FixedVectorType::get(OrigGEPTy, VT->getNumElements()); - + if (VectorType *VT = dyn_cast(C->getType())) { + // FIXME: handle scalable vectors (use getElementCount()) + GEPTy = FixedVectorType::get( + OrigGEPTy, cast(VT)->getNumElements()); + } // The GEP returns a vector of pointers when one of more of // its arguments is a vector. for (unsigned i = 0, e = Idxs.size(); i != e; ++i) { if (auto *VT = dyn_cast(Idxs[i]->getType())) { - GEPTy = FixedVectorType::get(OrigGEPTy, VT->getNumElements()); + // FIXME: handle scalable vectors + GEPTy = FixedVectorType::get( + OrigGEPTy, cast(VT)->getNumElements()); break; } } @@ -2500,19 +2516,19 @@ Constant *llvm::ConstantFoldGetElementPtr(Type *PointeeTy, Constant *C, if (!IsCurrIdxVector && IsPrevIdxVector) CurrIdx = ConstantDataVector::getSplat( - cast(PrevIdx->getType())->getNumElements(), CurrIdx); + cast(PrevIdx->getType())->getNumElements(), CurrIdx); if (!IsPrevIdxVector && IsCurrIdxVector) PrevIdx = ConstantDataVector::getSplat( - cast(CurrIdx->getType())->getNumElements(), PrevIdx); + cast(CurrIdx->getType())->getNumElements(), PrevIdx); Constant *Factor = ConstantInt::get(CurrIdx->getType()->getScalarType(), NumElements); if (UseVector) Factor = ConstantDataVector::getSplat( IsPrevIdxVector - ? cast(PrevIdx->getType())->getNumElements() - : cast(CurrIdx->getType())->getNumElements(), + ? cast(PrevIdx->getType())->getNumElements() + : cast(CurrIdx->getType())->getNumElements(), Factor); NewIdxs[i] = ConstantExpr::getSRem(CurrIdx, Factor); @@ -2531,8 +2547,8 @@ Constant *llvm::ConstantFoldGetElementPtr(Type *PointeeTy, Constant *C, ExtendedTy = FixedVectorType::get( ExtendedTy, IsPrevIdxVector - ? cast(PrevIdx->getType())->getNumElements() - : cast(CurrIdx->getType())->getNumElements()); + ? cast(PrevIdx->getType())->getNumElements() + : cast(CurrIdx->getType())->getNumElements()); if (!PrevIdx->getType()->isIntOrIntVectorTy(CommonExtendedWidth)) PrevIdx = ConstantExpr::getSExt(PrevIdx, ExtendedTy); diff --git a/llvm/test/Analysis/ConstantFolding/extractelement-vscale.ll b/llvm/test/Analysis/ConstantFolding/extractelement-vscale.ll new file mode 100644 index 000000000000..c4b42be45019 --- /dev/null +++ b/llvm/test/Analysis/ConstantFolding/extractelement-vscale.ll @@ -0,0 +1,13 @@ +; RUN: opt -instcombine -S < %s | FileCheck %s + +; CHECK-LABEL: definitely_in_bounds +; CHECK: ret i8 0 +define i8 @definitely_in_bounds() { + ret i8 extractelement ( zeroinitializer, i64 15) +} + +; CHECK-LABEL: maybe_in_bounds +; CHECK: ret i8 extractelement ( zeroinitializer, i64 16) +define i8 @maybe_in_bounds() { + ret i8 extractelement ( zeroinitializer, i64 16) +}