forked from OSchip/llvm-project
Masked gather and scatter: Added code for SelectionDAG.
All other patches, including tests will follow. http://reviews.llvm.org/D7665 llvm-svn: 235970
This commit is contained in:
parent
426bdf8861
commit
584ce378ab
|
@ -687,9 +687,16 @@ namespace ISD {
|
|||
ATOMIC_LOAD_UMIN,
|
||||
ATOMIC_LOAD_UMAX,
|
||||
|
||||
// Masked load and store
|
||||
// Masked load and store - consecutive vector load and store operations
|
||||
// with additional mask operand that prevents memory accesses to the
|
||||
// masked-off lanes.
|
||||
MLOAD, MSTORE,
|
||||
|
||||
// Masked gather and scatter - load and store operations for a vector of
|
||||
// random addresses with additional mask operand that prevents memory
|
||||
// accesses to the masked-off lanes.
|
||||
MGATHER, MSCATTER,
|
||||
|
||||
/// This corresponds to the llvm.lifetime.* intrinsics. The first operand
|
||||
/// is the chain and the second operand is the alloca pointer.
|
||||
LIFETIME_START, LIFETIME_END,
|
||||
|
|
|
@ -856,6 +856,10 @@ public:
|
|||
SDValue getMaskedStore(SDValue Chain, SDLoc dl, SDValue Val,
|
||||
SDValue Ptr, SDValue Mask, EVT MemVT,
|
||||
MachineMemOperand *MMO, bool IsTrunc);
|
||||
SDValue getMaskedGather(SDVTList VTs, EVT VT, SDLoc dl,
|
||||
ArrayRef<SDValue> Ops, MachineMemOperand *MMO);
|
||||
SDValue getMaskedScatter(SDVTList VTs, EVT VT, SDLoc dl,
|
||||
ArrayRef<SDValue> Ops, MachineMemOperand *MMO);
|
||||
/// Construct a node to track a Value* through the backend.
|
||||
SDValue getSrcValue(const Value *v);
|
||||
|
||||
|
|
|
@ -1151,6 +1151,8 @@ public:
|
|||
N->getOpcode() == ISD::ATOMIC_STORE ||
|
||||
N->getOpcode() == ISD::MLOAD ||
|
||||
N->getOpcode() == ISD::MSTORE ||
|
||||
N->getOpcode() == ISD::MGATHER ||
|
||||
N->getOpcode() == ISD::MSCATTER ||
|
||||
N->isMemIntrinsic() ||
|
||||
N->isTargetMemoryOpcode();
|
||||
}
|
||||
|
@ -1987,6 +1989,82 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
/// This is a base class is used to represent
|
||||
/// MGATHER and MSCATTER nodes
|
||||
///
|
||||
class MaskedGatherScatterSDNode : public MemSDNode {
|
||||
// Operands
|
||||
SDUse Ops[5];
|
||||
public:
|
||||
friend class SelectionDAG;
|
||||
MaskedGatherScatterSDNode(ISD::NodeType NodeTy, unsigned Order, DebugLoc dl,
|
||||
ArrayRef<SDValue> Operands, SDVTList VTs, EVT MemVT,
|
||||
MachineMemOperand *MMO)
|
||||
: MemSDNode(NodeTy, Order, dl, VTs, MemVT, MMO) {
|
||||
assert(Operands.size() == 5 && "Incompatible number of operands");
|
||||
InitOperands(Ops, Operands.data(), Operands.size());
|
||||
}
|
||||
|
||||
// In the both nodes address is Op1, mask is Op2:
|
||||
// MaskedGatherSDNode (Chain, src0, mask, base, index), src0 is a passthru value
|
||||
// MaskedScatterSDNode (Chain, value, mask, base, index)
|
||||
// Mask is a vector of i1 elements
|
||||
const SDValue &getBasePtr() const { return getOperand(3); }
|
||||
const SDValue &getIndex() const { return getOperand(4); }
|
||||
const SDValue &getMask() const { return getOperand(2); }
|
||||
const SDValue &getValue() const { return getOperand(1); }
|
||||
|
||||
static bool classof(const SDNode *N) {
|
||||
return N->getOpcode() == ISD::MGATHER ||
|
||||
N->getOpcode() == ISD::MSCATTER;
|
||||
}
|
||||
};
|
||||
|
||||
/// This class is used to represent an MGATHER node
|
||||
///
|
||||
class MaskedGatherSDNode : public MaskedGatherScatterSDNode {
|
||||
public:
|
||||
friend class SelectionDAG;
|
||||
MaskedGatherSDNode(unsigned Order, DebugLoc dl, ArrayRef<SDValue> Operands,
|
||||
SDVTList VTs, EVT MemVT, MachineMemOperand *MMO)
|
||||
: MaskedGatherScatterSDNode(ISD::MGATHER, Order, dl, Operands, VTs, MemVT,
|
||||
MMO) {
|
||||
assert(getValue().getValueType() == getValueType(0) &&
|
||||
"Incompatible type of the PathThru value in MaskedGatherSDNode");
|
||||
assert(getMask().getValueType().getVectorNumElements() ==
|
||||
getValueType(0).getVectorNumElements() &&
|
||||
"Vector width mismatch between mask and data");
|
||||
assert(getMask().getValueType().getScalarType() == MVT::i1 &&
|
||||
"Vector width mismatch between mask and data");
|
||||
}
|
||||
|
||||
static bool classof(const SDNode *N) {
|
||||
return N->getOpcode() == ISD::MGATHER;
|
||||
}
|
||||
};
|
||||
|
||||
/// This class is used to represent an MSCATTER node
|
||||
///
|
||||
class MaskedScatterSDNode : public MaskedGatherScatterSDNode {
|
||||
|
||||
public:
|
||||
friend class SelectionDAG;
|
||||
MaskedScatterSDNode(unsigned Order, DebugLoc dl,ArrayRef<SDValue> Operands,
|
||||
SDVTList VTs, EVT MemVT, MachineMemOperand *MMO)
|
||||
: MaskedGatherScatterSDNode(ISD::MSCATTER, Order, dl, Operands, VTs, MemVT,
|
||||
MMO) {
|
||||
assert(getMask().getValueType().getVectorNumElements() ==
|
||||
getValue().getValueType().getVectorNumElements() &&
|
||||
"Vector width mismatch between mask and data");
|
||||
assert(getMask().getValueType().getScalarType() == MVT::i1 &&
|
||||
"Vector width mismatch between mask and data");
|
||||
}
|
||||
|
||||
static bool classof(const SDNode *N) {
|
||||
return N->getOpcode() == ISD::MSCATTER;
|
||||
}
|
||||
};
|
||||
|
||||
/// An SDNode that represents everything that will be needed
|
||||
/// to construct a MachineInstr. These nodes are created during the
|
||||
/// instruction selection proper phase.
|
||||
|
@ -2078,7 +2156,7 @@ template <> struct GraphTraits<SDNode*> {
|
|||
};
|
||||
|
||||
/// The largest SDNode class.
|
||||
typedef AtomicSDNode LargestSDNode;
|
||||
typedef MaskedGatherScatterSDNode LargestSDNode;
|
||||
|
||||
/// The SDNode class with the greatest alignment requirement.
|
||||
typedef GlobalAddressSDNode MostAlignedSDNode;
|
||||
|
|
|
@ -5097,6 +5097,55 @@ SDValue SelectionDAG::getMaskedStore(SDValue Chain, SDLoc dl, SDValue Val,
|
|||
return SDValue(N, 0);
|
||||
}
|
||||
|
||||
SDValue
|
||||
SelectionDAG::getMaskedGather(SDVTList VTs, EVT VT, SDLoc dl,
|
||||
ArrayRef<SDValue> Ops,
|
||||
MachineMemOperand *MMO) {
|
||||
|
||||
FoldingSetNodeID ID;
|
||||
AddNodeIDNode(ID, ISD::MGATHER, VTs, Ops);
|
||||
ID.AddInteger(VT.getRawBits());
|
||||
ID.AddInteger(encodeMemSDNodeFlags(ISD::NON_EXTLOAD, ISD::UNINDEXED,
|
||||
MMO->isVolatile(),
|
||||
MMO->isNonTemporal(),
|
||||
MMO->isInvariant()));
|
||||
ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
|
||||
void *IP = nullptr;
|
||||
if (SDNode *E = CSEMap.FindNodeOrInsertPos(ID, IP)) {
|
||||
cast<MaskedGatherSDNode>(E)->refineAlignment(MMO);
|
||||
return SDValue(E, 0);
|
||||
}
|
||||
MaskedGatherSDNode *N =
|
||||
new (NodeAllocator) MaskedGatherSDNode(dl.getIROrder(), dl.getDebugLoc(),
|
||||
Ops, VTs, VT, MMO);
|
||||
CSEMap.InsertNode(N, IP);
|
||||
InsertNode(N);
|
||||
return SDValue(N, 0);
|
||||
}
|
||||
|
||||
SDValue SelectionDAG::getMaskedScatter(SDVTList VTs, EVT VT, SDLoc dl,
|
||||
ArrayRef<SDValue> Ops,
|
||||
MachineMemOperand *MMO) {
|
||||
FoldingSetNodeID ID;
|
||||
AddNodeIDNode(ID, ISD::MSCATTER, VTs, Ops);
|
||||
ID.AddInteger(VT.getRawBits());
|
||||
ID.AddInteger(encodeMemSDNodeFlags(false, ISD::UNINDEXED, MMO->isVolatile(),
|
||||
MMO->isNonTemporal(),
|
||||
MMO->isInvariant()));
|
||||
ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
|
||||
void *IP = nullptr;
|
||||
if (SDNode *E = CSEMap.FindNodeOrInsertPos(ID, IP)) {
|
||||
cast<MaskedScatterSDNode>(E)->refineAlignment(MMO);
|
||||
return SDValue(E, 0);
|
||||
}
|
||||
SDNode *N =
|
||||
new (NodeAllocator) MaskedScatterSDNode(dl.getIROrder(), dl.getDebugLoc(),
|
||||
Ops, VTs, VT, MMO);
|
||||
CSEMap.InsertNode(N, IP);
|
||||
InsertNode(N);
|
||||
return SDValue(N, 0);
|
||||
}
|
||||
|
||||
SDValue SelectionDAG::getVAArg(EVT VT, SDLoc dl,
|
||||
SDValue Chain, SDValue Ptr,
|
||||
SDValue SV,
|
||||
|
|
|
@ -1059,6 +1059,12 @@ SDValue SelectionDAGBuilder::getValue(const Value *V) {
|
|||
return Val;
|
||||
}
|
||||
|
||||
// Return true if SDValue exists for the given Value
|
||||
bool SelectionDAGBuilder::findValue(const Value *V) const {
|
||||
return (NodeMap.find(V) != NodeMap.end()) ||
|
||||
(FuncInfo.ValueMap.find(V) != FuncInfo.ValueMap.end());
|
||||
}
|
||||
|
||||
/// getNonRegisterValue - Return an SDValue for the given Value, but
|
||||
/// don't look in FuncInfo.ValueMap for a virtual register.
|
||||
SDValue SelectionDAGBuilder::getNonRegisterValue(const Value *V) {
|
||||
|
@ -3026,6 +3032,92 @@ void SelectionDAGBuilder::visitMaskedStore(const CallInst &I) {
|
|||
setValue(&I, StoreNode);
|
||||
}
|
||||
|
||||
// Gather/scatter receive a vector of pointers.
|
||||
// This vector of pointers may be represented as a base pointer + vector of
|
||||
// indices, it depends on GEP and instruction preceeding GEP
|
||||
// that calculates indices
|
||||
static bool getUniformBase(Value *& Ptr, SDValue& Base, SDValue& Index,
|
||||
SelectionDAGBuilder* SDB) {
|
||||
|
||||
assert (Ptr->getType()->isVectorTy() && "Uexpected pointer type");
|
||||
GetElementPtrInst *Gep = dyn_cast<GetElementPtrInst>(Ptr);
|
||||
if (!Gep || Gep->getNumOperands() > 2)
|
||||
return false;
|
||||
ShuffleVectorInst *ShuffleInst =
|
||||
dyn_cast<ShuffleVectorInst>(Gep->getPointerOperand());
|
||||
if (!ShuffleInst || !ShuffleInst->getMask()->isNullValue() ||
|
||||
cast<Instruction>(ShuffleInst->getOperand(0))->getOpcode() !=
|
||||
Instruction::InsertElement)
|
||||
return false;
|
||||
|
||||
Ptr = cast<InsertElementInst>(ShuffleInst->getOperand(0))->getOperand(1);
|
||||
|
||||
SelectionDAG& DAG = SDB->DAG;
|
||||
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
|
||||
// Check is the Ptr is inside current basic block
|
||||
// If not, look for the shuffle instruction
|
||||
if (SDB->findValue(Ptr))
|
||||
Base = SDB->getValue(Ptr);
|
||||
else if (SDB->findValue(ShuffleInst)) {
|
||||
SDValue ShuffleNode = SDB->getValue(ShuffleInst);
|
||||
Base = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SDLoc(ShuffleNode),
|
||||
ShuffleNode.getValueType().getScalarType(), ShuffleNode,
|
||||
DAG.getConstant(0, TLI.getVectorIdxTy()));
|
||||
SDB->setValue(Ptr, Base);
|
||||
}
|
||||
else
|
||||
return false;
|
||||
|
||||
Value *IndexVal = Gep->getOperand(1);
|
||||
if (SDB->findValue(IndexVal)) {
|
||||
Index = SDB->getValue(IndexVal);
|
||||
|
||||
if (SExtInst* Sext = dyn_cast<SExtInst>(IndexVal)) {
|
||||
IndexVal = Sext->getOperand(0);
|
||||
if (SDB->findValue(IndexVal))
|
||||
Index = SDB->getValue(IndexVal);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void SelectionDAGBuilder::visitMaskedScatter(const CallInst &I) {
|
||||
SDLoc sdl = getCurSDLoc();
|
||||
|
||||
// llvm.masked.scatter.*(Src0, Ptrs, alignemt, Mask)
|
||||
Value *Ptr = I.getArgOperand(1);
|
||||
SDValue Src0 = getValue(I.getArgOperand(0));
|
||||
SDValue Mask = getValue(I.getArgOperand(3));
|
||||
EVT VT = Src0.getValueType();
|
||||
unsigned Alignment = (cast<ConstantInt>(I.getArgOperand(2)))->getZExtValue();
|
||||
if (!Alignment)
|
||||
Alignment = DAG.getEVTAlignment(VT);
|
||||
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
|
||||
|
||||
AAMDNodes AAInfo;
|
||||
I.getAAMetadata(AAInfo);
|
||||
|
||||
SDValue Base;
|
||||
SDValue Index;
|
||||
Value *BasePtr = Ptr;
|
||||
bool UniformBase = getUniformBase(BasePtr, Base, Index, this);
|
||||
|
||||
Value *MemOpBasePtr = UniformBase ? BasePtr : NULL;
|
||||
MachineMemOperand *MMO = DAG.getMachineFunction().
|
||||
getMachineMemOperand(MachinePointerInfo(MemOpBasePtr),
|
||||
MachineMemOperand::MOStore, VT.getStoreSize(),
|
||||
Alignment, AAInfo);
|
||||
if (!UniformBase) {
|
||||
Base = DAG.getTargetConstant(0, TLI.getPointerTy());
|
||||
Index = getValue(Ptr);
|
||||
}
|
||||
SDValue Ops[] = { getRoot(), Src0, Mask, Base, Index };
|
||||
SDValue Scatter = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), VT, sdl, Ops, MMO);
|
||||
DAG.setRoot(Scatter);
|
||||
setValue(&I, Scatter);
|
||||
}
|
||||
|
||||
void SelectionDAGBuilder::visitMaskedLoad(const CallInst &I) {
|
||||
SDLoc sdl = getCurSDLoc();
|
||||
|
||||
|
@ -3067,6 +3159,60 @@ void SelectionDAGBuilder::visitMaskedLoad(const CallInst &I) {
|
|||
setValue(&I, Load);
|
||||
}
|
||||
|
||||
void SelectionDAGBuilder::visitMaskedGather(const CallInst &I) {
|
||||
SDLoc sdl = getCurSDLoc();
|
||||
|
||||
// @llvm.masked.gather.*(Ptrs, alignment, Mask, Src0)
|
||||
Value *Ptr = I.getArgOperand(0);
|
||||
SDValue Src0 = getValue(I.getArgOperand(3));
|
||||
SDValue Mask = getValue(I.getArgOperand(2));
|
||||
|
||||
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
|
||||
EVT VT = TLI.getValueType(I.getType());
|
||||
unsigned Alignment = (cast<ConstantInt>(I.getArgOperand(1)))->getZExtValue();
|
||||
if (!Alignment)
|
||||
Alignment = DAG.getEVTAlignment(VT);
|
||||
|
||||
AAMDNodes AAInfo;
|
||||
I.getAAMetadata(AAInfo);
|
||||
const MDNode *Ranges = I.getMetadata(LLVMContext::MD_range);
|
||||
|
||||
SDValue Root = DAG.getRoot();
|
||||
SDValue Base;
|
||||
SDValue Index;
|
||||
Value *BasePtr = Ptr;
|
||||
bool UniformBase = getUniformBase(BasePtr, Base, Index, this);
|
||||
bool ConstantMemory = false;
|
||||
if (UniformBase && AA->pointsToConstantMemory(
|
||||
AliasAnalysis::Location(BasePtr,
|
||||
AA->getTypeStoreSize(I.getType()),
|
||||
AAInfo))) {
|
||||
// Do not serialize (non-volatile) loads of constant memory with anything.
|
||||
Root = DAG.getEntryNode();
|
||||
ConstantMemory = true;
|
||||
}
|
||||
|
||||
MachineMemOperand *MMO =
|
||||
DAG.getMachineFunction().
|
||||
getMachineMemOperand(MachinePointerInfo(UniformBase ? BasePtr : NULL),
|
||||
MachineMemOperand::MOLoad, VT.getStoreSize(),
|
||||
Alignment, AAInfo, Ranges);
|
||||
|
||||
if (!UniformBase) {
|
||||
Base = DAG.getTargetConstant(0, TLI.getPointerTy());
|
||||
Index = getValue(Ptr);
|
||||
}
|
||||
|
||||
SDValue Ops[] = { Root, Src0, Mask, Base, Index };
|
||||
SDValue Gather = DAG.getMaskedGather(DAG.getVTList(VT, MVT::Other), VT, sdl,
|
||||
Ops, MMO);
|
||||
|
||||
SDValue OutChain = Gather.getValue(1);
|
||||
if (!ConstantMemory)
|
||||
PendingLoads.push_back(OutChain);
|
||||
setValue(&I, Gather);
|
||||
}
|
||||
|
||||
void SelectionDAGBuilder::visitAtomicCmpXchg(const AtomicCmpXchgInst &I) {
|
||||
SDLoc dl = getCurSDLoc();
|
||||
AtomicOrdering SuccessOrder = I.getSuccessOrdering();
|
||||
|
@ -4216,9 +4362,13 @@ SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I, unsigned Intrinsic) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
case Intrinsic::masked_gather:
|
||||
visitMaskedGather(I);
|
||||
case Intrinsic::masked_load:
|
||||
visitMaskedLoad(I);
|
||||
return nullptr;
|
||||
case Intrinsic::masked_scatter:
|
||||
visitMaskedScatter(I);
|
||||
case Intrinsic::masked_store:
|
||||
visitMaskedStore(I);
|
||||
return nullptr;
|
||||
|
|
|
@ -667,6 +667,8 @@ public:
|
|||
// generate the debug data structures now that we've seen its definition.
|
||||
void resolveDanglingDebugInfo(const Value *V, SDValue Val);
|
||||
SDValue getValue(const Value *V);
|
||||
bool findValue(const Value *V) const;
|
||||
|
||||
SDValue getNonRegisterValue(const Value *V);
|
||||
SDValue getValueImpl(const Value *V);
|
||||
|
||||
|
@ -814,6 +816,8 @@ private:
|
|||
void visitStore(const StoreInst &I);
|
||||
void visitMaskedLoad(const CallInst &I);
|
||||
void visitMaskedStore(const CallInst &I);
|
||||
void visitMaskedGather(const CallInst &I);
|
||||
void visitMaskedScatter(const CallInst &I);
|
||||
void visitAtomicCmpXchg(const AtomicCmpXchgInst &I);
|
||||
void visitAtomicRMW(const AtomicRMWInst &I);
|
||||
void visitFence(const FenceInst &I);
|
||||
|
|
Loading…
Reference in New Issue