diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h index 6ea6d2361eba..c6637709ff7b 100644 --- a/llvm/include/llvm/Analysis/TargetTransformInfo.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -1042,6 +1042,9 @@ public: /// \return True if prefetching should also be done for writes. bool enableWritePrefetching() const; + /// \return if target want to issue a prefetch in address space \p AS. + bool shouldPrefetchAddressSpace(unsigned AS) const; + /// \return The maximum interleave factor that any transform should try to /// perform for this target. This number depends on the level of parallelism /// and the number of execution units in the CPU. @@ -1705,6 +1708,9 @@ public: /// \return True if prefetching should also be done for writes. virtual bool enableWritePrefetching() const = 0; + /// \return if target want to issue a prefetch in address space \p AS. + virtual bool shouldPrefetchAddressSpace(unsigned AS) const = 0; + virtual unsigned getMaxInterleaveFactor(unsigned VF) = 0; virtual InstructionCost getArithmeticInstrCost( unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind, @@ -2231,6 +2237,11 @@ public: return Impl.enableWritePrefetching(); } + /// \return if target want to issue a prefetch in address space \p AS. + bool shouldPrefetchAddressSpace(unsigned AS) const override { + return Impl.shouldPrefetchAddressSpace(AS); + } + unsigned getMaxInterleaveFactor(unsigned VF) override { return Impl.getMaxInterleaveFactor(VF); } diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h index 1a75cb35549e..eb1e688735d6 100644 --- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -475,6 +475,7 @@ public: } unsigned getMaxPrefetchIterationsAhead() const { return UINT_MAX; } bool enableWritePrefetching() const { return false; } + bool shouldPrefetchAddressSpace(unsigned AS) const { return !AS; } unsigned getMaxInterleaveFactor(unsigned VF) const { return 1; } diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h index c35a9e878613..2dc63917ea5e 100644 --- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h +++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h @@ -683,6 +683,10 @@ public: return getST()->enableWritePrefetching(); } + virtual bool shouldPrefetchAddressSpace(unsigned AS) const { + return getST()->shouldPrefetchAddressSpace(AS); + } + /// @} /// \name Vector TTI Implementations diff --git a/llvm/include/llvm/MC/MCSubtargetInfo.h b/llvm/include/llvm/MC/MCSubtargetInfo.h index e1f0a86141e3..0f33d3b6a239 100644 --- a/llvm/include/llvm/MC/MCSubtargetInfo.h +++ b/llvm/include/llvm/MC/MCSubtargetInfo.h @@ -282,6 +282,9 @@ public: unsigned NumStridedMemAccesses, unsigned NumPrefetches, bool HasCall) const; + + /// \return if target want to issue a prefetch in address space \p AS. + virtual bool shouldPrefetchAddressSpace(unsigned AS) const; }; } // end namespace llvm diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp index cfa6e3a97626..afd24950667d 100644 --- a/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -704,6 +704,10 @@ bool TargetTransformInfo::enableWritePrefetching() const { return TTIImpl->enableWritePrefetching(); } +bool TargetTransformInfo::shouldPrefetchAddressSpace(unsigned AS) const { + return TTIImpl->shouldPrefetchAddressSpace(AS); +} + unsigned TargetTransformInfo::getMaxInterleaveFactor(unsigned VF) const { return TTIImpl->getMaxInterleaveFactor(VF); } diff --git a/llvm/lib/MC/MCSubtargetInfo.cpp b/llvm/lib/MC/MCSubtargetInfo.cpp index 33971e5dc171..defb1436146f 100644 --- a/llvm/lib/MC/MCSubtargetInfo.cpp +++ b/llvm/lib/MC/MCSubtargetInfo.cpp @@ -366,3 +366,7 @@ unsigned MCSubtargetInfo::getMinPrefetchStride(unsigned NumMemAccesses, bool HasCall) const { return 1; } + +bool MCSubtargetInfo::shouldPrefetchAddressSpace(unsigned AS) const { + return !AS; +} diff --git a/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp b/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp index 013a119c5096..7c2770979a90 100644 --- a/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp +++ b/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp @@ -338,7 +338,7 @@ bool LoopDataPrefetch::runOnLoop(Loop *L) { } else continue; unsigned PtrAddrSpace = PtrValue->getType()->getPointerAddressSpace(); - if (PtrAddrSpace) + if (!TTI->shouldPrefetchAddressSpace(PtrAddrSpace)) continue; NumMemAccesses++; if (L->isLoopInvariant(PtrValue)) @@ -398,7 +398,8 @@ bool LoopDataPrefetch::runOnLoop(Loop *L) { if (!SCEVE.isSafeToExpand(NextLSCEV)) continue; - Type *I8Ptr = Type::getInt8PtrTy(BB->getContext(), 0/*PtrAddrSpace*/); + unsigned PtrAddrSpace = NextLSCEV->getType()->getPointerAddressSpace(); + Type *I8Ptr = Type::getInt8PtrTy(BB->getContext(), PtrAddrSpace); Value *PrefPtrValue = SCEVE.expandCodeFor(NextLSCEV, I8Ptr, P.InsertPt); IRBuilder<> Builder(P.InsertPt);