diff --git a/burn-autodiff/src/ops/module.rs b/burn-autodiff/src/ops/module.rs index 73ff6b440..6200f64f2 100644 --- a/burn-autodiff/src/ops/module.rs +++ b/burn-autodiff/src/ops/module.rs @@ -52,6 +52,7 @@ impl ModuleOps> for ADBackendDecorator { bias: Option>, stride: [usize; 2], padding: [usize; 2], + dilation: [usize; 2], ) -> ADTensor { #[derive(Debug)] struct Conv2DWithBias; @@ -64,14 +65,17 @@ impl ModuleOps> for ADBackendDecorator { B::TensorPrimitive<4>, B::TensorPrimitive<1>, [usize; 2], + [usize; 2], + [usize; 2], ); fn backward(self, ops: Ops, grads: &mut Gradients) { let [node_x, node_weight, node_bias] = ops.parents; let grad = grads.consume::(&ops.node); - let (x, weight, bias, stride) = ops.state; - let backward = B::conv2d_backward(x, weight, Some(bias), stride, grad); + let (x, weight, bias, stride, padding, dilation) = ops.state; + let backward = + B::conv2d_backward(x, weight, Some(bias), stride, padding, dilation, grad); if let Some(node) = node_x { grads.register::(node, backward.x_grad) @@ -86,14 +90,20 @@ impl ModuleOps> for ADBackendDecorator { } impl Backward for Conv2DNoBias { - type State = (B::TensorPrimitive<4>, B::TensorPrimitive<4>, [usize; 2]); + type State = ( + B::TensorPrimitive<4>, + B::TensorPrimitive<4>, + [usize; 2], + [usize; 2], + [usize; 2], + ); fn backward(self, ops: Ops, grads: &mut Gradients) { let [node_x, node_weight] = ops.parents; let grad = grads.consume::(&ops.node); - let (x, weight, stride) = ops.state; - let backward = B::conv2d_backward(x, weight, None, stride, grad); + let (x, weight, stride, padding, dilation) = ops.state; + let backward = B::conv2d_backward(x, weight, None, stride, padding, dilation, grad); if let Some(node) = node_x { grads.register::(node, backward.x_grad) @@ -119,6 +129,8 @@ impl ModuleOps> for ADBackendDecorator { weight.primitive.clone(), bias.primitive.clone(), stride, + padding, + dilation, ), B::conv2d( x.primitive, @@ -126,6 +138,7 @@ impl ModuleOps> for ADBackendDecorator { Some(bias.primitive), stride, padding, + dilation, ), ), OpsKind::UnTracked(prep) => prep.finish(B::conv2d( @@ -134,6 +147,7 @@ impl ModuleOps> for ADBackendDecorator { Some(bias.primitive), stride, padding, + dilation, )), } } @@ -143,8 +157,21 @@ impl ModuleOps> for ADBackendDecorator { .statefull() { OpsKind::Tracked(prep) => prep.finish( - (x.primitive.clone(), weight.primitive.clone(), stride), - B::conv2d(x.primitive, weight.primitive, None, stride, padding), + ( + x.primitive.clone(), + weight.primitive.clone(), + stride, + padding, + dilation, + ), + B::conv2d( + x.primitive, + weight.primitive, + None, + stride, + padding, + dilation, + ), ), OpsKind::UnTracked(prep) => prep.finish(B::conv2d( x.primitive, @@ -152,6 +179,7 @@ impl ModuleOps> for ADBackendDecorator { None, stride, padding, + dilation, )), } } @@ -164,7 +192,8 @@ impl ModuleOps> for ADBackendDecorator { _bias: Option>, _stride: [usize; 2], _padding: [usize; 2], - _out_padding: [usize; 2], + _padding_out: [usize; 2], + _dilation: [usize; 2], ) -> ADTensor { todo!("Transposed 2D convolution doesn't yet support backward."); } @@ -175,7 +204,8 @@ impl ModuleOps> for ADBackendDecorator { _bias: Option>, _stride: usize, _padding: usize, - _out_padding: usize, + _padding_out: usize, + _dilation: usize, ) -> ADTensor { todo!("Transposed 1D convolution doesn't yet support backward."); } @@ -186,6 +216,7 @@ impl ModuleOps> for ADBackendDecorator { bias: Option>, stride: usize, padding: usize, + dilation: usize, ) -> ADTensor { #[derive(Debug)] struct Conv1DWithBias; @@ -198,14 +229,17 @@ impl ModuleOps> for ADBackendDecorator { B::TensorPrimitive<3>, B::TensorPrimitive<1>, usize, + usize, + usize, ); fn backward(self, ops: Ops, grads: &mut Gradients) { let [node_x, node_weight, node_bias] = ops.parents; let grad = grads.consume::(&ops.node); - let (x, weight, bias, stride) = ops.state; - let backward = B::conv1d_backward(x, weight, Some(bias), stride, grad); + let (x, weight, bias, stride, padding, dilation) = ops.state; + let backward = + B::conv1d_backward(x, weight, Some(bias), stride, padding, dilation, grad); if let Some(node) = node_x { grads.register::(node, backward.x_grad) @@ -220,14 +254,20 @@ impl ModuleOps> for ADBackendDecorator { } impl Backward for Conv1DNoBias { - type State = (B::TensorPrimitive<3>, B::TensorPrimitive<3>, usize); + type State = ( + B::TensorPrimitive<3>, + B::TensorPrimitive<3>, + usize, + usize, + usize, + ); fn backward(self, ops: Ops, grads: &mut Gradients) { let [node_x, node_weight] = ops.parents; let grad = grads.consume::(&ops.node); - let (x, weight, stride) = ops.state; - let backward = B::conv1d_backward(x, weight, None, stride, grad); + let (x, weight, stride, padding, dilation) = ops.state; + let backward = B::conv1d_backward(x, weight, None, stride, padding, dilation, grad); if let Some(node) = node_x { grads.register::(node, backward.x_grad) @@ -252,6 +292,8 @@ impl ModuleOps> for ADBackendDecorator { weight.primitive.clone(), bias.primitive.clone(), stride, + padding, + dilation, ), B::conv1d( x.primitive, @@ -259,6 +301,7 @@ impl ModuleOps> for ADBackendDecorator { Some(bias.primitive), stride, padding, + dilation, ), ), OpsKind::UnTracked(prep) => prep.finish(B::conv1d( @@ -267,6 +310,7 @@ impl ModuleOps> for ADBackendDecorator { Some(bias.primitive), stride, padding, + dilation, )), } } @@ -276,8 +320,21 @@ impl ModuleOps> for ADBackendDecorator { .statefull() { OpsKind::Tracked(prep) => prep.finish( - (x.primitive.clone(), weight.primitive.clone(), stride), - B::conv1d(x.primitive, weight.primitive, None, stride, padding), + ( + x.primitive.clone(), + weight.primitive.clone(), + stride, + padding, + dilation, + ), + B::conv1d( + x.primitive, + weight.primitive, + None, + stride, + padding, + dilation, + ), ), OpsKind::UnTracked(prep) => prep.finish(B::conv1d( x.primitive, @@ -285,6 +342,7 @@ impl ModuleOps> for ADBackendDecorator { None, stride, padding, + dilation, )), } } diff --git a/burn-autodiff/src/tests/conv1d.rs b/burn-autodiff/src/tests/conv1d.rs index ac584e8aa..7ebe067ce 100644 --- a/burn-autodiff/src/tests/conv1d.rs +++ b/burn-autodiff/src/tests/conv1d.rs @@ -12,6 +12,7 @@ mod tests { kernel_size: 3, padding: 1, stride: 1, + dilation: 1, length: 6, }; let grads = Grads { @@ -46,6 +47,7 @@ mod tests { kernel_size: 3, padding: 1, stride: 1, + dilation: 1, length: 6, }; let grads = Grads { @@ -72,6 +74,7 @@ mod tests { kernel_size: 3, padding: 2, stride: 1, + dilation: 1, length: 6, }; let grads = Grads { @@ -97,6 +100,7 @@ mod tests { kernel_size: 3, padding: 1, stride: 2, + dilation: 1, length: 4, }; let grads = Grads { @@ -113,6 +117,32 @@ mod tests { test.assert_grads(grads); } + #[test] + fn test_conv1d_dilation() { + let test = Conv1dTestCase { + batch_size: 2, + channels_in: 2, + channels_out: 2, + kernel_size: 3, + padding: 1, + stride: 1, + dilation: 2, + length: 4, + }; + let grads = Grads { + x: TestTensor::from_floats([ + [[2., 2., 2., 2.], [2., 2., 2., 2.]], + [[2., 2., 2., 2.], [2., 2., 2., 2.]], + ]), + weight: TestTensor::from_floats([ + [[2., 4., 2.], [2., 4., 2.]], + [[2., 4., 2.], [2., 4., 2.]], + ]), + bias: TestTensor::from_floats([4., 4.]), + }; + test.assert_grads(grads); + } + struct Conv1dTestCase { batch_size: usize, channels_in: usize, @@ -120,6 +150,7 @@ mod tests { kernel_size: usize, padding: usize, stride: usize, + dilation: usize, length: usize, } @@ -143,6 +174,7 @@ mod tests { Some(bias.clone()), self.stride, self.padding, + self.dilation, ); let grads = output.backward(); diff --git a/burn-autodiff/src/tests/conv2d.rs b/burn-autodiff/src/tests/conv2d.rs index a1c2eac36..c8e8c55de 100644 --- a/burn-autodiff/src/tests/conv2d.rs +++ b/burn-autodiff/src/tests/conv2d.rs @@ -15,6 +15,8 @@ mod tests { padding_2: 1, stride_1: 1, stride_2: 1, + dilation_1: 1, + dilation_2: 1, height: 6, width: 6, }; @@ -107,6 +109,8 @@ mod tests { padding_2: 1, stride_1: 1, stride_2: 1, + dilation_1: 1, + dilation_2: 1, height: 6, width: 6, }; @@ -180,6 +184,8 @@ mod tests { padding_2: 1, stride_1: 1, stride_2: 1, + dilation_1: 1, + dilation_2: 1, height: 6, width: 6, }; @@ -265,6 +271,8 @@ mod tests { padding_2: 2, stride_1: 1, stride_2: 1, + dilation_1: 1, + dilation_2: 1, height: 6, width: 6, }; @@ -334,6 +342,8 @@ mod tests { padding_2: 1, stride_1: 1, stride_2: 1, + dilation_1: 1, + dilation_2: 1, height: 6, width: 5, }; @@ -403,6 +413,8 @@ mod tests { padding_2: 1, stride_1: 2, stride_2: 2, + dilation_1: 1, + dilation_2: 1, height: 8, width: 8, }; @@ -480,6 +492,8 @@ mod tests { padding_2: 1, stride_1: 3, stride_2: 1, + dilation_1: 1, + dilation_2: 1, height: 8, width: 8, }; @@ -545,6 +559,48 @@ mod tests { test.assert_grads(grads); } + #[test] + fn test_conv2d_complex() { + let test = Conv2dTestCase { + batch_size: 1, + channels_in: 2, + channels_out: 3, + kernel_size_1: 2, + kernel_size_2: 3, + padding_1: 1, + padding_2: 2, + stride_1: 1, + stride_2: 2, + dilation_1: 2, + dilation_2: 3, + height: 4, + width: 5, + }; + let grads = Grads { + x: TestTensor::from_floats([[ + [ + [3., 3., 0., 3., 3.], + [6., 6., 0., 6., 6.], + [6., 6., 0., 6., 6.], + [3., 3., 0., 3., 3.], + ], + [ + [3., 3., 0., 3., 3.], + [6., 6., 0., 6., 6.], + [6., 6., 0., 6., 6.], + [3., 3., 0., 3., 3.], + ], + ]]), + weight: TestTensor::from_floats([ + [[[3., 6., 3.], [3., 6., 3.]], [[3., 6., 3.], [3., 6., 3.]]], + [[[3., 6., 3.], [3., 6., 3.]], [[3., 6., 3.], [3., 6., 3.]]], + [[[3., 6., 3.], [3., 6., 3.]], [[3., 6., 3.], [3., 6., 3.]]], + ]), + bias: TestTensor::from_floats([8., 8., 8.]), + }; + test.assert_grads(grads); + } + struct Conv2dTestCase { batch_size: usize, channels_in: usize, @@ -555,6 +611,8 @@ mod tests { padding_2: usize, stride_1: usize, stride_2: usize, + dilation_1: usize, + dilation_2: usize, height: usize, width: usize, } @@ -584,6 +642,7 @@ mod tests { Some(bias.clone()), [self.stride_1, self.stride_2], [self.padding_1, self.padding_2], + [self.dilation_1, self.dilation_2], ); let grads = output.backward(); diff --git a/burn-core/src/nn/conv/conv1d.rs b/burn-core/src/nn/conv/conv1d.rs index f1e94c6cc..20558dd62 100644 --- a/burn-core/src/nn/conv/conv1d.rs +++ b/burn-core/src/nn/conv/conv1d.rs @@ -7,7 +7,7 @@ use crate::nn::Initializer; use crate::tensor::backend::Backend; use crate::tensor::Tensor; use burn_tensor::module::conv1d; -use burn_tensor::ops::conv::calculate_padding; +use burn_tensor::ops::conv::calculate_conv_padding; use libm::sqrt; @@ -23,6 +23,9 @@ pub struct Conv1dConfig { /// The stride of the convolution. #[config(default = "1")] pub stride: usize, + /// Spacing between kernel elements. + #[config(default = "1")] + pub dilation: usize, /// The padding configuration. #[config(default = "Conv1dPaddingConfig::Valid")] pub padding: Conv1dPaddingConfig, @@ -61,6 +64,7 @@ pub struct Conv1d { bias: Option>>, stride: usize, kernel_size: usize, + dilation: usize, padding: Conv1dPaddingConfig, } @@ -90,6 +94,7 @@ impl Conv1dConfig { stride: 1, // TODO: Add the stride to the config when properly supported. kernel_size: self.kernel_size, padding: self.padding.clone(), + dilation: self.dilation, } } /// Initialize a new [conv1d](Conv1d) module with a [record](Conv1dRecord). @@ -100,6 +105,7 @@ impl Conv1dConfig { stride: 1, // TODO: Add the stride to the config when properly supported. kernel_size: self.kernel_size, padding: self.padding.clone(), + dilation: self.dilation, } } } @@ -114,7 +120,7 @@ impl Conv1d { pub fn forward(&self, input: Tensor) -> Tensor { let same_padding = || { let [_batch_size, _channels_in, length] = input.dims(); - calculate_padding(self.kernel_size, self.stride, length, length) + calculate_conv_padding(self.kernel_size, self.stride, length, length) }; let padding = match &self.padding { @@ -129,6 +135,7 @@ impl Conv1d { self.bias.as_ref().map(|bias| bias.val()), self.stride, padding, + self.dilation, ) } } diff --git a/burn-core/src/nn/conv/conv2d.rs b/burn-core/src/nn/conv/conv2d.rs index 00fcf9b23..4f1871458 100644 --- a/burn-core/src/nn/conv/conv2d.rs +++ b/burn-core/src/nn/conv/conv2d.rs @@ -7,7 +7,7 @@ use crate::nn::Initializer; use crate::tensor::backend::Backend; use crate::tensor::Tensor; use burn_tensor::module::conv2d; -use burn_tensor::ops::conv::calculate_padding; +use burn_tensor::ops::conv::calculate_conv_padding; use libm::sqrt; @@ -21,6 +21,9 @@ pub struct Conv2dConfig { /// The stride of the convolution. #[config(default = "[1, 1]")] pub stride: [usize; 2], + /// Spacing between kernel elements. + #[config(default = "[1, 1]")] + pub dilation: [usize; 2], /// The padding configuration. #[config(default = "Conv2dPaddingConfig::Valid")] pub padding: Conv2dPaddingConfig, @@ -59,6 +62,7 @@ pub struct Conv2d { bias: Option>>, stride: [usize; 2], kernel_size: [usize; 2], + dilation: [usize; 2], padding: Conv2dPaddingConfig, } @@ -90,8 +94,9 @@ impl Conv2dConfig { Conv2d { weight: Param::from(weight), bias: bias.map(Param::from), - stride: [1, 1], // TODO: Add the stride to the config when properly supported. + stride: self.stride, kernel_size: self.kernel_size, + dilation: self.dilation, padding: self.padding.clone(), } } @@ -102,6 +107,7 @@ impl Conv2dConfig { weight: record.weight, bias: record.bias, stride: self.stride, + dilation: self.dilation, kernel_size: self.kernel_size, padding: self.padding.clone(), } @@ -126,6 +132,7 @@ impl Conv2d { self.bias.as_ref().map(|bias| bias.val()), self.stride, padding, + self.dilation, ) } } @@ -139,8 +146,8 @@ impl Conv2dPaddingConfig { stride: &[usize; 2], ) -> [usize; 2] { let same_padding = || { - let p1 = calculate_padding(kernel_size[0], stride[0], height, height); - let p2 = calculate_padding(kernel_size[1], stride[1], width, width); + let p1 = calculate_conv_padding(kernel_size[0], stride[0], height, height); + let p2 = calculate_conv_padding(kernel_size[1], stride[1], width, width); [p1, p2] }; diff --git a/burn-ndarray/src/ops/conv.rs b/burn-ndarray/src/ops/conv.rs index 08bec9912..102cbd0ff 100644 --- a/burn-ndarray/src/ops/conv.rs +++ b/burn-ndarray/src/ops/conv.rs @@ -1,4 +1,4 @@ -use burn_tensor::ElementConversion; +use burn_tensor::{ops::conv::calculate_conv_output_size, ElementConversion}; use ndarray::{Array4, Dim}; use crate::{ @@ -20,13 +20,20 @@ pub(crate) fn conv2d( let [batch_size, _in_channels, in_height, in_width] = x.shape().dims; let [out_channels, in_channels, kernel_height, kernel_width] = weight.shape().dims; - let out_height = (in_height + 2 * padding_height - dilatation_height * (kernel_height - 1) - 1) - / stride_height - + 1; - - let out_width = (in_width + 2 * padding_width - dilatation_width * (kernel_width - 1) - 1) - / stride_width - + 1; + let out_height = calculate_conv_output_size( + kernel_height, + stride_height, + padding_height, + dilatation_height, + in_height, + ); + let out_width = calculate_conv_output_size( + kernel_width, + stride_width, + padding_width, + dilatation_width, + in_width, + ); let x = apply_padding_4d(x, padding, 0i32.elem()).array; @@ -86,10 +93,15 @@ pub(crate) fn conv_transpose2d( let [batch_size, _in_channels, in_height, in_width] = x.shape().dims; let [in_channels, out_channels, kernel_height, kernel_width] = weight.shape().dims; - let out_height = - (in_height - 1) * stride_height + kernel_height + out_padding_height - 2 * padding_height; + let out_height = (in_height - 1) * stride_height + + dilation_height * (kernel_height - 1) + + out_padding_height + - 2 * padding_height + + 1; let out_width = - (in_width - 1) * stride_width + kernel_width + out_padding_width - 2 * padding_width; + (in_width - 1) * stride_width + dilation_width * (kernel_width - 1) + out_padding_width + - 2 * padding_width + + 1; let x = x.array; let mut output = Array4::zeros(Dim([batch_size, out_channels, out_height, out_width])); diff --git a/burn-ndarray/src/ops/module.rs b/burn-ndarray/src/ops/module.rs index f530dbc2e..e84ed6d61 100644 --- a/burn-ndarray/src/ops/module.rs +++ b/burn-ndarray/src/ops/module.rs @@ -75,8 +75,9 @@ impl ModuleOps> for NdArrayBackend bias: Option>, stride: [usize; 2], padding: [usize; 2], + dilation: [usize; 2], ) -> NdArrayTensor { - conv2d(x, weight, bias, stride, padding, [1, 1]) + conv2d(x, weight, bias, stride, padding, dilation) } fn conv_transpose2d( @@ -85,9 +86,10 @@ impl ModuleOps> for NdArrayBackend bias: Option>, stride: [usize; 2], padding: [usize; 2], - out_padding: [usize; 2], + padding_out: [usize; 2], + dilation: [usize; 2], ) -> NdArrayTensor { - conv_transpose2d(x, weight, bias, stride, padding, out_padding, [1, 1]) + conv_transpose2d(x, weight, bias, stride, padding, padding_out, dilation) } fn max_pool2d( diff --git a/burn-tch/src/ops/module.rs b/burn-tch/src/ops/module.rs index 868dadf55..e64c4f28d 100644 --- a/burn-tch/src/ops/module.rs +++ b/burn-tch/src/ops/module.rs @@ -32,6 +32,7 @@ impl ModuleOps> for TchBackend { bias: Option>, stride: usize, padding: usize, + dilation: usize, ) -> TchTensor { let tensor = tch::Tensor::conv1d( &x.tensor, @@ -39,7 +40,7 @@ impl ModuleOps> for TchBackend { bias.map(|t| t.tensor), &[stride as i64], &[padding as i64], - &[1], + &[dilation as i64], 1, ); @@ -52,6 +53,7 @@ impl ModuleOps> for TchBackend { bias: Option>, stride: [usize; 2], padding: [usize; 2], + dilation: [usize; 2], ) -> TchTensor { let tensor = tch::Tensor::conv2d( &x.tensor, @@ -59,7 +61,7 @@ impl ModuleOps> for TchBackend { bias.map(|t| t.tensor), &[stride[0] as i64, stride[1] as i64], &[padding[0] as i64, padding[1] as i64], - &[1, 1], + &[dilation[0] as i64, dilation[1] as i64], 1, ); @@ -72,7 +74,8 @@ impl ModuleOps> for TchBackend { bias: Option>, stride: [usize; 2], padding: [usize; 2], - out_padding: [usize; 2], + padding_out: [usize; 2], + dilation: [usize; 2], ) -> TchTensor { let tensor = tch::Tensor::conv_transpose2d( &x.tensor, @@ -80,9 +83,9 @@ impl ModuleOps> for TchBackend { bias.map(|t| t.tensor), &[stride[0] as i64, stride[1] as i64], &[padding[0] as i64, padding[1] as i64], - &[out_padding[0] as i64, out_padding[1] as i64], + &[padding_out[0] as i64, padding_out[1] as i64], 1, - &[1, 1], + &[dilation[0] as i64, dilation[1] as i64], ); TchTensor::new(tensor) @@ -95,6 +98,7 @@ impl ModuleOps> for TchBackend { stride: usize, padding: usize, padding_out: usize, + dilation: usize, ) -> TchTensor { let tensor = tch::Tensor::conv_transpose1d( &x.tensor, @@ -104,7 +108,7 @@ impl ModuleOps> for TchBackend { &[padding as i64], &[padding_out as i64], 1, - &[1], + &[dilation as i64], ); TchTensor::new(tensor) diff --git a/burn-tensor/src/tensor/module.rs b/burn-tensor/src/tensor/module.rs index 376cbf761..63c4aed1d 100644 --- a/burn-tensor/src/tensor/module.rs +++ b/burn-tensor/src/tensor/module.rs @@ -15,6 +15,7 @@ pub fn conv1d( bias: Option>, stride: usize, padding: usize, + dilation: usize, ) -> Tensor where B: Backend, @@ -25,6 +26,7 @@ where bias.map(|b| b.primitive), stride, padding, + dilation, )) } @@ -35,6 +37,7 @@ pub fn conv2d( bias: Option>, stride: [usize; 2], padding: [usize; 2], + dilation: [usize; 2], ) -> Tensor where B: Backend, @@ -45,6 +48,7 @@ where bias.map(|b| b.primitive), stride, padding, + dilation, )) } @@ -56,6 +60,7 @@ pub fn conv_transpose1d( stride: usize, padding: usize, padding_out: usize, + dilation: usize, ) -> Tensor where B: Backend, @@ -67,6 +72,7 @@ where stride, padding, padding_out, + dilation, )) } @@ -77,7 +83,8 @@ pub fn conv_transpose2d( bias: Option>, stride: [usize; 2], padding: [usize; 2], - out_padding: [usize; 2], + padding_out: [usize; 2], + dilation: [usize; 2], ) -> Tensor where B: Backend, @@ -88,7 +95,8 @@ where bias.map(|b| b.primitive), stride, padding, - out_padding, + padding_out, + dilation, )) } diff --git a/burn-tensor/src/tensor/ops/modules/base.rs b/burn-tensor/src/tensor/ops/modules/base.rs index c5fc8f030..d23a71c84 100644 --- a/burn-tensor/src/tensor/ops/modules/base.rs +++ b/burn-tensor/src/tensor/ops/modules/base.rs @@ -53,6 +53,7 @@ pub trait ModuleOps { bias: Option>, stride: [usize; 2], padding: [usize; 2], + dilation: [usize; 2], ) -> B::TensorPrimitive<4>; /// Two dimensional transposed convolution. /// @@ -67,7 +68,8 @@ pub trait ModuleOps { bias: Option>, stride: [usize; 2], padding: [usize; 2], - out_padding: [usize; 2], + padding_out: [usize; 2], + dilation: [usize; 2], ) -> B::TensorPrimitive<4>; /// Backward pass for the [conv2d](ModuleOps::conv2d) operation. @@ -76,9 +78,11 @@ pub trait ModuleOps { weight: B::TensorPrimitive<4>, bias: Option>, stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], output_grad: B::TensorPrimitive<4>, ) -> Conv2dBackward { - conv::conv2d_backward(x, weight, bias, stride, output_grad) + conv::conv2d_backward(x, weight, bias, stride, padding, dilation, output_grad) } /// One dimensional convolution. /// @@ -93,8 +97,9 @@ pub trait ModuleOps { bias: Option>, stride: usize, padding: usize, + dilation: usize, ) -> B::TensorPrimitive<3> { - conv::conv1d_from_conv2d::(x, weight, bias, stride, padding) + conv::conv1d_from_conv2d::(x, weight, bias, stride, padding, dilation) } /// One dimensional transposed convolution. /// @@ -110,6 +115,7 @@ pub trait ModuleOps { stride: usize, padding: usize, padding_out: usize, + dilation: usize, ) -> B::TensorPrimitive<3> { conv::conv_transpose1d_from_conv_transpose2d::( x, @@ -118,6 +124,7 @@ pub trait ModuleOps { stride, padding, padding_out, + dilation, ) } /// Backward pass for the [conv1d](ModuleOps::conv1d) operation. @@ -126,9 +133,11 @@ pub trait ModuleOps { weight: B::TensorPrimitive<3>, bias: Option>, stride: usize, + padding: usize, + dilation: usize, output_grad: B::TensorPrimitive<3>, ) -> Conv1dBackward { - conv::conv1d_backward(x, weight, bias, stride, output_grad) + conv::conv1d_backward(x, weight, bias, stride, padding, dilation, output_grad) } /// Two dimensional max pooling. /// diff --git a/burn-tensor/src/tensor/ops/modules/conv.rs b/burn-tensor/src/tensor/ops/modules/conv.rs index bf0d9737d..395b9fd00 100644 --- a/burn-tensor/src/tensor/ops/modules/conv.rs +++ b/burn-tensor/src/tensor/ops/modules/conv.rs @@ -2,9 +2,8 @@ use super::{Conv1dBackward, Conv2dBackward}; use crate::{backend::Backend, Shape}; use libm::ceilf; -/// Calculate the expected padding size required when applying a convolution with the specified -/// kernel size, stride, and input size to get the desired output size. -pub fn calculate_padding( +/// Calculate the expected padding size required when applying a convolution. +pub fn calculate_conv_padding( kernel_size: usize, stride: usize, size_in: usize, @@ -21,29 +20,22 @@ pub fn calculate_padding( padding as usize } -/// Calculate the expected output size when applying a convolution with the specified kernel size, -/// stride and padding. -pub fn calculate_output_size( +/// Calculate the expected output size when doing a convolution operation. +pub fn calculate_conv_output_size( kernel_size: usize, stride: usize, padding: usize, + dilation: usize, size_in: usize, ) -> usize { - let kernel_size = kernel_size as f32; - let stride = stride as f32; - let padding = padding as f32; - let size_in = size_in as f32; - - let size_out = (size_in + (2. * padding) - kernel_size) / stride; - let size_out = ceilf(size_out + 1.); - - size_out as usize + (size_in + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1 } fn calculate_padding_out( kernel_size: usize, stride: usize, padding: usize, + dilation: usize, size_in: usize, size_out: usize, ) -> usize { @@ -51,8 +43,10 @@ fn calculate_padding_out( return 0; } - let out = calculate_output_size(kernel_size, stride, padding, size_out) as i64; - i64::max(0, out - size_in as i64) as usize + let out = 1 + libm::ceil( + (size_in + 2 * padding - dilation * (kernel_size - 1) - 1) as f64 / stride as f64, + ) as usize; + i64::max(0, out as i64 - size_out as i64) as usize } /// Calculate the [1D convolution](crate::ops::ModuleOps::conv1d) backward pass using convolutions. @@ -61,14 +55,22 @@ pub(crate) fn conv1d_backward( weight: B::TensorPrimitive<3>, bias: Option>, stride: usize, + padding: usize, + dilation: usize, output_grad: B::TensorPrimitive<3>, ) -> Conv1dBackward { let [batch_size, channels_in, length_in] = B::shape(&x).dims; let [_batch_size, channels_out, length_out] = B::shape(&output_grad).dims; let [_, _, kernel_size] = B::shape(&weight).dims; - let padding = calculate_padding(kernel_size, stride, length_in, length_out); - let padding_out = calculate_padding_out(kernel_size, stride, padding, length_out, length_in); + let padding_out = calculate_padding_out( + kernel_size, + stride, + padding, + dilation, + length_in, + length_out, + ); let x_grad = B::conv_transpose1d( output_grad.clone(), @@ -77,11 +79,19 @@ pub(crate) fn conv1d_backward( stride, padding, padding_out, + dilation, ); let x_swapped = B::swap_dims(x, 0, 1); let output_grad_swapped = B::swap_dims(output_grad, 0, 1); - let weight_grad_swapped = B::conv1d(x_swapped, output_grad_swapped.clone(), None, 1, padding); + let weight_grad_swapped = B::conv1d( + x_swapped, + output_grad_swapped.clone(), + None, + dilation, + padding, + stride, + ); let mut weight_grad = B::swap_dims(weight_grad_swapped, 0, 1); if B::shape(&weight_grad) != Shape::new([channels_out, channels_in, kernel_size]) { @@ -110,28 +120,39 @@ pub(crate) fn conv2d_backward( weight: B::TensorPrimitive<4>, bias: Option>, stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], output_grad: B::TensorPrimitive<4>, ) -> Conv2dBackward { let [batch_size, channels_in, height_in, width_in] = B::shape(&x).dims; let [_batch_size, channels_out, height_out, width_out] = B::shape(&output_grad).dims; let [_, _, kernel_size_1, kernel_size_2] = B::shape(&weight).dims; - let [stride_1, stride_2] = stride; - let padding_1 = calculate_padding(kernel_size_1, stride_1, height_in, height_out); - let padding_2 = calculate_padding(kernel_size_2, stride_2, width_in, width_out); - - let padding_1_out = - calculate_padding_out(kernel_size_1, stride_1, padding_1, height_out, height_in); - let padding_2_out = - calculate_padding_out(kernel_size_2, stride_2, padding_2, width_out, width_in); + let padding_1_out = calculate_padding_out( + kernel_size_1, + stride[0], + padding[0], + dilation[0], + height_in, + height_out, + ); + let padding_2_out = calculate_padding_out( + kernel_size_2, + stride[1], + padding[1], + dilation[1], + width_in, + width_out, + ); let x_grad = B::conv_transpose2d( output_grad.clone(), weight, None, - [stride_1, stride_2], - [padding_1, padding_2], + stride, + padding, [padding_1_out, padding_2_out], + dilation, ); let x_swapped = B::swap_dims(x, 0, 1); @@ -140,8 +161,9 @@ pub(crate) fn conv2d_backward( x_swapped, output_grad_swapped.clone(), None, - [1, 1], - [padding_1, padding_2], + dilation, + padding, + stride, ); let mut weight_grad = B::swap_dims(weight_grad_swapped, 0, 1); @@ -182,6 +204,7 @@ pub(crate) fn conv1d_from_conv2d( bias: Option>, stride: usize, padding: usize, + dilation: usize, ) -> B::TensorPrimitive<3> { let [channels_out, _channels_in, kernel_size] = B::shape(&weight).dims; let [batch_size, channels_in, length_in] = B::shape(&x).dims; @@ -192,7 +215,7 @@ pub(crate) fn conv1d_from_conv2d( ); let x = B::reshape(x, Shape::new([batch_size, channels_in, length_in, 1])); - let tensor = B::conv2d(x, weight, bias, [stride, 1], [padding, 0]); + let tensor = B::conv2d(x, weight, bias, [stride, 1], [padding, 0], [dilation, 1]); let [batch_size, channels_out, height_out, _weight_out] = B::shape(&tensor).dims; B::reshape(tensor, Shape::from([batch_size, channels_out, height_out])) } @@ -205,6 +228,7 @@ pub(crate) fn conv_transpose1d_from_conv_transpose2d( stride: usize, padding: usize, padding_out: usize, + dilation: usize, ) -> B::TensorPrimitive<3> { let [channels_in, channels_out, kernel_size] = B::shape(&weight).dims; let [batch_size, _channels_in, length_in] = B::shape(&x).dims; @@ -215,7 +239,15 @@ pub(crate) fn conv_transpose1d_from_conv_transpose2d( ); let x = B::reshape(x, Shape::new([batch_size, channels_in, length_in, 1])); - let tensor = B::conv_transpose2d(x, weight, bias, [stride, 1], [padding, 0], [padding_out, 0]); + let tensor = B::conv_transpose2d( + x, + weight, + bias, + [stride, 1], + [padding, 0], + [padding_out, 0], + [dilation, 1], + ); let [batch_size, channels_out, height_out, _weight_out] = B::shape(&tensor).dims; B::reshape(tensor, Shape::from([batch_size, channels_out, height_out])) } @@ -230,8 +262,9 @@ mod tests { let stride = 1; let padding = 1; let size_in = 3; + let dilation = 1; - let size_out = calculate_output_size(kernel_size, stride, padding, size_in); + let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in); assert_eq!(size_out, 3); } @@ -242,20 +275,35 @@ mod tests { let stride = 2; let padding = 3; let size_in = 27; + let dilation = 1; - let size_out = calculate_output_size(kernel_size, stride, padding, size_in); + let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in); assert_eq!(size_out, 15); } + #[test] + fn test_calculate_output_size_3() { + let kernel_size = 5; + let stride = 2; + let padding = 3; + let size_in = 27; + let dilation = 2; + + let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in); + + assert_eq!(size_out, 13); + } + #[test] fn test_calculate_same_padding_1() { let kernel_size = 3; let stride = 1; let size_in = 3; + let dilation = 1; - let padding = calculate_padding(kernel_size, stride, size_in, size_in); - let size_out = calculate_output_size(kernel_size, stride, padding, size_in); + let padding = calculate_conv_padding(kernel_size, stride, size_in, size_in); + let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in); assert_eq!(size_in, size_out, "Expected size"); } @@ -265,9 +313,10 @@ mod tests { let kernel_size = 3; let stride = 2; let size_in = 7; + let dilation = 1; - let padding = calculate_padding(kernel_size, stride, size_in, size_in); - let size_out = calculate_output_size(kernel_size, stride, padding, size_in); + let padding = calculate_conv_padding(kernel_size, stride, size_in, size_in); + let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in); assert_eq!(size_in, size_out, "Expected size"); } @@ -278,9 +327,11 @@ mod tests { let stride = 2; let size_in = 7; let size_out = 10; + let dilation = 1; - let padding = calculate_padding(kernel_size, stride, size_in, size_out); - let size_out_expected = calculate_output_size(kernel_size, stride, padding, size_in); + let padding = calculate_conv_padding(kernel_size, stride, size_in, size_out); + let size_out_expected = + calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in); assert_eq!(size_out, size_out_expected, "Expected size"); } diff --git a/burn-tensor/src/tests/module/conv1d.rs b/burn-tensor/src/tests/module/conv1d.rs index f723046c7..28cf829ad 100644 --- a/burn-tensor/src/tests/module/conv1d.rs +++ b/burn-tensor/src/tests/module/conv1d.rs @@ -13,6 +13,7 @@ mod tests { kernel_size: 3, padding: 1, stride: 1, + dilation: 1, length: 6, }; @@ -30,6 +31,25 @@ mod tests { ])); } + #[test] + fn test_conv1d_dilation() { + let test = Conv1dTestCase { + batch_size: 2, + channels_in: 2, + channels_out: 2, + kernel_size: 3, + padding: 1, + stride: 1, + dilation: 2, + length: 4, + }; + + test.assert_output(TestTensor::from_floats([ + [[5., 5.], [5., 5.]], + [[5., 5.], [5., 5.]], + ])); + } + #[test] fn test_conv1d_complex() { let test = Conv1dTestCase { @@ -39,6 +59,7 @@ mod tests { kernel_size: 3, padding: 1, stride: 2, + dilation: 1, length: 9, }; @@ -65,6 +86,7 @@ mod tests { kernel_size: usize, padding: usize, stride: usize, + dilation: usize, length: usize, } @@ -73,7 +95,14 @@ mod tests { let weights = TestTensor::ones([self.channels_out, self.channels_in, self.kernel_size]); let bias = TestTensor::ones([self.channels_out]); let x = TestTensor::ones([self.batch_size, self.channels_in, self.length]); - let output = conv1d(x, weights, Some(bias), self.stride, self.padding); + let output = conv1d( + x, + weights, + Some(bias), + self.stride, + self.padding, + self.dilation, + ); y.to_data().assert_approx_eq(&output.into_data(), 3); } diff --git a/burn-tensor/src/tests/module/conv2d.rs b/burn-tensor/src/tests/module/conv2d.rs index 247e342df..121bdc287 100644 --- a/burn-tensor/src/tests/module/conv2d.rs +++ b/burn-tensor/src/tests/module/conv2d.rs @@ -16,6 +16,8 @@ mod tests { padding_2: 1, stride_1: 1, stride_2: 1, + dilation_1: 1, + dilation_2: 1, height: 6, width: 6, }; @@ -88,62 +90,24 @@ mod tests { padding_2: 2, stride_1: 2, stride_2: 3, - height: 7, - width: 9, + dilation_1: 1, + dilation_2: 2, + height: 4, + width: 5, }; test.assert_output(TestTensor::from_floats([ [ - [ - [1., 13., 13., 13.], - [1., 19., 19., 19.], - [1., 19., 19., 19.], - [1., 13., 13., 13.], - ], - [ - [1., 13., 13., 13.], - [1., 19., 19., 19.], - [1., 19., 19., 19.], - [1., 13., 13., 13.], - ], - [ - [1., 13., 13., 13.], - [1., 19., 19., 19.], - [1., 19., 19., 19.], - [1., 13., 13., 13.], - ], - [ - [1., 13., 13., 13.], - [1., 19., 19., 19.], - [1., 19., 19., 19.], - [1., 13., 13., 13.], - ], + [[7., 13., 7.], [10., 19., 10.]], + [[7., 13., 7.], [10., 19., 10.]], + [[7., 13., 7.], [10., 19., 10.]], + [[7., 13., 7.], [10., 19., 10.]], ], [ - [ - [1., 13., 13., 13.], - [1., 19., 19., 19.], - [1., 19., 19., 19.], - [1., 13., 13., 13.], - ], - [ - [1., 13., 13., 13.], - [1., 19., 19., 19.], - [1., 19., 19., 19.], - [1., 13., 13., 13.], - ], - [ - [1., 13., 13., 13.], - [1., 19., 19., 19.], - [1., 19., 19., 19.], - [1., 13., 13., 13.], - ], - [ - [1., 13., 13., 13.], - [1., 19., 19., 19.], - [1., 19., 19., 19.], - [1., 13., 13., 13.], - ], + [[7., 13., 7.], [10., 19., 10.]], + [[7., 13., 7.], [10., 19., 10.]], + [[7., 13., 7.], [10., 19., 10.]], + [[7., 13., 7.], [10., 19., 10.]], ], ])); } @@ -158,6 +122,8 @@ mod tests { padding_2: usize, stride_1: usize, stride_2: usize, + dilation_1: usize, + dilation_2: usize, height: usize, width: usize, } @@ -178,6 +144,7 @@ mod tests { Some(bias), [self.stride_1, self.stride_2], [self.padding_1, self.padding_2], + [self.dilation_1, self.dilation_2], ); y.to_data().assert_approx_eq(&output.into_data(), 3); diff --git a/burn-tensor/src/tests/module/conv_transpose1d.rs b/burn-tensor/src/tests/module/conv_transpose1d.rs index 91a902bc1..2ebdf8067 100644 --- a/burn-tensor/src/tests/module/conv_transpose1d.rs +++ b/burn-tensor/src/tests/module/conv_transpose1d.rs @@ -14,6 +14,7 @@ mod tests { padding: 1, padding_out: 0, stride: 1, + dilation: 1, length: 4, }; @@ -33,6 +34,7 @@ mod tests { padding: 1, padding_out: 1, stride: 2, + dilation: 1, length: 4, }; @@ -42,6 +44,26 @@ mod tests { ]])); } + #[test] + fn test_conv_transpose1d_dilation() { + let test = ConvTranspose1dTestCase { + batch_size: 1, + channels_in: 2, + channels_out: 2, + kernel_size: 3, + padding: 1, + padding_out: 0, + stride: 1, + dilation: 2, + length: 4, + }; + + test.assert_output(TestTensor::from_floats([[ + [30., 64., 78., 76., 94., 52.], + [49., 101., 127., 113., 143., 77.], + ]])); + } + struct ConvTranspose1dTestCase { batch_size: usize, channels_in: usize, @@ -50,6 +72,7 @@ mod tests { padding: usize, padding_out: usize, stride: usize, + dilation: usize, length: usize, } @@ -81,6 +104,7 @@ mod tests { self.stride, self.padding, self.padding_out, + self.dilation, ); y.to_data().assert_approx_eq(&output.into_data(), 3); diff --git a/burn-tensor/src/tests/module/conv_transpose2d.rs b/burn-tensor/src/tests/module/conv_transpose2d.rs index 37c1e9431..7cd2e9cde 100644 --- a/burn-tensor/src/tests/module/conv_transpose2d.rs +++ b/burn-tensor/src/tests/module/conv_transpose2d.rs @@ -18,6 +18,8 @@ mod tests { padding_out_2: 0, stride_1: 1, stride_2: 1, + dilation_1: 1, + dilation_2: 1, height: 2, width: 2, }; @@ -38,6 +40,8 @@ mod tests { padding_out_2: 0, stride_1: 1, stride_2: 1, + dilation_1: 1, + dilation_2: 1, height: 4, width: 4, }; @@ -78,6 +82,8 @@ mod tests { padding_out_2: 0, stride_1: 2, stride_2: 2, + dilation_1: 1, + dilation_2: 1, height: 2, width: 2, }; @@ -90,6 +96,44 @@ mod tests { ]]])); } + #[test] + fn test_conv_transpose2d_dilation_2() { + let test = ConvTranspose2dTestCase { + batch_size: 1, + channels_in: 2, + channels_out: 2, + kernel_size_1: 3, + kernel_size_2: 3, + padding_1: 1, + padding_2: 1, + padding_out_1: 1, + padding_out_2: 1, + stride_1: 1, + stride_2: 1, + dilation_1: 2, + dilation_2: 2, + height: 2, + width: 2, + }; + + test.assert_output(TestTensor::from_floats([[ + [ + [126., 116., 136., 124., 146.], + [108., 88., 114., 92., 120.], + [156., 140., 166., 148., 176.], + [126., 100., 132., 104., 138.], + [186., 164., 196., 172., 206.], + ], + [ + [217., 189., 227., 197., 237.], + [163., 125., 169., 129., 175.], + [247., 213., 257., 221., 267.], + [181., 137., 187., 141., 193.], + [277., 237., 287., 245., 297.], + ], + ]])); + } + #[test] fn test_conv_transpose2d_stride2_out_padding() { let test = ConvTranspose2dTestCase { @@ -104,6 +148,8 @@ mod tests { padding_out_2: 1, stride_1: 2, stride_2: 2, + dilation_1: 1, + dilation_2: 1, height: 4, width: 4, }; @@ -144,6 +190,8 @@ mod tests { padding_out_2: usize, stride_1: usize, stride_2: usize, + dilation_1: usize, + dilation_2: usize, height: usize, width: usize, } @@ -181,6 +229,7 @@ mod tests { [self.stride_1, self.stride_2], [self.padding_1, self.padding_2], [self.padding_out_1, self.padding_out_2], + [self.dilation_1, self.dilation_2], ); y.to_data().assert_approx_eq(&output.into_data(), 3); diff --git a/examples/mnist-inference-web/model.bin b/examples/mnist-inference-web/model.bin index cc565aeb9..9b76d2cc1 100644 Binary files a/examples/mnist-inference-web/model.bin and b/examples/mnist-inference-web/model.bin differ