[mlir][Vector] Add some missing tests for `broadcast` and `splat`

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D114853
This commit is contained in:
Michal Terepeta 2021-12-03 07:56:21 +00:00 committed by Nicolas Vasilache
parent 829b29b619
commit 8e2b373396
3 changed files with 25 additions and 7 deletions

View File

@ -62,3 +62,10 @@ func @constant_complex_f64() -> complex<f64> {
%result = constant [0.1 : f64, -1.0 : f64] : complex<f64>
return %result : complex<f64>
}
// CHECK-LABEL: func @vector_splat_0d(
func @vector_splat_0d(%a: f32) -> vector<f32> {
// CHECK: splat %{{.*}} : vector<f32>
%0 = splat %a : vector<f32>
return %0 : vector<f32>
}

View File

@ -16,6 +16,13 @@ func @broadcast_rank_too_high(%arg0: vector<4x4xf32>) {
// -----
func @broadcast_rank_too_high_0d(%arg0: vector<1xf32>) {
// expected-error@+1 {{'vector.broadcast' op source rank higher than destination rank}}
%1 = vector.broadcast %arg0 : vector<1xf32> to vector<f32>
}
// -----
func @broadcast_dim1_mismatch(%arg0: vector<7xf32>) {
// expected-error@+1 {{'vector.broadcast' op dimension mismatch (7 vs. 3)}}
%1 = vector.broadcast %arg0 : vector<7xf32> to vector<3xf32>
@ -79,7 +86,7 @@ func @extract_element(%arg0: vector<f32>) {
}
// -----
func @extract_element(%arg0: vector<4xf32>) {
%c = arith.constant 3 : i32
// expected-error@+1 {{expected position for 1-D vector}}

View File

@ -149,16 +149,20 @@ func @vector_transfer_ops_tensor(%arg0: tensor<?x?xf32>,
}
// CHECK-LABEL: @vector_broadcast
func @vector_broadcast(%a: f32, %b: vector<16xf32>, %c: vector<1x16xf32>, %d: vector<8x1xf32>) -> vector<8x16xf32> {
func @vector_broadcast(%a: f32, %b: vector<f32>, %c: vector<16xf32>, %d: vector<1x16xf32>, %e: vector<8x1xf32>) -> vector<8x16xf32> {
// CHECK: vector.broadcast %{{.*}} : f32 to vector<f32>
%0 = vector.broadcast %a : f32 to vector<f32>
// CHECK: vector.broadcast %{{.*}} : vector<f32> to vector<4xf32>
%1 = vector.broadcast %b : vector<f32> to vector<4xf32>
// CHECK: vector.broadcast %{{.*}} : f32 to vector<16xf32>
%0 = vector.broadcast %a : f32 to vector<16xf32>
%2 = vector.broadcast %a : f32 to vector<16xf32>
// CHECK-NEXT: vector.broadcast %{{.*}} : vector<16xf32> to vector<8x16xf32>
%1 = vector.broadcast %b : vector<16xf32> to vector<8x16xf32>
%3 = vector.broadcast %c : vector<16xf32> to vector<8x16xf32>
// CHECK-NEXT: vector.broadcast %{{.*}} : vector<1x16xf32> to vector<8x16xf32>
%2 = vector.broadcast %c : vector<1x16xf32> to vector<8x16xf32>
%4 = vector.broadcast %d : vector<1x16xf32> to vector<8x16xf32>
// CHECK-NEXT: vector.broadcast %{{.*}} : vector<8x1xf32> to vector<8x16xf32>
%3 = vector.broadcast %d : vector<8x1xf32> to vector<8x16xf32>
return %3 : vector<8x16xf32>
%5 = vector.broadcast %e : vector<8x1xf32> to vector<8x16xf32>
return %4 : vector<8x16xf32>
}
// CHECK-LABEL: @shuffle1D