Support dilation in convolution operations (#301)

This commit is contained in:
Nathaniel Simard 2023-04-18 10:01:11 -04:00 committed by GitHub
parent bd58922784
commit 78ac09fb7a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 459 additions and 141 deletions

View File

@ -52,6 +52,7 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
bias: Option<ADTensor<B, 1>>,
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
) -> ADTensor<B, 4> {
#[derive(Debug)]
struct Conv2DWithBias;
@ -64,14 +65,17 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
B::TensorPrimitive<4>,
B::TensorPrimitive<1>,
[usize; 2],
[usize; 2],
[usize; 2],
);
fn backward(self, ops: Ops<Self::State, 3>, grads: &mut Gradients) {
let [node_x, node_weight, node_bias] = ops.parents;
let grad = grads.consume::<B, 4>(&ops.node);
let (x, weight, bias, stride) = ops.state;
let backward = B::conv2d_backward(x, weight, Some(bias), stride, grad);
let (x, weight, bias, stride, padding, dilation) = ops.state;
let backward =
B::conv2d_backward(x, weight, Some(bias), stride, padding, dilation, grad);
if let Some(node) = node_x {
grads.register::<B, 4>(node, backward.x_grad)
@ -86,14 +90,20 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
}
impl<B: Backend> Backward<B, 4, 2> for Conv2DNoBias {
type State = (B::TensorPrimitive<4>, B::TensorPrimitive<4>, [usize; 2]);
type State = (
B::TensorPrimitive<4>,
B::TensorPrimitive<4>,
[usize; 2],
[usize; 2],
[usize; 2],
);
fn backward(self, ops: Ops<Self::State, 2>, grads: &mut Gradients) {
let [node_x, node_weight] = ops.parents;
let grad = grads.consume::<B, 4>(&ops.node);
let (x, weight, stride) = ops.state;
let backward = B::conv2d_backward(x, weight, None, stride, grad);
let (x, weight, stride, padding, dilation) = ops.state;
let backward = B::conv2d_backward(x, weight, None, stride, padding, dilation, grad);
if let Some(node) = node_x {
grads.register::<B, 4>(node, backward.x_grad)
@ -119,6 +129,8 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
weight.primitive.clone(),
bias.primitive.clone(),
stride,
padding,
dilation,
),
B::conv2d(
x.primitive,
@ -126,6 +138,7 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
Some(bias.primitive),
stride,
padding,
dilation,
),
),
OpsKind::UnTracked(prep) => prep.finish(B::conv2d(
@ -134,6 +147,7 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
Some(bias.primitive),
stride,
padding,
dilation,
)),
}
}
@ -143,8 +157,21 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
.statefull()
{
OpsKind::Tracked(prep) => prep.finish(
(x.primitive.clone(), weight.primitive.clone(), stride),
B::conv2d(x.primitive, weight.primitive, None, stride, padding),
(
x.primitive.clone(),
weight.primitive.clone(),
stride,
padding,
dilation,
),
B::conv2d(
x.primitive,
weight.primitive,
None,
stride,
padding,
dilation,
),
),
OpsKind::UnTracked(prep) => prep.finish(B::conv2d(
x.primitive,
@ -152,6 +179,7 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
None,
stride,
padding,
dilation,
)),
}
}
@ -164,7 +192,8 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
_bias: Option<ADTensor<B, 1>>,
_stride: [usize; 2],
_padding: [usize; 2],
_out_padding: [usize; 2],
_padding_out: [usize; 2],
_dilation: [usize; 2],
) -> ADTensor<B, 4> {
todo!("Transposed 2D convolution doesn't yet support backward.");
}
@ -175,7 +204,8 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
_bias: Option<ADTensor<B, 1>>,
_stride: usize,
_padding: usize,
_out_padding: usize,
_padding_out: usize,
_dilation: usize,
) -> ADTensor<B, 3> {
todo!("Transposed 1D convolution doesn't yet support backward.");
}
@ -186,6 +216,7 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
bias: Option<ADTensor<B, 1>>,
stride: usize,
padding: usize,
dilation: usize,
) -> ADTensor<B, 3> {
#[derive(Debug)]
struct Conv1DWithBias;
@ -198,14 +229,17 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
B::TensorPrimitive<3>,
B::TensorPrimitive<1>,
usize,
usize,
usize,
);
fn backward(self, ops: Ops<Self::State, 3>, grads: &mut Gradients) {
let [node_x, node_weight, node_bias] = ops.parents;
let grad = grads.consume::<B, 3>(&ops.node);
let (x, weight, bias, stride) = ops.state;
let backward = B::conv1d_backward(x, weight, Some(bias), stride, grad);
let (x, weight, bias, stride, padding, dilation) = ops.state;
let backward =
B::conv1d_backward(x, weight, Some(bias), stride, padding, dilation, grad);
if let Some(node) = node_x {
grads.register::<B, 3>(node, backward.x_grad)
@ -220,14 +254,20 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
}
impl<B: Backend> Backward<B, 3, 2> for Conv1DNoBias {
type State = (B::TensorPrimitive<3>, B::TensorPrimitive<3>, usize);
type State = (
B::TensorPrimitive<3>,
B::TensorPrimitive<3>,
usize,
usize,
usize,
);
fn backward(self, ops: Ops<Self::State, 2>, grads: &mut Gradients) {
let [node_x, node_weight] = ops.parents;
let grad = grads.consume::<B, 3>(&ops.node);
let (x, weight, stride) = ops.state;
let backward = B::conv1d_backward(x, weight, None, stride, grad);
let (x, weight, stride, padding, dilation) = ops.state;
let backward = B::conv1d_backward(x, weight, None, stride, padding, dilation, grad);
if let Some(node) = node_x {
grads.register::<B, 3>(node, backward.x_grad)
@ -252,6 +292,8 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
weight.primitive.clone(),
bias.primitive.clone(),
stride,
padding,
dilation,
),
B::conv1d(
x.primitive,
@ -259,6 +301,7 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
Some(bias.primitive),
stride,
padding,
dilation,
),
),
OpsKind::UnTracked(prep) => prep.finish(B::conv1d(
@ -267,6 +310,7 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
Some(bias.primitive),
stride,
padding,
dilation,
)),
}
}
@ -276,8 +320,21 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
.statefull()
{
OpsKind::Tracked(prep) => prep.finish(
(x.primitive.clone(), weight.primitive.clone(), stride),
B::conv1d(x.primitive, weight.primitive, None, stride, padding),
(
x.primitive.clone(),
weight.primitive.clone(),
stride,
padding,
dilation,
),
B::conv1d(
x.primitive,
weight.primitive,
None,
stride,
padding,
dilation,
),
),
OpsKind::UnTracked(prep) => prep.finish(B::conv1d(
x.primitive,
@ -285,6 +342,7 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
None,
stride,
padding,
dilation,
)),
}
}

View File

@ -12,6 +12,7 @@ mod tests {
kernel_size: 3,
padding: 1,
stride: 1,
dilation: 1,
length: 6,
};
let grads = Grads {
@ -46,6 +47,7 @@ mod tests {
kernel_size: 3,
padding: 1,
stride: 1,
dilation: 1,
length: 6,
};
let grads = Grads {
@ -72,6 +74,7 @@ mod tests {
kernel_size: 3,
padding: 2,
stride: 1,
dilation: 1,
length: 6,
};
let grads = Grads {
@ -97,6 +100,7 @@ mod tests {
kernel_size: 3,
padding: 1,
stride: 2,
dilation: 1,
length: 4,
};
let grads = Grads {
@ -113,6 +117,32 @@ mod tests {
test.assert_grads(grads);
}
#[test]
fn test_conv1d_dilation() {
let test = Conv1dTestCase {
batch_size: 2,
channels_in: 2,
channels_out: 2,
kernel_size: 3,
padding: 1,
stride: 1,
dilation: 2,
length: 4,
};
let grads = Grads {
x: TestTensor::from_floats([
[[2., 2., 2., 2.], [2., 2., 2., 2.]],
[[2., 2., 2., 2.], [2., 2., 2., 2.]],
]),
weight: TestTensor::from_floats([
[[2., 4., 2.], [2., 4., 2.]],
[[2., 4., 2.], [2., 4., 2.]],
]),
bias: TestTensor::from_floats([4., 4.]),
};
test.assert_grads(grads);
}
struct Conv1dTestCase {
batch_size: usize,
channels_in: usize,
@ -120,6 +150,7 @@ mod tests {
kernel_size: usize,
padding: usize,
stride: usize,
dilation: usize,
length: usize,
}
@ -143,6 +174,7 @@ mod tests {
Some(bias.clone()),
self.stride,
self.padding,
self.dilation,
);
let grads = output.backward();

View File

@ -15,6 +15,8 @@ mod tests {
padding_2: 1,
stride_1: 1,
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
height: 6,
width: 6,
};
@ -107,6 +109,8 @@ mod tests {
padding_2: 1,
stride_1: 1,
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
height: 6,
width: 6,
};
@ -180,6 +184,8 @@ mod tests {
padding_2: 1,
stride_1: 1,
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
height: 6,
width: 6,
};
@ -265,6 +271,8 @@ mod tests {
padding_2: 2,
stride_1: 1,
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
height: 6,
width: 6,
};
@ -334,6 +342,8 @@ mod tests {
padding_2: 1,
stride_1: 1,
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
height: 6,
width: 5,
};
@ -403,6 +413,8 @@ mod tests {
padding_2: 1,
stride_1: 2,
stride_2: 2,
dilation_1: 1,
dilation_2: 1,
height: 8,
width: 8,
};
@ -480,6 +492,8 @@ mod tests {
padding_2: 1,
stride_1: 3,
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
height: 8,
width: 8,
};
@ -545,6 +559,48 @@ mod tests {
test.assert_grads(grads);
}
#[test]
fn test_conv2d_complex() {
let test = Conv2dTestCase {
batch_size: 1,
channels_in: 2,
channels_out: 3,
kernel_size_1: 2,
kernel_size_2: 3,
padding_1: 1,
padding_2: 2,
stride_1: 1,
stride_2: 2,
dilation_1: 2,
dilation_2: 3,
height: 4,
width: 5,
};
let grads = Grads {
x: TestTensor::from_floats([[
[
[3., 3., 0., 3., 3.],
[6., 6., 0., 6., 6.],
[6., 6., 0., 6., 6.],
[3., 3., 0., 3., 3.],
],
[
[3., 3., 0., 3., 3.],
[6., 6., 0., 6., 6.],
[6., 6., 0., 6., 6.],
[3., 3., 0., 3., 3.],
],
]]),
weight: TestTensor::from_floats([
[[[3., 6., 3.], [3., 6., 3.]], [[3., 6., 3.], [3., 6., 3.]]],
[[[3., 6., 3.], [3., 6., 3.]], [[3., 6., 3.], [3., 6., 3.]]],
[[[3., 6., 3.], [3., 6., 3.]], [[3., 6., 3.], [3., 6., 3.]]],
]),
bias: TestTensor::from_floats([8., 8., 8.]),
};
test.assert_grads(grads);
}
struct Conv2dTestCase {
batch_size: usize,
channels_in: usize,
@ -555,6 +611,8 @@ mod tests {
padding_2: usize,
stride_1: usize,
stride_2: usize,
dilation_1: usize,
dilation_2: usize,
height: usize,
width: usize,
}
@ -584,6 +642,7 @@ mod tests {
Some(bias.clone()),
[self.stride_1, self.stride_2],
[self.padding_1, self.padding_2],
[self.dilation_1, self.dilation_2],
);
let grads = output.backward();

View File

@ -7,7 +7,7 @@ use crate::nn::Initializer;
use crate::tensor::backend::Backend;
use crate::tensor::Tensor;
use burn_tensor::module::conv1d;
use burn_tensor::ops::conv::calculate_padding;
use burn_tensor::ops::conv::calculate_conv_padding;
use libm::sqrt;
@ -23,6 +23,9 @@ pub struct Conv1dConfig {
/// The stride of the convolution.
#[config(default = "1")]
pub stride: usize,
/// Spacing between kernel elements.
#[config(default = "1")]
pub dilation: usize,
/// The padding configuration.
#[config(default = "Conv1dPaddingConfig::Valid")]
pub padding: Conv1dPaddingConfig,
@ -61,6 +64,7 @@ pub struct Conv1d<B: Backend> {
bias: Option<Param<Tensor<B, 1>>>,
stride: usize,
kernel_size: usize,
dilation: usize,
padding: Conv1dPaddingConfig,
}
@ -90,6 +94,7 @@ impl Conv1dConfig {
stride: 1, // TODO: Add the stride to the config when properly supported.
kernel_size: self.kernel_size,
padding: self.padding.clone(),
dilation: self.dilation,
}
}
/// Initialize a new [conv1d](Conv1d) module with a [record](Conv1dRecord).
@ -100,6 +105,7 @@ impl Conv1dConfig {
stride: 1, // TODO: Add the stride to the config when properly supported.
kernel_size: self.kernel_size,
padding: self.padding.clone(),
dilation: self.dilation,
}
}
}
@ -114,7 +120,7 @@ impl<B: Backend> Conv1d<B> {
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
let same_padding = || {
let [_batch_size, _channels_in, length] = input.dims();
calculate_padding(self.kernel_size, self.stride, length, length)
calculate_conv_padding(self.kernel_size, self.stride, length, length)
};
let padding = match &self.padding {
@ -129,6 +135,7 @@ impl<B: Backend> Conv1d<B> {
self.bias.as_ref().map(|bias| bias.val()),
self.stride,
padding,
self.dilation,
)
}
}

View File

@ -7,7 +7,7 @@ use crate::nn::Initializer;
use crate::tensor::backend::Backend;
use crate::tensor::Tensor;
use burn_tensor::module::conv2d;
use burn_tensor::ops::conv::calculate_padding;
use burn_tensor::ops::conv::calculate_conv_padding;
use libm::sqrt;
@ -21,6 +21,9 @@ pub struct Conv2dConfig {
/// The stride of the convolution.
#[config(default = "[1, 1]")]
pub stride: [usize; 2],
/// Spacing between kernel elements.
#[config(default = "[1, 1]")]
pub dilation: [usize; 2],
/// The padding configuration.
#[config(default = "Conv2dPaddingConfig::Valid")]
pub padding: Conv2dPaddingConfig,
@ -59,6 +62,7 @@ pub struct Conv2d<B: Backend> {
bias: Option<Param<Tensor<B, 1>>>,
stride: [usize; 2],
kernel_size: [usize; 2],
dilation: [usize; 2],
padding: Conv2dPaddingConfig,
}
@ -90,8 +94,9 @@ impl Conv2dConfig {
Conv2d {
weight: Param::from(weight),
bias: bias.map(Param::from),
stride: [1, 1], // TODO: Add the stride to the config when properly supported.
stride: self.stride,
kernel_size: self.kernel_size,
dilation: self.dilation,
padding: self.padding.clone(),
}
}
@ -102,6 +107,7 @@ impl Conv2dConfig {
weight: record.weight,
bias: record.bias,
stride: self.stride,
dilation: self.dilation,
kernel_size: self.kernel_size,
padding: self.padding.clone(),
}
@ -126,6 +132,7 @@ impl<B: Backend> Conv2d<B> {
self.bias.as_ref().map(|bias| bias.val()),
self.stride,
padding,
self.dilation,
)
}
}
@ -139,8 +146,8 @@ impl Conv2dPaddingConfig {
stride: &[usize; 2],
) -> [usize; 2] {
let same_padding = || {
let p1 = calculate_padding(kernel_size[0], stride[0], height, height);
let p2 = calculate_padding(kernel_size[1], stride[1], width, width);
let p1 = calculate_conv_padding(kernel_size[0], stride[0], height, height);
let p2 = calculate_conv_padding(kernel_size[1], stride[1], width, width);
[p1, p2]
};

View File

@ -1,4 +1,4 @@
use burn_tensor::ElementConversion;
use burn_tensor::{ops::conv::calculate_conv_output_size, ElementConversion};
use ndarray::{Array4, Dim};
use crate::{
@ -20,13 +20,20 @@ pub(crate) fn conv2d<E: FloatNdArrayElement>(
let [batch_size, _in_channels, in_height, in_width] = x.shape().dims;
let [out_channels, in_channels, kernel_height, kernel_width] = weight.shape().dims;
let out_height = (in_height + 2 * padding_height - dilatation_height * (kernel_height - 1) - 1)
/ stride_height
+ 1;
let out_width = (in_width + 2 * padding_width - dilatation_width * (kernel_width - 1) - 1)
/ stride_width
+ 1;
let out_height = calculate_conv_output_size(
kernel_height,
stride_height,
padding_height,
dilatation_height,
in_height,
);
let out_width = calculate_conv_output_size(
kernel_width,
stride_width,
padding_width,
dilatation_width,
in_width,
);
let x = apply_padding_4d(x, padding, 0i32.elem()).array;
@ -86,10 +93,15 @@ pub(crate) fn conv_transpose2d<E: FloatNdArrayElement>(
let [batch_size, _in_channels, in_height, in_width] = x.shape().dims;
let [in_channels, out_channels, kernel_height, kernel_width] = weight.shape().dims;
let out_height =
(in_height - 1) * stride_height + kernel_height + out_padding_height - 2 * padding_height;
let out_height = (in_height - 1) * stride_height
+ dilation_height * (kernel_height - 1)
+ out_padding_height
- 2 * padding_height
+ 1;
let out_width =
(in_width - 1) * stride_width + kernel_width + out_padding_width - 2 * padding_width;
(in_width - 1) * stride_width + dilation_width * (kernel_width - 1) + out_padding_width
- 2 * padding_width
+ 1;
let x = x.array;
let mut output = Array4::zeros(Dim([batch_size, out_channels, out_height, out_width]));

View File

@ -75,8 +75,9 @@ impl<E: FloatNdArrayElement> ModuleOps<NdArrayBackend<E>> for NdArrayBackend<E>
bias: Option<NdArrayTensor<E, 1>>,
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
) -> NdArrayTensor<E, 4> {
conv2d(x, weight, bias, stride, padding, [1, 1])
conv2d(x, weight, bias, stride, padding, dilation)
}
fn conv_transpose2d(
@ -85,9 +86,10 @@ impl<E: FloatNdArrayElement> ModuleOps<NdArrayBackend<E>> for NdArrayBackend<E>
bias: Option<NdArrayTensor<E, 1>>,
stride: [usize; 2],
padding: [usize; 2],
out_padding: [usize; 2],
padding_out: [usize; 2],
dilation: [usize; 2],
) -> NdArrayTensor<E, 4> {
conv_transpose2d(x, weight, bias, stride, padding, out_padding, [1, 1])
conv_transpose2d(x, weight, bias, stride, padding, padding_out, dilation)
}
fn max_pool2d(

View File

@ -32,6 +32,7 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
bias: Option<TchTensor<E, 1>>,
stride: usize,
padding: usize,
dilation: usize,
) -> TchTensor<E, 3> {
let tensor = tch::Tensor::conv1d(
&x.tensor,
@ -39,7 +40,7 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
bias.map(|t| t.tensor),
&[stride as i64],
&[padding as i64],
&[1],
&[dilation as i64],
1,
);
@ -52,6 +53,7 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
bias: Option<TchTensor<E, 1>>,
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
) -> TchTensor<E, 4> {
let tensor = tch::Tensor::conv2d(
&x.tensor,
@ -59,7 +61,7 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
bias.map(|t| t.tensor),
&[stride[0] as i64, stride[1] as i64],
&[padding[0] as i64, padding[1] as i64],
&[1, 1],
&[dilation[0] as i64, dilation[1] as i64],
1,
);
@ -72,7 +74,8 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
bias: Option<TchTensor<E, 1>>,
stride: [usize; 2],
padding: [usize; 2],
out_padding: [usize; 2],
padding_out: [usize; 2],
dilation: [usize; 2],
) -> TchTensor<E, 4> {
let tensor = tch::Tensor::conv_transpose2d(
&x.tensor,
@ -80,9 +83,9 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
bias.map(|t| t.tensor),
&[stride[0] as i64, stride[1] as i64],
&[padding[0] as i64, padding[1] as i64],
&[out_padding[0] as i64, out_padding[1] as i64],
&[padding_out[0] as i64, padding_out[1] as i64],
1,
&[1, 1],
&[dilation[0] as i64, dilation[1] as i64],
);
TchTensor::new(tensor)
@ -95,6 +98,7 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
stride: usize,
padding: usize,
padding_out: usize,
dilation: usize,
) -> TchTensor<E, 3> {
let tensor = tch::Tensor::conv_transpose1d(
&x.tensor,
@ -104,7 +108,7 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
&[padding as i64],
&[padding_out as i64],
1,
&[1],
&[dilation as i64],
);
TchTensor::new(tensor)

View File

@ -15,6 +15,7 @@ pub fn conv1d<B>(
bias: Option<Tensor<B, 1>>,
stride: usize,
padding: usize,
dilation: usize,
) -> Tensor<B, 3>
where
B: Backend,
@ -25,6 +26,7 @@ where
bias.map(|b| b.primitive),
stride,
padding,
dilation,
))
}
@ -35,6 +37,7 @@ pub fn conv2d<B>(
bias: Option<Tensor<B, 1>>,
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
) -> Tensor<B, 4>
where
B: Backend,
@ -45,6 +48,7 @@ where
bias.map(|b| b.primitive),
stride,
padding,
dilation,
))
}
@ -56,6 +60,7 @@ pub fn conv_transpose1d<B>(
stride: usize,
padding: usize,
padding_out: usize,
dilation: usize,
) -> Tensor<B, 3>
where
B: Backend,
@ -67,6 +72,7 @@ where
stride,
padding,
padding_out,
dilation,
))
}
@ -77,7 +83,8 @@ pub fn conv_transpose2d<B>(
bias: Option<Tensor<B, 1>>,
stride: [usize; 2],
padding: [usize; 2],
out_padding: [usize; 2],
padding_out: [usize; 2],
dilation: [usize; 2],
) -> Tensor<B, 4>
where
B: Backend,
@ -88,7 +95,8 @@ where
bias.map(|b| b.primitive),
stride,
padding,
out_padding,
padding_out,
dilation,
))
}

View File

@ -53,6 +53,7 @@ pub trait ModuleOps<B: Backend> {
bias: Option<B::TensorPrimitive<1>>,
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
) -> B::TensorPrimitive<4>;
/// Two dimensional transposed convolution.
///
@ -67,7 +68,8 @@ pub trait ModuleOps<B: Backend> {
bias: Option<B::TensorPrimitive<1>>,
stride: [usize; 2],
padding: [usize; 2],
out_padding: [usize; 2],
padding_out: [usize; 2],
dilation: [usize; 2],
) -> B::TensorPrimitive<4>;
/// Backward pass for the [conv2d](ModuleOps::conv2d) operation.
@ -76,9 +78,11 @@ pub trait ModuleOps<B: Backend> {
weight: B::TensorPrimitive<4>,
bias: Option<B::TensorPrimitive<1>>,
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
output_grad: B::TensorPrimitive<4>,
) -> Conv2dBackward<B> {
conv::conv2d_backward(x, weight, bias, stride, output_grad)
conv::conv2d_backward(x, weight, bias, stride, padding, dilation, output_grad)
}
/// One dimensional convolution.
///
@ -93,8 +97,9 @@ pub trait ModuleOps<B: Backend> {
bias: Option<B::TensorPrimitive<1>>,
stride: usize,
padding: usize,
dilation: usize,
) -> B::TensorPrimitive<3> {
conv::conv1d_from_conv2d::<B>(x, weight, bias, stride, padding)
conv::conv1d_from_conv2d::<B>(x, weight, bias, stride, padding, dilation)
}
/// One dimensional transposed convolution.
///
@ -110,6 +115,7 @@ pub trait ModuleOps<B: Backend> {
stride: usize,
padding: usize,
padding_out: usize,
dilation: usize,
) -> B::TensorPrimitive<3> {
conv::conv_transpose1d_from_conv_transpose2d::<B>(
x,
@ -118,6 +124,7 @@ pub trait ModuleOps<B: Backend> {
stride,
padding,
padding_out,
dilation,
)
}
/// Backward pass for the [conv1d](ModuleOps::conv1d) operation.
@ -126,9 +133,11 @@ pub trait ModuleOps<B: Backend> {
weight: B::TensorPrimitive<3>,
bias: Option<B::TensorPrimitive<1>>,
stride: usize,
padding: usize,
dilation: usize,
output_grad: B::TensorPrimitive<3>,
) -> Conv1dBackward<B> {
conv::conv1d_backward(x, weight, bias, stride, output_grad)
conv::conv1d_backward(x, weight, bias, stride, padding, dilation, output_grad)
}
/// Two dimensional max pooling.
///

View File

@ -2,9 +2,8 @@ use super::{Conv1dBackward, Conv2dBackward};
use crate::{backend::Backend, Shape};
use libm::ceilf;
/// Calculate the expected padding size required when applying a convolution with the specified
/// kernel size, stride, and input size to get the desired output size.
pub fn calculate_padding(
/// Calculate the expected padding size required when applying a convolution.
pub fn calculate_conv_padding(
kernel_size: usize,
stride: usize,
size_in: usize,
@ -21,29 +20,22 @@ pub fn calculate_padding(
padding as usize
}
/// Calculate the expected output size when applying a convolution with the specified kernel size,
/// stride and padding.
pub fn calculate_output_size(
/// Calculate the expected output size when doing a convolution operation.
pub fn calculate_conv_output_size(
kernel_size: usize,
stride: usize,
padding: usize,
dilation: usize,
size_in: usize,
) -> usize {
let kernel_size = kernel_size as f32;
let stride = stride as f32;
let padding = padding as f32;
let size_in = size_in as f32;
let size_out = (size_in + (2. * padding) - kernel_size) / stride;
let size_out = ceilf(size_out + 1.);
size_out as usize
(size_in + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1
}
fn calculate_padding_out(
kernel_size: usize,
stride: usize,
padding: usize,
dilation: usize,
size_in: usize,
size_out: usize,
) -> usize {
@ -51,8 +43,10 @@ fn calculate_padding_out(
return 0;
}
let out = calculate_output_size(kernel_size, stride, padding, size_out) as i64;
i64::max(0, out - size_in as i64) as usize
let out = 1 + libm::ceil(
(size_in + 2 * padding - dilation * (kernel_size - 1) - 1) as f64 / stride as f64,
) as usize;
i64::max(0, out as i64 - size_out as i64) as usize
}
/// Calculate the [1D convolution](crate::ops::ModuleOps::conv1d) backward pass using convolutions.
@ -61,14 +55,22 @@ pub(crate) fn conv1d_backward<B: Backend>(
weight: B::TensorPrimitive<3>,
bias: Option<B::TensorPrimitive<1>>,
stride: usize,
padding: usize,
dilation: usize,
output_grad: B::TensorPrimitive<3>,
) -> Conv1dBackward<B> {
let [batch_size, channels_in, length_in] = B::shape(&x).dims;
let [_batch_size, channels_out, length_out] = B::shape(&output_grad).dims;
let [_, _, kernel_size] = B::shape(&weight).dims;
let padding = calculate_padding(kernel_size, stride, length_in, length_out);
let padding_out = calculate_padding_out(kernel_size, stride, padding, length_out, length_in);
let padding_out = calculate_padding_out(
kernel_size,
stride,
padding,
dilation,
length_in,
length_out,
);
let x_grad = B::conv_transpose1d(
output_grad.clone(),
@ -77,11 +79,19 @@ pub(crate) fn conv1d_backward<B: Backend>(
stride,
padding,
padding_out,
dilation,
);
let x_swapped = B::swap_dims(x, 0, 1);
let output_grad_swapped = B::swap_dims(output_grad, 0, 1);
let weight_grad_swapped = B::conv1d(x_swapped, output_grad_swapped.clone(), None, 1, padding);
let weight_grad_swapped = B::conv1d(
x_swapped,
output_grad_swapped.clone(),
None,
dilation,
padding,
stride,
);
let mut weight_grad = B::swap_dims(weight_grad_swapped, 0, 1);
if B::shape(&weight_grad) != Shape::new([channels_out, channels_in, kernel_size]) {
@ -110,28 +120,39 @@ pub(crate) fn conv2d_backward<B: Backend>(
weight: B::TensorPrimitive<4>,
bias: Option<B::TensorPrimitive<1>>,
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
output_grad: B::TensorPrimitive<4>,
) -> Conv2dBackward<B> {
let [batch_size, channels_in, height_in, width_in] = B::shape(&x).dims;
let [_batch_size, channels_out, height_out, width_out] = B::shape(&output_grad).dims;
let [_, _, kernel_size_1, kernel_size_2] = B::shape(&weight).dims;
let [stride_1, stride_2] = stride;
let padding_1 = calculate_padding(kernel_size_1, stride_1, height_in, height_out);
let padding_2 = calculate_padding(kernel_size_2, stride_2, width_in, width_out);
let padding_1_out =
calculate_padding_out(kernel_size_1, stride_1, padding_1, height_out, height_in);
let padding_2_out =
calculate_padding_out(kernel_size_2, stride_2, padding_2, width_out, width_in);
let padding_1_out = calculate_padding_out(
kernel_size_1,
stride[0],
padding[0],
dilation[0],
height_in,
height_out,
);
let padding_2_out = calculate_padding_out(
kernel_size_2,
stride[1],
padding[1],
dilation[1],
width_in,
width_out,
);
let x_grad = B::conv_transpose2d(
output_grad.clone(),
weight,
None,
[stride_1, stride_2],
[padding_1, padding_2],
stride,
padding,
[padding_1_out, padding_2_out],
dilation,
);
let x_swapped = B::swap_dims(x, 0, 1);
@ -140,8 +161,9 @@ pub(crate) fn conv2d_backward<B: Backend>(
x_swapped,
output_grad_swapped.clone(),
None,
[1, 1],
[padding_1, padding_2],
dilation,
padding,
stride,
);
let mut weight_grad = B::swap_dims(weight_grad_swapped, 0, 1);
@ -182,6 +204,7 @@ pub(crate) fn conv1d_from_conv2d<B: Backend>(
bias: Option<B::TensorPrimitive<1>>,
stride: usize,
padding: usize,
dilation: usize,
) -> B::TensorPrimitive<3> {
let [channels_out, _channels_in, kernel_size] = B::shape(&weight).dims;
let [batch_size, channels_in, length_in] = B::shape(&x).dims;
@ -192,7 +215,7 @@ pub(crate) fn conv1d_from_conv2d<B: Backend>(
);
let x = B::reshape(x, Shape::new([batch_size, channels_in, length_in, 1]));
let tensor = B::conv2d(x, weight, bias, [stride, 1], [padding, 0]);
let tensor = B::conv2d(x, weight, bias, [stride, 1], [padding, 0], [dilation, 1]);
let [batch_size, channels_out, height_out, _weight_out] = B::shape(&tensor).dims;
B::reshape(tensor, Shape::from([batch_size, channels_out, height_out]))
}
@ -205,6 +228,7 @@ pub(crate) fn conv_transpose1d_from_conv_transpose2d<B: Backend>(
stride: usize,
padding: usize,
padding_out: usize,
dilation: usize,
) -> B::TensorPrimitive<3> {
let [channels_in, channels_out, kernel_size] = B::shape(&weight).dims;
let [batch_size, _channels_in, length_in] = B::shape(&x).dims;
@ -215,7 +239,15 @@ pub(crate) fn conv_transpose1d_from_conv_transpose2d<B: Backend>(
);
let x = B::reshape(x, Shape::new([batch_size, channels_in, length_in, 1]));
let tensor = B::conv_transpose2d(x, weight, bias, [stride, 1], [padding, 0], [padding_out, 0]);
let tensor = B::conv_transpose2d(
x,
weight,
bias,
[stride, 1],
[padding, 0],
[padding_out, 0],
[dilation, 1],
);
let [batch_size, channels_out, height_out, _weight_out] = B::shape(&tensor).dims;
B::reshape(tensor, Shape::from([batch_size, channels_out, height_out]))
}
@ -230,8 +262,9 @@ mod tests {
let stride = 1;
let padding = 1;
let size_in = 3;
let dilation = 1;
let size_out = calculate_output_size(kernel_size, stride, padding, size_in);
let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in);
assert_eq!(size_out, 3);
}
@ -242,20 +275,35 @@ mod tests {
let stride = 2;
let padding = 3;
let size_in = 27;
let dilation = 1;
let size_out = calculate_output_size(kernel_size, stride, padding, size_in);
let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in);
assert_eq!(size_out, 15);
}
#[test]
fn test_calculate_output_size_3() {
let kernel_size = 5;
let stride = 2;
let padding = 3;
let size_in = 27;
let dilation = 2;
let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in);
assert_eq!(size_out, 13);
}
#[test]
fn test_calculate_same_padding_1() {
let kernel_size = 3;
let stride = 1;
let size_in = 3;
let dilation = 1;
let padding = calculate_padding(kernel_size, stride, size_in, size_in);
let size_out = calculate_output_size(kernel_size, stride, padding, size_in);
let padding = calculate_conv_padding(kernel_size, stride, size_in, size_in);
let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in);
assert_eq!(size_in, size_out, "Expected size");
}
@ -265,9 +313,10 @@ mod tests {
let kernel_size = 3;
let stride = 2;
let size_in = 7;
let dilation = 1;
let padding = calculate_padding(kernel_size, stride, size_in, size_in);
let size_out = calculate_output_size(kernel_size, stride, padding, size_in);
let padding = calculate_conv_padding(kernel_size, stride, size_in, size_in);
let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in);
assert_eq!(size_in, size_out, "Expected size");
}
@ -278,9 +327,11 @@ mod tests {
let stride = 2;
let size_in = 7;
let size_out = 10;
let dilation = 1;
let padding = calculate_padding(kernel_size, stride, size_in, size_out);
let size_out_expected = calculate_output_size(kernel_size, stride, padding, size_in);
let padding = calculate_conv_padding(kernel_size, stride, size_in, size_out);
let size_out_expected =
calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in);
assert_eq!(size_out, size_out_expected, "Expected size");
}

View File

@ -13,6 +13,7 @@ mod tests {
kernel_size: 3,
padding: 1,
stride: 1,
dilation: 1,
length: 6,
};
@ -30,6 +31,25 @@ mod tests {
]));
}
#[test]
fn test_conv1d_dilation() {
let test = Conv1dTestCase {
batch_size: 2,
channels_in: 2,
channels_out: 2,
kernel_size: 3,
padding: 1,
stride: 1,
dilation: 2,
length: 4,
};
test.assert_output(TestTensor::from_floats([
[[5., 5.], [5., 5.]],
[[5., 5.], [5., 5.]],
]));
}
#[test]
fn test_conv1d_complex() {
let test = Conv1dTestCase {
@ -39,6 +59,7 @@ mod tests {
kernel_size: 3,
padding: 1,
stride: 2,
dilation: 1,
length: 9,
};
@ -65,6 +86,7 @@ mod tests {
kernel_size: usize,
padding: usize,
stride: usize,
dilation: usize,
length: usize,
}
@ -73,7 +95,14 @@ mod tests {
let weights = TestTensor::ones([self.channels_out, self.channels_in, self.kernel_size]);
let bias = TestTensor::ones([self.channels_out]);
let x = TestTensor::ones([self.batch_size, self.channels_in, self.length]);
let output = conv1d(x, weights, Some(bias), self.stride, self.padding);
let output = conv1d(
x,
weights,
Some(bias),
self.stride,
self.padding,
self.dilation,
);
y.to_data().assert_approx_eq(&output.into_data(), 3);
}

View File

@ -16,6 +16,8 @@ mod tests {
padding_2: 1,
stride_1: 1,
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
height: 6,
width: 6,
};
@ -88,62 +90,24 @@ mod tests {
padding_2: 2,
stride_1: 2,
stride_2: 3,
height: 7,
width: 9,
dilation_1: 1,
dilation_2: 2,
height: 4,
width: 5,
};
test.assert_output(TestTensor::from_floats([
[
[
[1., 13., 13., 13.],
[1., 19., 19., 19.],
[1., 19., 19., 19.],
[1., 13., 13., 13.],
],
[
[1., 13., 13., 13.],
[1., 19., 19., 19.],
[1., 19., 19., 19.],
[1., 13., 13., 13.],
],
[
[1., 13., 13., 13.],
[1., 19., 19., 19.],
[1., 19., 19., 19.],
[1., 13., 13., 13.],
],
[
[1., 13., 13., 13.],
[1., 19., 19., 19.],
[1., 19., 19., 19.],
[1., 13., 13., 13.],
],
[[7., 13., 7.], [10., 19., 10.]],
[[7., 13., 7.], [10., 19., 10.]],
[[7., 13., 7.], [10., 19., 10.]],
[[7., 13., 7.], [10., 19., 10.]],
],
[
[
[1., 13., 13., 13.],
[1., 19., 19., 19.],
[1., 19., 19., 19.],
[1., 13., 13., 13.],
],
[
[1., 13., 13., 13.],
[1., 19., 19., 19.],
[1., 19., 19., 19.],
[1., 13., 13., 13.],
],
[
[1., 13., 13., 13.],
[1., 19., 19., 19.],
[1., 19., 19., 19.],
[1., 13., 13., 13.],
],
[
[1., 13., 13., 13.],
[1., 19., 19., 19.],
[1., 19., 19., 19.],
[1., 13., 13., 13.],
],
[[7., 13., 7.], [10., 19., 10.]],
[[7., 13., 7.], [10., 19., 10.]],
[[7., 13., 7.], [10., 19., 10.]],
[[7., 13., 7.], [10., 19., 10.]],
],
]));
}
@ -158,6 +122,8 @@ mod tests {
padding_2: usize,
stride_1: usize,
stride_2: usize,
dilation_1: usize,
dilation_2: usize,
height: usize,
width: usize,
}
@ -178,6 +144,7 @@ mod tests {
Some(bias),
[self.stride_1, self.stride_2],
[self.padding_1, self.padding_2],
[self.dilation_1, self.dilation_2],
);
y.to_data().assert_approx_eq(&output.into_data(), 3);

View File

@ -14,6 +14,7 @@ mod tests {
padding: 1,
padding_out: 0,
stride: 1,
dilation: 1,
length: 4,
};
@ -33,6 +34,7 @@ mod tests {
padding: 1,
padding_out: 1,
stride: 2,
dilation: 1,
length: 4,
};
@ -42,6 +44,26 @@ mod tests {
]]));
}
#[test]
fn test_conv_transpose1d_dilation() {
let test = ConvTranspose1dTestCase {
batch_size: 1,
channels_in: 2,
channels_out: 2,
kernel_size: 3,
padding: 1,
padding_out: 0,
stride: 1,
dilation: 2,
length: 4,
};
test.assert_output(TestTensor::from_floats([[
[30., 64., 78., 76., 94., 52.],
[49., 101., 127., 113., 143., 77.],
]]));
}
struct ConvTranspose1dTestCase {
batch_size: usize,
channels_in: usize,
@ -50,6 +72,7 @@ mod tests {
padding: usize,
padding_out: usize,
stride: usize,
dilation: usize,
length: usize,
}
@ -81,6 +104,7 @@ mod tests {
self.stride,
self.padding,
self.padding_out,
self.dilation,
);
y.to_data().assert_approx_eq(&output.into_data(), 3);

View File

@ -18,6 +18,8 @@ mod tests {
padding_out_2: 0,
stride_1: 1,
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
height: 2,
width: 2,
};
@ -38,6 +40,8 @@ mod tests {
padding_out_2: 0,
stride_1: 1,
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
height: 4,
width: 4,
};
@ -78,6 +82,8 @@ mod tests {
padding_out_2: 0,
stride_1: 2,
stride_2: 2,
dilation_1: 1,
dilation_2: 1,
height: 2,
width: 2,
};
@ -90,6 +96,44 @@ mod tests {
]]]));
}
#[test]
fn test_conv_transpose2d_dilation_2() {
let test = ConvTranspose2dTestCase {
batch_size: 1,
channels_in: 2,
channels_out: 2,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 1,
padding_2: 1,
padding_out_1: 1,
padding_out_2: 1,
stride_1: 1,
stride_2: 1,
dilation_1: 2,
dilation_2: 2,
height: 2,
width: 2,
};
test.assert_output(TestTensor::from_floats([[
[
[126., 116., 136., 124., 146.],
[108., 88., 114., 92., 120.],
[156., 140., 166., 148., 176.],
[126., 100., 132., 104., 138.],
[186., 164., 196., 172., 206.],
],
[
[217., 189., 227., 197., 237.],
[163., 125., 169., 129., 175.],
[247., 213., 257., 221., 267.],
[181., 137., 187., 141., 193.],
[277., 237., 287., 245., 297.],
],
]]));
}
#[test]
fn test_conv_transpose2d_stride2_out_padding() {
let test = ConvTranspose2dTestCase {
@ -104,6 +148,8 @@ mod tests {
padding_out_2: 1,
stride_1: 2,
stride_2: 2,
dilation_1: 1,
dilation_2: 1,
height: 4,
width: 4,
};
@ -144,6 +190,8 @@ mod tests {
padding_out_2: usize,
stride_1: usize,
stride_2: usize,
dilation_1: usize,
dilation_2: usize,
height: usize,
width: usize,
}
@ -181,6 +229,7 @@ mod tests {
[self.stride_1, self.stride_2],
[self.padding_1, self.padding_2],
[self.padding_out_1, self.padding_out_2],
[self.dilation_1, self.dilation_2],
);
y.to_data().assert_approx_eq(&output.into_data(), 3);