[Hexagon] Implement llvm.masked.load and llvm.masked.store for HVX

This commit is contained in:
Krzysztof Parzyszek 2020-08-24 18:29:57 -05:00
parent f78687df9b
commit e15143d31b
9 changed files with 215 additions and 23 deletions

View File

@ -469,8 +469,7 @@ namespace HexagonISD {
SDValue LowerHvxExtend(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerHvxShift(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerHvxIntrinsic(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerHvxStore(SDValue Op, SelectionDAG &DAG) const;
SDValue HvxVecPredBitcastComputation(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerHvxMaskedOp(SDValue Op, SelectionDAG &DAG) const;
SDValue SplitHvxPairOp(SDValue Op, SelectionDAG &DAG) const;
SDValue SplitHvxMemOp(SDValue Op, SelectionDAG &DAG) const;

View File

@ -97,6 +97,8 @@ HexagonTargetLowering::initializeHVXLowering() {
setOperationAction(ISD::CTTZ, T, Custom);
setOperationAction(ISD::LOAD, T, Custom);
setOperationAction(ISD::MLOAD, T, Custom);
setOperationAction(ISD::MSTORE, T, Custom);
setOperationAction(ISD::MUL, T, Custom);
setOperationAction(ISD::MULHS, T, Custom);
setOperationAction(ISD::MULHU, T, Custom);
@ -150,6 +152,8 @@ HexagonTargetLowering::initializeHVXLowering() {
setOperationAction(ISD::LOAD, T, Custom);
setOperationAction(ISD::STORE, T, Custom);
setOperationAction(ISD::MLOAD, T, Custom);
setOperationAction(ISD::MSTORE, T, Custom);
setOperationAction(ISD::CTLZ, T, Custom);
setOperationAction(ISD::CTTZ, T, Custom);
setOperationAction(ISD::CTPOP, T, Custom);
@ -188,6 +192,9 @@ HexagonTargetLowering::initializeHVXLowering() {
setOperationAction(ISD::AND, BoolW, Custom);
setOperationAction(ISD::OR, BoolW, Custom);
setOperationAction(ISD::XOR, BoolW, Custom);
// Masked load/store takes a mask that may need splitting.
setOperationAction(ISD::MLOAD, BoolW, Custom);
setOperationAction(ISD::MSTORE, BoolW, Custom);
}
for (MVT T : LegalV) {
@ -1593,7 +1600,7 @@ HexagonTargetLowering::LowerHvxShift(SDValue Op, SelectionDAG &DAG) const {
SDValue
HexagonTargetLowering::LowerHvxIntrinsic(SDValue Op, SelectionDAG &DAG) const {
const SDLoc &dl(Op);
const SDLoc &dl(Op);
MVT ResTy = ty(Op);
unsigned IntNo = cast<ConstantSDNode>(Op.getOperand(0))->getZExtValue();
@ -1613,6 +1620,75 @@ HexagonTargetLowering::LowerHvxIntrinsic(SDValue Op, SelectionDAG &DAG) const {
return Op;
}
SDValue
HexagonTargetLowering::LowerHvxMaskedOp(SDValue Op, SelectionDAG &DAG) const {
const SDLoc &dl(Op);
unsigned HwLen = Subtarget.getVectorLength();
auto *MaskN = cast<MaskedLoadStoreSDNode>(Op.getNode());
SDValue Mask = MaskN->getMask();
SDValue Chain = MaskN->getChain();
SDValue Base = MaskN->getBasePtr();
auto *MemOp = MaskN->getMemOperand();
unsigned Opc = Op->getOpcode();
assert(Opc == ISD::MLOAD || Opc == ISD::MSTORE);
if (Opc == ISD::MLOAD) {
MVT ValTy = ty(Op);
SDValue Load = DAG.getLoad(ValTy, dl, Chain, Base, MaskN->getMemOperand());
SDValue Thru = cast<MaskedLoadSDNode>(MaskN)->getPassThru();
if (isUndef(Thru))
return Load;
SDValue VSel = DAG.getNode(ISD::VSELECT, dl, ValTy, Mask, Load, Thru);
return DAG.getMergeValues({VSel, Load.getValue(1)}, dl);
}
// MSTORE
// HVX only has aligned masked stores.
// TODO: Fold negations of the mask into the store.
unsigned StoreOpc = Hexagon::V6_vS32b_qpred_ai;
SDValue Value = cast<MaskedStoreSDNode>(MaskN)->getValue();
SDValue Offset0 = DAG.getTargetConstant(0, dl, ty(Base));
if (MaskN->getAlign().value() % HwLen == 0) {
SDValue Store = getInstr(StoreOpc, dl, MVT::Other,
{Mask, Base, Offset0, Value, Chain}, DAG);
DAG.setNodeMemRefs(cast<MachineSDNode>(Store.getNode()), {MemOp});
return Store;
}
// Unaligned case.
auto StoreAlign = [&](SDValue V, SDValue A) {
SDValue Z = getZero(dl, ty(V), DAG);
// TODO: use funnel shifts?
// vlalign(Vu,Vv,Rt) rotates the pair Vu:Vv left by Rt and takes the
// upper half.
SDValue LoV = getInstr(Hexagon::V6_vlalignb, dl, ty(V), {V, Z, A}, DAG);
SDValue HiV = getInstr(Hexagon::V6_vlalignb, dl, ty(V), {Z, V, A}, DAG);
return std::make_pair(LoV, HiV);
};
MVT ByteTy = MVT::getVectorVT(MVT::i8, HwLen);
MVT BoolTy = MVT::getVectorVT(MVT::i1, HwLen);
SDValue MaskV = DAG.getNode(HexagonISD::Q2V, dl, ByteTy, Mask);
VectorPair Tmp = StoreAlign(MaskV, Base);
VectorPair MaskU = {DAG.getNode(HexagonISD::V2Q, dl, BoolTy, Tmp.first),
DAG.getNode(HexagonISD::V2Q, dl, BoolTy, Tmp.second)};
VectorPair ValueU = StoreAlign(Value, Base);
SDValue Offset1 = DAG.getTargetConstant(HwLen, dl, MVT::i32);
SDValue StoreLo =
getInstr(StoreOpc, dl, MVT::Other,
{MaskU.first, Base, Offset0, ValueU.first, Chain}, DAG);
SDValue StoreHi =
getInstr(StoreOpc, dl, MVT::Other,
{MaskU.second, Base, Offset1, ValueU.second, Chain}, DAG);
DAG.setNodeMemRefs(cast<MachineSDNode>(StoreLo.getNode()), {MemOp});
DAG.setNodeMemRefs(cast<MachineSDNode>(StoreHi.getNode()), {MemOp});
return DAG.getNode(ISD::TokenFactor, dl, MVT::Other, {StoreLo, StoreHi});
}
SDValue
HexagonTargetLowering::SplitHvxPairOp(SDValue Op, SelectionDAG &DAG) const {
assert(!Op.isMachineOpcode());
@ -1648,45 +1724,81 @@ HexagonTargetLowering::SplitHvxPairOp(SDValue Op, SelectionDAG &DAG) const {
SDValue
HexagonTargetLowering::SplitHvxMemOp(SDValue Op, SelectionDAG &DAG) const {
LSBaseSDNode *BN = cast<LSBaseSDNode>(Op.getNode());
assert(BN->isUnindexed());
MVT MemTy = BN->getMemoryVT().getSimpleVT();
auto *MemN = cast<MemSDNode>(Op.getNode());
MVT MemTy = MemN->getMemoryVT().getSimpleVT();
if (!isHvxPairTy(MemTy))
return Op;
const SDLoc &dl(Op);
unsigned HwLen = Subtarget.getVectorLength();
MVT SingleTy = typeSplit(MemTy).first;
SDValue Chain = BN->getChain();
SDValue Base0 = BN->getBasePtr();
SDValue Chain = MemN->getChain();
SDValue Base0 = MemN->getBasePtr();
SDValue Base1 = DAG.getMemBasePlusOffset(Base0, TypeSize::Fixed(HwLen), dl);
MachineMemOperand *MOp0 = nullptr, *MOp1 = nullptr;
if (MachineMemOperand *MMO = BN->getMemOperand()) {
if (MachineMemOperand *MMO = MemN->getMemOperand()) {
MachineFunction &MF = DAG.getMachineFunction();
MOp0 = MF.getMachineMemOperand(MMO, 0, HwLen);
MOp1 = MF.getMachineMemOperand(MMO, HwLen, HwLen);
}
unsigned MemOpc = BN->getOpcode();
SDValue NewOp;
unsigned MemOpc = MemN->getOpcode();
if (MemOpc == ISD::LOAD) {
assert(cast<LoadSDNode>(Op)->isUnindexed());
SDValue Load0 = DAG.getLoad(SingleTy, dl, Chain, Base0, MOp0);
SDValue Load1 = DAG.getLoad(SingleTy, dl, Chain, Base1, MOp1);
NewOp = DAG.getMergeValues(
{ DAG.getNode(ISD::CONCAT_VECTORS, dl, MemTy, Load0, Load1),
DAG.getNode(ISD::TokenFactor, dl, MVT::Other,
Load0.getValue(1), Load1.getValue(1)) }, dl);
} else {
assert(MemOpc == ISD::STORE);
return DAG.getMergeValues(
{ DAG.getNode(ISD::CONCAT_VECTORS, dl, MemTy, Load0, Load1),
DAG.getNode(ISD::TokenFactor, dl, MVT::Other,
Load0.getValue(1), Load1.getValue(1)) }, dl);
}
if (MemOpc == ISD::STORE) {
assert(cast<StoreSDNode>(Op)->isUnindexed());
VectorPair Vals = opSplit(cast<StoreSDNode>(Op)->getValue(), dl, DAG);
SDValue Store0 = DAG.getStore(Chain, dl, Vals.first, Base0, MOp0);
SDValue Store1 = DAG.getStore(Chain, dl, Vals.second, Base1, MOp1);
NewOp = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, Store0, Store1);
return DAG.getNode(ISD::TokenFactor, dl, MVT::Other, Store0, Store1);
}
return NewOp;
assert(MemOpc == ISD::MLOAD || MemOpc == ISD::MSTORE);
auto MaskN = cast<MaskedLoadStoreSDNode>(Op);
assert(MaskN->isUnindexed());
VectorPair Masks = opSplit(MaskN->getMask(), dl, DAG);
SDValue Offset = DAG.getUNDEF(MVT::i32);
if (MemOpc == ISD::MLOAD) {
VectorPair Thru =
opSplit(cast<MaskedLoadSDNode>(Op)->getPassThru(), dl, DAG);
SDValue MLoad0 =
DAG.getMaskedLoad(SingleTy, dl, Chain, Base0, Offset, Masks.first,
Thru.first, SingleTy, MOp0, ISD::UNINDEXED,
ISD::NON_EXTLOAD, false);
SDValue MLoad1 =
DAG.getMaskedLoad(SingleTy, dl, Chain, Base1, Offset, Masks.second,
Thru.second, SingleTy, MOp1, ISD::UNINDEXED,
ISD::NON_EXTLOAD, false);
return DAG.getMergeValues(
{ DAG.getNode(ISD::CONCAT_VECTORS, dl, MemTy, MLoad0, MLoad1),
DAG.getNode(ISD::TokenFactor, dl, MVT::Other,
MLoad0.getValue(1), MLoad1.getValue(1)) }, dl);
}
if (MemOpc == ISD::MSTORE) {
VectorPair Vals = opSplit(cast<MaskedStoreSDNode>(Op)->getValue(), dl, DAG);
SDValue MStore0 = DAG.getMaskedStore(Chain, dl, Vals.first, Base0, Offset,
Masks.first, SingleTy, MOp0,
ISD::UNINDEXED, false, false);
SDValue MStore1 = DAG.getMaskedStore(Chain, dl, Vals.second, Base1, Offset,
Masks.second, SingleTy, MOp1,
ISD::UNINDEXED, false, false);
return DAG.getNode(ISD::TokenFactor, dl, MVT::Other, MStore0, MStore1);
}
std::string Name = "Unexpected operation: " + Op->getOperationName(&DAG);
llvm_unreachable(Name.c_str());
}
SDValue
@ -1749,6 +1861,8 @@ HexagonTargetLowering::LowerHvxOperation(SDValue Op, SelectionDAG &DAG) const {
case ISD::SETCC:
case ISD::INTRINSIC_VOID: return Op;
case ISD::INTRINSIC_WO_CHAIN: return LowerHvxIntrinsic(Op, DAG);
case ISD::MLOAD:
case ISD::MSTORE: return LowerHvxMaskedOp(Op, DAG);
// Unaligned loads will be handled by the default lowering.
case ISD::LOAD: return SDValue();
}
@ -1761,6 +1875,25 @@ HexagonTargetLowering::LowerHvxOperation(SDValue Op, SelectionDAG &DAG) const {
void
HexagonTargetLowering::LowerHvxOperationWrapper(SDNode *N,
SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
unsigned Opc = N->getOpcode();
SDValue Op(N, 0);
switch (Opc) {
case ISD::MLOAD:
if (isHvxPairTy(ty(Op))) {
SDValue S = SplitHvxMemOp(Op, DAG);
assert(S->getOpcode() == ISD::MERGE_VALUES);
Results.push_back(S.getOperand(0));
Results.push_back(S.getOperand(1));
}
break;
case ISD::MSTORE:
if (isHvxPairTy(ty(Op->getOperand(1)))) { // Stored value
SDValue S = SplitHvxMemOp(Op, DAG);
Results.push_back(S);
}
break;
}
}
void
@ -1783,6 +1916,8 @@ HexagonTargetLowering::ReplaceHvxNodeResults(SDNode *N,
SDValue
HexagonTargetLowering::PerformHvxDAGCombine(SDNode *N, DAGCombinerInfo &DCI)
const {
if (DCI.isBeforeLegalizeOps())
return SDValue();
const SDLoc &dl(N);
SDValue Op(N, 0);

View File

@ -2721,6 +2721,8 @@ bool HexagonInstrInfo::isValidOffset(unsigned Opcode, int Offset,
case Hexagon::PS_vloadrw_nt_ai:
case Hexagon::V6_vL32b_ai:
case Hexagon::V6_vS32b_ai:
case Hexagon::V6_vS32b_qpred_ai:
case Hexagon::V6_vS32b_nqpred_ai:
case Hexagon::V6_vL32b_nt_ai:
case Hexagon::V6_vS32b_nt_ai:
case Hexagon::V6_vL32Ub_ai:

View File

@ -364,6 +364,14 @@ let Predicates = [UseHVX] in {
(V6_vasrw (V6_vaslw HVI32:$Vs, (A2_tfrsi 16)), (A2_tfrsi 16))>;
}
// Take a pair of vectors Vt:Vs and shift them towards LSB by (Rt & HwLen).
def: Pat<(VecI8 (valign HVI8:$Vt, HVI8:$Vs, I32:$Rt)),
(LoVec (V6_valignb HvxVR:$Vt, HvxVR:$Vs, I32:$Rt))>;
def: Pat<(VecI16 (valign HVI16:$Vt, HVI16:$Vs, I32:$Rt)),
(LoVec (V6_valignb HvxVR:$Vt, HvxVR:$Vs, I32:$Rt))>;
def: Pat<(VecI32 (valign HVI32:$Vt, HVI32:$Vs, I32:$Rt)),
(LoVec (V6_valignb HvxVR:$Vt, HvxVR:$Vs, I32:$Rt))>;
def: Pat<(HexagonVASL HVI8:$Vs, I32:$Rt),
(V6_vpackeb (V6_vaslh (HiVec (VZxtb HvxVR:$Vs)), I32:$Rt),
(V6_vaslh (LoVec (VZxtb HvxVR:$Vs)), I32:$Rt))>;

View File

@ -35,6 +35,9 @@ static cl::opt<bool> EmitLookupTables("hexagon-emit-lookup-tables",
cl::init(true), cl::Hidden,
cl::desc("Control lookup table emission on Hexagon target"));
static cl::opt<bool> HexagonMaskedVMem("hexagon-masked-vmem", cl::init(true),
cl::Hidden, cl::desc("Enable loop vectorizer for HVX"));
// Constant "cost factor" to make floating point operations more expensive
// in terms of vectorization cost. This isn't the best way, but it should
// do. Ultimately, the cost should use cycles.
@ -45,8 +48,7 @@ bool HexagonTTIImpl::useHVX() const {
}
bool HexagonTTIImpl::isTypeForHVX(Type *VecTy) const {
assert(VecTy->isVectorTy());
if (isa<ScalableVectorType>(VecTy))
if (!VecTy->isVectorTy() || isa<ScalableVectorType>(VecTy))
return false;
// Avoid types like <2 x i32*>.
if (!cast<VectorType>(VecTy)->getElementType()->isIntegerTy())
@ -308,6 +310,14 @@ unsigned HexagonTTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
return 1;
}
bool HexagonTTIImpl::isLegalMaskedStore(Type *DataType, Align /*Alignment*/) {
return HexagonMaskedVMem && isTypeForHVX(DataType);
}
bool HexagonTTIImpl::isLegalMaskedLoad(Type *DataType, Align /*Alignment*/) {
return HexagonMaskedVMem && isTypeForHVX(DataType);
}
/// --- Vector TTI end ---
unsigned HexagonTTIImpl::getPrefetchDistance() const {

View File

@ -155,6 +155,9 @@ public:
return 1;
}
bool isLegalMaskedStore(Type *DataType, Align Alignment);
bool isLegalMaskedLoad(Type *DataType, Align Alignment);
/// @}
int getUserCost(const User *U, ArrayRef<const Value *> Operands,

View File

@ -0,0 +1,35 @@
; RUN: llc -march=hexagon < %s | FileCheck %s
; CHECK-LABEL: f0:
; CHECK: vmemu
; CHECK: vmux
define <128 x i8> @f0(<128 x i8>* %a0, i32 %a1, i32 %a2) #0 {
%q0 = call <128 x i1> @llvm.hexagon.V6.pred.scalar2.128B(i32 %a2)
%v0 = call <32 x i32> @llvm.hexagon.V6.lvsplatb.128B(i32 %a1)
%v1 = bitcast <32 x i32> %v0 to <128 x i8>
%v2 = call <128 x i8> @llvm.masked.load.v128i8.p0v128i8(<128 x i8>* %a0, i32 4, <128 x i1> %q0, <128 x i8> %v1)
ret <128 x i8> %v2
}
; CHECK-LABEL: f1:
; CHECK: vlalign
; CHECK: if (q{{.}}) vmem{{.*}} = v
define void @f1(<128 x i8>* %a0, i32 %a1, i32 %a2) #0 {
%q0 = call <128 x i1> @llvm.hexagon.V6.pred.scalar2.128B(i32 %a2)
%v0 = call <32 x i32> @llvm.hexagon.V6.lvsplatb.128B(i32 %a1)
%v1 = bitcast <32 x i32> %v0 to <128 x i8>
call void @llvm.masked.store.v128i8.p0v128i8(<128 x i8> %v1, <128 x i8>* %a0, i32 4, <128 x i1> %q0)
ret void
}
declare <128 x i1> @llvm.hexagon.V6.pred.scalar2.128B(i32) #1
declare <32 x i32> @llvm.hexagon.V6.lvsplatb.128B(i32) #1
declare <128 x i8> @llvm.masked.load.v128i8.p0v128i8(<128 x i8>*, i32 immarg, <128 x i1>, <128 x i8>) #2
declare void @llvm.masked.store.v128i8.p0v128i8(<128 x i8>, <128 x i8>*, i32 immarg, <128 x i1>) #2
attributes #0 = { "target-cpu"="hexagonv65" "target-features"="+hvx,+hvx-length128b" }
attributes #1 = { nounwind readnone }
attributes #2 = { argmemonly nounwind readonly willreturn }
attributes #3 = { argmemonly nounwind willreturn }

View File

@ -1,4 +1,4 @@
; RUN: llc -march=hexagon -hexagon-instsimplify=0 < %s | FileCheck %s
; RUN: llc -march=hexagon -hexagon-instsimplify=0 -hexagon-masked-vmem=0 < %s | FileCheck %s
; Test that LLVM does not assert and bitcast v64i1 to i64 is lowered
; without crashing.

View File

@ -1,4 +1,4 @@
; RUN: llc -march=hexagon -hexagon-instsimplify=0 < %s | FileCheck %s
; RUN: llc -march=hexagon -hexagon-instsimplify=0 -hexagon-masked-vmem=0 < %s | FileCheck %s
; This test checks that store a vector predicate of type v128i1 is lowered
; without crashing.