forked from OSchip/llvm-project
[RISCV] Support fixed-length vector truncates
This patch extends support for our custom-lowering of scalable-vector truncates to include those of fixed-length vectors. It does this by co-opting the custom RISCVISD::TRUNCATE_VECTOR node and adding mask and VL operands. This avoids unnecessary duplication of patterns and inflation of the ISel table. Some truncates go through CONCAT_VECTORS which currently isn't efficiently handled, as it goes through the stack. This can be improved upon in the future. Reviewed By: craig.topper Differential Revision: https://reviews.llvm.org/D97202
This commit is contained in:
parent
3bc5ed3875
commit
84413e1947
|
@ -446,7 +446,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
|
|||
setOperationAction(ISD::FP_TO_SINT, VT, Custom);
|
||||
setOperationAction(ISD::FP_TO_UINT, VT, Custom);
|
||||
|
||||
// Integer VTs are lowered as a series of "RISCVISD::TRUNCATE_VECTOR"
|
||||
// Integer VTs are lowered as a series of "RISCVISD::TRUNCATE_VECTOR_VL"
|
||||
// nodes which truncate by one power of two at a time.
|
||||
setOperationAction(ISD::TRUNCATE, VT, Custom);
|
||||
|
||||
|
@ -526,6 +526,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
|
|||
// By default everything must be expanded.
|
||||
for (unsigned Op = 0; Op < ISD::BUILTIN_OP_END; ++Op)
|
||||
setOperationAction(Op, VT, Expand);
|
||||
for (MVT OtherVT : MVT::fixedlen_vector_valuetypes())
|
||||
setTruncStoreAction(VT, OtherVT, Expand);
|
||||
|
||||
// We use EXTRACT_SUBVECTOR as a "cast" from scalable to fixed.
|
||||
setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom);
|
||||
|
@ -571,6 +573,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
|
|||
|
||||
setOperationAction(ISD::VSELECT, VT, Custom);
|
||||
|
||||
setOperationAction(ISD::TRUNCATE, VT, Custom);
|
||||
setOperationAction(ISD::ANY_EXTEND, VT, Custom);
|
||||
setOperationAction(ISD::SIGN_EXTEND, VT, Custom);
|
||||
setOperationAction(ISD::ZERO_EXTEND, VT, Custom);
|
||||
|
@ -1171,7 +1174,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
|
|||
}
|
||||
case ISD::TRUNCATE: {
|
||||
SDLoc DL(Op);
|
||||
EVT VT = Op.getValueType();
|
||||
MVT VT = Op.getSimpleValueType();
|
||||
// Only custom-lower vector truncates
|
||||
if (!VT.isVector())
|
||||
return Op;
|
||||
|
@ -1181,28 +1184,42 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
|
|||
return lowerVectorMaskTrunc(Op, DAG);
|
||||
|
||||
// RVV only has truncates which operate from SEW*2->SEW, so lower arbitrary
|
||||
// truncates as a series of "RISCVISD::TRUNCATE_VECTOR" nodes which
|
||||
// truncates as a series of "RISCVISD::TRUNCATE_VECTOR_VL" nodes which
|
||||
// truncate by one power of two at a time.
|
||||
EVT DstEltVT = VT.getVectorElementType();
|
||||
MVT DstEltVT = VT.getVectorElementType();
|
||||
|
||||
SDValue Src = Op.getOperand(0);
|
||||
EVT SrcVT = Src.getValueType();
|
||||
EVT SrcEltVT = SrcVT.getVectorElementType();
|
||||
MVT SrcVT = Src.getSimpleValueType();
|
||||
MVT SrcEltVT = SrcVT.getVectorElementType();
|
||||
|
||||
assert(DstEltVT.bitsLT(SrcEltVT) &&
|
||||
isPowerOf2_64(DstEltVT.getSizeInBits()) &&
|
||||
isPowerOf2_64(SrcEltVT.getSizeInBits()) &&
|
||||
"Unexpected vector truncate lowering");
|
||||
|
||||
MVT ContainerVT = SrcVT;
|
||||
if (SrcVT.isFixedLengthVector()) {
|
||||
ContainerVT = RISCVTargetLowering::getContainerForFixedLengthVector(
|
||||
DAG, SrcVT, Subtarget);
|
||||
Src = convertToScalableVector(ContainerVT, Src, DAG, Subtarget);
|
||||
}
|
||||
|
||||
SDValue Result = Src;
|
||||
SDValue Mask, VL;
|
||||
std::tie(Mask, VL) =
|
||||
getDefaultVLOps(SrcVT, ContainerVT, DL, DAG, Subtarget);
|
||||
LLVMContext &Context = *DAG.getContext();
|
||||
const ElementCount Count = SrcVT.getVectorElementCount();
|
||||
const ElementCount Count = ContainerVT.getVectorElementCount();
|
||||
do {
|
||||
SrcEltVT = EVT::getIntegerVT(Context, SrcEltVT.getSizeInBits() / 2);
|
||||
SrcEltVT = MVT::getIntegerVT(SrcEltVT.getSizeInBits() / 2);
|
||||
EVT ResultVT = EVT::getVectorVT(Context, SrcEltVT, Count);
|
||||
Result = DAG.getNode(RISCVISD::TRUNCATE_VECTOR, DL, ResultVT, Result);
|
||||
Result = DAG.getNode(RISCVISD::TRUNCATE_VECTOR_VL, DL, ResultVT, Result,
|
||||
Mask, VL);
|
||||
} while (SrcEltVT != DstEltVT);
|
||||
|
||||
if (SrcVT.isFixedLengthVector())
|
||||
Result = convertFromScalableVector(VT, Result, DAG, Subtarget);
|
||||
|
||||
return Result;
|
||||
}
|
||||
case ISD::ANY_EXTEND:
|
||||
|
@ -5437,7 +5454,9 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
|
|||
NODE_NAME_CASE(VMV_X_S)
|
||||
NODE_NAME_CASE(SPLAT_VECTOR_I64)
|
||||
NODE_NAME_CASE(READ_VLENB)
|
||||
NODE_NAME_CASE(TRUNCATE_VECTOR)
|
||||
NODE_NAME_CASE(TRUNCATE_VECTOR_VL)
|
||||
NODE_NAME_CASE(VLEFF)
|
||||
NODE_NAME_CASE(VLEFF_MASK)
|
||||
NODE_NAME_CASE(VSLIDEUP_VL)
|
||||
NODE_NAME_CASE(VSLIDEDOWN_VL)
|
||||
NODE_NAME_CASE(VID_VL)
|
||||
|
|
|
@ -105,8 +105,12 @@ enum NodeType : unsigned {
|
|||
SPLAT_VECTOR_I64,
|
||||
// Read VLENB CSR
|
||||
READ_VLENB,
|
||||
// Truncates a RVV integer vector by one power-of-two.
|
||||
TRUNCATE_VECTOR,
|
||||
// Truncates a RVV integer vector by one power-of-two. Carries both an extra
|
||||
// mask and VL operand.
|
||||
TRUNCATE_VECTOR_VL,
|
||||
// Unit-stride fault-only-first load
|
||||
VLEFF,
|
||||
VLEFF_MASK,
|
||||
// Matches the semantics of vslideup/vslidedown. The first operand is the
|
||||
// pass-thru operand, the second is the source vector, the third is the
|
||||
// XLenVT index (either constant or non-constant), the fourth is the mask
|
||||
|
|
|
@ -28,10 +28,6 @@ def SDTSplatI64 : SDTypeProfile<1, 1, [
|
|||
|
||||
def rv32_splat_i64 : SDNode<"RISCVISD::SPLAT_VECTOR_I64", SDTSplatI64>;
|
||||
|
||||
def riscv_trunc_vector : SDNode<"RISCVISD::TRUNCATE_VECTOR",
|
||||
SDTypeProfile<1, 1,
|
||||
[SDTCisVec<0>, SDTCisVec<1>]>>;
|
||||
|
||||
// Give explicit Complexity to prefer simm5/uimm5.
|
||||
def SplatPat : ComplexPattern<vAny, 1, "selectVSplat", [splat_vector, rv32_splat_i64], [], 1>;
|
||||
def SplatPat_simm5 : ComplexPattern<vAny, 1, "selectVSplatSimm5", [splat_vector, rv32_splat_i64], [], 2>;
|
||||
|
@ -433,15 +429,6 @@ defm "" : VPatBinarySDNode_VV_VX_VI<shl, "PseudoVSLL", uimm5>;
|
|||
defm "" : VPatBinarySDNode_VV_VX_VI<srl, "PseudoVSRL", uimm5>;
|
||||
defm "" : VPatBinarySDNode_VV_VX_VI<sra, "PseudoVSRA", uimm5>;
|
||||
|
||||
// 12.7. Vector Narrowing Integer Right Shift Instructions
|
||||
foreach vtiTofti = AllFractionableVF2IntVectors in {
|
||||
defvar vti = vtiTofti.Vti;
|
||||
defvar fti = vtiTofti.Fti;
|
||||
def : Pat<(fti.Vector (riscv_trunc_vector (vti.Vector vti.RegClass:$rs1))),
|
||||
(!cast<Instruction>("PseudoVNSRL_WI_"#fti.LMul.MX)
|
||||
vti.RegClass:$rs1, 0, fti.AVL, fti.SEW)>;
|
||||
}
|
||||
|
||||
// 12.8. Vector Integer Comparison Instructions
|
||||
defm "" : VPatIntegerSetCCSDNode_VV_VX_VI<SETEQ, "PseudoVMSEQ">;
|
||||
defm "" : VPatIntegerSetCCSDNode_VV_VX_VI<SETNE, "PseudoVMSNE">;
|
||||
|
|
|
@ -148,6 +148,13 @@ def SDT_RISCVVEXTEND_VL : SDTypeProfile<1, 3, [SDTCisVec<0>,
|
|||
def riscv_sext_vl : SDNode<"RISCVISD::VSEXT_VL", SDT_RISCVVEXTEND_VL>;
|
||||
def riscv_zext_vl : SDNode<"RISCVISD::VZEXT_VL", SDT_RISCVVEXTEND_VL>;
|
||||
|
||||
def riscv_trunc_vector_vl : SDNode<"RISCVISD::TRUNCATE_VECTOR_VL",
|
||||
SDTypeProfile<1, 3, [SDTCisVec<0>,
|
||||
SDTCisVec<1>,
|
||||
SDTCisSameNumEltsAs<0, 2>,
|
||||
SDTCVecEltisVT<2, i1>,
|
||||
SDTCisVT<3, XLenVT>]>>;
|
||||
|
||||
// Ignore the vl operand.
|
||||
def SplatFPOp : PatFrag<(ops node:$op),
|
||||
(riscv_vfmv_v_f_vl node:$op, srcvalue)>;
|
||||
|
@ -443,6 +450,17 @@ defm "" : VPatBinaryVL_VV_VX_VI<riscv_shl_vl, "PseudoVSLL", uimm5>;
|
|||
defm "" : VPatBinaryVL_VV_VX_VI<riscv_srl_vl, "PseudoVSRL", uimm5>;
|
||||
defm "" : VPatBinaryVL_VV_VX_VI<riscv_sra_vl, "PseudoVSRA", uimm5>;
|
||||
|
||||
// 12.7. Vector Narrowing Integer Right Shift Instructions
|
||||
foreach vtiTofti = AllFractionableVF2IntVectors in {
|
||||
defvar vti = vtiTofti.Vti;
|
||||
defvar fti = vtiTofti.Fti;
|
||||
def : Pat<(fti.Vector (riscv_trunc_vector_vl (vti.Vector vti.RegClass:$rs1),
|
||||
(vti.Mask true_mask),
|
||||
(XLenVT (VLOp GPR:$vl)))),
|
||||
(!cast<Instruction>("PseudoVNSRL_WI_"#fti.LMul.MX)
|
||||
vti.RegClass:$rs1, 0, GPR:$vl, fti.SEW)>;
|
||||
}
|
||||
|
||||
// 12.8. Vector Integer Comparison Instructions
|
||||
foreach vti = AllIntegerVectors in {
|
||||
defm "" : VPatIntegerSetCCVL_VV<vti, "PseudoVMSEQ", SETEQ>;
|
||||
|
|
|
@ -165,3 +165,80 @@ define void @sext_v32i8_v32i32(<32 x i8>* %x, <32 x i32>* %z) {
|
|||
store <32 x i32> %b, <32 x i32>* %z
|
||||
ret void
|
||||
}
|
||||
|
||||
define void @trunc_v4i8_v4i32(<4 x i32>* %x, <4 x i8>* %z) {
|
||||
; CHECK-LABEL: trunc_v4i8_v4i32:
|
||||
; CHECK: # %bb.0:
|
||||
; CHECK-NEXT: vsetivli a2, 4, e32,m1,ta,mu
|
||||
; CHECK-NEXT: vle32.v v25, (a0)
|
||||
; CHECK-NEXT: vsetivli a0, 4, e16,mf2,ta,mu
|
||||
; CHECK-NEXT: vnsrl.wi v26, v25, 0
|
||||
; CHECK-NEXT: vsetivli a0, 4, e8,mf4,ta,mu
|
||||
; CHECK-NEXT: vnsrl.wi v25, v26, 0
|
||||
; CHECK-NEXT: vsetivli a0, 4, e8,m1,ta,mu
|
||||
; CHECK-NEXT: vse8.v v25, (a1)
|
||||
; CHECK-NEXT: ret
|
||||
%a = load <4 x i32>, <4 x i32>* %x
|
||||
%b = trunc <4 x i32> %a to <4 x i8>
|
||||
store <4 x i8> %b, <4 x i8>* %z
|
||||
ret void
|
||||
}
|
||||
|
||||
define void @trunc_v8i8_v8i32(<8 x i32>* %x, <8 x i8>* %z) {
|
||||
; LMULMAX8-LABEL: trunc_v8i8_v8i32:
|
||||
; LMULMAX8: # %bb.0:
|
||||
; LMULMAX8-NEXT: vsetivli a2, 8, e32,m2,ta,mu
|
||||
; LMULMAX8-NEXT: vle32.v v26, (a0)
|
||||
; LMULMAX8-NEXT: vsetivli a0, 8, e16,m1,ta,mu
|
||||
; LMULMAX8-NEXT: vnsrl.wi v25, v26, 0
|
||||
; LMULMAX8-NEXT: vsetivli a0, 8, e8,mf2,ta,mu
|
||||
; LMULMAX8-NEXT: vnsrl.wi v26, v25, 0
|
||||
; LMULMAX8-NEXT: vsetivli a0, 8, e8,m1,ta,mu
|
||||
; LMULMAX8-NEXT: vse8.v v26, (a1)
|
||||
; LMULMAX8-NEXT: ret
|
||||
;
|
||||
; LMULMAX2-LABEL: trunc_v8i8_v8i32:
|
||||
; LMULMAX2: # %bb.0:
|
||||
; LMULMAX2-NEXT: vsetivli a2, 8, e32,m2,ta,mu
|
||||
; LMULMAX2-NEXT: vle32.v v26, (a0)
|
||||
; LMULMAX2-NEXT: vsetivli a0, 8, e16,m1,ta,mu
|
||||
; LMULMAX2-NEXT: vnsrl.wi v25, v26, 0
|
||||
; LMULMAX2-NEXT: vsetivli a0, 8, e8,mf2,ta,mu
|
||||
; LMULMAX2-NEXT: vnsrl.wi v26, v25, 0
|
||||
; LMULMAX2-NEXT: vsetivli a0, 8, e8,m1,ta,mu
|
||||
; LMULMAX2-NEXT: vse8.v v26, (a1)
|
||||
; LMULMAX2-NEXT: ret
|
||||
;
|
||||
; LMULMAX1-LABEL: trunc_v8i8_v8i32:
|
||||
; LMULMAX1: # %bb.0:
|
||||
; LMULMAX1-NEXT: addi sp, sp, -16
|
||||
; LMULMAX1-NEXT: .cfi_def_cfa_offset 16
|
||||
; LMULMAX1-NEXT: vsetivli a2, 4, e32,m1,ta,mu
|
||||
; LMULMAX1-NEXT: addi a2, a0, 16
|
||||
; LMULMAX1-NEXT: vle32.v v25, (a2)
|
||||
; LMULMAX1-NEXT: vle32.v v26, (a0)
|
||||
; LMULMAX1-NEXT: vsetivli a0, 4, e16,mf2,ta,mu
|
||||
; LMULMAX1-NEXT: vnsrl.wi v27, v25, 0
|
||||
; LMULMAX1-NEXT: vsetivli a0, 4, e8,mf4,ta,mu
|
||||
; LMULMAX1-NEXT: vnsrl.wi v25, v27, 0
|
||||
; LMULMAX1-NEXT: addi a0, sp, 12
|
||||
; LMULMAX1-NEXT: vsetivli a2, 4, e8,m1,ta,mu
|
||||
; LMULMAX1-NEXT: vse8.v v25, (a0)
|
||||
; LMULMAX1-NEXT: vsetivli a0, 4, e16,mf2,ta,mu
|
||||
; LMULMAX1-NEXT: vnsrl.wi v25, v26, 0
|
||||
; LMULMAX1-NEXT: vsetivli a0, 4, e8,mf4,ta,mu
|
||||
; LMULMAX1-NEXT: vnsrl.wi v26, v25, 0
|
||||
; LMULMAX1-NEXT: vsetivli a0, 4, e8,m1,ta,mu
|
||||
; LMULMAX1-NEXT: addi a0, sp, 8
|
||||
; LMULMAX1-NEXT: vse8.v v26, (a0)
|
||||
; LMULMAX1-NEXT: vsetivli a0, 8, e8,m1,ta,mu
|
||||
; LMULMAX1-NEXT: addi a0, sp, 8
|
||||
; LMULMAX1-NEXT: vle8.v v25, (a0)
|
||||
; LMULMAX1-NEXT: vse8.v v25, (a1)
|
||||
; LMULMAX1-NEXT: addi sp, sp, 16
|
||||
; LMULMAX1-NEXT: ret
|
||||
%a = load <8 x i32>, <8 x i32>* %x
|
||||
%b = trunc <8 x i32> %a to <8 x i8>
|
||||
store <8 x i8> %b, <8 x i8>* %z
|
||||
ret void
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue