!20168 fix bug prelu grad
Merge pull request !20168 from zhangbuxue/fix_bug_prelu_grad
This commit is contained in:
commit
9ceea5a365
|
@ -22,11 +22,12 @@ template <typename T>
|
|||
__global__ void CalPReLUGradKernel(size_t size, size_t weight_size, size_t per_channel_size,
|
||||
const T *dy, const T *x, const T *w, T *dx, float *dw_array) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
|
||||
size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
size_t channel_id = weight_size == 1 ? 0 : (pos / per_channel_size) % weight_size;
|
||||
dx[pos] = pos[x] <= static_cast<T>(0) ? w[channel_id] * dy[pos] : dy[pos];
|
||||
|
||||
if (pos[x] < static_cast<T>(0)) {
|
||||
size_t index = channel_id * blockDim.x * gridDim.x + pos;
|
||||
size_t index = channel_id * blockDim.x * gridDim.x + thread_id;
|
||||
T threshold = static_cast<T>(0);
|
||||
dx[pos] = x[pos] <= threshold ? w[channel_id] * dy[pos] : dy[pos];
|
||||
if (x[pos] < threshold) {
|
||||
dw_array[index] += static_cast<float>(x[pos] * dy[pos]);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue