From 19e523514714688e4eb6588925eab079997e21a6 Mon Sep 17 00:00:00 2001 From: zhongyunde Date: Wed, 6 Apr 2022 20:47:32 +0800 Subject: [PATCH] [AArch64][InstCombine] Fold MLOAD and zero extensions into MLOAD Accord the discussion in D122281, we missing an ISD::AND combine for MLOAD because it relies on BuildVectorSDNode is fails for scalable vectors. This patch is intend to handle that, so we can circle back the type MVT::nxv2i32 Reviewed By: paulwalker-arm Differential Revision: https://reviews.llvm.org/D122703 --- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 26 +++++++++---------- .../Target/AArch64/AArch64ISelLowering.cpp | 1 - .../CodeGen/AArch64/sve-masked-ldst-zext.ll | 2 +- 3 files changed, 13 insertions(+), 16 deletions(-) diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index bd38e1669144..7bf89ca1580d 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -6067,27 +6067,25 @@ SDValue DAGCombiner::visitAND(SDNode *N) { if (ISD::isConstantSplatVectorAllOnes(N1.getNode())) return N0; - // fold (and (masked_load) (build_vec (x, ...))) to zext_masked_load + // fold (and (masked_load) (splat_vec (x, ...))) to zext_masked_load auto *MLoad = dyn_cast(N0); - auto *BVec = dyn_cast(N1); - if (MLoad && BVec && MLoad->getExtensionType() == ISD::EXTLOAD && - N0.hasOneUse() && N1.hasOneUse()) { + ConstantSDNode *Splat = isConstOrConstSplat(N1, true, true); + if (MLoad && MLoad->getExtensionType() == ISD::EXTLOAD && N0.hasOneUse() && + Splat && N1.hasOneUse()) { EVT LoadVT = MLoad->getMemoryVT(); EVT ExtVT = VT; if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, ExtVT, LoadVT)) { // For this AND to be a zero extension of the masked load the elements // of the BuildVec must mask the bottom bits of the extended element // type - if (ConstantSDNode *Splat = BVec->getConstantSplatNode()) { - uint64_t ElementSize = - LoadVT.getVectorElementType().getScalarSizeInBits(); - if (Splat->getAPIntValue().isMask(ElementSize)) { - return DAG.getMaskedLoad( - ExtVT, SDLoc(N), MLoad->getChain(), MLoad->getBasePtr(), - MLoad->getOffset(), MLoad->getMask(), MLoad->getPassThru(), - LoadVT, MLoad->getMemOperand(), MLoad->getAddressingMode(), - ISD::ZEXTLOAD, MLoad->isExpandingLoad()); - } + uint64_t ElementSize = + LoadVT.getVectorElementType().getScalarSizeInBits(); + if (Splat->getAPIntValue().isMask(ElementSize)) { + return DAG.getMaskedLoad( + ExtVT, SDLoc(N), MLoad->getChain(), MLoad->getBasePtr(), + MLoad->getOffset(), MLoad->getMask(), MLoad->getPassThru(), + LoadVT, MLoad->getMemOperand(), MLoad->getAddressingMode(), + ISD::ZEXTLOAD, MLoad->isExpandingLoad()); } } } diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 68c8d73fbcb6..a00d2a5661ce 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1230,7 +1230,6 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setLoadExtAction(Op, MVT::nxv2i64, MVT::nxv2i16, Legal); setLoadExtAction(Op, MVT::nxv2i64, MVT::nxv2i32, Legal); setLoadExtAction(Op, MVT::nxv4i32, MVT::nxv4i8, Legal); - setLoadExtAction(Op, MVT::nxv2i32, MVT::nxv2i16, Legal); setLoadExtAction(Op, MVT::nxv4i32, MVT::nxv4i16, Legal); setLoadExtAction(Op, MVT::nxv8i16, MVT::nxv8i8, Legal); } diff --git a/llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll b/llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll index 66c4798f3059..1ddcd16f1c12 100644 --- a/llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll +++ b/llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll @@ -97,7 +97,7 @@ define @masked_zload_2i16_2f64(* noalias ; CHECK-LABEL: masked_zload_2i16_2f64: ; CHECK: ld1h { z0.d }, p0/z, [x0] ; CHECK-NEXT: ptrue p0.d -; CHECK-NEXT: ucvtf z0.d, p0/m, z0.s +; CHECK-NEXT: ucvtf z0.d, p0/m, z0.d ; CHECK-NEXT: ret %wide.load = call @llvm.masked.load.nxv2i16(* %in, i32 2, %mask, undef) %zext = zext %wide.load to