forked from OSchip/llvm-project
[X86] Add DAG combine to turn (trunc (srl (mul ext, ext), 16) into PMULHW/PMULHUW.
Ultimately I want to use this to remove the intrinsics for these instructions. llvm-svn: 330520
This commit is contained in:
parent
1b223e75da
commit
fe59bea07b
|
@ -36102,6 +36102,59 @@ static SDValue detectAddSubSatPattern(SDValue In, EVT VT, SelectionDAG &DAG,
|
|||
AddSubSatBuilder);
|
||||
}
|
||||
|
||||
// Try to form a MULHU or MULHS node by looking for
|
||||
// (trunc (srl (mul ext, ext), 16))
|
||||
// TODO: This is X86 specific because we want to be able to handle wide types
|
||||
// before type legalization. But we can only do it if the vector will be
|
||||
// legalized via widening/splitting. Type legalization can't handle promotion
|
||||
// of a MULHU/MULHS. There isn't a way to convey this to the generic DAG
|
||||
// combiner.
|
||||
static SDValue combinePMULH(SDValue Src, EVT VT, const SDLoc &DL,
|
||||
SelectionDAG &DAG, const X86Subtarget &Subtarget) {
|
||||
// First instruction should be a right shift of a multiply.
|
||||
if (Src.getOpcode() != ISD::SRL ||
|
||||
Src.getOperand(0).getOpcode() != ISD::MUL)
|
||||
return SDValue();
|
||||
|
||||
if (!Subtarget.hasSSE2())
|
||||
return SDValue();
|
||||
|
||||
// Only handle vXi16 types that are at least 128-bits.
|
||||
if (!VT.isVector() || VT.getVectorElementType() != MVT::i16 ||
|
||||
VT.getVectorNumElements() < 8)
|
||||
return SDValue();
|
||||
|
||||
// Input type should be vXi32.
|
||||
EVT InVT = Src.getValueType();
|
||||
if (InVT.getVectorElementType() != MVT::i32)
|
||||
return SDValue();
|
||||
|
||||
// Need a shift by 16.
|
||||
APInt ShiftAmt;
|
||||
if (!ISD::isConstantSplatVector(Src.getOperand(1).getNode(), ShiftAmt) ||
|
||||
ShiftAmt != 16)
|
||||
return SDValue();
|
||||
|
||||
SDValue LHS = Src.getOperand(0).getOperand(0);
|
||||
SDValue RHS = Src.getOperand(0).getOperand(1);
|
||||
|
||||
unsigned ExtOpc = LHS.getOpcode();
|
||||
if ((ExtOpc != ISD::SIGN_EXTEND && ExtOpc != ISD::ZERO_EXTEND) ||
|
||||
RHS.getOpcode() != ExtOpc)
|
||||
return SDValue();
|
||||
|
||||
// Peek through the extends.
|
||||
LHS = LHS.getOperand(0);
|
||||
RHS = RHS.getOperand(0);
|
||||
|
||||
// Ensure the input types match.
|
||||
if (LHS.getValueType() != VT || RHS.getValueType() != VT)
|
||||
return SDValue();
|
||||
|
||||
unsigned Opc = ExtOpc == ISD::SIGN_EXTEND ? ISD::MULHS : ISD::MULHU;
|
||||
return DAG.getNode(Opc, DL, VT, LHS, RHS);
|
||||
}
|
||||
|
||||
static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG,
|
||||
const X86Subtarget &Subtarget) {
|
||||
EVT VT = N->getValueType(0);
|
||||
|
@ -36124,6 +36177,10 @@ static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG,
|
|||
if (SDValue Val = combineTruncateWithSat(Src, VT, DL, DAG, Subtarget))
|
||||
return Val;
|
||||
|
||||
// Try to combine PMULHUW/PMULHW for vXi16.
|
||||
if (SDValue V = combinePMULH(Src, VT, DL, DAG, Subtarget))
|
||||
return V;
|
||||
|
||||
// The bitcast source is a direct mmx result.
|
||||
// Detect bitcasts between i32 to x86mmx
|
||||
if (Src.getOpcode() == ISD::BITCAST && VT == MVT::i32) {
|
||||
|
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue