forked from OSchip/llvm-project
[RISCV][SelectionDAG] Introduce an ISD::SPLAT_VECTOR_PARTS node that can represent a splat of 2 i32 values into a nxvXi64 vector for riscv32.
On riscv32, i64 isn't a legal scalar type but we would like to support scalable vectors of i64. This patch introduces a new node that can represent a splat made of multiple scalar values. I've used this new node to solve the current crashes we experience when getConstant is used after type legalization. For RISCV, we are now default expanding SPLAT_VECTOR to SPLAT_VECTOR_PARTS when needed and then handling the SPLAT_VECTOR_PARTS later during LegalizeOps. I've remove the special case I previously put in for ABS for D97991 as the default expansion is now able to succesfully use getConstant. Reviewed By: frasercrmck Differential Revision: https://reviews.llvm.org/D98004
This commit is contained in:
parent
0c73a506e8
commit
9106d04554
llvm
include/llvm/CodeGen
lib
CodeGen/SelectionDAG
Target/RISCV
|
@ -583,6 +583,15 @@ enum NodeType {
|
|||
/// implicitly truncated to it.
|
||||
SPLAT_VECTOR,
|
||||
|
||||
/// SPLAT_VECTOR_PARTS(SCALAR1, SCALAR2, ...) - Returns a vector with the
|
||||
/// scalar values joined together and then duplicated in all lanes. This
|
||||
/// represents a SPLAT_VECTOR that has had its scalar operand expanded. This
|
||||
/// allows representing a 64-bit splat on a target with 32-bit integers. The
|
||||
/// total width of the scalars must cover the element width. SCALAR1 contains
|
||||
/// the least significant bits of the value regardless of endianness and all
|
||||
/// scalars should have the same type.
|
||||
SPLAT_VECTOR_PARTS,
|
||||
|
||||
/// MULHU/MULHS - Multiply high - Multiply two integers of type iN,
|
||||
/// producing an unsigned/signed value of type i[2*N], then return the top
|
||||
/// part.
|
||||
|
|
|
@ -4194,6 +4194,7 @@ bool DAGTypeLegalizer::ExpandIntegerOperand(SDNode *N, unsigned OpNo) {
|
|||
case ISD::EXTRACT_ELEMENT: Res = ExpandOp_EXTRACT_ELEMENT(N); break;
|
||||
case ISD::INSERT_VECTOR_ELT: Res = ExpandOp_INSERT_VECTOR_ELT(N); break;
|
||||
case ISD::SCALAR_TO_VECTOR: Res = ExpandOp_SCALAR_TO_VECTOR(N); break;
|
||||
case ISD::SPLAT_VECTOR: Res = ExpandIntOp_SPLAT_VECTOR(N); break;
|
||||
case ISD::SELECT_CC: Res = ExpandIntOp_SELECT_CC(N); break;
|
||||
case ISD::SETCC: Res = ExpandIntOp_SETCC(N); break;
|
||||
case ISD::SETCCCARRY: Res = ExpandIntOp_SETCCCARRY(N); break;
|
||||
|
@ -4449,6 +4450,14 @@ SDValue DAGTypeLegalizer::ExpandIntOp_SETCCCARRY(SDNode *N) {
|
|||
LowCmp.getValue(1), Cond);
|
||||
}
|
||||
|
||||
SDValue DAGTypeLegalizer::ExpandIntOp_SPLAT_VECTOR(SDNode *N) {
|
||||
// Split the operand and replace with SPLAT_VECTOR_PARTS.
|
||||
SDValue Lo, Hi;
|
||||
GetExpandedInteger(N->getOperand(0), Lo, Hi);
|
||||
return DAG.getNode(ISD::SPLAT_VECTOR_PARTS, SDLoc(N), N->getValueType(0), Lo,
|
||||
Hi);
|
||||
}
|
||||
|
||||
SDValue DAGTypeLegalizer::ExpandIntOp_Shift(SDNode *N) {
|
||||
// The value being shifted is legal, but the shift amount is too big.
|
||||
// It follows that either the result of the shift is undefined, or the
|
||||
|
|
|
@ -481,6 +481,7 @@ private:
|
|||
SDValue ExpandIntOp_UINT_TO_FP(SDNode *N);
|
||||
SDValue ExpandIntOp_RETURNADDR(SDNode *N);
|
||||
SDValue ExpandIntOp_ATOMIC_STORE(SDNode *N);
|
||||
SDValue ExpandIntOp_SPLAT_VECTOR(SDNode *N);
|
||||
|
||||
void IntegerExpandSetCCOperands(SDValue &NewLHS, SDValue &NewRHS,
|
||||
ISD::CondCode &CCCode, const SDLoc &dl);
|
||||
|
|
|
@ -1383,6 +1383,22 @@ SDValue SelectionDAG::getConstant(const ConstantInt &Val, const SDLoc &DL,
|
|||
const APInt &NewVal = Elt->getValue();
|
||||
EVT ViaEltVT = TLI->getTypeToTransformTo(*getContext(), EltVT);
|
||||
unsigned ViaEltSizeInBits = ViaEltVT.getSizeInBits();
|
||||
|
||||
// For scalable vectors, try to use a SPLAT_VECTOR_PARTS node.
|
||||
if (VT.isScalableVector()) {
|
||||
assert(EltVT.getSizeInBits() % ViaEltSizeInBits == 0 &&
|
||||
"Can only handle an even split!");
|
||||
unsigned Parts = EltVT.getSizeInBits() / ViaEltSizeInBits;
|
||||
|
||||
SmallVector<SDValue, 2> ScalarParts;
|
||||
for (unsigned i = 0; i != Parts; ++i)
|
||||
ScalarParts.push_back(getConstant(
|
||||
NewVal.lshr(i * ViaEltSizeInBits).trunc(ViaEltSizeInBits), DL,
|
||||
ViaEltVT, isT, isO));
|
||||
|
||||
return getNode(ISD::SPLAT_VECTOR_PARTS, DL, VT, ScalarParts);
|
||||
}
|
||||
|
||||
unsigned ViaVecNumElts = VT.getSizeInBits() / ViaEltSizeInBits;
|
||||
EVT ViaVecVT = EVT::getVectorVT(*getContext(), ViaEltVT, ViaVecNumElts);
|
||||
|
||||
|
|
|
@ -290,6 +290,7 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
|
|||
case ISD::VECTOR_SHUFFLE: return "vector_shuffle";
|
||||
case ISD::VECTOR_SPLICE: return "vector_splice";
|
||||
case ISD::SPLAT_VECTOR: return "splat_vector";
|
||||
case ISD::SPLAT_VECTOR_PARTS: return "splat_vector_parts";
|
||||
case ISD::VECTOR_REVERSE: return "vector_reverse";
|
||||
case ISD::CARRY_FALSE: return "carry_false";
|
||||
case ISD::ADDC: return "addc";
|
||||
|
|
|
@ -399,7 +399,6 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
|
|||
if (!Subtarget.is64Bit()) {
|
||||
// We must custom-lower certain vXi64 operations on RV32 due to the vector
|
||||
// element type being illegal.
|
||||
setOperationAction(ISD::SPLAT_VECTOR, MVT::i64, Custom);
|
||||
setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::i64, Custom);
|
||||
setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::i64, Custom);
|
||||
|
||||
|
@ -424,15 +423,13 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
|
|||
|
||||
for (MVT VT : IntVecVTs) {
|
||||
setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);
|
||||
setOperationAction(ISD::SPLAT_VECTOR_PARTS, VT, Custom);
|
||||
|
||||
setOperationAction(ISD::SMIN, VT, Legal);
|
||||
setOperationAction(ISD::SMAX, VT, Legal);
|
||||
setOperationAction(ISD::UMIN, VT, Legal);
|
||||
setOperationAction(ISD::UMAX, VT, Legal);
|
||||
|
||||
if (!Subtarget.is64Bit() && VT.getVectorElementType() == MVT::i64)
|
||||
setOperationAction(ISD::ABS, VT, Custom);
|
||||
|
||||
setOperationAction(ISD::ROTL, VT, Expand);
|
||||
setOperationAction(ISD::ROTR, VT, Expand);
|
||||
|
||||
|
@ -1313,8 +1310,8 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
|
|||
Op.getOperand(0).getValueType().getVectorElementType() == MVT::i1)
|
||||
return lowerVectorMaskExt(Op, DAG, /*ExtVal*/ -1);
|
||||
return lowerFixedLengthVectorExtendToRVV(Op, DAG, RISCVISD::VSEXT_VL);
|
||||
case ISD::SPLAT_VECTOR:
|
||||
return lowerSPLATVECTOR(Op, DAG);
|
||||
case ISD::SPLAT_VECTOR_PARTS:
|
||||
return lowerSPLAT_VECTOR_PARTS(Op, DAG);
|
||||
case ISD::INSERT_VECTOR_ELT:
|
||||
return lowerINSERT_VECTOR_ELT(Op, DAG);
|
||||
case ISD::EXTRACT_VECTOR_ELT:
|
||||
|
@ -2035,30 +2032,28 @@ SDValue RISCVTargetLowering::lowerShiftRightParts(SDValue Op, SelectionDAG &DAG,
|
|||
return DAG.getMergeValues(Parts, DL);
|
||||
}
|
||||
|
||||
// Custom-lower a SPLAT_VECTOR where XLEN<SEW, as the SEW element type is
|
||||
// Custom-lower a SPLAT_VECTOR_PARTS where XLEN<SEW, as the SEW element type is
|
||||
// illegal (currently only vXi64 RV32).
|
||||
// FIXME: We could also catch non-constant sign-extended i32 values and lower
|
||||
// them to SPLAT_VECTOR_I64
|
||||
SDValue RISCVTargetLowering::lowerSPLATVECTOR(SDValue Op,
|
||||
SelectionDAG &DAG) const {
|
||||
SDValue RISCVTargetLowering::lowerSPLAT_VECTOR_PARTS(SDValue Op,
|
||||
SelectionDAG &DAG) const {
|
||||
SDLoc DL(Op);
|
||||
EVT VecVT = Op.getValueType();
|
||||
assert(!Subtarget.is64Bit() && VecVT.getVectorElementType() == MVT::i64 &&
|
||||
"Unexpected SPLAT_VECTOR lowering");
|
||||
SDValue SplatVal = Op.getOperand(0);
|
||||
"Unexpected SPLAT_VECTOR_PARTS lowering");
|
||||
|
||||
// If we can prove that the value is a sign-extended 32-bit value, lower this
|
||||
// as a custom node in order to try and match RVV vector/scalar instructions.
|
||||
if (auto *CVal = dyn_cast<ConstantSDNode>(SplatVal)) {
|
||||
if (isInt<32>(CVal->getSExtValue()))
|
||||
return DAG.getNode(RISCVISD::SPLAT_VECTOR_I64, DL, VecVT,
|
||||
DAG.getConstant(CVal->getSExtValue(), DL, MVT::i32));
|
||||
}
|
||||
assert(Op.getNumOperands() == 2 && "Unexpected number of operands!");
|
||||
SDValue Lo = Op.getOperand(0);
|
||||
SDValue Hi = Op.getOperand(1);
|
||||
|
||||
if (SplatVal.getOpcode() == ISD::SIGN_EXTEND &&
|
||||
SplatVal.getOperand(0).getValueType() == MVT::i32) {
|
||||
return DAG.getNode(RISCVISD::SPLAT_VECTOR_I64, DL, VecVT,
|
||||
SplatVal.getOperand(0));
|
||||
if (isa<ConstantSDNode>(Lo) && isa<ConstantSDNode>(Hi)) {
|
||||
int32_t LoC = cast<ConstantSDNode>(Lo)->getSExtValue();
|
||||
int32_t HiC = cast<ConstantSDNode>(Hi)->getSExtValue();
|
||||
// If Hi constant is all the same sign bit as Lo, lower this as a custom
|
||||
// node in order to try and match RVV vector/scalar instructions.
|
||||
if ((LoC >> 31) == HiC)
|
||||
return DAG.getNode(RISCVISD::SPLAT_VECTOR_I64, DL, VecVT, Lo);
|
||||
}
|
||||
|
||||
// Else, on RV32 we lower an i64-element SPLAT_VECTOR thus, being careful not
|
||||
|
@ -2069,11 +2064,7 @@ SDValue RISCVTargetLowering::lowerSPLATVECTOR(SDValue Op,
|
|||
// vsll.vx vY, vY, /*32*/
|
||||
// vsrl.vx vY, vY, /*32*/
|
||||
// vor.vv vX, vX, vY
|
||||
SDValue One = DAG.getConstant(1, DL, MVT::i32);
|
||||
SDValue Zero = DAG.getConstant(0, DL, MVT::i32);
|
||||
SDValue ThirtyTwoV = DAG.getConstant(32, DL, VecVT);
|
||||
SDValue Lo = DAG.getNode(ISD::EXTRACT_ELEMENT, DL, MVT::i32, SplatVal, Zero);
|
||||
SDValue Hi = DAG.getNode(ISD::EXTRACT_ELEMENT, DL, MVT::i32, SplatVal, One);
|
||||
|
||||
Lo = DAG.getNode(RISCVISD::SPLAT_VECTOR_I64, DL, VecVT, Lo);
|
||||
Lo = DAG.getNode(ISD::SHL, DL, VecVT, Lo, ThirtyTwoV);
|
||||
|
@ -3162,17 +3153,6 @@ SDValue RISCVTargetLowering::lowerABS(SDValue Op, SelectionDAG &DAG) const {
|
|||
MVT VT = Op.getSimpleValueType();
|
||||
SDValue X = Op.getOperand(0);
|
||||
|
||||
// For scalable vectors we just need to deal with i64 on RV32 since the
|
||||
// default expansion crashes in getConstant.
|
||||
if (VT.isScalableVector()) {
|
||||
assert(!Subtarget.is64Bit() && VT.getVectorElementType() == MVT::i64 &&
|
||||
"Unexpected custom lowering!");
|
||||
SDValue SplatZero = DAG.getNode(RISCVISD::SPLAT_VECTOR_I64, DL, VT,
|
||||
DAG.getConstant(0, DL, MVT::i32));
|
||||
SDValue NegX = DAG.getNode(ISD::SUB, DL, VT, SplatZero, X);
|
||||
return DAG.getNode(ISD::SMAX, DL, VT, X, NegX);
|
||||
}
|
||||
|
||||
assert(VT.isFixedLengthVector() && "Unexpected type");
|
||||
|
||||
MVT ContainerVT =
|
||||
|
|
|
@ -437,7 +437,7 @@ private:
|
|||
SDValue lowerRETURNADDR(SDValue Op, SelectionDAG &DAG) const;
|
||||
SDValue lowerShiftLeftParts(SDValue Op, SelectionDAG &DAG) const;
|
||||
SDValue lowerShiftRightParts(SDValue Op, SelectionDAG &DAG, bool IsSRA) const;
|
||||
SDValue lowerSPLATVECTOR(SDValue Op, SelectionDAG &DAG) const;
|
||||
SDValue lowerSPLAT_VECTOR_PARTS(SDValue Op, SelectionDAG &DAG) const;
|
||||
SDValue lowerVectorMaskExt(SDValue Op, SelectionDAG &DAG,
|
||||
int64_t ExtTrueVal) const;
|
||||
SDValue lowerVectorMaskTrunc(SDValue Op, SelectionDAG &DAG) const;
|
||||
|
|
Loading…
Reference in New Issue