[X86][AVX] lowerShuffleAsBroadcast - improve load folding by avoiding bitcasts

AVX1 broadcasts were failing as we were adding bitcasts that caused MayFoldLoad's hasOneUse to return false.

This patch stops introducing bitcasts so early and also replaces the broadcast index scaling through bitcasts (which can't succeed in some cases) to instead just keep track of the bitoffset which can be converted back to the broadcast index later on.

Differential Revision: https://reviews.llvm.org/D58888

llvm-svn: 356043
This commit is contained in:
Simon Pilgrim 2019-03-13 12:20:39 +00:00
parent 8eacea80ad
commit 7abbd70300
2 changed files with 67 additions and 145 deletions
llvm
lib/Target/X86
test/CodeGen/X86

View File

@ -11930,6 +11930,7 @@ static SDValue lowerShuffleAsBroadcast(const SDLoc &DL, MVT VT, SDValue V1,
// With MOVDDUP (v2f64) we can broadcast from a register or a load, otherwise
// we can only broadcast from a register with AVX2.
unsigned NumElts = Mask.size();
unsigned NumEltBits = VT.getScalarSizeInBits();
unsigned Opcode = (VT == MVT::v2f64 && !Subtarget.hasAVX2())
? X86ISD::MOVDDUP
: X86ISD::VBROADCAST;
@ -11953,29 +11954,19 @@ static SDValue lowerShuffleAsBroadcast(const SDLoc &DL, MVT VT, SDValue V1,
// Go up the chain of (vector) values to find a scalar load that we can
// combine with the broadcast.
int BitOffset = BroadcastIdx * NumEltBits;
SDValue V = V1;
for (;;) {
switch (V.getOpcode()) {
case ISD::BITCAST: {
// Peek through bitcasts as long as BroadcastIdx can be adjusted.
SDValue VSrc = V.getOperand(0);
unsigned NumEltBits = V.getScalarValueSizeInBits();
unsigned NumSrcBits = VSrc.getScalarValueSizeInBits();
if ((NumEltBits % NumSrcBits) == 0)
BroadcastIdx *= (NumEltBits / NumSrcBits);
else if ((NumSrcBits % NumEltBits) == 0 &&
(BroadcastIdx % (NumSrcBits / NumEltBits)) == 0)
BroadcastIdx /= (NumSrcBits / NumEltBits);
else
break;
V = VSrc;
V = V.getOperand(0);
continue;
}
case ISD::CONCAT_VECTORS: {
int OperandSize =
V.getOperand(0).getSimpleValueType().getVectorNumElements();
V = V.getOperand(BroadcastIdx / OperandSize);
BroadcastIdx %= OperandSize;
int OpBitWidth = V.getOperand(0).getValueSizeInBits();
int OpIdx = BitOffset / OpBitWidth;
V = V.getOperand(OpIdx);
BitOffset %= OpBitWidth;
continue;
}
case ISD::INSERT_SUBVECTOR: {
@ -11984,11 +11975,13 @@ static SDValue lowerShuffleAsBroadcast(const SDLoc &DL, MVT VT, SDValue V1,
if (!ConstantIdx)
break;
int BeginIdx = (int)ConstantIdx->getZExtValue();
int EndIdx =
BeginIdx + (int)VInner.getSimpleValueType().getVectorNumElements();
if (BroadcastIdx >= BeginIdx && BroadcastIdx < EndIdx) {
BroadcastIdx -= BeginIdx;
int EltBitWidth = VOuter.getScalarValueSizeInBits();
int Idx = (int)ConstantIdx->getZExtValue();
int NumSubElts = (int)VInner.getSimpleValueType().getVectorNumElements();
int BeginOffset = Idx * EltBitWidth;
int EndOffset = BeginOffset + NumSubElts * EltBitWidth;
if (BeginOffset <= BitOffset && BitOffset < EndOffset) {
BitOffset -= BeginOffset;
V = VInner;
} else {
V = VOuter;
@ -11998,48 +11991,34 @@ static SDValue lowerShuffleAsBroadcast(const SDLoc &DL, MVT VT, SDValue V1,
}
break;
}
assert((BitOffset % NumEltBits) == 0 && "Illegal bit-offset");
BroadcastIdx = BitOffset / NumEltBits;
// Ensure the source vector and BroadcastIdx are for a suitable type.
if (VT.getScalarSizeInBits() != V.getScalarValueSizeInBits()) {
unsigned NumEltBits = VT.getScalarSizeInBits();
unsigned NumSrcBits = V.getScalarValueSizeInBits();
if ((NumSrcBits % NumEltBits) == 0)
BroadcastIdx *= (NumSrcBits / NumEltBits);
else if ((NumEltBits % NumSrcBits) == 0 &&
(BroadcastIdx % (NumEltBits / NumSrcBits)) == 0)
BroadcastIdx /= (NumEltBits / NumSrcBits);
else
return SDValue();
unsigned NumSrcElts = V.getValueSizeInBits() / NumEltBits;
MVT SrcVT = MVT::getVectorVT(VT.getScalarType(), NumSrcElts);
V = DAG.getBitcast(SrcVT, V);
}
// Do we need to bitcast the source to retrieve the original broadcast index?
bool BitCastSrc = V.getScalarValueSizeInBits() != NumEltBits;
// Check if this is a broadcast of a scalar. We special case lowering
// for scalars so that we can more effectively fold with loads.
// First, look through bitcast: if the original value has a larger element
// type than the shuffle, the broadcast element is in essence truncated.
// Make that explicit to ease folding.
if (V.getOpcode() == ISD::BITCAST && VT.isInteger())
// If the original value has a larger element type than the shuffle, the
// broadcast element is in essence truncated. Make that explicit to ease
// folding.
if (BitCastSrc && VT.isInteger())
if (SDValue TruncBroadcast = lowerShuffleAsTruncBroadcast(
DL, VT, V.getOperand(0), BroadcastIdx, Subtarget, DAG))
DL, VT, V, BroadcastIdx, Subtarget, DAG))
return TruncBroadcast;
MVT BroadcastVT = VT;
// Peek through any bitcast (only useful for loads).
SDValue BC = peekThroughBitcasts(V);
// Also check the simpler case, where we can directly reuse the scalar.
if ((V.getOpcode() == ISD::BUILD_VECTOR && V.hasOneUse()) ||
(V.getOpcode() == ISD::SCALAR_TO_VECTOR && BroadcastIdx == 0)) {
if (!BitCastSrc &&
((V.getOpcode() == ISD::BUILD_VECTOR && V.hasOneUse()) ||
(V.getOpcode() == ISD::SCALAR_TO_VECTOR && BroadcastIdx == 0))) {
V = V.getOperand(BroadcastIdx);
// If we can't broadcast from a register, check that the input is a load.
if (!BroadcastFromReg && !isShuffleFoldableLoad(V))
return SDValue();
} else if (MayFoldLoad(BC) && !cast<LoadSDNode>(BC)->isVolatile()) {
} else if (MayFoldLoad(V) && !cast<LoadSDNode>(V)->isVolatile()) {
// 32-bit targets need to load i64 as a f64 and then bitcast the result.
if (!Subtarget.is64Bit() && VT.getScalarType() == MVT::i64) {
BroadcastVT = MVT::getVectorVT(MVT::f64, VT.getVectorNumElements());
@ -12050,10 +12029,11 @@ static SDValue lowerShuffleAsBroadcast(const SDLoc &DL, MVT VT, SDValue V1,
// If we are broadcasting a load that is only used by the shuffle
// then we can reduce the vector load to the broadcasted scalar load.
LoadSDNode *Ld = cast<LoadSDNode>(BC);
LoadSDNode *Ld = cast<LoadSDNode>(V);
SDValue BaseAddr = Ld->getOperand(1);
EVT SVT = BroadcastVT.getScalarType();
unsigned Offset = BroadcastIdx * SVT.getStoreSize();
assert((Offset * 8) == BitOffset && "Unexpected bit-offset");
SDValue NewAddr = DAG.getMemBasePlusOffset(BaseAddr, Offset, DL);
V = DAG.getLoad(SVT, DL, Ld->getChain(), NewAddr,
DAG.getMachineFunction().getMachineMemOperand(
@ -12062,7 +12042,7 @@ static SDValue lowerShuffleAsBroadcast(const SDLoc &DL, MVT VT, SDValue V1,
} else if (!BroadcastFromReg) {
// We can't broadcast from a vector register.
return SDValue();
} else if (BroadcastIdx != 0) {
} else if (BitOffset != 0) {
// We can only broadcast from the zero-element of a vector register,
// but it can be advantageous to broadcast from the zero-element of a
// subvector.
@ -12074,18 +12054,15 @@ static SDValue lowerShuffleAsBroadcast(const SDLoc &DL, MVT VT, SDValue V1,
return SDValue();
// Only broadcast the zero-element of a 128-bit subvector.
unsigned EltSize = VT.getScalarSizeInBits();
if (((BroadcastIdx * EltSize) % 128) != 0)
if ((BitOffset % 128) != 0)
return SDValue();
// The shuffle input might have been a bitcast we looked through; look at
// the original input vector. Emit an EXTRACT_SUBVECTOR of that type; we'll
// later bitcast it to BroadcastVT.
assert(V.getScalarValueSizeInBits() == BroadcastVT.getScalarSizeInBits() &&
"Unexpected vector element size");
assert((BitOffset % V.getScalarValueSizeInBits()) == 0 &&
"Unexpected bit-offset");
assert((V.getValueSizeInBits() == 256 || V.getValueSizeInBits() == 512) &&
"Unexpected vector size");
V = extract128BitVector(V, BroadcastIdx, DAG, DL);
unsigned ExtractIdx = BitOffset / V.getScalarValueSizeInBits();
V = extract128BitVector(V, ExtractIdx, DAG, DL);
}
if (Opcode == X86ISD::MOVDDUP && !V.getValueType().isVector())
@ -12093,21 +12070,21 @@ static SDValue lowerShuffleAsBroadcast(const SDLoc &DL, MVT VT, SDValue V1,
DAG.getBitcast(MVT::f64, V));
// Bitcast back to the same scalar type as BroadcastVT.
MVT SrcVT = V.getSimpleValueType();
if (SrcVT.getScalarType() != BroadcastVT.getScalarType()) {
assert(SrcVT.getScalarSizeInBits() == BroadcastVT.getScalarSizeInBits() &&
if (V.getValueType().getScalarType() != BroadcastVT.getScalarType()) {
assert(NumEltBits == BroadcastVT.getScalarSizeInBits() &&
"Unexpected vector element size");
if (SrcVT.isVector()) {
unsigned NumSrcElts = SrcVT.getVectorNumElements();
SrcVT = MVT::getVectorVT(BroadcastVT.getScalarType(), NumSrcElts);
MVT ExtVT;
if (V.getValueType().isVector()) {
unsigned NumSrcElts = V.getValueSizeInBits() / NumEltBits;
ExtVT = MVT::getVectorVT(BroadcastVT.getScalarType(), NumSrcElts);
} else {
SrcVT = BroadcastVT.getScalarType();
ExtVT = BroadcastVT.getScalarType();
}
V = DAG.getBitcast(SrcVT, V);
V = DAG.getBitcast(ExtVT, V);
}
// 32-bit targets need to load i64 as a f64 and then bitcast the result.
if (!Subtarget.is64Bit() && SrcVT == MVT::i64) {
if (!Subtarget.is64Bit() && V.getValueType() == MVT::i64) {
V = DAG.getBitcast(MVT::f64, V);
unsigned NumBroadcastElts = BroadcastVT.getVectorNumElements();
BroadcastVT = MVT::getVectorVT(MVT::f64, NumBroadcastElts);
@ -12116,9 +12093,9 @@ static SDValue lowerShuffleAsBroadcast(const SDLoc &DL, MVT VT, SDValue V1,
// We only support broadcasting from 128-bit vectors to minimize the
// number of patterns we need to deal with in isel. So extract down to
// 128-bits, removing as many bitcasts as possible.
if (SrcVT.getSizeInBits() > 128) {
MVT ExtVT = MVT::getVectorVT(SrcVT.getScalarType(),
128 / SrcVT.getScalarSizeInBits());
if (V.getValueSizeInBits() > 128) {
MVT ExtVT = V.getSimpleValueType().getScalarType();
ExtVT = MVT::getVectorVT(ExtVT, 128 / ExtVT.getScalarSizeInBits());
V = extract128BitVector(peekThroughBitcasts(V), 0, DAG, DL);
V = DAG.getBitcast(ExtVT, V);
}

View File

@ -110,21 +110,10 @@ define <8 x i32> @load_splat_8i32_4i32_01010101(<4 x i32>* %ptr) nounwind uwtabl
; SSE-NEXT: movdqa %xmm0, %xmm1
; SSE-NEXT: retq
;
; AVX1-LABEL: load_splat_8i32_4i32_01010101:
; AVX1: # %bb.0: # %entry
; AVX1-NEXT: vpermilps {{.*#+}} xmm0 = mem[0,1,0,1]
; AVX1-NEXT: vinsertf128 $1, %xmm0, %ymm0, %ymm0
; AVX1-NEXT: retq
;
; AVX2-LABEL: load_splat_8i32_4i32_01010101:
; AVX2: # %bb.0: # %entry
; AVX2-NEXT: vbroadcastsd (%rdi), %ymm0
; AVX2-NEXT: retq
;
; AVX512-LABEL: load_splat_8i32_4i32_01010101:
; AVX512: # %bb.0: # %entry
; AVX512-NEXT: vbroadcastsd (%rdi), %ymm0
; AVX512-NEXT: retq
; AVX-LABEL: load_splat_8i32_4i32_01010101:
; AVX: # %bb.0: # %entry
; AVX-NEXT: vbroadcastsd (%rdi), %ymm0
; AVX-NEXT: retq
entry:
%ld = load <4 x i32>, <4 x i32>* %ptr
%ret = shufflevector <4 x i32> %ld, <4 x i32> undef, <8 x i32> <i32 0, i32 1, i32 0, i32 1, i32 0, i32 1, i32 0, i32 1>
@ -207,21 +196,10 @@ define <16 x i16> @load_splat_16i16_8i16_0101010101010101(<8 x i16>* %ptr) nounw
; SSE-NEXT: movdqa %xmm0, %xmm1
; SSE-NEXT: retq
;
; AVX1-LABEL: load_splat_16i16_8i16_0101010101010101:
; AVX1: # %bb.0: # %entry
; AVX1-NEXT: vpermilps {{.*#+}} xmm0 = mem[0,0,0,0]
; AVX1-NEXT: vinsertf128 $1, %xmm0, %ymm0, %ymm0
; AVX1-NEXT: retq
;
; AVX2-LABEL: load_splat_16i16_8i16_0101010101010101:
; AVX2: # %bb.0: # %entry
; AVX2-NEXT: vbroadcastss (%rdi), %ymm0
; AVX2-NEXT: retq
;
; AVX512-LABEL: load_splat_16i16_8i16_0101010101010101:
; AVX512: # %bb.0: # %entry
; AVX512-NEXT: vbroadcastss (%rdi), %ymm0
; AVX512-NEXT: retq
; AVX-LABEL: load_splat_16i16_8i16_0101010101010101:
; AVX: # %bb.0: # %entry
; AVX-NEXT: vbroadcastss (%rdi), %ymm0
; AVX-NEXT: retq
entry:
%ld = load <8 x i16>, <8 x i16>* %ptr
%ret = shufflevector <8 x i16> %ld, <8 x i16> undef, <16 x i32> <i32 0, i32 1, i32 0, i32 1, i32 0, i32 1, i32 0, i32 1, i32 0, i32 1, i32 0, i32 1, i32 0, i32 1, i32 0, i32 1>
@ -235,21 +213,10 @@ define <16 x i16> @load_splat_16i16_8i16_0123012301230123(<8 x i16>* %ptr) nounw
; SSE-NEXT: movdqa %xmm0, %xmm1
; SSE-NEXT: retq
;
; AVX1-LABEL: load_splat_16i16_8i16_0123012301230123:
; AVX1: # %bb.0: # %entry
; AVX1-NEXT: vpermilps {{.*#+}} xmm0 = mem[0,1,0,1]
; AVX1-NEXT: vinsertf128 $1, %xmm0, %ymm0, %ymm0
; AVX1-NEXT: retq
;
; AVX2-LABEL: load_splat_16i16_8i16_0123012301230123:
; AVX2: # %bb.0: # %entry
; AVX2-NEXT: vbroadcastsd (%rdi), %ymm0
; AVX2-NEXT: retq
;
; AVX512-LABEL: load_splat_16i16_8i16_0123012301230123:
; AVX512: # %bb.0: # %entry
; AVX512-NEXT: vbroadcastsd (%rdi), %ymm0
; AVX512-NEXT: retq
; AVX-LABEL: load_splat_16i16_8i16_0123012301230123:
; AVX: # %bb.0: # %entry
; AVX-NEXT: vbroadcastsd (%rdi), %ymm0
; AVX-NEXT: retq
entry:
%ld = load <8 x i16>, <8 x i16>* %ptr
%ret = shufflevector <8 x i16> %ld, <8 x i16> undef, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 0, i32 1, i32 2, i32 3,i32 0, i32 1, i32 2, i32 3, i32 0, i32 1, i32 2, i32 3>
@ -407,21 +374,10 @@ define <32 x i8> @load_splat_32i8_16i8_01230123012301230123012301230123(<16 x i8
; SSE-NEXT: movdqa %xmm0, %xmm1
; SSE-NEXT: retq
;
; AVX1-LABEL: load_splat_32i8_16i8_01230123012301230123012301230123:
; AVX1: # %bb.0: # %entry
; AVX1-NEXT: vpermilps {{.*#+}} xmm0 = mem[0,0,0,0]
; AVX1-NEXT: vinsertf128 $1, %xmm0, %ymm0, %ymm0
; AVX1-NEXT: retq
;
; AVX2-LABEL: load_splat_32i8_16i8_01230123012301230123012301230123:
; AVX2: # %bb.0: # %entry
; AVX2-NEXT: vbroadcastss (%rdi), %ymm0
; AVX2-NEXT: retq
;
; AVX512-LABEL: load_splat_32i8_16i8_01230123012301230123012301230123:
; AVX512: # %bb.0: # %entry
; AVX512-NEXT: vbroadcastss (%rdi), %ymm0
; AVX512-NEXT: retq
; AVX-LABEL: load_splat_32i8_16i8_01230123012301230123012301230123:
; AVX: # %bb.0: # %entry
; AVX-NEXT: vbroadcastss (%rdi), %ymm0
; AVX-NEXT: retq
entry:
%ld = load <16 x i8>, <16 x i8>* %ptr
%ret = shufflevector <16 x i8> %ld, <16 x i8> undef, <32 x i32> <i32 0, i32 1, i32 2, i32 3, i32 0, i32 1, i32 2, i32 3, i32 0, i32 1, i32 2, i32 3, i32 0, i32 1, i32 2, i32 3, i32 0, i32 1, i32 2, i32 3, i32 0, i32 1, i32 2, i32 3, i32 0, i32 1, i32 2, i32 3, i32 0, i32 1, i32 2, i32 3>
@ -435,21 +391,10 @@ define <32 x i8> @load_splat_32i8_16i8_01234567012345670123456701234567(<16 x i8
; SSE-NEXT: movdqa %xmm0, %xmm1
; SSE-NEXT: retq
;
; AVX1-LABEL: load_splat_32i8_16i8_01234567012345670123456701234567:
; AVX1: # %bb.0: # %entry
; AVX1-NEXT: vpermilps {{.*#+}} xmm0 = mem[0,1,0,1]
; AVX1-NEXT: vinsertf128 $1, %xmm0, %ymm0, %ymm0
; AVX1-NEXT: retq
;
; AVX2-LABEL: load_splat_32i8_16i8_01234567012345670123456701234567:
; AVX2: # %bb.0: # %entry
; AVX2-NEXT: vbroadcastsd (%rdi), %ymm0
; AVX2-NEXT: retq
;
; AVX512-LABEL: load_splat_32i8_16i8_01234567012345670123456701234567:
; AVX512: # %bb.0: # %entry
; AVX512-NEXT: vbroadcastsd (%rdi), %ymm0
; AVX512-NEXT: retq
; AVX-LABEL: load_splat_32i8_16i8_01234567012345670123456701234567:
; AVX: # %bb.0: # %entry
; AVX-NEXT: vbroadcastsd (%rdi), %ymm0
; AVX-NEXT: retq
entry:
%ld = load <16 x i8>, <16 x i8>* %ptr
%ret = shufflevector <16 x i8> %ld, <16 x i8> undef, <32 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>