[SLP] Enhance SLPVectorizer to vectorize different combinations of aggregates

Summary:
Make SLPVectorize to recognize homogeneous aggregates like
`{<2 x float>, <2 x float>}`, `{{float, float}, {float, float}}`,
`[2 x {float, float}]` and so on.
It's a follow-up of https://reviews.llvm.org/D70068.
Merged `findBuildVector()` and `findBuildAggregate()` to
one `findBuildAggregate()` function making it recursive
to recognize multidimensional aggregates. Aggregates required
to be homogeneous.

Reviewers: RKSimon, ABataev, dtemirbulatov, spatel, vporpo

Subscribers: hiraditya, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D70587
This commit is contained in:
Anton Afanasyev 2019-11-21 18:41:52 +03:00
parent c094e7dc4b
commit a315519c17
2 changed files with 96 additions and 109 deletions

View File

@ -629,9 +629,10 @@ public:
return MinVecRegSize;
}
/// Check if ArrayType or StructType is isomorphic to some VectorType.
/// Accepts homogeneous aggregate of vectors like
/// { <2 x float>, <2 x float> }
/// Check if homogeneous aggregate is isomorphic to some VectorType.
/// Accepts homogeneous multidimensional aggregate of scalars/vectors like
/// {[4 x i16], [4 x i16]}, { <2 x float>, <2 x float> },
/// {{{i16, i16}, {i16, i16}}, {{i16, i16}, {i16, i16}}} and so on.
///
/// \returns number of elements in vector if isomorphism exists, 0 otherwise.
unsigned canMapToVector(Type *T, const DataLayout &DL) const;
@ -3088,20 +3089,22 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
}
unsigned BoUpSLP::canMapToVector(Type *T, const DataLayout &DL) const {
unsigned N;
Type *EltTy;
auto *ST = dyn_cast<StructType>(T);
if (ST) {
N = ST->getNumElements();
EltTy = *ST->element_begin();
} else {
N = cast<ArrayType>(T)->getNumElements();
EltTy = cast<ArrayType>(T)->getElementType();
}
unsigned N = 1;
Type *EltTy = T;
if (auto *VT = dyn_cast<VectorType>(EltTy)) {
EltTy = VT->getElementType();
N *= VT->getNumElements();
while (isa<CompositeType>(EltTy)) {
if (auto *ST = dyn_cast<StructType>(EltTy)) {
// Check that struct is homogeneous.
for (const auto *Ty : ST->elements())
if (Ty != *ST->element_begin())
return 0;
N *= ST->getNumElements();
EltTy = *ST->element_begin();
} else {
auto *SeqT = cast<SequentialType>(EltTy);
N *= SeqT->getNumElements();
EltTy = SeqT->getElementType();
}
}
if (!isValidElementType(EltTy))
@ -3109,12 +3112,6 @@ unsigned BoUpSLP::canMapToVector(Type *T, const DataLayout &DL) const {
uint64_t VTSize = DL.getTypeStoreSizeInBits(VectorType::get(EltTy, N));
if (VTSize < MinVecRegSize || VTSize > MaxVecRegSize || VTSize != DL.getTypeStoreSizeInBits(T))
return 0;
if (ST) {
// Check that struct is homogeneous.
for (const auto *Ty : ST->elements())
if (Ty != *ST->element_begin())
return 0;
}
return N;
}
@ -6940,57 +6937,54 @@ private:
/// %rb = insertelement <4 x float> %ra, float %s1, i32 1
/// %rc = insertelement <4 x float> %rb, float %s2, i32 2
/// %rd = insertelement <4 x float> %rc, float %s3, i32 3
/// starting from the last insertelement instruction.
/// starting from the last insertelement or insertvalue instruction.
///
/// Returns true if it matches
static bool findBuildVector(InsertElementInst *LastInsertElem,
TargetTransformInfo *TTI,
SmallVectorImpl<Value *> &BuildVectorOpds,
int &UserCost) {
UserCost = 0;
Value *V = nullptr;
do {
if (auto *CI = dyn_cast<ConstantInt>(LastInsertElem->getOperand(2))) {
UserCost += TTI->getVectorInstrCost(Instruction::InsertElement,
LastInsertElem->getType(),
CI->getZExtValue());
}
BuildVectorOpds.push_back(LastInsertElem->getOperand(1));
V = LastInsertElem->getOperand(0);
if (isa<UndefValue>(V))
break;
LastInsertElem = dyn_cast<InsertElementInst>(V);
if (!LastInsertElem || !LastInsertElem->hasOneUse())
return false;
} while (true);
std::reverse(BuildVectorOpds.begin(), BuildVectorOpds.end());
return true;
}
/// Like findBuildVector, but looks for construction of aggregate.
/// Accepts homegeneous aggregate of vectors like { <2 x float>, <2 x float> }.
/// Also recognize aggregates like {<2 x float>, <2 x float>},
/// {{float, float}, {float, float}}, [2 x {float, float}] and so on.
/// See llvm/test/Transforms/SLPVectorizer/X86/pr42022.ll for examples.
///
/// Assume LastInsertInst is of InsertElementInst or InsertValueInst type.
///
/// \return true if it matches.
static bool findBuildAggregate(InsertValueInst *IV, TargetTransformInfo *TTI,
static bool findBuildAggregate(Value *LastInsertInst, TargetTransformInfo *TTI,
SmallVectorImpl<Value *> &BuildVectorOpds,
int &UserCost) {
assert((isa<InsertElementInst>(LastInsertInst) ||
isa<InsertValueInst>(LastInsertInst)) &&
"Expected insertelement or insertvalue instruction!");
UserCost = 0;
do {
if (auto *IE = dyn_cast<InsertElementInst>(IV->getInsertedValueOperand())) {
Value *InsertedOperand;
if (auto *IE = dyn_cast<InsertElementInst>(LastInsertInst)) {
InsertedOperand = IE->getOperand(1);
LastInsertInst = IE->getOperand(0);
if (auto *CI = dyn_cast<ConstantInt>(IE->getOperand(2))) {
UserCost += TTI->getVectorInstrCost(Instruction::InsertElement,
IE->getType(), CI->getZExtValue());
}
} else {
auto *IV = cast<InsertValueInst>(LastInsertInst);
InsertedOperand = IV->getInsertedValueOperand();
LastInsertInst = IV->getAggregateOperand();
}
if (isa<InsertElementInst>(InsertedOperand) ||
isa<InsertValueInst>(InsertedOperand)) {
int TmpUserCost;
SmallVector<Value *, 4> TmpBuildVectorOpds;
if (!findBuildVector(IE, TTI, TmpBuildVectorOpds, TmpUserCost))
SmallVector<Value *, 8> TmpBuildVectorOpds;
if (!findBuildAggregate(InsertedOperand, TTI, TmpBuildVectorOpds,
TmpUserCost))
return false;
BuildVectorOpds.append(TmpBuildVectorOpds.rbegin(), TmpBuildVectorOpds.rend());
BuildVectorOpds.append(TmpBuildVectorOpds.rbegin(),
TmpBuildVectorOpds.rend());
UserCost += TmpUserCost;
} else {
BuildVectorOpds.push_back(IV->getInsertedValueOperand());
BuildVectorOpds.push_back(InsertedOperand);
}
Value *V = IV->getAggregateOperand();
if (isa<UndefValue>(V))
if (isa<UndefValue>(LastInsertInst))
break;
IV = dyn_cast<InsertValueInst>(V);
if (!IV || !IV->hasOneUse())
if ((!isa<InsertValueInst>(LastInsertInst) &&
!isa<InsertElementInst>(LastInsertInst)) ||
!LastInsertInst->hasOneUse())
return false;
} while (true);
std::reverse(BuildVectorOpds.begin(), BuildVectorOpds.end());
@ -7177,7 +7171,7 @@ bool SLPVectorizerPass::vectorizeInsertElementInst(InsertElementInst *IEI,
BasicBlock *BB, BoUpSLP &R) {
int UserCost;
SmallVector<Value *, 16> BuildVectorOpds;
if (!findBuildVector(IEI, TTI, BuildVectorOpds, UserCost) ||
if (!findBuildAggregate(IEI, TTI, BuildVectorOpds, UserCost) ||
(llvm::all_of(BuildVectorOpds,
[](Value *V) { return isa<ExtractElementInst>(V); }) &&
isShuffle(BuildVectorOpds)))

View File

@ -55,21 +55,20 @@ define { <2 x float>, <2 x float> } @StructOfVectors(float *%Ptr) {
define [2 x %StructTy] @ArrayOfStruct(float *%Ptr) {
; CHECK-LABEL: @ArrayOfStruct(
; CHECK-NEXT: [[GEP0:%.*]] = getelementptr inbounds float, float* [[PTR:%.*]], i64 0
; CHECK-NEXT: [[L0:%.*]] = load float, float* [[GEP0]]
; CHECK-NEXT: [[GEP1:%.*]] = getelementptr inbounds float, float* [[PTR]], i64 1
; CHECK-NEXT: [[L1:%.*]] = load float, float* [[GEP1]]
; CHECK-NEXT: [[GEP2:%.*]] = getelementptr inbounds float, float* [[PTR]], i64 2
; CHECK-NEXT: [[L2:%.*]] = load float, float* [[GEP2]]
; CHECK-NEXT: [[GEP3:%.*]] = getelementptr inbounds float, float* [[PTR]], i64 3
; CHECK-NEXT: [[L3:%.*]] = load float, float* [[GEP3]]
; CHECK-NEXT: [[FADD0:%.*]] = fadd fast float [[L0]], 1.100000e+01
; CHECK-NEXT: [[FADD1:%.*]] = fadd fast float [[L1]], 1.200000e+01
; CHECK-NEXT: [[FADD2:%.*]] = fadd fast float [[L2]], 1.300000e+01
; CHECK-NEXT: [[FADD3:%.*]] = fadd fast float [[L3]], 1.400000e+01
; CHECK-NEXT: [[STRUCTIN0:%.*]] = insertvalue [[STRUCTTY:%.*]] undef, float [[FADD0]], 0
; CHECK-NEXT: [[STRUCTIN1:%.*]] = insertvalue [[STRUCTTY]] %StructIn0, float [[FADD1]], 1
; CHECK-NEXT: [[STRUCTIN2:%.*]] = insertvalue [[STRUCTTY]] undef, float [[FADD2]], 0
; CHECK-NEXT: [[STRUCTIN3:%.*]] = insertvalue [[STRUCTTY]] %StructIn2, float [[FADD3]], 1
; CHECK-NEXT: [[TMP1:%.*]] = bitcast float* [[GEP0]] to <4 x float>*
; CHECK-NEXT: [[TMP2:%.*]] = load <4 x float>, <4 x float>* [[TMP1]], align 4
; CHECK-NEXT: [[TMP3:%.*]] = fadd fast <4 x float> [[TMP2]], <float 1.100000e+01, float 1.200000e+01, float 1.300000e+01, float 1.400000e+01>
; CHECK-NEXT: [[TMP4:%.*]] = extractelement <4 x float> [[TMP3]], i32 0
; CHECK-NEXT: [[STRUCTIN0:%.*]] = insertvalue [[STRUCTTY:%.*]] undef, float [[TMP4]], 0
; CHECK-NEXT: [[TMP5:%.*]] = extractelement <4 x float> [[TMP3]], i32 1
; CHECK-NEXT: [[STRUCTIN1:%.*]] = insertvalue [[STRUCTTY]] %StructIn0, float [[TMP5]], 1
; CHECK-NEXT: [[TMP6:%.*]] = extractelement <4 x float> [[TMP3]], i32 2
; CHECK-NEXT: [[STRUCTIN2:%.*]] = insertvalue [[STRUCTTY]] undef, float [[TMP6]], 0
; CHECK-NEXT: [[TMP7:%.*]] = extractelement <4 x float> [[TMP3]], i32 3
; CHECK-NEXT: [[STRUCTIN3:%.*]] = insertvalue [[STRUCTTY]] %StructIn2, float [[TMP7]], 1
; CHECK-NEXT: [[RET0:%.*]] = insertvalue [2 x %StructTy] undef, [[STRUCTTY]] %StructIn1, 0
; CHECK-NEXT: [[RET1:%.*]] = insertvalue [2 x %StructTy] [[RET0]], [[STRUCTTY]] %StructIn3, 1
; CHECK-NEXT: ret [2 x %StructTy] [[RET1]]
@ -102,21 +101,20 @@ define [2 x %StructTy] @ArrayOfStruct(float *%Ptr) {
define {%StructTy, %StructTy} @StructOfStruct(float *%Ptr) {
; CHECK-LABEL: @StructOfStruct(
; CHECK-NEXT: [[GEP0:%.*]] = getelementptr inbounds float, float* [[PTR:%.*]], i64 0
; CHECK-NEXT: [[L0:%.*]] = load float, float* [[GEP0]]
; CHECK-NEXT: [[GEP1:%.*]] = getelementptr inbounds float, float* [[PTR]], i64 1
; CHECK-NEXT: [[L1:%.*]] = load float, float* [[GEP1]]
; CHECK-NEXT: [[GEP2:%.*]] = getelementptr inbounds float, float* [[PTR]], i64 2
; CHECK-NEXT: [[L2:%.*]] = load float, float* [[GEP2]]
; CHECK-NEXT: [[GEP3:%.*]] = getelementptr inbounds float, float* [[PTR]], i64 3
; CHECK-NEXT: [[L3:%.*]] = load float, float* [[GEP3]]
; CHECK-NEXT: [[FADD0:%.*]] = fadd fast float [[L0]], 1.100000e+01
; CHECK-NEXT: [[FADD1:%.*]] = fadd fast float [[L1]], 1.200000e+01
; CHECK-NEXT: [[FADD2:%.*]] = fadd fast float [[L2]], 1.300000e+01
; CHECK-NEXT: [[FADD3:%.*]] = fadd fast float [[L3]], 1.400000e+01
; CHECK-NEXT: [[STRUCTIN0:%.*]] = insertvalue [[STRUCTTY:%.*]] undef, float [[FADD0]], 0
; CHECK-NEXT: [[STRUCTIN1:%.*]] = insertvalue [[STRUCTTY]] %StructIn0, float [[FADD1]], 1
; CHECK-NEXT: [[STRUCTIN2:%.*]] = insertvalue [[STRUCTTY]] undef, float [[FADD2]], 0
; CHECK-NEXT: [[STRUCTIN3:%.*]] = insertvalue [[STRUCTTY]] %StructIn2, float [[FADD3]], 1
; CHECK-NEXT: [[TMP1:%.*]] = bitcast float* [[GEP0]] to <4 x float>*
; CHECK-NEXT: [[TMP2:%.*]] = load <4 x float>, <4 x float>* [[TMP1]], align 4
; CHECK-NEXT: [[TMP3:%.*]] = fadd fast <4 x float> [[TMP2]], <float 1.100000e+01, float 1.200000e+01, float 1.300000e+01, float 1.400000e+01>
; CHECK-NEXT: [[TMP4:%.*]] = extractelement <4 x float> [[TMP3]], i32 0
; CHECK-NEXT: [[STRUCTIN0:%.*]] = insertvalue [[STRUCTTY:%.*]] undef, float [[TMP4]], 0
; CHECK-NEXT: [[TMP5:%.*]] = extractelement <4 x float> [[TMP3]], i32 1
; CHECK-NEXT: [[STRUCTIN1:%.*]] = insertvalue [[STRUCTTY]] %StructIn0, float [[TMP5]], 1
; CHECK-NEXT: [[TMP6:%.*]] = extractelement <4 x float> [[TMP3]], i32 2
; CHECK-NEXT: [[STRUCTIN2:%.*]] = insertvalue [[STRUCTTY]] undef, float [[TMP6]], 0
; CHECK-NEXT: [[TMP7:%.*]] = extractelement <4 x float> [[TMP3]], i32 3
; CHECK-NEXT: [[STRUCTIN3:%.*]] = insertvalue [[STRUCTTY]] %StructIn2, float [[TMP7]], 1
; CHECK-NEXT: [[RET0:%.*]] = insertvalue { [[STRUCTTY]], [[STRUCTTY]] } undef, [[STRUCTTY]] %StructIn1, 0
; CHECK-NEXT: [[RET1:%.*]] = insertvalue { [[STRUCTTY]], [[STRUCTTY]] } [[RET0]], [[STRUCTTY]] %StructIn3, 1
; CHECK-NEXT: ret { [[STRUCTTY]], [[STRUCTTY]] } [[RET1]]
@ -196,37 +194,32 @@ define {%StructTy, float, float} @NonHomogeneousStruct(float *%Ptr) {
define {%Struct2Ty, %Struct2Ty} @StructOfStructOfStruct(i16 *%Ptr) {
; CHECK-LABEL: @StructOfStructOfStruct(
; CHECK-NEXT: [[GEP0:%.*]] = getelementptr inbounds i16, i16* [[PTR:%.*]], i64 0
; CHECK-NEXT: [[L0:%.*]] = load i16, i16* [[GEP0]]
; CHECK-NEXT: [[GEP1:%.*]] = getelementptr inbounds i16, i16* [[PTR]], i64 1
; CHECK-NEXT: [[L1:%.*]] = load i16, i16* [[GEP1]]
; CHECK-NEXT: [[GEP2:%.*]] = getelementptr inbounds i16, i16* [[PTR]], i64 2
; CHECK-NEXT: [[L2:%.*]] = load i16, i16* [[GEP2]]
; CHECK-NEXT: [[GEP3:%.*]] = getelementptr inbounds i16, i16* [[PTR]], i64 3
; CHECK-NEXT: [[L3:%.*]] = load i16, i16* [[GEP3]]
; CHECK-NEXT: [[GEP4:%.*]] = getelementptr inbounds i16, i16* [[PTR]], i64 4
; CHECK-NEXT: [[L4:%.*]] = load i16, i16* [[GEP4]]
; CHECK-NEXT: [[GEP5:%.*]] = getelementptr inbounds i16, i16* [[PTR]], i64 5
; CHECK-NEXT: [[L5:%.*]] = load i16, i16* [[GEP5]]
; CHECK-NEXT: [[GEP6:%.*]] = getelementptr inbounds i16, i16* [[PTR]], i64 6
; CHECK-NEXT: [[L6:%.*]] = load i16, i16* [[GEP6]]
; CHECK-NEXT: [[GEP7:%.*]] = getelementptr inbounds i16, i16* [[PTR]], i64 7
; CHECK-NEXT: [[L7:%.*]] = load i16, i16* [[GEP7]]
; CHECK-NEXT: [[FADD0:%.*]] = add i16 [[L0]], 1
; CHECK-NEXT: [[FADD1:%.*]] = add i16 [[L1]], 2
; CHECK-NEXT: [[FADD2:%.*]] = add i16 [[L2]], 3
; CHECK-NEXT: [[FADD3:%.*]] = add i16 [[L3]], 4
; CHECK-NEXT: [[FADD4:%.*]] = add i16 [[L4]], 5
; CHECK-NEXT: [[FADD5:%.*]] = add i16 [[L5]], 6
; CHECK-NEXT: [[FADD6:%.*]] = add i16 [[L6]], 7
; CHECK-NEXT: [[FADD7:%.*]] = add i16 [[L7]], 8
; CHECK-NEXT: [[STRUCTIN0:%.*]] = insertvalue [[STRUCT1TY:%.*]] undef, i16 [[FADD0]], 0
; CHECK-NEXT: [[STRUCTIN1:%.*]] = insertvalue [[STRUCT1TY]] %StructIn0, i16 [[FADD1]], 1
; CHECK-NEXT: [[STRUCTIN2:%.*]] = insertvalue [[STRUCT1TY]] undef, i16 [[FADD2]], 0
; CHECK-NEXT: [[STRUCTIN3:%.*]] = insertvalue [[STRUCT1TY]] %StructIn2, i16 [[FADD3]], 1
; CHECK-NEXT: [[STRUCTIN4:%.*]] = insertvalue [[STRUCT1TY]] undef, i16 [[FADD4]], 0
; CHECK-NEXT: [[STRUCTIN5:%.*]] = insertvalue [[STRUCT1TY]] %StructIn4, i16 [[FADD5]], 1
; CHECK-NEXT: [[STRUCTIN6:%.*]] = insertvalue [[STRUCT1TY]] undef, i16 [[FADD6]], 0
; CHECK-NEXT: [[STRUCTIN7:%.*]] = insertvalue [[STRUCT1TY]] %StructIn6, i16 [[FADD7]], 1
; CHECK-NEXT: [[TMP1:%.*]] = bitcast i16* [[GEP0]] to <8 x i16>*
; CHECK-NEXT: [[TMP2:%.*]] = load <8 x i16>, <8 x i16>* [[TMP1]], align 2
; CHECK-NEXT: [[TMP3:%.*]] = add <8 x i16> [[TMP2]], <i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7, i16 8>
; CHECK-NEXT: [[TMP4:%.*]] = extractelement <8 x i16> [[TMP3]], i32 0
; CHECK-NEXT: [[STRUCTIN0:%.*]] = insertvalue [[STRUCT1TY:%.*]] undef, i16 [[TMP4]], 0
; CHECK-NEXT: [[TMP5:%.*]] = extractelement <8 x i16> [[TMP3]], i32 1
; CHECK-NEXT: [[STRUCTIN1:%.*]] = insertvalue [[STRUCT1TY]] %StructIn0, i16 [[TMP5]], 1
; CHECK-NEXT: [[TMP6:%.*]] = extractelement <8 x i16> [[TMP3]], i32 2
; CHECK-NEXT: [[STRUCTIN2:%.*]] = insertvalue [[STRUCT1TY]] undef, i16 [[TMP6]], 0
; CHECK-NEXT: [[TMP7:%.*]] = extractelement <8 x i16> [[TMP3]], i32 3
; CHECK-NEXT: [[STRUCTIN3:%.*]] = insertvalue [[STRUCT1TY]] %StructIn2, i16 [[TMP7]], 1
; CHECK-NEXT: [[TMP8:%.*]] = extractelement <8 x i16> [[TMP3]], i32 4
; CHECK-NEXT: [[STRUCTIN4:%.*]] = insertvalue [[STRUCT1TY]] undef, i16 [[TMP8]], 0
; CHECK-NEXT: [[TMP9:%.*]] = extractelement <8 x i16> [[TMP3]], i32 5
; CHECK-NEXT: [[STRUCTIN5:%.*]] = insertvalue [[STRUCT1TY]] %StructIn4, i16 [[TMP9]], 1
; CHECK-NEXT: [[TMP10:%.*]] = extractelement <8 x i16> [[TMP3]], i32 6
; CHECK-NEXT: [[STRUCTIN6:%.*]] = insertvalue [[STRUCT1TY]] undef, i16 [[TMP10]], 0
; CHECK-NEXT: [[TMP11:%.*]] = extractelement <8 x i16> [[TMP3]], i32 7
; CHECK-NEXT: [[STRUCTIN7:%.*]] = insertvalue [[STRUCT1TY]] %StructIn6, i16 [[TMP11]], 1
; CHECK-NEXT: [[STRUCT2IN0:%.*]] = insertvalue [[STRUCT2TY:%.*]] undef, [[STRUCT1TY]] %StructIn1, 0
; CHECK-NEXT: [[STRUCT2IN1:%.*]] = insertvalue [[STRUCT2TY]] %Struct2In0, [[STRUCT1TY]] %StructIn3, 1
; CHECK-NEXT: [[STRUCT2IN2:%.*]] = insertvalue [[STRUCT2TY]] undef, [[STRUCT1TY]] %StructIn5, 0