fix: conv bias backward (#248)

This commit is contained in:
Nathaniel Simard 2023-03-23 16:10:26 -04:00 committed by GitHub
parent 6f43d983f7
commit a74e4cd0bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 16 additions and 15 deletions

View File

@ -1,5 +1,5 @@
use super::{Conv1dBackward, Conv2dBackward};
use crate::{backend::Backend, ElementConversion, Shape};
use crate::{backend::Backend, Shape};
use libm::ceilf;
/// Calculate the expected padding size required when applying a convolution with the specified
@ -50,7 +50,7 @@ pub(crate) fn conv1d_backward<B: Backend>(
) -> Conv1dBackward<B> {
// TODO: Fix the backward pass when using stride > 1.
let [batch_size, _channels_in, length_in] = B::shape(&x).dims;
let [_batch_size, _channels_out, length_out] = B::shape(&output_grad).dims;
let [_batch_size, channels_out, length_out] = B::shape(&output_grad).dims;
let [_, _, kernel_size] = B::shape(&weight).dims;
let output_grad_tmp = output_grad.clone();
@ -63,7 +63,7 @@ pub(crate) fn conv1d_backward<B: Backend>(
let padding = calculate_padding(length_out, stride, length_in, kernel_size);
let x_tmp = B::swap_dims(x, 0, 1);
let output_grad_tmp = B::swap_dims(output_grad, 0, 1);
let output_grad_tmp = B::swap_dims(output_grad.clone(), 0, 1);
let weight_grad = B::conv1d(x_tmp, output_grad_tmp, None, stride, padding);
let weight_grad = B::swap_dims(weight_grad, 0, 1);
@ -71,12 +71,11 @@ pub(crate) fn conv1d_backward<B: Backend>(
x_grad,
weight_grad,
bias.map(|b| {
let elem = batch_size * length_out;
let elem = (elem as i32).elem();
let grad = B::swap_dims(output_grad, 0, 1);
let grad = B::reshape(grad, Shape::new([channels_out, batch_size * length_out]));
let grad = B::sum_dim(grad, 1);
let b = B::zeros(B::shape(&b), &B::device(&b));
B::add_scalar(b, elem)
B::reshape(grad, B::shape(&b))
}),
)
}
@ -90,7 +89,7 @@ pub(crate) fn conv2d_backward<B: Backend>(
) -> Conv2dBackward<B> {
// TODO: Fix the backward pass when using stride > 1.
let [batch_size, _channels_in, height_in, width_in] = B::shape(&x).dims;
let [_batch_size, _channels_out, height_out, width_out] = B::shape(&output_grad).dims;
let [_batch_size, channels_out, height_out, width_out] = B::shape(&output_grad).dims;
let [_, _, kernel_size_1, kernel_size_2] = B::shape(&weight).dims;
let [stride_1, stride_2] = stride;
@ -112,7 +111,7 @@ pub(crate) fn conv2d_backward<B: Backend>(
let padding_2 = calculate_padding(width_out, stride_2, width_in, kernel_size_2);
let x_tmp = B::swap_dims(x, 0, 1);
let output_grad_tmp = B::swap_dims(output_grad, 0, 1);
let output_grad_tmp = B::swap_dims(output_grad.clone(), 0, 1);
let weight_grad = B::conv2d(
x_tmp,
output_grad_tmp,
@ -126,12 +125,14 @@ pub(crate) fn conv2d_backward<B: Backend>(
x_grad,
weight_grad,
bias.map(|b| {
let elem = batch_size * width_out * height_out;
let elem = (elem as i32).elem();
let grad = B::swap_dims(output_grad, 0, 1);
let grad = B::reshape(
grad,
Shape::new([channels_out, batch_size * height_out * width_out]),
);
let grad = B::sum_dim(grad, 1);
let b = B::zeros(B::shape(&b), &B::device(&b));
B::add_scalar(b, elem)
B::reshape(grad, B::shape(&b))
}),
)
}