[RISCV] Support fixed-length vector truncates

This patch extends support for our custom-lowering of scalable-vector
truncates to include those of fixed-length vectors. It does this by
co-opting the custom RISCVISD::TRUNCATE_VECTOR node and adding mask and
VL operands. This avoids unnecessary duplication of patterns and
inflation of the ISel table.

Some truncates go through CONCAT_VECTORS which currently isn't
efficiently handled, as it goes through the stack. This can be improved
upon in the future.

Reviewed By: craig.topper

Differential Revision: https://reviews.llvm.org/D97202
This commit is contained in:
Fraser Cormack 2021-02-22 16:51:24 +00:00
parent 3bc5ed3875
commit 84413e1947
5 changed files with 130 additions and 25 deletions

View File

@ -446,7 +446,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::FP_TO_SINT, VT, Custom);
setOperationAction(ISD::FP_TO_UINT, VT, Custom);
// Integer VTs are lowered as a series of "RISCVISD::TRUNCATE_VECTOR"
// Integer VTs are lowered as a series of "RISCVISD::TRUNCATE_VECTOR_VL"
// nodes which truncate by one power of two at a time.
setOperationAction(ISD::TRUNCATE, VT, Custom);
@ -526,6 +526,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
// By default everything must be expanded.
for (unsigned Op = 0; Op < ISD::BUILTIN_OP_END; ++Op)
setOperationAction(Op, VT, Expand);
for (MVT OtherVT : MVT::fixedlen_vector_valuetypes())
setTruncStoreAction(VT, OtherVT, Expand);
// We use EXTRACT_SUBVECTOR as a "cast" from scalable to fixed.
setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom);
@ -571,6 +573,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::VSELECT, VT, Custom);
setOperationAction(ISD::TRUNCATE, VT, Custom);
setOperationAction(ISD::ANY_EXTEND, VT, Custom);
setOperationAction(ISD::SIGN_EXTEND, VT, Custom);
setOperationAction(ISD::ZERO_EXTEND, VT, Custom);
@ -1171,7 +1174,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
}
case ISD::TRUNCATE: {
SDLoc DL(Op);
EVT VT = Op.getValueType();
MVT VT = Op.getSimpleValueType();
// Only custom-lower vector truncates
if (!VT.isVector())
return Op;
@ -1181,28 +1184,42 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
return lowerVectorMaskTrunc(Op, DAG);
// RVV only has truncates which operate from SEW*2->SEW, so lower arbitrary
// truncates as a series of "RISCVISD::TRUNCATE_VECTOR" nodes which
// truncates as a series of "RISCVISD::TRUNCATE_VECTOR_VL" nodes which
// truncate by one power of two at a time.
EVT DstEltVT = VT.getVectorElementType();
MVT DstEltVT = VT.getVectorElementType();
SDValue Src = Op.getOperand(0);
EVT SrcVT = Src.getValueType();
EVT SrcEltVT = SrcVT.getVectorElementType();
MVT SrcVT = Src.getSimpleValueType();
MVT SrcEltVT = SrcVT.getVectorElementType();
assert(DstEltVT.bitsLT(SrcEltVT) &&
isPowerOf2_64(DstEltVT.getSizeInBits()) &&
isPowerOf2_64(SrcEltVT.getSizeInBits()) &&
"Unexpected vector truncate lowering");
MVT ContainerVT = SrcVT;
if (SrcVT.isFixedLengthVector()) {
ContainerVT = RISCVTargetLowering::getContainerForFixedLengthVector(
DAG, SrcVT, Subtarget);
Src = convertToScalableVector(ContainerVT, Src, DAG, Subtarget);
}
SDValue Result = Src;
SDValue Mask, VL;
std::tie(Mask, VL) =
getDefaultVLOps(SrcVT, ContainerVT, DL, DAG, Subtarget);
LLVMContext &Context = *DAG.getContext();
const ElementCount Count = SrcVT.getVectorElementCount();
const ElementCount Count = ContainerVT.getVectorElementCount();
do {
SrcEltVT = EVT::getIntegerVT(Context, SrcEltVT.getSizeInBits() / 2);
SrcEltVT = MVT::getIntegerVT(SrcEltVT.getSizeInBits() / 2);
EVT ResultVT = EVT::getVectorVT(Context, SrcEltVT, Count);
Result = DAG.getNode(RISCVISD::TRUNCATE_VECTOR, DL, ResultVT, Result);
Result = DAG.getNode(RISCVISD::TRUNCATE_VECTOR_VL, DL, ResultVT, Result,
Mask, VL);
} while (SrcEltVT != DstEltVT);
if (SrcVT.isFixedLengthVector())
Result = convertFromScalableVector(VT, Result, DAG, Subtarget);
return Result;
}
case ISD::ANY_EXTEND:
@ -5437,7 +5454,9 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(VMV_X_S)
NODE_NAME_CASE(SPLAT_VECTOR_I64)
NODE_NAME_CASE(READ_VLENB)
NODE_NAME_CASE(TRUNCATE_VECTOR)
NODE_NAME_CASE(TRUNCATE_VECTOR_VL)
NODE_NAME_CASE(VLEFF)
NODE_NAME_CASE(VLEFF_MASK)
NODE_NAME_CASE(VSLIDEUP_VL)
NODE_NAME_CASE(VSLIDEDOWN_VL)
NODE_NAME_CASE(VID_VL)

View File

@ -105,8 +105,12 @@ enum NodeType : unsigned {
SPLAT_VECTOR_I64,
// Read VLENB CSR
READ_VLENB,
// Truncates a RVV integer vector by one power-of-two.
TRUNCATE_VECTOR,
// Truncates a RVV integer vector by one power-of-two. Carries both an extra
// mask and VL operand.
TRUNCATE_VECTOR_VL,
// Unit-stride fault-only-first load
VLEFF,
VLEFF_MASK,
// Matches the semantics of vslideup/vslidedown. The first operand is the
// pass-thru operand, the second is the source vector, the third is the
// XLenVT index (either constant or non-constant), the fourth is the mask

View File

@ -28,10 +28,6 @@ def SDTSplatI64 : SDTypeProfile<1, 1, [
def rv32_splat_i64 : SDNode<"RISCVISD::SPLAT_VECTOR_I64", SDTSplatI64>;
def riscv_trunc_vector : SDNode<"RISCVISD::TRUNCATE_VECTOR",
SDTypeProfile<1, 1,
[SDTCisVec<0>, SDTCisVec<1>]>>;
// Give explicit Complexity to prefer simm5/uimm5.
def SplatPat : ComplexPattern<vAny, 1, "selectVSplat", [splat_vector, rv32_splat_i64], [], 1>;
def SplatPat_simm5 : ComplexPattern<vAny, 1, "selectVSplatSimm5", [splat_vector, rv32_splat_i64], [], 2>;
@ -433,15 +429,6 @@ defm "" : VPatBinarySDNode_VV_VX_VI<shl, "PseudoVSLL", uimm5>;
defm "" : VPatBinarySDNode_VV_VX_VI<srl, "PseudoVSRL", uimm5>;
defm "" : VPatBinarySDNode_VV_VX_VI<sra, "PseudoVSRA", uimm5>;
// 12.7. Vector Narrowing Integer Right Shift Instructions
foreach vtiTofti = AllFractionableVF2IntVectors in {
defvar vti = vtiTofti.Vti;
defvar fti = vtiTofti.Fti;
def : Pat<(fti.Vector (riscv_trunc_vector (vti.Vector vti.RegClass:$rs1))),
(!cast<Instruction>("PseudoVNSRL_WI_"#fti.LMul.MX)
vti.RegClass:$rs1, 0, fti.AVL, fti.SEW)>;
}
// 12.8. Vector Integer Comparison Instructions
defm "" : VPatIntegerSetCCSDNode_VV_VX_VI<SETEQ, "PseudoVMSEQ">;
defm "" : VPatIntegerSetCCSDNode_VV_VX_VI<SETNE, "PseudoVMSNE">;

View File

@ -148,6 +148,13 @@ def SDT_RISCVVEXTEND_VL : SDTypeProfile<1, 3, [SDTCisVec<0>,
def riscv_sext_vl : SDNode<"RISCVISD::VSEXT_VL", SDT_RISCVVEXTEND_VL>;
def riscv_zext_vl : SDNode<"RISCVISD::VZEXT_VL", SDT_RISCVVEXTEND_VL>;
def riscv_trunc_vector_vl : SDNode<"RISCVISD::TRUNCATE_VECTOR_VL",
SDTypeProfile<1, 3, [SDTCisVec<0>,
SDTCisVec<1>,
SDTCisSameNumEltsAs<0, 2>,
SDTCVecEltisVT<2, i1>,
SDTCisVT<3, XLenVT>]>>;
// Ignore the vl operand.
def SplatFPOp : PatFrag<(ops node:$op),
(riscv_vfmv_v_f_vl node:$op, srcvalue)>;
@ -443,6 +450,17 @@ defm "" : VPatBinaryVL_VV_VX_VI<riscv_shl_vl, "PseudoVSLL", uimm5>;
defm "" : VPatBinaryVL_VV_VX_VI<riscv_srl_vl, "PseudoVSRL", uimm5>;
defm "" : VPatBinaryVL_VV_VX_VI<riscv_sra_vl, "PseudoVSRA", uimm5>;
// 12.7. Vector Narrowing Integer Right Shift Instructions
foreach vtiTofti = AllFractionableVF2IntVectors in {
defvar vti = vtiTofti.Vti;
defvar fti = vtiTofti.Fti;
def : Pat<(fti.Vector (riscv_trunc_vector_vl (vti.Vector vti.RegClass:$rs1),
(vti.Mask true_mask),
(XLenVT (VLOp GPR:$vl)))),
(!cast<Instruction>("PseudoVNSRL_WI_"#fti.LMul.MX)
vti.RegClass:$rs1, 0, GPR:$vl, fti.SEW)>;
}
// 12.8. Vector Integer Comparison Instructions
foreach vti = AllIntegerVectors in {
defm "" : VPatIntegerSetCCVL_VV<vti, "PseudoVMSEQ", SETEQ>;

View File

@ -165,3 +165,80 @@ define void @sext_v32i8_v32i32(<32 x i8>* %x, <32 x i32>* %z) {
store <32 x i32> %b, <32 x i32>* %z
ret void
}
define void @trunc_v4i8_v4i32(<4 x i32>* %x, <4 x i8>* %z) {
; CHECK-LABEL: trunc_v4i8_v4i32:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli a2, 4, e32,m1,ta,mu
; CHECK-NEXT: vle32.v v25, (a0)
; CHECK-NEXT: vsetivli a0, 4, e16,mf2,ta,mu
; CHECK-NEXT: vnsrl.wi v26, v25, 0
; CHECK-NEXT: vsetivli a0, 4, e8,mf4,ta,mu
; CHECK-NEXT: vnsrl.wi v25, v26, 0
; CHECK-NEXT: vsetivli a0, 4, e8,m1,ta,mu
; CHECK-NEXT: vse8.v v25, (a1)
; CHECK-NEXT: ret
%a = load <4 x i32>, <4 x i32>* %x
%b = trunc <4 x i32> %a to <4 x i8>
store <4 x i8> %b, <4 x i8>* %z
ret void
}
define void @trunc_v8i8_v8i32(<8 x i32>* %x, <8 x i8>* %z) {
; LMULMAX8-LABEL: trunc_v8i8_v8i32:
; LMULMAX8: # %bb.0:
; LMULMAX8-NEXT: vsetivli a2, 8, e32,m2,ta,mu
; LMULMAX8-NEXT: vle32.v v26, (a0)
; LMULMAX8-NEXT: vsetivli a0, 8, e16,m1,ta,mu
; LMULMAX8-NEXT: vnsrl.wi v25, v26, 0
; LMULMAX8-NEXT: vsetivli a0, 8, e8,mf2,ta,mu
; LMULMAX8-NEXT: vnsrl.wi v26, v25, 0
; LMULMAX8-NEXT: vsetivli a0, 8, e8,m1,ta,mu
; LMULMAX8-NEXT: vse8.v v26, (a1)
; LMULMAX8-NEXT: ret
;
; LMULMAX2-LABEL: trunc_v8i8_v8i32:
; LMULMAX2: # %bb.0:
; LMULMAX2-NEXT: vsetivli a2, 8, e32,m2,ta,mu
; LMULMAX2-NEXT: vle32.v v26, (a0)
; LMULMAX2-NEXT: vsetivli a0, 8, e16,m1,ta,mu
; LMULMAX2-NEXT: vnsrl.wi v25, v26, 0
; LMULMAX2-NEXT: vsetivli a0, 8, e8,mf2,ta,mu
; LMULMAX2-NEXT: vnsrl.wi v26, v25, 0
; LMULMAX2-NEXT: vsetivli a0, 8, e8,m1,ta,mu
; LMULMAX2-NEXT: vse8.v v26, (a1)
; LMULMAX2-NEXT: ret
;
; LMULMAX1-LABEL: trunc_v8i8_v8i32:
; LMULMAX1: # %bb.0:
; LMULMAX1-NEXT: addi sp, sp, -16
; LMULMAX1-NEXT: .cfi_def_cfa_offset 16
; LMULMAX1-NEXT: vsetivli a2, 4, e32,m1,ta,mu
; LMULMAX1-NEXT: addi a2, a0, 16
; LMULMAX1-NEXT: vle32.v v25, (a2)
; LMULMAX1-NEXT: vle32.v v26, (a0)
; LMULMAX1-NEXT: vsetivli a0, 4, e16,mf2,ta,mu
; LMULMAX1-NEXT: vnsrl.wi v27, v25, 0
; LMULMAX1-NEXT: vsetivli a0, 4, e8,mf4,ta,mu
; LMULMAX1-NEXT: vnsrl.wi v25, v27, 0
; LMULMAX1-NEXT: addi a0, sp, 12
; LMULMAX1-NEXT: vsetivli a2, 4, e8,m1,ta,mu
; LMULMAX1-NEXT: vse8.v v25, (a0)
; LMULMAX1-NEXT: vsetivli a0, 4, e16,mf2,ta,mu
; LMULMAX1-NEXT: vnsrl.wi v25, v26, 0
; LMULMAX1-NEXT: vsetivli a0, 4, e8,mf4,ta,mu
; LMULMAX1-NEXT: vnsrl.wi v26, v25, 0
; LMULMAX1-NEXT: vsetivli a0, 4, e8,m1,ta,mu
; LMULMAX1-NEXT: addi a0, sp, 8
; LMULMAX1-NEXT: vse8.v v26, (a0)
; LMULMAX1-NEXT: vsetivli a0, 8, e8,m1,ta,mu
; LMULMAX1-NEXT: addi a0, sp, 8
; LMULMAX1-NEXT: vle8.v v25, (a0)
; LMULMAX1-NEXT: vse8.v v25, (a1)
; LMULMAX1-NEXT: addi sp, sp, 16
; LMULMAX1-NEXT: ret
%a = load <8 x i32>, <8 x i32>* %x
%b = trunc <8 x i32> %a to <8 x i8>
store <8 x i8> %b, <8 x i8>* %z
ret void
}