diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index e8d47c62b317..fd7bfcbea33d 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -11930,6 +11930,7 @@ static SDValue lowerShuffleAsBroadcast(const SDLoc &DL, MVT VT, SDValue V1, // With MOVDDUP (v2f64) we can broadcast from a register or a load, otherwise // we can only broadcast from a register with AVX2. unsigned NumElts = Mask.size(); + unsigned NumEltBits = VT.getScalarSizeInBits(); unsigned Opcode = (VT == MVT::v2f64 && !Subtarget.hasAVX2()) ? X86ISD::MOVDDUP : X86ISD::VBROADCAST; @@ -11953,29 +11954,19 @@ static SDValue lowerShuffleAsBroadcast(const SDLoc &DL, MVT VT, SDValue V1, // Go up the chain of (vector) values to find a scalar load that we can // combine with the broadcast. + int BitOffset = BroadcastIdx * NumEltBits; SDValue V = V1; for (;;) { switch (V.getOpcode()) { case ISD::BITCAST: { - // Peek through bitcasts as long as BroadcastIdx can be adjusted. - SDValue VSrc = V.getOperand(0); - unsigned NumEltBits = V.getScalarValueSizeInBits(); - unsigned NumSrcBits = VSrc.getScalarValueSizeInBits(); - if ((NumEltBits % NumSrcBits) == 0) - BroadcastIdx *= (NumEltBits / NumSrcBits); - else if ((NumSrcBits % NumEltBits) == 0 && - (BroadcastIdx % (NumSrcBits / NumEltBits)) == 0) - BroadcastIdx /= (NumSrcBits / NumEltBits); - else - break; - V = VSrc; + V = V.getOperand(0); continue; } case ISD::CONCAT_VECTORS: { - int OperandSize = - V.getOperand(0).getSimpleValueType().getVectorNumElements(); - V = V.getOperand(BroadcastIdx / OperandSize); - BroadcastIdx %= OperandSize; + int OpBitWidth = V.getOperand(0).getValueSizeInBits(); + int OpIdx = BitOffset / OpBitWidth; + V = V.getOperand(OpIdx); + BitOffset %= OpBitWidth; continue; } case ISD::INSERT_SUBVECTOR: { @@ -11984,11 +11975,13 @@ static SDValue lowerShuffleAsBroadcast(const SDLoc &DL, MVT VT, SDValue V1, if (!ConstantIdx) break; - int BeginIdx = (int)ConstantIdx->getZExtValue(); - int EndIdx = - BeginIdx + (int)VInner.getSimpleValueType().getVectorNumElements(); - if (BroadcastIdx >= BeginIdx && BroadcastIdx < EndIdx) { - BroadcastIdx -= BeginIdx; + int EltBitWidth = VOuter.getScalarValueSizeInBits(); + int Idx = (int)ConstantIdx->getZExtValue(); + int NumSubElts = (int)VInner.getSimpleValueType().getVectorNumElements(); + int BeginOffset = Idx * EltBitWidth; + int EndOffset = BeginOffset + NumSubElts * EltBitWidth; + if (BeginOffset <= BitOffset && BitOffset < EndOffset) { + BitOffset -= BeginOffset; V = VInner; } else { V = VOuter; @@ -11998,48 +11991,34 @@ static SDValue lowerShuffleAsBroadcast(const SDLoc &DL, MVT VT, SDValue V1, } break; } + assert((BitOffset % NumEltBits) == 0 && "Illegal bit-offset"); + BroadcastIdx = BitOffset / NumEltBits; - // Ensure the source vector and BroadcastIdx are for a suitable type. - if (VT.getScalarSizeInBits() != V.getScalarValueSizeInBits()) { - unsigned NumEltBits = VT.getScalarSizeInBits(); - unsigned NumSrcBits = V.getScalarValueSizeInBits(); - if ((NumSrcBits % NumEltBits) == 0) - BroadcastIdx *= (NumSrcBits / NumEltBits); - else if ((NumEltBits % NumSrcBits) == 0 && - (BroadcastIdx % (NumEltBits / NumSrcBits)) == 0) - BroadcastIdx /= (NumEltBits / NumSrcBits); - else - return SDValue(); - - unsigned NumSrcElts = V.getValueSizeInBits() / NumEltBits; - MVT SrcVT = MVT::getVectorVT(VT.getScalarType(), NumSrcElts); - V = DAG.getBitcast(SrcVT, V); - } + // Do we need to bitcast the source to retrieve the original broadcast index? + bool BitCastSrc = V.getScalarValueSizeInBits() != NumEltBits; // Check if this is a broadcast of a scalar. We special case lowering // for scalars so that we can more effectively fold with loads. - // First, look through bitcast: if the original value has a larger element - // type than the shuffle, the broadcast element is in essence truncated. - // Make that explicit to ease folding. - if (V.getOpcode() == ISD::BITCAST && VT.isInteger()) + // If the original value has a larger element type than the shuffle, the + // broadcast element is in essence truncated. Make that explicit to ease + // folding. + if (BitCastSrc && VT.isInteger()) if (SDValue TruncBroadcast = lowerShuffleAsTruncBroadcast( - DL, VT, V.getOperand(0), BroadcastIdx, Subtarget, DAG)) + DL, VT, V, BroadcastIdx, Subtarget, DAG)) return TruncBroadcast; MVT BroadcastVT = VT; - // Peek through any bitcast (only useful for loads). - SDValue BC = peekThroughBitcasts(V); - // Also check the simpler case, where we can directly reuse the scalar. - if ((V.getOpcode() == ISD::BUILD_VECTOR && V.hasOneUse()) || - (V.getOpcode() == ISD::SCALAR_TO_VECTOR && BroadcastIdx == 0)) { + if (!BitCastSrc && + ((V.getOpcode() == ISD::BUILD_VECTOR && V.hasOneUse()) || + (V.getOpcode() == ISD::SCALAR_TO_VECTOR && BroadcastIdx == 0))) { V = V.getOperand(BroadcastIdx); // If we can't broadcast from a register, check that the input is a load. if (!BroadcastFromReg && !isShuffleFoldableLoad(V)) return SDValue(); - } else if (MayFoldLoad(BC) && !cast(BC)->isVolatile()) { + } else if (MayFoldLoad(V) && !cast(V)->isVolatile()) { // 32-bit targets need to load i64 as a f64 and then bitcast the result. if (!Subtarget.is64Bit() && VT.getScalarType() == MVT::i64) { BroadcastVT = MVT::getVectorVT(MVT::f64, VT.getVectorNumElements()); @@ -12050,10 +12029,11 @@ static SDValue lowerShuffleAsBroadcast(const SDLoc &DL, MVT VT, SDValue V1, // If we are broadcasting a load that is only used by the shuffle // then we can reduce the vector load to the broadcasted scalar load. - LoadSDNode *Ld = cast(BC); + LoadSDNode *Ld = cast(V); SDValue BaseAddr = Ld->getOperand(1); EVT SVT = BroadcastVT.getScalarType(); unsigned Offset = BroadcastIdx * SVT.getStoreSize(); + assert((Offset * 8) == BitOffset && "Unexpected bit-offset"); SDValue NewAddr = DAG.getMemBasePlusOffset(BaseAddr, Offset, DL); V = DAG.getLoad(SVT, DL, Ld->getChain(), NewAddr, DAG.getMachineFunction().getMachineMemOperand( @@ -12062,7 +12042,7 @@ static SDValue lowerShuffleAsBroadcast(const SDLoc &DL, MVT VT, SDValue V1, } else if (!BroadcastFromReg) { // We can't broadcast from a vector register. return SDValue(); - } else if (BroadcastIdx != 0) { + } else if (BitOffset != 0) { // We can only broadcast from the zero-element of a vector register, // but it can be advantageous to broadcast from the zero-element of a // subvector. @@ -12074,18 +12054,15 @@ static SDValue lowerShuffleAsBroadcast(const SDLoc &DL, MVT VT, SDValue V1, return SDValue(); // Only broadcast the zero-element of a 128-bit subvector. - unsigned EltSize = VT.getScalarSizeInBits(); - if (((BroadcastIdx * EltSize) % 128) != 0) + if ((BitOffset % 128) != 0) return SDValue(); - // The shuffle input might have been a bitcast we looked through; look at - // the original input vector. Emit an EXTRACT_SUBVECTOR of that type; we'll - // later bitcast it to BroadcastVT. - assert(V.getScalarValueSizeInBits() == BroadcastVT.getScalarSizeInBits() && - "Unexpected vector element size"); + assert((BitOffset % V.getScalarValueSizeInBits()) == 0 && + "Unexpected bit-offset"); assert((V.getValueSizeInBits() == 256 || V.getValueSizeInBits() == 512) && "Unexpected vector size"); - V = extract128BitVector(V, BroadcastIdx, DAG, DL); + unsigned ExtractIdx = BitOffset / V.getScalarValueSizeInBits(); + V = extract128BitVector(V, ExtractIdx, DAG, DL); } if (Opcode == X86ISD::MOVDDUP && !V.getValueType().isVector()) @@ -12093,21 +12070,21 @@ static SDValue lowerShuffleAsBroadcast(const SDLoc &DL, MVT VT, SDValue V1, DAG.getBitcast(MVT::f64, V)); // Bitcast back to the same scalar type as BroadcastVT. - MVT SrcVT = V.getSimpleValueType(); - if (SrcVT.getScalarType() != BroadcastVT.getScalarType()) { - assert(SrcVT.getScalarSizeInBits() == BroadcastVT.getScalarSizeInBits() && + if (V.getValueType().getScalarType() != BroadcastVT.getScalarType()) { + assert(NumEltBits == BroadcastVT.getScalarSizeInBits() && "Unexpected vector element size"); - if (SrcVT.isVector()) { - unsigned NumSrcElts = SrcVT.getVectorNumElements(); - SrcVT = MVT::getVectorVT(BroadcastVT.getScalarType(), NumSrcElts); + MVT ExtVT; + if (V.getValueType().isVector()) { + unsigned NumSrcElts = V.getValueSizeInBits() / NumEltBits; + ExtVT = MVT::getVectorVT(BroadcastVT.getScalarType(), NumSrcElts); } else { - SrcVT = BroadcastVT.getScalarType(); + ExtVT = BroadcastVT.getScalarType(); } - V = DAG.getBitcast(SrcVT, V); + V = DAG.getBitcast(ExtVT, V); } // 32-bit targets need to load i64 as a f64 and then bitcast the result. - if (!Subtarget.is64Bit() && SrcVT == MVT::i64) { + if (!Subtarget.is64Bit() && V.getValueType() == MVT::i64) { V = DAG.getBitcast(MVT::f64, V); unsigned NumBroadcastElts = BroadcastVT.getVectorNumElements(); BroadcastVT = MVT::getVectorVT(MVT::f64, NumBroadcastElts); @@ -12116,9 +12093,9 @@ static SDValue lowerShuffleAsBroadcast(const SDLoc &DL, MVT VT, SDValue V1, // We only support broadcasting from 128-bit vectors to minimize the // number of patterns we need to deal with in isel. So extract down to // 128-bits, removing as many bitcasts as possible. - if (SrcVT.getSizeInBits() > 128) { - MVT ExtVT = MVT::getVectorVT(SrcVT.getScalarType(), - 128 / SrcVT.getScalarSizeInBits()); + if (V.getValueSizeInBits() > 128) { + MVT ExtVT = V.getSimpleValueType().getScalarType(); + ExtVT = MVT::getVectorVT(ExtVT, 128 / ExtVT.getScalarSizeInBits()); V = extract128BitVector(peekThroughBitcasts(V), 0, DAG, DL); V = DAG.getBitcast(ExtVT, V); } diff --git a/llvm/test/CodeGen/X86/widened-broadcast.ll b/llvm/test/CodeGen/X86/widened-broadcast.ll index b43c8a4649fd..2ffc413420f4 100644 --- a/llvm/test/CodeGen/X86/widened-broadcast.ll +++ b/llvm/test/CodeGen/X86/widened-broadcast.ll @@ -110,21 +110,10 @@ define <8 x i32> @load_splat_8i32_4i32_01010101(<4 x i32>* %ptr) nounwind uwtabl ; SSE-NEXT: movdqa %xmm0, %xmm1 ; SSE-NEXT: retq ; -; AVX1-LABEL: load_splat_8i32_4i32_01010101: -; AVX1: # %bb.0: # %entry -; AVX1-NEXT: vpermilps {{.*#+}} xmm0 = mem[0,1,0,1] -; AVX1-NEXT: vinsertf128 $1, %xmm0, %ymm0, %ymm0 -; AVX1-NEXT: retq -; -; AVX2-LABEL: load_splat_8i32_4i32_01010101: -; AVX2: # %bb.0: # %entry -; AVX2-NEXT: vbroadcastsd (%rdi), %ymm0 -; AVX2-NEXT: retq -; -; AVX512-LABEL: load_splat_8i32_4i32_01010101: -; AVX512: # %bb.0: # %entry -; AVX512-NEXT: vbroadcastsd (%rdi), %ymm0 -; AVX512-NEXT: retq +; AVX-LABEL: load_splat_8i32_4i32_01010101: +; AVX: # %bb.0: # %entry +; AVX-NEXT: vbroadcastsd (%rdi), %ymm0 +; AVX-NEXT: retq entry: %ld = load <4 x i32>, <4 x i32>* %ptr %ret = shufflevector <4 x i32> %ld, <4 x i32> undef, <8 x i32> @@ -207,21 +196,10 @@ define <16 x i16> @load_splat_16i16_8i16_0101010101010101(<8 x i16>* %ptr) nounw ; SSE-NEXT: movdqa %xmm0, %xmm1 ; SSE-NEXT: retq ; -; AVX1-LABEL: load_splat_16i16_8i16_0101010101010101: -; AVX1: # %bb.0: # %entry -; AVX1-NEXT: vpermilps {{.*#+}} xmm0 = mem[0,0,0,0] -; AVX1-NEXT: vinsertf128 $1, %xmm0, %ymm0, %ymm0 -; AVX1-NEXT: retq -; -; AVX2-LABEL: load_splat_16i16_8i16_0101010101010101: -; AVX2: # %bb.0: # %entry -; AVX2-NEXT: vbroadcastss (%rdi), %ymm0 -; AVX2-NEXT: retq -; -; AVX512-LABEL: load_splat_16i16_8i16_0101010101010101: -; AVX512: # %bb.0: # %entry -; AVX512-NEXT: vbroadcastss (%rdi), %ymm0 -; AVX512-NEXT: retq +; AVX-LABEL: load_splat_16i16_8i16_0101010101010101: +; AVX: # %bb.0: # %entry +; AVX-NEXT: vbroadcastss (%rdi), %ymm0 +; AVX-NEXT: retq entry: %ld = load <8 x i16>, <8 x i16>* %ptr %ret = shufflevector <8 x i16> %ld, <8 x i16> undef, <16 x i32> @@ -235,21 +213,10 @@ define <16 x i16> @load_splat_16i16_8i16_0123012301230123(<8 x i16>* %ptr) nounw ; SSE-NEXT: movdqa %xmm0, %xmm1 ; SSE-NEXT: retq ; -; AVX1-LABEL: load_splat_16i16_8i16_0123012301230123: -; AVX1: # %bb.0: # %entry -; AVX1-NEXT: vpermilps {{.*#+}} xmm0 = mem[0,1,0,1] -; AVX1-NEXT: vinsertf128 $1, %xmm0, %ymm0, %ymm0 -; AVX1-NEXT: retq -; -; AVX2-LABEL: load_splat_16i16_8i16_0123012301230123: -; AVX2: # %bb.0: # %entry -; AVX2-NEXT: vbroadcastsd (%rdi), %ymm0 -; AVX2-NEXT: retq -; -; AVX512-LABEL: load_splat_16i16_8i16_0123012301230123: -; AVX512: # %bb.0: # %entry -; AVX512-NEXT: vbroadcastsd (%rdi), %ymm0 -; AVX512-NEXT: retq +; AVX-LABEL: load_splat_16i16_8i16_0123012301230123: +; AVX: # %bb.0: # %entry +; AVX-NEXT: vbroadcastsd (%rdi), %ymm0 +; AVX-NEXT: retq entry: %ld = load <8 x i16>, <8 x i16>* %ptr %ret = shufflevector <8 x i16> %ld, <8 x i16> undef, <16 x i32> @@ -407,21 +374,10 @@ define <32 x i8> @load_splat_32i8_16i8_01230123012301230123012301230123(<16 x i8 ; SSE-NEXT: movdqa %xmm0, %xmm1 ; SSE-NEXT: retq ; -; AVX1-LABEL: load_splat_32i8_16i8_01230123012301230123012301230123: -; AVX1: # %bb.0: # %entry -; AVX1-NEXT: vpermilps {{.*#+}} xmm0 = mem[0,0,0,0] -; AVX1-NEXT: vinsertf128 $1, %xmm0, %ymm0, %ymm0 -; AVX1-NEXT: retq -; -; AVX2-LABEL: load_splat_32i8_16i8_01230123012301230123012301230123: -; AVX2: # %bb.0: # %entry -; AVX2-NEXT: vbroadcastss (%rdi), %ymm0 -; AVX2-NEXT: retq -; -; AVX512-LABEL: load_splat_32i8_16i8_01230123012301230123012301230123: -; AVX512: # %bb.0: # %entry -; AVX512-NEXT: vbroadcastss (%rdi), %ymm0 -; AVX512-NEXT: retq +; AVX-LABEL: load_splat_32i8_16i8_01230123012301230123012301230123: +; AVX: # %bb.0: # %entry +; AVX-NEXT: vbroadcastss (%rdi), %ymm0 +; AVX-NEXT: retq entry: %ld = load <16 x i8>, <16 x i8>* %ptr %ret = shufflevector <16 x i8> %ld, <16 x i8> undef, <32 x i32> @@ -435,21 +391,10 @@ define <32 x i8> @load_splat_32i8_16i8_01234567012345670123456701234567(<16 x i8 ; SSE-NEXT: movdqa %xmm0, %xmm1 ; SSE-NEXT: retq ; -; AVX1-LABEL: load_splat_32i8_16i8_01234567012345670123456701234567: -; AVX1: # %bb.0: # %entry -; AVX1-NEXT: vpermilps {{.*#+}} xmm0 = mem[0,1,0,1] -; AVX1-NEXT: vinsertf128 $1, %xmm0, %ymm0, %ymm0 -; AVX1-NEXT: retq -; -; AVX2-LABEL: load_splat_32i8_16i8_01234567012345670123456701234567: -; AVX2: # %bb.0: # %entry -; AVX2-NEXT: vbroadcastsd (%rdi), %ymm0 -; AVX2-NEXT: retq -; -; AVX512-LABEL: load_splat_32i8_16i8_01234567012345670123456701234567: -; AVX512: # %bb.0: # %entry -; AVX512-NEXT: vbroadcastsd (%rdi), %ymm0 -; AVX512-NEXT: retq +; AVX-LABEL: load_splat_32i8_16i8_01234567012345670123456701234567: +; AVX: # %bb.0: # %entry +; AVX-NEXT: vbroadcastsd (%rdi), %ymm0 +; AVX-NEXT: retq entry: %ld = load <16 x i8>, <16 x i8>* %ptr %ret = shufflevector <16 x i8> %ld, <16 x i8> undef, <32 x i32>