diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/prelu_grad_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/prelu_grad_impl.cu index 2b2cd7d5dbc..f151aea026d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/prelu_grad_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/prelu_grad_impl.cu @@ -22,11 +22,12 @@ template __global__ void CalPReLUGradKernel(size_t size, size_t weight_size, size_t per_channel_size, const T *dy, const T *x, const T *w, T *dx, float *dw_array) { for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x; size_t channel_id = weight_size == 1 ? 0 : (pos / per_channel_size) % weight_size; - dx[pos] = pos[x] <= static_cast(0) ? w[channel_id] * dy[pos] : dy[pos]; - - if (pos[x] < static_cast(0)) { - size_t index = channel_id * blockDim.x * gridDim.x + pos; + size_t index = channel_id * blockDim.x * gridDim.x + thread_id; + T threshold = static_cast(0); + dx[pos] = x[pos] <= threshold ? w[channel_id] * dy[pos] : dy[pos]; + if (x[pos] < threshold) { dw_array[index] += static_cast(x[pos] * dy[pos]); } }