[DAGCombiner] clean up visitEXTRACT_VECTOR_ELT

This isn't quite NFC, but I don't know how to expose
any outward diffs from these changes. Mostly, this
was confusing because it used 'VT' to refer to the
operand type rather the usual type of the input node.

There's also a large block at the end that is dedicated 
solely to matching loads, but that wasn't obvious. This
could probably be split up into separate functions to
make it easier to see. 

It's still not clear to me when we make certain transforms 
because the legality and constant conditions are 
intertwined in a way that might be improved.

llvm-svn: 349095
This commit is contained in:
Sanjay Patel 2018-12-14 00:09:08 +00:00
parent 178abc59ac
commit 093ab45d4c
1 changed files with 135 additions and 144 deletions

View File

@ -264,8 +264,9 @@ namespace {
/// \param EltNo index of the vector element to load.
/// \param OriginalLoad load that EVE came from to be replaced.
/// \returns EVE on success SDValue() on failure.
SDValue ReplaceExtractVectorEltOfLoadWithNarrowedLoad(
SDNode *EVE, EVT InVecVT, SDValue EltNo, LoadSDNode *OriginalLoad);
SDValue scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT,
SDValue EltNo,
LoadSDNode *OriginalLoad);
void ReplaceLoadWithPromotedLoad(SDNode *Load, SDNode *ExtLoad);
SDValue PromoteOperand(SDValue Op, EVT PVT, bool &Replace);
SDValue SExtPromoteOperand(SDValue Op, EVT PVT);
@ -15463,8 +15464,9 @@ SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) {
return DAG.getBuildVector(VT, DL, Ops);
}
SDValue DAGCombiner::ReplaceExtractVectorEltOfLoadWithNarrowedLoad(
SDNode *EVE, EVT InVecVT, SDValue EltNo, LoadSDNode *OriginalLoad) {
SDValue DAGCombiner::scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT,
SDValue EltNo,
LoadSDNode *OriginalLoad) {
assert(!OriginalLoad->isVolatile());
EVT ResultVT = EVE->getValueType(0);
@ -15547,43 +15549,52 @@ SDValue DAGCombiner::ReplaceExtractVectorEltOfLoadWithNarrowedLoad(
}
SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) {
SDValue InVec = N->getOperand(0);
EVT VT = InVec.getValueType();
EVT NVT = N->getValueType(0);
if (InVec.isUndef())
return DAG.getUNDEF(NVT);
SDValue VecOp = N->getOperand(0);
SDValue Index = N->getOperand(1);
EVT ScalarVT = N->getValueType(0);
EVT VecVT = VecOp.getValueType();
if (VecOp.isUndef())
return DAG.getUNDEF(ScalarVT);
// extract_vector_elt (insert_vector_elt vec, val, idx), idx) -> val
//
// This only really matters if the index is non-constant since other combines
// on the constant elements already work.
SDLoc DL(N);
if (VecOp.getOpcode() == ISD::INSERT_VECTOR_ELT &&
Index == VecOp.getOperand(2)) {
SDValue Elt = VecOp.getOperand(1);
return VecVT.isInteger() ? DAG.getAnyExtOrTrunc(Elt, DL, ScalarVT) : Elt;
}
// (vextract (scalar_to_vector val, 0) -> val
if (InVec.getOpcode() == ISD::SCALAR_TO_VECTOR) {
if (VecOp.getOpcode() == ISD::SCALAR_TO_VECTOR) {
// Check if the result type doesn't match the inserted element type. A
// SCALAR_TO_VECTOR may truncate the inserted element and the
// EXTRACT_VECTOR_ELT may widen the extracted vector.
SDValue InOp = InVec.getOperand(0);
if (InOp.getValueType() != NVT) {
assert(InOp.getValueType().isInteger() && NVT.isInteger());
return DAG.getSExtOrTrunc(InOp, SDLoc(InVec), NVT);
SDValue InOp = VecOp.getOperand(0);
if (InOp.getValueType() != ScalarVT) {
assert(InOp.getValueType().isInteger() && ScalarVT.isInteger());
return DAG.getSExtOrTrunc(InOp, DL, ScalarVT);
}
return InOp;
}
SDValue EltNo = N->getOperand(1);
ConstantSDNode *ConstEltNo = dyn_cast<ConstantSDNode>(EltNo);
// extract_vector_elt of out-of-bounds element -> UNDEF
if (ConstEltNo && ConstEltNo->getAPIntValue().uge(VT.getVectorNumElements()))
return DAG.getUNDEF(NVT);
auto *IndexC = dyn_cast<ConstantSDNode>(Index);
unsigned NumElts = VecVT.getVectorNumElements();
if (IndexC && IndexC->getAPIntValue().uge(NumElts))
return DAG.getUNDEF(ScalarVT);
// extract_vector_elt (build_vector x, y), 1 -> y
if (ConstEltNo &&
InVec.getOpcode() == ISD::BUILD_VECTOR &&
TLI.isTypeLegal(VT) &&
(InVec.hasOneUse() ||
TLI.aggressivelyPreferBuildVectorSources(VT))) {
SDValue Elt = InVec.getOperand(ConstEltNo->getZExtValue());
if (IndexC && VecOp.getOpcode() == ISD::BUILD_VECTOR &&
TLI.isTypeLegal(VecVT) &&
(VecOp.hasOneUse() || TLI.aggressivelyPreferBuildVectorSources(VecVT))) {
SDValue Elt = VecOp.getOperand(IndexC->getZExtValue());
EVT InEltVT = Elt.getValueType();
// Sometimes build_vector's scalar input types do not match result type.
if (NVT == InEltVT)
if (ScalarVT == InEltVT)
return Elt;
// TODO: It may be useful to truncate if free if the build_vector implicitly
@ -15593,27 +15604,27 @@ SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) {
// TODO: These transforms should not require the 'hasOneUse' restriction, but
// there are regressions on multiple targets without it. We can end up with a
// mess of scalar and vector code if we reduce only part of the DAG to scalar.
if (ConstEltNo && InVec.getOpcode() == ISD::BITCAST && VT.isInteger() &&
InVec.hasOneUse()) {
if (IndexC && VecOp.getOpcode() == ISD::BITCAST && VecVT.isInteger() &&
VecOp.hasOneUse()) {
// The vector index of the LSBs of the source depend on the endian-ness.
bool IsLE = DAG.getDataLayout().isLittleEndian();
unsigned ExtractIndex = ConstEltNo->getZExtValue();
unsigned ExtractIndex = IndexC->getZExtValue();
// extract_elt (v2i32 (bitcast i64:x)), BCTruncElt -> i32 (trunc i64:x)
unsigned BCTruncElt = IsLE ? 0 : VT.getVectorNumElements() - 1;
SDValue BCSrc = InVec.getOperand(0);
unsigned BCTruncElt = IsLE ? 0 : NumElts - 1;
SDValue BCSrc = VecOp.getOperand(0);
if (ExtractIndex == BCTruncElt && BCSrc.getValueType().isScalarInteger())
return DAG.getNode(ISD::TRUNCATE, SDLoc(N), NVT, BCSrc);
return DAG.getNode(ISD::TRUNCATE, DL, ScalarVT, BCSrc);
if (LegalTypes && BCSrc.getValueType().isInteger() &&
BCSrc.getOpcode() == ISD::SCALAR_TO_VECTOR) {
// ext_elt (bitcast (scalar_to_vec i64 X to v2i64) to v4i32), TruncElt -->
// trunc i64 X to i32
SDValue X = BCSrc.getOperand(0);
assert(X.getValueType().isScalarInteger() && NVT.isScalarInteger() &&
assert(X.getValueType().isScalarInteger() && ScalarVT.isScalarInteger() &&
"Extract element and scalar to vector can't change element type "
"from FP to integer.");
unsigned XBitWidth = X.getValueSizeInBits();
unsigned VecEltBitWidth = VT.getScalarSizeInBits();
unsigned VecEltBitWidth = VecVT.getScalarSizeInBits();
BCTruncElt = IsLE ? 0 : XBitWidth / VecEltBitWidth - 1;
// An extract element return value type can be wider than its vector
@ -15622,51 +15633,40 @@ SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) {
if (ExtractIndex == BCTruncElt && XBitWidth > VecEltBitWidth) {
assert(XBitWidth % VecEltBitWidth == 0 &&
"Scalar bitwidth must be a multiple of vector element bitwidth");
return DAG.getAnyExtOrTrunc(X, SDLoc(N), NVT);
return DAG.getAnyExtOrTrunc(X, DL, ScalarVT);
}
}
}
// extract_vector_elt (insert_vector_elt vec, val, idx), idx) -> val
//
// This only really matters if the index is non-constant since other combines
// on the constant elements already work.
if (InVec.getOpcode() == ISD::INSERT_VECTOR_ELT &&
EltNo == InVec.getOperand(2)) {
SDValue Elt = InVec.getOperand(1);
return VT.isInteger() ? DAG.getAnyExtOrTrunc(Elt, SDLoc(N), NVT) : Elt;
}
// Transform: (EXTRACT_VECTOR_ELT( VECTOR_SHUFFLE )) -> EXTRACT_VECTOR_ELT.
// We only perform this optimization before the op legalization phase because
// we may introduce new vector instructions which are not backed by TD
// patterns. For example on AVX, extracting elements from a wide vector
// without using extract_subvector. However, if we can find an underlying
// scalar value, then we can always use that.
if (ConstEltNo && InVec.getOpcode() == ISD::VECTOR_SHUFFLE) {
int NumElem = VT.getVectorNumElements();
ShuffleVectorSDNode *SVOp = cast<ShuffleVectorSDNode>(InVec);
if (IndexC && VecOp.getOpcode() == ISD::VECTOR_SHUFFLE) {
auto *Shuf = cast<ShuffleVectorSDNode>(VecOp);
// Find the new index to extract from.
int OrigElt = SVOp->getMaskElt(ConstEltNo->getZExtValue());
int OrigElt = Shuf->getMaskElt(IndexC->getZExtValue());
// Extracting an undef index is undef.
if (OrigElt == -1)
return DAG.getUNDEF(NVT);
return DAG.getUNDEF(ScalarVT);
// Select the right vector half to extract from.
SDValue SVInVec;
if (OrigElt < NumElem) {
SVInVec = InVec->getOperand(0);
if (OrigElt < (int)NumElts) {
SVInVec = VecOp.getOperand(0);
} else {
SVInVec = InVec->getOperand(1);
OrigElt -= NumElem;
SVInVec = VecOp.getOperand(1);
OrigElt -= NumElts;
}
if (SVInVec.getOpcode() == ISD::BUILD_VECTOR) {
SDValue InOp = SVInVec.getOperand(OrigElt);
if (InOp.getValueType() != NVT) {
assert(InOp.getValueType().isInteger() && NVT.isInteger());
InOp = DAG.getSExtOrTrunc(InOp, SDLoc(SVInVec), NVT);
if (InOp.getValueType() != ScalarVT) {
assert(InOp.getValueType().isInteger() && ScalarVT.isInteger());
InOp = DAG.getSExtOrTrunc(InOp, DL, ScalarVT);
}
return InOp;
@ -15677,28 +15677,28 @@ SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) {
if (!LegalOperations ||
// FIXME: Should really be just isOperationLegalOrCustom.
TLI.isOperationLegal(ISD::EXTRACT_VECTOR_ELT, VT) ||
TLI.isOperationExpand(ISD::VECTOR_SHUFFLE, VT)) {
TLI.isOperationLegal(ISD::EXTRACT_VECTOR_ELT, VecVT) ||
TLI.isOperationExpand(ISD::VECTOR_SHUFFLE, VecVT)) {
EVT IndexTy = TLI.getVectorIdxTy(DAG.getDataLayout());
return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SDLoc(N), NVT, SVInVec,
DAG.getConstant(OrigElt, SDLoc(SVOp), IndexTy));
return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT, SVInVec,
DAG.getConstant(OrigElt, DL, IndexTy));
}
}
// If only EXTRACT_VECTOR_ELT nodes use the source vector we can
// simplify it based on the (valid) extraction indices.
if (llvm::all_of(InVec->uses(), [&](SDNode *Use) {
if (llvm::all_of(VecOp->uses(), [&](SDNode *Use) {
return Use->getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
Use->getOperand(0) == InVec &&
Use->getOperand(0) == VecOp &&
isa<ConstantSDNode>(Use->getOperand(1));
})) {
APInt DemandedElts = APInt::getNullValue(VT.getVectorNumElements());
for (SDNode *Use : InVec->uses()) {
APInt DemandedElts = APInt::getNullValue(NumElts);
for (SDNode *Use : VecOp->uses()) {
auto *CstElt = cast<ConstantSDNode>(Use->getOperand(1));
if (CstElt->getAPIntValue().ult(VT.getVectorNumElements()))
if (CstElt->getAPIntValue().ult(NumElts))
DemandedElts.setBit(CstElt->getZExtValue());
}
if (SimplifyDemandedVectorElts(InVec, DemandedElts, true)) {
if (SimplifyDemandedVectorElts(VecOp, DemandedElts, true)) {
// We simplified the vector operand of this extract element. If this
// extract is not dead, visit it again so it is folded properly.
if (N->getOpcode() != ISD::DELETED_NODE)
@ -15707,111 +15707,102 @@ SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) {
}
}
bool BCNumEltsChanged = false;
EVT ExtVT = VT.getVectorElementType();
EVT LVT = ExtVT;
// Everything under here is trying to match an extract of a loaded value.
// If the result of load has to be truncated, then it's not necessarily
// profitable.
if (NVT.bitsLT(LVT) && !TLI.isTruncateFree(LVT, NVT))
bool BCNumEltsChanged = false;
EVT ExtVT = VecVT.getVectorElementType();
EVT LVT = ExtVT;
if (ScalarVT.bitsLT(LVT) && !TLI.isTruncateFree(LVT, ScalarVT))
return SDValue();
if (InVec.getOpcode() == ISD::BITCAST) {
if (VecOp.getOpcode() == ISD::BITCAST) {
// Don't duplicate a load with other uses.
if (!InVec.hasOneUse())
if (!VecOp.hasOneUse())
return SDValue();
EVT BCVT = InVec.getOperand(0).getValueType();
EVT BCVT = VecOp.getOperand(0).getValueType();
if (!BCVT.isVector() || ExtVT.bitsGT(BCVT.getVectorElementType()))
return SDValue();
if (VT.getVectorNumElements() != BCVT.getVectorNumElements())
if (NumElts != BCVT.getVectorNumElements())
BCNumEltsChanged = true;
InVec = InVec.getOperand(0);
VecOp = VecOp.getOperand(0);
ExtVT = BCVT.getVectorElementType();
}
// (vextract (vN[if]M load $addr), i) -> ([if]M load $addr + i * size)
if (!LegalOperations && !ConstEltNo && InVec.hasOneUse() &&
ISD::isNormalLoad(InVec.getNode()) &&
!N->getOperand(1)->hasPredecessor(InVec.getNode())) {
if (!LegalOperations && !IndexC && VecOp.hasOneUse() &&
ISD::isNormalLoad(VecOp.getNode()) &&
!N->getOperand(1)->hasPredecessor(VecOp.getNode())) {
SDValue Index = N->getOperand(1);
if (LoadSDNode *OrigLoad = dyn_cast<LoadSDNode>(InVec)) {
if (!OrigLoad->isVolatile()) {
return ReplaceExtractVectorEltOfLoadWithNarrowedLoad(N, VT, Index,
OrigLoad);
}
}
if (auto *OrigLoad = dyn_cast<LoadSDNode>(VecOp))
if (!OrigLoad->isVolatile())
return scalarizeExtractedVectorLoad(N, VecVT, Index, OrigLoad);
}
// Perform only after legalization to ensure build_vector / vector_shuffle
// optimizations have already been done.
if (!LegalOperations) return SDValue();
if (!LegalOperations || !IndexC)
return SDValue();
// (vextract (v4f32 load $addr), c) -> (f32 load $addr+c*size)
// (vextract (v4f32 s2v (f32 load $addr)), c) -> (f32 load $addr+c*size)
// (vextract (v4f32 shuffle (load $addr), <1,u,u,u>), 0) -> (f32 load $addr)
if (ConstEltNo) {
int Elt = cast<ConstantSDNode>(EltNo)->getZExtValue();
LoadSDNode *LN0 = nullptr;
const ShuffleVectorSDNode *SVN = nullptr;
if (ISD::isNormalLoad(InVec.getNode())) {
LN0 = cast<LoadSDNode>(InVec);
} else if (InVec.getOpcode() == ISD::SCALAR_TO_VECTOR &&
InVec.getOperand(0).getValueType() == ExtVT &&
ISD::isNormalLoad(InVec.getOperand(0).getNode())) {
// Don't duplicate a load with other uses.
if (!InVec.hasOneUse())
return SDValue();
LN0 = cast<LoadSDNode>(InVec.getOperand(0));
} else if ((SVN = dyn_cast<ShuffleVectorSDNode>(InVec))) {
// (vextract (vector_shuffle (load $addr), v2, <1, u, u, u>), 1)
// =>
// (load $addr+1*size)
// Don't duplicate a load with other uses.
if (!InVec.hasOneUse())
return SDValue();
// If the bit convert changed the number of elements, it is unsafe
// to examine the mask.
if (BCNumEltsChanged)
return SDValue();
// Select the input vector, guarding against out of range extract vector.
unsigned NumElems = VT.getVectorNumElements();
int Idx = (Elt > (int)NumElems) ? -1 : SVN->getMaskElt(Elt);
InVec = (Idx < (int)NumElems) ? InVec.getOperand(0) : InVec.getOperand(1);
if (InVec.getOpcode() == ISD::BITCAST) {
// Don't duplicate a load with other uses.
if (!InVec.hasOneUse())
return SDValue();
InVec = InVec.getOperand(0);
}
if (ISD::isNormalLoad(InVec.getNode())) {
LN0 = cast<LoadSDNode>(InVec);
Elt = (Idx < (int)NumElems) ? Idx : Idx - (int)NumElems;
EltNo = DAG.getConstant(Elt, SDLoc(EltNo), EltNo.getValueType());
}
}
// Make sure we found a non-volatile load and the extractelement is
// the only use.
if (!LN0 || !LN0->hasNUsesOfValue(1,0) || LN0->isVolatile())
int Elt = IndexC->getZExtValue();
LoadSDNode *LN0 = nullptr;
if (ISD::isNormalLoad(VecOp.getNode())) {
LN0 = cast<LoadSDNode>(VecOp);
} else if (VecOp.getOpcode() == ISD::SCALAR_TO_VECTOR &&
VecOp.getOperand(0).getValueType() == ExtVT &&
ISD::isNormalLoad(VecOp.getOperand(0).getNode())) {
// Don't duplicate a load with other uses.
if (!VecOp.hasOneUse())
return SDValue();
// If Idx was -1 above, Elt is going to be -1, so just return undef.
if (Elt == -1)
return DAG.getUNDEF(LVT);
LN0 = cast<LoadSDNode>(VecOp.getOperand(0));
}
if (auto *Shuf = dyn_cast<ShuffleVectorSDNode>(VecOp)) {
// (vextract (vector_shuffle (load $addr), v2, <1, u, u, u>), 1)
// =>
// (load $addr+1*size)
return ReplaceExtractVectorEltOfLoadWithNarrowedLoad(N, VT, EltNo, LN0);
// Don't duplicate a load with other uses.
if (!VecOp.hasOneUse())
return SDValue();
// If the bit convert changed the number of elements, it is unsafe
// to examine the mask.
if (BCNumEltsChanged)
return SDValue();
// Select the input vector, guarding against out of range extract vector.
int Idx = (Elt > (int)NumElts) ? -1 : Shuf->getMaskElt(Elt);
VecOp = (Idx < (int)NumElts) ? VecOp.getOperand(0) : VecOp.getOperand(1);
if (VecOp.getOpcode() == ISD::BITCAST) {
// Don't duplicate a load with other uses.
if (!VecOp.hasOneUse())
return SDValue();
VecOp = VecOp.getOperand(0);
}
if (ISD::isNormalLoad(VecOp.getNode())) {
LN0 = cast<LoadSDNode>(VecOp);
Elt = (Idx < (int)NumElts) ? Idx : Idx - (int)NumElts;
Index = DAG.getConstant(Elt, DL, Index.getValueType());
}
}
return SDValue();
// Make sure we found a non-volatile load and the extractelement is
// the only use.
if (!LN0 || !LN0->hasNUsesOfValue(1,0) || LN0->isVolatile())
return SDValue();
// If Idx was -1 above, Elt is going to be -1, so just return undef.
if (Elt == -1)
return DAG.getUNDEF(LVT);
return scalarizeExtractedVectorLoad(N, VecVT, Index, LN0);
}
// Simplify (build_vec (ext )) to (bitcast (build_vec ))