[RISCV] Unify scalable- and fixed-vector EXTRACT_SUBVECTOR lowering

This patch unifies the two disparate paths for lowering
EXTRACT_SUBVECTOR operations under one roof. Consequently, with this
patch it is possible to support any fixed-length subvector extraction,
not just "cast-like" ones.

Reviewed By: craig.topper

Differential Revision: https://reviews.llvm.org/D97192
This commit is contained in:
Fraser Cormack 2021-02-22 14:14:10 +00:00
parent d0a6f8bb65
commit 821f8bb29a
4 changed files with 245 additions and 66 deletions

View File

@ -994,59 +994,44 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
MVT InVT = V.getSimpleValueType();
SDLoc DL(V);
// TODO: This method of selecting EXTRACT_SUBVECTOR should work
// with any type of extraction (fixed <-> scalable) but we don't yet
// correctly identify the canonical register class for fixed-length types.
// For now, keep the two paths separate.
if (VT.isScalableVector() && InVT.isScalableVector()) {
const auto *TRI = Subtarget->getRegisterInfo();
unsigned SubRegIdx;
std::tie(SubRegIdx, Idx) =
RISCVTargetLowering::decomposeSubvectorInsertExtractToSubRegs(
InVT, VT, Idx, TRI);
MVT SubVecContainerVT = VT;
// Establish the correct scalable-vector types for any fixed-length type.
if (VT.isFixedLengthVector())
SubVecContainerVT = RISCVTargetLowering::getContainerForFixedLengthVector(
*CurDAG, VT, *Subtarget);
if (InVT.isFixedLengthVector())
InVT = RISCVTargetLowering::getContainerForFixedLengthVector(
*CurDAG, InVT, *Subtarget);
// If the Idx hasn't been completely eliminated then this is a subvector
// extract which doesn't naturally align to a vector register. These must
// be handled using instructions to manipulate the vector registers.
if (Idx != 0)
break;
const auto *TRI = Subtarget->getRegisterInfo();
unsigned SubRegIdx;
std::tie(SubRegIdx, Idx) =
RISCVTargetLowering::decomposeSubvectorInsertExtractToSubRegs(
InVT, SubVecContainerVT, Idx, TRI);
// If we haven't set a SubRegIdx, then we must be going between LMUL<=1
// types (VR -> VR). This can be done as a copy.
if (SubRegIdx == RISCV::NoSubRegister) {
unsigned InRegClassID =
RISCVTargetLowering::getRegClassIDForVecVT(InVT);
assert(RISCVTargetLowering::getRegClassIDForVecVT(VT) ==
RISCV::VRRegClassID &&
InRegClassID == RISCV::VRRegClassID &&
"Unexpected subvector extraction");
SDValue RC =
CurDAG->getTargetConstant(InRegClassID, DL, XLenVT);
SDNode *NewNode = CurDAG->getMachineNode(TargetOpcode::COPY_TO_REGCLASS,
DL, VT, V, RC);
return ReplaceNode(Node, NewNode);
}
SDNode *NewNode = CurDAG->getMachineNode(
TargetOpcode::EXTRACT_SUBREG, DL, VT, V,
CurDAG->getTargetConstant(SubRegIdx, DL, XLenVT));
// If the Idx hasn't been completely eliminated then this is a subvector
// extract which doesn't naturally align to a vector register. These must
// be handled using instructions to manipulate the vector registers.
if (Idx != 0)
break;
// If we haven't set a SubRegIdx, then we must be going between
// equally-sized LMUL types (e.g. VR -> VR). This can be done as a copy.
if (SubRegIdx == RISCV::NoSubRegister) {
unsigned InRegClassID = RISCVTargetLowering::getRegClassIDForVecVT(InVT);
assert(RISCVTargetLowering::getRegClassIDForVecVT(SubVecContainerVT) ==
InRegClassID &&
"Unexpected subvector extraction");
SDValue RC = CurDAG->getTargetConstant(InRegClassID, DL, XLenVT);
SDNode *NewNode =
CurDAG->getMachineNode(TargetOpcode::COPY_TO_REGCLASS, DL, VT, V, RC);
return ReplaceNode(Node, NewNode);
}
if (VT.isFixedLengthVector() && InVT.isScalableVector()) {
// Bail when not a "cast" like extract_subvector.
if (Idx != 0)
break;
unsigned InRegClassID = RISCVTargetLowering::getRegClassIDForVecVT(InVT);
SDValue RC =
CurDAG->getTargetConstant(InRegClassID, DL, Subtarget->getXLenVT());
SDNode *NewNode =
CurDAG->getMachineNode(TargetOpcode::COPY_TO_REGCLASS, DL, VT, V, RC);
ReplaceNode(Node, NewNode);
return;
}
break;
SDNode *NewNode = CurDAG->getMachineNode(
TargetOpcode::EXTRACT_SUBREG, DL, VT, V,
CurDAG->getTargetConstant(SubRegIdx, DL, XLenVT));
return ReplaceNode(Node, NewNode);
}
}

View File

@ -528,7 +528,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setOperationAction(Op, VT, Expand);
// We use EXTRACT_SUBVECTOR as a "cast" from scalable to fixed.
setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Legal);
setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom);
setOperationAction(ISD::BUILD_VECTOR, VT, Custom);
@ -587,7 +587,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setOperationAction(Op, VT, Expand);
// We use EXTRACT_SUBVECTOR as a "cast" from scalable to fixed.
setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Legal);
setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom);
setOperationAction(ISD::BUILD_VECTOR, VT, Custom);
setOperationAction(ISD::VECTOR_SHUFFLE, VT, Custom);
@ -918,8 +918,8 @@ RISCVTargetLowering::decomposeSubvectorInsertExtractToSubRegs(
}
// Return the largest legal scalable vector type that matches VT's element type.
static MVT getContainerForFixedLengthVector(SelectionDAG &DAG, MVT VT,
const RISCVSubtarget &Subtarget) {
MVT RISCVTargetLowering::getContainerForFixedLengthVector(
SelectionDAG &DAG, MVT VT, const RISCVSubtarget &Subtarget) {
assert(VT.isFixedLengthVector() &&
DAG.getTargetLoweringInfo().isTypeLegal(VT) &&
"Expected legal fixed length vector!");
@ -1003,7 +1003,8 @@ static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG,
MVT VT = Op.getSimpleValueType();
assert(VT.isFixedLengthVector() && "Unexpected vector!");
MVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
MVT ContainerVT =
RISCVTargetLowering::getContainerForFixedLengthVector(DAG, VT, Subtarget);
SDLoc DL(Op);
SDValue Mask, VL;
@ -1058,7 +1059,8 @@ static SDValue lowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG,
if (SVN->isSplat()) {
int Lane = SVN->getSplatIndex();
if (Lane >= 0) {
MVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
MVT ContainerVT = RISCVTargetLowering::getContainerForFixedLengthVector(
DAG, VT, Subtarget);
V1 = convertToScalableVector(ContainerVT, V1, DAG, Subtarget);
assert(Lane < (int)VT.getVectorNumElements() && "Unexpected lane!");
@ -1911,7 +1913,8 @@ SDValue RISCVTargetLowering::lowerVectorMaskExt(SDValue Op, SelectionDAG &DAG,
return DAG.getNode(ISD::VSELECT, DL, VecVT, Src, SplatTrueVal, SplatZero);
}
MVT ContainerVT = getContainerForFixedLengthVector(DAG, VecVT, Subtarget);
MVT ContainerVT = RISCVTargetLowering::getContainerForFixedLengthVector(
DAG, VecVT, Subtarget);
MVT I1ContainerVT =
MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());
@ -2335,15 +2338,41 @@ SDValue RISCVTargetLowering::lowerEXTRACT_SUBVECTOR(SDValue Op,
MVT SubVecVT = Op.getSimpleValueType();
MVT VecVT = Vec.getSimpleValueType();
// TODO: Only handle scalable->scalable extracts for now, and revisit this
// for fixed-length vectors later.
if (!SubVecVT.isScalableVector() || !VecVT.isScalableVector())
return Op;
SDLoc DL(Op);
MVT XLenVT = Subtarget.getXLenVT();
unsigned OrigIdx = Op.getConstantOperandVal(1);
const RISCVRegisterInfo *TRI = Subtarget.getRegisterInfo();
// If the subvector vector is a fixed-length type, we cannot use subregister
// manipulation to simplify the codegen; we don't know which register of a
// LMUL group contains the specific subvector as we only know the minimum
// register size. Therefore we must slide the vector group down the full
// amount.
if (SubVecVT.isFixedLengthVector()) {
// With an index of 0 this is a cast-like subvector, which can be performed
// with subregister operations.
if (OrigIdx == 0)
return Op;
MVT ContainerVT = VecVT;
if (VecVT.isFixedLengthVector()) {
ContainerVT = RISCVTargetLowering::getContainerForFixedLengthVector(
DAG, VecVT, Subtarget);
Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget);
}
SDValue Mask =
getDefaultVLOps(VecVT, ContainerVT, DL, DAG, Subtarget).first;
// Set the vector length to only the number of elements we care about. This
// avoids sliding down elements we're going to discard straight away.
SDValue VL = DAG.getConstant(SubVecVT.getVectorNumElements(), DL, XLenVT);
SDValue SlidedownAmt = DAG.getConstant(OrigIdx, DL, XLenVT);
SDValue Slidedown =
DAG.getNode(RISCVISD::VSLIDEDOWN_VL, DL, ContainerVT,
DAG.getUNDEF(ContainerVT), Vec, SlidedownAmt, Mask, VL);
// Now we can use a cast-like subvector extract to get the result.
return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SubVecVT, Slidedown,
DAG.getConstant(0, DL, XLenVT));
}
unsigned SubRegIdx, RemIdx;
std::tie(SubRegIdx, RemIdx) =
RISCVTargetLowering::decomposeSubvectorInsertExtractToSubRegs(
@ -2357,7 +2386,6 @@ SDValue RISCVTargetLowering::lowerEXTRACT_SUBVECTOR(SDValue Op,
// Else we must shift our vector register directly to extract the subvector.
// Do this using VSLIDEDOWN.
MVT XLenVT = Subtarget.getXLenVT();
// Extract a subvector equal to the nearest full vector register type. This
// should resolve to a EXTRACT_SUBREG instruction.
@ -2392,7 +2420,8 @@ RISCVTargetLowering::lowerFixedLengthVectorLoadToRVV(SDValue Op,
SDLoc DL(Op);
MVT VT = Op.getSimpleValueType();
MVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
MVT ContainerVT =
RISCVTargetLowering::getContainerForFixedLengthVector(DAG, VT, Subtarget);
SDValue VL =
DAG.getConstant(VT.getVectorNumElements(), DL, Subtarget.getXLenVT());
@ -2417,7 +2446,8 @@ RISCVTargetLowering::lowerFixedLengthVectorStoreToRVV(SDValue Op,
// FIXME: We probably need to zero any extra bits in a byte for mask stores.
// This is tricky to do.
MVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
MVT ContainerVT =
RISCVTargetLowering::getContainerForFixedLengthVector(DAG, VT, Subtarget);
SDValue VL =
DAG.getConstant(VT.getVectorNumElements(), DL, Subtarget.getXLenVT());
@ -2434,7 +2464,8 @@ SDValue
RISCVTargetLowering::lowerFixedLengthVectorSetccToRVV(SDValue Op,
SelectionDAG &DAG) const {
MVT InVT = Op.getOperand(0).getSimpleValueType();
MVT ContainerVT = getContainerForFixedLengthVector(DAG, InVT, Subtarget);
MVT ContainerVT = RISCVTargetLowering::getContainerForFixedLengthVector(
DAG, InVT, Subtarget);
MVT VT = Op.getSimpleValueType();
@ -2547,7 +2578,8 @@ SDValue RISCVTargetLowering::lowerFixedLengthVectorLogicOpToRVV(
SDValue RISCVTargetLowering::lowerFixedLengthVectorSelectToRVV(
SDValue Op, SelectionDAG &DAG) const {
MVT VT = Op.getSimpleValueType();
MVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
MVT ContainerVT =
RISCVTargetLowering::getContainerForFixedLengthVector(DAG, VT, Subtarget);
MVT I1ContainerVT =
MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());
@ -2575,7 +2607,8 @@ SDValue RISCVTargetLowering::lowerToScalableOp(SDValue Op, SelectionDAG &DAG,
MVT VT = Op.getSimpleValueType();
assert(useRVVForFixedLengthVectorVT(VT) &&
"Only expected to lower fixed length vector operation!");
MVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
MVT ContainerVT =
RISCVTargetLowering::getContainerForFixedLengthVector(DAG, VT, Subtarget);
// Create list of operands by converting existing ones to scalable types.
SmallVector<SDValue, 6> Ops;

View File

@ -380,6 +380,8 @@ public:
decomposeSubvectorInsertExtractToSubRegs(MVT VecVT, MVT SubVecVT,
unsigned InsertExtractIdx,
const RISCVRegisterInfo *TRI);
static MVT getContainerForFixedLengthVector(SelectionDAG &DAG, MVT VT,
const RISCVSubtarget &Subtarget);
private:
void analyzeInputArgs(MachineFunction &MF, CCState &CCInfo,

View File

@ -0,0 +1,159 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc -mtriple=riscv64 -mattr=+m,+experimental-v -verify-machineinstrs -riscv-v-vector-bits-min=128 -riscv-v-fixed-length-vector-lmul-max=2 -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,LMULMAX2
; RUN: llc -mtriple=riscv64 -mattr=+m,+experimental-v -verify-machineinstrs -riscv-v-vector-bits-min=128 -riscv-v-fixed-length-vector-lmul-max=1 -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,LMULMAX1
define void @extract_v2i8_v8i8_0(<8 x i8>* %x, <2 x i8>* %y) {
; CHECK-LABEL: extract_v2i8_v8i8_0:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli a2, 8, e8,m1,ta,mu
; CHECK-NEXT: vle8.v v25, (a0)
; CHECK-NEXT: vsetivli a0, 2, e8,m1,ta,mu
; CHECK-NEXT: vse8.v v25, (a1)
; CHECK-NEXT: ret
%a = load <8 x i8>, <8 x i8>* %x
%c = call <2 x i8> @llvm.experimental.vector.extract.v2i8.v8i8(<8 x i8> %a, i64 0)
store <2 x i8> %c, <2 x i8>* %y
ret void
}
define void @extract_v2i8_v8i8_6(<8 x i8>* %x, <2 x i8>* %y) {
; CHECK-LABEL: extract_v2i8_v8i8_6:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli a2, 8, e8,m1,ta,mu
; CHECK-NEXT: vle8.v v25, (a0)
; CHECK-NEXT: vsetivli a0, 2, e8,m1,ta,mu
; CHECK-NEXT: vslidedown.vi v25, v25, 6
; CHECK-NEXT: vse8.v v25, (a1)
; CHECK-NEXT: ret
%a = load <8 x i8>, <8 x i8>* %x
%c = call <2 x i8> @llvm.experimental.vector.extract.v2i8.v8i8(<8 x i8> %a, i64 6)
store <2 x i8> %c, <2 x i8>* %y
ret void
}
define void @extract_v2i32_v8i32_0(<8 x i32>* %x, <2 x i32>* %y) {
; LMULMAX2-LABEL: extract_v2i32_v8i32_0:
; LMULMAX2: # %bb.0:
; LMULMAX2-NEXT: vsetivli a2, 8, e32,m2,ta,mu
; LMULMAX2-NEXT: vle32.v v26, (a0)
; LMULMAX2-NEXT: vsetivli a0, 2, e32,m1,ta,mu
; LMULMAX2-NEXT: vse32.v v26, (a1)
; LMULMAX2-NEXT: ret
;
; LMULMAX1-LABEL: extract_v2i32_v8i32_0:
; LMULMAX1: # %bb.0:
; LMULMAX1-NEXT: vsetivli a2, 4, e32,m1,ta,mu
; LMULMAX1-NEXT: vle32.v v25, (a0)
; LMULMAX1-NEXT: vsetivli a0, 2, e32,m1,ta,mu
; LMULMAX1-NEXT: vse32.v v25, (a1)
; LMULMAX1-NEXT: ret
%a = load <8 x i32>, <8 x i32>* %x
%c = call <2 x i32> @llvm.experimental.vector.extract.v2i32.v8i32(<8 x i32> %a, i64 0)
store <2 x i32> %c, <2 x i32>* %y
ret void
}
define void @extract_v2i32_v8i32_2(<8 x i32>* %x, <2 x i32>* %y) {
; LMULMAX2-LABEL: extract_v2i32_v8i32_2:
; LMULMAX2: # %bb.0:
; LMULMAX2-NEXT: vsetivli a2, 8, e32,m2,ta,mu
; LMULMAX2-NEXT: vle32.v v26, (a0)
; LMULMAX2-NEXT: vsetivli a0, 2, e32,m2,ta,mu
; LMULMAX2-NEXT: vslidedown.vi v26, v26, 2
; LMULMAX2-NEXT: vsetivli a0, 2, e32,m1,ta,mu
; LMULMAX2-NEXT: vse32.v v26, (a1)
; LMULMAX2-NEXT: ret
;
; LMULMAX1-LABEL: extract_v2i32_v8i32_2:
; LMULMAX1: # %bb.0:
; LMULMAX1-NEXT: vsetivli a2, 4, e32,m1,ta,mu
; LMULMAX1-NEXT: vle32.v v25, (a0)
; LMULMAX1-NEXT: vsetivli a0, 2, e32,m1,ta,mu
; LMULMAX1-NEXT: vslidedown.vi v25, v25, 2
; LMULMAX1-NEXT: vse32.v v25, (a1)
; LMULMAX1-NEXT: ret
%a = load <8 x i32>, <8 x i32>* %x
%c = call <2 x i32> @llvm.experimental.vector.extract.v2i32.v8i32(<8 x i32> %a, i64 2)
store <2 x i32> %c, <2 x i32>* %y
ret void
}
define void @extract_v2i32_v8i32_6(<8 x i32>* %x, <2 x i32>* %y) {
; LMULMAX2-LABEL: extract_v2i32_v8i32_6:
; LMULMAX2: # %bb.0:
; LMULMAX2-NEXT: vsetivli a2, 8, e32,m2,ta,mu
; LMULMAX2-NEXT: vle32.v v26, (a0)
; LMULMAX2-NEXT: vsetivli a0, 2, e32,m2,ta,mu
; LMULMAX2-NEXT: vslidedown.vi v26, v26, 6
; LMULMAX2-NEXT: vsetivli a0, 2, e32,m1,ta,mu
; LMULMAX2-NEXT: vse32.v v26, (a1)
; LMULMAX2-NEXT: ret
;
; LMULMAX1-LABEL: extract_v2i32_v8i32_6:
; LMULMAX1: # %bb.0:
; LMULMAX1-NEXT: addi a0, a0, 16
; LMULMAX1-NEXT: vsetivli a2, 4, e32,m1,ta,mu
; LMULMAX1-NEXT: vle32.v v25, (a0)
; LMULMAX1-NEXT: vsetivli a0, 2, e32,m1,ta,mu
; LMULMAX1-NEXT: vslidedown.vi v25, v25, 2
; LMULMAX1-NEXT: vse32.v v25, (a1)
; LMULMAX1-NEXT: ret
%a = load <8 x i32>, <8 x i32>* %x
%c = call <2 x i32> @llvm.experimental.vector.extract.v2i32.v8i32(<8 x i32> %a, i64 6)
store <2 x i32> %c, <2 x i32>* %y
ret void
}
define void @extract_v2i32_nxv16i32_0(<vscale x 16 x i32> %x, <2 x i32>* %y) {
; CHECK-LABEL: extract_v2i32_nxv16i32_0:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli a1, 2, e32,m1,ta,mu
; CHECK-NEXT: vse32.v v8, (a0)
; CHECK-NEXT: ret
%c = call <2 x i32> @llvm.experimental.vector.extract.v2i32.nxv16i32(<vscale x 16 x i32> %x, i64 0)
store <2 x i32> %c, <2 x i32>* %y
ret void
}
define void @extract_v2i32_nxv16i32_8(<vscale x 16 x i32> %x, <2 x i32>* %y) {
; CHECK-LABEL: extract_v2i32_nxv16i32_8:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli a1, 2, e32,m8,ta,mu
; CHECK-NEXT: vslidedown.vi v8, v8, 6
; CHECK-NEXT: vsetivli a1, 2, e32,m1,ta,mu
; CHECK-NEXT: vse32.v v8, (a0)
; CHECK-NEXT: ret
%c = call <2 x i32> @llvm.experimental.vector.extract.v2i32.nxv16i32(<vscale x 16 x i32> %x, i64 6)
store <2 x i32> %c, <2 x i32>* %y
ret void
}
define void @extract_v8i32_nxv16i32_8(<vscale x 16 x i32> %x, <8 x i32>* %y) {
; LMULMAX2-LABEL: extract_v8i32_nxv16i32_8:
; LMULMAX2: # %bb.0:
; LMULMAX2-NEXT: vsetivli a1, 8, e32,m8,ta,mu
; LMULMAX2-NEXT: vslidedown.vi v8, v8, 8
; LMULMAX2-NEXT: vsetivli a1, 8, e32,m2,ta,mu
; LMULMAX2-NEXT: vse32.v v8, (a0)
; LMULMAX2-NEXT: ret
;
; LMULMAX1-LABEL: extract_v8i32_nxv16i32_8:
; LMULMAX1: # %bb.0:
; LMULMAX1-NEXT: vsetivli a1, 4, e32,m8,ta,mu
; LMULMAX1-NEXT: vslidedown.vi v16, v8, 8
; LMULMAX1-NEXT: vslidedown.vi v8, v8, 12
; LMULMAX1-NEXT: addi a1, a0, 16
; LMULMAX1-NEXT: vsetivli a2, 4, e32,m1,ta,mu
; LMULMAX1-NEXT: vse32.v v8, (a1)
; LMULMAX1-NEXT: vse32.v v16, (a0)
; LMULMAX1-NEXT: ret
%c = call <8 x i32> @llvm.experimental.vector.extract.v8i32.nxv16i32(<vscale x 16 x i32> %x, i64 8)
store <8 x i32> %c, <8 x i32>* %y
ret void
}
declare <2 x i8> @llvm.experimental.vector.extract.v2i8.v8i8(<8 x i8> %vec, i64 %idx)
declare <2 x i32> @llvm.experimental.vector.extract.v2i32.v8i32(<8 x i32> %vec, i64 %idx)
declare <2 x i32> @llvm.experimental.vector.extract.v2i32.nxv16i32(<vscale x 16 x i32> %vec, i64 %idx)
declare <8 x i32> @llvm.experimental.vector.extract.v8i32.nxv16i32(<vscale x 16 x i32> %vec, i64 %idx)