[RISCV] Fold (sra (add (shl X, 32), C1), 32 - C) -> (shl (sext_inreg (add X, C1), C)

Similar for a subtract with a constant left hand side.

(sra (add (shl X, 32), C1<<32), 32) is the canonical IR from InstCombine
for (sext (add (trunc X to i32), 32) to i32).

For RISCV, we should lower this as addiw which means turning it into
(sext_inreg (add X, C1)).

There is an existing DAG combine to convert back to (sext (add (trunc X
to i32), 32) to i32), but it requires isTruncateFree to return true
and for i32 to be a legal type as it used sign_extend and truncate
nodes. So that doesn't work for RISCV.

If the outer sra happens be used by a shl by constant, it will be
folded and the shift amount of the sra will be changed before we
can do our own DAG combine. This requires us to match the more
general pattern and restore the shl.

I had wanted to do this as a separate (add (shl X, 32), C1<<32) ->
(shl (add X, C1), 32) combine, but that hit an infinite loop for some
values of C1.

Reviewed By: asb

Differential Revision: https://reviews.llvm.org/D128869
This commit is contained in:
Craig Topper 2022-06-30 08:52:57 -07:00
parent 9ace5af049
commit 51d672946e
2 changed files with 66 additions and 33 deletions

View File

@ -8532,6 +8532,10 @@ static unsigned negateFMAOpcode(unsigned Opcode, bool NegMul, bool NegAcc) {
// Combine (sra (shl X, 32), 32 - C) -> (shl (sext_inreg X, i32), C)
// FIXME: Should this be a generic combine? There's a similar combine on X86.
//
// Also try these folds where an add or sub is in the middle.
// (sra (add (shl X, 32), C1), 32 - C) -> (shl (sext_inreg (add X, C1), C)
// (sra (sub C1, (shl X, 32)), 32 - C) -> (shl (sext_inreg (sub C1, X), C)
static SDValue performSRACombine(SDNode *N, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
assert(N->getOpcode() == ISD::SRA && "Unexpected opcode");
@ -8539,21 +8543,63 @@ static SDValue performSRACombine(SDNode *N, SelectionDAG &DAG,
if (N->getValueType(0) != MVT::i64 || !Subtarget.is64Bit())
return SDValue();
auto *C = dyn_cast<ConstantSDNode>(N->getOperand(1));
if (!C || C->getZExtValue() >= 32)
auto *ShAmtC = dyn_cast<ConstantSDNode>(N->getOperand(1));
if (!ShAmtC || ShAmtC->getZExtValue() > 32)
return SDValue();
SDValue N0 = N->getOperand(0);
if (N0.getOpcode() != ISD::SHL || !N0.hasOneUse() ||
!isa<ConstantSDNode>(N0.getOperand(1)) ||
N0.getConstantOperandVal(1) != 32)
SDValue Shl;
ConstantSDNode *AddC = nullptr;
// We might have an ADD or SUB between the SRA and SHL.
bool IsAdd = N0.getOpcode() == ISD::ADD;
if ((IsAdd || N0.getOpcode() == ISD::SUB)) {
if (!N0.hasOneUse())
return SDValue();
// Other operand needs to be a constant we can modify.
AddC = dyn_cast<ConstantSDNode>(N0.getOperand(IsAdd ? 1 : 0));
if (!AddC)
return SDValue();
// AddC needs to have at least 32 trailing zeros.
if (AddC->getAPIntValue().countTrailingZeros() < 32)
return SDValue();
Shl = N0.getOperand(IsAdd ? 0 : 1);
} else {
// Not an ADD or SUB.
Shl = N0;
}
// Look for a shift left by 32.
if (Shl.getOpcode() != ISD::SHL || !Shl.hasOneUse() ||
!isa<ConstantSDNode>(Shl.getOperand(1)) ||
Shl.getConstantOperandVal(1) != 32)
return SDValue();
SDLoc DL(N);
SDValue SExt = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, MVT::i64,
N0.getOperand(0), DAG.getValueType(MVT::i32));
return DAG.getNode(ISD::SHL, DL, MVT::i64, SExt,
DAG.getConstant(32 - C->getZExtValue(), DL, MVT::i64));
SDValue In = Shl.getOperand(0);
// If we looked through an ADD or SUB, we need to rebuild it with the shifted
// constant.
if (AddC) {
SDValue ShiftedAddC =
DAG.getConstant(AddC->getAPIntValue().lshr(32), DL, MVT::i64);
if (IsAdd)
In = DAG.getNode(ISD::ADD, DL, MVT::i64, In, ShiftedAddC);
else
In = DAG.getNode(ISD::SUB, DL, MVT::i64, ShiftedAddC, In);
}
SDValue SExt = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, MVT::i64, In,
DAG.getValueType(MVT::i32));
if (ShAmtC->getZExtValue() == 32)
return SExt;
return DAG.getNode(
ISD::SHL, DL, MVT::i64, SExt,
DAG.getConstant(32 - ShAmtC->getZExtValue(), DL, MVT::i64));
}
SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,

View File

@ -84,11 +84,7 @@ define i64 @test6(i32 signext %a, i32 signext %b) nounwind {
define i64 @test7(i32* %0, i64 %1) {
; RV64I-LABEL: test7:
; RV64I: # %bb.0:
; RV64I-NEXT: slli a0, a1, 32
; RV64I-NEXT: li a1, 1
; RV64I-NEXT: slli a1, a1, 32
; RV64I-NEXT: add a0, a0, a1
; RV64I-NEXT: srai a0, a0, 32
; RV64I-NEXT: addiw a0, a1, 1
; RV64I-NEXT: ret
%3 = shl i64 %1, 32
%4 = add i64 %3, 4294967296
@ -102,11 +98,8 @@ define i64 @test7(i32* %0, i64 %1) {
define i64 @test8(i32* %0, i64 %1) {
; RV64I-LABEL: test8:
; RV64I: # %bb.0:
; RV64I-NEXT: slli a0, a1, 32
; RV64I-NEXT: li a1, 1
; RV64I-NEXT: slli a1, a1, 32
; RV64I-NEXT: sub a0, a1, a0
; RV64I-NEXT: srai a0, a0, 32
; RV64I-NEXT: li a0, 1
; RV64I-NEXT: subw a0, a0, a1
; RV64I-NEXT: ret
%3 = mul i64 %1, -4294967296
%4 = add i64 %3, 4294967296
@ -119,11 +112,10 @@ define i64 @test8(i32* %0, i64 %1) {
define signext i32 @test9(i32* %0, i64 %1) {
; RV64I-LABEL: test9:
; RV64I: # %bb.0:
; RV64I-NEXT: slli a1, a1, 32
; RV64I-NEXT: lui a2, 4097
; RV64I-NEXT: slli a2, a2, 20
; RV64I-NEXT: add a1, a1, a2
; RV64I-NEXT: srai a1, a1, 30
; RV64I-NEXT: lui a2, 1
; RV64I-NEXT: addiw a2, a2, 1
; RV64I-NEXT: addw a1, a1, a2
; RV64I-NEXT: slli a1, a1, 2
; RV64I-NEXT: add a0, a0, a1
; RV64I-NEXT: lw a0, 0(a0)
; RV64I-NEXT: ret
@ -140,12 +132,10 @@ define signext i32 @test9(i32* %0, i64 %1) {
define signext i32 @test10(i32* %0, i64 %1) {
; RV64I-LABEL: test10:
; RV64I: # %bb.0:
; RV64I-NEXT: slli a1, a1, 32
; RV64I-NEXT: lui a2, 30141
; RV64I-NEXT: addiw a2, a2, -747
; RV64I-NEXT: slli a2, a2, 32
; RV64I-NEXT: sub a1, a2, a1
; RV64I-NEXT: srai a1, a1, 30
; RV64I-NEXT: subw a1, a2, a1
; RV64I-NEXT: slli a1, a1, 2
; RV64I-NEXT: add a0, a0, a1
; RV64I-NEXT: lw a0, 0(a0)
; RV64I-NEXT: ret
@ -160,11 +150,8 @@ define signext i32 @test10(i32* %0, i64 %1) {
define i64 @test11(i32* %0, i64 %1) {
; RV64I-LABEL: test11:
; RV64I: # %bb.0:
; RV64I-NEXT: slli a0, a1, 32
; RV64I-NEXT: li a1, -1
; RV64I-NEXT: slli a1, a1, 63
; RV64I-NEXT: sub a0, a1, a0
; RV64I-NEXT: srai a0, a0, 32
; RV64I-NEXT: lui a0, 524288
; RV64I-NEXT: subw a0, a0, a1
; RV64I-NEXT: ret
%3 = mul i64 %1, -4294967296
%4 = add i64 %3, 9223372036854775808 ;0x8000'0000'0000'0000