[TargetLowering] SimplifyDemandedBits - use getValidShiftAmountConstant helper.

Use the SelectionDAG::getValidShiftAmountConstant helper to get const/constsplat shift amounts, which allows us to drop the out of range shift amount early-out.

First step towards better non-uniform shift amount support in SimplifyDemandedBits.
This commit is contained in:
Simon Pilgrim 2020-02-21 13:29:40 +00:00
parent 35b685270b
commit 86c52af05a
2 changed files with 19 additions and 31 deletions

View File

@ -7571,8 +7571,7 @@ SDValue DAGCombiner::visitSHL(SDNode *N) {
return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0, NewOp1);
}
// TODO - support non-uniform vector shift amounts.
if (N1C && SimplifyDemandedBits(SDValue(N, 0)))
if (SimplifyDemandedBits(SDValue(N, 0)))
return SDValue(N, 0);
// fold (shl (shl x, c1), c2) -> 0 or (shl x, (add c1, c2))
@ -7938,8 +7937,7 @@ SDValue DAGCombiner::visitSRA(SDNode *N) {
}
// Simplify, based on bits shifted out of the LHS.
// TODO - support non-uniform vector shift amounts.
if (N1C && SimplifyDemandedBits(SDValue(N, 0)))
if (SimplifyDemandedBits(SDValue(N, 0)))
return SDValue(N, 0);
// If the sign bit is known to be zero, switch this to a SRL.
@ -8135,8 +8133,7 @@ SDValue DAGCombiner::visitSRL(SDNode *N) {
// fold operands of srl based on knowledge that the low bits are not
// demanded.
// TODO - support non-uniform vector shift amounts.
if (N1C && SimplifyDemandedBits(SDValue(N, 0)))
if (SimplifyDemandedBits(SDValue(N, 0)))
return SDValue(N, 0);
if (N1C && !N1C->isOpaque())

View File

@ -1365,11 +1365,8 @@ bool TargetLowering::SimplifyDemandedBits(
SDValue Op0 = Op.getOperand(0);
SDValue Op1 = Op.getOperand(1);
if (ConstantSDNode *SA = isConstOrConstSplat(Op1, DemandedElts)) {
// If the shift count is an invalid immediate, don't do anything.
if (SA->getAPIntValue().uge(BitWidth))
break;
if (const APInt *SA =
TLO.DAG.getValidShiftAmountConstant(Op, DemandedElts)) {
unsigned ShAmt = SA->getZExtValue();
if (ShAmt == 0)
return TLO.CombineTo(Op, Op0);
@ -1380,9 +1377,9 @@ bool TargetLowering::SimplifyDemandedBits(
// TODO - support non-uniform vector amounts.
if (Op0.getOpcode() == ISD::SRL) {
if (!DemandedBits.intersects(APInt::getLowBitsSet(BitWidth, ShAmt))) {
if (ConstantSDNode *SA2 =
isConstOrConstSplat(Op0.getOperand(1), DemandedElts)) {
if (SA2->getAPIntValue().ult(BitWidth)) {
if (const APInt *SA2 =
TLO.DAG.getValidShiftAmountConstant(Op0, DemandedElts)) {
if (SA2->ult(BitWidth)) {
unsigned C1 = SA2->getZExtValue();
unsigned Opc = ISD::SHL;
int Diff = ShAmt - C1;
@ -1434,8 +1431,8 @@ bool TargetLowering::SimplifyDemandedBits(
// x aren't demanded.
if (Op0.hasOneUse() && InnerOp.getOpcode() == ISD::SRL &&
InnerOp.hasOneUse()) {
if (ConstantSDNode *SA2 =
isConstOrConstSplat(InnerOp.getOperand(1))) {
if (const APInt *SA2 =
TLO.DAG.getValidShiftAmountConstant(InnerOp, DemandedElts)) {
unsigned InnerShAmt = SA2->getLimitedValue(InnerBits);
if (InnerShAmt < ShAmt && InnerShAmt < InnerBits &&
DemandedBits.getActiveBits() <=
@ -1463,11 +1460,8 @@ bool TargetLowering::SimplifyDemandedBits(
SDValue Op0 = Op.getOperand(0);
SDValue Op1 = Op.getOperand(1);
if (ConstantSDNode *SA = isConstOrConstSplat(Op1, DemandedElts)) {
// If the shift count is an invalid immediate, don't do anything.
if (SA->getAPIntValue().uge(BitWidth))
break;
if (const APInt *SA =
TLO.DAG.getValidShiftAmountConstant(Op, DemandedElts)) {
unsigned ShAmt = SA->getZExtValue();
if (ShAmt == 0)
return TLO.CombineTo(Op, Op0);
@ -1485,11 +1479,11 @@ bool TargetLowering::SimplifyDemandedBits(
// are never demanded.
// TODO - support non-uniform vector amounts.
if (Op0.getOpcode() == ISD::SHL) {
if (ConstantSDNode *SA2 =
isConstOrConstSplat(Op0.getOperand(1), DemandedElts)) {
if (const APInt *SA2 =
TLO.DAG.getValidShiftAmountConstant(Op0, DemandedElts)) {
if (!DemandedBits.intersects(
APInt::getHighBitsSet(BitWidth, ShAmt))) {
if (SA2->getAPIntValue().ult(BitWidth)) {
if (SA2->ult(BitWidth)) {
unsigned C1 = SA2->getZExtValue();
unsigned Opc = ISD::SRL;
int Diff = ShAmt - C1;
@ -1513,8 +1507,8 @@ bool TargetLowering::SimplifyDemandedBits(
assert(!Known.hasConflict() && "Bits known to be one AND zero?");
Known.Zero.lshrInPlace(ShAmt);
Known.One.lshrInPlace(ShAmt);
Known.Zero.setHighBits(ShAmt); // High bits known zero.
// High bits known zero.
Known.Zero.setHighBits(ShAmt);
}
break;
}
@ -1536,11 +1530,8 @@ bool TargetLowering::SimplifyDemandedBits(
if (DemandedBits.isOneValue())
return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SRL, dl, VT, Op0, Op1));
if (ConstantSDNode *SA = isConstOrConstSplat(Op1, DemandedElts)) {
// If the shift count is an invalid immediate, don't do anything.
if (SA->getAPIntValue().uge(BitWidth))
break;
if (const APInt *SA =
TLO.DAG.getValidShiftAmountConstant(Op, DemandedElts)) {
unsigned ShAmt = SA->getZExtValue();
if (ShAmt == 0)
return TLO.CombineTo(Op, Op0);