[SelectionDAG] Compute Known + Sign Bits - merge INSERT_VECTOR_ELT known/unknown index paths

Match the approach in SimplifyDemandedBits where we calculate the demanded elts and then have a common path for the ComputeKnownBits/ComputeNumSignBits call.
This commit is contained in:
Simon Pilgrim 2020-01-23 12:35:29 +00:00
parent 2f6987ba61
commit 48d4ba8fb2
1 changed files with 37 additions and 51 deletions

View File

@ -3261,38 +3261,31 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
break;
}
case ISD::INSERT_VECTOR_ELT: {
// If we know the element index, split the demand between the
// source vector and the inserted element, otherwise assume we need
// the original demanded vector elements and the value.
SDValue InVec = Op.getOperand(0);
SDValue InVal = Op.getOperand(1);
SDValue EltNo = Op.getOperand(2);
ConstantSDNode *CEltNo = dyn_cast<ConstantSDNode>(EltNo);
bool DemandedVal = true;
APInt DemandedVecElts = DemandedElts;
auto *CEltNo = dyn_cast<ConstantSDNode>(EltNo);
if (CEltNo && CEltNo->getAPIntValue().ult(NumElts)) {
// If we know the element index, split the demand between the
// source vector and the inserted element.
Known.Zero = Known.One = APInt::getAllOnesValue(BitWidth);
unsigned EltIdx = CEltNo->getZExtValue();
// If we demand the inserted element then add its common known bits.
if (DemandedElts[EltIdx]) {
Known2 = computeKnownBits(InVal, Depth + 1);
Known.One &= Known2.One.zextOrTrunc(Known.One.getBitWidth());
Known.Zero &= Known2.Zero.zextOrTrunc(Known.Zero.getBitWidth());
}
// If we demand the source vector then add its common known bits, ensuring
// that we don't demand the inserted element.
APInt VectorElts = DemandedElts & ~(APInt::getOneBitSet(NumElts, EltIdx));
if (!!VectorElts) {
Known2 = computeKnownBits(InVec, VectorElts, Depth + 1);
Known.One &= Known2.One;
Known.Zero &= Known2.Zero;
}
} else {
// Unknown element index, so ignore DemandedElts and demand them all.
Known = computeKnownBits(InVec, Depth + 1);
DemandedVal = !!DemandedElts[EltIdx];
DemandedVecElts.clearBit(EltIdx);
}
Known.One.setAllBits();
Known.Zero.setAllBits();
if (DemandedVal) {
Known2 = computeKnownBits(InVal, Depth + 1);
Known.One &= Known2.One.zextOrTrunc(Known.One.getBitWidth());
Known.Zero &= Known2.Zero.zextOrTrunc(Known.Zero.getBitWidth());
Known.One &= Known2.One.zextOrTrunc(BitWidth);
Known.Zero &= Known2.Zero.zextOrTrunc(BitWidth);
}
if (!!DemandedVecElts) {
Known2 = computeKnownBits(InVec, DemandedVecElts, Depth + 1);
Known.One &= Known2.One;
Known.Zero &= Known2.Zero;
}
break;
}
@ -3850,39 +3843,32 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
return std::max(std::min(KnownSign - rIndex * BitWidth, BitWidth), 0);
}
case ISD::INSERT_VECTOR_ELT: {
// If we know the element index, split the demand between the
// source vector and the inserted element, otherwise assume we need
// the original demanded vector elements and the value.
SDValue InVec = Op.getOperand(0);
SDValue InVal = Op.getOperand(1);
SDValue EltNo = Op.getOperand(2);
ConstantSDNode *CEltNo = dyn_cast<ConstantSDNode>(EltNo);
bool DemandedVal = true;
APInt DemandedVecElts = DemandedElts;
auto *CEltNo = dyn_cast<ConstantSDNode>(EltNo);
if (CEltNo && CEltNo->getAPIntValue().ult(NumElts)) {
// If we know the element index, split the demand between the
// source vector and the inserted element.
unsigned EltIdx = CEltNo->getZExtValue();
// If we demand the inserted element then get its sign bits.
Tmp = std::numeric_limits<unsigned>::max();
if (DemandedElts[EltIdx]) {
// TODO - handle implicit truncation of inserted elements.
if (InVal.getScalarValueSizeInBits() != VTBits)
break;
Tmp = ComputeNumSignBits(InVal, Depth + 1);
}
// If we demand the source vector then get its sign bits, and determine
// the minimum.
APInt VectorElts = DemandedElts;
VectorElts.clearBit(EltIdx);
if (!!VectorElts) {
Tmp2 = ComputeNumSignBits(InVec, VectorElts, Depth + 1);
Tmp = std::min(Tmp, Tmp2);
}
} else {
// Unknown element index, so ignore DemandedElts and demand them all.
Tmp = ComputeNumSignBits(InVec, Depth + 1);
DemandedVal = !!DemandedElts[EltIdx];
DemandedVecElts.clearBit(EltIdx);
}
Tmp = std::numeric_limits<unsigned>::max();
if (DemandedVal) {
// TODO - handle implicit truncation of inserted elements.
if (InVal.getScalarValueSizeInBits() != VTBits)
break;
Tmp2 = ComputeNumSignBits(InVal, Depth + 1);
Tmp = std::min(Tmp, Tmp2);
}
if (!!DemandedVecElts) {
Tmp2 = ComputeNumSignBits(InVec, DemandedVecElts, Depth + 1);
Tmp = std::min(Tmp, Tmp2);
}
assert(Tmp <= VTBits && "Failed to determine minimum sign bits");
return Tmp;
}