From 24dfba8d507e383eeccfc0cfa30ad6340f71a377 Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper@sifive.com>
Date: Fri, 6 Aug 2021 22:19:38 -0700
Subject: [PATCH] [X86] Teach shouldSinkOperands to recognize pmuldq/pmuludq
 patterns.

The IR for pmuldq/pmuludq intrinsics uses a sext_inreg/zext_inreg
pattern on the inputs. Ideally we pattern match these away during
isel. It is possible for LICM or other middle end optimizations
to separate the extend from the mul. This prevents SelectionDAG
from removing it or depending on how the extend is lowered, we
may not be able to generate an AssertSExt/AssertZExt in the
mul basic block. This will prevent pmuldq/pmuludq from being
formed at all.

This patch teaches shouldSinkOperands to recognize this so
that CodeGenPrepare will clone the extend into the same basic
block as the mul.

Fixes PR51371.

Differential Revision: https://reviews.llvm.org/D107689
---
 llvm/lib/Target/X86/X86ISelLowering.cpp | 33 ++++++++++++++++++++++++-
 llvm/test/CodeGen/X86/pr51371.ll        | 28 ++++-----------------
 2 files changed, 37 insertions(+), 24 deletions(-)

diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index a147993c1ef8..a316f2430751 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -48,9 +48,10 @@
 #include "llvm/IR/Function.h"
 #include "llvm/IR/GlobalAlias.h"
 #include "llvm/IR/GlobalVariable.h"
+#include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/Intrinsics.h"
-#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/PatternMatch.h"
 #include "llvm/MC/MCAsmInfo.h"
 #include "llvm/MC/MCContext.h"
 #include "llvm/MC/MCExpr.h"
@@ -32091,6 +32092,36 @@ bool X86TargetLowering::isZExtFree(SDValue Val, EVT VT2) const {
 
 bool X86TargetLowering::shouldSinkOperands(Instruction *I,
                                            SmallVectorImpl<Use *> &Ops) const {
+  using namespace llvm::PatternMatch;
+
+  FixedVectorType *VTy = dyn_cast<FixedVectorType>(I->getType());
+  if (!VTy)
+    return false;
+
+  if (I->getOpcode() == Instruction::Mul &&
+      VTy->getElementType()->isIntegerTy(64)) {
+    for (auto &Op : I->operands()) {
+      // Make sure we are not already sinking this operand
+      if (any_of(Ops, [&](Use *U) { return U->get() == Op; }))
+        continue;
+
+      // Look for PMULDQ pattern where the input is a sext_inreg from vXi32 or
+      // the PMULUDQ pattern where the input is a zext_inreg from vXi32.
+      if (Subtarget.hasSSE41() &&
+          match(Op.get(), m_AShr(m_Shl(m_Value(), m_SpecificInt(32)),
+                                 m_SpecificInt(32)))) {
+        Ops.push_back(&cast<Instruction>(Op)->getOperandUse(0));
+        Ops.push_back(&Op);
+      } else if (Subtarget.hasSSE2() &&
+                 match(Op.get(),
+                       m_And(m_Value(), m_SpecificInt(UINT64_C(0xffffffff))))) {
+        Ops.push_back(&Op);
+      }
+    }
+
+    return !Ops.empty();
+  }
+
   // A uniform shift amount in a vector shift or funnel shift may be much
   // cheaper than a generic variable vector shift, so make that pattern visible
   // to SDAG by sinking the shuffle instruction next to the shift.
diff --git a/llvm/test/CodeGen/X86/pr51371.ll b/llvm/test/CodeGen/X86/pr51371.ll
index 6f3edd43b8f1..12915e298f8b 100644
--- a/llvm/test/CodeGen/X86/pr51371.ll
+++ b/llvm/test/CodeGen/X86/pr51371.ll
@@ -8,28 +8,12 @@ define void @pmuldq(<2 x i64>* nocapture %0, i32 %1, i64 %2) {
 ; CHECK-NEXT:    je .LBB0_3
 ; CHECK-NEXT:  # %bb.1:
 ; CHECK-NEXT:    movd %esi, %xmm0
-; CHECK-NEXT:    pshufd {{.*#+}} xmm1 = xmm0[0,0,0,0]
-; CHECK-NEXT:    movdqa %xmm1, %xmm0
-; CHECK-NEXT:    psllq $32, %xmm0
-; CHECK-NEXT:    psrad $31, %xmm0
-; CHECK-NEXT:    pblendw {{.*#+}} xmm0 = xmm1[0,1],xmm0[2,3],xmm1[4,5],xmm0[6,7]
-; CHECK-NEXT:    movdqa %xmm0, %xmm1
-; CHECK-NEXT:    psrlq $32, %xmm1
+; CHECK-NEXT:    pshufd {{.*#+}} xmm0 = xmm0[0,0,0,0]
 ; CHECK-NEXT:    .p2align 4, 0x90
 ; CHECK-NEXT:  .LBB0_2: # =>This Inner Loop Header: Depth=1
-; CHECK-NEXT:    movdqa (%rdi), %xmm2
-; CHECK-NEXT:    movdqa %xmm2, %xmm3
-; CHECK-NEXT:    psllq $32, %xmm3
-; CHECK-NEXT:    psrad $31, %xmm3
-; CHECK-NEXT:    movdqa %xmm2, %xmm4
-; CHECK-NEXT:    pmuludq %xmm1, %xmm4
-; CHECK-NEXT:    psrlq $32, %xmm3
-; CHECK-NEXT:    pmuludq %xmm0, %xmm3
-; CHECK-NEXT:    paddq %xmm4, %xmm3
-; CHECK-NEXT:    psllq $32, %xmm3
-; CHECK-NEXT:    pmuludq %xmm0, %xmm2
-; CHECK-NEXT:    paddq %xmm3, %xmm2
-; CHECK-NEXT:    movdqa %xmm2, (%rdi)
+; CHECK-NEXT:    movdqa (%rdi), %xmm1
+; CHECK-NEXT:    pmuldq %xmm0, %xmm1
+; CHECK-NEXT:    movdqa %xmm1, (%rdi)
 ; CHECK-NEXT:    addq $16, %rdi
 ; CHECK-NEXT:    decq %rdx
 ; CHECK-NEXT:    jne .LBB0_2
@@ -66,9 +50,7 @@ define void @pmuludq(<2 x i64>* nocapture %0, i32 %1, i64 %2) {
 ; CHECK-NEXT:    je .LBB1_3
 ; CHECK-NEXT:  # %bb.1:
 ; CHECK-NEXT:    movd %esi, %xmm0
-; CHECK-NEXT:    pshufd {{.*#+}} xmm1 = xmm0[0,0,0,0]
-; CHECK-NEXT:    pxor %xmm0, %xmm0
-; CHECK-NEXT:    pblendw {{.*#+}} xmm0 = xmm1[0,1],xmm0[2,3],xmm1[4,5],xmm0[6,7]
+; CHECK-NEXT:    pshufd {{.*#+}} xmm0 = xmm0[0,0,0,0]
 ; CHECK-NEXT:    .p2align 4, 0x90
 ; CHECK-NEXT:  .LBB1_2: # =>This Inner Loop Header: Depth=1
 ; CHECK-NEXT:    movdqa (%rdi), %xmm1