forked from OSchip/llvm-project
[NFCI] [LoopIdiom] Let processLoopStridedStore take StoreSize as SCEV instead of unsigned
Letting it take SCEV allows further modification on the function to optimize if the StoreSize / Stride is runtime determined. This is a preceeding of D107353. The big picture is to let LoopIdiom deal with runtime-determined sizes. Reviewed By: Whitney, lebedev.ri Differential Revision: https://reviews.llvm.org/D104595
This commit is contained in:
parent
9c3345ad10
commit
26aa1bbe97
|
@ -217,7 +217,7 @@ private:
|
||||||
bool processLoopMemCpy(MemCpyInst *MCI, const SCEV *BECount);
|
bool processLoopMemCpy(MemCpyInst *MCI, const SCEV *BECount);
|
||||||
bool processLoopMemSet(MemSetInst *MSI, const SCEV *BECount);
|
bool processLoopMemSet(MemSetInst *MSI, const SCEV *BECount);
|
||||||
|
|
||||||
bool processLoopStridedStore(Value *DestPtr, unsigned StoreSize,
|
bool processLoopStridedStore(Value *DestPtr, const SCEV *StoreSizeSCEV,
|
||||||
MaybeAlign StoreAlignment, Value *StoredVal,
|
MaybeAlign StoreAlignment, Value *StoredVal,
|
||||||
Instruction *TheStore,
|
Instruction *TheStore,
|
||||||
SmallPtrSetImpl<Instruction *> &Stores,
|
SmallPtrSetImpl<Instruction *> &Stores,
|
||||||
|
@ -786,7 +786,8 @@ bool LoopIdiomRecognize::processLoopStores(SmallVectorImpl<StoreInst *> &SL,
|
||||||
|
|
||||||
bool NegStride = StoreSize == -Stride;
|
bool NegStride = StoreSize == -Stride;
|
||||||
|
|
||||||
if (processLoopStridedStore(StorePtr, StoreSize,
|
const SCEV *StoreSizeSCEV = SE->getConstant(BECount->getType(), StoreSize);
|
||||||
|
if (processLoopStridedStore(StorePtr, StoreSizeSCEV,
|
||||||
MaybeAlign(HeadStore->getAlignment()),
|
MaybeAlign(HeadStore->getAlignment()),
|
||||||
StoredVal, HeadStore, AdjacentStores, StoreEv,
|
StoredVal, HeadStore, AdjacentStores, StoreEv,
|
||||||
BECount, NegStride)) {
|
BECount, NegStride)) {
|
||||||
|
@ -936,9 +937,10 @@ bool LoopIdiomRecognize::processLoopMemSet(MemSetInst *MSI,
|
||||||
SmallPtrSet<Instruction *, 1> MSIs;
|
SmallPtrSet<Instruction *, 1> MSIs;
|
||||||
MSIs.insert(MSI);
|
MSIs.insert(MSI);
|
||||||
bool NegStride = SizeInBytes == -Stride;
|
bool NegStride = SizeInBytes == -Stride;
|
||||||
return processLoopStridedStore(
|
return processLoopStridedStore(Pointer, SE->getSCEV(MSI->getLength()),
|
||||||
Pointer, (unsigned)SizeInBytes, MaybeAlign(MSI->getDestAlignment()),
|
MaybeAlign(MSI->getDestAlignment()),
|
||||||
SplatValue, MSI, MSIs, Ev, BECount, NegStride, /*IsLoopMemset=*/true);
|
SplatValue, MSI, MSIs, Ev, BECount, NegStride,
|
||||||
|
/*IsLoopMemset=*/true);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// mayLoopAccessLocation - Return true if the specified loop might access the
|
/// mayLoopAccessLocation - Return true if the specified loop might access the
|
||||||
|
@ -946,7 +948,7 @@ bool LoopIdiomRecognize::processLoopMemSet(MemSetInst *MSI,
|
||||||
/// argument specifies what the verboten forms of access are (read or write).
|
/// argument specifies what the verboten forms of access are (read or write).
|
||||||
static bool
|
static bool
|
||||||
mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L,
|
mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L,
|
||||||
const SCEV *BECount, unsigned StoreSize,
|
const SCEV *BECount, const SCEV *StoreSizeSCEV,
|
||||||
AliasAnalysis &AA,
|
AliasAnalysis &AA,
|
||||||
SmallPtrSetImpl<Instruction *> &IgnoredStores) {
|
SmallPtrSetImpl<Instruction *> &IgnoredStores) {
|
||||||
// Get the location that may be stored across the loop. Since the access is
|
// Get the location that may be stored across the loop. Since the access is
|
||||||
|
@ -956,9 +958,11 @@ mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L,
|
||||||
|
|
||||||
// If the loop iterates a fixed number of times, we can refine the access size
|
// If the loop iterates a fixed number of times, we can refine the access size
|
||||||
// to be exactly the size of the memset, which is (BECount+1)*StoreSize
|
// to be exactly the size of the memset, which is (BECount+1)*StoreSize
|
||||||
if (const SCEVConstant *BECst = dyn_cast<SCEVConstant>(BECount))
|
const SCEVConstant *BECst = dyn_cast<SCEVConstant>(BECount);
|
||||||
|
const SCEVConstant *ConstSize = dyn_cast<SCEVConstant>(StoreSizeSCEV);
|
||||||
|
if (BECst && ConstSize)
|
||||||
AccessSize = LocationSize::precise((BECst->getValue()->getZExtValue() + 1) *
|
AccessSize = LocationSize::precise((BECst->getValue()->getZExtValue() + 1) *
|
||||||
StoreSize);
|
ConstSize->getValue()->getZExtValue());
|
||||||
|
|
||||||
// TODO: For this to be really effective, we have to dive into the pointer
|
// TODO: For this to be really effective, we have to dive into the pointer
|
||||||
// operand in the store. Store to &A[i] of 100 will always return may alias
|
// operand in the store. Store to &A[i] of 100 will always return may alias
|
||||||
|
@ -973,7 +977,6 @@ mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L,
|
||||||
isModOrRefSet(
|
isModOrRefSet(
|
||||||
intersectModRef(AA.getModRefInfo(&I, StoreLoc), Access)))
|
intersectModRef(AA.getModRefInfo(&I, StoreLoc), Access)))
|
||||||
return true;
|
return true;
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -981,15 +984,46 @@ mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L,
|
||||||
// we're trying to memset. Therefore, we need to recompute the base pointer,
|
// we're trying to memset. Therefore, we need to recompute the base pointer,
|
||||||
// which is just Start - BECount*Size.
|
// which is just Start - BECount*Size.
|
||||||
static const SCEV *getStartForNegStride(const SCEV *Start, const SCEV *BECount,
|
static const SCEV *getStartForNegStride(const SCEV *Start, const SCEV *BECount,
|
||||||
Type *IntPtr, unsigned StoreSize,
|
Type *IntPtr, const SCEV *StoreSizeSCEV,
|
||||||
ScalarEvolution *SE) {
|
ScalarEvolution *SE) {
|
||||||
const SCEV *Index = SE->getTruncateOrZeroExtend(BECount, IntPtr);
|
const SCEV *Index = SE->getTruncateOrZeroExtend(BECount, IntPtr);
|
||||||
if (StoreSize != 1)
|
if (!StoreSizeSCEV->isOne()) {
|
||||||
Index = SE->getMulExpr(Index, SE->getConstant(IntPtr, StoreSize),
|
// index = back edge count * store size
|
||||||
|
Index = SE->getMulExpr(Index,
|
||||||
|
SE->getTruncateOrZeroExtend(StoreSizeSCEV, IntPtr),
|
||||||
SCEV::FlagNUW);
|
SCEV::FlagNUW);
|
||||||
|
}
|
||||||
|
// base pointer = start - index * store size
|
||||||
return SE->getMinusSCEV(Start, Index);
|
return SE->getMinusSCEV(Start, Index);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Compute trip count from the backedge taken count.
|
||||||
|
static const SCEV *getTripCount(const SCEV *BECount, Type *IntPtr,
|
||||||
|
Loop *CurLoop, const DataLayout *DL,
|
||||||
|
ScalarEvolution *SE) {
|
||||||
|
const SCEV *TripCountS = nullptr;
|
||||||
|
// The # stored bytes is (BECount+1). Expand the trip count out to
|
||||||
|
// pointer size if it isn't already.
|
||||||
|
//
|
||||||
|
// If we're going to need to zero extend the BE count, check if we can add
|
||||||
|
// one to it prior to zero extending without overflow. Provided this is safe,
|
||||||
|
// it allows better simplification of the +1.
|
||||||
|
if (DL->getTypeSizeInBits(BECount->getType()) <
|
||||||
|
DL->getTypeSizeInBits(IntPtr) &&
|
||||||
|
SE->isLoopEntryGuardedByCond(
|
||||||
|
CurLoop, ICmpInst::ICMP_NE, BECount,
|
||||||
|
SE->getNegativeSCEV(SE->getOne(BECount->getType())))) {
|
||||||
|
TripCountS = SE->getZeroExtendExpr(
|
||||||
|
SE->getAddExpr(BECount, SE->getOne(BECount->getType()), SCEV::FlagNUW),
|
||||||
|
IntPtr);
|
||||||
|
} else {
|
||||||
|
TripCountS = SE->getAddExpr(SE->getTruncateOrZeroExtend(BECount, IntPtr),
|
||||||
|
SE->getOne(IntPtr), SCEV::FlagNUW);
|
||||||
|
}
|
||||||
|
|
||||||
|
return TripCountS;
|
||||||
|
}
|
||||||
|
|
||||||
/// Compute the number of bytes as a SCEV from the backedge taken count.
|
/// Compute the number of bytes as a SCEV from the backedge taken count.
|
||||||
///
|
///
|
||||||
/// This also maps the SCEV into the provided type and tries to handle the
|
/// This also maps the SCEV into the provided type and tries to handle the
|
||||||
|
@ -997,38 +1031,31 @@ static const SCEV *getStartForNegStride(const SCEV *Start, const SCEV *BECount,
|
||||||
static const SCEV *getNumBytes(const SCEV *BECount, Type *IntPtr,
|
static const SCEV *getNumBytes(const SCEV *BECount, Type *IntPtr,
|
||||||
unsigned StoreSize, Loop *CurLoop,
|
unsigned StoreSize, Loop *CurLoop,
|
||||||
const DataLayout *DL, ScalarEvolution *SE) {
|
const DataLayout *DL, ScalarEvolution *SE) {
|
||||||
const SCEV *NumBytesS;
|
const SCEV *TripCountSCEV = getTripCount(BECount, IntPtr, CurLoop, DL, SE);
|
||||||
// The # stored bytes is (BECount+1)*Size. Expand the trip count out to
|
|
||||||
// pointer size if it isn't already.
|
|
||||||
//
|
|
||||||
// If we're going to need to zero extend the BE count, check if we can add
|
|
||||||
// one to it prior to zero extending without overflow. Provided this is safe,
|
|
||||||
// it allows better simplification of the +1.
|
|
||||||
if (DL->getTypeSizeInBits(BECount->getType()).getFixedSize() <
|
|
||||||
DL->getTypeSizeInBits(IntPtr).getFixedSize() &&
|
|
||||||
SE->isLoopEntryGuardedByCond(
|
|
||||||
CurLoop, ICmpInst::ICMP_NE, BECount,
|
|
||||||
SE->getNegativeSCEV(SE->getOne(BECount->getType())))) {
|
|
||||||
NumBytesS = SE->getZeroExtendExpr(
|
|
||||||
SE->getAddExpr(BECount, SE->getOne(BECount->getType()), SCEV::FlagNUW),
|
|
||||||
IntPtr);
|
|
||||||
} else {
|
|
||||||
NumBytesS = SE->getAddExpr(SE->getTruncateOrZeroExtend(BECount, IntPtr),
|
|
||||||
SE->getOne(IntPtr), SCEV::FlagNUW);
|
|
||||||
}
|
|
||||||
|
|
||||||
// And scale it based on the store size.
|
// And scale it based on the store size.
|
||||||
if (StoreSize != 1) {
|
if (StoreSize != 1) {
|
||||||
NumBytesS = SE->getMulExpr(NumBytesS, SE->getConstant(IntPtr, StoreSize),
|
return SE->getMulExpr(TripCountSCEV, SE->getConstant(IntPtr, StoreSize),
|
||||||
SCEV::FlagNUW);
|
SCEV::FlagNUW);
|
||||||
}
|
}
|
||||||
return NumBytesS;
|
return TripCountSCEV;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// getNumBytes that takes StoreSize as a SCEV
|
||||||
|
static const SCEV *getNumBytes(const SCEV *BECount, Type *IntPtr,
|
||||||
|
const SCEV *StoreSizeSCEV, Loop *CurLoop,
|
||||||
|
const DataLayout *DL, ScalarEvolution *SE) {
|
||||||
|
const SCEV *TripCountSCEV = getTripCount(BECount, IntPtr, CurLoop, DL, SE);
|
||||||
|
|
||||||
|
return SE->getMulExpr(TripCountSCEV,
|
||||||
|
SE->getTruncateOrZeroExtend(StoreSizeSCEV, IntPtr),
|
||||||
|
SCEV::FlagNUW);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// processLoopStridedStore - We see a strided store of some value. If we can
|
/// processLoopStridedStore - We see a strided store of some value. If we can
|
||||||
/// transform this into a memset or memset_pattern in the loop preheader, do so.
|
/// transform this into a memset or memset_pattern in the loop preheader, do so.
|
||||||
bool LoopIdiomRecognize::processLoopStridedStore(
|
bool LoopIdiomRecognize::processLoopStridedStore(
|
||||||
Value *DestPtr, unsigned StoreSize, MaybeAlign StoreAlignment,
|
Value *DestPtr, const SCEV *StoreSizeSCEV, MaybeAlign StoreAlignment,
|
||||||
Value *StoredVal, Instruction *TheStore,
|
Value *StoredVal, Instruction *TheStore,
|
||||||
SmallPtrSetImpl<Instruction *> &Stores, const SCEVAddRecExpr *Ev,
|
SmallPtrSetImpl<Instruction *> &Stores, const SCEVAddRecExpr *Ev,
|
||||||
const SCEV *BECount, bool NegStride, bool IsLoopMemset) {
|
const SCEV *BECount, bool NegStride, bool IsLoopMemset) {
|
||||||
|
@ -1057,7 +1084,7 @@ bool LoopIdiomRecognize::processLoopStridedStore(
|
||||||
const SCEV *Start = Ev->getStart();
|
const SCEV *Start = Ev->getStart();
|
||||||
// Handle negative strided loops.
|
// Handle negative strided loops.
|
||||||
if (NegStride)
|
if (NegStride)
|
||||||
Start = getStartForNegStride(Start, BECount, IntIdxTy, StoreSize, SE);
|
Start = getStartForNegStride(Start, BECount, IntIdxTy, StoreSizeSCEV, SE);
|
||||||
|
|
||||||
// TODO: ideally we should still be able to generate memset if SCEV expander
|
// TODO: ideally we should still be able to generate memset if SCEV expander
|
||||||
// is taught to generate the dependencies at the latest point.
|
// is taught to generate the dependencies at the latest point.
|
||||||
|
@ -1082,7 +1109,7 @@ bool LoopIdiomRecognize::processLoopStridedStore(
|
||||||
Changed = true;
|
Changed = true;
|
||||||
|
|
||||||
if (mayLoopAccessLocation(BasePtr, ModRefInfo::ModRef, CurLoop, BECount,
|
if (mayLoopAccessLocation(BasePtr, ModRefInfo::ModRef, CurLoop, BECount,
|
||||||
StoreSize, *AA, Stores))
|
StoreSizeSCEV, *AA, Stores))
|
||||||
return Changed;
|
return Changed;
|
||||||
|
|
||||||
if (avoidLIRForMultiBlockLoop(/*IsMemset=*/true, IsLoopMemset))
|
if (avoidLIRForMultiBlockLoop(/*IsMemset=*/true, IsLoopMemset))
|
||||||
|
@ -1091,7 +1118,7 @@ bool LoopIdiomRecognize::processLoopStridedStore(
|
||||||
// Okay, everything looks good, insert the memset.
|
// Okay, everything looks good, insert the memset.
|
||||||
|
|
||||||
const SCEV *NumBytesS =
|
const SCEV *NumBytesS =
|
||||||
getNumBytes(BECount, IntIdxTy, StoreSize, CurLoop, DL, SE);
|
getNumBytes(BECount, IntIdxTy, StoreSizeSCEV, CurLoop, DL, SE);
|
||||||
|
|
||||||
// TODO: ideally we should still be able to generate memset if SCEV expander
|
// TODO: ideally we should still be able to generate memset if SCEV expander
|
||||||
// is taught to generate the dependencies at the latest point.
|
// is taught to generate the dependencies at the latest point.
|
||||||
|
@ -1215,9 +1242,11 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(
|
||||||
APInt Stride = getStoreStride(StoreEv);
|
APInt Stride = getStoreStride(StoreEv);
|
||||||
bool NegStride = StoreSize == -Stride;
|
bool NegStride = StoreSize == -Stride;
|
||||||
|
|
||||||
|
const SCEV *StoreSizeSCEV = SE->getConstant(BECount->getType(), StoreSize);
|
||||||
// Handle negative strided loops.
|
// Handle negative strided loops.
|
||||||
if (NegStride)
|
if (NegStride)
|
||||||
StrStart = getStartForNegStride(StrStart, BECount, IntIdxTy, StoreSize, SE);
|
StrStart =
|
||||||
|
getStartForNegStride(StrStart, BECount, IntIdxTy, StoreSizeSCEV, SE);
|
||||||
|
|
||||||
// Okay, we have a strided store "p[i]" of a loaded value. We can turn
|
// Okay, we have a strided store "p[i]" of a loaded value. We can turn
|
||||||
// this into a memcpy in the loop preheader now if we want. However, this
|
// this into a memcpy in the loop preheader now if we want. However, this
|
||||||
|
@ -1245,11 +1274,11 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(
|
||||||
|
|
||||||
bool UseMemMove =
|
bool UseMemMove =
|
||||||
mayLoopAccessLocation(StoreBasePtr, ModRefInfo::ModRef, CurLoop, BECount,
|
mayLoopAccessLocation(StoreBasePtr, ModRefInfo::ModRef, CurLoop, BECount,
|
||||||
StoreSize, *AA, Stores);
|
StoreSizeSCEV, *AA, Stores);
|
||||||
if (UseMemMove) {
|
if (UseMemMove) {
|
||||||
Stores.insert(TheLoad);
|
Stores.insert(TheLoad);
|
||||||
if (mayLoopAccessLocation(StoreBasePtr, ModRefInfo::ModRef, CurLoop,
|
if (mayLoopAccessLocation(StoreBasePtr, ModRefInfo::ModRef, CurLoop,
|
||||||
BECount, StoreSize, *AA, Stores)) {
|
BECount, StoreSizeSCEV, *AA, Stores)) {
|
||||||
ORE.emit([&]() {
|
ORE.emit([&]() {
|
||||||
return OptimizationRemarkMissed(DEBUG_TYPE, "LoopMayAccessStore",
|
return OptimizationRemarkMissed(DEBUG_TYPE, "LoopMayAccessStore",
|
||||||
TheStore)
|
TheStore)
|
||||||
|
@ -1268,7 +1297,8 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(
|
||||||
|
|
||||||
// Handle negative strided loops.
|
// Handle negative strided loops.
|
||||||
if (NegStride)
|
if (NegStride)
|
||||||
LdStart = getStartForNegStride(LdStart, BECount, IntIdxTy, StoreSize, SE);
|
LdStart =
|
||||||
|
getStartForNegStride(LdStart, BECount, IntIdxTy, StoreSizeSCEV, SE);
|
||||||
|
|
||||||
// For a memcpy, we have to make sure that the input array is not being
|
// For a memcpy, we have to make sure that the input array is not being
|
||||||
// mutated by the loop.
|
// mutated by the loop.
|
||||||
|
@ -1280,7 +1310,7 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(
|
||||||
if (IsMemCpy)
|
if (IsMemCpy)
|
||||||
Stores.erase(TheStore);
|
Stores.erase(TheStore);
|
||||||
if (mayLoopAccessLocation(LoadBasePtr, ModRefInfo::Mod, CurLoop, BECount,
|
if (mayLoopAccessLocation(LoadBasePtr, ModRefInfo::Mod, CurLoop, BECount,
|
||||||
StoreSize, *AA, Stores)) {
|
StoreSizeSCEV, *AA, Stores)) {
|
||||||
ORE.emit([&]() {
|
ORE.emit([&]() {
|
||||||
return OptimizationRemarkMissed(DEBUG_TYPE, "LoopMayAccessLoad", TheLoad)
|
return OptimizationRemarkMissed(DEBUG_TYPE, "LoopMayAccessLoad", TheLoad)
|
||||||
<< ore::NV("Inst", InstRemark) << " in "
|
<< ore::NV("Inst", InstRemark) << " in "
|
||||||
|
|
Loading…
Reference in New Issue