[X86] combineX86ShuffleChain - add support for combining to X86ISD::ROTLI

Refactors matchShuffleAsBitRotate to allow use by both lowerShuffleAsBitRotate and matchUnaryPermuteShuffle.
This commit is contained in:
Simon Pilgrim 2020-02-15 20:04:15 +00:00
parent aa5ebfdf20
commit 34a054ce71
4 changed files with 88 additions and 48 deletions

View File

@ -11704,24 +11704,12 @@ static int matchShuffleAsBitRotate(ArrayRef<int> Mask, int NumSubElts) {
return RotateAmt;
}
/// Lower shuffle using X86ISD::VROTLI rotations.
static SDValue lowerShuffleAsBitRotate(const SDLoc &DL, MVT VT, SDValue V1,
ArrayRef<int> Mask,
const X86Subtarget &Subtarget,
SelectionDAG &DAG) {
static int matchShuffleAsBitRotate(MVT &RotateVT, int EltSizeInBits,
const X86Subtarget &Subtarget,
ArrayRef<int> Mask) {
assert(!isNoopShuffleMask(Mask) && "We shouldn't lower no-op shuffles!");
MVT SVT = VT.getScalarType();
int EltSizeInBits = SVT.getScalarSizeInBits();
assert(EltSizeInBits < 64 && "Can't rotate 64-bit integers");
// Only XOP + AVX512 targets have bit rotation instructions.
// If we at least have SSSE3 (PSHUFB) then we shouldn't attempt to use this.
bool IsLegal =
(VT.is128BitVector() && Subtarget.hasXOP()) || Subtarget.hasAVX512();
if (!IsLegal && Subtarget.hasSSE3())
return SDValue();
// AVX512 only has vXi32/vXi64 rotates, so limit the rotation sub group size.
int MinSubElts = Subtarget.hasAVX512() ? std::max(32 / EltSizeInBits, 2) : 2;
int MaxSubElts = 64 / EltSizeInBits;
@ -11730,36 +11718,55 @@ static SDValue lowerShuffleAsBitRotate(const SDLoc &DL, MVT VT, SDValue V1,
if (RotateAmt < 0)
continue;
int NumElts = VT.getVectorNumElements();
int NumElts = Mask.size();
MVT RotateSVT = MVT::getIntegerVT(EltSizeInBits * NumSubElts);
MVT RotateVT = MVT::getVectorVT(RotateSVT, NumElts / NumSubElts);
RotateVT = MVT::getVectorVT(RotateSVT, NumElts / NumSubElts);
return RotateAmt * EltSizeInBits;
}
// For pre-SSSE3 targets, if we are shuffling vXi8 elts then ISD::ROTL,
// expanded to OR(SRL,SHL), will be more efficient, but if they can
// widen to vXi16 or more then existing lowering should will be better.
int RotateAmtInBits = RotateAmt * EltSizeInBits;
if (!IsLegal) {
if ((RotateAmtInBits % 16) == 0)
return SDValue();
// TODO: Use getTargetVShiftByConstNode.
unsigned ShlAmt = RotateAmtInBits;
unsigned SrlAmt = RotateSVT.getScalarSizeInBits() - RotateAmtInBits;
V1 = DAG.getBitcast(RotateVT, V1);
SDValue SHL = DAG.getNode(X86ISD::VSHLI, DL, RotateVT, V1,
DAG.getTargetConstant(ShlAmt, DL, MVT::i8));
SDValue SRL = DAG.getNode(X86ISD::VSRLI, DL, RotateVT, V1,
DAG.getTargetConstant(SrlAmt, DL, MVT::i8));
SDValue Rot = DAG.getNode(ISD::OR, DL, RotateVT, SHL, SRL);
return DAG.getBitcast(VT, Rot);
}
return -1;
}
SDValue Rot =
DAG.getNode(X86ISD::VROTLI, DL, RotateVT, DAG.getBitcast(RotateVT, V1),
DAG.getTargetConstant(RotateAmtInBits, DL, MVT::i8));
/// Lower shuffle using X86ISD::VROTLI rotations.
static SDValue lowerShuffleAsBitRotate(const SDLoc &DL, MVT VT, SDValue V1,
ArrayRef<int> Mask,
const X86Subtarget &Subtarget,
SelectionDAG &DAG) {
// Only XOP + AVX512 targets have bit rotation instructions.
// If we at least have SSSE3 (PSHUFB) then we shouldn't attempt to use this.
bool IsLegal =
(VT.is128BitVector() && Subtarget.hasXOP()) || Subtarget.hasAVX512();
if (!IsLegal && Subtarget.hasSSE3())
return SDValue();
MVT RotateVT;
int RotateAmt = matchShuffleAsBitRotate(RotateVT, VT.getScalarSizeInBits(),
Subtarget, Mask);
if (RotateAmt < 0)
return SDValue();
// For pre-SSSE3 targets, if we are shuffling vXi8 elts then ISD::ROTL,
// expanded to OR(SRL,SHL), will be more efficient, but if they can
// widen to vXi16 or more then existing lowering should will be better.
if (!IsLegal) {
if ((RotateAmt % 16) == 0)
return SDValue();
// TODO: Use getTargetVShiftByConstNode.
unsigned ShlAmt = RotateAmt;
unsigned SrlAmt = RotateVT.getScalarSizeInBits() - RotateAmt;
V1 = DAG.getBitcast(RotateVT, V1);
SDValue SHL = DAG.getNode(X86ISD::VSHLI, DL, RotateVT, V1,
DAG.getTargetConstant(ShlAmt, DL, MVT::i8));
SDValue SRL = DAG.getNode(X86ISD::VSRLI, DL, RotateVT, V1,
DAG.getTargetConstant(SrlAmt, DL, MVT::i8));
SDValue Rot = DAG.getNode(ISD::OR, DL, RotateVT, SHL, SRL);
return DAG.getBitcast(VT, Rot);
}
return SDValue();
SDValue Rot =
DAG.getNode(X86ISD::VROTLI, DL, RotateVT, DAG.getBitcast(RotateVT, V1),
DAG.getTargetConstant(RotateAmt, DL, MVT::i8));
return DAG.getBitcast(VT, Rot);
}
/// Try to lower a vector shuffle as a byte rotation.
@ -33538,6 +33545,19 @@ static bool matchUnaryPermuteShuffle(MVT MaskVT, ArrayRef<int> Mask,
}
}
// Attempt to match against bit rotates.
if (!ContainsZeros && AllowIntDomain && MaskScalarSizeInBits < 64 &&
((MaskVT.is128BitVector() && Subtarget.hasXOP()) ||
Subtarget.hasAVX512())) {
int RotateAmt = matchShuffleAsBitRotate(ShuffleVT, MaskScalarSizeInBits,
Subtarget, Mask);
if (0 < RotateAmt) {
Shuffle = X86ISD::VROTLI;
PermuteImm = (unsigned)RotateAmt;
return true;
}
}
return false;
}

View File

@ -464,10 +464,17 @@ define <32 x i8> @combine_pshufb_as_pshufhw(<32 x i8> %a0) {
}
define <32 x i8> @combine_pshufb_not_as_pshufw(<32 x i8> %a0) {
; CHECK-LABEL: combine_pshufb_not_as_pshufw:
; CHECK: # %bb.0:
; CHECK-NEXT: vpshufb {{.*#+}} ymm0 = ymm0[2,3,0,1,6,7,4,5,10,11,8,9,14,15,12,13,18,19,16,17,22,23,20,21,26,27,24,25,30,31,28,29]
; CHECK-NEXT: ret{{[l|q]}}
; AVX2-LABEL: combine_pshufb_not_as_pshufw:
; AVX2: # %bb.0:
; AVX2-NEXT: vpshufb {{.*#+}} ymm0 = ymm0[2,3,0,1,6,7,4,5,10,11,8,9,14,15,12,13,18,19,16,17,22,23,20,21,26,27,24,25,30,31,28,29]
; AVX2-NEXT: ret{{[l|q]}}
;
; AVX512-LABEL: combine_pshufb_not_as_pshufw:
; AVX512: # %bb.0:
; AVX512-NEXT: # kill: def $ymm0 killed $ymm0 def $zmm0
; AVX512-NEXT: vprold $16, %zmm0, %zmm0
; AVX512-NEXT: # kill: def $ymm0 killed $ymm0 killed $zmm0
; AVX512-NEXT: ret{{[l|q]}}
%res0 = call <32 x i8> @llvm.x86.avx2.pshuf.b(<32 x i8> %a0, <32 x i8> <i8 2, i8 3, i8 0, i8 1, i8 6, i8 7, i8 4, i8 5, i8 8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 2, i8 3, i8 0, i8 1, i8 6, i8 7, i8 4, i8 5, i8 8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15>)
%res1 = call <32 x i8> @llvm.x86.avx2.pshuf.b(<32 x i8> %res0, <32 x i8> <i8 0, i8 1, i8 2, i8 3, i8 4, i8 5, i8 6, i8 7, i8 10, i8 11, i8 8, i8 9, i8 14, i8 15, i8 12, i8 13, i8 0, i8 1, i8 2, i8 3, i8 4, i8 5, i8 6, i8 7, i8 10, i8 11, i8 8, i8 9, i8 14, i8 15, i8 12, i8 13>)
ret <32 x i8> %res1

View File

@ -403,10 +403,23 @@ define <16 x i8> @combine_pshufb_not_as_pshufw(<16 x i8> %a0) {
; SSE-NEXT: pshufb {{.*#+}} xmm0 = xmm0[2,3,0,1,6,7,4,5,10,11,8,9,14,15,12,13]
; SSE-NEXT: retq
;
; AVX-LABEL: combine_pshufb_not_as_pshufw:
; AVX: # %bb.0:
; AVX-NEXT: vpshufb {{.*#+}} xmm0 = xmm0[2,3,0,1,6,7,4,5,10,11,8,9,14,15,12,13]
; AVX-NEXT: retq
; AVX1-LABEL: combine_pshufb_not_as_pshufw:
; AVX1: # %bb.0:
; AVX1-NEXT: vpshufb {{.*#+}} xmm0 = xmm0[2,3,0,1,6,7,4,5,10,11,8,9,14,15,12,13]
; AVX1-NEXT: retq
;
; AVX2-LABEL: combine_pshufb_not_as_pshufw:
; AVX2: # %bb.0:
; AVX2-NEXT: vpshufb {{.*#+}} xmm0 = xmm0[2,3,0,1,6,7,4,5,10,11,8,9,14,15,12,13]
; AVX2-NEXT: retq
;
; AVX512F-LABEL: combine_pshufb_not_as_pshufw:
; AVX512F: # %bb.0:
; AVX512F-NEXT: # kill: def $xmm0 killed $xmm0 def $zmm0
; AVX512F-NEXT: vprold $16, %zmm0, %zmm0
; AVX512F-NEXT: # kill: def $xmm0 killed $xmm0 killed $zmm0
; AVX512F-NEXT: vzeroupper
; AVX512F-NEXT: retq
%res0 = call <16 x i8> @llvm.x86.ssse3.pshuf.b.128(<16 x i8> %a0, <16 x i8> <i8 2, i8 3, i8 0, i8 1, i8 6, i8 7, i8 4, i8 5, i8 8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15>)
%res1 = call <16 x i8> @llvm.x86.ssse3.pshuf.b.128(<16 x i8> %res0, <16 x i8> <i8 0, i8 1, i8 2, i8 3, i8 4, i8 5, i8 6, i8 7, i8 10, i8 11, i8 8, i8 9, i8 14, i8 15, i8 12, i8 13>)
ret <16 x i8> %res1

View File

@ -255,7 +255,7 @@ define <4 x i32> @combine_vpperm_10zz32BA(<4 x i32> %a0, <4 x i32> %a1) {
define <16 x i8> @combine_vpperm_as_proti_v8i16(<16 x i8> %a0, <16 x i8> %a1) {
; CHECK-LABEL: combine_vpperm_as_proti_v8i16:
; CHECK: # %bb.0:
; CHECK-NEXT: vpperm {{.*#+}} xmm0 = xmm0[1,0,3,2,5,4,7,6,9,8,11,10,13,12,15,14]
; CHECK-NEXT: vprotw $8, %xmm0, %xmm0
; CHECK-NEXT: ret{{[l|q]}}
%res0 = call <16 x i8> @llvm.x86.xop.vpperm(<16 x i8> %a0, <16 x i8> %a1, <16 x i8> <i8 1, i8 0, i8 3, i8 2, i8 5, i8 4, i8 7, i8 6, i8 9, i8 8, i8 11, i8 10, i8 13, i8 12, i8 15, i8 14>)
ret <16 x i8> %res0