diff --git a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp index 9fb40ce68573..41c15b737b21 100644 --- a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp +++ b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp @@ -215,7 +215,7 @@ namespace { bool selectTLSADDRAddr(SDValue N, SDValue &Base, SDValue &Scale, SDValue &Index, SDValue &Disp, SDValue &Segment); - bool selectScalarSSELoad(SDNode *Root, SDValue N, + bool selectScalarSSELoad(SDNode *Root, SDNode *Parent, SDValue N, SDValue &Base, SDValue &Scale, SDValue &Index, SDValue &Disp, SDValue &Segment, @@ -1685,8 +1685,7 @@ bool X86DAGToDAGISel::selectAddr(SDNode *Parent, SDValue N, SDValue &Base, // We can only fold a load if all nodes between it and the root node have a // single use. If there are additional uses, we could end up duplicating the // load. -static bool hasSingleUsesFromRoot(SDNode *Root, SDNode *N) { - SDNode *User = *N->use_begin(); +static bool hasSingleUsesFromRoot(SDNode *Root, SDNode *User) { while (User != Root) { if (!User->hasOneUse()) return false; @@ -1703,17 +1702,19 @@ static bool hasSingleUsesFromRoot(SDNode *Root, SDNode *N) { /// We also return: /// PatternChainNode: this is the matched node that has a chain input and /// output. -bool X86DAGToDAGISel::selectScalarSSELoad(SDNode *Root, +bool X86DAGToDAGISel::selectScalarSSELoad(SDNode *Root, SDNode *Parent, SDValue N, SDValue &Base, SDValue &Scale, SDValue &Index, SDValue &Disp, SDValue &Segment, SDValue &PatternNodeWithChain) { + if (!hasSingleUsesFromRoot(Root, Parent)) + return false; + // We can allow a full vector load here since narrowing a load is ok. if (ISD::isNON_EXTLoad(N.getNode())) { PatternNodeWithChain = N; if (IsProfitableToFold(PatternNodeWithChain, N.getNode(), Root) && - IsLegalToFold(PatternNodeWithChain, *N->use_begin(), Root, OptLevel) && - hasSingleUsesFromRoot(Root, N.getNode())) { + IsLegalToFold(PatternNodeWithChain, Parent, Root, OptLevel)) { LoadSDNode *LD = cast(PatternNodeWithChain); return selectAddr(LD, LD->getBasePtr(), Base, Scale, Index, Disp, Segment); @@ -1724,8 +1725,7 @@ bool X86DAGToDAGISel::selectScalarSSELoad(SDNode *Root, if (N.getOpcode() == X86ISD::VZEXT_LOAD) { PatternNodeWithChain = N; if (IsProfitableToFold(PatternNodeWithChain, N.getNode(), Root) && - IsLegalToFold(PatternNodeWithChain, *N->use_begin(), Root, OptLevel) && - hasSingleUsesFromRoot(Root, N.getNode())) { + IsLegalToFold(PatternNodeWithChain, Parent, Root, OptLevel)) { auto *MI = cast(PatternNodeWithChain); return selectAddr(MI, MI->getBasePtr(), Base, Scale, Index, Disp, Segment); @@ -1739,8 +1739,7 @@ bool X86DAGToDAGISel::selectScalarSSELoad(SDNode *Root, PatternNodeWithChain = N.getOperand(0); if (ISD::isNON_EXTLoad(PatternNodeWithChain.getNode()) && IsProfitableToFold(PatternNodeWithChain, N.getNode(), Root) && - IsLegalToFold(PatternNodeWithChain, N.getNode(), Root, OptLevel) && - hasSingleUsesFromRoot(Root, N.getNode())) { + IsLegalToFold(PatternNodeWithChain, N.getNode(), Root, OptLevel)) { LoadSDNode *LD = cast(PatternNodeWithChain); return selectAddr(LD, LD->getBasePtr(), Base, Scale, Index, Disp, Segment); @@ -1756,8 +1755,7 @@ bool X86DAGToDAGISel::selectScalarSSELoad(SDNode *Root, PatternNodeWithChain = N.getOperand(0).getOperand(0); if (ISD::isNON_EXTLoad(PatternNodeWithChain.getNode()) && IsProfitableToFold(PatternNodeWithChain, N.getNode(), Root) && - IsLegalToFold(PatternNodeWithChain, N.getNode(), Root, OptLevel) && - hasSingleUsesFromRoot(Root, N.getNode())) { + IsLegalToFold(PatternNodeWithChain, N.getNode(), Root, OptLevel)) { // Okay, this is a zero extending load. Fold it. LoadSDNode *LD = cast(PatternNodeWithChain); return selectAddr(LD, LD->getBasePtr(), Base, Scale, Index, Disp, diff --git a/llvm/lib/Target/X86/X86InstrFragmentsSIMD.td b/llvm/lib/Target/X86/X86InstrFragmentsSIMD.td index ee72a7231e36..91b2a568f4de 100644 --- a/llvm/lib/Target/X86/X86InstrFragmentsSIMD.td +++ b/llvm/lib/Target/X86/X86InstrFragmentsSIMD.td @@ -651,10 +651,10 @@ def X86GF2P8mulb : SDNode<"X86ISD::GF2P8MULB", SDTIntBinOp>; // forms. def sse_load_f32 : ComplexPattern; + SDNPWantRoot, SDNPWantParent]>; def sse_load_f64 : ComplexPattern; + SDNPWantRoot, SDNPWantParent]>; def ssmem : Operand { let PrintMethod = "printf32mem"; diff --git a/llvm/test/CodeGen/X86/fma-scalar-memfold.ll b/llvm/test/CodeGen/X86/fma-scalar-memfold.ll index 7822139c3e14..016a78a8dd3a 100644 --- a/llvm/test/CodeGen/X86/fma-scalar-memfold.ll +++ b/llvm/test/CodeGen/X86/fma-scalar-memfold.ll @@ -1,6 +1,6 @@ ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py -; RUN: llc < %s -mcpu=core-avx2 | FileCheck %s --check-prefix=CHECK --check-prefix=AVX2 -; RUN: llc < %s -mcpu=skx | FileCheck %s --check-prefix=CHECK --check-prefix=AVX512 +; RUN: llc < %s -disable-peephole -mcpu=core-avx2 | FileCheck %s --check-prefix=CHECK --check-prefix=AVX2 +; RUN: llc < %s -disable-peephole -mcpu=skx | FileCheck %s --check-prefix=CHECK --check-prefix=AVX512 target triple = "x86_64-unknown-unknown"