diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 073b4e01377a..8acdb9de793a 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -16086,9 +16086,6 @@ static SDValue LowerAVXExtend(SDValue Op, SelectionDAG &DAG, MVT InVT = In.getSimpleValueType(); SDLoc dl(Op); - if (VT.is512BitVector() || InVT.getVectorElementType() == MVT::i1) - return DAG.getNode(ISD::ZERO_EXTEND, dl, VT, In); - // Optimize vectors in AVX mode: // // v8i16 -> v8i32 @@ -16158,6 +16155,13 @@ static SDValue LowerZERO_EXTEND_AVX512(SDValue Op, static SDValue LowerANY_EXTEND(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG) { + MVT VT = Op->getSimpleValueType(0); + SDValue In = Op->getOperand(0); + MVT InVT = In.getSimpleValueType(); + + if (VT.is512BitVector() || InVT.getVectorElementType() == MVT::i1) + return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(Op), VT, In); + if (Subtarget.hasFp256()) if (SDValue Res = LowerAVXExtend(Op, DAG, Subtarget)) return Res; @@ -16167,7 +16171,6 @@ static SDValue LowerANY_EXTEND(SDValue Op, const X86Subtarget &Subtarget, static SDValue LowerZERO_EXTEND(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG) { - SDLoc DL(Op); MVT VT = Op.getSimpleValueType(); SDValue In = Op.getOperand(0); MVT SVT = In.getSimpleValueType(); @@ -18268,14 +18271,6 @@ static SDValue LowerSIGN_EXTEND_AVX512(SDValue Op, MVT InVTElt = InVT.getVectorElementType(); SDLoc dl(Op); - // SKX processor - if ((InVTElt == MVT::i1) && - (((Subtarget.hasBWI() && VTElt.getSizeInBits() <= 16)) || - - ((Subtarget.hasDQI() && VTElt.getSizeInBits() >= 32)))) - - return DAG.getNode(X86ISD::VSEXT, dl, VT, In); - unsigned NumElts = VT.getVectorNumElements(); if (VT.is512BitVector() && InVTElt != MVT::i1 && @@ -18288,6 +18283,11 @@ static SDValue LowerSIGN_EXTEND_AVX512(SDValue Op, if (InVTElt != MVT::i1) return SDValue(); + // SKX processor + if (((Subtarget.hasBWI() && VTElt.getSizeInBits() <= 16)) || + ((Subtarget.hasDQI() && VTElt.getSizeInBits() >= 32))) + return DAG.getNode(X86ISD::VSEXT, dl, VT, In); + MVT ExtVT = VT; if (!VT.is512BitVector() && !Subtarget.hasVLX()) { ExtVT = MVT::getVectorVT(MVT::getIntegerVT(512/NumElts), NumElts);