diff --git a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp index 3a3c6428add6..d2d915617627 100644 --- a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp +++ b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp @@ -2184,13 +2184,16 @@ SDNode *X86DAGToDAGISel::Select(SDNode *Node) { SDValue N1 = Node->getOperand(1); bool isSigned = Opcode == ISD::SMUL_LOHI; + bool hasBMI2 = Subtarget->hasBMI2(); if (!isSigned) { switch (NVT.getSimpleVT().SimpleTy) { default: llvm_unreachable("Unsupported VT!"); case MVT::i8: Opc = X86::MUL8r; MOpc = X86::MUL8m; break; case MVT::i16: Opc = X86::MUL16r; MOpc = X86::MUL16m; break; - case MVT::i32: Opc = X86::MUL32r; MOpc = X86::MUL32m; break; - case MVT::i64: Opc = X86::MUL64r; MOpc = X86::MUL64m; break; + case MVT::i32: Opc = hasBMI2 ? X86::MULX32rr : X86::MUL32r; + MOpc = hasBMI2 ? X86::MULX32rm : X86::MUL32m; break; + case MVT::i64: Opc = hasBMI2 ? X86::MULX64rr : X86::MUL64r; + MOpc = hasBMI2 ? X86::MULX64rm : X86::MUL64m; break; } } else { switch (NVT.getSimpleVT().SimpleTy) { @@ -2202,13 +2205,31 @@ SDNode *X86DAGToDAGISel::Select(SDNode *Node) { } } - unsigned LoReg, HiReg; - switch (NVT.getSimpleVT().SimpleTy) { - default: llvm_unreachable("Unsupported VT!"); - case MVT::i8: LoReg = X86::AL; HiReg = X86::AH; break; - case MVT::i16: LoReg = X86::AX; HiReg = X86::DX; break; - case MVT::i32: LoReg = X86::EAX; HiReg = X86::EDX; break; - case MVT::i64: LoReg = X86::RAX; HiReg = X86::RDX; break; + unsigned SrcReg, LoReg, HiReg; + switch (Opc) { + default: llvm_unreachable("Unknown MUL opcode!"); + case X86::IMUL8r: + case X86::MUL8r: + SrcReg = LoReg = X86::AL; HiReg = X86::AH; + break; + case X86::IMUL16r: + case X86::MUL16r: + SrcReg = LoReg = X86::AX; HiReg = X86::DX; + break; + case X86::IMUL32r: + case X86::MUL32r: + SrcReg = LoReg = X86::EAX; HiReg = X86::EDX; + break; + case X86::IMUL64r: + case X86::MUL64r: + SrcReg = LoReg = X86::RAX; HiReg = X86::RDX; + break; + case X86::MULX32rr: + SrcReg = X86::EDX; LoReg = HiReg = 0; + break; + case X86::MULX64rr: + SrcReg = X86::RDX; LoReg = HiReg = 0; + break; } SDValue Tmp0, Tmp1, Tmp2, Tmp3, Tmp4; @@ -2220,22 +2241,47 @@ SDNode *X86DAGToDAGISel::Select(SDNode *Node) { std::swap(N0, N1); } - SDValue InFlag = CurDAG->getCopyToReg(CurDAG->getEntryNode(), dl, LoReg, + SDValue InFlag = CurDAG->getCopyToReg(CurDAG->getEntryNode(), dl, SrcReg, N0, SDValue()).getValue(1); + SDValue ResHi, ResLo; if (foldedLoad) { + SDValue Chain; SDValue Ops[] = { Tmp0, Tmp1, Tmp2, Tmp3, Tmp4, N1.getOperand(0), InFlag }; - SDNode *CNode = - CurDAG->getMachineNode(MOpc, dl, MVT::Other, MVT::Glue, Ops, - array_lengthof(Ops)); - InFlag = SDValue(CNode, 1); + if (MOpc == X86::MULX32rm || MOpc == X86::MULX64rm) { + SDVTList VTs = CurDAG->getVTList(NVT, NVT, MVT::Other, MVT::Glue); + SDNode *CNode = CurDAG->getMachineNode(MOpc, dl, VTs, Ops, + array_lengthof(Ops)); + ResHi = SDValue(CNode, 0); + ResLo = SDValue(CNode, 1); + Chain = SDValue(CNode, 2); + InFlag = SDValue(CNode, 3); + } else { + SDVTList VTs = CurDAG->getVTList(MVT::Other, MVT::Glue); + SDNode *CNode = CurDAG->getMachineNode(MOpc, dl, VTs, Ops, + array_lengthof(Ops)); + Chain = SDValue(CNode, 0); + InFlag = SDValue(CNode, 1); + } // Update the chain. - ReplaceUses(N1.getValue(1), SDValue(CNode, 0)); + ReplaceUses(N1.getValue(1), Chain); } else { - SDNode *CNode = CurDAG->getMachineNode(Opc, dl, MVT::Glue, N1, InFlag); - InFlag = SDValue(CNode, 0); + SDValue Ops[] = { N1, InFlag }; + if (Opc == X86::MULX32rr || Opc == X86::MULX64rr) { + SDVTList VTs = CurDAG->getVTList(NVT, NVT, MVT::Glue); + SDNode *CNode = CurDAG->getMachineNode(Opc, dl, VTs, Ops, + array_lengthof(Ops)); + ResHi = SDValue(CNode, 0); + ResLo = SDValue(CNode, 1); + InFlag = SDValue(CNode, 2); + } else { + SDVTList VTs = CurDAG->getVTList(MVT::Glue); + SDNode *CNode = CurDAG->getMachineNode(Opc, dl, VTs, Ops, + array_lengthof(Ops)); + InFlag = SDValue(CNode, 0); + } } // Prevent use of AH in a REX instruction by referencing AX instead. @@ -2260,19 +2306,25 @@ SDNode *X86DAGToDAGISel::Select(SDNode *Node) { } // Copy the low half of the result, if it is needed. if (!SDValue(Node, 0).use_empty()) { - SDValue Result = CurDAG->getCopyFromReg(CurDAG->getEntryNode(), dl, - LoReg, NVT, InFlag); - InFlag = Result.getValue(2); - ReplaceUses(SDValue(Node, 0), Result); - DEBUG(dbgs() << "=> "; Result.getNode()->dump(CurDAG); dbgs() << '\n'); + if (ResLo.getNode() == 0) { + assert(LoReg && "Register for low half is not defined!"); + ResLo = CurDAG->getCopyFromReg(CurDAG->getEntryNode(), dl, LoReg, NVT, + InFlag); + InFlag = ResLo.getValue(2); + } + ReplaceUses(SDValue(Node, 0), ResLo); + DEBUG(dbgs() << "=> "; ResLo.getNode()->dump(CurDAG); dbgs() << '\n'); } // Copy the high half of the result, if it is needed. if (!SDValue(Node, 1).use_empty()) { - SDValue Result = CurDAG->getCopyFromReg(CurDAG->getEntryNode(), dl, - HiReg, NVT, InFlag); - InFlag = Result.getValue(2); - ReplaceUses(SDValue(Node, 1), Result); - DEBUG(dbgs() << "=> "; Result.getNode()->dump(CurDAG); dbgs() << '\n'); + if (ResHi.getNode() == 0) { + assert(HiReg && "Register for high half is not defined!"); + ResHi = CurDAG->getCopyFromReg(CurDAG->getEntryNode(), dl, HiReg, NVT, + InFlag); + InFlag = ResHi.getValue(2); + } + ReplaceUses(SDValue(Node, 1), ResHi); + DEBUG(dbgs() << "=> "; ResHi.getNode()->dump(CurDAG); dbgs() << '\n'); } return NULL; diff --git a/llvm/lib/Target/X86/X86InstrInfo.cpp b/llvm/lib/Target/X86/X86InstrInfo.cpp index 820ac06dc1c7..f575e8018449 100644 --- a/llvm/lib/Target/X86/X86InstrInfo.cpp +++ b/llvm/lib/Target/X86/X86InstrInfo.cpp @@ -1140,6 +1140,10 @@ X86InstrInfo::X86InstrInfo(X86TargetMachine &tm) { X86::VFMSUBADDPD4rr, X86::VFMSUBADDPD4mr, TB_ALIGN_16 }, { X86::VFMSUBADDPS4rrY, X86::VFMSUBADDPS4mrY, TB_ALIGN_32 }, { X86::VFMSUBADDPD4rrY, X86::VFMSUBADDPD4mrY, TB_ALIGN_32 }, + + // BMI/BMI2 foldable instructions + { X86::MULX32rr, X86::MULX32rm, 0 }, + { X86::MULX64rr, X86::MULX64rm, 0 }, }; for (unsigned i = 0, e = array_lengthof(OpTbl2); i != e; ++i) { diff --git a/llvm/test/CodeGen/X86/mulx32.ll b/llvm/test/CodeGen/X86/mulx32.ll new file mode 100644 index 000000000000..6a6450d500e2 --- /dev/null +++ b/llvm/test/CodeGen/X86/mulx32.ll @@ -0,0 +1,22 @@ +; RUN: llc -mcpu=core-avx2 -march=x86 < %s | FileCheck %s + +define i64 @f1(i32 %a, i32 %b) { + %x = zext i32 %a to i64 + %y = zext i32 %b to i64 + %r = mul i64 %x, %y +; CHECK: f1 +; CHECK: mulxl +; CHECK: ret + ret i64 %r +} + +define i64 @f2(i32 %a, i32* %p) { + %b = load i32* %p + %x = zext i32 %a to i64 + %y = zext i32 %b to i64 + %r = mul i64 %x, %y +; CHECK: f1 +; CHECK: mulxl ({{.+}}), %{{.+}}, %{{.+}} +; CHECK: ret + ret i64 %r +} diff --git a/llvm/test/CodeGen/X86/mulx64.ll b/llvm/test/CodeGen/X86/mulx64.ll new file mode 100644 index 000000000000..c59adc6069d7 --- /dev/null +++ b/llvm/test/CodeGen/X86/mulx64.ll @@ -0,0 +1,22 @@ +; RUN: llc -mcpu=core-avx2 -march=x86-64 < %s | FileCheck %s + +define i128 @f1(i64 %a, i64 %b) { + %x = zext i64 %a to i128 + %y = zext i64 %b to i128 + %r = mul i128 %x, %y +; CHECK: f1 +; CHECK: mulxq +; CHECK: ret + ret i128 %r +} + +define i128 @f2(i64 %a, i64* %p) { + %b = load i64* %p + %x = zext i64 %a to i128 + %y = zext i64 %b to i128 + %r = mul i128 %x, %y +; CHECK: f1 +; CHECK: mulxq ({{.+}}), %{{.+}}, %{{.+}} +; CHECK: ret + ret i128 %r +}