forked from OSchip/llvm-project
[mlir][Linalg] Add a roundtrip test for indexed_generic op with tensors.
Summary: After D72555 has been landed, `linalg.indexed_generic` also accepts ranked tensor as input and output. Add a test for it. Differential Revision: https://reviews.llvm.org/D74267
This commit is contained in:
parent
bc8e442188
commit
4687822b9e
|
@ -310,6 +310,41 @@ func @generic_with_tensor_input_and_output(
|
|||
|
||||
// -----
|
||||
|
||||
func @foo(%i: index, %j: index, %k: index, %0: vector<3x4xi4>, %1: f32) -> f32 {
|
||||
%f0 = constant 0.0 : f32
|
||||
return %f0 : f32
|
||||
}
|
||||
|
||||
#accesses = [
|
||||
affine_map<(i, j, k) -> (j, i)>,
|
||||
affine_map<(i, j, k) -> (i, k, i + j)>
|
||||
]
|
||||
|
||||
#trait2 = {
|
||||
args_in = 2,
|
||||
args_out = 1,
|
||||
indexing_maps = #accesses,
|
||||
iterator_types = ["parallel", "parallel", "parallel"],
|
||||
fun = @foo,
|
||||
library_call = "some_external_function_name_1"
|
||||
}
|
||||
|
||||
func @indexed_generic_with_tensor_input_and_output(
|
||||
%arg0: tensor<?x?xvector<3x4xi4>>, %arg1: tensor<?x?x?xf32>)
|
||||
-> (tensor<?x?x?xf32>) {
|
||||
%0 = linalg.indexed_generic #trait2 %arg0, %arg1 {foo = 1} :
|
||||
tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32> -> tensor<?x?x?xf32>
|
||||
return %0 : tensor<?x?x?xf32>
|
||||
}
|
||||
// CHECK-LABEL: func @indexed_generic_with_tensor_input_and_output
|
||||
// CHECK: linalg.indexed_generic {args_in = 2 : i64, args_out = 1 : i64, fun = @foo,
|
||||
// CHECK-SAME: indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"],
|
||||
// CHECK-SAME: library_call = "some_external_function_name_1"} %{{.*}}, %{{.*}} {foo = 1 : i64}:
|
||||
// CHECK-SAME: tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32> -> tensor<?x?x?xf32>
|
||||
// CHECK: return {{.*}} : tensor<?x?x?xf32>
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
|
||||
// CHECK-DAG: #[[strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>
|
||||
|
||||
|
|
Loading…
Reference in New Issue