mirror of https://github.com/tracel-ai/burn.git
fix: conv bias backward (#248)
This commit is contained in:
parent
6f43d983f7
commit
a74e4cd0bc
|
@ -1,5 +1,5 @@
|
|||
use super::{Conv1dBackward, Conv2dBackward};
|
||||
use crate::{backend::Backend, ElementConversion, Shape};
|
||||
use crate::{backend::Backend, Shape};
|
||||
use libm::ceilf;
|
||||
|
||||
/// Calculate the expected padding size required when applying a convolution with the specified
|
||||
|
@ -50,7 +50,7 @@ pub(crate) fn conv1d_backward<B: Backend>(
|
|||
) -> Conv1dBackward<B> {
|
||||
// TODO: Fix the backward pass when using stride > 1.
|
||||
let [batch_size, _channels_in, length_in] = B::shape(&x).dims;
|
||||
let [_batch_size, _channels_out, length_out] = B::shape(&output_grad).dims;
|
||||
let [_batch_size, channels_out, length_out] = B::shape(&output_grad).dims;
|
||||
let [_, _, kernel_size] = B::shape(&weight).dims;
|
||||
|
||||
let output_grad_tmp = output_grad.clone();
|
||||
|
@ -63,7 +63,7 @@ pub(crate) fn conv1d_backward<B: Backend>(
|
|||
let padding = calculate_padding(length_out, stride, length_in, kernel_size);
|
||||
|
||||
let x_tmp = B::swap_dims(x, 0, 1);
|
||||
let output_grad_tmp = B::swap_dims(output_grad, 0, 1);
|
||||
let output_grad_tmp = B::swap_dims(output_grad.clone(), 0, 1);
|
||||
let weight_grad = B::conv1d(x_tmp, output_grad_tmp, None, stride, padding);
|
||||
let weight_grad = B::swap_dims(weight_grad, 0, 1);
|
||||
|
||||
|
@ -71,12 +71,11 @@ pub(crate) fn conv1d_backward<B: Backend>(
|
|||
x_grad,
|
||||
weight_grad,
|
||||
bias.map(|b| {
|
||||
let elem = batch_size * length_out;
|
||||
let elem = (elem as i32).elem();
|
||||
let grad = B::swap_dims(output_grad, 0, 1);
|
||||
let grad = B::reshape(grad, Shape::new([channels_out, batch_size * length_out]));
|
||||
let grad = B::sum_dim(grad, 1);
|
||||
|
||||
let b = B::zeros(B::shape(&b), &B::device(&b));
|
||||
|
||||
B::add_scalar(b, elem)
|
||||
B::reshape(grad, B::shape(&b))
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
@ -90,7 +89,7 @@ pub(crate) fn conv2d_backward<B: Backend>(
|
|||
) -> Conv2dBackward<B> {
|
||||
// TODO: Fix the backward pass when using stride > 1.
|
||||
let [batch_size, _channels_in, height_in, width_in] = B::shape(&x).dims;
|
||||
let [_batch_size, _channels_out, height_out, width_out] = B::shape(&output_grad).dims;
|
||||
let [_batch_size, channels_out, height_out, width_out] = B::shape(&output_grad).dims;
|
||||
let [_, _, kernel_size_1, kernel_size_2] = B::shape(&weight).dims;
|
||||
let [stride_1, stride_2] = stride;
|
||||
|
||||
|
@ -112,7 +111,7 @@ pub(crate) fn conv2d_backward<B: Backend>(
|
|||
let padding_2 = calculate_padding(width_out, stride_2, width_in, kernel_size_2);
|
||||
|
||||
let x_tmp = B::swap_dims(x, 0, 1);
|
||||
let output_grad_tmp = B::swap_dims(output_grad, 0, 1);
|
||||
let output_grad_tmp = B::swap_dims(output_grad.clone(), 0, 1);
|
||||
let weight_grad = B::conv2d(
|
||||
x_tmp,
|
||||
output_grad_tmp,
|
||||
|
@ -126,12 +125,14 @@ pub(crate) fn conv2d_backward<B: Backend>(
|
|||
x_grad,
|
||||
weight_grad,
|
||||
bias.map(|b| {
|
||||
let elem = batch_size * width_out * height_out;
|
||||
let elem = (elem as i32).elem();
|
||||
let grad = B::swap_dims(output_grad, 0, 1);
|
||||
let grad = B::reshape(
|
||||
grad,
|
||||
Shape::new([channels_out, batch_size * height_out * width_out]),
|
||||
);
|
||||
let grad = B::sum_dim(grad, 1);
|
||||
|
||||
let b = B::zeros(B::shape(&b), &B::device(&b));
|
||||
|
||||
B::add_scalar(b, elem)
|
||||
B::reshape(grad, B::shape(&b))
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue