[X86] LowerFunnelShift - use supportedVectorShiftWithBaseAmnt to check for supported scalar shifts

Allows us to reuse the ISD shift opcode instead of a mixture of ISD/X86ISD variants
This commit is contained in:
Simon Pilgrim 2022-01-23 21:13:58 +00:00
parent 413684313d
commit 32dc14f876
1 changed files with 12 additions and 12 deletions

View File

@ -29832,6 +29832,7 @@ static SDValue LowerFunnelShift(SDValue Op, const X86Subtarget &Subtarget,
SDValue AmtMask = DAG.getConstant(EltSizeInBits - 1, DL, VT); SDValue AmtMask = DAG.getConstant(EltSizeInBits - 1, DL, VT);
SDValue AmtMod = DAG.getNode(ISD::AND, DL, VT, Amt, AmtMask); SDValue AmtMod = DAG.getNode(ISD::AND, DL, VT, Amt, AmtMask);
unsigned ShiftOpc = IsFSHR ? ISD::SRL : ISD::SHL;
unsigned NumElts = VT.getVectorNumElements(); unsigned NumElts = VT.getVectorNumElements();
MVT ExtSVT = MVT::getIntegerVT(2 * EltSizeInBits); MVT ExtSVT = MVT::getIntegerVT(2 * EltSizeInBits);
MVT ExtVT = MVT::getVectorVT(ExtSVT, NumElts / 2); MVT ExtVT = MVT::getVectorVT(ExtSVT, NumElts / 2);
@ -29848,20 +29849,19 @@ static SDValue LowerFunnelShift(SDValue Op, const X86Subtarget &Subtarget,
} }
// Attempt to fold scalar shift as unpack(y,x) << zext(splat(z)) // Attempt to fold scalar shift as unpack(y,x) << zext(splat(z))
if (SDValue ScalarAmt = DAG.getSplatValue(AmtMod)) { if (supportedVectorShiftWithBaseAmnt(ExtVT, Subtarget, ShiftOpc)) {
unsigned ShiftX86Opc = IsFSHR ? X86ISD::VSRLI : X86ISD::VSHLI; if (SDValue ScalarAmt = DAG.getSplatValue(AmtMod)) {
SDValue Lo = DAG.getBitcast(ExtVT, getUnpackl(DAG, DL, VT, Op1, Op0)); SDValue Lo = DAG.getBitcast(ExtVT, getUnpackl(DAG, DL, VT, Op1, Op0));
SDValue Hi = DAG.getBitcast(ExtVT, getUnpackh(DAG, DL, VT, Op1, Op0)); SDValue Hi = DAG.getBitcast(ExtVT, getUnpackh(DAG, DL, VT, Op1, Op0));
ScalarAmt = DAG.getZExtOrTrunc(ScalarAmt, DL, MVT::i32); ScalarAmt = DAG.getZExtOrTrunc(ScalarAmt, DL, MVT::i32);
Lo = getTargetVShiftNode(ShiftX86Opc, DL, ExtVT, Lo, ScalarAmt, Subtarget, Lo = getTargetVShiftNode(ShiftOpc, DL, ExtVT, Lo, ScalarAmt, Subtarget,
DAG); DAG);
Hi = getTargetVShiftNode(ShiftX86Opc, DL, ExtVT, Hi, ScalarAmt, Subtarget, Hi = getTargetVShiftNode(ShiftOpc, DL, ExtVT, Hi, ScalarAmt, Subtarget,
DAG); DAG);
return getPack(DAG, Subtarget, DL, VT, Lo, Hi, !IsFSHR); return getPack(DAG, Subtarget, DL, VT, Lo, Hi, !IsFSHR);
}
} }
unsigned ShiftOpc = IsFSHR ? ISD::SRL : ISD::SHL;
MVT WideSVT = MVT::getIntegerVT( MVT WideSVT = MVT::getIntegerVT(
std::min<unsigned>(EltSizeInBits * 2, Subtarget.hasBWI() ? 16 : 32)); std::min<unsigned>(EltSizeInBits * 2, Subtarget.hasBWI() ? 16 : 32));
MVT WideVT = MVT::getVectorVT(WideSVT, NumElts); MVT WideVT = MVT::getVectorVT(WideSVT, NumElts);