[RISCV] Add custom type legalization to form MULHSU when possible.

There's no target independent ISD opcode for MULHSU, so custom
legalize 2*XLen multiplies ourselves. We have to be a little
careful to prefer MULHU or MULHSU.

I thought about doing this in isel by pattern matching the
(add (mul X, (srai Y, XLen-1)), (mulhu X, Y)) pattern. I decided
against this because the add might become part of a chain of adds.
I don't trust DAG combine not to reassociate with other adds making
it difficult to find both pieces again.

Reviewed By: asb

Differential Revision: https://reviews.llvm.org/D99479
This commit is contained in:
Craig Topper 2021-04-01 09:41:36 -07:00
parent dadcd940f0
commit b7c2e577cc
4 changed files with 62 additions and 23 deletions

View File

@ -219,10 +219,10 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::UDIV, XLenVT, Expand);
setOperationAction(ISD::SREM, XLenVT, Expand);
setOperationAction(ISD::UREM, XLenVT, Expand);
}
if (Subtarget.is64Bit() && Subtarget.hasStdExtM()) {
} else {
if (Subtarget.is64Bit()) {
setOperationAction(ISD::MUL, MVT::i32, Custom);
setOperationAction(ISD::MUL, MVT::i128, Custom);
setOperationAction(ISD::SDIV, MVT::i8, Custom);
setOperationAction(ISD::UDIV, MVT::i8, Custom);
@ -233,6 +233,9 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::SDIV, MVT::i32, Custom);
setOperationAction(ISD::UDIV, MVT::i32, Custom);
setOperationAction(ISD::UREM, MVT::i32, Custom);
} else {
setOperationAction(ISD::MUL, MVT::i64, Custom);
}
}
setOperationAction(ISD::SDIVREM, XLenVT, Expand);
@ -3868,9 +3871,47 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
Results.push_back(RCW.getValue(2));
break;
}
case ISD::MUL: {
unsigned Size = N->getSimpleValueType(0).getSizeInBits();
unsigned XLen = Subtarget.getXLen();
// This multiply needs to be expanded, try to use MULHSU+MUL if possible.
if (Size > XLen) {
assert(Size == (XLen * 2) && "Unexpected custom legalisation");
SDValue LHS = N->getOperand(0);
SDValue RHS = N->getOperand(1);
APInt HighMask = APInt::getHighBitsSet(Size, XLen);
bool LHSIsU = DAG.MaskedValueIsZero(LHS, HighMask);
bool RHSIsU = DAG.MaskedValueIsZero(RHS, HighMask);
// We need exactly one side to be unsigned.
if (LHSIsU == RHSIsU)
return;
auto MakeMULPair = [&](SDValue S, SDValue U) {
MVT XLenVT = Subtarget.getXLenVT();
S = DAG.getNode(ISD::TRUNCATE, DL, XLenVT, S);
U = DAG.getNode(ISD::TRUNCATE, DL, XLenVT, U);
SDValue Lo = DAG.getNode(ISD::MUL, DL, XLenVT, S, U);
SDValue Hi = DAG.getNode(RISCVISD::MULHSU, DL, XLenVT, S, U);
return DAG.getNode(ISD::BUILD_PAIR, DL, N->getValueType(0), Lo, Hi);
};
bool LHSIsS = DAG.ComputeNumSignBits(LHS) > XLen;
bool RHSIsS = DAG.ComputeNumSignBits(RHS) > XLen;
// The other operand should be signed, but still prefer MULH when
// possible.
if (RHSIsU && LHSIsS && !RHSIsS)
Results.push_back(MakeMULPair(LHS, RHS));
else if (LHSIsU && RHSIsS && !LHSIsS)
Results.push_back(MakeMULPair(RHS, LHS));
return;
}
LLVM_FALLTHROUGH;
}
case ISD::ADD:
case ISD::SUB:
case ISD::MUL:
assert(N->getValueType(0) == MVT::i32 && Subtarget.is64Bit() &&
"Unexpected custom legalisation");
if (N->getOperand(1).getOpcode() == ISD::Constant)
@ -6784,6 +6825,7 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(BuildPairF64)
NODE_NAME_CASE(SplitF64)
NODE_NAME_CASE(TAIL)
NODE_NAME_CASE(MULHSU)
NODE_NAME_CASE(SLLW)
NODE_NAME_CASE(SRAW)
NODE_NAME_CASE(SRLW)

View File

@ -40,6 +40,8 @@ enum NodeType : unsigned {
BuildPairF64,
SplitF64,
TAIL,
// Multiply high for signedxunsigned.
MULHSU,
// RV64I shifts, directly matching the semantics of the named RISC-V
// instructions.
SLLW,

View File

@ -15,6 +15,7 @@
// RISC-V specific DAG Nodes.
//===----------------------------------------------------------------------===//
def riscv_mulhsu : SDNode<"RISCVISD::MULHSU", SDTIntBinOp>;
def riscv_divw : SDNode<"RISCVISD::DIVW", SDT_RISCVIntBinOpW>;
def riscv_divuw : SDNode<"RISCVISD::DIVUW", SDT_RISCVIntBinOpW>;
def riscv_remuw : SDNode<"RISCVISD::REMUW", SDT_RISCVIntBinOpW>;
@ -63,7 +64,7 @@ let Predicates = [HasStdExtM] in {
def : PatGprGpr<mul, MUL>;
def : PatGprGpr<mulhs, MULH>;
def : PatGprGpr<mulhu, MULHU>;
// No ISDOpcode for mulhsu
def : PatGprGpr<riscv_mulhsu, MULHSU>;
def : PatGprGpr<sdiv, DIV>;
def : PatGprGpr<udiv, DIVU>;
def : PatGprGpr<srem, REM>;

View File

@ -398,10 +398,7 @@ define i32 @mulhsu(i32 %a, i32 %b) nounwind {
;
; RV32IM-LABEL: mulhsu:
; RV32IM: # %bb.0:
; RV32IM-NEXT: srai a2, a1, 31
; RV32IM-NEXT: mulhu a1, a0, a1
; RV32IM-NEXT: mul a0, a0, a2
; RV32IM-NEXT: add a0, a1, a0
; RV32IM-NEXT: mulhsu a0, a1, a0
; RV32IM-NEXT: ret
;
; RV64I-LABEL: mulhsu:
@ -1423,10 +1420,7 @@ define i64 @mulhsu_i64(i64 %a, i64 %b) nounwind {
;
; RV64IM-LABEL: mulhsu_i64:
; RV64IM: # %bb.0:
; RV64IM-NEXT: srai a2, a1, 63
; RV64IM-NEXT: mulhu a1, a0, a1
; RV64IM-NEXT: mul a0, a0, a2
; RV64IM-NEXT: add a0, a1, a0
; RV64IM-NEXT: mulhsu a0, a1, a0
; RV64IM-NEXT: ret
%1 = zext i64 %a to i128
%2 = sext i64 %b to i128