fix nan value process in relu ops backend
This commit is contained in:
parent
86e0967260
commit
f0c456c244
|
@ -20,7 +20,8 @@
|
|||
template <typename T>
|
||||
__global__ void CalReLUKernel(int size, T *input_addr, T *output_addr) {
|
||||
for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
|
||||
output_addr[pos] = input_addr[pos] > static_cast<T>(0) ? input_addr[pos] : static_cast<T>(0);
|
||||
output_addr[pos] = std::isnan(static_cast<double>(input_addr[pos])) || (input_addr[pos] > static_cast<T>(0))
|
||||
? input_addr[pos] : static_cast<T>(0);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -103,4 +104,3 @@ template CUDA_LIB_EXPORT void ReluGradV2(const size_t num, const int64_t *dy, co
|
|||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void ReluGradV2(const size_t num, const uint8_t *dy, const uint32_t *mask, uint8_t *dx,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
|
|
Loading…
Reference in New Issue