add general float16 support initial commit

This commit is contained in:
Peilin Wang 2021-07-05 18:20:03 -04:00
parent ede0139b06
commit 694d5eb97b
1 changed files with 7 additions and 1 deletions

View File

@ -41,7 +41,13 @@ class TopKGpuKernel : public GpuKernel {
S *k = GetDeviceAddress<S>(inputs, 1);
T *output_addr = GetDeviceAddress<T>(outputs, 0);
S *indices = GetDeviceAddress<S>(outputs, 1);
const T init_k = std::numeric_limits<T>::lowest();
T init_k = std::numeric_limits<T>::lowest();
if (std::is_same<T, half>::value) {
// min value representable by float16, std::numeric_limits doesn't support half
init_k = static_cast<half>(-65504.);
}
S k_cut = 0;
CHECK_CUDA_RET_WITH_EXCEPT(
kernel_node_,