[NFCI] [LoopIdiom] Let processLoopStridedStore take StoreSize as SCEV instead of unsigned

Letting it take SCEV allows further modification on the function to optimize
if the StoreSize / Stride is runtime determined.

This is a preceeding of D107353.
The big picture is to let LoopIdiom deal with runtime-determined sizes.

Reviewed By: Whitney, lebedev.ri

Differential Revision: https://reviews.llvm.org/D104595
This commit is contained in:
eopXD 2021-06-19 14:26:12 +00:00 committed by eopXD
parent 9c3345ad10
commit 26aa1bbe97
1 changed files with 73 additions and 43 deletions

View File

@ -217,7 +217,7 @@ private:
bool processLoopMemCpy(MemCpyInst *MCI, const SCEV *BECount);
bool processLoopMemSet(MemSetInst *MSI, const SCEV *BECount);
bool processLoopStridedStore(Value *DestPtr, unsigned StoreSize,
bool processLoopStridedStore(Value *DestPtr, const SCEV *StoreSizeSCEV,
MaybeAlign StoreAlignment, Value *StoredVal,
Instruction *TheStore,
SmallPtrSetImpl<Instruction *> &Stores,
@ -786,7 +786,8 @@ bool LoopIdiomRecognize::processLoopStores(SmallVectorImpl<StoreInst *> &SL,
bool NegStride = StoreSize == -Stride;
if (processLoopStridedStore(StorePtr, StoreSize,
const SCEV *StoreSizeSCEV = SE->getConstant(BECount->getType(), StoreSize);
if (processLoopStridedStore(StorePtr, StoreSizeSCEV,
MaybeAlign(HeadStore->getAlignment()),
StoredVal, HeadStore, AdjacentStores, StoreEv,
BECount, NegStride)) {
@ -936,9 +937,10 @@ bool LoopIdiomRecognize::processLoopMemSet(MemSetInst *MSI,
SmallPtrSet<Instruction *, 1> MSIs;
MSIs.insert(MSI);
bool NegStride = SizeInBytes == -Stride;
return processLoopStridedStore(
Pointer, (unsigned)SizeInBytes, MaybeAlign(MSI->getDestAlignment()),
SplatValue, MSI, MSIs, Ev, BECount, NegStride, /*IsLoopMemset=*/true);
return processLoopStridedStore(Pointer, SE->getSCEV(MSI->getLength()),
MaybeAlign(MSI->getDestAlignment()),
SplatValue, MSI, MSIs, Ev, BECount, NegStride,
/*IsLoopMemset=*/true);
}
/// mayLoopAccessLocation - Return true if the specified loop might access the
@ -946,7 +948,7 @@ bool LoopIdiomRecognize::processLoopMemSet(MemSetInst *MSI,
/// argument specifies what the verboten forms of access are (read or write).
static bool
mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L,
const SCEV *BECount, unsigned StoreSize,
const SCEV *BECount, const SCEV *StoreSizeSCEV,
AliasAnalysis &AA,
SmallPtrSetImpl<Instruction *> &IgnoredStores) {
// Get the location that may be stored across the loop. Since the access is
@ -956,9 +958,11 @@ mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L,
// If the loop iterates a fixed number of times, we can refine the access size
// to be exactly the size of the memset, which is (BECount+1)*StoreSize
if (const SCEVConstant *BECst = dyn_cast<SCEVConstant>(BECount))
const SCEVConstant *BECst = dyn_cast<SCEVConstant>(BECount);
const SCEVConstant *ConstSize = dyn_cast<SCEVConstant>(StoreSizeSCEV);
if (BECst && ConstSize)
AccessSize = LocationSize::precise((BECst->getValue()->getZExtValue() + 1) *
StoreSize);
ConstSize->getValue()->getZExtValue());
// TODO: For this to be really effective, we have to dive into the pointer
// operand in the store. Store to &A[i] of 100 will always return may alias
@ -973,7 +977,6 @@ mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L,
isModOrRefSet(
intersectModRef(AA.getModRefInfo(&I, StoreLoc), Access)))
return true;
return false;
}
@ -981,15 +984,46 @@ mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L,
// we're trying to memset. Therefore, we need to recompute the base pointer,
// which is just Start - BECount*Size.
static const SCEV *getStartForNegStride(const SCEV *Start, const SCEV *BECount,
Type *IntPtr, unsigned StoreSize,
Type *IntPtr, const SCEV *StoreSizeSCEV,
ScalarEvolution *SE) {
const SCEV *Index = SE->getTruncateOrZeroExtend(BECount, IntPtr);
if (StoreSize != 1)
Index = SE->getMulExpr(Index, SE->getConstant(IntPtr, StoreSize),
if (!StoreSizeSCEV->isOne()) {
// index = back edge count * store size
Index = SE->getMulExpr(Index,
SE->getTruncateOrZeroExtend(StoreSizeSCEV, IntPtr),
SCEV::FlagNUW);
}
// base pointer = start - index * store size
return SE->getMinusSCEV(Start, Index);
}
/// Compute trip count from the backedge taken count.
static const SCEV *getTripCount(const SCEV *BECount, Type *IntPtr,
Loop *CurLoop, const DataLayout *DL,
ScalarEvolution *SE) {
const SCEV *TripCountS = nullptr;
// The # stored bytes is (BECount+1). Expand the trip count out to
// pointer size if it isn't already.
//
// If we're going to need to zero extend the BE count, check if we can add
// one to it prior to zero extending without overflow. Provided this is safe,
// it allows better simplification of the +1.
if (DL->getTypeSizeInBits(BECount->getType()) <
DL->getTypeSizeInBits(IntPtr) &&
SE->isLoopEntryGuardedByCond(
CurLoop, ICmpInst::ICMP_NE, BECount,
SE->getNegativeSCEV(SE->getOne(BECount->getType())))) {
TripCountS = SE->getZeroExtendExpr(
SE->getAddExpr(BECount, SE->getOne(BECount->getType()), SCEV::FlagNUW),
IntPtr);
} else {
TripCountS = SE->getAddExpr(SE->getTruncateOrZeroExtend(BECount, IntPtr),
SE->getOne(IntPtr), SCEV::FlagNUW);
}
return TripCountS;
}
/// Compute the number of bytes as a SCEV from the backedge taken count.
///
/// This also maps the SCEV into the provided type and tries to handle the
@ -997,38 +1031,31 @@ static const SCEV *getStartForNegStride(const SCEV *Start, const SCEV *BECount,
static const SCEV *getNumBytes(const SCEV *BECount, Type *IntPtr,
unsigned StoreSize, Loop *CurLoop,
const DataLayout *DL, ScalarEvolution *SE) {
const SCEV *NumBytesS;
// The # stored bytes is (BECount+1)*Size. Expand the trip count out to
// pointer size if it isn't already.
//
// If we're going to need to zero extend the BE count, check if we can add
// one to it prior to zero extending without overflow. Provided this is safe,
// it allows better simplification of the +1.
if (DL->getTypeSizeInBits(BECount->getType()).getFixedSize() <
DL->getTypeSizeInBits(IntPtr).getFixedSize() &&
SE->isLoopEntryGuardedByCond(
CurLoop, ICmpInst::ICMP_NE, BECount,
SE->getNegativeSCEV(SE->getOne(BECount->getType())))) {
NumBytesS = SE->getZeroExtendExpr(
SE->getAddExpr(BECount, SE->getOne(BECount->getType()), SCEV::FlagNUW),
IntPtr);
} else {
NumBytesS = SE->getAddExpr(SE->getTruncateOrZeroExtend(BECount, IntPtr),
SE->getOne(IntPtr), SCEV::FlagNUW);
}
const SCEV *TripCountSCEV = getTripCount(BECount, IntPtr, CurLoop, DL, SE);
// And scale it based on the store size.
if (StoreSize != 1) {
NumBytesS = SE->getMulExpr(NumBytesS, SE->getConstant(IntPtr, StoreSize),
return SE->getMulExpr(TripCountSCEV, SE->getConstant(IntPtr, StoreSize),
SCEV::FlagNUW);
}
return NumBytesS;
return TripCountSCEV;
}
/// getNumBytes that takes StoreSize as a SCEV
static const SCEV *getNumBytes(const SCEV *BECount, Type *IntPtr,
const SCEV *StoreSizeSCEV, Loop *CurLoop,
const DataLayout *DL, ScalarEvolution *SE) {
const SCEV *TripCountSCEV = getTripCount(BECount, IntPtr, CurLoop, DL, SE);
return SE->getMulExpr(TripCountSCEV,
SE->getTruncateOrZeroExtend(StoreSizeSCEV, IntPtr),
SCEV::FlagNUW);
}
/// processLoopStridedStore - We see a strided store of some value. If we can
/// transform this into a memset or memset_pattern in the loop preheader, do so.
bool LoopIdiomRecognize::processLoopStridedStore(
Value *DestPtr, unsigned StoreSize, MaybeAlign StoreAlignment,
Value *DestPtr, const SCEV *StoreSizeSCEV, MaybeAlign StoreAlignment,
Value *StoredVal, Instruction *TheStore,
SmallPtrSetImpl<Instruction *> &Stores, const SCEVAddRecExpr *Ev,
const SCEV *BECount, bool NegStride, bool IsLoopMemset) {
@ -1057,7 +1084,7 @@ bool LoopIdiomRecognize::processLoopStridedStore(
const SCEV *Start = Ev->getStart();
// Handle negative strided loops.
if (NegStride)
Start = getStartForNegStride(Start, BECount, IntIdxTy, StoreSize, SE);
Start = getStartForNegStride(Start, BECount, IntIdxTy, StoreSizeSCEV, SE);
// TODO: ideally we should still be able to generate memset if SCEV expander
// is taught to generate the dependencies at the latest point.
@ -1082,7 +1109,7 @@ bool LoopIdiomRecognize::processLoopStridedStore(
Changed = true;
if (mayLoopAccessLocation(BasePtr, ModRefInfo::ModRef, CurLoop, BECount,
StoreSize, *AA, Stores))
StoreSizeSCEV, *AA, Stores))
return Changed;
if (avoidLIRForMultiBlockLoop(/*IsMemset=*/true, IsLoopMemset))
@ -1091,7 +1118,7 @@ bool LoopIdiomRecognize::processLoopStridedStore(
// Okay, everything looks good, insert the memset.
const SCEV *NumBytesS =
getNumBytes(BECount, IntIdxTy, StoreSize, CurLoop, DL, SE);
getNumBytes(BECount, IntIdxTy, StoreSizeSCEV, CurLoop, DL, SE);
// TODO: ideally we should still be able to generate memset if SCEV expander
// is taught to generate the dependencies at the latest point.
@ -1215,9 +1242,11 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(
APInt Stride = getStoreStride(StoreEv);
bool NegStride = StoreSize == -Stride;
const SCEV *StoreSizeSCEV = SE->getConstant(BECount->getType(), StoreSize);
// Handle negative strided loops.
if (NegStride)
StrStart = getStartForNegStride(StrStart, BECount, IntIdxTy, StoreSize, SE);
StrStart =
getStartForNegStride(StrStart, BECount, IntIdxTy, StoreSizeSCEV, SE);
// Okay, we have a strided store "p[i]" of a loaded value. We can turn
// this into a memcpy in the loop preheader now if we want. However, this
@ -1245,11 +1274,11 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(
bool UseMemMove =
mayLoopAccessLocation(StoreBasePtr, ModRefInfo::ModRef, CurLoop, BECount,
StoreSize, *AA, Stores);
StoreSizeSCEV, *AA, Stores);
if (UseMemMove) {
Stores.insert(TheLoad);
if (mayLoopAccessLocation(StoreBasePtr, ModRefInfo::ModRef, CurLoop,
BECount, StoreSize, *AA, Stores)) {
BECount, StoreSizeSCEV, *AA, Stores)) {
ORE.emit([&]() {
return OptimizationRemarkMissed(DEBUG_TYPE, "LoopMayAccessStore",
TheStore)
@ -1268,7 +1297,8 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(
// Handle negative strided loops.
if (NegStride)
LdStart = getStartForNegStride(LdStart, BECount, IntIdxTy, StoreSize, SE);
LdStart =
getStartForNegStride(LdStart, BECount, IntIdxTy, StoreSizeSCEV, SE);
// For a memcpy, we have to make sure that the input array is not being
// mutated by the loop.
@ -1280,7 +1310,7 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(
if (IsMemCpy)
Stores.erase(TheStore);
if (mayLoopAccessLocation(LoadBasePtr, ModRefInfo::Mod, CurLoop, BECount,
StoreSize, *AA, Stores)) {
StoreSizeSCEV, *AA, Stores)) {
ORE.emit([&]() {
return OptimizationRemarkMissed(DEBUG_TYPE, "LoopMayAccessLoad", TheLoad)
<< ore::NV("Inst", InstRemark) << " in "