[RISCV][SelectionDAG] Introduce an ISD::SPLAT_VECTOR_PARTS node that can represent a splat of 2 i32 values into a nxvXi64 vector for riscv32.

On riscv32, i64 isn't a legal scalar type but we would like to
support scalable vectors of i64.

This patch introduces a new node that can represent a splat made
of multiple scalar values. I've used this new node to solve the current
crashes we experience when getConstant is used after type legalization.

For RISCV, we are now default expanding SPLAT_VECTOR to SPLAT_VECTOR_PARTS
when needed and then handling the SPLAT_VECTOR_PARTS later during
LegalizeOps. I've remove the special case I previously put in for
ABS for D97991 as the default expansion is now able to succesfully
use getConstant.

Reviewed By: frasercrmck

Differential Revision: https://reviews.llvm.org/D98004
This commit is contained in:
Craig Topper 2021-03-10 09:46:16 -08:00
parent 0c73a506e8
commit 9106d04554
7 changed files with 54 additions and 38 deletions

View File

@ -583,6 +583,15 @@ enum NodeType {
/// implicitly truncated to it.
SPLAT_VECTOR,
/// SPLAT_VECTOR_PARTS(SCALAR1, SCALAR2, ...) - Returns a vector with the
/// scalar values joined together and then duplicated in all lanes. This
/// represents a SPLAT_VECTOR that has had its scalar operand expanded. This
/// allows representing a 64-bit splat on a target with 32-bit integers. The
/// total width of the scalars must cover the element width. SCALAR1 contains
/// the least significant bits of the value regardless of endianness and all
/// scalars should have the same type.
SPLAT_VECTOR_PARTS,
/// MULHU/MULHS - Multiply high - Multiply two integers of type iN,
/// producing an unsigned/signed value of type i[2*N], then return the top
/// part.

View File

@ -4194,6 +4194,7 @@ bool DAGTypeLegalizer::ExpandIntegerOperand(SDNode *N, unsigned OpNo) {
case ISD::EXTRACT_ELEMENT: Res = ExpandOp_EXTRACT_ELEMENT(N); break;
case ISD::INSERT_VECTOR_ELT: Res = ExpandOp_INSERT_VECTOR_ELT(N); break;
case ISD::SCALAR_TO_VECTOR: Res = ExpandOp_SCALAR_TO_VECTOR(N); break;
case ISD::SPLAT_VECTOR: Res = ExpandIntOp_SPLAT_VECTOR(N); break;
case ISD::SELECT_CC: Res = ExpandIntOp_SELECT_CC(N); break;
case ISD::SETCC: Res = ExpandIntOp_SETCC(N); break;
case ISD::SETCCCARRY: Res = ExpandIntOp_SETCCCARRY(N); break;
@ -4449,6 +4450,14 @@ SDValue DAGTypeLegalizer::ExpandIntOp_SETCCCARRY(SDNode *N) {
LowCmp.getValue(1), Cond);
}
SDValue DAGTypeLegalizer::ExpandIntOp_SPLAT_VECTOR(SDNode *N) {
// Split the operand and replace with SPLAT_VECTOR_PARTS.
SDValue Lo, Hi;
GetExpandedInteger(N->getOperand(0), Lo, Hi);
return DAG.getNode(ISD::SPLAT_VECTOR_PARTS, SDLoc(N), N->getValueType(0), Lo,
Hi);
}
SDValue DAGTypeLegalizer::ExpandIntOp_Shift(SDNode *N) {
// The value being shifted is legal, but the shift amount is too big.
// It follows that either the result of the shift is undefined, or the

View File

@ -481,6 +481,7 @@ private:
SDValue ExpandIntOp_UINT_TO_FP(SDNode *N);
SDValue ExpandIntOp_RETURNADDR(SDNode *N);
SDValue ExpandIntOp_ATOMIC_STORE(SDNode *N);
SDValue ExpandIntOp_SPLAT_VECTOR(SDNode *N);
void IntegerExpandSetCCOperands(SDValue &NewLHS, SDValue &NewRHS,
ISD::CondCode &CCCode, const SDLoc &dl);

View File

@ -1383,6 +1383,22 @@ SDValue SelectionDAG::getConstant(const ConstantInt &Val, const SDLoc &DL,
const APInt &NewVal = Elt->getValue();
EVT ViaEltVT = TLI->getTypeToTransformTo(*getContext(), EltVT);
unsigned ViaEltSizeInBits = ViaEltVT.getSizeInBits();
// For scalable vectors, try to use a SPLAT_VECTOR_PARTS node.
if (VT.isScalableVector()) {
assert(EltVT.getSizeInBits() % ViaEltSizeInBits == 0 &&
"Can only handle an even split!");
unsigned Parts = EltVT.getSizeInBits() / ViaEltSizeInBits;
SmallVector<SDValue, 2> ScalarParts;
for (unsigned i = 0; i != Parts; ++i)
ScalarParts.push_back(getConstant(
NewVal.lshr(i * ViaEltSizeInBits).trunc(ViaEltSizeInBits), DL,
ViaEltVT, isT, isO));
return getNode(ISD::SPLAT_VECTOR_PARTS, DL, VT, ScalarParts);
}
unsigned ViaVecNumElts = VT.getSizeInBits() / ViaEltSizeInBits;
EVT ViaVecVT = EVT::getVectorVT(*getContext(), ViaEltVT, ViaVecNumElts);

View File

@ -290,6 +290,7 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
case ISD::VECTOR_SHUFFLE: return "vector_shuffle";
case ISD::VECTOR_SPLICE: return "vector_splice";
case ISD::SPLAT_VECTOR: return "splat_vector";
case ISD::SPLAT_VECTOR_PARTS: return "splat_vector_parts";
case ISD::VECTOR_REVERSE: return "vector_reverse";
case ISD::CARRY_FALSE: return "carry_false";
case ISD::ADDC: return "addc";

View File

@ -399,7 +399,6 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
if (!Subtarget.is64Bit()) {
// We must custom-lower certain vXi64 operations on RV32 due to the vector
// element type being illegal.
setOperationAction(ISD::SPLAT_VECTOR, MVT::i64, Custom);
setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::i64, Custom);
setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::i64, Custom);
@ -424,15 +423,13 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
for (MVT VT : IntVecVTs) {
setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);
setOperationAction(ISD::SPLAT_VECTOR_PARTS, VT, Custom);
setOperationAction(ISD::SMIN, VT, Legal);
setOperationAction(ISD::SMAX, VT, Legal);
setOperationAction(ISD::UMIN, VT, Legal);
setOperationAction(ISD::UMAX, VT, Legal);
if (!Subtarget.is64Bit() && VT.getVectorElementType() == MVT::i64)
setOperationAction(ISD::ABS, VT, Custom);
setOperationAction(ISD::ROTL, VT, Expand);
setOperationAction(ISD::ROTR, VT, Expand);
@ -1313,8 +1310,8 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
Op.getOperand(0).getValueType().getVectorElementType() == MVT::i1)
return lowerVectorMaskExt(Op, DAG, /*ExtVal*/ -1);
return lowerFixedLengthVectorExtendToRVV(Op, DAG, RISCVISD::VSEXT_VL);
case ISD::SPLAT_VECTOR:
return lowerSPLATVECTOR(Op, DAG);
case ISD::SPLAT_VECTOR_PARTS:
return lowerSPLAT_VECTOR_PARTS(Op, DAG);
case ISD::INSERT_VECTOR_ELT:
return lowerINSERT_VECTOR_ELT(Op, DAG);
case ISD::EXTRACT_VECTOR_ELT:
@ -2035,30 +2032,28 @@ SDValue RISCVTargetLowering::lowerShiftRightParts(SDValue Op, SelectionDAG &DAG,
return DAG.getMergeValues(Parts, DL);
}
// Custom-lower a SPLAT_VECTOR where XLEN<SEW, as the SEW element type is
// Custom-lower a SPLAT_VECTOR_PARTS where XLEN<SEW, as the SEW element type is
// illegal (currently only vXi64 RV32).
// FIXME: We could also catch non-constant sign-extended i32 values and lower
// them to SPLAT_VECTOR_I64
SDValue RISCVTargetLowering::lowerSPLATVECTOR(SDValue Op,
SelectionDAG &DAG) const {
SDValue RISCVTargetLowering::lowerSPLAT_VECTOR_PARTS(SDValue Op,
SelectionDAG &DAG) const {
SDLoc DL(Op);
EVT VecVT = Op.getValueType();
assert(!Subtarget.is64Bit() && VecVT.getVectorElementType() == MVT::i64 &&
"Unexpected SPLAT_VECTOR lowering");
SDValue SplatVal = Op.getOperand(0);
"Unexpected SPLAT_VECTOR_PARTS lowering");
// If we can prove that the value is a sign-extended 32-bit value, lower this
// as a custom node in order to try and match RVV vector/scalar instructions.
if (auto *CVal = dyn_cast<ConstantSDNode>(SplatVal)) {
if (isInt<32>(CVal->getSExtValue()))
return DAG.getNode(RISCVISD::SPLAT_VECTOR_I64, DL, VecVT,
DAG.getConstant(CVal->getSExtValue(), DL, MVT::i32));
}
assert(Op.getNumOperands() == 2 && "Unexpected number of operands!");
SDValue Lo = Op.getOperand(0);
SDValue Hi = Op.getOperand(1);
if (SplatVal.getOpcode() == ISD::SIGN_EXTEND &&
SplatVal.getOperand(0).getValueType() == MVT::i32) {
return DAG.getNode(RISCVISD::SPLAT_VECTOR_I64, DL, VecVT,
SplatVal.getOperand(0));
if (isa<ConstantSDNode>(Lo) && isa<ConstantSDNode>(Hi)) {
int32_t LoC = cast<ConstantSDNode>(Lo)->getSExtValue();
int32_t HiC = cast<ConstantSDNode>(Hi)->getSExtValue();
// If Hi constant is all the same sign bit as Lo, lower this as a custom
// node in order to try and match RVV vector/scalar instructions.
if ((LoC >> 31) == HiC)
return DAG.getNode(RISCVISD::SPLAT_VECTOR_I64, DL, VecVT, Lo);
}
// Else, on RV32 we lower an i64-element SPLAT_VECTOR thus, being careful not
@ -2069,11 +2064,7 @@ SDValue RISCVTargetLowering::lowerSPLATVECTOR(SDValue Op,
// vsll.vx vY, vY, /*32*/
// vsrl.vx vY, vY, /*32*/
// vor.vv vX, vX, vY
SDValue One = DAG.getConstant(1, DL, MVT::i32);
SDValue Zero = DAG.getConstant(0, DL, MVT::i32);
SDValue ThirtyTwoV = DAG.getConstant(32, DL, VecVT);
SDValue Lo = DAG.getNode(ISD::EXTRACT_ELEMENT, DL, MVT::i32, SplatVal, Zero);
SDValue Hi = DAG.getNode(ISD::EXTRACT_ELEMENT, DL, MVT::i32, SplatVal, One);
Lo = DAG.getNode(RISCVISD::SPLAT_VECTOR_I64, DL, VecVT, Lo);
Lo = DAG.getNode(ISD::SHL, DL, VecVT, Lo, ThirtyTwoV);
@ -3162,17 +3153,6 @@ SDValue RISCVTargetLowering::lowerABS(SDValue Op, SelectionDAG &DAG) const {
MVT VT = Op.getSimpleValueType();
SDValue X = Op.getOperand(0);
// For scalable vectors we just need to deal with i64 on RV32 since the
// default expansion crashes in getConstant.
if (VT.isScalableVector()) {
assert(!Subtarget.is64Bit() && VT.getVectorElementType() == MVT::i64 &&
"Unexpected custom lowering!");
SDValue SplatZero = DAG.getNode(RISCVISD::SPLAT_VECTOR_I64, DL, VT,
DAG.getConstant(0, DL, MVT::i32));
SDValue NegX = DAG.getNode(ISD::SUB, DL, VT, SplatZero, X);
return DAG.getNode(ISD::SMAX, DL, VT, X, NegX);
}
assert(VT.isFixedLengthVector() && "Unexpected type");
MVT ContainerVT =

View File

@ -437,7 +437,7 @@ private:
SDValue lowerRETURNADDR(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerShiftLeftParts(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerShiftRightParts(SDValue Op, SelectionDAG &DAG, bool IsSRA) const;
SDValue lowerSPLATVECTOR(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerSPLAT_VECTOR_PARTS(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerVectorMaskExt(SDValue Op, SelectionDAG &DAG,
int64_t ExtTrueVal) const;
SDValue lowerVectorMaskTrunc(SDValue Op, SelectionDAG &DAG) const;