forked from OSchip/llvm-project
[SVE][CodeGen] Fix scalable vector issues in DAGTypeLegalizer::GenWidenVectorStores
In DAGTypeLegalizer::GenWidenVectorStores the algorithm assumes it only ever deals with fixed width types, hence the offsets for each individual store never take 'vscale' into account. I've changed the main loop in that function to use TypeSize instead of unsigned for tracking the remaining store amount and offset increment. In addition, I've changed the loop to use the new IncrementPointer helper function for updating the addresses in each iteration, since this handles scalable vector types. Whilst fixing this function I also fixed a minor issue in IncrementPointer whereby we were not adding the no-unsigned-wrap flag for the add instruction in the same way as the fixed width case does. Also, I've added a report_fatal_error in GenWidenVectorTruncStores, since this code currently uses a sequence of element-by-element scalar stores. I've added new tests in CodeGen/AArch64/sve-intrinsics-stores.ll CodeGen/AArch64/sve-st1-addressing-mode-reg-imm.ll for the changes in GenWidenVectorStores. Differential Revision: https://reviews.llvm.org/D84937
This commit is contained in:
parent
3ec3fcb97a
commit
6af1677161
|
@ -131,6 +131,20 @@ public:
|
|||
return { MinSize / RHS, IsScalable };
|
||||
}
|
||||
|
||||
TypeSize &operator-=(TypeSize RHS) {
|
||||
assert(IsScalable == RHS.IsScalable &&
|
||||
"Subtraction using mixed scalable and fixed types");
|
||||
MinSize -= RHS.MinSize;
|
||||
return *this;
|
||||
}
|
||||
|
||||
TypeSize &operator+=(TypeSize RHS) {
|
||||
assert(IsScalable == RHS.IsScalable &&
|
||||
"Addition using mixed scalable and fixed types");
|
||||
MinSize += RHS.MinSize;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Return the minimum size with the assumption that the size is exact.
|
||||
// Use in places where a scalable size doesn't make sense (e.g. non-vector
|
||||
// types, or vectors in backends which don't support scalable vectors).
|
||||
|
|
|
@ -780,8 +780,8 @@ private:
|
|||
|
||||
// Helper function for incrementing the pointer when splitting
|
||||
// memory operations
|
||||
void IncrementPointer(MemSDNode *N, EVT MemVT,
|
||||
MachinePointerInfo &MPI, SDValue &Ptr);
|
||||
void IncrementPointer(MemSDNode *N, EVT MemVT, MachinePointerInfo &MPI,
|
||||
SDValue &Ptr, uint64_t *ScaledOffset = nullptr);
|
||||
|
||||
// Vector Result Splitting: <128 x ty> -> 2 x <64 x ty>.
|
||||
void SplitVectorResult(SDNode *N, unsigned ResNo);
|
||||
|
|
|
@ -984,16 +984,20 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
|
|||
}
|
||||
|
||||
void DAGTypeLegalizer::IncrementPointer(MemSDNode *N, EVT MemVT,
|
||||
MachinePointerInfo &MPI,
|
||||
SDValue &Ptr) {
|
||||
MachinePointerInfo &MPI, SDValue &Ptr,
|
||||
uint64_t *ScaledOffset) {
|
||||
SDLoc DL(N);
|
||||
unsigned IncrementSize = MemVT.getSizeInBits().getKnownMinSize() / 8;
|
||||
|
||||
if (MemVT.isScalableVector()) {
|
||||
SDNodeFlags Flags;
|
||||
SDValue BytesIncrement = DAG.getVScale(
|
||||
DL, Ptr.getValueType(),
|
||||
APInt(Ptr.getValueSizeInBits().getFixedSize(), IncrementSize));
|
||||
MPI = MachinePointerInfo(N->getPointerInfo().getAddrSpace());
|
||||
Flags.setNoUnsignedWrap(true);
|
||||
if (ScaledOffset)
|
||||
*ScaledOffset += IncrementSize;
|
||||
Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), Ptr, BytesIncrement);
|
||||
} else {
|
||||
MPI = N->getPointerInfo().getWithOffset(IncrementSize);
|
||||
|
@ -4844,7 +4848,7 @@ static EVT FindMemType(SelectionDAG& DAG, const TargetLowering &TLI,
|
|||
|
||||
// If we have one element to load/store, return it.
|
||||
EVT RetVT = WidenEltVT;
|
||||
if (Width == WidenEltWidth)
|
||||
if (!Scalable && Width == WidenEltWidth)
|
||||
return RetVT;
|
||||
|
||||
// See if there is larger legal integer than the element type to load/store.
|
||||
|
@ -5139,55 +5143,62 @@ void DAGTypeLegalizer::GenWidenVectorStores(SmallVectorImpl<SDValue> &StChain,
|
|||
SDLoc dl(ST);
|
||||
|
||||
EVT StVT = ST->getMemoryVT();
|
||||
unsigned StWidth = StVT.getSizeInBits();
|
||||
TypeSize StWidth = StVT.getSizeInBits();
|
||||
EVT ValVT = ValOp.getValueType();
|
||||
unsigned ValWidth = ValVT.getSizeInBits();
|
||||
TypeSize ValWidth = ValVT.getSizeInBits();
|
||||
EVT ValEltVT = ValVT.getVectorElementType();
|
||||
unsigned ValEltWidth = ValEltVT.getSizeInBits();
|
||||
unsigned ValEltWidth = ValEltVT.getSizeInBits().getFixedSize();
|
||||
assert(StVT.getVectorElementType() == ValEltVT);
|
||||
assert(StVT.isScalableVector() == ValVT.isScalableVector() &&
|
||||
"Mismatch between store and value types");
|
||||
|
||||
int Idx = 0; // current index to store
|
||||
unsigned Offset = 0; // offset from base to store
|
||||
while (StWidth != 0) {
|
||||
|
||||
MachinePointerInfo MPI = ST->getPointerInfo();
|
||||
uint64_t ScaledOffset = 0;
|
||||
while (StWidth.isNonZero()) {
|
||||
// Find the largest vector type we can store with.
|
||||
EVT NewVT = FindMemType(DAG, TLI, StWidth, ValVT);
|
||||
unsigned NewVTWidth = NewVT.getSizeInBits();
|
||||
unsigned Increment = NewVTWidth / 8;
|
||||
EVT NewVT = FindMemType(DAG, TLI, StWidth.getKnownMinSize(), ValVT);
|
||||
TypeSize NewVTWidth = NewVT.getSizeInBits();
|
||||
|
||||
if (NewVT.isVector()) {
|
||||
unsigned NumVTElts = NewVT.getVectorNumElements();
|
||||
unsigned NumVTElts = NewVT.getVectorMinNumElements();
|
||||
do {
|
||||
Align NewAlign = ScaledOffset == 0
|
||||
? ST->getOriginalAlign()
|
||||
: commonAlignment(ST->getAlign(), ScaledOffset);
|
||||
SDValue EOp = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, NewVT, ValOp,
|
||||
DAG.getVectorIdxConstant(Idx, dl));
|
||||
StChain.push_back(DAG.getStore(
|
||||
Chain, dl, EOp, BasePtr, ST->getPointerInfo().getWithOffset(Offset),
|
||||
ST->getOriginalAlign(), MMOFlags, AAInfo));
|
||||
SDValue PartStore = DAG.getStore(Chain, dl, EOp, BasePtr, MPI, NewAlign,
|
||||
MMOFlags, AAInfo);
|
||||
StChain.push_back(PartStore);
|
||||
|
||||
StWidth -= NewVTWidth;
|
||||
Offset += Increment;
|
||||
Idx += NumVTElts;
|
||||
|
||||
BasePtr =
|
||||
DAG.getObjectPtrOffset(dl, BasePtr, TypeSize::Fixed(Increment));
|
||||
} while (StWidth != 0 && StWidth >= NewVTWidth);
|
||||
IncrementPointer(cast<StoreSDNode>(PartStore), NewVT, MPI, BasePtr,
|
||||
&ScaledOffset);
|
||||
} while (StWidth.isNonZero() && StWidth >= NewVTWidth);
|
||||
} else {
|
||||
// Cast the vector to the scalar type we can store.
|
||||
unsigned NumElts = ValWidth / NewVTWidth;
|
||||
unsigned NumElts = ValWidth.getFixedSize() / NewVTWidth.getFixedSize();
|
||||
EVT NewVecVT = EVT::getVectorVT(*DAG.getContext(), NewVT, NumElts);
|
||||
SDValue VecOp = DAG.getNode(ISD::BITCAST, dl, NewVecVT, ValOp);
|
||||
// Readjust index position based on new vector type.
|
||||
Idx = Idx * ValEltWidth / NewVTWidth;
|
||||
Idx = Idx * ValEltWidth / NewVTWidth.getFixedSize();
|
||||
do {
|
||||
SDValue EOp = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, NewVT, VecOp,
|
||||
DAG.getVectorIdxConstant(Idx++, dl));
|
||||
StChain.push_back(DAG.getStore(
|
||||
Chain, dl, EOp, BasePtr, ST->getPointerInfo().getWithOffset(Offset),
|
||||
ST->getOriginalAlign(), MMOFlags, AAInfo));
|
||||
SDValue PartStore =
|
||||
DAG.getStore(Chain, dl, EOp, BasePtr, MPI, ST->getOriginalAlign(),
|
||||
MMOFlags, AAInfo);
|
||||
StChain.push_back(PartStore);
|
||||
|
||||
StWidth -= NewVTWidth;
|
||||
Offset += Increment;
|
||||
BasePtr =
|
||||
DAG.getObjectPtrOffset(dl, BasePtr, TypeSize::Fixed(Increment));
|
||||
} while (StWidth != 0 && StWidth >= NewVTWidth);
|
||||
IncrementPointer(cast<StoreSDNode>(PartStore), NewVT, MPI, BasePtr);
|
||||
} while (StWidth.isNonZero() && StWidth >= NewVTWidth);
|
||||
// Restore index back to be relative to the original widen element type.
|
||||
Idx = Idx * NewVTWidth / ValEltWidth;
|
||||
Idx = Idx * NewVTWidth.getFixedSize() / ValEltWidth;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -5210,8 +5221,13 @@ DAGTypeLegalizer::GenWidenVectorTruncStores(SmallVectorImpl<SDValue> &StChain,
|
|||
// It must be true that the wide vector type is bigger than where we need to
|
||||
// store.
|
||||
assert(StVT.isVector() && ValOp.getValueType().isVector());
|
||||
assert(StVT.isScalableVector() == ValOp.getValueType().isScalableVector());
|
||||
assert(StVT.bitsLT(ValOp.getValueType()));
|
||||
|
||||
if (StVT.isScalableVector())
|
||||
report_fatal_error("Generating widen scalable vector truncating stores not "
|
||||
"yet supported");
|
||||
|
||||
// For truncating stores, we can not play the tricks of chopping legal vector
|
||||
// types and bitcast it to the right type. Instead, we unroll the store.
|
||||
EVT StEltVT = StVT.getVectorElementType();
|
||||
|
|
|
@ -437,6 +437,69 @@ define void @stnt1d_f64(<vscale x 2 x double> %data, <vscale x 2 x i1> %pred, do
|
|||
}
|
||||
|
||||
|
||||
; Stores (tuples)
|
||||
|
||||
define void @store_i64_tuple3(<vscale x 6 x i64>* %out, <vscale x 2 x i64> %in1, <vscale x 2 x i64> %in2, <vscale x 2 x i64> %in3) {
|
||||
; CHECK-LABEL: store_i64_tuple3
|
||||
; CHECK: st1d { z2.d }, p0, [x0, #2, mul vl]
|
||||
; CHECK-NEXT: st1d { z1.d }, p0, [x0, #1, mul vl]
|
||||
; CHECK-NEXT: st1d { z0.d }, p0, [x0]
|
||||
%tuple = tail call <vscale x 6 x i64> @llvm.aarch64.sve.tuple.create3.nxv6i64.nxv2i64(<vscale x 2 x i64> %in1, <vscale x 2 x i64> %in2, <vscale x 2 x i64> %in3)
|
||||
store <vscale x 6 x i64> %tuple, <vscale x 6 x i64>* %out
|
||||
ret void
|
||||
}
|
||||
|
||||
define void @store_i64_tuple4(<vscale x 8 x i64>* %out, <vscale x 2 x i64> %in1, <vscale x 2 x i64> %in2, <vscale x 2 x i64> %in3, <vscale x 2 x i64> %in4) {
|
||||
; CHECK-LABEL: store_i64_tuple4
|
||||
; CHECK: st1d { z3.d }, p0, [x0, #3, mul vl]
|
||||
; CHECK-NEXT: st1d { z2.d }, p0, [x0, #2, mul vl]
|
||||
; CHECK-NEXT: st1d { z1.d }, p0, [x0, #1, mul vl]
|
||||
; CHECK-NEXT: st1d { z0.d }, p0, [x0]
|
||||
%tuple = tail call <vscale x 8 x i64> @llvm.aarch64.sve.tuple.create4.nxv8i64.nxv2i64(<vscale x 2 x i64> %in1, <vscale x 2 x i64> %in2, <vscale x 2 x i64> %in3, <vscale x 2 x i64> %in4)
|
||||
store <vscale x 8 x i64> %tuple, <vscale x 8 x i64>* %out
|
||||
ret void
|
||||
}
|
||||
|
||||
define void @store_i16_tuple2(<vscale x 16 x i16>* %out, <vscale x 8 x i16> %in1, <vscale x 8 x i16> %in2) {
|
||||
; CHECK-LABEL: store_i16_tuple2
|
||||
; CHECK: st1h { z1.h }, p0, [x0, #1, mul vl]
|
||||
; CHECK-NEXT: st1h { z0.h }, p0, [x0]
|
||||
%tuple = tail call <vscale x 16 x i16> @llvm.aarch64.sve.tuple.create2.nxv16i16.nxv8i16(<vscale x 8 x i16> %in1, <vscale x 8 x i16> %in2)
|
||||
store <vscale x 16 x i16> %tuple, <vscale x 16 x i16>* %out
|
||||
ret void
|
||||
}
|
||||
|
||||
define void @store_i16_tuple3(<vscale x 24 x i16>* %out, <vscale x 8 x i16> %in1, <vscale x 8 x i16> %in2, <vscale x 8 x i16> %in3) {
|
||||
; CHECK-LABEL: store_i16_tuple3
|
||||
; CHECK: st1h { z2.h }, p0, [x0, #2, mul vl]
|
||||
; CHECK-NEXT: st1h { z1.h }, p0, [x0, #1, mul vl]
|
||||
; CHECK-NEXT: st1h { z0.h }, p0, [x0]
|
||||
%tuple = tail call <vscale x 24 x i16> @llvm.aarch64.sve.tuple.create3.nxv24i16.nxv8i16(<vscale x 8 x i16> %in1, <vscale x 8 x i16> %in2, <vscale x 8 x i16> %in3)
|
||||
store <vscale x 24 x i16> %tuple, <vscale x 24 x i16>* %out
|
||||
ret void
|
||||
}
|
||||
|
||||
define void @store_f32_tuple3(<vscale x 12 x float>* %out, <vscale x 4 x float> %in1, <vscale x 4 x float> %in2, <vscale x 4 x float> %in3) {
|
||||
; CHECK-LABEL: store_f32_tuple3
|
||||
; CHECK: st1w { z2.s }, p0, [x0, #2, mul vl]
|
||||
; CHECK-NEXT: st1w { z1.s }, p0, [x0, #1, mul vl]
|
||||
; CHECK-NEXT: st1w { z0.s }, p0, [x0]
|
||||
%tuple = tail call <vscale x 12 x float> @llvm.aarch64.sve.tuple.create3.nxv12f32.nxv4f32(<vscale x 4 x float> %in1, <vscale x 4 x float> %in2, <vscale x 4 x float> %in3)
|
||||
store <vscale x 12 x float> %tuple, <vscale x 12 x float>* %out
|
||||
ret void
|
||||
}
|
||||
|
||||
define void @store_f32_tuple4(<vscale x 16 x float>* %out, <vscale x 4 x float> %in1, <vscale x 4 x float> %in2, <vscale x 4 x float> %in3, <vscale x 4 x float> %in4) {
|
||||
; CHECK-LABEL: store_f32_tuple4
|
||||
; CHECK: st1w { z3.s }, p0, [x0, #3, mul vl]
|
||||
; CHECK-NEXT: st1w { z2.s }, p0, [x0, #2, mul vl]
|
||||
; CHECK-NEXT: st1w { z1.s }, p0, [x0, #1, mul vl]
|
||||
; CHECK-NEXT: st1w { z0.s }, p0, [x0]
|
||||
%tuple = tail call <vscale x 16 x float> @llvm.aarch64.sve.tuple.create4.nxv16f32.nxv4f32(<vscale x 4 x float> %in1, <vscale x 4 x float> %in2, <vscale x 4 x float> %in3, <vscale x 4 x float> %in4)
|
||||
store <vscale x 16 x float> %tuple, <vscale x 16 x float>* %out
|
||||
ret void
|
||||
}
|
||||
|
||||
declare void @llvm.aarch64.sve.st2.nxv16i8(<vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i1>, i8*)
|
||||
declare void @llvm.aarch64.sve.st2.nxv8i16(<vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i1>, i16*)
|
||||
declare void @llvm.aarch64.sve.st2.nxv4i32(<vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i1>, i32*)
|
||||
|
@ -473,5 +536,14 @@ declare void @llvm.aarch64.sve.stnt1.nxv8bf16(<vscale x 8 x bfloat>, <vscale x 8
|
|||
declare void @llvm.aarch64.sve.stnt1.nxv4f32(<vscale x 4 x float>, <vscale x 4 x i1>, float*)
|
||||
declare void @llvm.aarch64.sve.stnt1.nxv2f64(<vscale x 2 x double>, <vscale x 2 x i1>, double*)
|
||||
|
||||
declare <vscale x 6 x i64> @llvm.aarch64.sve.tuple.create3.nxv6i64.nxv2i64(<vscale x 2 x i64>, <vscale x 2 x i64>, <vscale x 2 x i64>)
|
||||
declare <vscale x 8 x i64> @llvm.aarch64.sve.tuple.create4.nxv8i64.nxv2i64(<vscale x 2 x i64>, <vscale x 2 x i64>, <vscale x 2 x i64>, <vscale x 2 x i64>)
|
||||
|
||||
declare <vscale x 16 x i16> @llvm.aarch64.sve.tuple.create2.nxv16i16.nxv8i16(<vscale x 8 x i16>, <vscale x 8 x i16>)
|
||||
declare <vscale x 24 x i16> @llvm.aarch64.sve.tuple.create3.nxv24i16.nxv8i16(<vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16>)
|
||||
|
||||
declare <vscale x 12 x float> @llvm.aarch64.sve.tuple.create3.nxv12f32.nxv4f32(<vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>)
|
||||
declare <vscale x 16 x float> @llvm.aarch64.sve.tuple.create4.nxv16f32.nxv4f32(<vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>)
|
||||
|
||||
; +bf16 is required for the bfloat version.
|
||||
attributes #0 = { "target-features"="+sve,+bf16" }
|
||||
|
|
|
@ -133,3 +133,37 @@ define void @store_nxv4f16(<vscale x 4 x half>* %out) {
|
|||
store <vscale x 4 x half> %splat, <vscale x 4 x half>* %out
|
||||
ret void
|
||||
}
|
||||
|
||||
; Splat stores of unusual FP scalable vector types
|
||||
|
||||
define void @store_nxv6f32(<vscale x 6 x float>* %out) {
|
||||
; CHECK-LABEL: store_nxv6f32:
|
||||
; CHECK: // %bb.0:
|
||||
; CHECK-NEXT: fmov z0.s, #1.00000000
|
||||
; CHECK-NEXT: ptrue p0.s
|
||||
; CHECK-NEXT: st1w { z0.s }, p0, [x0]
|
||||
; CHECK-NEXT: uunpklo z0.d, z0.s
|
||||
; CHECK-NEXT: ptrue p0.d
|
||||
; CHECK-NEXT: st1w { z0.d }, p0, [x0, #2, mul vl]
|
||||
; CHECK-NEXT: ret
|
||||
%ins = insertelement <vscale x 6 x float> undef, float 1.0, i32 0
|
||||
%splat = shufflevector <vscale x 6 x float> %ins, <vscale x 6 x float> undef, <vscale x 6 x i32> zeroinitializer
|
||||
store <vscale x 6 x float> %splat, <vscale x 6 x float>* %out
|
||||
ret void
|
||||
}
|
||||
|
||||
define void @store_nxv12f16(<vscale x 12 x half>* %out) {
|
||||
; CHECK-LABEL: store_nxv12f16:
|
||||
; CHECK: // %bb.0:
|
||||
; CHECK-NEXT: fmov z0.h, #1.00000000
|
||||
; CHECK-NEXT: ptrue p0.h
|
||||
; CHECK-NEXT: st1h { z0.h }, p0, [x0]
|
||||
; CHECK-NEXT: uunpklo z0.s, z0.h
|
||||
; CHECK-NEXT: ptrue p0.s
|
||||
; CHECK-NEXT: st1h { z0.s }, p0, [x0, #2, mul vl]
|
||||
; CHECK-NEXT: ret
|
||||
%ins = insertelement <vscale x 12 x half> undef, half 1.0, i32 0
|
||||
%splat = shufflevector <vscale x 12 x half> %ins, <vscale x 12 x half> undef, <vscale x 12 x i32> zeroinitializer
|
||||
store <vscale x 12 x half> %splat, <vscale x 12 x half>* %out
|
||||
ret void
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue