From 9f87d951fccfc21239af3ac981a6cc7941da3f93 Mon Sep 17 00:00:00 2001 From: Christopher Tetreault Date: Thu, 9 Apr 2020 16:15:49 -0700 Subject: [PATCH] Clean up usages of asserting vector getters in Type Summary: Remove usages of asserting vector getters in Type in preparation for the VectorType refactor. The existence of these functions complicates the refactor while adding little value. Reviewers: mcrosier, efriedma, sdesmalen Reviewed By: efriedma Subscribers: hiraditya, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D77269 --- .../Target/AArch64/AArch64ISelLowering.cpp | 41 +++++++++---------- .../AArch64/AArch64TargetTransformInfo.cpp | 22 +++++----- .../AArch64/AArch64TargetTransformInfo.h | 9 ++-- 3 files changed, 34 insertions(+), 38 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 523ffe993efc..dd309488959f 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -9376,10 +9376,9 @@ bool AArch64TargetLowering::lowerInterleavedLoad( // A pointer vector can not be the return type of the ldN intrinsics. Need to // load integer vectors first and then convert to pointer vectors. - Type *EltTy = VecTy->getVectorElementType(); + Type *EltTy = VecTy->getElementType(); if (EltTy->isPointerTy()) - VecTy = - VectorType::get(DL.getIntPtrType(EltTy), VecTy->getVectorNumElements()); + VecTy = VectorType::get(DL.getIntPtrType(EltTy), VecTy->getNumElements()); IRBuilder<> Builder(LI); @@ -9389,15 +9388,15 @@ bool AArch64TargetLowering::lowerInterleavedLoad( if (NumLoads > 1) { // If we're going to generate more than one load, reset the sub-vector type // to something legal. - VecTy = VectorType::get(VecTy->getVectorElementType(), - VecTy->getVectorNumElements() / NumLoads); + VecTy = VectorType::get(VecTy->getElementType(), + VecTy->getNumElements() / NumLoads); // We will compute the pointer operand of each load from the original base // address using GEPs. Cast the base address to a pointer to the scalar // element type. BaseAddr = Builder.CreateBitCast( - BaseAddr, VecTy->getVectorElementType()->getPointerTo( - LI->getPointerAddressSpace())); + BaseAddr, + VecTy->getElementType()->getPointerTo(LI->getPointerAddressSpace())); } Type *PtrTy = VecTy->getPointerTo(LI->getPointerAddressSpace()); @@ -9418,9 +9417,8 @@ bool AArch64TargetLowering::lowerInterleavedLoad( // If we're generating more than one load, compute the base address of // subsequent loads as an offset from the previous. if (LoadCount > 0) - BaseAddr = - Builder.CreateConstGEP1_32(VecTy->getVectorElementType(), BaseAddr, - VecTy->getVectorNumElements() * Factor); + BaseAddr = Builder.CreateConstGEP1_32(VecTy->getElementType(), BaseAddr, + VecTy->getNumElements() * Factor); CallInst *LdN = Builder.CreateCall( LdNFunc, Builder.CreateBitCast(BaseAddr, PtrTy), "ldN"); @@ -9435,8 +9433,8 @@ bool AArch64TargetLowering::lowerInterleavedLoad( // Convert the integer vector to pointer vector if the element is pointer. if (EltTy->isPointerTy()) SubVec = Builder.CreateIntToPtr( - SubVec, VectorType::get(SVI->getType()->getVectorElementType(), - VecTy->getVectorNumElements())); + SubVec, VectorType::get(SVI->getType()->getElementType(), + VecTy->getNumElements())); SubVecs[SVI].push_back(SubVec); } } @@ -9488,11 +9486,10 @@ bool AArch64TargetLowering::lowerInterleavedStore(StoreInst *SI, "Invalid interleave factor"); VectorType *VecTy = SVI->getType(); - assert(VecTy->getVectorNumElements() % Factor == 0 && - "Invalid interleaved store"); + assert(VecTy->getNumElements() % Factor == 0 && "Invalid interleaved store"); - unsigned LaneLen = VecTy->getVectorNumElements() / Factor; - Type *EltTy = VecTy->getVectorElementType(); + unsigned LaneLen = VecTy->getNumElements() / Factor; + Type *EltTy = VecTy->getElementType(); VectorType *SubVecTy = VectorType::get(EltTy, LaneLen); const DataLayout &DL = SI->getModule()->getDataLayout(); @@ -9513,7 +9510,7 @@ bool AArch64TargetLowering::lowerInterleavedStore(StoreInst *SI, // vectors to integer vectors. if (EltTy->isPointerTy()) { Type *IntTy = DL.getIntPtrType(EltTy); - unsigned NumOpElts = Op0->getType()->getVectorNumElements(); + unsigned NumOpElts = cast(Op0->getType())->getNumElements(); // Convert to the corresponding integer vector. Type *IntVecTy = VectorType::get(IntTy, NumOpElts); @@ -9530,14 +9527,14 @@ bool AArch64TargetLowering::lowerInterleavedStore(StoreInst *SI, // If we're going to generate more than one store, reset the lane length // and sub-vector type to something legal. LaneLen /= NumStores; - SubVecTy = VectorType::get(SubVecTy->getVectorElementType(), LaneLen); + SubVecTy = VectorType::get(SubVecTy->getElementType(), LaneLen); // We will compute the pointer operand of each store from the original base // address using GEPs. Cast the base address to a pointer to the scalar // element type. BaseAddr = Builder.CreateBitCast( - BaseAddr, SubVecTy->getVectorElementType()->getPointerTo( - SI->getPointerAddressSpace())); + BaseAddr, + SubVecTy->getElementType()->getPointerTo(SI->getPointerAddressSpace())); } auto Mask = SVI->getShuffleMask(); @@ -9582,7 +9579,7 @@ bool AArch64TargetLowering::lowerInterleavedStore(StoreInst *SI, // If we generating more than one store, we compute the base address of // subsequent stores as an offset from the previous. if (StoreCount > 0) - BaseAddr = Builder.CreateConstGEP1_32(SubVecTy->getVectorElementType(), + BaseAddr = Builder.CreateConstGEP1_32(SubVecTy->getElementType(), BaseAddr, LaneLen * Factor); Ops.push_back(Builder.CreateBitCast(BaseAddr, PtrTy)); @@ -9697,7 +9694,7 @@ bool AArch64TargetLowering::isLegalAddressingMode(const DataLayout &DL, return false; // FIXME: Update this method to support scalable addressing modes. - if (Ty->isVectorTy() && Ty->getVectorIsScalable()) + if (Ty->isVectorTy() && cast(Ty)->isScalable()) return AM.HasBaseReg && !AM.BaseOffs && !AM.Scale; // check reg + imm case: diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index e8ba30c7e92a..cd8b71767997 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -209,7 +209,7 @@ bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode, // elements in type Ty determine the vector width. auto toVectorTy = [&](Type *ArgTy) { return VectorType::get(ArgTy->getScalarType(), - DstTy->getVectorNumElements()); + cast(DstTy)->getNumElements()); }; // Exit early if DstTy is not a vector type whose elements are at least @@ -661,7 +661,8 @@ int AArch64TTIImpl::getMemoryOpCost(unsigned Opcode, Type *Ty, return LT.first * 2 * AmortizationCost; } - if (Ty->isVectorTy() && Ty->getVectorElementType()->isIntegerTy(8)) { + if (Ty->isVectorTy() && + cast(Ty)->getElementType()->isIntegerTy(8)) { unsigned ProfitableNumElements; if (Opcode == Instruction::Store) // We use a custom trunc store lowering so v.4b should be profitable. @@ -671,8 +672,8 @@ int AArch64TTIImpl::getMemoryOpCost(unsigned Opcode, Type *Ty, // have to promote the elements to v.2. ProfitableNumElements = 8; - if (Ty->getVectorNumElements() < ProfitableNumElements) { - unsigned NumVecElts = Ty->getVectorNumElements(); + if (cast(Ty)->getNumElements() < ProfitableNumElements) { + unsigned NumVecElts = cast(Ty)->getNumElements(); unsigned NumVectorizableInstsToAmortize = NumVecElts * 2; // We generate 2 instructions per vector element. return NumVectorizableInstsToAmortize * NumVecElts * 2; @@ -690,11 +691,11 @@ int AArch64TTIImpl::getInterleavedMemoryOpCost(unsigned Opcode, Type *VecTy, bool UseMaskForCond, bool UseMaskForGaps) { assert(Factor >= 2 && "Invalid interleave factor"); - assert(isa(VecTy) && "Expect a vector type"); + auto *VecVTy = cast(VecTy); if (!UseMaskForCond && !UseMaskForGaps && Factor <= TLI->getMaxSupportedInterleaveFactor()) { - unsigned NumElts = VecTy->getVectorNumElements(); + unsigned NumElts = VecVTy->getNumElements(); auto *SubVecTy = VectorType::get(VecTy->getScalarType(), NumElts / Factor); // ldN/stN only support legal vector types of size 64 or 128 in bits. @@ -715,7 +716,7 @@ int AArch64TTIImpl::getCostOfKeepingLiveOverCall(ArrayRef Tys) { for (auto *I : Tys) { if (!I->isVectorTy()) continue; - if (I->getScalarSizeInBits() * I->getVectorNumElements() == 128) + if (I->getScalarSizeInBits() * cast(I)->getNumElements() == 128) Cost += getMemoryOpCost(Instruction::Store, I, Align(128), 0) + getMemoryOpCost(Instruction::Load, I, Align(128), 0); } @@ -907,7 +908,7 @@ bool AArch64TTIImpl::shouldConsiderAddressTypePromotion( bool AArch64TTIImpl::useReductionIntrinsic(unsigned Opcode, Type *Ty, TTI::ReductionFlags Flags) const { - assert(isa(Ty) && "Expected Ty to be a vector type"); + auto *VTy = cast(Ty); unsigned ScalarBits = Ty->getScalarSizeInBits(); switch (Opcode) { case Instruction::FAdd: @@ -918,10 +919,9 @@ bool AArch64TTIImpl::useReductionIntrinsic(unsigned Opcode, Type *Ty, case Instruction::Mul: return false; case Instruction::Add: - return ScalarBits * Ty->getVectorNumElements() >= 128; + return ScalarBits * VTy->getNumElements() >= 128; case Instruction::ICmp: - return (ScalarBits < 64) && - (ScalarBits * Ty->getVectorNumElements() >= 128); + return (ScalarBits < 64) && (ScalarBits * VTy->getNumElements() >= 128); case Instruction::FCmp: return Flags.NoNaN; default: diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h index 2f07448acc10..a47f87a7bbcf 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h @@ -153,7 +153,7 @@ public: if (!isa(DataType) || !ST->hasSVE()) return false; - Type *Ty = DataType->getVectorElementType(); + Type *Ty = cast(DataType)->getElementType(); if (Ty->isHalfTy() || Ty->isFloatTy() || Ty->isDoubleTy()) return true; @@ -180,10 +180,9 @@ public: // can be halved so that each half fits into a register. That's the case if // the element type fits into a register and the number of elements is a // power of 2 > 1. - if (isa(DataType)) { - unsigned NumElements = DataType->getVectorNumElements(); - unsigned EltSize = - DataType->getVectorElementType()->getScalarSizeInBits(); + if (auto *DataTypeVTy = dyn_cast(DataType)) { + unsigned NumElements = DataTypeVTy->getNumElements(); + unsigned EltSize = DataTypeVTy->getElementType()->getScalarSizeInBits(); return NumElements > 1 && isPowerOf2_64(NumElements) && EltSize >= 8 && EltSize <= 128 && isPowerOf2_64(EltSize); }