[RISCV] Handle vector of pointer in getTgtMemIntrinsic for strided load/store.

getScalarSizeInBits() doesn't work if the scalar type is a pointer.
For that we need to go through DataLayout.
This commit is contained in:
Craig Topper 2021-10-06 17:14:08 -07:00
parent d456fed1a9
commit c4803bd416
2 changed files with 160 additions and 5 deletions

View File

@ -953,6 +953,7 @@ bool RISCVTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
const CallInst &I,
MachineFunction &MF,
unsigned Intrinsic) const {
auto &DL = I.getModule()->getDataLayout();
switch (Intrinsic) {
default:
return false;
@ -978,17 +979,19 @@ bool RISCVTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
case Intrinsic::riscv_masked_strided_load:
Info.opc = ISD::INTRINSIC_W_CHAIN;
Info.ptrVal = I.getArgOperand(1);
Info.memVT = MVT::getVT(I.getType()->getScalarType());
Info.align = Align(I.getType()->getScalarSizeInBits() / 8);
Info.memVT = getValueType(DL, I.getType()->getScalarType());
Info.align = Align(DL.getTypeSizeInBits(I.getType()->getScalarType()) / 8);
Info.size = MemoryLocation::UnknownSize;
Info.flags |= MachineMemOperand::MOLoad;
return true;
case Intrinsic::riscv_masked_strided_store:
Info.opc = ISD::INTRINSIC_VOID;
Info.ptrVal = I.getArgOperand(1);
Info.memVT = MVT::getVT(I.getArgOperand(0)->getType()->getScalarType());
Info.align =
Align(I.getArgOperand(0)->getType()->getScalarSizeInBits() / 8);
Info.memVT =
getValueType(DL, I.getArgOperand(0)->getType()->getScalarType());
Info.align = Align(
DL.getTypeSizeInBits(I.getArgOperand(0)->getType()->getScalarType()) /
8);
Info.size = MemoryLocation::UnknownSize;
Info.flags |= MachineMemOperand::MOStore;
return true;

View File

@ -826,3 +826,155 @@ declare <32 x i8> @llvm.masked.gather.v32i8.v32p0i8(<32 x i8*>, i32 immarg, <32
declare <8 x i32> @llvm.masked.gather.v8i32.v8p0i32(<8 x i32*>, i32 immarg, <8 x i1>, <8 x i32>)
declare void @llvm.masked.scatter.v32i8.v32p0i8(<32 x i8>, <32 x i8*>, i32 immarg, <32 x i1>)
declare void @llvm.masked.scatter.v8i32.v8p0i32(<8 x i32>, <8 x i32*>, i32 immarg, <8 x i1>)
; Make sure we don't crash in getTgtMemIntrinsic for a vector of pointers.
define void @gather_of_pointers(i32** noalias nocapture %0, i32** noalias nocapture readonly %1) {
; CHECK-LABEL: @gather_of_pointers(
; CHECK-NEXT: br label [[TMP3:%.*]]
; CHECK: 3:
; CHECK-NEXT: [[TMP4:%.*]] = phi i64 [ 0, [[TMP2:%.*]] ], [ [[TMP15:%.*]], [[TMP3]] ]
; CHECK-NEXT: [[DOTSCALAR:%.*]] = phi i64 [ 0, [[TMP2]] ], [ [[DOTSCALAR1:%.*]], [[TMP3]] ]
; CHECK-NEXT: [[DOTSCALAR2:%.*]] = phi i64 [ 10, [[TMP2]] ], [ [[DOTSCALAR3:%.*]], [[TMP3]] ]
; CHECK-NEXT: [[TMP5:%.*]] = getelementptr i32*, i32** [[TMP1:%.*]], i64 [[DOTSCALAR]]
; CHECK-NEXT: [[TMP6:%.*]] = bitcast i32** [[TMP5]] to i8*
; CHECK-NEXT: [[TMP7:%.*]] = getelementptr i32*, i32** [[TMP1]], i64 [[DOTSCALAR2]]
; CHECK-NEXT: [[TMP8:%.*]] = bitcast i32** [[TMP7]] to i8*
; CHECK-NEXT: [[TMP9:%.*]] = call <2 x i32*> @llvm.riscv.masked.strided.load.v2p0i32.p0i8.i64(<2 x i32*> undef, i8* [[TMP6]], i64 40, <2 x i1> <i1 true, i1 true>)
; CHECK-NEXT: [[TMP10:%.*]] = call <2 x i32*> @llvm.riscv.masked.strided.load.v2p0i32.p0i8.i64(<2 x i32*> undef, i8* [[TMP8]], i64 40, <2 x i1> <i1 true, i1 true>)
; CHECK-NEXT: [[TMP11:%.*]] = getelementptr inbounds i32*, i32** [[TMP0:%.*]], i64 [[TMP4]]
; CHECK-NEXT: [[TMP12:%.*]] = bitcast i32** [[TMP11]] to <2 x i32*>*
; CHECK-NEXT: store <2 x i32*> [[TMP9]], <2 x i32*>* [[TMP12]], align 8
; CHECK-NEXT: [[TMP13:%.*]] = getelementptr inbounds i32*, i32** [[TMP11]], i64 2
; CHECK-NEXT: [[TMP14:%.*]] = bitcast i32** [[TMP13]] to <2 x i32*>*
; CHECK-NEXT: store <2 x i32*> [[TMP10]], <2 x i32*>* [[TMP14]], align 8
; CHECK-NEXT: [[TMP15]] = add nuw i64 [[TMP4]], 4
; CHECK-NEXT: [[DOTSCALAR1]] = add i64 [[DOTSCALAR]], 20
; CHECK-NEXT: [[DOTSCALAR3]] = add i64 [[DOTSCALAR2]], 20
; CHECK-NEXT: [[TMP16:%.*]] = icmp eq i64 [[TMP15]], 1024
; CHECK-NEXT: br i1 [[TMP16]], label [[TMP17:%.*]], label [[TMP3]]
; CHECK: 17:
; CHECK-NEXT: ret void
;
; CHECK-ASM-LABEL: gather_of_pointers:
; CHECK-ASM: # %bb.0:
; CHECK-ASM-NEXT: addi a0, a0, 16
; CHECK-ASM-NEXT: addi a2, zero, 1024
; CHECK-ASM-NEXT: addi a3, zero, 40
; CHECK-ASM-NEXT: .LBB10_1: # =>This Inner Loop Header: Depth=1
; CHECK-ASM-NEXT: addi a4, a1, 80
; CHECK-ASM-NEXT: vsetivli zero, 2, e64, m1, ta, mu
; CHECK-ASM-NEXT: vlse64.v v25, (a1), a3
; CHECK-ASM-NEXT: vlse64.v v26, (a4), a3
; CHECK-ASM-NEXT: addi a4, a0, -16
; CHECK-ASM-NEXT: vse64.v v25, (a4)
; CHECK-ASM-NEXT: vse64.v v26, (a0)
; CHECK-ASM-NEXT: addi a2, a2, -4
; CHECK-ASM-NEXT: addi a0, a0, 32
; CHECK-ASM-NEXT: addi a1, a1, 160
; CHECK-ASM-NEXT: bnez a2, .LBB10_1
; CHECK-ASM-NEXT: # %bb.2:
; CHECK-ASM-NEXT: ret
br label %3
3: ; preds = %3, %2
%4 = phi i64 [ 0, %2 ], [ %17, %3 ]
%5 = phi <2 x i64> [ <i64 0, i64 1>, %2 ], [ %18, %3 ]
%6 = mul nuw nsw <2 x i64> %5, <i64 5, i64 5>
%7 = mul <2 x i64> %5, <i64 5, i64 5>
%8 = add <2 x i64> %7, <i64 10, i64 10>
%9 = getelementptr inbounds i32*, i32** %1, <2 x i64> %6
%10 = getelementptr inbounds i32*, i32** %1, <2 x i64> %8
%11 = call <2 x i32*> @llvm.masked.gather.v2p0i32.v2p0p0i32(<2 x i32**> %9, i32 8, <2 x i1> <i1 true, i1 true>, <2 x i32*> undef)
%12 = call <2 x i32*> @llvm.masked.gather.v2p0i32.v2p0p0i32(<2 x i32**> %10, i32 8, <2 x i1> <i1 true, i1 true>, <2 x i32*> undef)
%13 = getelementptr inbounds i32*, i32** %0, i64 %4
%14 = bitcast i32** %13 to <2 x i32*>*
store <2 x i32*> %11, <2 x i32*>* %14, align 8
%15 = getelementptr inbounds i32*, i32** %13, i64 2
%16 = bitcast i32** %15 to <2 x i32*>*
store <2 x i32*> %12, <2 x i32*>* %16, align 8
%17 = add nuw i64 %4, 4
%18 = add <2 x i64> %5, <i64 4, i64 4>
%19 = icmp eq i64 %17, 1024
br i1 %19, label %20, label %3
20: ; preds = %3
ret void
}
declare <2 x i32*> @llvm.masked.gather.v2p0i32.v2p0p0i32(<2 x i32**>, i32 immarg, <2 x i1>, <2 x i32*>)
; Make sure we don't crash in getTgtMemIntrinsic for a vector of pointers.
define void @scatter_of_pointers(i32** noalias nocapture %0, i32** noalias nocapture readonly %1) {
; CHECK-LABEL: @scatter_of_pointers(
; CHECK-NEXT: br label [[TMP3:%.*]]
; CHECK: 3:
; CHECK-NEXT: [[TMP4:%.*]] = phi i64 [ 0, [[TMP2:%.*]] ], [ [[TMP15:%.*]], [[TMP3]] ]
; CHECK-NEXT: [[DOTSCALAR:%.*]] = phi i64 [ 0, [[TMP2]] ], [ [[DOTSCALAR1:%.*]], [[TMP3]] ]
; CHECK-NEXT: [[DOTSCALAR2:%.*]] = phi i64 [ 10, [[TMP2]] ], [ [[DOTSCALAR3:%.*]], [[TMP3]] ]
; CHECK-NEXT: [[TMP5:%.*]] = getelementptr inbounds i32*, i32** [[TMP1:%.*]], i64 [[TMP4]]
; CHECK-NEXT: [[TMP6:%.*]] = bitcast i32** [[TMP5]] to <2 x i32*>*
; CHECK-NEXT: [[TMP7:%.*]] = load <2 x i32*>, <2 x i32*>* [[TMP6]], align 8
; CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds i32*, i32** [[TMP5]], i64 2
; CHECK-NEXT: [[TMP9:%.*]] = bitcast i32** [[TMP8]] to <2 x i32*>*
; CHECK-NEXT: [[TMP10:%.*]] = load <2 x i32*>, <2 x i32*>* [[TMP9]], align 8
; CHECK-NEXT: [[TMP11:%.*]] = getelementptr i32*, i32** [[TMP0:%.*]], i64 [[DOTSCALAR]]
; CHECK-NEXT: [[TMP12:%.*]] = bitcast i32** [[TMP11]] to i8*
; CHECK-NEXT: [[TMP13:%.*]] = getelementptr i32*, i32** [[TMP0]], i64 [[DOTSCALAR2]]
; CHECK-NEXT: [[TMP14:%.*]] = bitcast i32** [[TMP13]] to i8*
; CHECK-NEXT: call void @llvm.riscv.masked.strided.store.v2p0i32.p0i8.i64(<2 x i32*> [[TMP7]], i8* [[TMP12]], i64 40, <2 x i1> <i1 true, i1 true>)
; CHECK-NEXT: call void @llvm.riscv.masked.strided.store.v2p0i32.p0i8.i64(<2 x i32*> [[TMP10]], i8* [[TMP14]], i64 40, <2 x i1> <i1 true, i1 true>)
; CHECK-NEXT: [[TMP15]] = add nuw i64 [[TMP4]], 4
; CHECK-NEXT: [[DOTSCALAR1]] = add i64 [[DOTSCALAR]], 20
; CHECK-NEXT: [[DOTSCALAR3]] = add i64 [[DOTSCALAR2]], 20
; CHECK-NEXT: [[TMP16:%.*]] = icmp eq i64 [[TMP15]], 1024
; CHECK-NEXT: br i1 [[TMP16]], label [[TMP17:%.*]], label [[TMP3]]
; CHECK: 17:
; CHECK-NEXT: ret void
;
; CHECK-ASM-LABEL: scatter_of_pointers:
; CHECK-ASM: # %bb.0:
; CHECK-ASM-NEXT: addi a1, a1, 16
; CHECK-ASM-NEXT: addi a2, zero, 1024
; CHECK-ASM-NEXT: addi a3, zero, 40
; CHECK-ASM-NEXT: .LBB11_1: # =>This Inner Loop Header: Depth=1
; CHECK-ASM-NEXT: addi a4, a1, -16
; CHECK-ASM-NEXT: vsetivli zero, 2, e64, m1, ta, mu
; CHECK-ASM-NEXT: vle64.v v25, (a4)
; CHECK-ASM-NEXT: vle64.v v26, (a1)
; CHECK-ASM-NEXT: addi a4, a0, 80
; CHECK-ASM-NEXT: vsse64.v v25, (a0), a3
; CHECK-ASM-NEXT: vsse64.v v26, (a4), a3
; CHECK-ASM-NEXT: addi a2, a2, -4
; CHECK-ASM-NEXT: addi a1, a1, 32
; CHECK-ASM-NEXT: addi a0, a0, 160
; CHECK-ASM-NEXT: bnez a2, .LBB11_1
; CHECK-ASM-NEXT: # %bb.2:
; CHECK-ASM-NEXT: ret
br label %3
3: ; preds = %3, %2
%4 = phi i64 [ 0, %2 ], [ %17, %3 ]
%5 = phi <2 x i64> [ <i64 0, i64 1>, %2 ], [ %18, %3 ]
%6 = getelementptr inbounds i32*, i32** %1, i64 %4
%7 = bitcast i32** %6 to <2 x i32*>*
%8 = load <2 x i32*>, <2 x i32*>* %7, align 8
%9 = getelementptr inbounds i32*, i32** %6, i64 2
%10 = bitcast i32** %9 to <2 x i32*>*
%11 = load <2 x i32*>, <2 x i32*>* %10, align 8
%12 = mul nuw nsw <2 x i64> %5, <i64 5, i64 5>
%13 = mul <2 x i64> %5, <i64 5, i64 5>
%14 = add <2 x i64> %13, <i64 10, i64 10>
%15 = getelementptr inbounds i32*, i32** %0, <2 x i64> %12
%16 = getelementptr inbounds i32*, i32** %0, <2 x i64> %14
call void @llvm.masked.scatter.v2p0i32.v2p0p0i32(<2 x i32*> %8, <2 x i32**> %15, i32 8, <2 x i1> <i1 true, i1 true>)
call void @llvm.masked.scatter.v2p0i32.v2p0p0i32(<2 x i32*> %11, <2 x i32**> %16, i32 8, <2 x i1> <i1 true, i1 true>)
%17 = add nuw i64 %4, 4
%18 = add <2 x i64> %5, <i64 4, i64 4>
%19 = icmp eq i64 %17, 1024
br i1 %19, label %20, label %3
20: ; preds = %3
ret void
}
declare void @llvm.masked.scatter.v2p0i32.v2p0p0i32(<2 x i32*>, <2 x i32**>, i32 immarg, <2 x i1>)