fix nan value process in relu ops backend

This commit is contained in:
lei-yuanzhe 2022-11-02 18:08:24 +08:00
parent 86e0967260
commit f0c456c244
1 changed files with 2 additions and 2 deletions

View File

@ -20,7 +20,8 @@
template <typename T>
__global__ void CalReLUKernel(int size, T *input_addr, T *output_addr) {
for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
output_addr[pos] = input_addr[pos] > static_cast<T>(0) ? input_addr[pos] : static_cast<T>(0);
output_addr[pos] = std::isnan(static_cast<double>(input_addr[pos])) || (input_addr[pos] > static_cast<T>(0))
? input_addr[pos] : static_cast<T>(0);
}
}
@ -103,4 +104,3 @@ template CUDA_LIB_EXPORT void ReluGradV2(const size_t num, const int64_t *dy, co
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void ReluGradV2(const size_t num, const uint8_t *dy, const uint32_t *mask, uint8_t *dx,
cudaStream_t cuda_stream);