forked from mindspore-Ecosystem/mindspore
add general float16 support initial commit
This commit is contained in:
parent
ede0139b06
commit
694d5eb97b
|
@ -41,7 +41,13 @@ class TopKGpuKernel : public GpuKernel {
|
|||
S *k = GetDeviceAddress<S>(inputs, 1);
|
||||
T *output_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
S *indices = GetDeviceAddress<S>(outputs, 1);
|
||||
const T init_k = std::numeric_limits<T>::lowest();
|
||||
|
||||
T init_k = std::numeric_limits<T>::lowest();
|
||||
if (std::is_same<T, half>::value) {
|
||||
// min value representable by float16, std::numeric_limits doesn't support half
|
||||
init_k = static_cast<half>(-65504.);
|
||||
}
|
||||
|
||||
S k_cut = 0;
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
|
|
Loading…
Reference in New Issue