diff --git a/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp b/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp index 0bd0f28de6f9..f2ab1ec51a9d 100644 --- a/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp +++ b/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp @@ -2095,6 +2095,8 @@ HexagonTargetLowering::HexagonTargetLowering(const TargetMachine &TM, } setOperationAction(ISD::MUL, T, Custom); + setOperationAction(ISD::MULHS, T, Custom); + setOperationAction(ISD::MULHU, T, Custom); setOperationAction(ISD::SETCC, T, Custom); setOperationAction(ISD::BUILD_VECTOR, T, Custom); setOperationAction(ISD::INSERT_SUBVECTOR, T, Custom); @@ -3018,6 +3020,11 @@ HexagonTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { if (Subtarget.useHVXOps()) return LowerHvxMul(Op, DAG); break; + case ISD::MULHS: + case ISD::MULHU: + if (Subtarget.useHVXOps()) + return LowerHvxMulh(Op, DAG); + break; } return SDValue(); } diff --git a/llvm/lib/Target/Hexagon/HexagonISelLowering.h b/llvm/lib/Target/Hexagon/HexagonISelLowering.h index 66214a409b32..4330cfb7302f 100644 --- a/llvm/lib/Target/Hexagon/HexagonISelLowering.h +++ b/llvm/lib/Target/Hexagon/HexagonISelLowering.h @@ -360,6 +360,7 @@ namespace HexagonISD { SDValue LowerHvxExtractSubvector(SDValue Op, SelectionDAG &DAG) const; SDValue LowerHvxInsertSubvector(SDValue Op, SelectionDAG &DAG) const; SDValue LowerHvxMul(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerHvxMulh(SDValue Op, SelectionDAG &DAG) const; SDValue LowerHvxSetCC(SDValue Op, SelectionDAG &DAG) const; SDValue LowerHvxExtend(SDValue Op, SelectionDAG &DAG) const; diff --git a/llvm/lib/Target/Hexagon/HexagonISelLoweringHVX.cpp b/llvm/lib/Target/Hexagon/HexagonISelLoweringHVX.cpp index 6488b9aee0ab..acf8b3e1f27f 100644 --- a/llvm/lib/Target/Hexagon/HexagonISelLoweringHVX.cpp +++ b/llvm/lib/Target/Hexagon/HexagonISelLoweringHVX.cpp @@ -412,8 +412,7 @@ HexagonTargetLowering::LowerHvxInsertSubvector(SDValue Op, SelectionDAG &DAG) SDValue HexagonTargetLowering::LowerHvxMul(SDValue Op, SelectionDAG &DAG) const { MVT ResTy = ty(Op); - if (!ResTy.isVector()) - return SDValue(); + assert(ResTy.isVector()); const SDLoc &dl(Op); SmallVector ShuffMask; @@ -424,7 +423,7 @@ HexagonTargetLowering::LowerHvxMul(SDValue Op, SelectionDAG &DAG) const { switch (ElemTy.SimpleTy) { case MVT::i8: - case MVT::i16: { + case MVT::i16: { // V6_vmpyih // For i8 vectors Vs = (a0, a1, ...), Vt = (b0, b1, ...), // V6_vmpybv Vs, Vt produces a pair of i16 vectors Hi:Lo, // where Lo = (a0*b0, a2*b2, ...), Hi = (a1*b1, a3*b3, ...). @@ -463,6 +462,112 @@ HexagonTargetLowering::LowerHvxMul(SDValue Op, SelectionDAG &DAG) const { return SDValue(); } +SDValue +HexagonTargetLowering::LowerHvxMulh(SDValue Op, SelectionDAG &DAG) const { + MVT ResTy = ty(Op); + assert(ResTy.isVector()); + const SDLoc &dl(Op); + SmallVector ShuffMask; + + MVT ElemTy = ResTy.getVectorElementType(); + unsigned VecLen = ResTy.getVectorNumElements(); + SDValue Vs = Op.getOperand(0); + SDValue Vt = Op.getOperand(1); + bool IsSigned = Op.getOpcode() == ISD::MULHS; + + if (ElemTy == MVT::i8 || ElemTy == MVT::i16) { + // For i8 vectors Vs = (a0, a1, ...), Vt = (b0, b1, ...), + // V6_vmpybv Vs, Vt produces a pair of i16 vectors Hi:Lo, + // where Lo = (a0*b0, a2*b2, ...), Hi = (a1*b1, a3*b3, ...). + // For i16, use V6_vmpyhv, which behaves in an analogous way to + // V6_vmpybv: results Lo and Hi are products of even/odd elements + // respectively. + MVT ExtTy = typeExtElem(ResTy, 2); + unsigned MpyOpc = ElemTy == MVT::i8 + ? (IsSigned ? Hexagon::V6_vmpybv : Hexagon::V6_vmpyubv) + : (IsSigned ? Hexagon::V6_vmpyhv : Hexagon::V6_vmpyuhv); + SDValue M = getNode(MpyOpc, dl, ExtTy, {Vs, Vt}, DAG); + + // Discard low halves of the resulting values, collect the high halves. + for (unsigned I = 0; I < VecLen; I += 2) { + ShuffMask.push_back(I+1); // Pick even element. + ShuffMask.push_back(I+VecLen+1); // Pick odd element. + } + VectorPair P = opSplit(opCastElem(M, ElemTy, DAG), dl, DAG); + SDValue BS = getByteShuffle(dl, P.first, P.second, ShuffMask, DAG); + return DAG.getBitcast(ResTy, BS); + } + + assert(ElemTy == MVT::i32); + SDValue S16 = DAG.getConstant(16, dl, MVT::i32); + + if (IsSigned) { + // mulhs(Vs,Vt) = + // = [(Hi(Vs)*2^16 + Lo(Vs)) *s (Hi(Vt)*2^16 + Lo(Vt))] >> 32 + // = [Hi(Vs)*2^16 *s Hi(Vt)*2^16 + Hi(Vs) *su Lo(Vt)*2^16 + // + Lo(Vs) *us (Hi(Vt)*2^16 + Lo(Vt))] >> 32 + // = [Hi(Vs) *s Hi(Vt)*2^32 + Hi(Vs) *su Lo(Vt)*2^16 + // + Lo(Vs) *us Vt] >> 32 + // The low half of Lo(Vs)*Lo(Vt) will be discarded (it's not added to + // anything, so it cannot produce any carry over to higher bits), + // so everything in [] can be shifted by 16 without loss of precision. + // = [Hi(Vs) *s Hi(Vt)*2^16 + Hi(Vs)*su Lo(Vt) + Lo(Vs)*Vt >> 16] >> 16 + // = [Hi(Vs) *s Hi(Vt)*2^16 + Hi(Vs)*su Lo(Vt) + V6_vmpyewuh(Vs,Vt)] >> 16 + // Denote Hi(Vs) = Vs': + // = [Vs'*s Hi(Vt)*2^16 + Vs' *su Lo(Vt) + V6_vmpyewuh(Vt,Vs)] >> 16 + // = Vs'*s Hi(Vt) + (V6_vmpyiewuh(Vs',Vt) + V6_vmpyewuh(Vt,Vs)) >> 16 + SDValue T0 = getNode(Hexagon::V6_vmpyewuh, dl, ResTy, {Vt, Vs}, DAG); + // Get Vs': + SDValue S0 = getNode(Hexagon::V6_vasrw, dl, ResTy, {Vs, S16}, DAG); + SDValue T1 = getNode(Hexagon::V6_vmpyiewuh_acc, dl, ResTy, + {T0, S0, Vt}, DAG); + // Shift by 16: + SDValue S2 = getNode(Hexagon::V6_vasrw, dl, ResTy, {T1, S16}, DAG); + // Get Vs'*Hi(Vt): + SDValue T2 = getNode(Hexagon::V6_vmpyiowh, dl, ResTy, {S0, Vt}, DAG); + // Add: + SDValue T3 = DAG.getNode(ISD::ADD, dl, ResTy, {S2, T2}); + return T3; + } + + // Unsigned mulhw. (Would expansion using signed mulhw be better?) + + auto LoVec = [&DAG,ResTy,dl] (SDValue Pair) { + return DAG.getTargetExtractSubreg(Hexagon::vsub_lo, dl, ResTy, Pair); + }; + auto HiVec = [&DAG,ResTy,dl] (SDValue Pair) { + return DAG.getTargetExtractSubreg(Hexagon::vsub_hi, dl, ResTy, Pair); + }; + + MVT PairTy = typeJoin({ResTy, ResTy}); + SDValue P = getNode(Hexagon::V6_lvsplatw, dl, ResTy, + {DAG.getConstant(0x02020202, dl, MVT::i32)}, DAG); + // Multiply-unsigned halfwords: + // LoVec = Vs.uh[2i] * Vt.uh[2i], + // HiVec = Vs.uh[2i+1] * Vt.uh[2i+1] + SDValue T0 = getNode(Hexagon::V6_vmpyuhv, dl, PairTy, {Vs, Vt}, DAG); + // The low halves in the LoVec of the pair can be discarded. They are + // not added to anything (in the full-precision product), so they cannot + // produce a carry into the higher bits. + SDValue T1 = getNode(Hexagon::V6_vlsrw, dl, ResTy, {LoVec(T0), S16}, DAG); + // Swap low and high halves in Vt, and do the halfword multiplication + // to get products Vs.uh[2i] * Vt.uh[2i+1] and Vs.uh[2i+1] * Vt.uh[2i]. + SDValue D0 = getNode(Hexagon::V6_vdelta, dl, ResTy, {Vt, P}, DAG); + SDValue T2 = getNode(Hexagon::V6_vmpyuhv, dl, PairTy, {Vs, D0}, DAG); + // T2 has mixed products of halfwords: Lo(Vt)*Hi(Vs) and Hi(Vt)*Lo(Vs). + // These products are words, but cannot be added directly because the + // sums could overflow. Add these products, by halfwords, where each sum + // of a pair of halfwords gives a word. + SDValue T3 = getNode(Hexagon::V6_vadduhw, dl, PairTy, + {LoVec(T2), HiVec(T2)}, DAG); + // Add the high halfwords from the products of the low halfwords. + SDValue T4 = DAG.getNode(ISD::ADD, dl, ResTy, {T1, LoVec(T3)}); + SDValue T5 = getNode(Hexagon::V6_vlsrw, dl, ResTy, {T4, S16}, DAG); + SDValue T6 = DAG.getNode(ISD::ADD, dl, ResTy, {HiVec(T0), HiVec(T3)}); + SDValue T7 = DAG.getNode(ISD::ADD, dl, ResTy, {T5, T6}); + return T7; +} + SDValue HexagonTargetLowering::LowerHvxSetCC(SDValue Op, SelectionDAG &DAG) const { MVT VecTy = ty(Op.getOperand(0)); diff --git a/llvm/lib/Target/Hexagon/HexagonPatterns.td b/llvm/lib/Target/Hexagon/HexagonPatterns.td index 7355b2d8eb81..bf1b55b7b891 100644 --- a/llvm/lib/Target/Hexagon/HexagonPatterns.td +++ b/llvm/lib/Target/Hexagon/HexagonPatterns.td @@ -1274,6 +1274,56 @@ def: AccRRI_pat, I32, s32_0ImmPred>; def: AccRRI_pat, I32, s32_0ImmPred>; def: AccRRR_pat, I32, I32>; +// Mulh for vectors +// +def: Pat<(v2i32 (mulhu V2I32:$Rss, V2I32:$Rtt)), + (Combinew (M2_mpyu_up (HiReg $Rss), (HiReg $Rtt)), + (M2_mpyu_up (LoReg $Rss), (LoReg $Rtt)))>; + +def: Pat<(v2i32 (mulhs V2I32:$Rs, V2I32:$Rt)), + (Combinew (M2_mpy_up (HiReg $Rs), (HiReg $Rt)), + (M2_mpy_up (LoReg $Rt), (LoReg $Rt)))>; + +def Mulhub: + OutPatFrag<(ops node:$Rss, node:$Rtt), + (Combinew (S2_vtrunohb (M5_vmpybuu (HiReg $Rss), (HiReg $Rtt))), + (S2_vtrunohb (M5_vmpybuu (LoReg $Rss), (LoReg $Rtt))))>; + +// Equivalent of byte-wise arithmetic shift right by 7 in v8i8. +def Asr7: + OutPatFrag<(ops node:$Rss), (C2_mask (C2_not (A4_vcmpbgti $Rss, 0)))>; + +def: Pat<(v8i8 (mulhu V8I8:$Rss, V8I8:$Rtt)), + (Mulhub $Rss, $Rtt)>; + +def: Pat<(v8i8 (mulhs V8I8:$Rss, V8I8:$Rtt)), + (A2_vsubub + (Mulhub $Rss, $Rtt), + (A2_vaddub (A2_andp V8I8:$Rss, (Asr7 $Rtt)), + (A2_andp V8I8:$Rtt, (Asr7 $Rss))))>; + +def Mpysh: + OutPatFrag<(ops node:$Rs, node:$Rt), (M2_vmpy2s_s0 $Rs, $Rt)>; +def Mpyshh: + OutPatFrag<(ops node:$Rss, node:$Rtt), (Mpysh (HiReg $Rss), (HiReg $Rtt))>; +def Mpyshl: + OutPatFrag<(ops node:$Rss, node:$Rtt), (Mpysh (LoReg $Rss), (LoReg $Rtt))>; + +def Mulhsh: + OutPatFrag<(ops node:$Rss, node:$Rtt), + (Combinew (A2_combine_hh (HiReg (Mpyshh $Rss, $Rtt)), + (LoReg (Mpyshh $Rss, $Rtt))), + (A2_combine_hh (HiReg (Mpyshl $Rss, $Rtt)), + (LoReg (Mpyshl $Rss, $Rtt))))>; + +def: Pat<(v4i16 (mulhs V4I16:$Rss, V4I16:$Rtt)), (Mulhsh $Rss, $Rtt)>; + +def: Pat<(v4i16 (mulhu V4I16:$Rss, V4I16:$Rtt)), + (A2_vaddh + (Mulhsh $Rss, $Rtt), + (A2_vaddh (A2_andp V4I16:$Rss, (S2_asr_i_vh $Rtt, 15)), + (A2_andp V4I16:$Rtt, (S2_asr_i_vh $Rss, 15))))>; + def: Pat<(ineg (mul I32:$Rs, u8_0ImmPred:$u8)), (M2_mpysin IntRegs:$Rs, imm:$u8)>;