forked from mindspore-Ecosystem/mindspore
!19894 topk and intopk float16 bug workaround
Merge pull request !19894 from Peilin/topk-float16-cast-r1.3-fix
This commit is contained in:
commit
070307634b
|
@ -22,6 +22,7 @@
|
|||
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/cast_impl.cuh"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/in_top_k_impl.cuh"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/topk_impl.cuh"
|
||||
|
||||
|
@ -47,12 +48,29 @@ class InTopKGpuKernel : public GpuKernel {
|
|||
T *top_k_output_device = GetDeviceAddress<T>(workspace, 0);
|
||||
int32_t *top_k_indices_device = GetDeviceAddress<int32_t>(workspace, 1);
|
||||
|
||||
// topk sorts the input along the last dimension
|
||||
FastTopK(outer_size_, inner_size_, predictions_device, static_cast<int32_t>(k_), top_k_output_device,
|
||||
top_k_indices_device, top_k_init_, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
if (std::is_same<T, half>::value) {
|
||||
// remove later! urgent fix for bug: topk has incorrect output for float16
|
||||
float top_k_init = std::numeric_limits<float>::lowest();
|
||||
|
||||
CalInTopK(predictions_device, targets_device, output_device, top_k_output_device, input_shape_[0], k_,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
// cast to float32
|
||||
float *casted_float32_input = GetDeviceAddress<float>(workspace, 2);
|
||||
float *top_k_output_device_float32 = GetDeviceAddress<float>(workspace, 3);
|
||||
|
||||
Cast(input_size_, predictions_device, casted_float32_input, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
|
||||
FastTopK(outer_size_, inner_size_, casted_float32_input, static_cast<int32_t>(k_), top_k_output_device_float32,
|
||||
top_k_indices_device, top_k_init, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
|
||||
CalInTopK(casted_float32_input, targets_device, output_device, top_k_output_device_float32, input_shape_[0],
|
||||
input_shape_[1], k_, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
} else {
|
||||
// topk sorts the input along the last dimension
|
||||
FastTopK(outer_size_, inner_size_, predictions_device, static_cast<int32_t>(k_), top_k_output_device,
|
||||
top_k_indices_device, top_k_init_, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
|
||||
CalInTopK(predictions_device, targets_device, output_device, top_k_output_device, input_shape_[0],
|
||||
input_shape_[1], k_, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
@ -114,6 +132,12 @@ class InTopKGpuKernel : public GpuKernel {
|
|||
output_size_list_.push_back(input_shape_[0] * sizeof(bool));
|
||||
workspace_size_list_.push_back(input_shape_[0] * k_ * sizeof(T));
|
||||
workspace_size_list_.push_back(input_shape_[0] * k_ * sizeof(int32_t));
|
||||
|
||||
// remove later! urgent fix for bug: topk has incorrect output for float16
|
||||
if (std::is_same<T, half>::value) {
|
||||
workspace_size_list_.push_back(input_size_ * sizeof(float));
|
||||
workspace_size_list_.push_back(input_shape_[0] * k_ * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include <vector>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/cast_impl.cuh"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/topk_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -42,20 +43,38 @@ class TopKGpuKernel : public GpuKernel {
|
|||
T *output_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
S *indices = GetDeviceAddress<S>(outputs, 1);
|
||||
|
||||
T init_k = std::numeric_limits<T>::lowest();
|
||||
if (std::is_same<T, half>::value) {
|
||||
// min value representable by float16, std::numeric_limits doesn't support half
|
||||
init_k = static_cast<half>(-65504.);
|
||||
}
|
||||
|
||||
S k_cut = 0;
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudaMemcpyAsync(&k_cut, k, sizeof(S), cudaMemcpyDeviceToHost, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync k_cut failed");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaDeviceSynchronize(), "cudaDeviceSyncFailed - TopK");
|
||||
FastTopK(outer_size_, inner_size_, input_addr, k_cut, output_addr, indices, init_k,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
|
||||
if (std::is_same<T, half>::value) {
|
||||
// remove later! urgent fix for bug: topk has incorrect output for float16
|
||||
float init_k = std::numeric_limits<float>::lowest();
|
||||
|
||||
// cast to float32
|
||||
float *casted_float32_input = GetDeviceAddress<float>(workspaces, 0);
|
||||
float *casted_float32_top_k_output = GetDeviceAddress<float>(workspaces, 1);
|
||||
Cast(outer_size_ * inner_size_, input_addr, casted_float32_input, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
|
||||
// call FastTopK with workspace[n], workspace[n+1] as input, output
|
||||
FastTopK(outer_size_, inner_size_, casted_float32_input, k_cut, casted_float32_top_k_output, indices, init_k,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
|
||||
// cast workspace[n+1] back to float16
|
||||
Cast(outer_size_ * k_, casted_float32_top_k_output, output_addr, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
} else {
|
||||
T init_k = std::numeric_limits<T>::lowest();
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudaMemcpyAsync(&k_cut, k, sizeof(S), cudaMemcpyDeviceToHost, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync k_cut failed");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaDeviceSynchronize(), "cudaDeviceSyncFailed - TopK");
|
||||
FastTopK(outer_size_, inner_size_, input_addr, k_cut, output_addr, indices, init_k,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -82,6 +101,12 @@ class TopKGpuKernel : public GpuKernel {
|
|||
input_size_list_.push_back(sizeof(S));
|
||||
output_size_list_.push_back(outer_size_ * k_ * sizeof(T));
|
||||
output_size_list_.push_back(outer_size_ * k_ * sizeof(S));
|
||||
|
||||
// remove later! urgent fix for bug: topk has incorrect output for float16
|
||||
if (std::is_same<T, half>::value) {
|
||||
workspace_size_list_.push_back(outer_size_ * inner_size_ * sizeof(float));
|
||||
workspace_size_list_.push_back(outer_size_ * k_ * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
|
@ -21,9 +21,9 @@
|
|||
|
||||
template <typename T>
|
||||
__global__ void InTopK(const T *predictions, const int32_t *targets, bool *output, const T *top_k_output,
|
||||
size_t class_id_count, int64_t k) {
|
||||
size_t batch_size, size_t class_id_count, int64_t k) {
|
||||
size_t gt_id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
for (; gt_id < class_id_count; gt_id += blockDim.x * gridDim.x) {
|
||||
for (; gt_id < batch_size; gt_id += blockDim.x * gridDim.x) {
|
||||
int32_t target_index = targets[gt_id];
|
||||
T predicted_value = predictions[gt_id * class_id_count + target_index];
|
||||
T top_k_smallest_value = top_k_output[k - 1];
|
||||
|
@ -33,14 +33,15 @@ __global__ void InTopK(const T *predictions, const int32_t *targets, bool *outpu
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void CalInTopK(const T *predictions, const int32_t *targets, bool *output, const T *top_k_output, size_t class_id_count,
|
||||
int64_t k, cudaStream_t cuda_stream) {
|
||||
void CalInTopK(const T *predictions, const int32_t *targets, bool *output, const T *top_k_output, size_t batch_size,
|
||||
size_t class_id_count, int64_t k, cudaStream_t cuda_stream) {
|
||||
InTopK<<<GET_BLOCKS(class_id_count), GET_THREADS, 0, cuda_stream>>>(predictions, targets, output, top_k_output,
|
||||
class_id_count, k);
|
||||
batch_size, class_id_count, k);
|
||||
}
|
||||
|
||||
template void CalInTopK<half>(const half *predictions, const int32_t *targets, bool *output, const half *top_k_output,
|
||||
size_t class_id_count, int64_t k, cudaStream_t cuda_stream);
|
||||
size_t batch_size, size_t class_id_count, int64_t k, cudaStream_t cuda_stream);
|
||||
|
||||
template void CalInTopK<float>(const float *predictions, const int32_t *targets, bool *output,
|
||||
const float *top_k_output, size_t class_id_count, int64_t k, cudaStream_t cuda_stream);
|
||||
const float *top_k_output, size_t batch_size, size_t class_id_count, int64_t k,
|
||||
cudaStream_t cuda_stream);
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
#include <cuda_runtime.h>
|
||||
|
||||
template <typename T>
|
||||
void CalInTopK(const T *predictions, const int32_t *targets, bool *output, const T *top_k_output, size_t class_id_count,
|
||||
int64_t k, cudaStream_t cuda_stream);
|
||||
void CalInTopK(const T *predictions, const int32_t *targets, bool *output, const T *top_k_output, size_t batch_size,
|
||||
size_t class_id_count, int64_t k, cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_IN_TOP_K_CUH_
|
||||
|
|
|
@ -32,9 +32,9 @@ class InTopKNet(nn.Cell):
|
|||
def in_top_k(nptype):
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
|
||||
predictions = Tensor(np.array([[9, 3, 8],
|
||||
[7, 9, 9],
|
||||
[9, 9, 9]]).astype(nptype))
|
||||
predictions = Tensor(np.array([[9, 3, 8, 0, 0, 0, 0, 0, 0],
|
||||
[7, 9, 9, 0, 0, 0, 0, 0, 0],
|
||||
[9, 9, 9, 0, 0, 0, 0, 0, 0]]).astype(nptype))
|
||||
|
||||
k = 1
|
||||
in_top_k_net = InTopKNet(k)
|
||||
|
|
Loading…
Reference in New Issue