diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp index cbb28863850f..abf6a3ac6916 100644 --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp @@ -134,6 +134,7 @@ class VectorLegalizer { /// supported by the target. SDValue ExpandVSELECT(SDNode *Node); SDValue ExpandVP_SELECT(SDNode *Node); + SDValue ExpandVP_MERGE(SDNode *Node); SDValue ExpandSELECT(SDNode *Node); std::pair ExpandLoad(SDNode *N); SDValue ExpandStore(SDNode *N); @@ -877,6 +878,9 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl &Results) { case ISD::UREM: ExpandREM(Node, Results); return; + case ISD::VP_MERGE: + Results.push_back(ExpandVP_MERGE(Node)); + return; } Results.push_back(DAG.UnrollVectorOp(Node)); @@ -1238,6 +1242,48 @@ SDValue VectorLegalizer::ExpandVP_SELECT(SDNode *Node) { return DAG.getNode(ISD::VP_OR, DL, VT, Op1, Op2, Mask, EVL); } +SDValue VectorLegalizer::ExpandVP_MERGE(SDNode *Node) { + // Implement VP_MERGE in terms of VSELECT. Construct a mask where vector + // indices less than the EVL/pivot are true. Combine that with the original + // mask for a full-length mask. Use a full-length VSELECT to select between + // the true and false values. + SDLoc DL(Node); + + SDValue Mask = Node->getOperand(0); + SDValue Op1 = Node->getOperand(1); + SDValue Op2 = Node->getOperand(2); + SDValue EVL = Node->getOperand(3); + + EVT MaskVT = Mask.getValueType(); + bool IsFixedLen = MaskVT.isFixedLengthVector(); + + EVT EVLVecVT = EVT::getVectorVT(*DAG.getContext(), EVL.getValueType(), + MaskVT.getVectorElementCount()); + + // If we can't construct the EVL mask efficiently, it's better to unroll. + if ((IsFixedLen && + !TLI.isOperationLegalOrCustom(ISD::BUILD_VECTOR, EVLVecVT)) || + (!IsFixedLen && + (!TLI.isOperationLegalOrCustom(ISD::STEP_VECTOR, EVLVecVT) || + !TLI.isOperationLegalOrCustom(ISD::SPLAT_VECTOR, EVLVecVT)))) + return DAG.UnrollVectorOp(Node); + + // If using a SETCC would result in a different type than the mask type, + // unroll. + if (TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), + EVLVecVT) != MaskVT) + return DAG.UnrollVectorOp(Node); + + SDValue StepVec = DAG.getStepVector(DL, EVLVecVT); + SDValue SplatEVL = IsFixedLen ? DAG.getSplatBuildVector(EVLVecVT, DL, EVL) + : DAG.getSplatVector(EVLVecVT, DL, EVL); + SDValue EVLMask = + DAG.getSetCC(DL, MaskVT, StepVec, SplatEVL, ISD::CondCode::SETULT); + + SDValue FullMask = DAG.getNode(ISD::AND, DL, MaskVT, Mask, EVLMask); + return DAG.getSelect(DL, Node->getValueType(0), FullMask, Op1, Op2); +} + void VectorLegalizer::ExpandFP_TO_UINT(SDNode *Node, SmallVectorImpl &Results) { // Attempt to expand using TargetLowering. diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 205f71a6fe47..5cc3aa35d4d2 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -572,6 +572,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, setOperationAction(ISD::SELECT, VT, Custom); setOperationAction(ISD::SELECT_CC, VT, Expand); setOperationAction(ISD::VSELECT, VT, Expand); + setOperationAction(ISD::VP_MERGE, VT, Expand); setOperationAction(ISD::VP_SELECT, VT, Expand); setOperationAction(ISD::VP_AND, VT, Custom); diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vpmerge.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vpmerge.ll index 8ac3184f02c4..f99a90df1fb2 100644 --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vpmerge.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vpmerge.ll @@ -4,6 +4,34 @@ ; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+v,+m -target-abi=lp64d -riscv-v-vector-bits-min=128 \ ; RUN: -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,RV64 +declare <4 x i1> @llvm.vp.merge.v4i1(<4 x i1>, <4 x i1>, <4 x i1>, i32) + +define <4 x i1> @vpmerge_vv_v4i1(<4 x i1> %va, <4 x i1> %vb, <4 x i1> %m, i32 zeroext %evl) { +; RV32-LABEL: vpmerge_vv_v4i1: +; RV32: # %bb.0: +; RV32-NEXT: vsetivli zero, 4, e32, m1, ta, mu +; RV32-NEXT: vid.v v10 +; RV32-NEXT: vmsltu.vx v10, v10, a0 +; RV32-NEXT: vmand.mm v9, v9, v10 +; RV32-NEXT: vmandn.mm v8, v8, v9 +; RV32-NEXT: vmand.mm v9, v0, v9 +; RV32-NEXT: vmor.mm v0, v9, v8 +; RV32-NEXT: ret +; +; RV64-LABEL: vpmerge_vv_v4i1: +; RV64: # %bb.0: +; RV64-NEXT: vsetivli zero, 4, e64, m2, ta, mu +; RV64-NEXT: vid.v v10 +; RV64-NEXT: vmsltu.vx v12, v10, a0 +; RV64-NEXT: vmand.mm v9, v9, v12 +; RV64-NEXT: vmandn.mm v8, v8, v9 +; RV64-NEXT: vmand.mm v9, v0, v9 +; RV64-NEXT: vmor.mm v0, v9, v8 +; RV64-NEXT: ret + %v = call <4 x i1> @llvm.vp.merge.v4i1(<4 x i1> %m, <4 x i1> %va, <4 x i1> %vb, i32 %evl) + ret <4 x i1> %v +} + declare <2 x i8> @llvm.vp.merge.v2i8(<2 x i1>, <2 x i8>, <2 x i8>, i32) define <2 x i8> @vpmerge_vv_v2i8(<2 x i8> %va, <2 x i8> %vb, <2 x i1> %m, i32 zeroext %evl) { diff --git a/llvm/test/CodeGen/RISCV/rvv/vpmerge-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/vpmerge-sdnode.ll index 6a4ac666b110..3ea8029aaffe 100644 --- a/llvm/test/CodeGen/RISCV/rvv/vpmerge-sdnode.ll +++ b/llvm/test/CodeGen/RISCV/rvv/vpmerge-sdnode.ll @@ -4,6 +4,34 @@ ; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+v,+m -target-abi=lp64d \ ; RUN: -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,RV64 +declare @llvm.vp.merge.nxv1i1(, , , i32) + +define @vpmerge_nxv1i1( %va, %vb, %m, i32 zeroext %evl) { +; RV32-LABEL: vpmerge_nxv1i1: +; RV32: # %bb.0: +; RV32-NEXT: vsetvli a1, zero, e32, mf2, ta, mu +; RV32-NEXT: vid.v v10 +; RV32-NEXT: vmsltu.vx v10, v10, a0 +; RV32-NEXT: vmand.mm v9, v9, v10 +; RV32-NEXT: vmandn.mm v8, v8, v9 +; RV32-NEXT: vmand.mm v9, v0, v9 +; RV32-NEXT: vmor.mm v0, v9, v8 +; RV32-NEXT: ret +; +; RV64-LABEL: vpmerge_nxv1i1: +; RV64: # %bb.0: +; RV64-NEXT: vsetvli a1, zero, e64, m1, ta, mu +; RV64-NEXT: vid.v v10 +; RV64-NEXT: vmsltu.vx v10, v10, a0 +; RV64-NEXT: vmand.mm v9, v9, v10 +; RV64-NEXT: vmandn.mm v8, v8, v9 +; RV64-NEXT: vmand.mm v9, v0, v9 +; RV64-NEXT: vmor.mm v0, v9, v8 +; RV64-NEXT: ret + %v = call @llvm.vp.merge.nxv1i1( %m, %va, %vb, i32 %evl) + ret %v +} + declare @llvm.vp.merge.nxv1i8(, , , i32) define @vpmerge_vv_nxv1i8( %va, %vb, %m, i32 zeroext %evl) { @@ -332,10 +360,10 @@ define @vpmerge_vv_nxv128i8( %va, @vpmerge_vv_nxv128i8( %va, @vpmerge_vv_nxv128i8( %va, @vpmerge_vx_nxv128i8(i8 %a, %vb, ; CHECK-NEXT: csrr a3, vlenb ; CHECK-NEXT: slli a3, a3, 3 ; CHECK-NEXT: mv a4, a2 -; CHECK-NEXT: bltu a2, a3, .LBB25_2 +; CHECK-NEXT: bltu a2, a3, .LBB26_2 ; CHECK-NEXT: # %bb.1: ; CHECK-NEXT: mv a4, a3 -; CHECK-NEXT: .LBB25_2: +; CHECK-NEXT: .LBB26_2: ; CHECK-NEXT: li a5, 0 ; CHECK-NEXT: vsetvli a6, zero, e8, m8, ta, mu ; CHECK-NEXT: vlm.v v24, (a1) ; CHECK-NEXT: vsetvli zero, a4, e8, m8, tu, mu ; CHECK-NEXT: sub a1, a2, a3 ; CHECK-NEXT: vmerge.vxm v8, v8, a0, v0 -; CHECK-NEXT: bltu a2, a1, .LBB25_4 +; CHECK-NEXT: bltu a2, a1, .LBB26_4 ; CHECK-NEXT: # %bb.3: ; CHECK-NEXT: mv a5, a1 -; CHECK-NEXT: .LBB25_4: +; CHECK-NEXT: .LBB26_4: ; CHECK-NEXT: vsetvli zero, a5, e8, m8, tu, mu ; CHECK-NEXT: vmv1r.v v0, v24 ; CHECK-NEXT: vmerge.vxm v16, v16, a0, v0 @@ -447,20 +475,20 @@ define @vpmerge_vi_nxv128i8( %vb,