From f0c456c244ab9734217d55537cd2ba7c837b5938 Mon Sep 17 00:00:00 2001 From: lei-yuanzhe Date: Wed, 2 Nov 2022 18:08:24 +0800 Subject: [PATCH] fix nan value process in relu ops backend --- .../plugin/device/gpu/kernel/cuda_impl/cuda_ops/relu_impl.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/relu_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/relu_impl.cu index a4cef64d7f5..20fc562e361 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/relu_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/relu_impl.cu @@ -20,7 +20,8 @@ template __global__ void CalReLUKernel(int size, T *input_addr, T *output_addr) { for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { - output_addr[pos] = input_addr[pos] > static_cast(0) ? input_addr[pos] : static_cast(0); + output_addr[pos] = std::isnan(static_cast(input_addr[pos])) || (input_addr[pos] > static_cast(0)) + ? input_addr[pos] : static_cast(0); } } @@ -103,4 +104,3 @@ template CUDA_LIB_EXPORT void ReluGradV2(const size_t num, const int64_t *dy, co cudaStream_t cuda_stream); template CUDA_LIB_EXPORT void ReluGradV2(const size_t num, const uint8_t *dy, const uint32_t *mask, uint8_t *dx, cudaStream_t cuda_stream); -