[SelectionDAG][VP] Provide expansion for VP_MERGE

This patch adds support for expanding VP_MERGE through a sequence of
vector operations producing a full-length mask setting up the elements
past EVL/pivot to be false, combining this with the original mask, and
culminating in a full-length vector select.

This expansion should work for any data type, though the only use for
RVV is for boolean vectors, which themselves rely on an expansion for
the VSELECT.

Reviewed By: craig.topper

Differential Revision: https://reviews.llvm.org/D118058
This commit is contained in:
Fraser Cormack 2022-01-24 18:11:14 +00:00
parent 6730df4779
commit 84e85e025e
4 changed files with 119 additions and 16 deletions

View File

@ -134,6 +134,7 @@ class VectorLegalizer {
/// supported by the target.
SDValue ExpandVSELECT(SDNode *Node);
SDValue ExpandVP_SELECT(SDNode *Node);
SDValue ExpandVP_MERGE(SDNode *Node);
SDValue ExpandSELECT(SDNode *Node);
std::pair<SDValue, SDValue> ExpandLoad(SDNode *N);
SDValue ExpandStore(SDNode *N);
@ -877,6 +878,9 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
case ISD::UREM:
ExpandREM(Node, Results);
return;
case ISD::VP_MERGE:
Results.push_back(ExpandVP_MERGE(Node));
return;
}
Results.push_back(DAG.UnrollVectorOp(Node));
@ -1238,6 +1242,48 @@ SDValue VectorLegalizer::ExpandVP_SELECT(SDNode *Node) {
return DAG.getNode(ISD::VP_OR, DL, VT, Op1, Op2, Mask, EVL);
}
SDValue VectorLegalizer::ExpandVP_MERGE(SDNode *Node) {
// Implement VP_MERGE in terms of VSELECT. Construct a mask where vector
// indices less than the EVL/pivot are true. Combine that with the original
// mask for a full-length mask. Use a full-length VSELECT to select between
// the true and false values.
SDLoc DL(Node);
SDValue Mask = Node->getOperand(0);
SDValue Op1 = Node->getOperand(1);
SDValue Op2 = Node->getOperand(2);
SDValue EVL = Node->getOperand(3);
EVT MaskVT = Mask.getValueType();
bool IsFixedLen = MaskVT.isFixedLengthVector();
EVT EVLVecVT = EVT::getVectorVT(*DAG.getContext(), EVL.getValueType(),
MaskVT.getVectorElementCount());
// If we can't construct the EVL mask efficiently, it's better to unroll.
if ((IsFixedLen &&
!TLI.isOperationLegalOrCustom(ISD::BUILD_VECTOR, EVLVecVT)) ||
(!IsFixedLen &&
(!TLI.isOperationLegalOrCustom(ISD::STEP_VECTOR, EVLVecVT) ||
!TLI.isOperationLegalOrCustom(ISD::SPLAT_VECTOR, EVLVecVT))))
return DAG.UnrollVectorOp(Node);
// If using a SETCC would result in a different type than the mask type,
// unroll.
if (TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(),
EVLVecVT) != MaskVT)
return DAG.UnrollVectorOp(Node);
SDValue StepVec = DAG.getStepVector(DL, EVLVecVT);
SDValue SplatEVL = IsFixedLen ? DAG.getSplatBuildVector(EVLVecVT, DL, EVL)
: DAG.getSplatVector(EVLVecVT, DL, EVL);
SDValue EVLMask =
DAG.getSetCC(DL, MaskVT, StepVec, SplatEVL, ISD::CondCode::SETULT);
SDValue FullMask = DAG.getNode(ISD::AND, DL, MaskVT, Mask, EVLMask);
return DAG.getSelect(DL, Node->getValueType(0), FullMask, Op1, Op2);
}
void VectorLegalizer::ExpandFP_TO_UINT(SDNode *Node,
SmallVectorImpl<SDValue> &Results) {
// Attempt to expand using TargetLowering.

View File

@ -572,6 +572,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::SELECT, VT, Custom);
setOperationAction(ISD::SELECT_CC, VT, Expand);
setOperationAction(ISD::VSELECT, VT, Expand);
setOperationAction(ISD::VP_MERGE, VT, Expand);
setOperationAction(ISD::VP_SELECT, VT, Expand);
setOperationAction(ISD::VP_AND, VT, Custom);

View File

@ -4,6 +4,34 @@
; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+v,+m -target-abi=lp64d -riscv-v-vector-bits-min=128 \
; RUN: -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,RV64
declare <4 x i1> @llvm.vp.merge.v4i1(<4 x i1>, <4 x i1>, <4 x i1>, i32)
define <4 x i1> @vpmerge_vv_v4i1(<4 x i1> %va, <4 x i1> %vb, <4 x i1> %m, i32 zeroext %evl) {
; RV32-LABEL: vpmerge_vv_v4i1:
; RV32: # %bb.0:
; RV32-NEXT: vsetivli zero, 4, e32, m1, ta, mu
; RV32-NEXT: vid.v v10
; RV32-NEXT: vmsltu.vx v10, v10, a0
; RV32-NEXT: vmand.mm v9, v9, v10
; RV32-NEXT: vmandn.mm v8, v8, v9
; RV32-NEXT: vmand.mm v9, v0, v9
; RV32-NEXT: vmor.mm v0, v9, v8
; RV32-NEXT: ret
;
; RV64-LABEL: vpmerge_vv_v4i1:
; RV64: # %bb.0:
; RV64-NEXT: vsetivli zero, 4, e64, m2, ta, mu
; RV64-NEXT: vid.v v10
; RV64-NEXT: vmsltu.vx v12, v10, a0
; RV64-NEXT: vmand.mm v9, v9, v12
; RV64-NEXT: vmandn.mm v8, v8, v9
; RV64-NEXT: vmand.mm v9, v0, v9
; RV64-NEXT: vmor.mm v0, v9, v8
; RV64-NEXT: ret
%v = call <4 x i1> @llvm.vp.merge.v4i1(<4 x i1> %m, <4 x i1> %va, <4 x i1> %vb, i32 %evl)
ret <4 x i1> %v
}
declare <2 x i8> @llvm.vp.merge.v2i8(<2 x i1>, <2 x i8>, <2 x i8>, i32)
define <2 x i8> @vpmerge_vv_v2i8(<2 x i8> %va, <2 x i8> %vb, <2 x i1> %m, i32 zeroext %evl) {

View File

@ -4,6 +4,34 @@
; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+v,+m -target-abi=lp64d \
; RUN: -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,RV64
declare <vscale x 1 x i1> @llvm.vp.merge.nxv1i1(<vscale x 1 x i1>, <vscale x 1 x i1>, <vscale x 1 x i1>, i32)
define <vscale x 1 x i1> @vpmerge_nxv1i1(<vscale x 1 x i1> %va, <vscale x 1 x i1> %vb, <vscale x 1 x i1> %m, i32 zeroext %evl) {
; RV32-LABEL: vpmerge_nxv1i1:
; RV32: # %bb.0:
; RV32-NEXT: vsetvli a1, zero, e32, mf2, ta, mu
; RV32-NEXT: vid.v v10
; RV32-NEXT: vmsltu.vx v10, v10, a0
; RV32-NEXT: vmand.mm v9, v9, v10
; RV32-NEXT: vmandn.mm v8, v8, v9
; RV32-NEXT: vmand.mm v9, v0, v9
; RV32-NEXT: vmor.mm v0, v9, v8
; RV32-NEXT: ret
;
; RV64-LABEL: vpmerge_nxv1i1:
; RV64: # %bb.0:
; RV64-NEXT: vsetvli a1, zero, e64, m1, ta, mu
; RV64-NEXT: vid.v v10
; RV64-NEXT: vmsltu.vx v10, v10, a0
; RV64-NEXT: vmand.mm v9, v9, v10
; RV64-NEXT: vmandn.mm v8, v8, v9
; RV64-NEXT: vmand.mm v9, v0, v9
; RV64-NEXT: vmor.mm v0, v9, v8
; RV64-NEXT: ret
%v = call <vscale x 1 x i1> @llvm.vp.merge.nxv1i1(<vscale x 1 x i1> %m, <vscale x 1 x i1> %va, <vscale x 1 x i1> %vb, i32 %evl)
ret <vscale x 1 x i1> %v
}
declare <vscale x 1 x i8> @llvm.vp.merge.nxv1i8(<vscale x 1 x i1>, <vscale x 1 x i8>, <vscale x 1 x i8>, i32)
define <vscale x 1 x i8> @vpmerge_vv_nxv1i8(<vscale x 1 x i8> %va, <vscale x 1 x i8> %vb, <vscale x 1 x i1> %m, i32 zeroext %evl) {
@ -332,10 +360,10 @@ define <vscale x 128 x i8> @vpmerge_vv_nxv128i8(<vscale x 128 x i8> %va, <vscale
; RV32-NEXT: addi a2, sp, 16
; RV32-NEXT: vs8r.v v8, (a2) # Unknown-size Folded Spill
; RV32-NEXT: li a2, 0
; RV32-NEXT: bltu a3, a4, .LBB24_2
; RV32-NEXT: bltu a3, a4, .LBB25_2
; RV32-NEXT: # %bb.1:
; RV32-NEXT: mv a2, a4
; RV32-NEXT: .LBB24_2:
; RV32-NEXT: .LBB25_2:
; RV32-NEXT: vl8r.v v8, (a0)
; RV32-NEXT: vsetvli zero, a2, e8, m8, tu, mu
; RV32-NEXT: vmv1r.v v0, v2
@ -350,10 +378,10 @@ define <vscale x 128 x i8> @vpmerge_vv_nxv128i8(<vscale x 128 x i8> %va, <vscale
; RV32-NEXT: addi a0, a0, 16
; RV32-NEXT: vl8re8.v v16, (a0) # Unknown-size Folded Reload
; RV32-NEXT: vmerge.vvm v16, v16, v24, v0
; RV32-NEXT: bltu a3, a1, .LBB24_4
; RV32-NEXT: bltu a3, a1, .LBB25_4
; RV32-NEXT: # %bb.3:
; RV32-NEXT: mv a3, a1
; RV32-NEXT: .LBB24_4:
; RV32-NEXT: .LBB25_4:
; RV32-NEXT: vsetvli zero, a3, e8, m8, tu, mu
; RV32-NEXT: vmv1r.v v0, v1
; RV32-NEXT: addi a0, sp, 16
@ -384,18 +412,18 @@ define <vscale x 128 x i8> @vpmerge_vv_nxv128i8(<vscale x 128 x i8> %va, <vscale
; RV64-NEXT: addi a2, sp, 16
; RV64-NEXT: vs8r.v v8, (a2) # Unknown-size Folded Spill
; RV64-NEXT: li a2, 0
; RV64-NEXT: bltu a3, a4, .LBB24_2
; RV64-NEXT: bltu a3, a4, .LBB25_2
; RV64-NEXT: # %bb.1:
; RV64-NEXT: mv a2, a4
; RV64-NEXT: .LBB24_2:
; RV64-NEXT: .LBB25_2:
; RV64-NEXT: vl8r.v v8, (a0)
; RV64-NEXT: vsetvli zero, a2, e8, m8, tu, mu
; RV64-NEXT: vmv1r.v v0, v2
; RV64-NEXT: vmerge.vvm v24, v24, v16, v0
; RV64-NEXT: bltu a3, a1, .LBB24_4
; RV64-NEXT: bltu a3, a1, .LBB25_4
; RV64-NEXT: # %bb.3:
; RV64-NEXT: mv a3, a1
; RV64-NEXT: .LBB24_4:
; RV64-NEXT: .LBB25_4:
; RV64-NEXT: vsetvli zero, a3, e8, m8, tu, mu
; RV64-NEXT: vmv1r.v v0, v1
; RV64-NEXT: addi a0, sp, 16
@ -417,20 +445,20 @@ define <vscale x 128 x i8> @vpmerge_vx_nxv128i8(i8 %a, <vscale x 128 x i8> %vb,
; CHECK-NEXT: csrr a3, vlenb
; CHECK-NEXT: slli a3, a3, 3
; CHECK-NEXT: mv a4, a2
; CHECK-NEXT: bltu a2, a3, .LBB25_2
; CHECK-NEXT: bltu a2, a3, .LBB26_2
; CHECK-NEXT: # %bb.1:
; CHECK-NEXT: mv a4, a3
; CHECK-NEXT: .LBB25_2:
; CHECK-NEXT: .LBB26_2:
; CHECK-NEXT: li a5, 0
; CHECK-NEXT: vsetvli a6, zero, e8, m8, ta, mu
; CHECK-NEXT: vlm.v v24, (a1)
; CHECK-NEXT: vsetvli zero, a4, e8, m8, tu, mu
; CHECK-NEXT: sub a1, a2, a3
; CHECK-NEXT: vmerge.vxm v8, v8, a0, v0
; CHECK-NEXT: bltu a2, a1, .LBB25_4
; CHECK-NEXT: bltu a2, a1, .LBB26_4
; CHECK-NEXT: # %bb.3:
; CHECK-NEXT: mv a5, a1
; CHECK-NEXT: .LBB25_4:
; CHECK-NEXT: .LBB26_4:
; CHECK-NEXT: vsetvli zero, a5, e8, m8, tu, mu
; CHECK-NEXT: vmv1r.v v0, v24
; CHECK-NEXT: vmerge.vxm v16, v16, a0, v0
@ -447,20 +475,20 @@ define <vscale x 128 x i8> @vpmerge_vi_nxv128i8(<vscale x 128 x i8> %vb, <vscale
; CHECK-NEXT: csrr a2, vlenb
; CHECK-NEXT: slli a2, a2, 3
; CHECK-NEXT: mv a3, a1
; CHECK-NEXT: bltu a1, a2, .LBB26_2
; CHECK-NEXT: bltu a1, a2, .LBB27_2
; CHECK-NEXT: # %bb.1:
; CHECK-NEXT: mv a3, a2
; CHECK-NEXT: .LBB26_2:
; CHECK-NEXT: .LBB27_2:
; CHECK-NEXT: li a4, 0
; CHECK-NEXT: vsetvli a5, zero, e8, m8, ta, mu
; CHECK-NEXT: vlm.v v24, (a0)
; CHECK-NEXT: vsetvli zero, a3, e8, m8, tu, mu
; CHECK-NEXT: sub a0, a1, a2
; CHECK-NEXT: vmerge.vim v8, v8, 2, v0
; CHECK-NEXT: bltu a1, a0, .LBB26_4
; CHECK-NEXT: bltu a1, a0, .LBB27_4
; CHECK-NEXT: # %bb.3:
; CHECK-NEXT: mv a4, a0
; CHECK-NEXT: .LBB26_4:
; CHECK-NEXT: .LBB27_4:
; CHECK-NEXT: vsetvli zero, a4, e8, m8, tu, mu
; CHECK-NEXT: vmv1r.v v0, v24
; CHECK-NEXT: vmerge.vim v16, v16, 2, v0