[X86] Add DAG combine to turn (trunc (srl (mul ext, ext), 16) into PMULHW/PMULHUW.

Ultimately I want to use this to remove the intrinsics for these instructions.

llvm-svn: 330520
This commit is contained in:
Craig Topper 2018-04-21 18:39:21 +00:00
parent 1b223e75da
commit fe59bea07b
2 changed files with 175 additions and 1106 deletions

View File

@ -36102,6 +36102,59 @@ static SDValue detectAddSubSatPattern(SDValue In, EVT VT, SelectionDAG &DAG,
AddSubSatBuilder);
}
// Try to form a MULHU or MULHS node by looking for
// (trunc (srl (mul ext, ext), 16))
// TODO: This is X86 specific because we want to be able to handle wide types
// before type legalization. But we can only do it if the vector will be
// legalized via widening/splitting. Type legalization can't handle promotion
// of a MULHU/MULHS. There isn't a way to convey this to the generic DAG
// combiner.
static SDValue combinePMULH(SDValue Src, EVT VT, const SDLoc &DL,
SelectionDAG &DAG, const X86Subtarget &Subtarget) {
// First instruction should be a right shift of a multiply.
if (Src.getOpcode() != ISD::SRL ||
Src.getOperand(0).getOpcode() != ISD::MUL)
return SDValue();
if (!Subtarget.hasSSE2())
return SDValue();
// Only handle vXi16 types that are at least 128-bits.
if (!VT.isVector() || VT.getVectorElementType() != MVT::i16 ||
VT.getVectorNumElements() < 8)
return SDValue();
// Input type should be vXi32.
EVT InVT = Src.getValueType();
if (InVT.getVectorElementType() != MVT::i32)
return SDValue();
// Need a shift by 16.
APInt ShiftAmt;
if (!ISD::isConstantSplatVector(Src.getOperand(1).getNode(), ShiftAmt) ||
ShiftAmt != 16)
return SDValue();
SDValue LHS = Src.getOperand(0).getOperand(0);
SDValue RHS = Src.getOperand(0).getOperand(1);
unsigned ExtOpc = LHS.getOpcode();
if ((ExtOpc != ISD::SIGN_EXTEND && ExtOpc != ISD::ZERO_EXTEND) ||
RHS.getOpcode() != ExtOpc)
return SDValue();
// Peek through the extends.
LHS = LHS.getOperand(0);
RHS = RHS.getOperand(0);
// Ensure the input types match.
if (LHS.getValueType() != VT || RHS.getValueType() != VT)
return SDValue();
unsigned Opc = ExtOpc == ISD::SIGN_EXTEND ? ISD::MULHS : ISD::MULHU;
return DAG.getNode(Opc, DL, VT, LHS, RHS);
}
static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG,
const X86Subtarget &Subtarget) {
EVT VT = N->getValueType(0);
@ -36124,6 +36177,10 @@ static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG,
if (SDValue Val = combineTruncateWithSat(Src, VT, DL, DAG, Subtarget))
return Val;
// Try to combine PMULHUW/PMULHW for vXi16.
if (SDValue V = combinePMULH(Src, VT, DL, DAG, Subtarget))
return V;
// The bitcast source is a direct mmx result.
// Detect bitcasts between i32 to x86mmx
if (Src.getOpcode() == ISD::BITCAST && VT == MVT::i32) {

File diff suppressed because it is too large Load Diff