forked from OSchip/llvm-project
[X86][AVX512] Detect repeated constant patterns in BUILD_VECTOR suitable for broadcasting.
Check if a build_vector node includes a repeated constant pattern and replace it with a broadcast of that pattern. For example: "build_vector <0, 1, 2, 3, 0, 1, 2, 3>" would be replaced by "broadcast <0, 1, 2, 3>" Differential Revision: https://reviews.llvm.org/D26802 llvm-svn: 288804
This commit is contained in:
parent
e004d4bfc2
commit
86c00b799f
|
@ -6316,8 +6316,47 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
|
|||
return SDValue();
|
||||
}
|
||||
|
||||
/// Attempt to use the vbroadcast instruction to generate a splat value for a
|
||||
/// splat BUILD_VECTOR which uses a single scalar load, or a constant.
|
||||
static Constant *getConstantVector(MVT VT, APInt SplatValue,
|
||||
unsigned SplatBitSize, LLVMContext &C) {
|
||||
unsigned ScalarSize = VT.getScalarSizeInBits();
|
||||
unsigned NumElm = SplatBitSize / ScalarSize;
|
||||
|
||||
SmallVector<Constant *, 32> ConstantVec;
|
||||
for (unsigned i = 0; i < NumElm; i++) {
|
||||
APInt Val = SplatValue.lshr(ScalarSize * i).trunc(ScalarSize);
|
||||
Constant *Const;
|
||||
if (VT.isFloatingPoint()) {
|
||||
assert((ScalarSize == 32 || ScalarSize == 64) &&
|
||||
"Unsupported floating point scalar size");
|
||||
if (ScalarSize == 32)
|
||||
Const = ConstantFP::get(Type::getFloatTy(C), Val.bitsToFloat());
|
||||
else
|
||||
Const = ConstantFP::get(Type::getDoubleTy(C), Val.bitsToDouble());
|
||||
} else
|
||||
Const = Constant::getIntegerValue(Type::getIntNTy(C, ScalarSize), Val);
|
||||
ConstantVec.push_back(Const);
|
||||
}
|
||||
return ConstantVector::get(ArrayRef<Constant *>(ConstantVec));
|
||||
}
|
||||
|
||||
static bool isUseOfShuffle(SDNode *N) {
|
||||
for (auto *U : N->uses()) {
|
||||
if (isTargetShuffle(U->getOpcode()))
|
||||
return true;
|
||||
if (U->getOpcode() == ISD::BITCAST) // Ignore bitcasts
|
||||
return isUseOfShuffle(U);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Attempt to use the vbroadcast instruction to generate a splat value for the
|
||||
/// following cases:
|
||||
/// 1. A splat BUILD_VECTOR which uses:
|
||||
/// a. A single scalar load, or a constant.
|
||||
/// b. Repeated pattern of constants (e.g. <0,1,0,1> or <0,1,2,3,0,1,2,3>).
|
||||
/// 2. A splat shuffle which uses a scalar_to_vector node which comes from
|
||||
/// a scalar load, or a constant.
|
||||
///
|
||||
/// The VBROADCAST node is returned when a pattern is found,
|
||||
/// or SDValue() otherwise.
|
||||
static SDValue LowerVectorBroadcast(BuildVectorSDNode *BVOp, const X86Subtarget &Subtarget,
|
||||
|
@ -6339,8 +6378,82 @@ static SDValue LowerVectorBroadcast(BuildVectorSDNode *BVOp, const X86Subtarget
|
|||
|
||||
// We need a splat of a single value to use broadcast, and it doesn't
|
||||
// make any sense if the value is only in one element of the vector.
|
||||
if (!Ld || (VT.getVectorNumElements() - UndefElements.count()) <= 1)
|
||||
if (!Ld || (VT.getVectorNumElements() - UndefElements.count()) <= 1) {
|
||||
APInt SplatValue, Undef;
|
||||
unsigned SplatBitSize;
|
||||
bool HasUndef;
|
||||
// Check if this is a repeated constant pattern suitable for broadcasting.
|
||||
if (BVOp->isConstantSplat(SplatValue, Undef, SplatBitSize, HasUndef) &&
|
||||
SplatBitSize > VT.getScalarSizeInBits() &&
|
||||
SplatBitSize < VT.getSizeInBits()) {
|
||||
// Avoid replacing with broadcast when it's a use of a shuffle
|
||||
// instruction to preserve the present custom lowering of shuffles.
|
||||
if (isUseOfShuffle(BVOp) || BVOp->hasOneUse())
|
||||
return SDValue();
|
||||
// replace BUILD_VECTOR with broadcast of the repeated constants.
|
||||
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
|
||||
LLVMContext *Ctx = DAG.getContext();
|
||||
MVT PVT = TLI.getPointerTy(DAG.getDataLayout());
|
||||
if (Subtarget.hasAVX()) {
|
||||
if (SplatBitSize <= 64 && Subtarget.hasAVX2() &&
|
||||
!(SplatBitSize == 64 && Subtarget.is32Bit())) {
|
||||
// Splatted value can fit in one INTEGER constant in constant pool.
|
||||
// Load the constant and broadcast it.
|
||||
MVT CVT = MVT::getIntegerVT(SplatBitSize);
|
||||
Type *ScalarTy = Type::getIntNTy(*Ctx, SplatBitSize);
|
||||
Constant *C = Constant::getIntegerValue(ScalarTy, SplatValue);
|
||||
SDValue CP = DAG.getConstantPool(C, PVT);
|
||||
unsigned Repeat = VT.getSizeInBits() / SplatBitSize;
|
||||
|
||||
unsigned Alignment = cast<ConstantPoolSDNode>(CP)->getAlignment();
|
||||
Ld = DAG.getLoad(
|
||||
CVT, dl, DAG.getEntryNode(), CP,
|
||||
MachinePointerInfo::getConstantPool(DAG.getMachineFunction()),
|
||||
Alignment);
|
||||
SDValue Brdcst = DAG.getNode(X86ISD::VBROADCAST, dl,
|
||||
MVT::getVectorVT(CVT, Repeat), Ld);
|
||||
return DAG.getBitcast(VT, Brdcst);
|
||||
} else if (SplatBitSize == 32 || SplatBitSize == 64) {
|
||||
// Splatted value can fit in one FLOAT constant in constant pool.
|
||||
// Load the constant and broadcast it.
|
||||
// AVX have support for 32 and 64 bit broadcast for floats only.
|
||||
// No 64bit integer in 32bit subtarget.
|
||||
MVT CVT = MVT::getFloatingPointVT(SplatBitSize);
|
||||
Constant *C = SplatBitSize == 32
|
||||
? ConstantFP::get(Type::getFloatTy(*Ctx),
|
||||
SplatValue.bitsToFloat())
|
||||
: ConstantFP::get(Type::getDoubleTy(*Ctx),
|
||||
SplatValue.bitsToDouble());
|
||||
SDValue CP = DAG.getConstantPool(C, PVT);
|
||||
unsigned Repeat = VT.getSizeInBits() / SplatBitSize;
|
||||
|
||||
unsigned Alignment = cast<ConstantPoolSDNode>(CP)->getAlignment();
|
||||
Ld = DAG.getLoad(
|
||||
CVT, dl, DAG.getEntryNode(), CP,
|
||||
MachinePointerInfo::getConstantPool(DAG.getMachineFunction()),
|
||||
Alignment);
|
||||
SDValue Brdcst = DAG.getNode(X86ISD::VBROADCAST, dl,
|
||||
MVT::getVectorVT(CVT, Repeat), Ld);
|
||||
return DAG.getBitcast(VT, Brdcst);
|
||||
} else if (SplatBitSize > 64) {
|
||||
// Load the vector of constants and broadcast it.
|
||||
MVT CVT = VT.getScalarType();
|
||||
Constant *VecC = getConstantVector(VT, SplatValue, SplatBitSize,
|
||||
*Ctx);
|
||||
SDValue VCP = DAG.getConstantPool(VecC, PVT);
|
||||
unsigned NumElm = SplatBitSize / VT.getScalarSizeInBits();
|
||||
unsigned Alignment = cast<ConstantPoolSDNode>(VCP)->getAlignment();
|
||||
Ld = DAG.getLoad(
|
||||
MVT::getVectorVT(CVT, NumElm), dl, DAG.getEntryNode(), VCP,
|
||||
MachinePointerInfo::getConstantPool(DAG.getMachineFunction()),
|
||||
Alignment);
|
||||
SDValue Brdcst = DAG.getNode(X86ISD::SUBV_BROADCAST, dl, VT, Ld);
|
||||
return DAG.getBitcast(VT, Brdcst);
|
||||
}
|
||||
}
|
||||
}
|
||||
return SDValue();
|
||||
}
|
||||
|
||||
bool ConstSplatVal =
|
||||
(Ld.getOpcode() == ISD::Constant || Ld.getOpcode() == ISD::ConstantFP);
|
||||
|
|
|
@ -2132,7 +2132,7 @@ define void @avg_v64i8_const(<64 x i8>* %a) {
|
|||
; AVX512F-NEXT: vpmovzxbd {{.*#+}} zmm1 = mem[0],zero,zero,zero,mem[1],zero,zero,zero,mem[2],zero,zero,zero,mem[3],zero,zero,zero,mem[4],zero,zero,zero,mem[5],zero,zero,zero,mem[6],zero,zero,zero,mem[7],zero,zero,zero,mem[8],zero,zero,zero,mem[9],zero,zero,zero,mem[10],zero,zero,zero,mem[11],zero,zero,zero,mem[12],zero,zero,zero,mem[13],zero,zero,zero,mem[14],zero,zero,zero,mem[15],zero,zero,zero
|
||||
; AVX512F-NEXT: vpmovzxbd {{.*#+}} zmm2 = mem[0],zero,zero,zero,mem[1],zero,zero,zero,mem[2],zero,zero,zero,mem[3],zero,zero,zero,mem[4],zero,zero,zero,mem[5],zero,zero,zero,mem[6],zero,zero,zero,mem[7],zero,zero,zero,mem[8],zero,zero,zero,mem[9],zero,zero,zero,mem[10],zero,zero,zero,mem[11],zero,zero,zero,mem[12],zero,zero,zero,mem[13],zero,zero,zero,mem[14],zero,zero,zero,mem[15],zero,zero,zero
|
||||
; AVX512F-NEXT: vpmovzxbd {{.*#+}} zmm3 = mem[0],zero,zero,zero,mem[1],zero,zero,zero,mem[2],zero,zero,zero,mem[3],zero,zero,zero,mem[4],zero,zero,zero,mem[5],zero,zero,zero,mem[6],zero,zero,zero,mem[7],zero,zero,zero,mem[8],zero,zero,zero,mem[9],zero,zero,zero,mem[10],zero,zero,zero,mem[11],zero,zero,zero,mem[12],zero,zero,zero,mem[13],zero,zero,zero,mem[14],zero,zero,zero,mem[15],zero,zero,zero
|
||||
; AVX512F-NEXT: vmovdqa32 {{.*#+}} zmm4 = [1,2,3,4,5,6,7,8,1,2,3,4,5,6,7,8]
|
||||
; AVX512F-NEXT: vbroadcasti64x4 {{.*#+}} zmm4 = mem[0,1,2,3,0,1,2,3]
|
||||
; AVX512F-NEXT: vpaddd %zmm4, %zmm3, %zmm3
|
||||
; AVX512F-NEXT: vpaddd %zmm4, %zmm2, %zmm2
|
||||
; AVX512F-NEXT: vpaddd %zmm4, %zmm1, %zmm1
|
||||
|
@ -2405,7 +2405,7 @@ define void @avg_v32i16_const(<32 x i16>* %a) {
|
|||
; AVX512F: # BB#0:
|
||||
; AVX512F-NEXT: vpmovzxwd {{.*#+}} zmm0 = mem[0],zero,mem[1],zero,mem[2],zero,mem[3],zero,mem[4],zero,mem[5],zero,mem[6],zero,mem[7],zero,mem[8],zero,mem[9],zero,mem[10],zero,mem[11],zero,mem[12],zero,mem[13],zero,mem[14],zero,mem[15],zero
|
||||
; AVX512F-NEXT: vpmovzxwd {{.*#+}} zmm1 = mem[0],zero,mem[1],zero,mem[2],zero,mem[3],zero,mem[4],zero,mem[5],zero,mem[6],zero,mem[7],zero,mem[8],zero,mem[9],zero,mem[10],zero,mem[11],zero,mem[12],zero,mem[13],zero,mem[14],zero,mem[15],zero
|
||||
; AVX512F-NEXT: vmovdqa32 {{.*#+}} zmm2 = [1,2,3,4,5,6,7,8,1,2,3,4,5,6,7,8]
|
||||
; AVX512F-NEXT: vbroadcasti64x4 {{.*#+}} zmm2 = mem[0,1,2,3,0,1,2,3]
|
||||
; AVX512F-NEXT: vpaddd %zmm2, %zmm1, %zmm1
|
||||
; AVX512F-NEXT: vpaddd %zmm2, %zmm0, %zmm0
|
||||
; AVX512F-NEXT: vpsrld $1, %zmm0, %zmm0
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -153,14 +153,14 @@ define <32 x i16> @test7(<32 x i16> %a) {
|
|||
;
|
||||
; AVX2-LABEL: test7:
|
||||
; AVX2: # BB#0:
|
||||
; AVX2-NEXT: vmovdqa {{.*#+}} ymm2 = [2,2,4,8,128,1,512,2048,2,2,4,8,128,1,512,2048]
|
||||
; AVX2-NEXT: vbroadcasti128 {{.*#+}} ymm2 = mem[0,1,0,1]
|
||||
; AVX2-NEXT: vpmullw %ymm2, %ymm0, %ymm0
|
||||
; AVX2-NEXT: vpmullw %ymm2, %ymm1, %ymm1
|
||||
; AVX2-NEXT: retq
|
||||
;
|
||||
; AVX512-LABEL: test7:
|
||||
; AVX512: # BB#0:
|
||||
; AVX512-NEXT: vmovdqa {{.*#+}} ymm2 = [2,2,4,8,128,1,512,2048,2,2,4,8,128,1,512,2048]
|
||||
; AVX512-NEXT: vbroadcasti128 {{.*#+}} ymm2 = mem[0,1,0,1]
|
||||
; AVX512-NEXT: vpmullw %ymm2, %ymm0, %ymm0
|
||||
; AVX512-NEXT: vpmullw %ymm2, %ymm1, %ymm1
|
||||
; AVX512-NEXT: retq
|
||||
|
@ -183,7 +183,7 @@ define <16 x i32> @test8(<16 x i32> %a) {
|
|||
;
|
||||
; AVX2-LABEL: test8:
|
||||
; AVX2: # BB#0:
|
||||
; AVX2-NEXT: vmovdqa {{.*#+}} ymm2 = [1,1,2,3,1,1,2,3]
|
||||
; AVX2-NEXT: vbroadcasti128 {{.*#+}} ymm2 = mem[0,1,0,1]
|
||||
; AVX2-NEXT: vpsllvd %ymm2, %ymm0, %ymm0
|
||||
; AVX2-NEXT: vpsllvd %ymm2, %ymm1, %ymm1
|
||||
; AVX2-NEXT: retq
|
||||
|
|
|
@ -466,7 +466,7 @@ define <64 x i8> @shuffle_v64i8_63_zz_61_zz_59_zz_57_zz_55_zz_53_zz_51_zz_49_zz_
|
|||
define <64 x i8> @shuffle_v64i8_63_64_61_66_59_68_57_70_55_72_53_74_51_76_49_78_47_80_45_82_43_84_41_86_39_88_37_90_35_92_33_94_31_96_29_98_27_100_25_102_23_104_21_106_19_108_17_110_15_112_13_114_11_116_9_118_7_120_5_122_3_124_1_126(<64 x i8> %a, <64 x i8> %b) {
|
||||
; AVX512F-LABEL: shuffle_v64i8_63_64_61_66_59_68_57_70_55_72_53_74_51_76_49_78_47_80_45_82_43_84_41_86_39_88_37_90_35_92_33_94_31_96_29_98_27_100_25_102_23_104_21_106_19_108_17_110_15_112_13_114_11_116_9_118_7_120_5_122_3_124_1_126:
|
||||
; AVX512F: # BB#0:
|
||||
; AVX512F-NEXT: vmovdqa {{.*#+}} ymm4 = [255,0,255,0,255,0,255,0,255,0,255,0,255,0,255,0,255,0,255,0,255,0,255,0,255,0,255,0,255,0,255,0]
|
||||
; AVX512F-NEXT: vpbroadcastw {{.*}}(%rip), %ymm4
|
||||
; AVX512F-NEXT: vpblendvb %ymm4, %ymm2, %ymm1, %ymm1
|
||||
; AVX512F-NEXT: vperm2i128 {{.*#+}} ymm2 = ymm1[2,3,0,1]
|
||||
; AVX512F-NEXT: vpblendvb %ymm4, %ymm1, %ymm2, %ymm1
|
||||
|
@ -482,7 +482,7 @@ define <64 x i8> @shuffle_v64i8_63_64_61_66_59_68_57_70_55_72_53_74_51_76_49_78_
|
|||
; AVX512BW-LABEL: shuffle_v64i8_63_64_61_66_59_68_57_70_55_72_53_74_51_76_49_78_47_80_45_82_43_84_41_86_39_88_37_90_35_92_33_94_31_96_29_98_27_100_25_102_23_104_21_106_19_108_17_110_15_112_13_114_11_116_9_118_7_120_5_122_3_124_1_126:
|
||||
; AVX512BW: # BB#0:
|
||||
; AVX512BW-NEXT: vextracti64x4 $1, %zmm1, %ymm2
|
||||
; AVX512BW-NEXT: vmovdqa {{.*#+}} ymm3 = [255,0,255,0,255,0,255,0,255,0,255,0,255,0,255,0,255,0,255,0,255,0,255,0,255,0,255,0,255,0,255,0]
|
||||
; AVX512BW-NEXT: vpbroadcastw {{.*}}(%rip), %ymm3
|
||||
; AVX512BW-NEXT: vpblendvb %ymm3, %ymm2, %ymm0, %ymm2
|
||||
; AVX512BW-NEXT: vperm2i128 {{.*#+}} ymm4 = ymm2[2,3,0,1]
|
||||
; AVX512BW-NEXT: vpblendvb %ymm3, %ymm2, %ymm4, %ymm2
|
||||
|
@ -498,7 +498,7 @@ define <64 x i8> @shuffle_v64i8_63_64_61_66_59_68_57_70_55_72_53_74_51_76_49_78_
|
|||
;
|
||||
; AVX512DQ-LABEL: shuffle_v64i8_63_64_61_66_59_68_57_70_55_72_53_74_51_76_49_78_47_80_45_82_43_84_41_86_39_88_37_90_35_92_33_94_31_96_29_98_27_100_25_102_23_104_21_106_19_108_17_110_15_112_13_114_11_116_9_118_7_120_5_122_3_124_1_126:
|
||||
; AVX512DQ: # BB#0:
|
||||
; AVX512DQ-NEXT: vmovdqa {{.*#+}} ymm4 = [255,0,255,0,255,0,255,0,255,0,255,0,255,0,255,0,255,0,255,0,255,0,255,0,255,0,255,0,255,0,255,0]
|
||||
; AVX512DQ-NEXT: vpbroadcastw {{.*}}(%rip), %ymm4
|
||||
; AVX512DQ-NEXT: vpblendvb %ymm4, %ymm2, %ymm1, %ymm1
|
||||
; AVX512DQ-NEXT: vperm2i128 {{.*#+}} ymm2 = ymm1[2,3,0,1]
|
||||
; AVX512DQ-NEXT: vpblendvb %ymm4, %ymm1, %ymm2, %ymm1
|
||||
|
|
Loading…
Reference in New Issue