[DAGCombiner] Transform (zext (select c, load1, load2)) -> (select c, zextload1, zextload2)

If extload is legal, following transform
    (zext (select c, load1, load2)) -> (select c, zextload1, zextload2)
can save one ext instruction.

Differential Revision: https://reviews.llvm.org/D95086
This commit is contained in:
Guozhi Wei 2021-02-18 13:12:19 -08:00
parent ea2ff54ccc
commit 66f2d09ebf
2 changed files with 105 additions and 34 deletions

View File

@ -10029,6 +10029,77 @@ SDValue DAGCombiner::visitSETCCCARRY(SDNode *N) {
return SDValue();
}
/// Check if N satisfies:
/// N is used once.
/// N is a Load.
/// The load is compatible with ExtOpcode. It means
/// If load has explicit zero/sign extension, ExpOpcode must have the same
/// extension.
/// Otherwise returns true.
static bool isCompatibleLoad(SDValue N, unsigned ExtOpcode) {
if (!N.hasOneUse())
return false;
if (!isa<LoadSDNode>(N))
return false;
LoadSDNode *Load = cast<LoadSDNode>(N);
ISD::LoadExtType LoadExt = Load->getExtensionType();
if (LoadExt == ISD::NON_EXTLOAD || LoadExt == ISD::EXTLOAD)
return true;
// Now LoadExt is either SEXTLOAD or ZEXTLOAD, ExtOpcode must have the same
// extension.
if ((LoadExt == ISD::SEXTLOAD && ExtOpcode != ISD::SIGN_EXTEND) ||
(LoadExt == ISD::ZEXTLOAD && ExtOpcode != ISD::ZERO_EXTEND))
return false;
return true;
}
/// Fold
/// (sext (select c, load x, load y)) -> (select c, sextload x, sextload y)
/// (zext (select c, load x, load y)) -> (select c, zextload x, zextload y)
/// (aext (select c, load x, load y)) -> (select c, extload x, extload y)
/// This function is called by the DAGCombiner when visiting sext/zext/aext
/// dag nodes (see for example method DAGCombiner::visitSIGN_EXTEND).
static SDValue tryToFoldExtendSelectLoad(SDNode *N, const TargetLowering &TLI,
SelectionDAG &DAG) {
unsigned Opcode = N->getOpcode();
SDValue N0 = N->getOperand(0);
EVT VT = N->getValueType(0);
SDLoc DL(N);
assert((Opcode == ISD::SIGN_EXTEND || Opcode == ISD::ZERO_EXTEND ||
Opcode == ISD::ANY_EXTEND) &&
"Expected EXTEND dag node in input!");
if (!(N0->getOpcode() == ISD::SELECT || N0->getOpcode() == ISD::VSELECT) ||
!N0.hasOneUse())
return SDValue();
SDValue Op1 = N0->getOperand(1);
SDValue Op2 = N0->getOperand(2);
if (!isCompatibleLoad(Op1, Opcode) || !isCompatibleLoad(Op2, Opcode))
return SDValue();
auto ExtLoadOpcode = ISD::EXTLOAD;
if (Opcode == ISD::SIGN_EXTEND)
ExtLoadOpcode = ISD::SEXTLOAD;
else if (Opcode == ISD::ZERO_EXTEND)
ExtLoadOpcode = ISD::ZEXTLOAD;
LoadSDNode *Load1 = cast<LoadSDNode>(Op1);
LoadSDNode *Load2 = cast<LoadSDNode>(Op2);
if (!TLI.isLoadExtLegal(ExtLoadOpcode, VT, Load1->getMemoryVT()) ||
!TLI.isLoadExtLegal(ExtLoadOpcode, VT, Load2->getMemoryVT()))
return SDValue();
SDValue Ext1 = DAG.getNode(Opcode, DL, VT, Op1);
SDValue Ext2 = DAG.getNode(Opcode, DL, VT, Op2);
return DAG.getSelect(DL, VT, N0->getOperand(0), Ext1, Ext2);
}
/// Try to fold a sext/zext/aext dag node into a ConstantSDNode or
/// a build_vector of constants.
/// This function is called by the DAGCombiner when visiting sext/zext/aext
@ -10813,6 +10884,9 @@ SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) {
return DAG.getNode(ISD::ADD, DL, VT, Zext, DAG.getAllOnesConstant(DL, VT));
}
if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG))
return Res;
return SDValue();
}
@ -11125,6 +11199,9 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) {
if (SDValue NewCtPop = widenCtPop(N, DAG))
return NewCtPop;
if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG))
return Res;
return SDValue();
}
@ -11277,6 +11354,9 @@ SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) {
if (SDValue NewCtPop = widenCtPop(N, DAG))
return NewCtPop;
if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG))
return Res;
return SDValue();
}

View File

@ -1,15 +1,14 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+sse4.1 | FileCheck %s
; TODO: (zext(select c, load1, load2)) -> (select c, zextload1, zextload2)
; (zext(select c, load1, load2)) -> (select c, zextload1, zextload2)
define i64 @zext_scalar(i8* %p, i1 zeroext %c) {
; CHECK-LABEL: zext_scalar:
; CHECK: # %bb.0:
; CHECK-NEXT: movzbl (%rdi), %eax
; CHECK-NEXT: movzbl 1(%rdi), %ecx
; CHECK-NEXT: movzbl (%rdi), %ecx
; CHECK-NEXT: movzbl 1(%rdi), %eax
; CHECK-NEXT: testl %esi, %esi
; CHECK-NEXT: cmovel %eax, %ecx
; CHECK-NEXT: movzbl %cl, %eax
; CHECK-NEXT: cmoveq %rcx, %rax
; CHECK-NEXT: retq
%ld1 = load volatile i8, i8* %p
%arrayidx1 = getelementptr inbounds i8, i8* %p, i64 1
@ -22,13 +21,10 @@ define i64 @zext_scalar(i8* %p, i1 zeroext %c) {
define i64 @zext_scalar2(i8* %p, i16* %q, i1 zeroext %c) {
; CHECK-LABEL: zext_scalar2:
; CHECK: # %bb.0:
; CHECK-NEXT: movzbl (%rdi), %eax
; CHECK-NEXT: testl %edx, %edx
; CHECK-NEXT: je .LBB1_2
; CHECK-NEXT: # %bb.1:
; CHECK-NEXT: movzbl (%rdi), %ecx
; CHECK-NEXT: movzwl (%rsi), %eax
; CHECK-NEXT: .LBB1_2:
; CHECK-NEXT: movzwl %ax, %eax
; CHECK-NEXT: testl %edx, %edx
; CHECK-NEXT: cmoveq %rcx, %rax
; CHECK-NEXT: retq
%ld1 = load volatile i8, i8* %p
%ext_ld1 = zext i8 %ld1 to i16
@ -58,15 +54,14 @@ define i64 @zext_scalar_neg(i8* %p, i16* %q, i1 zeroext %c) {
ret i64 %cond
}
; TODO: (sext(select c, load1, load2)) -> (select c, sextload1, sextload2)
; (sext(select c, load1, load2)) -> (select c, sextload1, sextload2)
define i64 @sext_scalar(i8* %p, i1 zeroext %c) {
; CHECK-LABEL: sext_scalar:
; CHECK: # %bb.0:
; CHECK-NEXT: movzbl (%rdi), %eax
; CHECK-NEXT: movzbl 1(%rdi), %ecx
; CHECK-NEXT: movsbq (%rdi), %rcx
; CHECK-NEXT: movsbq 1(%rdi), %rax
; CHECK-NEXT: testl %esi, %esi
; CHECK-NEXT: cmovel %eax, %ecx
; CHECK-NEXT: movsbq %cl, %rax
; CHECK-NEXT: cmoveq %rcx, %rax
; CHECK-NEXT: retq
%ld1 = load volatile i8, i8* %p
%arrayidx1 = getelementptr inbounds i8, i8* %p, i64 1
@ -80,14 +75,13 @@ define i64 @sext_scalar(i8* %p, i1 zeroext %c) {
define <2 x i64> @zext_vector_i1(<2 x i32>* %p, i1 zeroext %c) {
; CHECK-LABEL: zext_vector_i1:
; CHECK: # %bb.0:
; CHECK-NEXT: movq {{.*#+}} xmm1 = mem[0],zero
; CHECK-NEXT: movq {{.*#+}} xmm0 = mem[0],zero
; CHECK-NEXT: pmovzxdq {{.*#+}} xmm1 = mem[0],zero,mem[1],zero
; CHECK-NEXT: pmovzxdq {{.*#+}} xmm0 = mem[0],zero,mem[1],zero
; CHECK-NEXT: testl %esi, %esi
; CHECK-NEXT: jne .LBB4_2
; CHECK-NEXT: # %bb.1:
; CHECK-NEXT: movdqa %xmm1, %xmm0
; CHECK-NEXT: .LBB4_2:
; CHECK-NEXT: pmovzxdq {{.*#+}} xmm0 = xmm0[0],zero,xmm0[1],zero
; CHECK-NEXT: retq
%ld1 = load volatile <2 x i32>, <2 x i32>* %p
%arrayidx1 = getelementptr inbounds <2 x i32>, <2 x i32>* %p, i64 1
@ -100,12 +94,11 @@ define <2 x i64> @zext_vector_i1(<2 x i32>* %p, i1 zeroext %c) {
define <2 x i64> @zext_vector_v2i1(<2 x i32>* %p, <2 x i1> %c) {
; CHECK-LABEL: zext_vector_v2i1:
; CHECK: # %bb.0:
; CHECK-NEXT: pshufd {{.*#+}} xmm0 = xmm0[0,2,2,3]
; CHECK-NEXT: pslld $31, %xmm0
; CHECK-NEXT: movsd {{.*#+}} xmm1 = mem[0],zero
; CHECK-NEXT: movsd {{.*#+}} xmm2 = mem[0],zero
; CHECK-NEXT: blendvps %xmm0, %xmm2, %xmm1
; CHECK-NEXT: pmovzxdq {{.*#+}} xmm0 = xmm1[0],zero,xmm1[1],zero
; CHECK-NEXT: psllq $63, %xmm0
; CHECK-NEXT: pmovzxdq {{.*#+}} xmm1 = mem[0],zero,mem[1],zero
; CHECK-NEXT: pmovzxdq {{.*#+}} xmm2 = mem[0],zero,mem[1],zero
; CHECK-NEXT: blendvpd %xmm0, %xmm2, %xmm1
; CHECK-NEXT: movapd %xmm1, %xmm0
; CHECK-NEXT: retq
%ld1 = load volatile <2 x i32>, <2 x i32>* %p
%arrayidx1 = getelementptr inbounds <2 x i32>, <2 x i32>* %p, i64 1
@ -119,14 +112,13 @@ define <2 x i64> @zext_vector_v2i1(<2 x i32>* %p, <2 x i1> %c) {
define <2 x i64> @sext_vector_i1(<2 x i32>* %p, i1 zeroext %c) {
; CHECK-LABEL: sext_vector_i1:
; CHECK: # %bb.0:
; CHECK-NEXT: movq {{.*#+}} xmm1 = mem[0],zero
; CHECK-NEXT: movq {{.*#+}} xmm0 = mem[0],zero
; CHECK-NEXT: pmovsxdq (%rdi), %xmm1
; CHECK-NEXT: pmovsxdq 8(%rdi), %xmm0
; CHECK-NEXT: testl %esi, %esi
; CHECK-NEXT: jne .LBB6_2
; CHECK-NEXT: # %bb.1:
; CHECK-NEXT: movdqa %xmm1, %xmm0
; CHECK-NEXT: .LBB6_2:
; CHECK-NEXT: pmovsxdq %xmm0, %xmm0
; CHECK-NEXT: retq
%ld1 = load volatile <2 x i32>, <2 x i32>* %p
%arrayidx1 = getelementptr inbounds <2 x i32>, <2 x i32>* %p, i64 1
@ -139,12 +131,11 @@ define <2 x i64> @sext_vector_i1(<2 x i32>* %p, i1 zeroext %c) {
define <2 x i64> @sext_vector_v2i1(<2 x i32>* %p, <2 x i1> %c) {
; CHECK-LABEL: sext_vector_v2i1:
; CHECK: # %bb.0:
; CHECK-NEXT: pshufd {{.*#+}} xmm0 = xmm0[0,2,2,3]
; CHECK-NEXT: pslld $31, %xmm0
; CHECK-NEXT: movsd {{.*#+}} xmm1 = mem[0],zero
; CHECK-NEXT: movsd {{.*#+}} xmm2 = mem[0],zero
; CHECK-NEXT: blendvps %xmm0, %xmm2, %xmm1
; CHECK-NEXT: pmovsxdq %xmm1, %xmm0
; CHECK-NEXT: psllq $63, %xmm0
; CHECK-NEXT: pmovsxdq (%rdi), %xmm1
; CHECK-NEXT: pmovsxdq 8(%rdi), %xmm2
; CHECK-NEXT: blendvpd %xmm0, %xmm2, %xmm1
; CHECK-NEXT: movapd %xmm1, %xmm0
; CHECK-NEXT: retq
%ld1 = load volatile <2 x i32>, <2 x i32>* %p
%arrayidx1 = getelementptr inbounds <2 x i32>, <2 x i32>* %p, i64 1