[X86][AVX] Attempt to fold a scaled index into a gather/scatter scale immediate (PR13310)

If the index operand for a gather/scatter intrinsic is being scaled (self-addition or a shl-by-immediate) then we may be able to fold that scaling into the intrinsic scale immediate value instead.

Fixes PR13310.

Differential Revision: https://reviews.llvm.org/D108539
This commit is contained in:
Simon Pilgrim 2021-10-28 14:07:17 +01:00
parent fbf1745722
commit d29ccbecd0
2 changed files with 60 additions and 36 deletions

View File

@ -50227,9 +50227,40 @@ static SDValue combineMOVMSK(SDNode *N, SelectionDAG &DAG,
}
static SDValue combineX86GatherScatter(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI) {
TargetLowering::DAGCombinerInfo &DCI,
const X86Subtarget &Subtarget) {
auto *MemOp = cast<X86MaskedGatherScatterSDNode>(N);
SDValue Index = MemOp->getIndex();
SDValue Scale = MemOp->getScale();
SDValue Mask = MemOp->getMask();
// Attempt to fold an index scale into the scale value directly.
// TODO: Move this into X86DAGToDAGISel::matchVectorAddressRecursively?
if ((Index.getOpcode() == X86ISD::VSHLI ||
(Index.getOpcode() == ISD::ADD &&
Index.getOperand(0) == Index.getOperand(1))) &&
isa<ConstantSDNode>(Scale)) {
unsigned ShiftAmt =
Index.getOpcode() == ISD::ADD ? 1 : Index.getConstantOperandVal(1);
uint64_t ScaleAmt = cast<ConstantSDNode>(Scale)->getZExtValue();
uint64_t NewScaleAmt = ScaleAmt * (1ULL << ShiftAmt);
if (isPowerOf2_64(NewScaleAmt) && NewScaleAmt <= 8) {
SDValue NewIndex = Index.getOperand(0);
SDValue NewScale =
DAG.getTargetConstant(NewScaleAmt, SDLoc(N), Scale.getValueType());
if (N->getOpcode() == X86ISD::MGATHER)
return getAVX2GatherNode(N->getOpcode(), SDValue(N, 0), DAG,
MemOp->getOperand(1), Mask,
MemOp->getBasePtr(), NewIndex, NewScale,
MemOp->getChain(), Subtarget);
if (N->getOpcode() == X86ISD::MSCATTER)
return getScatterNode(N->getOpcode(), SDValue(N, 0), DAG,
MemOp->getOperand(1), Mask, MemOp->getBasePtr(),
NewIndex, NewScale, MemOp->getChain(), Subtarget);
}
}
// With vector masks we only demand the upper bit of the mask.
SDValue Mask = cast<X86MaskedGatherScatterSDNode>(N)->getMask();
if (Mask.getScalarValueSizeInBits() != 1) {
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
APInt DemandedMask(APInt::getSignMask(Mask.getScalarValueSizeInBits()));
@ -52886,7 +52917,8 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
case X86ISD::FMSUBADD: return combineFMADDSUB(N, DAG, DCI);
case X86ISD::MOVMSK: return combineMOVMSK(N, DAG, DCI, Subtarget);
case X86ISD::MGATHER:
case X86ISD::MSCATTER: return combineX86GatherScatter(N, DAG, DCI);
case X86ISD::MSCATTER:
return combineX86GatherScatter(N, DAG, DCI, Subtarget);
case ISD::MGATHER:
case ISD::MSCATTER: return combineGatherScatter(N, DAG, DCI);
case X86ISD::PCMPEQ:

View File

@ -808,20 +808,19 @@ define <16 x float> @test14(float* %base, i32 %ind, <16 x float*> %vec) {
; KNL_64-NEXT: vmovd %esi, %xmm0
; KNL_64-NEXT: vpbroadcastd %xmm0, %ymm0
; KNL_64-NEXT: vpmovsxdq %ymm0, %zmm0
; KNL_64-NEXT: vpsllq $2, %zmm0, %zmm0
; KNL_64-NEXT: kxnorw %k0, %k0, %k1
; KNL_64-NEXT: vxorps %xmm1, %xmm1, %xmm1
; KNL_64-NEXT: vgatherqps (%rax,%zmm0), %ymm1 {%k1}
; KNL_64-NEXT: vgatherqps (%rax,%zmm0,4), %ymm1 {%k1}
; KNL_64-NEXT: vinsertf64x4 $1, %ymm1, %zmm1, %zmm0
; KNL_64-NEXT: retq
;
; KNL_32-LABEL: test14:
; KNL_32: # %bb.0:
; KNL_32-NEXT: vmovd %xmm0, %eax
; KNL_32-NEXT: vpslld $2, {{[0-9]+}}(%esp){1to16}, %zmm1
; KNL_32-NEXT: vbroadcastss {{[0-9]+}}(%esp), %zmm1
; KNL_32-NEXT: kxnorw %k0, %k0, %k1
; KNL_32-NEXT: vpxor %xmm0, %xmm0, %xmm0
; KNL_32-NEXT: vgatherdps (%eax,%zmm1), %zmm0 {%k1}
; KNL_32-NEXT: vgatherdps (%eax,%zmm1,4), %zmm0 {%k1}
; KNL_32-NEXT: retl
;
; SKX-LABEL: test14:
@ -829,20 +828,19 @@ define <16 x float> @test14(float* %base, i32 %ind, <16 x float*> %vec) {
; SKX-NEXT: vmovq %xmm0, %rax
; SKX-NEXT: vpbroadcastd %esi, %ymm0
; SKX-NEXT: vpmovsxdq %ymm0, %zmm0
; SKX-NEXT: vpsllq $2, %zmm0, %zmm0
; SKX-NEXT: kxnorw %k0, %k0, %k1
; SKX-NEXT: vxorps %xmm1, %xmm1, %xmm1
; SKX-NEXT: vgatherqps (%rax,%zmm0), %ymm1 {%k1}
; SKX-NEXT: vgatherqps (%rax,%zmm0,4), %ymm1 {%k1}
; SKX-NEXT: vinsertf64x4 $1, %ymm1, %zmm1, %zmm0
; SKX-NEXT: retq
;
; SKX_32-LABEL: test14:
; SKX_32: # %bb.0:
; SKX_32-NEXT: vmovd %xmm0, %eax
; SKX_32-NEXT: vpslld $2, {{[0-9]+}}(%esp){1to16}, %zmm1
; SKX_32-NEXT: vbroadcastss {{[0-9]+}}(%esp), %zmm1
; SKX_32-NEXT: kxnorw %k0, %k0, %k1
; SKX_32-NEXT: vpxor %xmm0, %xmm0, %xmm0
; SKX_32-NEXT: vgatherdps (%eax,%zmm1), %zmm0 {%k1}
; SKX_32-NEXT: vgatherdps (%eax,%zmm1,4), %zmm0 {%k1}
; SKX_32-NEXT: retl
%broadcast.splatinsert = insertelement <16 x float*> %vec, float* %base, i32 1
@ -4988,38 +4986,38 @@ define void @splat_ptr_scatter(i32* %ptr, <4 x i1> %mask, <4 x i32> %val) {
;
; PR13310
; FIXME: Failure to fold scaled-index into gather/scatter scale operand.
; Failure to fold scaled-index into gather/scatter scale operand.
;
define <8 x float> @scaleidx_x86gather(float* %base, <8 x i32> %index, <8 x i32> %imask) nounwind {
; KNL_64-LABEL: scaleidx_x86gather:
; KNL_64: # %bb.0:
; KNL_64-NEXT: vpslld $2, %ymm0, %ymm2
; KNL_64-NEXT: vpxor %xmm0, %xmm0, %xmm0
; KNL_64-NEXT: vgatherdps %ymm1, (%rdi,%ymm2), %ymm0
; KNL_64-NEXT: vxorps %xmm2, %xmm2, %xmm2
; KNL_64-NEXT: vgatherdps %ymm1, (%rdi,%ymm0,4), %ymm2
; KNL_64-NEXT: vmovaps %ymm2, %ymm0
; KNL_64-NEXT: retq
;
; KNL_32-LABEL: scaleidx_x86gather:
; KNL_32: # %bb.0:
; KNL_32-NEXT: movl {{[0-9]+}}(%esp), %eax
; KNL_32-NEXT: vpslld $2, %ymm0, %ymm2
; KNL_32-NEXT: vpxor %xmm0, %xmm0, %xmm0
; KNL_32-NEXT: vgatherdps %ymm1, (%eax,%ymm2), %ymm0
; KNL_32-NEXT: vxorps %xmm2, %xmm2, %xmm2
; KNL_32-NEXT: vgatherdps %ymm1, (%eax,%ymm0,4), %ymm2
; KNL_32-NEXT: vmovaps %ymm2, %ymm0
; KNL_32-NEXT: retl
;
; SKX-LABEL: scaleidx_x86gather:
; SKX: # %bb.0:
; SKX-NEXT: vpslld $2, %ymm0, %ymm2
; SKX-NEXT: vpxor %xmm0, %xmm0, %xmm0
; SKX-NEXT: vgatherdps %ymm1, (%rdi,%ymm2), %ymm0
; SKX-NEXT: vxorps %xmm2, %xmm2, %xmm2
; SKX-NEXT: vgatherdps %ymm1, (%rdi,%ymm0,4), %ymm2
; SKX-NEXT: vmovaps %ymm2, %ymm0
; SKX-NEXT: retq
;
; SKX_32-LABEL: scaleidx_x86gather:
; SKX_32: # %bb.0:
; SKX_32-NEXT: movl {{[0-9]+}}(%esp), %eax
; SKX_32-NEXT: vpslld $2, %ymm0, %ymm2
; SKX_32-NEXT: vpxor %xmm0, %xmm0, %xmm0
; SKX_32-NEXT: vgatherdps %ymm1, (%eax,%ymm2), %ymm0
; SKX_32-NEXT: vxorps %xmm2, %xmm2, %xmm2
; SKX_32-NEXT: vgatherdps %ymm1, (%eax,%ymm0,4), %ymm2
; SKX_32-NEXT: vmovaps %ymm2, %ymm0
; SKX_32-NEXT: retl
%ptr = bitcast float* %base to i8*
%mask = bitcast <8 x i32> %imask to <8 x float>
@ -5070,8 +5068,7 @@ define void @scaleidx_x86scatter(<16 x float> %value, float* %base, <16 x i32> %
; KNL_64-LABEL: scaleidx_x86scatter:
; KNL_64: # %bb.0:
; KNL_64-NEXT: kmovw %esi, %k1
; KNL_64-NEXT: vpaddd %zmm1, %zmm1, %zmm1
; KNL_64-NEXT: vscatterdps %zmm0, (%rdi,%zmm1,2) {%k1}
; KNL_64-NEXT: vscatterdps %zmm0, (%rdi,%zmm1,4) {%k1}
; KNL_64-NEXT: vzeroupper
; KNL_64-NEXT: retq
;
@ -5079,16 +5076,14 @@ define void @scaleidx_x86scatter(<16 x float> %value, float* %base, <16 x i32> %
; KNL_32: # %bb.0:
; KNL_32-NEXT: movl {{[0-9]+}}(%esp), %eax
; KNL_32-NEXT: kmovw {{[0-9]+}}(%esp), %k1
; KNL_32-NEXT: vpaddd %zmm1, %zmm1, %zmm1
; KNL_32-NEXT: vscatterdps %zmm0, (%eax,%zmm1,2) {%k1}
; KNL_32-NEXT: vscatterdps %zmm0, (%eax,%zmm1,4) {%k1}
; KNL_32-NEXT: vzeroupper
; KNL_32-NEXT: retl
;
; SKX-LABEL: scaleidx_x86scatter:
; SKX: # %bb.0:
; SKX-NEXT: kmovw %esi, %k1
; SKX-NEXT: vpaddd %zmm1, %zmm1, %zmm1
; SKX-NEXT: vscatterdps %zmm0, (%rdi,%zmm1,2) {%k1}
; SKX-NEXT: vscatterdps %zmm0, (%rdi,%zmm1,4) {%k1}
; SKX-NEXT: vzeroupper
; SKX-NEXT: retq
;
@ -5096,8 +5091,7 @@ define void @scaleidx_x86scatter(<16 x float> %value, float* %base, <16 x i32> %
; SKX_32: # %bb.0:
; SKX_32-NEXT: movl {{[0-9]+}}(%esp), %eax
; SKX_32-NEXT: kmovw {{[0-9]+}}(%esp), %k1
; SKX_32-NEXT: vpaddd %zmm1, %zmm1, %zmm1
; SKX_32-NEXT: vscatterdps %zmm0, (%eax,%zmm1,2) {%k1}
; SKX_32-NEXT: vscatterdps %zmm0, (%eax,%zmm1,4) {%k1}
; SKX_32-NEXT: vzeroupper
; SKX_32-NEXT: retl
%ptr = bitcast float* %base to i8*
@ -5135,18 +5129,16 @@ define void @scaleidx_scatter(<8 x float> %value, float* %base, <8 x i32> %index
;
; SKX-LABEL: scaleidx_scatter:
; SKX: # %bb.0:
; SKX-NEXT: vpaddd %ymm1, %ymm1, %ymm1
; SKX-NEXT: kmovw %esi, %k1
; SKX-NEXT: vscatterdps %ymm0, (%rdi,%ymm1,4) {%k1}
; SKX-NEXT: vscatterdps %ymm0, (%rdi,%ymm1,8) {%k1}
; SKX-NEXT: vzeroupper
; SKX-NEXT: retq
;
; SKX_32-LABEL: scaleidx_scatter:
; SKX_32: # %bb.0:
; SKX_32-NEXT: movl {{[0-9]+}}(%esp), %eax
; SKX_32-NEXT: vpaddd %ymm1, %ymm1, %ymm1
; SKX_32-NEXT: kmovb {{[0-9]+}}(%esp), %k1
; SKX_32-NEXT: vscatterdps %ymm0, (%eax,%ymm1,4) {%k1}
; SKX_32-NEXT: vscatterdps %ymm0, (%eax,%ymm1,8) {%k1}
; SKX_32-NEXT: vzeroupper
; SKX_32-NEXT: retl
%scaledindex = mul <8 x i32> %index, <i32 2, i32 2, i32 2, i32 2, i32 2, i32 2, i32 2, i32 2>