From 66f2d09ebf8d81d019a5524cdc5e7f88acbb7504 Mon Sep 17 00:00:00 2001 From: Guozhi Wei Date: Thu, 18 Feb 2021 13:12:19 -0800 Subject: [PATCH] [DAGCombiner] Transform (zext (select c, load1, load2)) -> (select c, zextload1, zextload2) If extload is legal, following transform (zext (select c, load1, load2)) -> (select c, zextload1, zextload2) can save one ext instruction. Differential Revision: https://reviews.llvm.org/D95086 --- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 80 +++++++++++++++++++ llvm/test/CodeGen/X86/select-ext.ll | 59 ++++++-------- 2 files changed, 105 insertions(+), 34 deletions(-) diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 737997a3eae6..7f3aeeb0353f 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -10029,6 +10029,77 @@ SDValue DAGCombiner::visitSETCCCARRY(SDNode *N) { return SDValue(); } +/// Check if N satisfies: +/// N is used once. +/// N is a Load. +/// The load is compatible with ExtOpcode. It means +/// If load has explicit zero/sign extension, ExpOpcode must have the same +/// extension. +/// Otherwise returns true. +static bool isCompatibleLoad(SDValue N, unsigned ExtOpcode) { + if (!N.hasOneUse()) + return false; + + if (!isa(N)) + return false; + + LoadSDNode *Load = cast(N); + ISD::LoadExtType LoadExt = Load->getExtensionType(); + if (LoadExt == ISD::NON_EXTLOAD || LoadExt == ISD::EXTLOAD) + return true; + + // Now LoadExt is either SEXTLOAD or ZEXTLOAD, ExtOpcode must have the same + // extension. + if ((LoadExt == ISD::SEXTLOAD && ExtOpcode != ISD::SIGN_EXTEND) || + (LoadExt == ISD::ZEXTLOAD && ExtOpcode != ISD::ZERO_EXTEND)) + return false; + + return true; +} + +/// Fold +/// (sext (select c, load x, load y)) -> (select c, sextload x, sextload y) +/// (zext (select c, load x, load y)) -> (select c, zextload x, zextload y) +/// (aext (select c, load x, load y)) -> (select c, extload x, extload y) +/// This function is called by the DAGCombiner when visiting sext/zext/aext +/// dag nodes (see for example method DAGCombiner::visitSIGN_EXTEND). +static SDValue tryToFoldExtendSelectLoad(SDNode *N, const TargetLowering &TLI, + SelectionDAG &DAG) { + unsigned Opcode = N->getOpcode(); + SDValue N0 = N->getOperand(0); + EVT VT = N->getValueType(0); + SDLoc DL(N); + + assert((Opcode == ISD::SIGN_EXTEND || Opcode == ISD::ZERO_EXTEND || + Opcode == ISD::ANY_EXTEND) && + "Expected EXTEND dag node in input!"); + + if (!(N0->getOpcode() == ISD::SELECT || N0->getOpcode() == ISD::VSELECT) || + !N0.hasOneUse()) + return SDValue(); + + SDValue Op1 = N0->getOperand(1); + SDValue Op2 = N0->getOperand(2); + if (!isCompatibleLoad(Op1, Opcode) || !isCompatibleLoad(Op2, Opcode)) + return SDValue(); + + auto ExtLoadOpcode = ISD::EXTLOAD; + if (Opcode == ISD::SIGN_EXTEND) + ExtLoadOpcode = ISD::SEXTLOAD; + else if (Opcode == ISD::ZERO_EXTEND) + ExtLoadOpcode = ISD::ZEXTLOAD; + + LoadSDNode *Load1 = cast(Op1); + LoadSDNode *Load2 = cast(Op2); + if (!TLI.isLoadExtLegal(ExtLoadOpcode, VT, Load1->getMemoryVT()) || + !TLI.isLoadExtLegal(ExtLoadOpcode, VT, Load2->getMemoryVT())) + return SDValue(); + + SDValue Ext1 = DAG.getNode(Opcode, DL, VT, Op1); + SDValue Ext2 = DAG.getNode(Opcode, DL, VT, Op2); + return DAG.getSelect(DL, VT, N0->getOperand(0), Ext1, Ext2); +} + /// Try to fold a sext/zext/aext dag node into a ConstantSDNode or /// a build_vector of constants. /// This function is called by the DAGCombiner when visiting sext/zext/aext @@ -10813,6 +10884,9 @@ SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) { return DAG.getNode(ISD::ADD, DL, VT, Zext, DAG.getAllOnesConstant(DL, VT)); } + if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG)) + return Res; + return SDValue(); } @@ -11125,6 +11199,9 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) { if (SDValue NewCtPop = widenCtPop(N, DAG)) return NewCtPop; + if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG)) + return Res; + return SDValue(); } @@ -11277,6 +11354,9 @@ SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) { if (SDValue NewCtPop = widenCtPop(N, DAG)) return NewCtPop; + if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG)) + return Res; + return SDValue(); } diff --git a/llvm/test/CodeGen/X86/select-ext.ll b/llvm/test/CodeGen/X86/select-ext.ll index acbd7577e9fd..82e79b18534c 100644 --- a/llvm/test/CodeGen/X86/select-ext.ll +++ b/llvm/test/CodeGen/X86/select-ext.ll @@ -1,15 +1,14 @@ ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py ; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+sse4.1 | FileCheck %s -; TODO: (zext(select c, load1, load2)) -> (select c, zextload1, zextload2) +; (zext(select c, load1, load2)) -> (select c, zextload1, zextload2) define i64 @zext_scalar(i8* %p, i1 zeroext %c) { ; CHECK-LABEL: zext_scalar: ; CHECK: # %bb.0: -; CHECK-NEXT: movzbl (%rdi), %eax -; CHECK-NEXT: movzbl 1(%rdi), %ecx +; CHECK-NEXT: movzbl (%rdi), %ecx +; CHECK-NEXT: movzbl 1(%rdi), %eax ; CHECK-NEXT: testl %esi, %esi -; CHECK-NEXT: cmovel %eax, %ecx -; CHECK-NEXT: movzbl %cl, %eax +; CHECK-NEXT: cmoveq %rcx, %rax ; CHECK-NEXT: retq %ld1 = load volatile i8, i8* %p %arrayidx1 = getelementptr inbounds i8, i8* %p, i64 1 @@ -22,13 +21,10 @@ define i64 @zext_scalar(i8* %p, i1 zeroext %c) { define i64 @zext_scalar2(i8* %p, i16* %q, i1 zeroext %c) { ; CHECK-LABEL: zext_scalar2: ; CHECK: # %bb.0: -; CHECK-NEXT: movzbl (%rdi), %eax -; CHECK-NEXT: testl %edx, %edx -; CHECK-NEXT: je .LBB1_2 -; CHECK-NEXT: # %bb.1: +; CHECK-NEXT: movzbl (%rdi), %ecx ; CHECK-NEXT: movzwl (%rsi), %eax -; CHECK-NEXT: .LBB1_2: -; CHECK-NEXT: movzwl %ax, %eax +; CHECK-NEXT: testl %edx, %edx +; CHECK-NEXT: cmoveq %rcx, %rax ; CHECK-NEXT: retq %ld1 = load volatile i8, i8* %p %ext_ld1 = zext i8 %ld1 to i16 @@ -58,15 +54,14 @@ define i64 @zext_scalar_neg(i8* %p, i16* %q, i1 zeroext %c) { ret i64 %cond } -; TODO: (sext(select c, load1, load2)) -> (select c, sextload1, sextload2) +; (sext(select c, load1, load2)) -> (select c, sextload1, sextload2) define i64 @sext_scalar(i8* %p, i1 zeroext %c) { ; CHECK-LABEL: sext_scalar: ; CHECK: # %bb.0: -; CHECK-NEXT: movzbl (%rdi), %eax -; CHECK-NEXT: movzbl 1(%rdi), %ecx +; CHECK-NEXT: movsbq (%rdi), %rcx +; CHECK-NEXT: movsbq 1(%rdi), %rax ; CHECK-NEXT: testl %esi, %esi -; CHECK-NEXT: cmovel %eax, %ecx -; CHECK-NEXT: movsbq %cl, %rax +; CHECK-NEXT: cmoveq %rcx, %rax ; CHECK-NEXT: retq %ld1 = load volatile i8, i8* %p %arrayidx1 = getelementptr inbounds i8, i8* %p, i64 1 @@ -80,14 +75,13 @@ define i64 @sext_scalar(i8* %p, i1 zeroext %c) { define <2 x i64> @zext_vector_i1(<2 x i32>* %p, i1 zeroext %c) { ; CHECK-LABEL: zext_vector_i1: ; CHECK: # %bb.0: -; CHECK-NEXT: movq {{.*#+}} xmm1 = mem[0],zero -; CHECK-NEXT: movq {{.*#+}} xmm0 = mem[0],zero +; CHECK-NEXT: pmovzxdq {{.*#+}} xmm1 = mem[0],zero,mem[1],zero +; CHECK-NEXT: pmovzxdq {{.*#+}} xmm0 = mem[0],zero,mem[1],zero ; CHECK-NEXT: testl %esi, %esi ; CHECK-NEXT: jne .LBB4_2 ; CHECK-NEXT: # %bb.1: ; CHECK-NEXT: movdqa %xmm1, %xmm0 ; CHECK-NEXT: .LBB4_2: -; CHECK-NEXT: pmovzxdq {{.*#+}} xmm0 = xmm0[0],zero,xmm0[1],zero ; CHECK-NEXT: retq %ld1 = load volatile <2 x i32>, <2 x i32>* %p %arrayidx1 = getelementptr inbounds <2 x i32>, <2 x i32>* %p, i64 1 @@ -100,12 +94,11 @@ define <2 x i64> @zext_vector_i1(<2 x i32>* %p, i1 zeroext %c) { define <2 x i64> @zext_vector_v2i1(<2 x i32>* %p, <2 x i1> %c) { ; CHECK-LABEL: zext_vector_v2i1: ; CHECK: # %bb.0: -; CHECK-NEXT: pshufd {{.*#+}} xmm0 = xmm0[0,2,2,3] -; CHECK-NEXT: pslld $31, %xmm0 -; CHECK-NEXT: movsd {{.*#+}} xmm1 = mem[0],zero -; CHECK-NEXT: movsd {{.*#+}} xmm2 = mem[0],zero -; CHECK-NEXT: blendvps %xmm0, %xmm2, %xmm1 -; CHECK-NEXT: pmovzxdq {{.*#+}} xmm0 = xmm1[0],zero,xmm1[1],zero +; CHECK-NEXT: psllq $63, %xmm0 +; CHECK-NEXT: pmovzxdq {{.*#+}} xmm1 = mem[0],zero,mem[1],zero +; CHECK-NEXT: pmovzxdq {{.*#+}} xmm2 = mem[0],zero,mem[1],zero +; CHECK-NEXT: blendvpd %xmm0, %xmm2, %xmm1 +; CHECK-NEXT: movapd %xmm1, %xmm0 ; CHECK-NEXT: retq %ld1 = load volatile <2 x i32>, <2 x i32>* %p %arrayidx1 = getelementptr inbounds <2 x i32>, <2 x i32>* %p, i64 1 @@ -119,14 +112,13 @@ define <2 x i64> @zext_vector_v2i1(<2 x i32>* %p, <2 x i1> %c) { define <2 x i64> @sext_vector_i1(<2 x i32>* %p, i1 zeroext %c) { ; CHECK-LABEL: sext_vector_i1: ; CHECK: # %bb.0: -; CHECK-NEXT: movq {{.*#+}} xmm1 = mem[0],zero -; CHECK-NEXT: movq {{.*#+}} xmm0 = mem[0],zero +; CHECK-NEXT: pmovsxdq (%rdi), %xmm1 +; CHECK-NEXT: pmovsxdq 8(%rdi), %xmm0 ; CHECK-NEXT: testl %esi, %esi ; CHECK-NEXT: jne .LBB6_2 ; CHECK-NEXT: # %bb.1: ; CHECK-NEXT: movdqa %xmm1, %xmm0 ; CHECK-NEXT: .LBB6_2: -; CHECK-NEXT: pmovsxdq %xmm0, %xmm0 ; CHECK-NEXT: retq %ld1 = load volatile <2 x i32>, <2 x i32>* %p %arrayidx1 = getelementptr inbounds <2 x i32>, <2 x i32>* %p, i64 1 @@ -139,12 +131,11 @@ define <2 x i64> @sext_vector_i1(<2 x i32>* %p, i1 zeroext %c) { define <2 x i64> @sext_vector_v2i1(<2 x i32>* %p, <2 x i1> %c) { ; CHECK-LABEL: sext_vector_v2i1: ; CHECK: # %bb.0: -; CHECK-NEXT: pshufd {{.*#+}} xmm0 = xmm0[0,2,2,3] -; CHECK-NEXT: pslld $31, %xmm0 -; CHECK-NEXT: movsd {{.*#+}} xmm1 = mem[0],zero -; CHECK-NEXT: movsd {{.*#+}} xmm2 = mem[0],zero -; CHECK-NEXT: blendvps %xmm0, %xmm2, %xmm1 -; CHECK-NEXT: pmovsxdq %xmm1, %xmm0 +; CHECK-NEXT: psllq $63, %xmm0 +; CHECK-NEXT: pmovsxdq (%rdi), %xmm1 +; CHECK-NEXT: pmovsxdq 8(%rdi), %xmm2 +; CHECK-NEXT: blendvpd %xmm0, %xmm2, %xmm1 +; CHECK-NEXT: movapd %xmm1, %xmm0 ; CHECK-NEXT: retq %ld1 = load volatile <2 x i32>, <2 x i32>* %p %arrayidx1 = getelementptr inbounds <2 x i32>, <2 x i32>* %p, i64 1