[mlir][gpu] NFC: Change gpu.launch_func ops to custom format.

This should fix the reason for the failures after ec7780ebda. I will roll forward in a separate change.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D90410
This commit is contained in:
Christian Sigg 2020-10-29 19:16:19 +01:00
parent 661797bd76
commit b22f111023
3 changed files with 8 additions and 11 deletions

View File

@ -40,9 +40,8 @@ module attributes {gpu.container_module, spv.target_env = #spv.target_env<#spv.v
func @main() {
%buffer = alloc() : memref<6xi32>
%one = constant 1 : index
"gpu.launch_func"(%one, %one, %one,
%one, %one, %one,
%buffer) {kernel = @foo::@bar} : (index, index, index, index, index, index, memref<6xi32>) -> ()
gpu.launch_func @foo::@bar blocks in (%one, %one, %one)
threads in (%one, %one, %one) args(%buffer : memref<6xi32>)
return
}
}

View File

@ -54,10 +54,9 @@ module attributes {
call @fillI32Buffer(%output_casted, %zero) : (memref<?xi32>, i32) -> ()
%one = constant 1 : index
"gpu.launch_func"(%one, %one, %one,
%one, %one, %one,
%input, %output) { kernel = @kernels::@double }
: (index, index, index, index, index, index, memref<6xi32>, memref<6xi32>) -> ()
gpu.launch_func @kernels::@double
blocks in (%one, %one, %one) threads in (%one, %one, %one)
args(%input : memref<6xi32>, %output : memref<6xi32>)
%result = memref_cast %output : memref<6xi32> to memref<*xi32>
call @print_memref_i32(%result) : (memref<*xi32>) -> ()
return

View File

@ -47,10 +47,9 @@ module attributes {
call @fillF32Buffer3D(%output_casted, %0) : (memref<?x?x?xf32>, f32) -> ()
%one = constant 1 : index
"gpu.launch_func"(%one, %one, %one,
%one, %one, %one,
%input1, %input2, %output) { kernel = @kernels::@sum }
: (index, index, index, index, index, index, memref<3xf32>, memref<3x3xf32>, memref<3x3x3xf32>) -> ()
gpu.launch_func @kernels::@sum
blocks in (%one, %one, %one) threads in (%one, %one, %one)
args(%input1 : memref<3xf32>, %input2 : memref<3x3xf32>, %output : memref<3x3x3xf32>)
%result = memref_cast %output : memref<3x3x3xf32> to memref<*xf32>
call @print_memref_f32(%result) : (memref<*xf32>) -> ()
return