From 7515f4a9ef6207587f7c1c2cbe511e293bd1ffc9 Mon Sep 17 00:00:00 2001 From: buxue Date: Tue, 13 Jul 2021 16:07:51 +0800 Subject: [PATCH] fix bug prelu grad --- .../kernel_compiler/gpu/cuda_impl/prelu_grad_impl.cu | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/prelu_grad_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/prelu_grad_impl.cu index 2b2cd7d5dbc..f151aea026d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/prelu_grad_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/prelu_grad_impl.cu @@ -22,11 +22,12 @@ template __global__ void CalPReLUGradKernel(size_t size, size_t weight_size, size_t per_channel_size, const T *dy, const T *x, const T *w, T *dx, float *dw_array) { for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x; size_t channel_id = weight_size == 1 ? 0 : (pos / per_channel_size) % weight_size; - dx[pos] = pos[x] <= static_cast(0) ? w[channel_id] * dy[pos] : dy[pos]; - - if (pos[x] < static_cast(0)) { - size_t index = channel_id * blockDim.x * gridDim.x + pos; + size_t index = channel_id * blockDim.x * gridDim.x + thread_id; + T threshold = static_cast(0); + dx[pos] = x[pos] <= threshold ? w[channel_id] * dy[pos] : dy[pos]; + if (x[pos] < threshold) { dw_array[index] += static_cast(x[pos] * dy[pos]); } }