diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/in_top_k_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/in_top_k_gpu_kernel.h index 30f21e86d38..10e790b6711 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/in_top_k_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/in_top_k_gpu_kernel.h @@ -22,6 +22,7 @@ #include "backend/kernel_compiler/gpu/gpu_kernel.h" #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/cast_impl.cuh" #include "backend/kernel_compiler/gpu/cuda_impl/in_top_k_impl.cuh" #include "backend/kernel_compiler/gpu/cuda_impl/topk_impl.cuh" @@ -47,12 +48,29 @@ class InTopKGpuKernel : public GpuKernel { T *top_k_output_device = GetDeviceAddress(workspace, 0); int32_t *top_k_indices_device = GetDeviceAddress(workspace, 1); - // topk sorts the input along the last dimension - FastTopK(outer_size_, inner_size_, predictions_device, static_cast(k_), top_k_output_device, - top_k_indices_device, top_k_init_, reinterpret_cast(stream_ptr)); + if (std::is_same::value) { + // remove later! urgent fix for bug: topk has incorrect output for float16 + float top_k_init = std::numeric_limits::lowest(); - CalInTopK(predictions_device, targets_device, output_device, top_k_output_device, input_shape_[0], k_, - reinterpret_cast(stream_ptr)); + // cast to float32 + float *casted_float32_input = GetDeviceAddress(workspace, 2); + float *top_k_output_device_float32 = GetDeviceAddress(workspace, 3); + + Cast(input_size_, predictions_device, casted_float32_input, reinterpret_cast(stream_ptr)); + + FastTopK(outer_size_, inner_size_, casted_float32_input, static_cast(k_), top_k_output_device_float32, + top_k_indices_device, top_k_init, reinterpret_cast(stream_ptr)); + + CalInTopK(casted_float32_input, targets_device, output_device, top_k_output_device_float32, input_shape_[0], + input_shape_[1], k_, reinterpret_cast(stream_ptr)); + } else { + // topk sorts the input along the last dimension + FastTopK(outer_size_, inner_size_, predictions_device, static_cast(k_), top_k_output_device, + top_k_indices_device, top_k_init_, reinterpret_cast(stream_ptr)); + + CalInTopK(predictions_device, targets_device, output_device, top_k_output_device, input_shape_[0], + input_shape_[1], k_, reinterpret_cast(stream_ptr)); + } return true; } @@ -114,6 +132,12 @@ class InTopKGpuKernel : public GpuKernel { output_size_list_.push_back(input_shape_[0] * sizeof(bool)); workspace_size_list_.push_back(input_shape_[0] * k_ * sizeof(T)); workspace_size_list_.push_back(input_shape_[0] * k_ * sizeof(int32_t)); + + // remove later! urgent fix for bug: topk has incorrect output for float16 + if (std::is_same::value) { + workspace_size_list_.push_back(input_size_ * sizeof(float)); + workspace_size_list_.push_back(input_shape_[0] * k_ * sizeof(float)); + } } private: diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/topk_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/topk_gpu_kernel.h index b07a37f621d..4fb7806fca8 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/topk_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/topk_gpu_kernel.h @@ -21,6 +21,7 @@ #include #include "backend/kernel_compiler/gpu/gpu_kernel.h" #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/cast_impl.cuh" #include "backend/kernel_compiler/gpu/cuda_impl/topk_impl.cuh" namespace mindspore { @@ -42,20 +43,38 @@ class TopKGpuKernel : public GpuKernel { T *output_addr = GetDeviceAddress(outputs, 0); S *indices = GetDeviceAddress(outputs, 1); - T init_k = std::numeric_limits::lowest(); - if (std::is_same::value) { - // min value representable by float16, std::numeric_limits doesn't support half - init_k = static_cast(-65504.); - } - S k_cut = 0; CHECK_CUDA_RET_WITH_EXCEPT( kernel_node_, cudaMemcpyAsync(&k_cut, k, sizeof(S), cudaMemcpyDeviceToHost, reinterpret_cast(stream_ptr)), "cudaMemcpyAsync k_cut failed"); CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaDeviceSynchronize(), "cudaDeviceSyncFailed - TopK"); - FastTopK(outer_size_, inner_size_, input_addr, k_cut, output_addr, indices, init_k, - reinterpret_cast(stream_ptr)); + + if (std::is_same::value) { + // remove later! urgent fix for bug: topk has incorrect output for float16 + float init_k = std::numeric_limits::lowest(); + + // cast to float32 + float *casted_float32_input = GetDeviceAddress(workspaces, 0); + float *casted_float32_top_k_output = GetDeviceAddress(workspaces, 1); + Cast(outer_size_ * inner_size_, input_addr, casted_float32_input, reinterpret_cast(stream_ptr)); + + // call FastTopK with workspace[n], workspace[n+1] as input, output + FastTopK(outer_size_, inner_size_, casted_float32_input, k_cut, casted_float32_top_k_output, indices, init_k, + reinterpret_cast(stream_ptr)); + + // cast workspace[n+1] back to float16 + Cast(outer_size_ * k_, casted_float32_top_k_output, output_addr, reinterpret_cast(stream_ptr)); + } else { + T init_k = std::numeric_limits::lowest(); + CHECK_CUDA_RET_WITH_EXCEPT( + kernel_node_, + cudaMemcpyAsync(&k_cut, k, sizeof(S), cudaMemcpyDeviceToHost, reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync k_cut failed"); + CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaDeviceSynchronize(), "cudaDeviceSyncFailed - TopK"); + FastTopK(outer_size_, inner_size_, input_addr, k_cut, output_addr, indices, init_k, + reinterpret_cast(stream_ptr)); + } return true; } @@ -82,6 +101,12 @@ class TopKGpuKernel : public GpuKernel { input_size_list_.push_back(sizeof(S)); output_size_list_.push_back(outer_size_ * k_ * sizeof(T)); output_size_list_.push_back(outer_size_ * k_ * sizeof(S)); + + // remove later! urgent fix for bug: topk has incorrect output for float16 + if (std::is_same::value) { + workspace_size_list_.push_back(outer_size_ * inner_size_ * sizeof(float)); + workspace_size_list_.push_back(outer_size_ * k_ * sizeof(float)); + } } private: diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/in_top_k_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/in_top_k_impl.cu index 19a13d631c0..e5da715acb1 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/in_top_k_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/in_top_k_impl.cu @@ -21,9 +21,9 @@ template __global__ void InTopK(const T *predictions, const int32_t *targets, bool *output, const T *top_k_output, - size_t class_id_count, int64_t k) { + size_t batch_size, size_t class_id_count, int64_t k) { size_t gt_id = blockIdx.x * blockDim.x + threadIdx.x; - for (; gt_id < class_id_count; gt_id += blockDim.x * gridDim.x) { + for (; gt_id < batch_size; gt_id += blockDim.x * gridDim.x) { int32_t target_index = targets[gt_id]; T predicted_value = predictions[gt_id * class_id_count + target_index]; T top_k_smallest_value = top_k_output[k - 1]; @@ -33,14 +33,15 @@ __global__ void InTopK(const T *predictions, const int32_t *targets, bool *outpu } template -void CalInTopK(const T *predictions, const int32_t *targets, bool *output, const T *top_k_output, size_t class_id_count, - int64_t k, cudaStream_t cuda_stream) { +void CalInTopK(const T *predictions, const int32_t *targets, bool *output, const T *top_k_output, size_t batch_size, + size_t class_id_count, int64_t k, cudaStream_t cuda_stream) { InTopK<<>>(predictions, targets, output, top_k_output, - class_id_count, k); + batch_size, class_id_count, k); } template void CalInTopK(const half *predictions, const int32_t *targets, bool *output, const half *top_k_output, - size_t class_id_count, int64_t k, cudaStream_t cuda_stream); + size_t batch_size, size_t class_id_count, int64_t k, cudaStream_t cuda_stream); template void CalInTopK(const float *predictions, const int32_t *targets, bool *output, - const float *top_k_output, size_t class_id_count, int64_t k, cudaStream_t cuda_stream); + const float *top_k_output, size_t batch_size, size_t class_id_count, int64_t k, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/in_top_k_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/in_top_k_impl.cuh index c5d2829bdf8..f72b20ab434 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/in_top_k_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/in_top_k_impl.cuh @@ -20,7 +20,7 @@ #include template -void CalInTopK(const T *predictions, const int32_t *targets, bool *output, const T *top_k_output, size_t class_id_count, - int64_t k, cudaStream_t cuda_stream); +void CalInTopK(const T *predictions, const int32_t *targets, bool *output, const T *top_k_output, size_t batch_size, + size_t class_id_count, int64_t k, cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_IN_TOP_K_CUH_ diff --git a/tests/st/ops/gpu/test_in_top_k.py b/tests/st/ops/gpu/test_in_top_k.py index 32f7b753637..fe4df4170cb 100644 --- a/tests/st/ops/gpu/test_in_top_k.py +++ b/tests/st/ops/gpu/test_in_top_k.py @@ -32,9 +32,9 @@ class InTopKNet(nn.Cell): def in_top_k(nptype): context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - predictions = Tensor(np.array([[9, 3, 8], - [7, 9, 9], - [9, 9, 9]]).astype(nptype)) + predictions = Tensor(np.array([[9, 3, 8, 0, 0, 0, 0, 0, 0], + [7, 9, 9, 0, 0, 0, 0, 0, 0], + [9, 9, 9, 0, 0, 0, 0, 0, 0]]).astype(nptype)) k = 1 in_top_k_net = InTopKNet(k)