[SelectionDAG][X86] Reorder the operands the MaskedStoreSDNode to put the value first.

Summary:
Previously the value being stored is the last operand in SDNode. This causes the type legalizer to visit the mask operand before the value operand. The type legalizer was more complicated because of this since we want the type of the value to drive the decisions.

This patch moves the value to be the first operand so we visit it first during type legalization. It also simplifies the type legalization code accordingly.

X86 is currently the only in tree target that uses this SDNode. Not sure if there are any users out of tree.

Reviewers: RKSimon, delena, hfinkel, eli.friedman

Reviewed By: RKSimon

Subscribers: llvm-commits

Differential Revision: https://reviews.llvm.org/D50402

llvm-svn: 340689
This commit is contained in:
Craig Topper 2018-08-25 17:48:17 +00:00
parent bce8680605
commit a11a3b3818
9 changed files with 49 additions and 59 deletions

View File

@ -2113,12 +2113,15 @@ public:
MachineMemOperand *MMO)
: MemSDNode(NodeTy, Order, dl, VTs, MemVT, MMO) {}
// In the both nodes address is Op1, mask is Op2:
// MaskedLoadSDNode (Chain, ptr, mask, src0), src0 is a passthru value
// MaskedStoreSDNode (Chain, ptr, mask, data)
// MaskedLoadSDNode (Chain, ptr, mask, passthru)
// MaskedStoreSDNode (Chain, data, ptr, mask)
// Mask is a vector of i1 elements
const SDValue &getBasePtr() const { return getOperand(1); }
const SDValue &getMask() const { return getOperand(2); }
const SDValue &getBasePtr() const {
return getOperand(getOpcode() == ISD::MLOAD ? 1 : 2);
}
const SDValue &getMask() const {
return getOperand(getOpcode() == ISD::MLOAD ? 2 : 3);
}
static bool classof(const SDNode *N) {
return N->getOpcode() == ISD::MLOAD ||
@ -2143,7 +2146,10 @@ public:
return static_cast<ISD::LoadExtType>(LoadSDNodeBits.ExtTy);
}
const SDValue &getBasePtr() const { return getOperand(1); }
const SDValue &getMask() const { return getOperand(2); }
const SDValue &getPassThru() const { return getOperand(3); }
static bool classof(const SDNode *N) {
return N->getOpcode() == ISD::MLOAD;
}
@ -2175,7 +2181,9 @@ public:
/// memory at base_addr.
bool isCompressingStore() const { return StoreSDNodeBits.IsCompressing; }
const SDValue &getValue() const { return getOperand(3); }
const SDValue &getValue() const { return getOperand(1); }
const SDValue &getBasePtr() const { return getOperand(2); }
const SDValue &getMask() const { return getOperand(3); }
static bool classof(const SDNode *N) {
return N->getOpcode() == ISD::MSTORE;

View File

@ -217,7 +217,7 @@ def SDTIStore : SDTypeProfile<1, 3, [ // indexed store
]>;
def SDTMaskedStore: SDTypeProfile<0, 3, [ // masked store
SDTCisPtrTy<0>, SDTCisVec<1>, SDTCisVec<2>, SDTCisSameNumEltsAs<1, 2>
SDTCisVec<0>, SDTCisPtrTy<1>, SDTCisVec<2>, SDTCisSameNumEltsAs<0, 2>
]>;
def SDTMaskedLoad: SDTypeProfile<1, 3, [ // masked load

View File

@ -1219,28 +1219,14 @@ SDValue DAGTypeLegalizer::PromoteIntOp_MSTORE(MaskedStoreSDNode *N,
SDLoc dl(N);
bool TruncateStore = false;
if (OpNo == 2) {
// Mask comes before the data operand. If the data operand is legal, we just
// promote the mask.
// When the data operand has illegal type, we should legalize the data
// operand first. The mask will be promoted/splitted/widened according to
// the data operand type.
if (TLI.isTypeLegal(DataVT)) {
Mask = PromoteTargetBoolean(Mask, DataVT);
// Update in place.
SmallVector<SDValue, 4> NewOps(N->op_begin(), N->op_end());
NewOps[2] = Mask;
return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0);
}
if (getTypeAction(DataVT) == TargetLowering::TypePromoteInteger)
return PromoteIntOp_MSTORE(N, 3);
if (getTypeAction(DataVT) == TargetLowering::TypeWidenVector)
return WidenVecOp_MSTORE(N, 3);
assert (getTypeAction(DataVT) == TargetLowering::TypeSplitVector);
return SplitVecOp_MSTORE(N, 3);
if (OpNo == 3) {
Mask = PromoteTargetBoolean(Mask, DataVT);
// Update in place.
SmallVector<SDValue, 4> NewOps(N->op_begin(), N->op_end());
NewOps[3] = Mask;
return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0);
} else { // Data operand
assert(OpNo == 3 && "Unexpected operand for promotion");
assert(OpNo == 1 && "Unexpected operand for promotion");
DataOp = GetPromotedInteger(DataOp);
TruncateStore = true;
}

View File

@ -3860,7 +3860,7 @@ SDValue DAGTypeLegalizer::WidenVecOp_STORE(SDNode *N) {
}
SDValue DAGTypeLegalizer::WidenVecOp_MSTORE(SDNode *N, unsigned OpNo) {
assert((OpNo == 2 || OpNo == 3) &&
assert((OpNo == 1 || OpNo == 3) &&
"Can widen only data or mask operand of mstore");
MaskedStoreSDNode *MST = cast<MaskedStoreSDNode>(N);
SDValue Mask = MST->getMask();
@ -3868,8 +3868,8 @@ SDValue DAGTypeLegalizer::WidenVecOp_MSTORE(SDNode *N, unsigned OpNo) {
SDValue StVal = MST->getValue();
SDLoc dl(N);
if (OpNo == 3) {
// Widen the value
if (OpNo == 1) {
// Widen the value.
StVal = GetWidenedVector(StVal);
// The mask should be widened as well.
@ -3879,18 +3879,15 @@ SDValue DAGTypeLegalizer::WidenVecOp_MSTORE(SDNode *N, unsigned OpNo) {
WideVT.getVectorNumElements());
Mask = ModifyToType(Mask, WideMaskVT, true);
} else {
// Widen the mask.
EVT WideMaskVT = TLI.getTypeToTransformTo(*DAG.getContext(), MaskVT);
Mask = ModifyToType(Mask, WideMaskVT, true);
EVT ValueVT = StVal.getValueType();
if (getTypeAction(ValueVT) == TargetLowering::TypeWidenVector)
StVal = GetWidenedVector(StVal);
else {
EVT WideVT = EVT::getVectorVT(*DAG.getContext(),
ValueVT.getVectorElementType(),
WideMaskVT.getVectorNumElements());
StVal = ModifyToType(StVal, WideVT);
}
EVT WideVT = EVT::getVectorVT(*DAG.getContext(),
ValueVT.getVectorElementType(),
WideMaskVT.getVectorNumElements());
StVal = ModifyToType(StVal, WideVT);
}
assert(Mask.getValueType().getVectorNumElements() ==

View File

@ -6580,11 +6580,11 @@ SDValue SelectionDAG::getIndexedStore(SDValue OrigStore, const SDLoc &dl,
}
SDValue SelectionDAG::getMaskedLoad(EVT VT, const SDLoc &dl, SDValue Chain,
SDValue Ptr, SDValue Mask, SDValue Src0,
SDValue Ptr, SDValue Mask, SDValue PassThru,
EVT MemVT, MachineMemOperand *MMO,
ISD::LoadExtType ExtTy, bool isExpanding) {
SDVTList VTs = getVTList(VT, MVT::Other);
SDValue Ops[] = { Chain, Ptr, Mask, Src0 };
SDValue Ops[] = { Chain, Ptr, Mask, PassThru };
FoldingSetNodeID ID;
AddNodeIDNode(ID, ISD::MLOAD, VTs, Ops);
ID.AddInteger(VT.getRawBits());
@ -6615,7 +6615,7 @@ SDValue SelectionDAG::getMaskedStore(SDValue Chain, const SDLoc &dl,
"Invalid chain type");
EVT VT = Val.getValueType();
SDVTList VTs = getVTList(MVT::Other);
SDValue Ops[] = { Chain, Ptr, Mask, Val };
SDValue Ops[] = { Chain, Val, Ptr, Mask };
FoldingSetNodeID ID;
AddNodeIDNode(ID, ISD::MSTORE, VTs, Ops);
ID.AddInteger(VT.getRawBits());

View File

@ -21723,7 +21723,7 @@ EmitMaskedTruncSStore(bool SignedSat, SDValue Chain, const SDLoc &Dl,
MachineMemOperand *MMO, SelectionDAG &DAG) {
SDVTList VTs = DAG.getVTList(MVT::Other);
SDValue Ops[] = { Chain, Ptr, Mask, Val };
SDValue Ops[] = { Chain, Val, Ptr, Mask };
return SignedSat ?
DAG.getTargetMemSDNode<MaskedTruncSStoreSDNode>(VTs, Ops, Dl, MemVT, MMO) :
DAG.getTargetMemSDNode<MaskedTruncUSStoreSDNode>(VTs, Ops, Dl, MemVT, MMO);

View File

@ -1407,9 +1407,9 @@ namespace llvm {
MachineMemOperand *MMO)
: MemSDNode(Opcode, Order, dl, VTs, MemVT, MMO) {}
const SDValue &getBasePtr() const { return getOperand(1); }
const SDValue &getMask() const { return getOperand(2); }
const SDValue &getValue() const { return getOperand(3); }
const SDValue &getValue() const { return getOperand(1); }
const SDValue &getBasePtr() const { return getOperand(2); }
const SDValue &getMask() const { return getOperand(3); }
static bool classof(const SDNode *N) {
return N->getOpcode() == X86ISD::VMTRUNCSTORES ||

View File

@ -3474,7 +3474,7 @@ multiclass avx512_store<bits<8> opc, string OpcodeStr, string BaseName,
[], _.ExeDomain>, EVEX, EVEX_K, Sched<[Sched.MR]>,
NotMemoryFoldable;
def: Pat<(mstore addr:$ptr, _.KRCWM:$mask, (_.VT _.RC:$src)),
def: Pat<(mstore (_.VT _.RC:$src), addr:$ptr, _.KRCWM:$mask),
(!cast<Instruction>(BaseName#_.ZSuffix#mrk) addr:$ptr,
_.KRCWM:$mask, _.RC:$src)>;
@ -4029,10 +4029,10 @@ def : Pat<(_.VT (OpNode _.RC:$src0,
multiclass avx512_store_scalar_lowering<string InstrStr, AVX512VLVectorVTInfo _,
dag Mask, RegisterClass MaskRC> {
def : Pat<(masked_store addr:$dst, Mask,
def : Pat<(masked_store
(_.info512.VT (insert_subvector undef,
(_.info128.VT _.info128.RC:$src),
(iPTR 0)))),
(iPTR 0))), addr:$dst, Mask),
(!cast<Instruction>(InstrStr#mrk) addr:$dst,
(COPY_TO_REGCLASS MaskRC:$mask, VK1WM),
(COPY_TO_REGCLASS _.info128.RC:$src, _.info128.FRC))>;
@ -4044,10 +4044,10 @@ multiclass avx512_store_scalar_lowering_subreg<string InstrStr,
dag Mask, RegisterClass MaskRC,
SubRegIndex subreg> {
def : Pat<(masked_store addr:$dst, Mask,
def : Pat<(masked_store
(_.info512.VT (insert_subvector undef,
(_.info128.VT _.info128.RC:$src),
(iPTR 0)))),
(iPTR 0))), addr:$dst, Mask),
(!cast<Instruction>(InstrStr#mrk) addr:$dst,
(COPY_TO_REGCLASS (i32 (INSERT_SUBREG (IMPLICIT_DEF), MaskRC:$mask, subreg)), VK1WM),
(COPY_TO_REGCLASS _.info128.RC:$src, _.info128.FRC))>;
@ -4064,16 +4064,16 @@ multiclass avx512_store_scalar_lowering_subreg2<string InstrStr,
SubRegIndex subreg> {
// AVX512F pattern.
def : Pat<(masked_store addr:$dst, Mask512,
def : Pat<(masked_store
(_.info512.VT (insert_subvector undef,
(_.info128.VT _.info128.RC:$src),
(iPTR 0)))),
(iPTR 0))), addr:$dst, Mask512),
(!cast<Instruction>(InstrStr#mrk) addr:$dst,
(COPY_TO_REGCLASS (i32 (INSERT_SUBREG (IMPLICIT_DEF), MaskRC:$mask, subreg)), VK1WM),
(COPY_TO_REGCLASS _.info128.RC:$src, _.info128.FRC))>;
// AVX512VL pattern.
def : Pat<(masked_store addr:$dst, Mask128, (_.info128.VT _.info128.RC:$src)),
def : Pat<(masked_store (_.info128.VT _.info128.RC:$src), addr:$dst, Mask128),
(!cast<Instruction>(InstrStr#mrk) addr:$dst,
(COPY_TO_REGCLASS (i32 (INSERT_SUBREG (IMPLICIT_DEF), MaskRC:$mask, subreg)), VK1WM),
(COPY_TO_REGCLASS _.info128.RC:$src, _.info128.FRC))>;
@ -8992,8 +8992,8 @@ multiclass avx512_trunc_mr_lowering<X86VectorVTInfo SrcInfo,
(!cast<Instruction>(Name#SrcInfo.ZSuffix##mr)
addr:$dst, SrcInfo.RC:$src)>;
def : Pat<(mtruncFrag addr:$dst, SrcInfo.KRCWM:$mask,
(SrcInfo.VT SrcInfo.RC:$src)),
def : Pat<(mtruncFrag (SrcInfo.VT SrcInfo.RC:$src), addr:$dst,
SrcInfo.KRCWM:$mask),
(!cast<Instruction>(Name#SrcInfo.ZSuffix##mrk)
addr:$dst, SrcInfo.KRCWM:$mask, SrcInfo.RC:$src)>;
}
@ -9714,8 +9714,7 @@ multiclass compress_by_vec_width_common<bits<8> opc, X86VectorVTInfo _,
}
multiclass compress_by_vec_width_lowering<X86VectorVTInfo _, string Name> {
def : Pat<(X86mCompressingStore addr:$dst, _.KRCWM:$mask,
(_.VT _.RC:$src)),
def : Pat<(X86mCompressingStore (_.VT _.RC:$src), addr:$dst, _.KRCWM:$mask),
(!cast<Instruction>(Name#_.ZSuffix##mrk)
addr:$dst, _.KRCWM:$mask, _.RC:$src)>;
}

View File

@ -7940,7 +7940,7 @@ defm VPMASKMOVQ : avx2_pmovmask<"vpmaskmovq",
multiclass maskmov_lowering<string InstrStr, RegisterClass RC, ValueType VT,
ValueType MaskVT, string BlendStr, ValueType ZeroVT> {
// masked store
def: Pat<(X86mstore addr:$ptr, (MaskVT RC:$mask), (VT RC:$src)),
def: Pat<(X86mstore (VT RC:$src), addr:$ptr, (MaskVT RC:$mask)),
(!cast<Instruction>(InstrStr#"mr") addr:$ptr, RC:$mask, RC:$src)>;
// masked load
def: Pat<(VT (X86mload addr:$ptr, (MaskVT RC:$mask), undef)),