[InstCombine] Fold for masked scatters to a uniform address

When masked scatter intrinsic does a uniform store to a destination
address from a source vector, and in this case, the mask is all one value.
This patch replaces the masked scatter with an extracted element of the
last lane of the source vector and stores it in the destination vector.
This patch also folds when the value in the masked scatter is a splat.
In this case, the mask cannot be all zero, and it folds to a scalar store
of the value in the destination pointer.

Differential Revision: https://reviews.llvm.org/D115724
This commit is contained in:
Caroline Concatto 2021-12-13 12:03:26 +00:00
parent 20d9c51dc0
commit 8e5a5b619d
2 changed files with 135 additions and 1 deletions

View File

@ -362,7 +362,6 @@ Instruction *InstCombinerImpl::simplifyMaskedGather(IntrinsicInst &II) {
// * Single constant active lane -> store
// * Adjacent vector addresses -> masked.store
// * Narrow store width by halfs excluding zero/undef lanes
// * Vector splat address w/known mask -> scalar store
// * Vector incrementing address -> vector masked store
Instruction *InstCombinerImpl::simplifyMaskedScatter(IntrinsicInst &II) {
auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(3));
@ -373,6 +372,34 @@ Instruction *InstCombinerImpl::simplifyMaskedScatter(IntrinsicInst &II) {
if (ConstMask->isNullValue())
return eraseInstFromFunction(II);
// Vector splat address -> scalar store
if (auto *SplatPtr = getSplatValue(II.getArgOperand(1))) {
// scatter(splat(value), splat(ptr), non-zero-mask) -> store value, ptr
if (auto *SplatValue = getSplatValue(II.getArgOperand(0))) {
Align Alignment = cast<ConstantInt>(II.getArgOperand(2))->getAlignValue();
StoreInst *S =
new StoreInst(SplatValue, SplatPtr, /*IsVolatile=*/false, Alignment);
S->copyMetadata(II);
return S;
}
// scatter(vector, splat(ptr), splat(true)) -> store extract(vector,
// lastlane), ptr
if (ConstMask->isAllOnesValue()) {
Align Alignment = cast<ConstantInt>(II.getArgOperand(2))->getAlignValue();
VectorType *WideLoadTy = cast<VectorType>(II.getArgOperand(1)->getType());
ElementCount VF = WideLoadTy->getElementCount();
Constant *EC =
ConstantInt::get(Builder.getInt32Ty(), VF.getKnownMinValue());
Value *RunTimeVF = VF.isScalable() ? Builder.CreateVScale(EC) : EC;
Value *LastLane = Builder.CreateSub(RunTimeVF, Builder.getInt32(1));
Value *Extract =
Builder.CreateExtractElement(II.getArgOperand(0), LastLane);
StoreInst *S =
new StoreInst(Extract, SplatPtr, /*IsVolatile=*/false, Alignment);
S->copyMetadata(II);
return S;
}
}
if (isa<ScalableVectorType>(ConstMask->getType()))
return nullptr;

View File

@ -269,3 +269,110 @@ define void @scatter_demandedelts(double* %ptr, double %val) {
call void @llvm.masked.scatter.v2f64.v2p0f64(<2 x double> %valvec2, <2 x double*> %ptrs, i32 8, <2 x i1> <i1 true, i1 false>)
ret void
}
; Test scatters that can be simplified to scalar stores.
;; Value splat (mask is not used)
define void @scatter_v4i16_uniform_vals_uniform_ptrs_no_all_active_mask(i16* %dst, i16 %val) {
; CHECK-LABEL: @scatter_v4i16_uniform_vals_uniform_ptrs_no_all_active_mask(
; CHECK-NEXT: entry:
; CHECK-NEXT: store i16 [[VAL:%.*]], i16* [[DST:%.*]], align 2
; CHECK-NEXT: ret void
;
entry:
%broadcast.splatinsert = insertelement <4 x i16*> poison, i16* %dst, i32 0
%broadcast.splat = shufflevector <4 x i16*> %broadcast.splatinsert, <4 x i16*> poison, <4 x i32> zeroinitializer
%broadcast.value = insertelement <4 x i16> poison, i16 %val, i32 0
%broadcast.splatvalue = shufflevector <4 x i16> %broadcast.value, <4 x i16> poison, <4 x i32> zeroinitializer
call void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16> %broadcast.splatvalue, <4 x i16*> %broadcast.splat, i32 2, <4 x i1> <i1 0, i1 0, i1 1, i1 1>)
ret void
}
define void @scatter_nxv4i16_uniform_vals_uniform_ptrs_all_active_mask(i16* %dst, i16 %val) {
; CHECK-LABEL: @scatter_nxv4i16_uniform_vals_uniform_ptrs_all_active_mask(
; CHECK-NEXT: entry:
; CHECK-NEXT: store i16 [[VAL:%.*]], i16* [[DST:%.*]], align 2
; CHECK-NEXT: ret void
;
entry:
%broadcast.splatinsert = insertelement <vscale x 4 x i16*> poison, i16* %dst, i32 0
%broadcast.splat = shufflevector <vscale x 4 x i16*> %broadcast.splatinsert, <vscale x 4 x i16*> poison, <vscale x 4 x i32> zeroinitializer
%broadcast.value = insertelement <vscale x 4 x i16> poison, i16 %val, i32 0
%broadcast.splatvalue = shufflevector <vscale x 4 x i16> %broadcast.value, <vscale x 4 x i16> poison, <vscale x 4 x i32> zeroinitializer
call void @llvm.masked.scatter.nxv4i16.nxv4p0i16(<vscale x 4 x i16> %broadcast.splatvalue, <vscale x 4 x i16*> %broadcast.splat, i32 2, <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> zeroinitializer , i1 true, i32 0), <vscale x 4 x i1> zeroinitializer, <vscale x 4 x i32> zeroinitializer))
ret void
}
;; The pointer is splat and mask is all active, but value is not a splat
define void @scatter_v4i16_no_uniform_vals_uniform_ptrs_all_active_mask(i16* %dst, <4 x i16>* %src) {
; CHECK-LABEL: @scatter_v4i16_no_uniform_vals_uniform_ptrs_all_active_mask(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <4 x i16>, <4 x i16>* [[SRC:%.*]], align 2
; CHECK-NEXT: [[TMP0:%.*]] = extractelement <4 x i16> [[WIDE_LOAD]], i64 3
; CHECK-NEXT: store i16 [[TMP0]], i16* [[DST:%.*]], align 2
; CHECK-NEXT: ret void
;
entry:
%broadcast.splatinsert = insertelement <4 x i16*> poison, i16* %dst, i32 0
%broadcast.splat = shufflevector <4 x i16*> %broadcast.splatinsert, <4 x i16*> poison, <4 x i32> zeroinitializer
%wide.load = load <4 x i16>, <4 x i16>* %src, align 2
call void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16> %wide.load, <4 x i16*> %broadcast.splat, i32 2, <4 x i1> <i1 1, i1 1, i1 1, i1 1>)
ret void
}
define void @scatter_nxv4i16_no_uniform_vals_uniform_ptrs_all_active_mask(i16* %dst, <vscale x 4 x i16>* %src) {
; CHECK-LABEL: @scatter_nxv4i16_no_uniform_vals_uniform_ptrs_all_active_mask(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <vscale x 4 x i16>, <vscale x 4 x i16>* [[SRC:%.*]], align 2
; CHECK-NEXT: [[TMP0:%.*]] = call i32 @llvm.vscale.i32()
; CHECK-NEXT: [[TMP1:%.*]] = shl i32 [[TMP0]], 2
; CHECK-NEXT: [[TMP2:%.*]] = add i32 [[TMP1]], -1
; CHECK-NEXT: [[TMP3:%.*]] = extractelement <vscale x 4 x i16> [[WIDE_LOAD]], i32 [[TMP2]]
; CHECK-NEXT: store i16 [[TMP3]], i16* [[DST:%.*]], align 2
; CHECK-NEXT: ret void
;
entry:
%broadcast.splatinsert = insertelement <vscale x 4 x i16*> poison, i16* %dst, i32 0
%broadcast.splat = shufflevector <vscale x 4 x i16*> %broadcast.splatinsert, <vscale x 4 x i16*> poison, <vscale x 4 x i32> zeroinitializer
%wide.load = load <vscale x 4 x i16>, <vscale x 4 x i16>* %src, align 2
call void @llvm.masked.scatter.nxv4i16.nxv4p0i16(<vscale x 4 x i16> %wide.load, <vscale x 4 x i16*> %broadcast.splat, i32 2, <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> poison, i1 true, i32 0), <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer))
ret void
}
; Negative scatter tests
;; Pointer is splat, but mask is not all active and value is not a splat
define void @negative_scatter_v4i16_no_uniform_vals_uniform_ptrs_all_inactive_mask(i16* %dst, <4 x i16>* %src) {
; CHECK-LABEL: @negative_scatter_v4i16_no_uniform_vals_uniform_ptrs_all_inactive_mask(
; CHECK-NEXT: [[INSERT_ELT:%.*]] = insertelement <4 x i16*> poison, i16* [[DST:%.*]], i64 0
; CHECK-NEXT: [[BROADCAST_SPLAT:%.*]] = shufflevector <4 x i16*> [[INSERT_ELT]], <4 x i16*> poison, <4 x i32> <i32 undef, i32 undef, i32 0, i32 0>
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <4 x i16>, <4 x i16>* [[SRC:%.*]], align 2
; CHECK-NEXT: call void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16> [[WIDE_LOAD]], <4 x i16*> [[BROADCAST_SPLAT]], i32 2, <4 x i1> <i1 false, i1 false, i1 true, i1 true>)
; CHECK-NEXT: ret void
;
%insert.elt = insertelement <4 x i16*> poison, i16* %dst, i32 0
%broadcast.splat = shufflevector <4 x i16*> %insert.elt, <4 x i16*> poison, <4 x i32> zeroinitializer
%wide.load = load <4 x i16>, <4 x i16>* %src, align 2
call void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16> %wide.load, <4 x i16*> %broadcast.splat, i32 2, <4 x i1> <i1 0, i1 0, i1 1, i1 1>)
ret void
}
;; The pointer in NOT a splat
define void @negative_scatter_v4i16_no_uniform_vals_no_uniform_ptrs_all_active_mask(<4 x i16*> %inPtr, <4 x i16>* %src) {
; CHECK-LABEL: @negative_scatter_v4i16_no_uniform_vals_no_uniform_ptrs_all_active_mask(
; CHECK-NEXT: [[BROADCAST:%.*]] = shufflevector <4 x i16*> [[INPTR:%.*]], <4 x i16*> poison, <4 x i32> zeroinitializer
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <4 x i16>, <4 x i16>* [[SRC:%.*]], align 2
; CHECK-NEXT: call void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16> [[WIDE_LOAD]], <4 x i16*> [[BROADCAST]], i32 2, <4 x i1> <i1 true, i1 true, i1 true, i1 true>)
; CHECK-NEXT: ret void
;
%broadcast= shufflevector <4 x i16*> %inPtr, <4 x i16*> poison, <4 x i32> zeroinitializer
%wide.load = load <4 x i16>, <4 x i16>* %src, align 2
call void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16> %wide.load, <4 x i16*> %broadcast, i32 2, <4 x i1> <i1 1, i1 1, i1 1, i1 1> )
ret void
}
; Function Attrs:
declare void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16>, <4 x i16*>, i32 immarg, <4 x i1>)
declare void @llvm.masked.scatter.nxv4i16.nxv4p0i16(<vscale x 4 x i16>, <vscale x 4 x i16*>, i32 immarg, <vscale x 4 x i1>)