forked from OSchip/llvm-project
[LoopVectorizer] NFCI: Calculate register usage based on TLI.getTypeLegalizationCost.
This is more accurate than dividing the bitwidth based on the element count by the maximum register size, as it can just reuse whatever has been calculated for legalization of these types. This change is also necessary when calculating register usage for scalable vectors, where the legalization of these types cannot be done based on the widest register size, because that does not take the 'vscale' component into account. Reviewed By: SjoerdMeijer Differential Revision: https://reviews.llvm.org/D91059
This commit is contained in:
parent
91ce6fb5a6
commit
b873aba394
|
@ -708,6 +708,9 @@ public:
|
|||
/// Return true if this type is legal.
|
||||
bool isTypeLegal(Type *Ty) const;
|
||||
|
||||
/// Returns the estimated number of registers required to represent \p Ty.
|
||||
unsigned getRegUsageForType(Type *Ty) const;
|
||||
|
||||
/// Return true if switches should be turned into lookup tables for the
|
||||
/// target.
|
||||
bool shouldBuildLookupTables() const;
|
||||
|
@ -1447,6 +1450,7 @@ public:
|
|||
virtual bool isProfitableToHoist(Instruction *I) = 0;
|
||||
virtual bool useAA() = 0;
|
||||
virtual bool isTypeLegal(Type *Ty) = 0;
|
||||
virtual unsigned getRegUsageForType(Type *Ty) = 0;
|
||||
virtual bool shouldBuildLookupTables() = 0;
|
||||
virtual bool shouldBuildLookupTablesForConstant(Constant *C) = 0;
|
||||
virtual bool useColdCCForColdCall(Function &F) = 0;
|
||||
|
@ -1807,6 +1811,9 @@ public:
|
|||
}
|
||||
bool useAA() override { return Impl.useAA(); }
|
||||
bool isTypeLegal(Type *Ty) override { return Impl.isTypeLegal(Ty); }
|
||||
unsigned getRegUsageForType(Type *Ty) override {
|
||||
return Impl.getRegUsageForType(Ty);
|
||||
}
|
||||
bool shouldBuildLookupTables() override {
|
||||
return Impl.shouldBuildLookupTables();
|
||||
}
|
||||
|
|
|
@ -259,6 +259,8 @@ public:
|
|||
|
||||
bool isTypeLegal(Type *Ty) { return false; }
|
||||
|
||||
unsigned getRegUsageForType(Type *Ty) { return 1; }
|
||||
|
||||
bool shouldBuildLookupTables() { return true; }
|
||||
bool shouldBuildLookupTablesForConstant(Constant *C) { return true; }
|
||||
|
||||
|
|
|
@ -297,6 +297,10 @@ public:
|
|||
return getTLI()->isTypeLegal(VT);
|
||||
}
|
||||
|
||||
unsigned getRegUsageForType(Type *Ty) {
|
||||
return getTLI()->getTypeLegalizationCost(DL, Ty).first;
|
||||
}
|
||||
|
||||
int getGEPCost(Type *PointeeType, const Value *Ptr,
|
||||
ArrayRef<const Value *> Operands) {
|
||||
return BaseT::getGEPCost(PointeeType, Ptr, Operands);
|
||||
|
|
|
@ -482,6 +482,10 @@ bool TargetTransformInfo::isTypeLegal(Type *Ty) const {
|
|||
return TTIImpl->isTypeLegal(Ty);
|
||||
}
|
||||
|
||||
unsigned TargetTransformInfo::getRegUsageForType(Type *Ty) const {
|
||||
return TTIImpl->getRegUsageForType(Ty);
|
||||
}
|
||||
|
||||
bool TargetTransformInfo::shouldBuildLookupTables() const {
|
||||
return TTIImpl->shouldBuildLookupTables();
|
||||
}
|
||||
|
|
|
@ -5793,8 +5793,6 @@ LoopVectorizationCostModel::calculateRegisterUsage(ArrayRef<ElementCount> VFs) {
|
|||
unsigned MaxSafeDepDist = -1U;
|
||||
if (Legal->getMaxSafeDepDistBytes() != -1U)
|
||||
MaxSafeDepDist = Legal->getMaxSafeDepDistBytes() * 8;
|
||||
unsigned WidestRegister =
|
||||
std::min(TTI.getRegisterBitWidth(true), MaxSafeDepDist);
|
||||
const DataLayout &DL = TheFunction->getParent()->getDataLayout();
|
||||
|
||||
SmallVector<RegisterUsage, 8> RUs(VFs.size());
|
||||
|
@ -5803,13 +5801,10 @@ LoopVectorizationCostModel::calculateRegisterUsage(ArrayRef<ElementCount> VFs) {
|
|||
LLVM_DEBUG(dbgs() << "LV(REG): Calculating max register usage:\n");
|
||||
|
||||
// A lambda that gets the register usage for the given type and VF.
|
||||
auto GetRegUsage = [&DL, WidestRegister](Type *Ty, ElementCount VF) {
|
||||
auto GetRegUsage = [&DL, &TTI=TTI](Type *Ty, ElementCount VF) {
|
||||
if (Ty->isTokenTy())
|
||||
return 0U;
|
||||
unsigned TypeSize = DL.getTypeSizeInBits(Ty->getScalarType());
|
||||
assert(!VF.isScalable() && "scalable vectors not yet supported.");
|
||||
return std::max<unsigned>(1, VF.getKnownMinValue() * TypeSize /
|
||||
WidestRegister);
|
||||
return TTI.getRegUsageForType(VectorType::get(Ty, VF));
|
||||
};
|
||||
|
||||
for (unsigned int i = 0, s = IdxToInstr.size(); i < s; ++i) {
|
||||
|
|
Loading…
Reference in New Issue