[DAGCombiner] Teach visitMLOAD to replace an all ones mask with an unmasked load

If we have an all ones mask, we can just a regular masked load. InstCombine already gets this in IR. But the all ones mask can appear after type legalization.

Only avx512 test cases are affected because X86 backend already looks for element 0 and the last element being 1. It replaces this with an unmasked load and blend. The all ones mask is a special case of that where the blend will be removed. That transform is only enabled on avx2 targets. I believe that's because a non-zero passthru on avx2 already requires a separate blend so its more profitable to handle mixed constant masks.

This patch adds a dedicated all ones handling to the target independent DAG combiner. I've skipped extending, expanding, and index loads for now. X86 doesn't use index so I don't know much about it. Extending made me nervous because I wasn't sure I could trust the memory VT had the right element count due to some weirdness in vector splitting. For expanding I wasn't sure if we needed different undef handling.

Differential Revision: https://reviews.llvm.org/D87788
This commit is contained in:
Craig Topper 2020-09-16 13:21:15 -07:00
parent 65ef2e50a2
commit 89ee4c0314
2 changed files with 18 additions and 26 deletions

View File

@ -9272,6 +9272,16 @@ SDValue DAGCombiner::visitMLOAD(SDNode *N) {
if (ISD::isBuildVectorAllZeros(Mask.getNode()))
return CombineTo(N, MLD->getPassThru(), MLD->getChain());
// If this is a masked load with an all ones mask, we can use a unmasked load.
// FIXME: Can we do this for indexed, expanding, or extending loads?
if (ISD::isBuildVectorAllOnes(Mask.getNode()) &&
MLD->isUnindexed() && !MLD->isExpandingLoad() &&
MLD->getExtensionType() == ISD::NON_EXTLOAD) {
SDValue NewLd = DAG.getLoad(N->getValueType(0), SDLoc(N), MLD->getChain(),
MLD->getBasePtr(), MLD->getMemOperand());
return CombineTo(N, NewLd, NewLd.getValue(1));
}
// Try transforming N to an indexed load.
if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
return SDValue(N, 0);

View File

@ -6171,25 +6171,10 @@ define <4 x float> @mload_constmask_v4f32_all(<4 x float>* %addr) {
; SSE-NEXT: movups (%rdi), %xmm0
; SSE-NEXT: retq
;
; AVX1OR2-LABEL: mload_constmask_v4f32_all:
; AVX1OR2: ## %bb.0:
; AVX1OR2-NEXT: vmovups (%rdi), %xmm0
; AVX1OR2-NEXT: retq
;
; AVX512F-LABEL: mload_constmask_v4f32_all:
; AVX512F: ## %bb.0:
; AVX512F-NEXT: movw $15, %ax
; AVX512F-NEXT: kmovw %eax, %k1
; AVX512F-NEXT: vmovups (%rdi), %zmm0 {%k1} {z}
; AVX512F-NEXT: ## kill: def $xmm0 killed $xmm0 killed $zmm0
; AVX512F-NEXT: vzeroupper
; AVX512F-NEXT: retq
;
; AVX512VL-LABEL: mload_constmask_v4f32_all:
; AVX512VL: ## %bb.0:
; AVX512VL-NEXT: kxnorw %k0, %k0, %k1
; AVX512VL-NEXT: vmovups (%rdi), %xmm0 {%k1} {z}
; AVX512VL-NEXT: retq
; AVX-LABEL: mload_constmask_v4f32_all:
; AVX: ## %bb.0:
; AVX-NEXT: vmovups (%rdi), %xmm0
; AVX-NEXT: retq
%res = call <4 x float> @llvm.masked.load.v4f32.p0v4f32(<4 x float>* %addr, i32 4, <4 x i1> <i1 true, i1 true, i1 true, i1 true>, <4 x float>undef)
ret <4 x float> %res
}
@ -6573,7 +6558,7 @@ define <8 x double> @mload_constmask_v8f64(<8 x double>* %addr, <8 x double> %ds
ret <8 x double> %res
}
; FIXME: We should be able to detect the mask is all ones after type
; Make sure we detect the mask is all ones after type
; legalization to use an unmasked load for some of the avx512 instructions.
define <16 x double> @mload_constmask_v16f64_allones_split(<16 x double>* %addr, <16 x double> %dst) {
; SSE-LABEL: mload_constmask_v16f64_allones_split:
@ -6611,29 +6596,26 @@ define <16 x double> @mload_constmask_v16f64_allones_split(<16 x double>* %addr,
;
; AVX512F-LABEL: mload_constmask_v16f64_allones_split:
; AVX512F: ## %bb.0:
; AVX512F-NEXT: kxnorw %k0, %k0, %k1
; AVX512F-NEXT: vmovupd (%rdi), %zmm0 {%k1}
; AVX512F-NEXT: movb $85, %al
; AVX512F-NEXT: kmovw %eax, %k1
; AVX512F-NEXT: vmovupd 64(%rdi), %zmm1 {%k1}
; AVX512F-NEXT: vmovups (%rdi), %zmm0
; AVX512F-NEXT: retq
;
; AVX512VLDQ-LABEL: mload_constmask_v16f64_allones_split:
; AVX512VLDQ: ## %bb.0:
; AVX512VLDQ-NEXT: kxnorw %k0, %k0, %k1
; AVX512VLDQ-NEXT: vmovupd (%rdi), %zmm0 {%k1}
; AVX512VLDQ-NEXT: movb $85, %al
; AVX512VLDQ-NEXT: kmovw %eax, %k1
; AVX512VLDQ-NEXT: vmovupd 64(%rdi), %zmm1 {%k1}
; AVX512VLDQ-NEXT: vmovups (%rdi), %zmm0
; AVX512VLDQ-NEXT: retq
;
; AVX512VLBW-LABEL: mload_constmask_v16f64_allones_split:
; AVX512VLBW: ## %bb.0:
; AVX512VLBW-NEXT: kxnorw %k0, %k0, %k1
; AVX512VLBW-NEXT: vmovupd (%rdi), %zmm0 {%k1}
; AVX512VLBW-NEXT: movb $85, %al
; AVX512VLBW-NEXT: kmovd %eax, %k1
; AVX512VLBW-NEXT: vmovupd 64(%rdi), %zmm1 {%k1}
; AVX512VLBW-NEXT: vmovups (%rdi), %zmm0
; AVX512VLBW-NEXT: retq
%res = call <16 x double> @llvm.masked.load.v16f64.p0v16f64(<16 x double>* %addr, i32 4, <16 x i1> <i1 1, i1 1, i1 1, i1 1, i1 1, i1 1, i1 1, i1 1, i1 1, i1 0, i1 1, i1 0, i1 1, i1 0, i1 1, i1 0>, <16 x double> %dst)
ret <16 x double> %res