[RISCV] Add support for selecting vid.v from build_vector

This patch optimizes a build_vector "index sequence" and lowers it to
the existing custom RISCVISD::VID node. This pattern is common in
autovectorized code.

The custom node was updated to allow it to be used by both scalable and
fixed-length vectors, thus avoiding pattern duplication.

Reviewed By: craig.topper

Differential Revision: https://reviews.llvm.org/D96332
This commit is contained in:
Fraser Cormack 2021-02-09 12:09:10 +00:00
parent 013613964f
commit a3c74d6d53
5 changed files with 121 additions and 23 deletions

View File

@ -821,19 +821,35 @@ static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG,
MVT VT = Op.getSimpleValueType(); MVT VT = Op.getSimpleValueType();
assert(VT.isFixedLengthVector() && "Unexpected vector!"); assert(VT.isFixedLengthVector() && "Unexpected vector!");
if (SDValue Splat = cast<BuildVectorSDNode>(Op)->getSplatValue()) {
MVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget); MVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
SDLoc DL(Op); SDLoc DL(Op);
SDValue VL = SDValue VL =
DAG.getConstant(VT.getVectorNumElements(), DL, Subtarget.getXLenVT()); DAG.getConstant(VT.getVectorNumElements(), DL, Subtarget.getXLenVT());
if (SDValue Splat = cast<BuildVectorSDNode>(Op)->getSplatValue()) {
unsigned Opc = VT.isFloatingPoint() ? RISCVISD::VFMV_V_F_VL unsigned Opc = VT.isFloatingPoint() ? RISCVISD::VFMV_V_F_VL
: RISCVISD::VMV_V_X_VL; : RISCVISD::VMV_V_X_VL;
Splat = DAG.getNode(Opc, DL, ContainerVT, Splat, VL); Splat = DAG.getNode(Opc, DL, ContainerVT, Splat, VL);
return convertFromScalableVector(VT, Splat, DAG, Subtarget); return convertFromScalableVector(VT, Splat, DAG, Subtarget);
} }
// Try and match an index sequence, which we can lower directly to the vid
// instruction. An all-undef vector is matched by getSplatValue, above.
bool IsVID = true;
if (VT.isInteger())
for (unsigned i = 0, e = Op.getNumOperands(); i < e && IsVID; i++)
IsVID &= Op.getOperand(i).isUndef() ||
(isa<ConstantSDNode>(Op.getOperand(i)) &&
Op.getConstantOperandVal(i) == i);
if (IsVID) {
MVT MaskVT = MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());
SDValue Mask = DAG.getNode(RISCVISD::VMSET_VL, DL, MaskVT, VL);
SDValue VID = DAG.getNode(RISCVISD::VID_VL, DL, ContainerVT, Mask, VL);
return convertFromScalableVector(VT, VID, DAG, Subtarget);
}
return SDValue(); return SDValue();
} }
@ -1706,12 +1722,15 @@ SDValue RISCVTargetLowering::lowerINSERT_VECTOR_ELT(SDValue Op,
SDValue SplattedVal = DAG.getSplatVector(VecVT, DL, Val); SDValue SplattedVal = DAG.getSplatVector(VecVT, DL, Val);
SDValue SplattedIdx = DAG.getNode(RISCVISD::SPLAT_VECTOR_I64, DL, VecVT, Idx); SDValue SplattedIdx = DAG.getNode(RISCVISD::SPLAT_VECTOR_I64, DL, VecVT, Idx);
SDValue VID = DAG.getNode(RISCVISD::VID, DL, VecVT); SDValue VL = DAG.getRegister(RISCV::X0, Subtarget.getXLenVT());
MVT MaskVT = MVT::getVectorVT(MVT::i1, VecVT.getVectorElementCount());
SDValue Mask = DAG.getNode(RISCVISD::VMSET_VL, DL, MaskVT, VL);
SDValue VID = DAG.getNode(RISCVISD::VID_VL, DL, VecVT, Mask, VL);
auto SetCCVT = auto SetCCVT =
getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VecVT); getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VecVT);
SDValue Mask = DAG.getSetCC(DL, SetCCVT, VID, SplattedIdx, ISD::SETEQ); SDValue SelectCond = DAG.getSetCC(DL, SetCCVT, VID, SplattedIdx, ISD::SETEQ);
return DAG.getNode(ISD::VSELECT, DL, VecVT, Mask, SplattedVal, Vec); return DAG.getNode(ISD::VSELECT, DL, VecVT, SelectCond, SplattedVal, Vec);
} }
// Custom-lower EXTRACT_VECTOR_ELT operations to slide the vector down, then // Custom-lower EXTRACT_VECTOR_ELT operations to slide the vector down, then
@ -4586,7 +4605,7 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(VLEFF_MASK) NODE_NAME_CASE(VLEFF_MASK)
NODE_NAME_CASE(VSLIDEUP) NODE_NAME_CASE(VSLIDEUP)
NODE_NAME_CASE(VSLIDEDOWN) NODE_NAME_CASE(VSLIDEDOWN)
NODE_NAME_CASE(VID) NODE_NAME_CASE(VID_VL)
NODE_NAME_CASE(VFNCVT_ROD) NODE_NAME_CASE(VFNCVT_ROD)
NODE_NAME_CASE(VECREDUCE_ADD) NODE_NAME_CASE(VECREDUCE_ADD)
NODE_NAME_CASE(VECREDUCE_UMAX) NODE_NAME_CASE(VECREDUCE_UMAX)

View File

@ -113,8 +113,9 @@ enum NodeType : unsigned {
// XLenVT index (either constant or non-constant). // XLenVT index (either constant or non-constant).
VSLIDEUP, VSLIDEUP,
VSLIDEDOWN, VSLIDEDOWN,
// Matches the semantics of the unmasked vid.v instruction. // Matches the semantics of the vid.v instruction, with a mask and VL
VID, // operand.
VID_VL,
// Matches the semantics of the vfcnvt.rod function (Convert double-width // Matches the semantics of the vfcnvt.rod function (Convert double-width
// float to single-width float, rounding towards odd). Takes a double-width // float to single-width float, rounding towards odd). Takes a double-width
// float vector and produces a single-width float vector. // float vector and produces a single-width float vector.

View File

@ -819,13 +819,6 @@ foreach vti = AllFloatVectors in {
vti.AVL, vti.SEW)>; vti.AVL, vti.SEW)>;
} }
//===----------------------------------------------------------------------===//
// Miscellaneous RISCVISD SDNodes
//===----------------------------------------------------------------------===//
def riscv_vid
: SDNode<"RISCVISD::VID", SDTypeProfile<1, 0, [SDTCisVec<0>]>, []>;
def SDTRVVSlide : SDTypeProfile<1, 3, [ def SDTRVVSlide : SDTypeProfile<1, 3, [
SDTCisVec<0>, SDTCisSameAs<1, 0>, SDTCisSameAs<2, 0>, SDTCisVT<3, XLenVT> SDTCisVec<0>, SDTCisSameAs<1, 0>, SDTCisSameAs<2, 0>, SDTCisVT<3, XLenVT>
]>; ]>;
@ -835,10 +828,6 @@ def riscv_slidedown : SDNode<"RISCVISD::VSLIDEDOWN", SDTRVVSlide, []>;
let Predicates = [HasStdExtV] in { let Predicates = [HasStdExtV] in {
foreach vti = AllIntegerVectors in
def : Pat<(vti.Vector riscv_vid),
(!cast<Instruction>("PseudoVID_V_"#vti.LMul.MX) vti.AVL, vti.SEW)>;
foreach vti = !listconcat(AllIntegerVectors, AllFloatVectors) in { foreach vti = !listconcat(AllIntegerVectors, AllFloatVectors) in {
def : Pat<(vti.Vector (riscv_slideup (vti.Vector vti.RegClass:$rs3), def : Pat<(vti.Vector (riscv_slideup (vti.Vector vti.RegClass:$rs3),
(vti.Vector vti.RegClass:$rs1), (vti.Vector vti.RegClass:$rs1),

View File

@ -210,3 +210,20 @@ foreach vti = AllFloatVectors in {
} }
} // Predicates = [HasStdExtV, HasStdExtF] } // Predicates = [HasStdExtV, HasStdExtF]
//===----------------------------------------------------------------------===//
// Miscellaneous RISCVISD SDNodes
//===----------------------------------------------------------------------===//
def riscv_vid_vl : SDNode<"RISCVISD::VID_VL", SDTypeProfile<1, 2,
[SDTCisVec<0>, SDTCisVec<1>, SDTCVecEltisVT<1, i1>,
SDTCisSameNumEltsAs<0, 1>, SDTCisVT<2, XLenVT>]>, []>;
let Predicates = [HasStdExtV] in {
foreach vti = AllIntegerVectors in
def : Pat<(vti.Vector (riscv_vid_vl (vti.Mask true_mask),
(XLenVT (VLOp GPR:$vl)))),
(!cast<Instruction>("PseudoVID_V_"#vti.LMul.MX) GPR:$vl, vti.SEW)>;
} // Predicates = [HasStdExtV]

View File

@ -0,0 +1,72 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc -mtriple=riscv32 -mattr=+experimental-v -verify-machineinstrs -riscv-v-vector-bits-min=128 -riscv-v-fixed-length-vector-lmul-max=1 -verify-machineinstrs < %s | FileCheck %s
; RUN: llc -mtriple=riscv64 -mattr=+experimental-v -verify-machineinstrs -riscv-v-vector-bits-min=128 -riscv-v-fixed-length-vector-lmul-max=1 -verify-machineinstrs < %s | FileCheck %s
define void @buildvec_vid_v16i8(<16 x i8>* %x) {
; CHECK-LABEL: buildvec_vid_v16i8:
; CHECK: # %bb.0:
; CHECK-NEXT: addi a1, zero, 16
; CHECK-NEXT: vsetvli a1, a1, e8,m1,ta,mu
; CHECK-NEXT: vid.v v25
; CHECK-NEXT: vse8.v v25, (a0)
; CHECK-NEXT: ret
store <16 x i8> <i8 0, i8 1, i8 2, i8 3, i8 4, i8 5, i8 6, i8 7, i8 8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15>, <16 x i8>* %x
ret void
}
define void @buildvec_vid_undefelts_v16i8(<16 x i8>* %x) {
; CHECK-LABEL: buildvec_vid_undefelts_v16i8:
; CHECK: # %bb.0:
; CHECK-NEXT: addi a1, zero, 16
; CHECK-NEXT: vsetvli a1, a1, e8,m1,ta,mu
; CHECK-NEXT: vid.v v25
; CHECK-NEXT: vse8.v v25, (a0)
; CHECK-NEXT: ret
store <16 x i8> <i8 0, i8 1, i8 2, i8 undef, i8 4, i8 undef, i8 6, i8 7, i8 8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15>, <16 x i8>* %x
ret void
}
; TODO: Could do VID then insertelement on missing elements
define void @buildvec_notquite_vid_v16i8(<16 x i8>* %x) {
; CHECK-LABEL: buildvec_notquite_vid_v16i8:
; CHECK: # %bb.0:
; CHECK-NEXT: lui a1, %hi(.LCPI2_0)
; CHECK-NEXT: addi a1, a1, %lo(.LCPI2_0)
; CHECK-NEXT: addi a2, zero, 16
; CHECK-NEXT: vsetvli a2, a2, e8,m1,ta,mu
; CHECK-NEXT: vle8.v v25, (a1)
; CHECK-NEXT: vse8.v v25, (a0)
; CHECK-NEXT: ret
store <16 x i8> <i8 0, i8 1, i8 3, i8 3, i8 4, i8 5, i8 6, i8 7, i8 8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15>, <16 x i8>* %x
ret void
}
; TODO: Could do VID then add a constant splat
define void @buildvec_vid_plus_imm_v16i8(<16 x i8>* %x) {
; CHECK-LABEL: buildvec_vid_plus_imm_v16i8:
; CHECK: # %bb.0:
; CHECK-NEXT: lui a1, %hi(.LCPI3_0)
; CHECK-NEXT: addi a1, a1, %lo(.LCPI3_0)
; CHECK-NEXT: addi a2, zero, 16
; CHECK-NEXT: vsetvli a2, a2, e8,m1,ta,mu
; CHECK-NEXT: vle8.v v25, (a1)
; CHECK-NEXT: vse8.v v25, (a0)
; CHECK-NEXT: ret
store <16 x i8> <i8 2, i8 3, i8 4, i8 5, i8 6, i8 7, i8 8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 16, i8 17>, <16 x i8>* %x
ret void
}
; TODO: Could do VID then multiply by a constant splat
define void @buildvec_vid_mpy_imm_v16i8(<16 x i8>* %x) {
; CHECK-LABEL: buildvec_vid_mpy_imm_v16i8:
; CHECK: # %bb.0:
; CHECK-NEXT: lui a1, %hi(.LCPI4_0)
; CHECK-NEXT: addi a1, a1, %lo(.LCPI4_0)
; CHECK-NEXT: addi a2, zero, 16
; CHECK-NEXT: vsetvli a2, a2, e8,m1,ta,mu
; CHECK-NEXT: vle8.v v25, (a1)
; CHECK-NEXT: vse8.v v25, (a0)
; CHECK-NEXT: ret
store <16 x i8> <i8 0, i8 3, i8 6, i8 9, i8 12, i8 15, i8 18, i8 21, i8 24, i8 27, i8 30, i8 33, i8 36, i8 39, i8 42, i8 45>, <16 x i8>* %x
ret void
}