From b7c2e577cc8f9f92b7ce206ea7d6cba3eaa3f98c Mon Sep 17 00:00:00 2001 From: Craig Topper Date: Thu, 1 Apr 2021 09:41:36 -0700 Subject: [PATCH] [RISCV] Add custom type legalization to form MULHSU when possible. There's no target independent ISD opcode for MULHSU, so custom legalize 2*XLen multiplies ourselves. We have to be a little careful to prefer MULHU or MULHSU. I thought about doing this in isel by pattern matching the (add (mul X, (srai Y, XLen-1)), (mulhu X, Y)) pattern. I decided against this because the add might become part of a chain of adds. I don't trust DAG combine not to reassociate with other adds making it difficult to find both pieces again. Reviewed By: asb Differential Revision: https://reviews.llvm.org/D99479 --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 70 ++++++++++++++++----- llvm/lib/Target/RISCV/RISCVISelLowering.h | 2 + llvm/lib/Target/RISCV/RISCVInstrInfoM.td | 3 +- llvm/test/CodeGen/RISCV/mul.ll | 10 +-- 4 files changed, 62 insertions(+), 23 deletions(-) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 16a781751017..6cafa2791ed6 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -219,20 +219,23 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, setOperationAction(ISD::UDIV, XLenVT, Expand); setOperationAction(ISD::SREM, XLenVT, Expand); setOperationAction(ISD::UREM, XLenVT, Expand); - } + } else { + if (Subtarget.is64Bit()) { + setOperationAction(ISD::MUL, MVT::i32, Custom); + setOperationAction(ISD::MUL, MVT::i128, Custom); - if (Subtarget.is64Bit() && Subtarget.hasStdExtM()) { - setOperationAction(ISD::MUL, MVT::i32, Custom); - - setOperationAction(ISD::SDIV, MVT::i8, Custom); - setOperationAction(ISD::UDIV, MVT::i8, Custom); - setOperationAction(ISD::UREM, MVT::i8, Custom); - setOperationAction(ISD::SDIV, MVT::i16, Custom); - setOperationAction(ISD::UDIV, MVT::i16, Custom); - setOperationAction(ISD::UREM, MVT::i16, Custom); - setOperationAction(ISD::SDIV, MVT::i32, Custom); - setOperationAction(ISD::UDIV, MVT::i32, Custom); - setOperationAction(ISD::UREM, MVT::i32, Custom); + setOperationAction(ISD::SDIV, MVT::i8, Custom); + setOperationAction(ISD::UDIV, MVT::i8, Custom); + setOperationAction(ISD::UREM, MVT::i8, Custom); + setOperationAction(ISD::SDIV, MVT::i16, Custom); + setOperationAction(ISD::UDIV, MVT::i16, Custom); + setOperationAction(ISD::UREM, MVT::i16, Custom); + setOperationAction(ISD::SDIV, MVT::i32, Custom); + setOperationAction(ISD::UDIV, MVT::i32, Custom); + setOperationAction(ISD::UREM, MVT::i32, Custom); + } else { + setOperationAction(ISD::MUL, MVT::i64, Custom); + } } setOperationAction(ISD::SDIVREM, XLenVT, Expand); @@ -3868,9 +3871,47 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N, Results.push_back(RCW.getValue(2)); break; } + case ISD::MUL: { + unsigned Size = N->getSimpleValueType(0).getSizeInBits(); + unsigned XLen = Subtarget.getXLen(); + // This multiply needs to be expanded, try to use MULHSU+MUL if possible. + if (Size > XLen) { + assert(Size == (XLen * 2) && "Unexpected custom legalisation"); + SDValue LHS = N->getOperand(0); + SDValue RHS = N->getOperand(1); + APInt HighMask = APInt::getHighBitsSet(Size, XLen); + + bool LHSIsU = DAG.MaskedValueIsZero(LHS, HighMask); + bool RHSIsU = DAG.MaskedValueIsZero(RHS, HighMask); + // We need exactly one side to be unsigned. + if (LHSIsU == RHSIsU) + return; + + auto MakeMULPair = [&](SDValue S, SDValue U) { + MVT XLenVT = Subtarget.getXLenVT(); + S = DAG.getNode(ISD::TRUNCATE, DL, XLenVT, S); + U = DAG.getNode(ISD::TRUNCATE, DL, XLenVT, U); + SDValue Lo = DAG.getNode(ISD::MUL, DL, XLenVT, S, U); + SDValue Hi = DAG.getNode(RISCVISD::MULHSU, DL, XLenVT, S, U); + return DAG.getNode(ISD::BUILD_PAIR, DL, N->getValueType(0), Lo, Hi); + }; + + bool LHSIsS = DAG.ComputeNumSignBits(LHS) > XLen; + bool RHSIsS = DAG.ComputeNumSignBits(RHS) > XLen; + + // The other operand should be signed, but still prefer MULH when + // possible. + if (RHSIsU && LHSIsS && !RHSIsS) + Results.push_back(MakeMULPair(LHS, RHS)); + else if (LHSIsU && RHSIsS && !LHSIsS) + Results.push_back(MakeMULPair(RHS, LHS)); + + return; + } + LLVM_FALLTHROUGH; + } case ISD::ADD: case ISD::SUB: - case ISD::MUL: assert(N->getValueType(0) == MVT::i32 && Subtarget.is64Bit() && "Unexpected custom legalisation"); if (N->getOperand(1).getOpcode() == ISD::Constant) @@ -6784,6 +6825,7 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const { NODE_NAME_CASE(BuildPairF64) NODE_NAME_CASE(SplitF64) NODE_NAME_CASE(TAIL) + NODE_NAME_CASE(MULHSU) NODE_NAME_CASE(SLLW) NODE_NAME_CASE(SRAW) NODE_NAME_CASE(SRLW) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h index 20e96c625339..b17aa1527b79 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.h +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -40,6 +40,8 @@ enum NodeType : unsigned { BuildPairF64, SplitF64, TAIL, + // Multiply high for signedxunsigned. + MULHSU, // RV64I shifts, directly matching the semantics of the named RISC-V // instructions. SLLW, diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoM.td b/llvm/lib/Target/RISCV/RISCVInstrInfoM.td index e841d7fdea0b..f654ed1949a4 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoM.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoM.td @@ -15,6 +15,7 @@ // RISC-V specific DAG Nodes. //===----------------------------------------------------------------------===// +def riscv_mulhsu : SDNode<"RISCVISD::MULHSU", SDTIntBinOp>; def riscv_divw : SDNode<"RISCVISD::DIVW", SDT_RISCVIntBinOpW>; def riscv_divuw : SDNode<"RISCVISD::DIVUW", SDT_RISCVIntBinOpW>; def riscv_remuw : SDNode<"RISCVISD::REMUW", SDT_RISCVIntBinOpW>; @@ -63,7 +64,7 @@ let Predicates = [HasStdExtM] in { def : PatGprGpr; def : PatGprGpr; def : PatGprGpr; -// No ISDOpcode for mulhsu +def : PatGprGpr; def : PatGprGpr; def : PatGprGpr; def : PatGprGpr; diff --git a/llvm/test/CodeGen/RISCV/mul.ll b/llvm/test/CodeGen/RISCV/mul.ll index 00df918d6f63..2260233a4559 100644 --- a/llvm/test/CodeGen/RISCV/mul.ll +++ b/llvm/test/CodeGen/RISCV/mul.ll @@ -398,10 +398,7 @@ define i32 @mulhsu(i32 %a, i32 %b) nounwind { ; ; RV32IM-LABEL: mulhsu: ; RV32IM: # %bb.0: -; RV32IM-NEXT: srai a2, a1, 31 -; RV32IM-NEXT: mulhu a1, a0, a1 -; RV32IM-NEXT: mul a0, a0, a2 -; RV32IM-NEXT: add a0, a1, a0 +; RV32IM-NEXT: mulhsu a0, a1, a0 ; RV32IM-NEXT: ret ; ; RV64I-LABEL: mulhsu: @@ -1423,10 +1420,7 @@ define i64 @mulhsu_i64(i64 %a, i64 %b) nounwind { ; ; RV64IM-LABEL: mulhsu_i64: ; RV64IM: # %bb.0: -; RV64IM-NEXT: srai a2, a1, 63 -; RV64IM-NEXT: mulhu a1, a0, a1 -; RV64IM-NEXT: mul a0, a0, a2 -; RV64IM-NEXT: add a0, a1, a0 +; RV64IM-NEXT: mulhsu a0, a1, a0 ; RV64IM-NEXT: ret %1 = zext i64 %a to i128 %2 = sext i64 %b to i128