[RISCV] Reduce duplicate code for calling SimplifyDemandedBits.

This encapsulates the APInt creation and worklist management into
a helper function.

To keep one common interface I've use Log2_32 in places that
previously created a mask by subtracting 1 from a power of 2.

Differential Revision: https://reviews.llvm.org/D108324
This commit is contained in:
Craig Topper 2021-08-18 12:21:04 -07:00
parent 765a421276
commit 36d8316cc8
1 changed files with 34 additions and 71 deletions

View File

@ -5968,6 +5968,20 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
SelectionDAG &DAG = DCI.DAG;
// Helper to call SimplifyDemandedBits on an operand of N where only some low
// bits are demanded. N will be added to the Worklist if it was not deleted.
// Caller should return SDValue(N, 0) if this returns true.
auto SimplifyDemandedLowBitsHelper = [&](unsigned OpNo, unsigned LowBits) {
SDValue Op = N->getOperand(OpNo);
APInt Mask = APInt::getLowBitsSet(Op.getValueSizeInBits(), LowBits);
if (!SimplifyDemandedBits(Op, Mask, DCI))
return false;
if (N->getOpcode() != ISD::DELETED_NODE)
DCI.AddToWorklist(N);
return true;
};
switch (N->getOpcode()) {
default:
break;
@ -6019,136 +6033,85 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
case RISCVISD::ROLW:
case RISCVISD::RORW: {
// Only the lower 32 bits of LHS and lower 5 bits of RHS are read.
SDValue LHS = N->getOperand(0);
SDValue RHS = N->getOperand(1);
APInt LHSMask = APInt::getLowBitsSet(LHS.getValueSizeInBits(), 32);
APInt RHSMask = APInt::getLowBitsSet(RHS.getValueSizeInBits(), 5);
if (SimplifyDemandedBits(N->getOperand(0), LHSMask, DCI) ||
SimplifyDemandedBits(N->getOperand(1), RHSMask, DCI)) {
if (N->getOpcode() != ISD::DELETED_NODE)
DCI.AddToWorklist(N);
if (SimplifyDemandedLowBitsHelper(0, 32) ||
SimplifyDemandedLowBitsHelper(1, 5))
return SDValue(N, 0);
}
break;
}
case RISCVISD::CLZW:
case RISCVISD::CTZW: {
// Only the lower 32 bits of the first operand are read
SDValue Op0 = N->getOperand(0);
APInt Mask = APInt::getLowBitsSet(Op0.getValueSizeInBits(), 32);
if (SimplifyDemandedBits(Op0, Mask, DCI)) {
if (N->getOpcode() != ISD::DELETED_NODE)
DCI.AddToWorklist(N);
if (SimplifyDemandedLowBitsHelper(0, 32))
return SDValue(N, 0);
}
break;
}
case RISCVISD::FSL:
case RISCVISD::FSR: {
// Only the lower log2(Bitwidth)+1 bits of the the shift amount are read.
SDValue ShAmt = N->getOperand(2);
unsigned BitWidth = ShAmt.getValueSizeInBits();
unsigned BitWidth = N->getOperand(2).getValueSizeInBits();
assert(isPowerOf2_32(BitWidth) && "Unexpected bit width");
APInt ShAmtMask(BitWidth, (BitWidth * 2) - 1);
if (SimplifyDemandedBits(ShAmt, ShAmtMask, DCI)) {
if (N->getOpcode() != ISD::DELETED_NODE)
DCI.AddToWorklist(N);
if (SimplifyDemandedLowBitsHelper(2, Log2_32(BitWidth) + 1))
return SDValue(N, 0);
}
break;
}
case RISCVISD::FSLW:
case RISCVISD::FSRW: {
// Only the lower 32 bits of Values and lower 6 bits of shift amount are
// read.
SDValue Op0 = N->getOperand(0);
SDValue Op1 = N->getOperand(1);
SDValue ShAmt = N->getOperand(2);
APInt OpMask = APInt::getLowBitsSet(Op0.getValueSizeInBits(), 32);
APInt ShAmtMask = APInt::getLowBitsSet(ShAmt.getValueSizeInBits(), 6);
if (SimplifyDemandedBits(Op0, OpMask, DCI) ||
SimplifyDemandedBits(Op1, OpMask, DCI) ||
SimplifyDemandedBits(ShAmt, ShAmtMask, DCI)) {
if (N->getOpcode() != ISD::DELETED_NODE)
DCI.AddToWorklist(N);
if (SimplifyDemandedLowBitsHelper(0, 32) ||
SimplifyDemandedLowBitsHelper(1, 32) ||
SimplifyDemandedLowBitsHelper(2, 6))
return SDValue(N, 0);
}
break;
}
case RISCVISD::GREV:
case RISCVISD::GORC: {
// Only the lower log2(Bitwidth) bits of the the shift amount are read.
SDValue ShAmt = N->getOperand(1);
unsigned BitWidth = ShAmt.getValueSizeInBits();
unsigned BitWidth = N->getOperand(1).getValueSizeInBits();
assert(isPowerOf2_32(BitWidth) && "Unexpected bit width");
APInt ShAmtMask(BitWidth, BitWidth - 1);
if (SimplifyDemandedBits(ShAmt, ShAmtMask, DCI)) {
if (N->getOpcode() != ISD::DELETED_NODE)
DCI.AddToWorklist(N);
if (SimplifyDemandedLowBitsHelper(1, Log2_32(BitWidth)))
return SDValue(N, 0);
}
return combineGREVI_GORCI(N, DCI.DAG);
}
case RISCVISD::GREVW:
case RISCVISD::GORCW: {
// Only the lower 32 bits of LHS and lower 5 bits of RHS are read.
SDValue LHS = N->getOperand(0);
SDValue RHS = N->getOperand(1);
APInt LHSMask = APInt::getLowBitsSet(LHS.getValueSizeInBits(), 32);
APInt RHSMask = APInt::getLowBitsSet(RHS.getValueSizeInBits(), 5);
if (SimplifyDemandedBits(LHS, LHSMask, DCI) ||
SimplifyDemandedBits(RHS, RHSMask, DCI)) {
if (N->getOpcode() != ISD::DELETED_NODE)
DCI.AddToWorklist(N);
if (SimplifyDemandedLowBitsHelper(0, 32) ||
SimplifyDemandedLowBitsHelper(1, 5))
return SDValue(N, 0);
}
return combineGREVI_GORCI(N, DCI.DAG);
}
case RISCVISD::SHFL:
case RISCVISD::UNSHFL: {
// Only the lower log2(Bitwidth) bits of the the shift amount are read.
SDValue ShAmt = N->getOperand(1);
unsigned BitWidth = ShAmt.getValueSizeInBits();
// Only the lower log2(Bitwidth)-1 bits of the the shift amount are read.
unsigned BitWidth = N->getOperand(1).getValueSizeInBits();
assert(isPowerOf2_32(BitWidth) && "Unexpected bit width");
APInt ShAmtMask(BitWidth, (BitWidth / 2) - 1);
if (SimplifyDemandedBits(ShAmt, ShAmtMask, DCI)) {
if (N->getOpcode() != ISD::DELETED_NODE)
DCI.AddToWorklist(N);
if (SimplifyDemandedLowBitsHelper(1, Log2_32(BitWidth) - 1))
return SDValue(N, 0);
}
break;
}
case RISCVISD::SHFLW:
case RISCVISD::UNSHFLW: {
// Only the lower 32 bits of LHS and lower 5 bits of RHS are read.
// Only the lower 32 bits of LHS and lower 4 bits of RHS are read.
SDValue LHS = N->getOperand(0);
SDValue RHS = N->getOperand(1);
APInt LHSMask = APInt::getLowBitsSet(LHS.getValueSizeInBits(), 32);
APInt RHSMask = APInt::getLowBitsSet(RHS.getValueSizeInBits(), 4);
if (SimplifyDemandedBits(LHS, LHSMask, DCI) ||
SimplifyDemandedBits(RHS, RHSMask, DCI)) {
if (N->getOpcode() != ISD::DELETED_NODE)
DCI.AddToWorklist(N);
if (SimplifyDemandedLowBitsHelper(0, 32) ||
SimplifyDemandedLowBitsHelper(1, 4))
return SDValue(N, 0);
}
break;
}
case RISCVISD::BCOMPRESSW:
case RISCVISD::BDECOMPRESSW: {
// Only the lower 32 bits of LHS and RHS are read.
SDValue LHS = N->getOperand(0);
SDValue RHS = N->getOperand(1);
APInt Mask = APInt::getLowBitsSet(LHS.getValueSizeInBits(), 32);
if (SimplifyDemandedBits(LHS, Mask, DCI) ||
SimplifyDemandedBits(RHS, Mask, DCI)) {
if (N->getOpcode() != ISD::DELETED_NODE)
DCI.AddToWorklist(N);
if (SimplifyDemandedLowBitsHelper(0, 32) ||
SimplifyDemandedLowBitsHelper(1, 32))
return SDValue(N, 0);
}
break;
}