diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp index 3d60e205b002..13b2703d5109 100644 --- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp +++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp @@ -217,7 +217,7 @@ private: bool processLoopMemCpy(MemCpyInst *MCI, const SCEV *BECount); bool processLoopMemSet(MemSetInst *MSI, const SCEV *BECount); - bool processLoopStridedStore(Value *DestPtr, unsigned StoreSize, + bool processLoopStridedStore(Value *DestPtr, const SCEV *StoreSizeSCEV, MaybeAlign StoreAlignment, Value *StoredVal, Instruction *TheStore, SmallPtrSetImpl &Stores, @@ -786,7 +786,8 @@ bool LoopIdiomRecognize::processLoopStores(SmallVectorImpl &SL, bool NegStride = StoreSize == -Stride; - if (processLoopStridedStore(StorePtr, StoreSize, + const SCEV *StoreSizeSCEV = SE->getConstant(BECount->getType(), StoreSize); + if (processLoopStridedStore(StorePtr, StoreSizeSCEV, MaybeAlign(HeadStore->getAlignment()), StoredVal, HeadStore, AdjacentStores, StoreEv, BECount, NegStride)) { @@ -936,9 +937,10 @@ bool LoopIdiomRecognize::processLoopMemSet(MemSetInst *MSI, SmallPtrSet MSIs; MSIs.insert(MSI); bool NegStride = SizeInBytes == -Stride; - return processLoopStridedStore( - Pointer, (unsigned)SizeInBytes, MaybeAlign(MSI->getDestAlignment()), - SplatValue, MSI, MSIs, Ev, BECount, NegStride, /*IsLoopMemset=*/true); + return processLoopStridedStore(Pointer, SE->getSCEV(MSI->getLength()), + MaybeAlign(MSI->getDestAlignment()), + SplatValue, MSI, MSIs, Ev, BECount, NegStride, + /*IsLoopMemset=*/true); } /// mayLoopAccessLocation - Return true if the specified loop might access the @@ -946,7 +948,7 @@ bool LoopIdiomRecognize::processLoopMemSet(MemSetInst *MSI, /// argument specifies what the verboten forms of access are (read or write). static bool mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L, - const SCEV *BECount, unsigned StoreSize, + const SCEV *BECount, const SCEV *StoreSizeSCEV, AliasAnalysis &AA, SmallPtrSetImpl &IgnoredStores) { // Get the location that may be stored across the loop. Since the access is @@ -956,9 +958,11 @@ mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L, // If the loop iterates a fixed number of times, we can refine the access size // to be exactly the size of the memset, which is (BECount+1)*StoreSize - if (const SCEVConstant *BECst = dyn_cast(BECount)) + const SCEVConstant *BECst = dyn_cast(BECount); + const SCEVConstant *ConstSize = dyn_cast(StoreSizeSCEV); + if (BECst && ConstSize) AccessSize = LocationSize::precise((BECst->getValue()->getZExtValue() + 1) * - StoreSize); + ConstSize->getValue()->getZExtValue()); // TODO: For this to be really effective, we have to dive into the pointer // operand in the store. Store to &A[i] of 100 will always return may alias @@ -973,7 +977,6 @@ mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L, isModOrRefSet( intersectModRef(AA.getModRefInfo(&I, StoreLoc), Access))) return true; - return false; } @@ -981,15 +984,46 @@ mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L, // we're trying to memset. Therefore, we need to recompute the base pointer, // which is just Start - BECount*Size. static const SCEV *getStartForNegStride(const SCEV *Start, const SCEV *BECount, - Type *IntPtr, unsigned StoreSize, + Type *IntPtr, const SCEV *StoreSizeSCEV, ScalarEvolution *SE) { const SCEV *Index = SE->getTruncateOrZeroExtend(BECount, IntPtr); - if (StoreSize != 1) - Index = SE->getMulExpr(Index, SE->getConstant(IntPtr, StoreSize), + if (!StoreSizeSCEV->isOne()) { + // index = back edge count * store size + Index = SE->getMulExpr(Index, + SE->getTruncateOrZeroExtend(StoreSizeSCEV, IntPtr), SCEV::FlagNUW); + } + // base pointer = start - index * store size return SE->getMinusSCEV(Start, Index); } +/// Compute trip count from the backedge taken count. +static const SCEV *getTripCount(const SCEV *BECount, Type *IntPtr, + Loop *CurLoop, const DataLayout *DL, + ScalarEvolution *SE) { + const SCEV *TripCountS = nullptr; + // The # stored bytes is (BECount+1). Expand the trip count out to + // pointer size if it isn't already. + // + // If we're going to need to zero extend the BE count, check if we can add + // one to it prior to zero extending without overflow. Provided this is safe, + // it allows better simplification of the +1. + if (DL->getTypeSizeInBits(BECount->getType()) < + DL->getTypeSizeInBits(IntPtr) && + SE->isLoopEntryGuardedByCond( + CurLoop, ICmpInst::ICMP_NE, BECount, + SE->getNegativeSCEV(SE->getOne(BECount->getType())))) { + TripCountS = SE->getZeroExtendExpr( + SE->getAddExpr(BECount, SE->getOne(BECount->getType()), SCEV::FlagNUW), + IntPtr); + } else { + TripCountS = SE->getAddExpr(SE->getTruncateOrZeroExtend(BECount, IntPtr), + SE->getOne(IntPtr), SCEV::FlagNUW); + } + + return TripCountS; +} + /// Compute the number of bytes as a SCEV from the backedge taken count. /// /// This also maps the SCEV into the provided type and tries to handle the @@ -997,38 +1031,31 @@ static const SCEV *getStartForNegStride(const SCEV *Start, const SCEV *BECount, static const SCEV *getNumBytes(const SCEV *BECount, Type *IntPtr, unsigned StoreSize, Loop *CurLoop, const DataLayout *DL, ScalarEvolution *SE) { - const SCEV *NumBytesS; - // The # stored bytes is (BECount+1)*Size. Expand the trip count out to - // pointer size if it isn't already. - // - // If we're going to need to zero extend the BE count, check if we can add - // one to it prior to zero extending without overflow. Provided this is safe, - // it allows better simplification of the +1. - if (DL->getTypeSizeInBits(BECount->getType()).getFixedSize() < - DL->getTypeSizeInBits(IntPtr).getFixedSize() && - SE->isLoopEntryGuardedByCond( - CurLoop, ICmpInst::ICMP_NE, BECount, - SE->getNegativeSCEV(SE->getOne(BECount->getType())))) { - NumBytesS = SE->getZeroExtendExpr( - SE->getAddExpr(BECount, SE->getOne(BECount->getType()), SCEV::FlagNUW), - IntPtr); - } else { - NumBytesS = SE->getAddExpr(SE->getTruncateOrZeroExtend(BECount, IntPtr), - SE->getOne(IntPtr), SCEV::FlagNUW); - } + const SCEV *TripCountSCEV = getTripCount(BECount, IntPtr, CurLoop, DL, SE); // And scale it based on the store size. if (StoreSize != 1) { - NumBytesS = SE->getMulExpr(NumBytesS, SE->getConstant(IntPtr, StoreSize), - SCEV::FlagNUW); + return SE->getMulExpr(TripCountSCEV, SE->getConstant(IntPtr, StoreSize), + SCEV::FlagNUW); } - return NumBytesS; + return TripCountSCEV; +} + +/// getNumBytes that takes StoreSize as a SCEV +static const SCEV *getNumBytes(const SCEV *BECount, Type *IntPtr, + const SCEV *StoreSizeSCEV, Loop *CurLoop, + const DataLayout *DL, ScalarEvolution *SE) { + const SCEV *TripCountSCEV = getTripCount(BECount, IntPtr, CurLoop, DL, SE); + + return SE->getMulExpr(TripCountSCEV, + SE->getTruncateOrZeroExtend(StoreSizeSCEV, IntPtr), + SCEV::FlagNUW); } /// processLoopStridedStore - We see a strided store of some value. If we can /// transform this into a memset or memset_pattern in the loop preheader, do so. bool LoopIdiomRecognize::processLoopStridedStore( - Value *DestPtr, unsigned StoreSize, MaybeAlign StoreAlignment, + Value *DestPtr, const SCEV *StoreSizeSCEV, MaybeAlign StoreAlignment, Value *StoredVal, Instruction *TheStore, SmallPtrSetImpl &Stores, const SCEVAddRecExpr *Ev, const SCEV *BECount, bool NegStride, bool IsLoopMemset) { @@ -1057,7 +1084,7 @@ bool LoopIdiomRecognize::processLoopStridedStore( const SCEV *Start = Ev->getStart(); // Handle negative strided loops. if (NegStride) - Start = getStartForNegStride(Start, BECount, IntIdxTy, StoreSize, SE); + Start = getStartForNegStride(Start, BECount, IntIdxTy, StoreSizeSCEV, SE); // TODO: ideally we should still be able to generate memset if SCEV expander // is taught to generate the dependencies at the latest point. @@ -1082,7 +1109,7 @@ bool LoopIdiomRecognize::processLoopStridedStore( Changed = true; if (mayLoopAccessLocation(BasePtr, ModRefInfo::ModRef, CurLoop, BECount, - StoreSize, *AA, Stores)) + StoreSizeSCEV, *AA, Stores)) return Changed; if (avoidLIRForMultiBlockLoop(/*IsMemset=*/true, IsLoopMemset)) @@ -1091,7 +1118,7 @@ bool LoopIdiomRecognize::processLoopStridedStore( // Okay, everything looks good, insert the memset. const SCEV *NumBytesS = - getNumBytes(BECount, IntIdxTy, StoreSize, CurLoop, DL, SE); + getNumBytes(BECount, IntIdxTy, StoreSizeSCEV, CurLoop, DL, SE); // TODO: ideally we should still be able to generate memset if SCEV expander // is taught to generate the dependencies at the latest point. @@ -1215,9 +1242,11 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad( APInt Stride = getStoreStride(StoreEv); bool NegStride = StoreSize == -Stride; + const SCEV *StoreSizeSCEV = SE->getConstant(BECount->getType(), StoreSize); // Handle negative strided loops. if (NegStride) - StrStart = getStartForNegStride(StrStart, BECount, IntIdxTy, StoreSize, SE); + StrStart = + getStartForNegStride(StrStart, BECount, IntIdxTy, StoreSizeSCEV, SE); // Okay, we have a strided store "p[i]" of a loaded value. We can turn // this into a memcpy in the loop preheader now if we want. However, this @@ -1245,11 +1274,11 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad( bool UseMemMove = mayLoopAccessLocation(StoreBasePtr, ModRefInfo::ModRef, CurLoop, BECount, - StoreSize, *AA, Stores); + StoreSizeSCEV, *AA, Stores); if (UseMemMove) { Stores.insert(TheLoad); if (mayLoopAccessLocation(StoreBasePtr, ModRefInfo::ModRef, CurLoop, - BECount, StoreSize, *AA, Stores)) { + BECount, StoreSizeSCEV, *AA, Stores)) { ORE.emit([&]() { return OptimizationRemarkMissed(DEBUG_TYPE, "LoopMayAccessStore", TheStore) @@ -1268,7 +1297,8 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad( // Handle negative strided loops. if (NegStride) - LdStart = getStartForNegStride(LdStart, BECount, IntIdxTy, StoreSize, SE); + LdStart = + getStartForNegStride(LdStart, BECount, IntIdxTy, StoreSizeSCEV, SE); // For a memcpy, we have to make sure that the input array is not being // mutated by the loop. @@ -1280,7 +1310,7 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad( if (IsMemCpy) Stores.erase(TheStore); if (mayLoopAccessLocation(LoadBasePtr, ModRefInfo::Mod, CurLoop, BECount, - StoreSize, *AA, Stores)) { + StoreSizeSCEV, *AA, Stores)) { ORE.emit([&]() { return OptimizationRemarkMissed(DEBUG_TYPE, "LoopMayAccessLoad", TheLoad) << ore::NV("Inst", InstRemark) << " in "