回退 'Pull Request !45031 : fix nan value process in relu ops backend'

This commit is contained in:
fengyixing 2022-11-04 07:18:11 +00:00 committed by Gitee
parent 845866ea8d
commit 98068cc552
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 2 additions and 2 deletions

View File

@ -20,8 +20,7 @@
template <typename T>
__global__ void CalReLUKernel(int size, T *input_addr, T *output_addr) {
for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
output_addr[pos] = std::isnan(static_cast<double>(input_addr[pos])) || (input_addr[pos] > static_cast<T>(0))
? input_addr[pos] : static_cast<T>(0);
output_addr[pos] = input_addr[pos] > static_cast<T>(0) ? input_addr[pos] : static_cast<T>(0);
}
}
@ -104,3 +103,4 @@ template CUDA_LIB_EXPORT void ReluGradV2(const size_t num, const int64_t *dy, co
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void ReluGradV2(const size_t num, const uint8_t *dy, const uint32_t *mask, uint8_t *dx,
cudaStream_t cuda_stream);