[mlir] [VectorOps] fixed bug in vector.insert_strided_slice lowering

Summary:
Rationale:
When lowering to LLVM for different rank insert (n vs k), the offset
arrays needs to drop one dimension (becomes n-1), but the strides
array needs to be preserved (remains k). With regression test.
Note that this example was actually in the documentation, so
extra important to do it right :-)

Reviewers: nicolasvasilache, andydavis1, ftynse

Reviewed By: nicolasvasilache, ftynse

Subscribers: Joonsoo, merge_guards_bot, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, liufengdb, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D73733
This commit is contained in:
aartbik 2020-01-31 10:56:22 -08:00
parent 73713f3e5e
commit c8fc76a99b
2 changed files with 110 additions and 48 deletions

View File

@ -548,7 +548,7 @@ public:
auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
loc, op.source(), extracted,
getI64SubArray(op.offsets(), /*dropFront=*/rankDiff),
getI64SubArray(op.strides(), /*dropFront=*/rankDiff));
getI64SubArray(op.strides(), /*dropFront=*/0));
rewriter.replaceOpWithNewOp<InsertOp>(
op, stridedSliceInnerOp.getResult(), op.dest(),
getI64SubArray(op.offsets(), /*dropFront=*/0,

View File

@ -4,7 +4,7 @@ func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> {
%0 = vector.broadcast %arg0 : f32 to vector<2xf32>
return %0 : vector<2xf32>
}
// CHECK-LABEL: broadcast_vec1d_from_scalar
// CHECK-LABEL: llvm.func @broadcast_vec1d_from_scalar
// CHECK: llvm.mlir.undef : !llvm<"<2 x float>">
// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<2 x float>">
@ -15,7 +15,7 @@ func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> {
%0 = vector.broadcast %arg0 : f32 to vector<2x3xf32>
return %0 : vector<2x3xf32>
}
// CHECK-LABEL: broadcast_vec2d_from_scalar
// CHECK-LABEL: llvm.func @broadcast_vec2d_from_scalar
// CHECK: llvm.mlir.undef : !llvm<"<3 x float>">
// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<3 x float>">
@ -29,7 +29,7 @@ func @broadcast_vec3d_from_scalar(%arg0: f32) -> vector<2x3x4xf32> {
%0 = vector.broadcast %arg0 : f32 to vector<2x3x4xf32>
return %0 : vector<2x3x4xf32>
}
// CHECK-LABEL: broadcast_vec3d_from_scalar
// CHECK-LABEL: llvm.func @broadcast_vec3d_from_scalar
// CHECK: llvm.mlir.undef : !llvm<"<4 x float>">
// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<4 x float>">
@ -47,14 +47,14 @@ func @broadcast_vec1d_from_vec1d(%arg0: vector<2xf32>) -> vector<2xf32> {
%0 = vector.broadcast %arg0 : vector<2xf32> to vector<2xf32>
return %0 : vector<2xf32>
}
// CHECK-LABEL: broadcast_vec1d_from_vec1d
// CHECK-LABEL: llvm.func @broadcast_vec1d_from_vec1d
// CHECK: llvm.return {{.*}} : !llvm<"<2 x float>">
func @broadcast_vec2d_from_vec1d(%arg0: vector<2xf32>) -> vector<3x2xf32> {
%0 = vector.broadcast %arg0 : vector<2xf32> to vector<3x2xf32>
return %0 : vector<3x2xf32>
}
// CHECK-LABEL: broadcast_vec2d_from_vec1d
// CHECK-LABEL: llvm.func @broadcast_vec2d_from_vec1d
// CHECK: llvm.mlir.undef : !llvm<"[3 x <2 x float>]">
// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[3 x <2 x float>]">
// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[3 x <2 x float>]">
@ -65,7 +65,7 @@ func @broadcast_vec3d_from_vec1d(%arg0: vector<2xf32>) -> vector<4x3x2xf32> {
%0 = vector.broadcast %arg0 : vector<2xf32> to vector<4x3x2xf32>
return %0 : vector<4x3x2xf32>
}
// CHECK-LABEL: broadcast_vec3d_from_vec1d
// CHECK-LABEL: llvm.func @broadcast_vec3d_from_vec1d
// CHECK: llvm.mlir.undef : !llvm<"[3 x <2 x float>]">
// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[3 x <2 x float>]">
// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[3 x <2 x float>]">
@ -81,7 +81,7 @@ func @broadcast_vec3d_from_vec2d(%arg0: vector<3x2xf32>) -> vector<4x3x2xf32> {
%0 = vector.broadcast %arg0 : vector<3x2xf32> to vector<4x3x2xf32>
return %0 : vector<4x3x2xf32>
}
// CHECK-LABEL: broadcast_vec3d_from_vec2d
// CHECK-LABEL: llvm.func @broadcast_vec3d_from_vec2d
// CHECK: llvm.mlir.undef : !llvm<"[4 x [3 x <2 x float>]]">
// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[4 x [3 x <2 x float>]]">
// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[4 x [3 x <2 x float>]]">
@ -93,7 +93,7 @@ func @broadcast_stretch(%arg0: vector<1xf32>) -> vector<4xf32> {
%0 = vector.broadcast %arg0 : vector<1xf32> to vector<4xf32>
return %0 : vector<4xf32>
}
// CHECK-LABEL: broadcast_stretch
// CHECK-LABEL: llvm.func @broadcast_stretch
// CHECK: llvm.mlir.undef : !llvm<"<4 x float>">
// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<1 x float>">
@ -106,7 +106,7 @@ func @broadcast_stretch_at_start(%arg0: vector<1x4xf32>) -> vector<3x4xf32> {
%0 = vector.broadcast %arg0 : vector<1x4xf32> to vector<3x4xf32>
return %0 : vector<3x4xf32>
}
// CHECK-LABEL: broadcast_stretch_at_start
// CHECK-LABEL: llvm.func @broadcast_stretch_at_start
// CHECK: llvm.mlir.undef : !llvm<"[3 x <4 x float>]">
// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <4 x float>]">
// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[3 x <4 x float>]">
@ -120,7 +120,7 @@ func @broadcast_stretch_at_end(%arg0: vector<4x1xf32>) -> vector<4x3xf32> {
%0 = vector.broadcast %arg0 : vector<4x1xf32> to vector<4x3xf32>
return %0 : vector<4x3xf32>
}
// CHECK-LABEL: broadcast_stretch_at_end
// CHECK-LABEL: llvm.func @broadcast_stretch_at_end
// CHECK: llvm.mlir.undef : !llvm<"[4 x <3 x float>]">
// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[4 x <1 x float>]">
// CHECK: llvm.mlir.undef : !llvm<"<3 x float>">
@ -160,7 +160,7 @@ func @broadcast_stretch_in_middle(%arg0: vector<4x1x2xf32>) -> vector<4x3x2xf32>
%0 = vector.broadcast %arg0 : vector<4x1x2xf32> to vector<4x3x2xf32>
return %0 : vector<4x3x2xf32>
}
// CHECK-LABEL: broadcast_stretch_in_middle
// CHECK-LABEL: llvm.func @broadcast_stretch_in_middle
// CHECK: llvm.mlir.undef : !llvm<"[4 x [3 x <2 x float>]]">
// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[4 x [1 x <2 x float>]]">
// CHECK: llvm.mlir.undef : !llvm<"[3 x <2 x float>]">
@ -204,7 +204,7 @@ func @outerproduct(%arg0: vector<2xf32>, %arg1: vector<3xf32>) -> vector<2x3xf32
%2 = vector.outerproduct %arg0, %arg1 : vector<2xf32>, vector<3xf32>
return %2 : vector<2x3xf32>
}
// CHECK-LABEL: outerproduct
// CHECK-LABEL: llvm.func @outerproduct
// CHECK: llvm.mlir.undef : !llvm<"[2 x <3 x float>]">
// CHECK: llvm.shufflevector {{.*}} [0 : i32, 0 : i32, 0 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
// CHECK: llvm.fmul {{.*}}, {{.*}} : !llvm<"<3 x float>">
@ -218,7 +218,7 @@ func @outerproduct_add(%arg0: vector<2xf32>, %arg1: vector<3xf32>, %arg2: vector
%2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<2xf32>, vector<3xf32>
return %2 : vector<2x3xf32>
}
// CHECK-LABEL: outerproduct_add
// CHECK-LABEL: llvm.func @outerproduct_add
// CHECK: llvm.mlir.undef : !llvm<"[2 x <3 x float>]">
// CHECK: llvm.shufflevector {{.*}} [0 : i32, 0 : i32, 0 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[2 x <3 x float>]">
@ -234,34 +234,38 @@ func @shuffle_1D_direct(%arg0: vector<2xf32>, %arg1: vector<2xf32>) -> vector<2x
%1 = vector.shuffle %arg0, %arg1 [0, 1] : vector<2xf32>, vector<2xf32>
return %1 : vector<2xf32>
}
// CHECK-LABEL: shuffle_1D_direct(%arg0: !llvm<"<2 x float>">, %arg1: !llvm<"<2 x float>">)
// CHECK: %[[s:.*]] = llvm.shufflevector %arg0, %arg1 [0, 1] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
// CHECK-LABEL: llvm.func @shuffle_1D_direct
// CHECK-SAME: %[[A:arg[0-9]+]]: !llvm<"<2 x float>">
// CHECK-SAME: %[[B:arg[0-9]+]]: !llvm<"<2 x float>">
// CHECK: %[[s:.*]] = llvm.shufflevector %[[A]], %[[B]] [0, 1] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
// CHECK: llvm.return %[[s]] : !llvm<"<2 x float>">
func @shuffle_1D(%arg0: vector<2xf32>, %arg1: vector<3xf32>) -> vector<5xf32> {
%1 = vector.shuffle %arg0, %arg1 [4, 3, 2, 1, 0] : vector<2xf32>, vector<3xf32>
return %1 : vector<5xf32>
}
// CHECK-LABEL: shuffle_1D(%arg0: !llvm<"<2 x float>">, %arg1: !llvm<"<3 x float>">)
// CHECK-LABEL: llvm.func @shuffle_1D
// CHECK-SAME: %[[A:arg[0-9]+]]: !llvm<"<2 x float>">
// CHECK-SAME: %[[B:arg[0-9]+]]: !llvm<"<3 x float>">
// CHECK: %[[u0:.*]] = llvm.mlir.undef : !llvm<"<5 x float>">
// CHECK: %[[c2:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64
// CHECK: %[[e1:.*]] = llvm.extractelement %arg1[%[[c2]] : !llvm.i64] : !llvm<"<3 x float>">
// CHECK: %[[e1:.*]] = llvm.extractelement %[[B]][%[[c2]] : !llvm.i64] : !llvm<"<3 x float>">
// CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK: %[[i1:.*]] = llvm.insertelement %[[e1]], %[[u0]][%[[c0]] : !llvm.i64] : !llvm<"<5 x float>">
// CHECK: %[[c1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK: %[[e2:.*]] = llvm.extractelement %arg1[%[[c1]] : !llvm.i64] : !llvm<"<3 x float>">
// CHECK: %[[e2:.*]] = llvm.extractelement %[[B]][%[[c1]] : !llvm.i64] : !llvm<"<3 x float>">
// CHECK: %[[c1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK: %[[i2:.*]] = llvm.insertelement %[[e2]], %[[i1]][%[[c1]] : !llvm.i64] : !llvm<"<5 x float>">
// CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK: %[[e3:.*]] = llvm.extractelement %arg1[%[[c0]] : !llvm.i64] : !llvm<"<3 x float>">
// CHECK: %[[e3:.*]] = llvm.extractelement %[[B]][%[[c0]] : !llvm.i64] : !llvm<"<3 x float>">
// CHECK: %[[c2:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64
// CHECK: %[[i3:.*]] = llvm.insertelement %[[e3]], %[[i2]][%[[c2]] : !llvm.i64] : !llvm<"<5 x float>">
// CHECK: %[[c1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK: %[[e4:.*]] = llvm.extractelement %arg0[%[[c1]] : !llvm.i64] : !llvm<"<2 x float>">
// CHECK: %[[e4:.*]] = llvm.extractelement %[[A]][%[[c1]] : !llvm.i64] : !llvm<"<2 x float>">
// CHECK: %[[c3:.*]] = llvm.mlir.constant(3 : index) : !llvm.i64
// CHECK: %[[i4:.*]] = llvm.insertelement %[[e4]], %[[i3]][%[[c3]] : !llvm.i64] : !llvm<"<5 x float>">
// CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK: %[[e5:.*]] = llvm.extractelement %arg0[%[[c0]] : !llvm.i64] : !llvm<"<2 x float>">
// CHECK: %[[e5:.*]] = llvm.extractelement %[[A]][%[[c0]] : !llvm.i64] : !llvm<"<2 x float>">
// CHECK: %[[c4:.*]] = llvm.mlir.constant(4 : index) : !llvm.i64
// CHECK: %[[i5:.*]] = llvm.insertelement %[[e5]], %[[i4]][%[[c4]] : !llvm.i64] : !llvm<"<5 x float>">
// CHECK: llvm.return %[[i5]] : !llvm<"<5 x float>">
@ -270,13 +274,15 @@ func @shuffle_2D(%a: vector<1x4xf32>, %b: vector<2x4xf32>) -> vector<3x4xf32> {
%1 = vector.shuffle %a, %b[1, 0, 2] : vector<1x4xf32>, vector<2x4xf32>
return %1 : vector<3x4xf32>
}
// CHECK-LABEL: shuffle_2D(%arg0: !llvm<"[1 x <4 x float>]">, %arg1: !llvm<"[2 x <4 x float>]">)
// CHECK-LABEL: llvm.func @shuffle_2D
// CHECK-SAME: %[[A:arg[0-9]+]]: !llvm<"[1 x <4 x float>]">
// CHECK-SAME: %[[B:arg[0-9]+]]: !llvm<"[2 x <4 x float>]">
// CHECK: %[[u0:.*]] = llvm.mlir.undef : !llvm<"[3 x <4 x float>]">
// CHECK: %[[e1:.*]] = llvm.extractvalue %arg1[0] : !llvm<"[2 x <4 x float>]">
// CHECK: %[[e1:.*]] = llvm.extractvalue %[[B]][0] : !llvm<"[2 x <4 x float>]">
// CHECK: %[[i1:.*]] = llvm.insertvalue %[[e1]], %[[u0]][0] : !llvm<"[3 x <4 x float>]">
// CHECK: %[[e2:.*]] = llvm.extractvalue %arg0[0] : !llvm<"[1 x <4 x float>]">
// CHECK: %[[e2:.*]] = llvm.extractvalue %[[A]][0] : !llvm<"[1 x <4 x float>]">
// CHECK: %[[i2:.*]] = llvm.insertvalue %[[e2]], %[[i1]][1] : !llvm<"[3 x <4 x float>]">
// CHECK: %[[e3:.*]] = llvm.extractvalue %arg1[1] : !llvm<"[2 x <4 x float>]">
// CHECK: %[[e3:.*]] = llvm.extractvalue %[[B]][1] : !llvm<"[2 x <4 x float>]">
// CHECK: %[[i3:.*]] = llvm.insertvalue %[[e3]], %[[i2]][2] : !llvm<"[3 x <4 x float>]">
// CHECK: llvm.return %[[i3]] : !llvm<"[3 x <4 x float>]">
@ -285,16 +291,17 @@ func @extract_element(%arg0: vector<16xf32>) -> f32 {
%1 = vector.extractelement %arg0[%0 : i32]: vector<16xf32>
return %1 : f32
}
// CHECK-LABEL: extract_element(%arg0: !llvm<"<16 x float>">)
// CHECK-LABEL: llvm.func @extract_element
// CHECK-SAME: %[[A:arg[0-9]+]]: !llvm<"<16 x float>">
// CHECK: %[[c:.*]] = llvm.mlir.constant(15 : i32) : !llvm.i32
// CHECK: %[[x:.*]] = llvm.extractelement %arg0[%[[c]] : !llvm.i32] : !llvm<"<16 x float>">
// CHECK: %[[x:.*]] = llvm.extractelement %[[A]][%[[c]] : !llvm.i32] : !llvm<"<16 x float>">
// CHECK: llvm.return %[[x]] : !llvm.float
func @extract_element_from_vec_1d(%arg0: vector<16xf32>) -> f32 {
%0 = vector.extract %arg0[15]: vector<16xf32>
return %0 : f32
}
// CHECK-LABEL: extract_element_from_vec_1d
// CHECK-LABEL: llvm.func @extract_element_from_vec_1d
// CHECK: llvm.mlir.constant(15 : i64) : !llvm.i64
// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<16 x float>">
// CHECK: llvm.return {{.*}} : !llvm.float
@ -303,7 +310,7 @@ func @extract_vec_2d_from_vec_3d(%arg0: vector<4x3x16xf32>) -> vector<3x16xf32>
%0 = vector.extract %arg0[0]: vector<4x3x16xf32>
return %0 : vector<3x16xf32>
}
// CHECK-LABEL: extract_vec_2d_from_vec_3d
// CHECK-LABEL: llvm.func @extract_vec_2d_from_vec_3d
// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[4 x [3 x <16 x float>]]">
// CHECK: llvm.return {{.*}} : !llvm<"[3 x <16 x float>]">
@ -311,7 +318,7 @@ func @extract_vec_1d_from_vec_3d(%arg0: vector<4x3x16xf32>) -> vector<16xf32> {
%0 = vector.extract %arg0[0, 0]: vector<4x3x16xf32>
return %0 : vector<16xf32>
}
// CHECK-LABEL: extract_vec_1d_from_vec_3d
// CHECK-LABEL: llvm.func @extract_vec_1d_from_vec_3d
// CHECK: llvm.extractvalue {{.*}}[0, 0] : !llvm<"[4 x [3 x <16 x float>]]">
// CHECK: llvm.return {{.*}} : !llvm<"<16 x float>">
@ -319,7 +326,7 @@ func @extract_element_from_vec_3d(%arg0: vector<4x3x16xf32>) -> f32 {
%0 = vector.extract %arg0[0, 0, 0]: vector<4x3x16xf32>
return %0 : f32
}
// CHECK-LABEL: extract_element_from_vec_3d
// CHECK-LABEL: llvm.func @extract_element_from_vec_3d
// CHECK: llvm.extractvalue {{.*}}[0, 0] : !llvm<"[4 x [3 x <16 x float>]]">
// CHECK: llvm.mlir.constant(0 : i64) : !llvm.i64
// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<16 x float>">
@ -330,16 +337,18 @@ func @insert_element(%arg0: f32, %arg1: vector<4xf32>) -> vector<4xf32> {
%1 = vector.insertelement %arg0, %arg1[%0 : i32] : vector<4xf32>
return %1 : vector<4xf32>
}
// CHECK-LABEL: insert_element(%arg0: !llvm.float, %arg1: !llvm<"<4 x float>">)
// CHECK-LABEL: llvm.func @insert_element
// CHECK-SAME: %[[A:arg[0-9]+]]: !llvm.float
// CHECK-SAME: %[[B:arg[0-9]+]]: !llvm<"<4 x float>">
// CHECK: %[[c:.*]] = llvm.mlir.constant(3 : i32) : !llvm.i32
// CHECK: %[[x:.*]] = llvm.insertelement %arg0, %arg1[%[[c]] : !llvm.i32] : !llvm<"<4 x float>">
// CHECK: %[[x:.*]] = llvm.insertelement %[[A]], %[[B]][%[[c]] : !llvm.i32] : !llvm<"<4 x float>">
// CHECK: llvm.return %[[x]] : !llvm<"<4 x float>">
func @insert_element_into_vec_1d(%arg0: f32, %arg1: vector<4xf32>) -> vector<4xf32> {
%0 = vector.insert %arg0, %arg1[3] : f32 into vector<4xf32>
return %0 : vector<4xf32>
}
// CHECK-LABEL: insert_element_into_vec_1d
// CHECK-LABEL: llvm.func @insert_element_into_vec_1d
// CHECK: llvm.mlir.constant(3 : i64) : !llvm.i64
// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<4 x float>">
// CHECK: llvm.return {{.*}} : !llvm<"<4 x float>">
@ -348,7 +357,7 @@ func @insert_vec_2d_into_vec_3d(%arg0: vector<8x16xf32>, %arg1: vector<4x8x16xf3
%0 = vector.insert %arg0, %arg1[3] : vector<8x16xf32> into vector<4x8x16xf32>
return %0 : vector<4x8x16xf32>
}
// CHECK-LABEL: insert_vec_2d_into_vec_3d
// CHECK-LABEL: llvm.func @insert_vec_2d_into_vec_3d
// CHECK: llvm.insertvalue {{.*}}, {{.*}}[3] : !llvm<"[4 x [8 x <16 x float>]]">
// CHECK: llvm.return {{.*}} : !llvm<"[4 x [8 x <16 x float>]]">
@ -356,7 +365,7 @@ func @insert_vec_1d_into_vec_3d(%arg0: vector<16xf32>, %arg1: vector<4x8x16xf32>
%0 = vector.insert %arg0, %arg1[3, 7] : vector<16xf32> into vector<4x8x16xf32>
return %0 : vector<4x8x16xf32>
}
// CHECK-LABEL: insert_vec_1d_into_vec_3d
// CHECK-LABEL: llvm.func @insert_vec_1d_into_vec_3d
// CHECK: llvm.insertvalue {{.*}}, {{.*}}[3, 7] : !llvm<"[4 x [8 x <16 x float>]]">
// CHECK: llvm.return {{.*}} : !llvm<"[4 x [8 x <16 x float>]]">
@ -364,7 +373,7 @@ func @insert_element_into_vec_3d(%arg0: f32, %arg1: vector<4x8x16xf32>) -> vecto
%0 = vector.insert %arg0, %arg1[3, 7, 15] : f32 into vector<4x8x16xf32>
return %0 : vector<4x8x16xf32>
}
// CHECK-LABEL: insert_element_into_vec_3d
// CHECK-LABEL: llvm.func @insert_element_into_vec_3d
// CHECK: llvm.extractvalue {{.*}}[3, 7] : !llvm<"[4 x [8 x <16 x float>]]">
// CHECK: llvm.mlir.constant(15 : i64) : !llvm.i64
// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<16 x float>">
@ -375,7 +384,7 @@ func @vector_type_cast(%arg0: memref<8x8x8xf32>) -> memref<vector<8x8x8xf32>> {
%0 = vector.type_cast %arg0: memref<8x8x8xf32> to memref<vector<8x8x8xf32>>
return %0 : memref<vector<8x8x8xf32>>
}
// CHECK-LABEL: vector_type_cast
// CHECK-LABEL: llvm.func @vector_type_cast
// CHECK: llvm.mlir.undef : !llvm<"{ [8 x [8 x <8 x float>]]*, [8 x [8 x <8 x float>]]*, i64 }">
// CHECK: %[[allocated:.*]] = llvm.extractvalue {{.*}}[0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
// CHECK: %[[allocatedBit:.*]] = llvm.bitcast %[[allocated]] : !llvm<"float*"> to !llvm<"[8 x [8 x <8 x float>]]*">
@ -390,17 +399,19 @@ func @vector_print_scalar(%arg0: f32) {
vector.print %arg0 : f32
return
}
// CHECK-LABEL: vector_print_scalar(%arg0: !llvm.float)
// CHECK: llvm.call @print_f32(%arg0) : (!llvm.float) -> ()
// CHECK-LABEL: llvm.func @vector_print_scalar
// CHECK-SAME: %[[A:arg[0-9]+]]: !llvm.float
// CHECK: llvm.call @print_f32(%[[A]]) : (!llvm.float) -> ()
// CHECK: llvm.call @print_newline() : () -> ()
func @vector_print_vector(%arg0: vector<2x2xf32>) {
vector.print %arg0 : vector<2x2xf32>
return
}
// CHECK-LABEL: vector_print_vector(%arg0: !llvm<"[2 x <2 x float>]">)
// CHECK-LABEL: llvm.func @vector_print_vector
// CHECK-SAME: %[[A:arg[0-9]+]]: !llvm<"[2 x <2 x float>]">
// CHECK: llvm.call @print_open() : () -> ()
// CHECK: %[[x0:.*]] = llvm.extractvalue %arg0[0] : !llvm<"[2 x <2 x float>]">
// CHECK: %[[x0:.*]] = llvm.extractvalue %[[A]][0] : !llvm<"[2 x <2 x float>]">
// CHECK: llvm.call @print_open() : () -> ()
// CHECK: %[[x1:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK: %[[x2:.*]] = llvm.extractelement %[[x0]][%[[x1]] : !llvm.i64] : !llvm<"<2 x float>">
@ -411,7 +422,7 @@ func @vector_print_vector(%arg0: vector<2x2xf32>) {
// CHECK: llvm.call @print_f32(%[[x4]]) : (!llvm.float) -> ()
// CHECK: llvm.call @print_close() : () -> ()
// CHECK: llvm.call @print_comma() : () -> ()
// CHECK: %[[x5:.*]] = llvm.extractvalue %arg0[1] : !llvm<"[2 x <2 x float>]">
// CHECK: %[[x5:.*]] = llvm.extractvalue %[[A]][1] : !llvm<"[2 x <2 x float>]">
// CHECK: llvm.call @print_open() : () -> ()
// CHECK: %[[x6:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK: %[[x7:.*]] = llvm.extractelement %[[x5]][%[[x6]] : !llvm.i64] : !llvm<"<2 x float>">
@ -492,7 +503,7 @@ func @insert_strided_slice1(%b: vector<4x4xf32>, %c: vector<4x4x4xf32>) -> vecto
%0 = vector.insert_strided_slice %b, %c {offsets = [2, 0, 0], strides = [1, 1]} : vector<4x4xf32> into vector<4x4x4xf32>
return %0 : vector<4x4x4xf32>
}
// CHECK-LABEL: @insert_strided_slice1
// CHECK-LABEL: llvm.func @insert_strided_slice1
// CHECK: llvm.extractvalue {{.*}}[2] : !llvm<"[4 x [4 x <4 x float>]]">
// CHECK-NEXT: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[4 x [4 x <4 x float>]]">
@ -500,7 +511,7 @@ func @insert_strided_slice2(%a: vector<2x2xf32>, %b: vector<4x4xf32>) -> vector<
%0 = vector.insert_strided_slice %a, %b {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
return %0 : vector<4x4xf32>
}
// CHECK-LABEL: @insert_strided_slice2
// CHECK-LABEL: llvm.func @insert_strided_slice2
//
// Subvector vector<2xf32> @0 into vector<4xf32> @2
// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[2 x <2 x float>]">
@ -532,15 +543,66 @@ func @insert_strided_slice2(%a: vector<2x2xf32>, %b: vector<4x4xf32>) -> vector<
// CHECK-NEXT: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<4 x float>">
// CHECK-NEXT: llvm.insertvalue {{.*}}, {{.*}}[3] : !llvm<"[4 x <4 x float>]">
func @insert_strided_slice3(%arg0: vector<2x4xf32>, %arg1: vector<16x4x8xf32>) -> vector<16x4x8xf32> {
%0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 2], strides = [1, 1]}:
vector<2x4xf32> into vector<16x4x8xf32>
return %0 : vector<16x4x8xf32>
}
// CHECK-LABEL: llvm.func @insert_strided_slice3
// CHECK-SAME: %[[A:arg[0-9]+]]: !llvm<"[2 x <4 x float>]">
// CHECK-SAME: %[[B:arg[0-9]+]]: !llvm<"[16 x [4 x <8 x float>]]">
// CHECK: %[[s0:.*]] = llvm.extractvalue %[[B]][0] : !llvm<"[16 x [4 x <8 x float>]]">
// CHECK: %[[s1:.*]] = llvm.extractvalue %[[A]][0] : !llvm<"[2 x <4 x float>]">
// CHECK: %[[s2:.*]] = llvm.extractvalue %[[s0]][0] : !llvm<"[4 x <8 x float>]">
// CHECK: %[[s3:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK: %[[s4:.*]] = llvm.extractelement %[[s1]][%[[s3]] : !llvm.i64] : !llvm<"<4 x float>">
// CHECK: %[[s5:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64
// CHECK: %[[s6:.*]] = llvm.insertelement %[[s4]], %[[s2]][%[[s5]] : !llvm.i64] : !llvm<"<8 x float>">
// CHECK: %[[s7:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK: %[[s8:.*]] = llvm.extractelement %[[s1]][%[[s7]] : !llvm.i64] : !llvm<"<4 x float>">
// CHECK: %[[s9:.*]] = llvm.mlir.constant(3 : index) : !llvm.i64
// CHECK: %[[s10:.*]] = llvm.insertelement %[[s8]], %[[s6]][%[[s9]] : !llvm.i64] : !llvm<"<8 x float>">
// CHECK: %[[s11:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64
// CHECK: %[[s12:.*]] = llvm.extractelement %[[s1]][%[[s11]] : !llvm.i64] : !llvm<"<4 x float>">
// CHECK: %[[s13:.*]] = llvm.mlir.constant(4 : index) : !llvm.i64
// CHECK: %[[s14:.*]] = llvm.insertelement %[[s12]], %[[s10]][%[[s13]] : !llvm.i64] : !llvm<"<8 x float>">
// CHECK: %[[s15:.*]] = llvm.mlir.constant(3 : index) : !llvm.i64
// CHECK: %[[s16:.*]] = llvm.extractelement %[[s1]][%[[s15]] : !llvm.i64] : !llvm<"<4 x float>">
// CHECK: %[[s17:.*]] = llvm.mlir.constant(5 : index) : !llvm.i64
// CHECK: %[[s18:.*]] = llvm.insertelement %[[s16]], %[[s14]][%[[s17]] : !llvm.i64] : !llvm<"<8 x float>">
// CHECK: %[[s19:.*]] = llvm.insertvalue %[[s18]], %[[s0]][0] : !llvm<"[4 x <8 x float>]">
// CHECK: %[[s20:.*]] = llvm.extractvalue %[[A]][1] : !llvm<"[2 x <4 x float>]">
// CHECK: %[[s21:.*]] = llvm.extractvalue %[[s0]][1] : !llvm<"[4 x <8 x float>]">
// CHECK: %[[s22:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK: %[[s23:.*]] = llvm.extractelement %[[s20]][%[[s22]] : !llvm.i64] : !llvm<"<4 x float>">
// CHECK: %[[s24:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64
// CHECK: %[[s25:.*]] = llvm.insertelement %[[s23]], %[[s21]][%[[s24]] : !llvm.i64] : !llvm<"<8 x float>">
// CHECK: %[[s26:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK: %[[s27:.*]] = llvm.extractelement %[[s20]][%[[s26]] : !llvm.i64] : !llvm<"<4 x float>">
// CHECK: %[[s28:.*]] = llvm.mlir.constant(3 : index) : !llvm.i64
// CHECK: %[[s29:.*]] = llvm.insertelement %[[s27]], %[[s25]][%[[s28]] : !llvm.i64] : !llvm<"<8 x float>">
// CHECK: %[[s30:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64
// CHECK: %[[s31:.*]] = llvm.extractelement %[[s20]][%[[s30]] : !llvm.i64] : !llvm<"<4 x float>">
// CHECK: %[[s32:.*]] = llvm.mlir.constant(4 : index) : !llvm.i64
// CHECK: %[[s33:.*]] = llvm.insertelement %[[s31]], %[[s29]][%[[s32]] : !llvm.i64] : !llvm<"<8 x float>">
// CHECK: %[[s34:.*]] = llvm.mlir.constant(3 : index) : !llvm.i64
// CHECK: %[[s35:.*]] = llvm.extractelement %[[s20]][%[[s34]] : !llvm.i64] : !llvm<"<4 x float>">
// CHECK: %[[s36:.*]] = llvm.mlir.constant(5 : index) : !llvm.i64
// CHECK: %[[s37:.*]] = llvm.insertelement %[[s35]], %[[s33]][%[[s36]] : !llvm.i64] : !llvm<"<8 x float>">
// CHECK: %[[s38:.*]] = llvm.insertvalue %[[s37]], %[[s19]][1] : !llvm<"[4 x <8 x float>]">
// CHECK: %[[s39:.*]] = llvm.insertvalue %[[s38]], %[[B]][0] : !llvm<"[16 x [4 x <8 x float>]]">
// CHECK: llvm.return %[[s39]] : !llvm<"[16 x [4 x <8 x float>]]">
func @extract_strides(%arg0: vector<3x3xf32>) -> vector<1x1xf32> {
%0 = vector.extract_slices %arg0, [2, 2], [1, 1]
: vector<3x3xf32> into tuple<vector<2x2xf32>, vector<2x1xf32>, vector<1x2xf32>, vector<1x1xf32>>
%1 = vector.tuple_get %0, 3 : tuple<vector<2x2xf32>, vector<2x1xf32>, vector<1x2xf32>, vector<1x1xf32>>
return %1 : vector<1x1xf32>
}
// CHECK-LABEL: extract_strides(%arg0: !llvm<"[3 x <3 x float>]">)
// CHECK-LABEL: llvm.func @extract_strides
// CHECK-SAME: %[[A:arg[0-9]+]]: !llvm<"[3 x <3 x float>]">
// CHECK: %[[s0:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<1x1xf32>) : !llvm<"[1 x <1 x float>]">
// CHECK: %[[s1:.*]] = llvm.extractvalue %arg0[2] : !llvm<"[3 x <3 x float>]">
// CHECK: %[[s1:.*]] = llvm.extractvalue %[[A]][2] : !llvm<"[3 x <3 x float>]">
// CHECK: %[[s3:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<1xf32>) : !llvm<"<1 x float>">
// CHECK: %[[s4:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64
// CHECK: %[[s5:.*]] = llvm.extractelement %[[s1]][%[[s4]] : !llvm.i64] : !llvm<"<3 x float>">