forked from OSchip/llvm-project
[InstCombine] Fold for masked scatters to a uniform address
When masked scatter intrinsic does a uniform store to a destination address from a source vector, and in this case, the mask is all one value. This patch replaces the masked scatter with an extracted element of the last lane of the source vector and stores it in the destination vector. This patch also folds when the value in the masked scatter is a splat. In this case, the mask cannot be all zero, and it folds to a scalar store of the value in the destination pointer. Differential Revision: https://reviews.llvm.org/D115724
This commit is contained in:
parent
20d9c51dc0
commit
8e5a5b619d
|
@ -362,7 +362,6 @@ Instruction *InstCombinerImpl::simplifyMaskedGather(IntrinsicInst &II) {
|
|||
// * Single constant active lane -> store
|
||||
// * Adjacent vector addresses -> masked.store
|
||||
// * Narrow store width by halfs excluding zero/undef lanes
|
||||
// * Vector splat address w/known mask -> scalar store
|
||||
// * Vector incrementing address -> vector masked store
|
||||
Instruction *InstCombinerImpl::simplifyMaskedScatter(IntrinsicInst &II) {
|
||||
auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(3));
|
||||
|
@ -373,6 +372,34 @@ Instruction *InstCombinerImpl::simplifyMaskedScatter(IntrinsicInst &II) {
|
|||
if (ConstMask->isNullValue())
|
||||
return eraseInstFromFunction(II);
|
||||
|
||||
// Vector splat address -> scalar store
|
||||
if (auto *SplatPtr = getSplatValue(II.getArgOperand(1))) {
|
||||
// scatter(splat(value), splat(ptr), non-zero-mask) -> store value, ptr
|
||||
if (auto *SplatValue = getSplatValue(II.getArgOperand(0))) {
|
||||
Align Alignment = cast<ConstantInt>(II.getArgOperand(2))->getAlignValue();
|
||||
StoreInst *S =
|
||||
new StoreInst(SplatValue, SplatPtr, /*IsVolatile=*/false, Alignment);
|
||||
S->copyMetadata(II);
|
||||
return S;
|
||||
}
|
||||
// scatter(vector, splat(ptr), splat(true)) -> store extract(vector,
|
||||
// lastlane), ptr
|
||||
if (ConstMask->isAllOnesValue()) {
|
||||
Align Alignment = cast<ConstantInt>(II.getArgOperand(2))->getAlignValue();
|
||||
VectorType *WideLoadTy = cast<VectorType>(II.getArgOperand(1)->getType());
|
||||
ElementCount VF = WideLoadTy->getElementCount();
|
||||
Constant *EC =
|
||||
ConstantInt::get(Builder.getInt32Ty(), VF.getKnownMinValue());
|
||||
Value *RunTimeVF = VF.isScalable() ? Builder.CreateVScale(EC) : EC;
|
||||
Value *LastLane = Builder.CreateSub(RunTimeVF, Builder.getInt32(1));
|
||||
Value *Extract =
|
||||
Builder.CreateExtractElement(II.getArgOperand(0), LastLane);
|
||||
StoreInst *S =
|
||||
new StoreInst(Extract, SplatPtr, /*IsVolatile=*/false, Alignment);
|
||||
S->copyMetadata(II);
|
||||
return S;
|
||||
}
|
||||
}
|
||||
if (isa<ScalableVectorType>(ConstMask->getType()))
|
||||
return nullptr;
|
||||
|
||||
|
|
|
@ -269,3 +269,110 @@ define void @scatter_demandedelts(double* %ptr, double %val) {
|
|||
call void @llvm.masked.scatter.v2f64.v2p0f64(<2 x double> %valvec2, <2 x double*> %ptrs, i32 8, <2 x i1> <i1 true, i1 false>)
|
||||
ret void
|
||||
}
|
||||
|
||||
|
||||
; Test scatters that can be simplified to scalar stores.
|
||||
|
||||
;; Value splat (mask is not used)
|
||||
define void @scatter_v4i16_uniform_vals_uniform_ptrs_no_all_active_mask(i16* %dst, i16 %val) {
|
||||
; CHECK-LABEL: @scatter_v4i16_uniform_vals_uniform_ptrs_no_all_active_mask(
|
||||
; CHECK-NEXT: entry:
|
||||
; CHECK-NEXT: store i16 [[VAL:%.*]], i16* [[DST:%.*]], align 2
|
||||
; CHECK-NEXT: ret void
|
||||
;
|
||||
entry:
|
||||
%broadcast.splatinsert = insertelement <4 x i16*> poison, i16* %dst, i32 0
|
||||
%broadcast.splat = shufflevector <4 x i16*> %broadcast.splatinsert, <4 x i16*> poison, <4 x i32> zeroinitializer
|
||||
%broadcast.value = insertelement <4 x i16> poison, i16 %val, i32 0
|
||||
%broadcast.splatvalue = shufflevector <4 x i16> %broadcast.value, <4 x i16> poison, <4 x i32> zeroinitializer
|
||||
call void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16> %broadcast.splatvalue, <4 x i16*> %broadcast.splat, i32 2, <4 x i1> <i1 0, i1 0, i1 1, i1 1>)
|
||||
ret void
|
||||
}
|
||||
|
||||
define void @scatter_nxv4i16_uniform_vals_uniform_ptrs_all_active_mask(i16* %dst, i16 %val) {
|
||||
; CHECK-LABEL: @scatter_nxv4i16_uniform_vals_uniform_ptrs_all_active_mask(
|
||||
; CHECK-NEXT: entry:
|
||||
; CHECK-NEXT: store i16 [[VAL:%.*]], i16* [[DST:%.*]], align 2
|
||||
; CHECK-NEXT: ret void
|
||||
;
|
||||
entry:
|
||||
%broadcast.splatinsert = insertelement <vscale x 4 x i16*> poison, i16* %dst, i32 0
|
||||
%broadcast.splat = shufflevector <vscale x 4 x i16*> %broadcast.splatinsert, <vscale x 4 x i16*> poison, <vscale x 4 x i32> zeroinitializer
|
||||
%broadcast.value = insertelement <vscale x 4 x i16> poison, i16 %val, i32 0
|
||||
%broadcast.splatvalue = shufflevector <vscale x 4 x i16> %broadcast.value, <vscale x 4 x i16> poison, <vscale x 4 x i32> zeroinitializer
|
||||
call void @llvm.masked.scatter.nxv4i16.nxv4p0i16(<vscale x 4 x i16> %broadcast.splatvalue, <vscale x 4 x i16*> %broadcast.splat, i32 2, <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> zeroinitializer , i1 true, i32 0), <vscale x 4 x i1> zeroinitializer, <vscale x 4 x i32> zeroinitializer))
|
||||
ret void
|
||||
}
|
||||
|
||||
;; The pointer is splat and mask is all active, but value is not a splat
|
||||
define void @scatter_v4i16_no_uniform_vals_uniform_ptrs_all_active_mask(i16* %dst, <4 x i16>* %src) {
|
||||
; CHECK-LABEL: @scatter_v4i16_no_uniform_vals_uniform_ptrs_all_active_mask(
|
||||
; CHECK-NEXT: entry:
|
||||
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <4 x i16>, <4 x i16>* [[SRC:%.*]], align 2
|
||||
; CHECK-NEXT: [[TMP0:%.*]] = extractelement <4 x i16> [[WIDE_LOAD]], i64 3
|
||||
; CHECK-NEXT: store i16 [[TMP0]], i16* [[DST:%.*]], align 2
|
||||
; CHECK-NEXT: ret void
|
||||
;
|
||||
entry:
|
||||
%broadcast.splatinsert = insertelement <4 x i16*> poison, i16* %dst, i32 0
|
||||
%broadcast.splat = shufflevector <4 x i16*> %broadcast.splatinsert, <4 x i16*> poison, <4 x i32> zeroinitializer
|
||||
%wide.load = load <4 x i16>, <4 x i16>* %src, align 2
|
||||
call void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16> %wide.load, <4 x i16*> %broadcast.splat, i32 2, <4 x i1> <i1 1, i1 1, i1 1, i1 1>)
|
||||
ret void
|
||||
}
|
||||
|
||||
define void @scatter_nxv4i16_no_uniform_vals_uniform_ptrs_all_active_mask(i16* %dst, <vscale x 4 x i16>* %src) {
|
||||
; CHECK-LABEL: @scatter_nxv4i16_no_uniform_vals_uniform_ptrs_all_active_mask(
|
||||
; CHECK-NEXT: entry:
|
||||
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <vscale x 4 x i16>, <vscale x 4 x i16>* [[SRC:%.*]], align 2
|
||||
; CHECK-NEXT: [[TMP0:%.*]] = call i32 @llvm.vscale.i32()
|
||||
; CHECK-NEXT: [[TMP1:%.*]] = shl i32 [[TMP0]], 2
|
||||
; CHECK-NEXT: [[TMP2:%.*]] = add i32 [[TMP1]], -1
|
||||
; CHECK-NEXT: [[TMP3:%.*]] = extractelement <vscale x 4 x i16> [[WIDE_LOAD]], i32 [[TMP2]]
|
||||
; CHECK-NEXT: store i16 [[TMP3]], i16* [[DST:%.*]], align 2
|
||||
; CHECK-NEXT: ret void
|
||||
;
|
||||
entry:
|
||||
%broadcast.splatinsert = insertelement <vscale x 4 x i16*> poison, i16* %dst, i32 0
|
||||
%broadcast.splat = shufflevector <vscale x 4 x i16*> %broadcast.splatinsert, <vscale x 4 x i16*> poison, <vscale x 4 x i32> zeroinitializer
|
||||
%wide.load = load <vscale x 4 x i16>, <vscale x 4 x i16>* %src, align 2
|
||||
call void @llvm.masked.scatter.nxv4i16.nxv4p0i16(<vscale x 4 x i16> %wide.load, <vscale x 4 x i16*> %broadcast.splat, i32 2, <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> poison, i1 true, i32 0), <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer))
|
||||
ret void
|
||||
}
|
||||
|
||||
; Negative scatter tests
|
||||
|
||||
;; Pointer is splat, but mask is not all active and value is not a splat
|
||||
define void @negative_scatter_v4i16_no_uniform_vals_uniform_ptrs_all_inactive_mask(i16* %dst, <4 x i16>* %src) {
|
||||
; CHECK-LABEL: @negative_scatter_v4i16_no_uniform_vals_uniform_ptrs_all_inactive_mask(
|
||||
; CHECK-NEXT: [[INSERT_ELT:%.*]] = insertelement <4 x i16*> poison, i16* [[DST:%.*]], i64 0
|
||||
; CHECK-NEXT: [[BROADCAST_SPLAT:%.*]] = shufflevector <4 x i16*> [[INSERT_ELT]], <4 x i16*> poison, <4 x i32> <i32 undef, i32 undef, i32 0, i32 0>
|
||||
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <4 x i16>, <4 x i16>* [[SRC:%.*]], align 2
|
||||
; CHECK-NEXT: call void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16> [[WIDE_LOAD]], <4 x i16*> [[BROADCAST_SPLAT]], i32 2, <4 x i1> <i1 false, i1 false, i1 true, i1 true>)
|
||||
; CHECK-NEXT: ret void
|
||||
;
|
||||
%insert.elt = insertelement <4 x i16*> poison, i16* %dst, i32 0
|
||||
%broadcast.splat = shufflevector <4 x i16*> %insert.elt, <4 x i16*> poison, <4 x i32> zeroinitializer
|
||||
%wide.load = load <4 x i16>, <4 x i16>* %src, align 2
|
||||
call void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16> %wide.load, <4 x i16*> %broadcast.splat, i32 2, <4 x i1> <i1 0, i1 0, i1 1, i1 1>)
|
||||
ret void
|
||||
}
|
||||
|
||||
;; The pointer in NOT a splat
|
||||
define void @negative_scatter_v4i16_no_uniform_vals_no_uniform_ptrs_all_active_mask(<4 x i16*> %inPtr, <4 x i16>* %src) {
|
||||
; CHECK-LABEL: @negative_scatter_v4i16_no_uniform_vals_no_uniform_ptrs_all_active_mask(
|
||||
; CHECK-NEXT: [[BROADCAST:%.*]] = shufflevector <4 x i16*> [[INPTR:%.*]], <4 x i16*> poison, <4 x i32> zeroinitializer
|
||||
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <4 x i16>, <4 x i16>* [[SRC:%.*]], align 2
|
||||
; CHECK-NEXT: call void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16> [[WIDE_LOAD]], <4 x i16*> [[BROADCAST]], i32 2, <4 x i1> <i1 true, i1 true, i1 true, i1 true>)
|
||||
; CHECK-NEXT: ret void
|
||||
;
|
||||
%broadcast= shufflevector <4 x i16*> %inPtr, <4 x i16*> poison, <4 x i32> zeroinitializer
|
||||
%wide.load = load <4 x i16>, <4 x i16>* %src, align 2
|
||||
call void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16> %wide.load, <4 x i16*> %broadcast, i32 2, <4 x i1> <i1 1, i1 1, i1 1, i1 1> )
|
||||
ret void
|
||||
}
|
||||
|
||||
|
||||
; Function Attrs:
|
||||
declare void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16>, <4 x i16*>, i32 immarg, <4 x i1>)
|
||||
declare void @llvm.masked.scatter.nxv4i16.nxv4p0i16(<vscale x 4 x i16>, <vscale x 4 x i16*>, i32 immarg, <vscale x 4 x i1>)
|
||||
|
|
Loading…
Reference in New Issue