[SVE][CodeGen] Fix scalable vector issues in DAGTypeLegalizer::GenWidenVectorStores

In DAGTypeLegalizer::GenWidenVectorStores the algorithm assumes it only
ever deals with fixed width types, hence the offsets for each individual
store never take 'vscale' into account. I've changed the main loop in
that function to use TypeSize instead of unsigned for tracking the
remaining store amount and offset increment. In addition, I've changed
the loop to use the new IncrementPointer helper function for updating
the addresses in each iteration, since this handles scalable vector
types.

Whilst fixing this function I also fixed a minor issue in
IncrementPointer whereby we were not adding the no-unsigned-wrap flag
for the add instruction in the same way as the fixed width case does.

Also, I've added a report_fatal_error in GenWidenVectorTruncStores,
since this code currently uses a sequence of element-by-element scalar
stores.

I've added new tests in

  CodeGen/AArch64/sve-intrinsics-stores.ll
  CodeGen/AArch64/sve-st1-addressing-mode-reg-imm.ll

for the changes in GenWidenVectorStores.

Differential Revision: https://reviews.llvm.org/D84937
This commit is contained in:
David Sherwood 2020-07-31 13:56:02 +01:00
parent 3ec3fcb97a
commit 6af1677161
5 changed files with 167 additions and 31 deletions

View File

@ -131,6 +131,20 @@ public:
return { MinSize / RHS, IsScalable }; return { MinSize / RHS, IsScalable };
} }
TypeSize &operator-=(TypeSize RHS) {
assert(IsScalable == RHS.IsScalable &&
"Subtraction using mixed scalable and fixed types");
MinSize -= RHS.MinSize;
return *this;
}
TypeSize &operator+=(TypeSize RHS) {
assert(IsScalable == RHS.IsScalable &&
"Addition using mixed scalable and fixed types");
MinSize += RHS.MinSize;
return *this;
}
// Return the minimum size with the assumption that the size is exact. // Return the minimum size with the assumption that the size is exact.
// Use in places where a scalable size doesn't make sense (e.g. non-vector // Use in places where a scalable size doesn't make sense (e.g. non-vector
// types, or vectors in backends which don't support scalable vectors). // types, or vectors in backends which don't support scalable vectors).

View File

@ -780,8 +780,8 @@ private:
// Helper function for incrementing the pointer when splitting // Helper function for incrementing the pointer when splitting
// memory operations // memory operations
void IncrementPointer(MemSDNode *N, EVT MemVT, void IncrementPointer(MemSDNode *N, EVT MemVT, MachinePointerInfo &MPI,
MachinePointerInfo &MPI, SDValue &Ptr); SDValue &Ptr, uint64_t *ScaledOffset = nullptr);
// Vector Result Splitting: <128 x ty> -> 2 x <64 x ty>. // Vector Result Splitting: <128 x ty> -> 2 x <64 x ty>.
void SplitVectorResult(SDNode *N, unsigned ResNo); void SplitVectorResult(SDNode *N, unsigned ResNo);

View File

@ -984,16 +984,20 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
} }
void DAGTypeLegalizer::IncrementPointer(MemSDNode *N, EVT MemVT, void DAGTypeLegalizer::IncrementPointer(MemSDNode *N, EVT MemVT,
MachinePointerInfo &MPI, MachinePointerInfo &MPI, SDValue &Ptr,
SDValue &Ptr) { uint64_t *ScaledOffset) {
SDLoc DL(N); SDLoc DL(N);
unsigned IncrementSize = MemVT.getSizeInBits().getKnownMinSize() / 8; unsigned IncrementSize = MemVT.getSizeInBits().getKnownMinSize() / 8;
if (MemVT.isScalableVector()) { if (MemVT.isScalableVector()) {
SDNodeFlags Flags;
SDValue BytesIncrement = DAG.getVScale( SDValue BytesIncrement = DAG.getVScale(
DL, Ptr.getValueType(), DL, Ptr.getValueType(),
APInt(Ptr.getValueSizeInBits().getFixedSize(), IncrementSize)); APInt(Ptr.getValueSizeInBits().getFixedSize(), IncrementSize));
MPI = MachinePointerInfo(N->getPointerInfo().getAddrSpace()); MPI = MachinePointerInfo(N->getPointerInfo().getAddrSpace());
Flags.setNoUnsignedWrap(true);
if (ScaledOffset)
*ScaledOffset += IncrementSize;
Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), Ptr, BytesIncrement); Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), Ptr, BytesIncrement);
} else { } else {
MPI = N->getPointerInfo().getWithOffset(IncrementSize); MPI = N->getPointerInfo().getWithOffset(IncrementSize);
@ -4844,7 +4848,7 @@ static EVT FindMemType(SelectionDAG& DAG, const TargetLowering &TLI,
// If we have one element to load/store, return it. // If we have one element to load/store, return it.
EVT RetVT = WidenEltVT; EVT RetVT = WidenEltVT;
if (Width == WidenEltWidth) if (!Scalable && Width == WidenEltWidth)
return RetVT; return RetVT;
// See if there is larger legal integer than the element type to load/store. // See if there is larger legal integer than the element type to load/store.
@ -5139,55 +5143,62 @@ void DAGTypeLegalizer::GenWidenVectorStores(SmallVectorImpl<SDValue> &StChain,
SDLoc dl(ST); SDLoc dl(ST);
EVT StVT = ST->getMemoryVT(); EVT StVT = ST->getMemoryVT();
unsigned StWidth = StVT.getSizeInBits(); TypeSize StWidth = StVT.getSizeInBits();
EVT ValVT = ValOp.getValueType(); EVT ValVT = ValOp.getValueType();
unsigned ValWidth = ValVT.getSizeInBits(); TypeSize ValWidth = ValVT.getSizeInBits();
EVT ValEltVT = ValVT.getVectorElementType(); EVT ValEltVT = ValVT.getVectorElementType();
unsigned ValEltWidth = ValEltVT.getSizeInBits(); unsigned ValEltWidth = ValEltVT.getSizeInBits().getFixedSize();
assert(StVT.getVectorElementType() == ValEltVT); assert(StVT.getVectorElementType() == ValEltVT);
assert(StVT.isScalableVector() == ValVT.isScalableVector() &&
"Mismatch between store and value types");
int Idx = 0; // current index to store int Idx = 0; // current index to store
unsigned Offset = 0; // offset from base to store
while (StWidth != 0) { MachinePointerInfo MPI = ST->getPointerInfo();
uint64_t ScaledOffset = 0;
while (StWidth.isNonZero()) {
// Find the largest vector type we can store with. // Find the largest vector type we can store with.
EVT NewVT = FindMemType(DAG, TLI, StWidth, ValVT); EVT NewVT = FindMemType(DAG, TLI, StWidth.getKnownMinSize(), ValVT);
unsigned NewVTWidth = NewVT.getSizeInBits(); TypeSize NewVTWidth = NewVT.getSizeInBits();
unsigned Increment = NewVTWidth / 8;
if (NewVT.isVector()) { if (NewVT.isVector()) {
unsigned NumVTElts = NewVT.getVectorNumElements(); unsigned NumVTElts = NewVT.getVectorMinNumElements();
do { do {
Align NewAlign = ScaledOffset == 0
? ST->getOriginalAlign()
: commonAlignment(ST->getAlign(), ScaledOffset);
SDValue EOp = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, NewVT, ValOp, SDValue EOp = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, NewVT, ValOp,
DAG.getVectorIdxConstant(Idx, dl)); DAG.getVectorIdxConstant(Idx, dl));
StChain.push_back(DAG.getStore( SDValue PartStore = DAG.getStore(Chain, dl, EOp, BasePtr, MPI, NewAlign,
Chain, dl, EOp, BasePtr, ST->getPointerInfo().getWithOffset(Offset), MMOFlags, AAInfo);
ST->getOriginalAlign(), MMOFlags, AAInfo)); StChain.push_back(PartStore);
StWidth -= NewVTWidth; StWidth -= NewVTWidth;
Offset += Increment;
Idx += NumVTElts; Idx += NumVTElts;
BasePtr = IncrementPointer(cast<StoreSDNode>(PartStore), NewVT, MPI, BasePtr,
DAG.getObjectPtrOffset(dl, BasePtr, TypeSize::Fixed(Increment)); &ScaledOffset);
} while (StWidth != 0 && StWidth >= NewVTWidth); } while (StWidth.isNonZero() && StWidth >= NewVTWidth);
} else { } else {
// Cast the vector to the scalar type we can store. // Cast the vector to the scalar type we can store.
unsigned NumElts = ValWidth / NewVTWidth; unsigned NumElts = ValWidth.getFixedSize() / NewVTWidth.getFixedSize();
EVT NewVecVT = EVT::getVectorVT(*DAG.getContext(), NewVT, NumElts); EVT NewVecVT = EVT::getVectorVT(*DAG.getContext(), NewVT, NumElts);
SDValue VecOp = DAG.getNode(ISD::BITCAST, dl, NewVecVT, ValOp); SDValue VecOp = DAG.getNode(ISD::BITCAST, dl, NewVecVT, ValOp);
// Readjust index position based on new vector type. // Readjust index position based on new vector type.
Idx = Idx * ValEltWidth / NewVTWidth; Idx = Idx * ValEltWidth / NewVTWidth.getFixedSize();
do { do {
SDValue EOp = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, NewVT, VecOp, SDValue EOp = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, NewVT, VecOp,
DAG.getVectorIdxConstant(Idx++, dl)); DAG.getVectorIdxConstant(Idx++, dl));
StChain.push_back(DAG.getStore( SDValue PartStore =
Chain, dl, EOp, BasePtr, ST->getPointerInfo().getWithOffset(Offset), DAG.getStore(Chain, dl, EOp, BasePtr, MPI, ST->getOriginalAlign(),
ST->getOriginalAlign(), MMOFlags, AAInfo)); MMOFlags, AAInfo);
StChain.push_back(PartStore);
StWidth -= NewVTWidth; StWidth -= NewVTWidth;
Offset += Increment; IncrementPointer(cast<StoreSDNode>(PartStore), NewVT, MPI, BasePtr);
BasePtr = } while (StWidth.isNonZero() && StWidth >= NewVTWidth);
DAG.getObjectPtrOffset(dl, BasePtr, TypeSize::Fixed(Increment));
} while (StWidth != 0 && StWidth >= NewVTWidth);
// Restore index back to be relative to the original widen element type. // Restore index back to be relative to the original widen element type.
Idx = Idx * NewVTWidth / ValEltWidth; Idx = Idx * NewVTWidth.getFixedSize() / ValEltWidth;
} }
} }
} }
@ -5210,8 +5221,13 @@ DAGTypeLegalizer::GenWidenVectorTruncStores(SmallVectorImpl<SDValue> &StChain,
// It must be true that the wide vector type is bigger than where we need to // It must be true that the wide vector type is bigger than where we need to
// store. // store.
assert(StVT.isVector() && ValOp.getValueType().isVector()); assert(StVT.isVector() && ValOp.getValueType().isVector());
assert(StVT.isScalableVector() == ValOp.getValueType().isScalableVector());
assert(StVT.bitsLT(ValOp.getValueType())); assert(StVT.bitsLT(ValOp.getValueType()));
if (StVT.isScalableVector())
report_fatal_error("Generating widen scalable vector truncating stores not "
"yet supported");
// For truncating stores, we can not play the tricks of chopping legal vector // For truncating stores, we can not play the tricks of chopping legal vector
// types and bitcast it to the right type. Instead, we unroll the store. // types and bitcast it to the right type. Instead, we unroll the store.
EVT StEltVT = StVT.getVectorElementType(); EVT StEltVT = StVT.getVectorElementType();

View File

@ -437,6 +437,69 @@ define void @stnt1d_f64(<vscale x 2 x double> %data, <vscale x 2 x i1> %pred, do
} }
; Stores (tuples)
define void @store_i64_tuple3(<vscale x 6 x i64>* %out, <vscale x 2 x i64> %in1, <vscale x 2 x i64> %in2, <vscale x 2 x i64> %in3) {
; CHECK-LABEL: store_i64_tuple3
; CHECK: st1d { z2.d }, p0, [x0, #2, mul vl]
; CHECK-NEXT: st1d { z1.d }, p0, [x0, #1, mul vl]
; CHECK-NEXT: st1d { z0.d }, p0, [x0]
%tuple = tail call <vscale x 6 x i64> @llvm.aarch64.sve.tuple.create3.nxv6i64.nxv2i64(<vscale x 2 x i64> %in1, <vscale x 2 x i64> %in2, <vscale x 2 x i64> %in3)
store <vscale x 6 x i64> %tuple, <vscale x 6 x i64>* %out
ret void
}
define void @store_i64_tuple4(<vscale x 8 x i64>* %out, <vscale x 2 x i64> %in1, <vscale x 2 x i64> %in2, <vscale x 2 x i64> %in3, <vscale x 2 x i64> %in4) {
; CHECK-LABEL: store_i64_tuple4
; CHECK: st1d { z3.d }, p0, [x0, #3, mul vl]
; CHECK-NEXT: st1d { z2.d }, p0, [x0, #2, mul vl]
; CHECK-NEXT: st1d { z1.d }, p0, [x0, #1, mul vl]
; CHECK-NEXT: st1d { z0.d }, p0, [x0]
%tuple = tail call <vscale x 8 x i64> @llvm.aarch64.sve.tuple.create4.nxv8i64.nxv2i64(<vscale x 2 x i64> %in1, <vscale x 2 x i64> %in2, <vscale x 2 x i64> %in3, <vscale x 2 x i64> %in4)
store <vscale x 8 x i64> %tuple, <vscale x 8 x i64>* %out
ret void
}
define void @store_i16_tuple2(<vscale x 16 x i16>* %out, <vscale x 8 x i16> %in1, <vscale x 8 x i16> %in2) {
; CHECK-LABEL: store_i16_tuple2
; CHECK: st1h { z1.h }, p0, [x0, #1, mul vl]
; CHECK-NEXT: st1h { z0.h }, p0, [x0]
%tuple = tail call <vscale x 16 x i16> @llvm.aarch64.sve.tuple.create2.nxv16i16.nxv8i16(<vscale x 8 x i16> %in1, <vscale x 8 x i16> %in2)
store <vscale x 16 x i16> %tuple, <vscale x 16 x i16>* %out
ret void
}
define void @store_i16_tuple3(<vscale x 24 x i16>* %out, <vscale x 8 x i16> %in1, <vscale x 8 x i16> %in2, <vscale x 8 x i16> %in3) {
; CHECK-LABEL: store_i16_tuple3
; CHECK: st1h { z2.h }, p0, [x0, #2, mul vl]
; CHECK-NEXT: st1h { z1.h }, p0, [x0, #1, mul vl]
; CHECK-NEXT: st1h { z0.h }, p0, [x0]
%tuple = tail call <vscale x 24 x i16> @llvm.aarch64.sve.tuple.create3.nxv24i16.nxv8i16(<vscale x 8 x i16> %in1, <vscale x 8 x i16> %in2, <vscale x 8 x i16> %in3)
store <vscale x 24 x i16> %tuple, <vscale x 24 x i16>* %out
ret void
}
define void @store_f32_tuple3(<vscale x 12 x float>* %out, <vscale x 4 x float> %in1, <vscale x 4 x float> %in2, <vscale x 4 x float> %in3) {
; CHECK-LABEL: store_f32_tuple3
; CHECK: st1w { z2.s }, p0, [x0, #2, mul vl]
; CHECK-NEXT: st1w { z1.s }, p0, [x0, #1, mul vl]
; CHECK-NEXT: st1w { z0.s }, p0, [x0]
%tuple = tail call <vscale x 12 x float> @llvm.aarch64.sve.tuple.create3.nxv12f32.nxv4f32(<vscale x 4 x float> %in1, <vscale x 4 x float> %in2, <vscale x 4 x float> %in3)
store <vscale x 12 x float> %tuple, <vscale x 12 x float>* %out
ret void
}
define void @store_f32_tuple4(<vscale x 16 x float>* %out, <vscale x 4 x float> %in1, <vscale x 4 x float> %in2, <vscale x 4 x float> %in3, <vscale x 4 x float> %in4) {
; CHECK-LABEL: store_f32_tuple4
; CHECK: st1w { z3.s }, p0, [x0, #3, mul vl]
; CHECK-NEXT: st1w { z2.s }, p0, [x0, #2, mul vl]
; CHECK-NEXT: st1w { z1.s }, p0, [x0, #1, mul vl]
; CHECK-NEXT: st1w { z0.s }, p0, [x0]
%tuple = tail call <vscale x 16 x float> @llvm.aarch64.sve.tuple.create4.nxv16f32.nxv4f32(<vscale x 4 x float> %in1, <vscale x 4 x float> %in2, <vscale x 4 x float> %in3, <vscale x 4 x float> %in4)
store <vscale x 16 x float> %tuple, <vscale x 16 x float>* %out
ret void
}
declare void @llvm.aarch64.sve.st2.nxv16i8(<vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i1>, i8*) declare void @llvm.aarch64.sve.st2.nxv16i8(<vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i1>, i8*)
declare void @llvm.aarch64.sve.st2.nxv8i16(<vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i1>, i16*) declare void @llvm.aarch64.sve.st2.nxv8i16(<vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i1>, i16*)
declare void @llvm.aarch64.sve.st2.nxv4i32(<vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i1>, i32*) declare void @llvm.aarch64.sve.st2.nxv4i32(<vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i1>, i32*)
@ -473,5 +536,14 @@ declare void @llvm.aarch64.sve.stnt1.nxv8bf16(<vscale x 8 x bfloat>, <vscale x 8
declare void @llvm.aarch64.sve.stnt1.nxv4f32(<vscale x 4 x float>, <vscale x 4 x i1>, float*) declare void @llvm.aarch64.sve.stnt1.nxv4f32(<vscale x 4 x float>, <vscale x 4 x i1>, float*)
declare void @llvm.aarch64.sve.stnt1.nxv2f64(<vscale x 2 x double>, <vscale x 2 x i1>, double*) declare void @llvm.aarch64.sve.stnt1.nxv2f64(<vscale x 2 x double>, <vscale x 2 x i1>, double*)
declare <vscale x 6 x i64> @llvm.aarch64.sve.tuple.create3.nxv6i64.nxv2i64(<vscale x 2 x i64>, <vscale x 2 x i64>, <vscale x 2 x i64>)
declare <vscale x 8 x i64> @llvm.aarch64.sve.tuple.create4.nxv8i64.nxv2i64(<vscale x 2 x i64>, <vscale x 2 x i64>, <vscale x 2 x i64>, <vscale x 2 x i64>)
declare <vscale x 16 x i16> @llvm.aarch64.sve.tuple.create2.nxv16i16.nxv8i16(<vscale x 8 x i16>, <vscale x 8 x i16>)
declare <vscale x 24 x i16> @llvm.aarch64.sve.tuple.create3.nxv24i16.nxv8i16(<vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16>)
declare <vscale x 12 x float> @llvm.aarch64.sve.tuple.create3.nxv12f32.nxv4f32(<vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>)
declare <vscale x 16 x float> @llvm.aarch64.sve.tuple.create4.nxv16f32.nxv4f32(<vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>)
; +bf16 is required for the bfloat version. ; +bf16 is required for the bfloat version.
attributes #0 = { "target-features"="+sve,+bf16" } attributes #0 = { "target-features"="+sve,+bf16" }

View File

@ -133,3 +133,37 @@ define void @store_nxv4f16(<vscale x 4 x half>* %out) {
store <vscale x 4 x half> %splat, <vscale x 4 x half>* %out store <vscale x 4 x half> %splat, <vscale x 4 x half>* %out
ret void ret void
} }
; Splat stores of unusual FP scalable vector types
define void @store_nxv6f32(<vscale x 6 x float>* %out) {
; CHECK-LABEL: store_nxv6f32:
; CHECK: // %bb.0:
; CHECK-NEXT: fmov z0.s, #1.00000000
; CHECK-NEXT: ptrue p0.s
; CHECK-NEXT: st1w { z0.s }, p0, [x0]
; CHECK-NEXT: uunpklo z0.d, z0.s
; CHECK-NEXT: ptrue p0.d
; CHECK-NEXT: st1w { z0.d }, p0, [x0, #2, mul vl]
; CHECK-NEXT: ret
%ins = insertelement <vscale x 6 x float> undef, float 1.0, i32 0
%splat = shufflevector <vscale x 6 x float> %ins, <vscale x 6 x float> undef, <vscale x 6 x i32> zeroinitializer
store <vscale x 6 x float> %splat, <vscale x 6 x float>* %out
ret void
}
define void @store_nxv12f16(<vscale x 12 x half>* %out) {
; CHECK-LABEL: store_nxv12f16:
; CHECK: // %bb.0:
; CHECK-NEXT: fmov z0.h, #1.00000000
; CHECK-NEXT: ptrue p0.h
; CHECK-NEXT: st1h { z0.h }, p0, [x0]
; CHECK-NEXT: uunpklo z0.s, z0.h
; CHECK-NEXT: ptrue p0.s
; CHECK-NEXT: st1h { z0.s }, p0, [x0, #2, mul vl]
; CHECK-NEXT: ret
%ins = insertelement <vscale x 12 x half> undef, half 1.0, i32 0
%splat = shufflevector <vscale x 12 x half> %ins, <vscale x 12 x half> undef, <vscale x 12 x i32> zeroinitializer
store <vscale x 12 x half> %splat, <vscale x 12 x half>* %out
ret void
}