!20168 fix bug prelu grad

Merge pull request !20168 from zhangbuxue/fix_bug_prelu_grad
This commit is contained in:
i-robot 2021-07-13 16:31:22 +00:00 committed by Gitee
commit 9ceea5a365
1 changed files with 5 additions and 4 deletions

View File

@ -22,11 +22,12 @@ template <typename T>
__global__ void CalPReLUGradKernel(size_t size, size_t weight_size, size_t per_channel_size,
const T *dy, const T *x, const T *w, T *dx, float *dw_array) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
size_t channel_id = weight_size == 1 ? 0 : (pos / per_channel_size) % weight_size;
dx[pos] = pos[x] <= static_cast<T>(0) ? w[channel_id] * dy[pos] : dy[pos];
if (pos[x] < static_cast<T>(0)) {
size_t index = channel_id * blockDim.x * gridDim.x + pos;
size_t index = channel_id * blockDim.x * gridDim.x + thread_id;
T threshold = static_cast<T>(0);
dx[pos] = x[pos] <= threshold ? w[channel_id] * dy[pos] : dy[pos];
if (x[pos] < threshold) {
dw_array[index] += static_cast<float>(x[pos] * dy[pos]);
}
}