[X86][SSE] Split IsSplatValue into GetSplatValue and IsSplatVector

Refactor towards making this recursive (necessary for PR38243 rotation splat detection).
IsSplatVector returns the original vector source of the splat and the splat index.
GetSplatValue returns the scalar splatted value as an extraction from IsSplatVector.

llvm-svn: 347168
This commit is contained in:
Simon Pilgrim 2018-11-18 17:15:06 +00:00
parent bc23408fe5
commit 50828c75d0
1 changed files with 40 additions and 36 deletions

View File

@ -23901,18 +23901,23 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG,
return SDValue();
}
// Determine if V is a splat value, and return the scalar.
// If V is a splat value, return the source vector and splat index;
// TODO - can we make this generic and move to SelectionDAG?
static SDValue IsSplatValue(MVT VT, SDValue V, const SDLoc &dl,
SelectionDAG &DAG) {
static SDValue IsSplatVector(MVT VT, SDValue V, int &SplatIdx) {
V = peekThroughEXTRACT_SUBVECTORs(V);
// Check if this is a splat build_vector node.
if (BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(V)) {
SDValue SplatAmt = BV->getSplatValue();
if (SplatAmt && SplatAmt.isUndef())
BitVector BVUndefElts;
SDValue SplatAmt = BV->getSplatValue(&BVUndefElts);
if (SplatAmt && !SplatAmt.isUndef()) {
for (int i = 0, e = BVUndefElts.size(); i != e; ++i)
if (!BVUndefElts[i]) {
SplatIdx = i;
return V;
}
}
return SDValue();
return SplatAmt;
}
// Check for SUB(SPLAT_BV, SPLAT) cases from rotate patterns.
@ -23925,11 +23930,12 @@ static SDValue IsSplatValue(MVT VT, SDValue V, const SDLoc &dl,
BuildVectorSDNode *BV0 = dyn_cast<BuildVectorSDNode>(LHS);
ShuffleVectorSDNode *SVN1 = dyn_cast<ShuffleVectorSDNode>(RHS);
if (BV0 && SVN1 && BV0->getSplatValue(&UndefElts) && SVN1->isSplat()) {
unsigned SplatIdx = (unsigned)SVN1->getSplatIndex();
if (!UndefElts[SplatIdx])
return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl,
VT.getVectorElementType(), V,
DAG.getIntPtrConstant(SplatIdx, dl));
int Idx = SVN1->getSplatIndex();
if (!UndefElts[Idx]) {
SplatIdx = Idx;
return V;
}
return SDValue();
}
}
@ -23937,23 +23943,19 @@ static SDValue IsSplatValue(MVT VT, SDValue V, const SDLoc &dl,
ShuffleVectorSDNode *SVN = dyn_cast<ShuffleVectorSDNode>(V);
if (!SVN || !SVN->isSplat())
return SDValue();
int Idx = SVN->getSplatIndex();
int NumElts = V.getValueType().getVectorNumElements();
SplatIdx = Idx % NumElts;
return V.getOperand(Idx / NumElts);
}
unsigned SplatIdx = (unsigned)SVN->getSplatIndex();
SDValue InVec = V.getOperand(0);
if (InVec.getOpcode() == ISD::BUILD_VECTOR) {
assert((SplatIdx < VT.getVectorNumElements()) &&
"Unexpected shuffle index found!");
return InVec.getOperand(SplatIdx);
} else if (InVec.getOpcode() == ISD::INSERT_VECTOR_ELT) {
if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(InVec.getOperand(2)))
if (C->getZExtValue() == SplatIdx)
return InVec.getOperand(1);
}
// Avoid introducing an extract element from a shuffle.
return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl,
VT.getVectorElementType(), InVec,
DAG.getIntPtrConstant(SplatIdx, dl));
static SDValue GetSplatValue(MVT VT, SDValue V, const SDLoc &dl,
SelectionDAG &DAG) {
int SplatIdx;
if (SDValue SrcVector = IsSplatVector(VT, V, SplatIdx))
return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, VT.getVectorElementType(),
SrcVector, DAG.getIntPtrConstant(SplatIdx, dl));
return SDValue();
}
static SDValue LowerScalarVariableShift(SDValue Op, SelectionDAG &DAG,
@ -23968,7 +23970,7 @@ static SDValue LowerScalarVariableShift(SDValue Op, SelectionDAG &DAG,
Amt = peekThroughEXTRACT_SUBVECTORs(Amt);
if (SDValue BaseShAmt = IsSplatValue(VT, Amt, dl, DAG)) {
if (SDValue BaseShAmt = GetSplatValue(VT, Amt, dl, DAG)) {
if (SupportedVectorShiftWithBaseAmnt(VT, Subtarget, Opcode)) {
MVT EltVT = VT.getVectorElementType();
assert(EltVT.bitsLE(MVT::i64) && "Unexpected element type!");
@ -24670,14 +24672,16 @@ static SDValue LowerRotate(SDValue Op, const X86Subtarget &Subtarget,
// Rotate by splat - expand back to shifts.
// TODO - legalizers should be able to handle this.
if ((EltSizeInBits >= 16 || Subtarget.hasBWI()) &&
IsSplatValue(VT, Amt, DL, DAG)) {
if (EltSizeInBits >= 16 || Subtarget.hasBWI()) {
int SplatIdx;
if (IsSplatVector(VT, Amt, SplatIdx)) {
SDValue AmtR = DAG.getConstant(EltSizeInBits, DL, VT);
AmtR = DAG.getNode(ISD::SUB, DL, VT, AmtR, Amt);
SDValue SHL = DAG.getNode(ISD::SHL, DL, VT, R, Amt);
SDValue SRL = DAG.getNode(ISD::SRL, DL, VT, R, AmtR);
return DAG.getNode(ISD::OR, DL, VT, SHL, SRL);
}
}
// v16i8/v32i8: Split rotation into rot4/rot2/rot1 stages and select by
// the amount bit.