mirror of https://github.com/tracel-ai/burn.git
Support dilation in convolution operations (#301)
This commit is contained in:
parent
bd58922784
commit
78ac09fb7a
|
@ -52,6 +52,7 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
bias: Option<ADTensor<B, 1>>,
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
) -> ADTensor<B, 4> {
|
||||
#[derive(Debug)]
|
||||
struct Conv2DWithBias;
|
||||
|
@ -64,14 +65,17 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
B::TensorPrimitive<4>,
|
||||
B::TensorPrimitive<1>,
|
||||
[usize; 2],
|
||||
[usize; 2],
|
||||
[usize; 2],
|
||||
);
|
||||
|
||||
fn backward(self, ops: Ops<Self::State, 3>, grads: &mut Gradients) {
|
||||
let [node_x, node_weight, node_bias] = ops.parents;
|
||||
let grad = grads.consume::<B, 4>(&ops.node);
|
||||
|
||||
let (x, weight, bias, stride) = ops.state;
|
||||
let backward = B::conv2d_backward(x, weight, Some(bias), stride, grad);
|
||||
let (x, weight, bias, stride, padding, dilation) = ops.state;
|
||||
let backward =
|
||||
B::conv2d_backward(x, weight, Some(bias), stride, padding, dilation, grad);
|
||||
|
||||
if let Some(node) = node_x {
|
||||
grads.register::<B, 4>(node, backward.x_grad)
|
||||
|
@ -86,14 +90,20 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
}
|
||||
|
||||
impl<B: Backend> Backward<B, 4, 2> for Conv2DNoBias {
|
||||
type State = (B::TensorPrimitive<4>, B::TensorPrimitive<4>, [usize; 2]);
|
||||
type State = (
|
||||
B::TensorPrimitive<4>,
|
||||
B::TensorPrimitive<4>,
|
||||
[usize; 2],
|
||||
[usize; 2],
|
||||
[usize; 2],
|
||||
);
|
||||
|
||||
fn backward(self, ops: Ops<Self::State, 2>, grads: &mut Gradients) {
|
||||
let [node_x, node_weight] = ops.parents;
|
||||
let grad = grads.consume::<B, 4>(&ops.node);
|
||||
|
||||
let (x, weight, stride) = ops.state;
|
||||
let backward = B::conv2d_backward(x, weight, None, stride, grad);
|
||||
let (x, weight, stride, padding, dilation) = ops.state;
|
||||
let backward = B::conv2d_backward(x, weight, None, stride, padding, dilation, grad);
|
||||
|
||||
if let Some(node) = node_x {
|
||||
grads.register::<B, 4>(node, backward.x_grad)
|
||||
|
@ -119,6 +129,8 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
weight.primitive.clone(),
|
||||
bias.primitive.clone(),
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
),
|
||||
B::conv2d(
|
||||
x.primitive,
|
||||
|
@ -126,6 +138,7 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
Some(bias.primitive),
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
),
|
||||
),
|
||||
OpsKind::UnTracked(prep) => prep.finish(B::conv2d(
|
||||
|
@ -134,6 +147,7 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
Some(bias.primitive),
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
@ -143,8 +157,21 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
.statefull()
|
||||
{
|
||||
OpsKind::Tracked(prep) => prep.finish(
|
||||
(x.primitive.clone(), weight.primitive.clone(), stride),
|
||||
B::conv2d(x.primitive, weight.primitive, None, stride, padding),
|
||||
(
|
||||
x.primitive.clone(),
|
||||
weight.primitive.clone(),
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
),
|
||||
B::conv2d(
|
||||
x.primitive,
|
||||
weight.primitive,
|
||||
None,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
),
|
||||
),
|
||||
OpsKind::UnTracked(prep) => prep.finish(B::conv2d(
|
||||
x.primitive,
|
||||
|
@ -152,6 +179,7 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
None,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
@ -164,7 +192,8 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
_bias: Option<ADTensor<B, 1>>,
|
||||
_stride: [usize; 2],
|
||||
_padding: [usize; 2],
|
||||
_out_padding: [usize; 2],
|
||||
_padding_out: [usize; 2],
|
||||
_dilation: [usize; 2],
|
||||
) -> ADTensor<B, 4> {
|
||||
todo!("Transposed 2D convolution doesn't yet support backward.");
|
||||
}
|
||||
|
@ -175,7 +204,8 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
_bias: Option<ADTensor<B, 1>>,
|
||||
_stride: usize,
|
||||
_padding: usize,
|
||||
_out_padding: usize,
|
||||
_padding_out: usize,
|
||||
_dilation: usize,
|
||||
) -> ADTensor<B, 3> {
|
||||
todo!("Transposed 1D convolution doesn't yet support backward.");
|
||||
}
|
||||
|
@ -186,6 +216,7 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
bias: Option<ADTensor<B, 1>>,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
dilation: usize,
|
||||
) -> ADTensor<B, 3> {
|
||||
#[derive(Debug)]
|
||||
struct Conv1DWithBias;
|
||||
|
@ -198,14 +229,17 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
B::TensorPrimitive<3>,
|
||||
B::TensorPrimitive<1>,
|
||||
usize,
|
||||
usize,
|
||||
usize,
|
||||
);
|
||||
|
||||
fn backward(self, ops: Ops<Self::State, 3>, grads: &mut Gradients) {
|
||||
let [node_x, node_weight, node_bias] = ops.parents;
|
||||
let grad = grads.consume::<B, 3>(&ops.node);
|
||||
|
||||
let (x, weight, bias, stride) = ops.state;
|
||||
let backward = B::conv1d_backward(x, weight, Some(bias), stride, grad);
|
||||
let (x, weight, bias, stride, padding, dilation) = ops.state;
|
||||
let backward =
|
||||
B::conv1d_backward(x, weight, Some(bias), stride, padding, dilation, grad);
|
||||
|
||||
if let Some(node) = node_x {
|
||||
grads.register::<B, 3>(node, backward.x_grad)
|
||||
|
@ -220,14 +254,20 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
}
|
||||
|
||||
impl<B: Backend> Backward<B, 3, 2> for Conv1DNoBias {
|
||||
type State = (B::TensorPrimitive<3>, B::TensorPrimitive<3>, usize);
|
||||
type State = (
|
||||
B::TensorPrimitive<3>,
|
||||
B::TensorPrimitive<3>,
|
||||
usize,
|
||||
usize,
|
||||
usize,
|
||||
);
|
||||
|
||||
fn backward(self, ops: Ops<Self::State, 2>, grads: &mut Gradients) {
|
||||
let [node_x, node_weight] = ops.parents;
|
||||
let grad = grads.consume::<B, 3>(&ops.node);
|
||||
|
||||
let (x, weight, stride) = ops.state;
|
||||
let backward = B::conv1d_backward(x, weight, None, stride, grad);
|
||||
let (x, weight, stride, padding, dilation) = ops.state;
|
||||
let backward = B::conv1d_backward(x, weight, None, stride, padding, dilation, grad);
|
||||
|
||||
if let Some(node) = node_x {
|
||||
grads.register::<B, 3>(node, backward.x_grad)
|
||||
|
@ -252,6 +292,8 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
weight.primitive.clone(),
|
||||
bias.primitive.clone(),
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
),
|
||||
B::conv1d(
|
||||
x.primitive,
|
||||
|
@ -259,6 +301,7 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
Some(bias.primitive),
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
),
|
||||
),
|
||||
OpsKind::UnTracked(prep) => prep.finish(B::conv1d(
|
||||
|
@ -267,6 +310,7 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
Some(bias.primitive),
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
@ -276,8 +320,21 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
.statefull()
|
||||
{
|
||||
OpsKind::Tracked(prep) => prep.finish(
|
||||
(x.primitive.clone(), weight.primitive.clone(), stride),
|
||||
B::conv1d(x.primitive, weight.primitive, None, stride, padding),
|
||||
(
|
||||
x.primitive.clone(),
|
||||
weight.primitive.clone(),
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
),
|
||||
B::conv1d(
|
||||
x.primitive,
|
||||
weight.primitive,
|
||||
None,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
),
|
||||
),
|
||||
OpsKind::UnTracked(prep) => prep.finish(B::conv1d(
|
||||
x.primitive,
|
||||
|
@ -285,6 +342,7 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
None,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -12,6 +12,7 @@ mod tests {
|
|||
kernel_size: 3,
|
||||
padding: 1,
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
length: 6,
|
||||
};
|
||||
let grads = Grads {
|
||||
|
@ -46,6 +47,7 @@ mod tests {
|
|||
kernel_size: 3,
|
||||
padding: 1,
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
length: 6,
|
||||
};
|
||||
let grads = Grads {
|
||||
|
@ -72,6 +74,7 @@ mod tests {
|
|||
kernel_size: 3,
|
||||
padding: 2,
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
length: 6,
|
||||
};
|
||||
let grads = Grads {
|
||||
|
@ -97,6 +100,7 @@ mod tests {
|
|||
kernel_size: 3,
|
||||
padding: 1,
|
||||
stride: 2,
|
||||
dilation: 1,
|
||||
length: 4,
|
||||
};
|
||||
let grads = Grads {
|
||||
|
@ -113,6 +117,32 @@ mod tests {
|
|||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv1d_dilation() {
|
||||
let test = Conv1dTestCase {
|
||||
batch_size: 2,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size: 3,
|
||||
padding: 1,
|
||||
stride: 1,
|
||||
dilation: 2,
|
||||
length: 4,
|
||||
};
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats([
|
||||
[[2., 2., 2., 2.], [2., 2., 2., 2.]],
|
||||
[[2., 2., 2., 2.], [2., 2., 2., 2.]],
|
||||
]),
|
||||
weight: TestTensor::from_floats([
|
||||
[[2., 4., 2.], [2., 4., 2.]],
|
||||
[[2., 4., 2.], [2., 4., 2.]],
|
||||
]),
|
||||
bias: TestTensor::from_floats([4., 4.]),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
struct Conv1dTestCase {
|
||||
batch_size: usize,
|
||||
channels_in: usize,
|
||||
|
@ -120,6 +150,7 @@ mod tests {
|
|||
kernel_size: usize,
|
||||
padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
length: usize,
|
||||
}
|
||||
|
||||
|
@ -143,6 +174,7 @@ mod tests {
|
|||
Some(bias.clone()),
|
||||
self.stride,
|
||||
self.padding,
|
||||
self.dilation,
|
||||
);
|
||||
let grads = output.backward();
|
||||
|
||||
|
|
|
@ -15,6 +15,8 @@ mod tests {
|
|||
padding_2: 1,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
height: 6,
|
||||
width: 6,
|
||||
};
|
||||
|
@ -107,6 +109,8 @@ mod tests {
|
|||
padding_2: 1,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
height: 6,
|
||||
width: 6,
|
||||
};
|
||||
|
@ -180,6 +184,8 @@ mod tests {
|
|||
padding_2: 1,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
height: 6,
|
||||
width: 6,
|
||||
};
|
||||
|
@ -265,6 +271,8 @@ mod tests {
|
|||
padding_2: 2,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
height: 6,
|
||||
width: 6,
|
||||
};
|
||||
|
@ -334,6 +342,8 @@ mod tests {
|
|||
padding_2: 1,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
height: 6,
|
||||
width: 5,
|
||||
};
|
||||
|
@ -403,6 +413,8 @@ mod tests {
|
|||
padding_2: 1,
|
||||
stride_1: 2,
|
||||
stride_2: 2,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
height: 8,
|
||||
width: 8,
|
||||
};
|
||||
|
@ -480,6 +492,8 @@ mod tests {
|
|||
padding_2: 1,
|
||||
stride_1: 3,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
height: 8,
|
||||
width: 8,
|
||||
};
|
||||
|
@ -545,6 +559,48 @@ mod tests {
|
|||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv2d_complex() {
|
||||
let test = Conv2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
channels_out: 3,
|
||||
kernel_size_1: 2,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 1,
|
||||
padding_2: 2,
|
||||
stride_1: 1,
|
||||
stride_2: 2,
|
||||
dilation_1: 2,
|
||||
dilation_2: 3,
|
||||
height: 4,
|
||||
width: 5,
|
||||
};
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats([[
|
||||
[
|
||||
[3., 3., 0., 3., 3.],
|
||||
[6., 6., 0., 6., 6.],
|
||||
[6., 6., 0., 6., 6.],
|
||||
[3., 3., 0., 3., 3.],
|
||||
],
|
||||
[
|
||||
[3., 3., 0., 3., 3.],
|
||||
[6., 6., 0., 6., 6.],
|
||||
[6., 6., 0., 6., 6.],
|
||||
[3., 3., 0., 3., 3.],
|
||||
],
|
||||
]]),
|
||||
weight: TestTensor::from_floats([
|
||||
[[[3., 6., 3.], [3., 6., 3.]], [[3., 6., 3.], [3., 6., 3.]]],
|
||||
[[[3., 6., 3.], [3., 6., 3.]], [[3., 6., 3.], [3., 6., 3.]]],
|
||||
[[[3., 6., 3.], [3., 6., 3.]], [[3., 6., 3.], [3., 6., 3.]]],
|
||||
]),
|
||||
bias: TestTensor::from_floats([8., 8., 8.]),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
struct Conv2dTestCase {
|
||||
batch_size: usize,
|
||||
channels_in: usize,
|
||||
|
@ -555,6 +611,8 @@ mod tests {
|
|||
padding_2: usize,
|
||||
stride_1: usize,
|
||||
stride_2: usize,
|
||||
dilation_1: usize,
|
||||
dilation_2: usize,
|
||||
height: usize,
|
||||
width: usize,
|
||||
}
|
||||
|
@ -584,6 +642,7 @@ mod tests {
|
|||
Some(bias.clone()),
|
||||
[self.stride_1, self.stride_2],
|
||||
[self.padding_1, self.padding_2],
|
||||
[self.dilation_1, self.dilation_2],
|
||||
);
|
||||
let grads = output.backward();
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ use crate::nn::Initializer;
|
|||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Tensor;
|
||||
use burn_tensor::module::conv1d;
|
||||
use burn_tensor::ops::conv::calculate_padding;
|
||||
use burn_tensor::ops::conv::calculate_conv_padding;
|
||||
|
||||
use libm::sqrt;
|
||||
|
||||
|
@ -23,6 +23,9 @@ pub struct Conv1dConfig {
|
|||
/// The stride of the convolution.
|
||||
#[config(default = "1")]
|
||||
pub stride: usize,
|
||||
/// Spacing between kernel elements.
|
||||
#[config(default = "1")]
|
||||
pub dilation: usize,
|
||||
/// The padding configuration.
|
||||
#[config(default = "Conv1dPaddingConfig::Valid")]
|
||||
pub padding: Conv1dPaddingConfig,
|
||||
|
@ -61,6 +64,7 @@ pub struct Conv1d<B: Backend> {
|
|||
bias: Option<Param<Tensor<B, 1>>>,
|
||||
stride: usize,
|
||||
kernel_size: usize,
|
||||
dilation: usize,
|
||||
padding: Conv1dPaddingConfig,
|
||||
}
|
||||
|
||||
|
@ -90,6 +94,7 @@ impl Conv1dConfig {
|
|||
stride: 1, // TODO: Add the stride to the config when properly supported.
|
||||
kernel_size: self.kernel_size,
|
||||
padding: self.padding.clone(),
|
||||
dilation: self.dilation,
|
||||
}
|
||||
}
|
||||
/// Initialize a new [conv1d](Conv1d) module with a [record](Conv1dRecord).
|
||||
|
@ -100,6 +105,7 @@ impl Conv1dConfig {
|
|||
stride: 1, // TODO: Add the stride to the config when properly supported.
|
||||
kernel_size: self.kernel_size,
|
||||
padding: self.padding.clone(),
|
||||
dilation: self.dilation,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -114,7 +120,7 @@ impl<B: Backend> Conv1d<B> {
|
|||
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
|
||||
let same_padding = || {
|
||||
let [_batch_size, _channels_in, length] = input.dims();
|
||||
calculate_padding(self.kernel_size, self.stride, length, length)
|
||||
calculate_conv_padding(self.kernel_size, self.stride, length, length)
|
||||
};
|
||||
|
||||
let padding = match &self.padding {
|
||||
|
@ -129,6 +135,7 @@ impl<B: Backend> Conv1d<B> {
|
|||
self.bias.as_ref().map(|bias| bias.val()),
|
||||
self.stride,
|
||||
padding,
|
||||
self.dilation,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,7 +7,7 @@ use crate::nn::Initializer;
|
|||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Tensor;
|
||||
use burn_tensor::module::conv2d;
|
||||
use burn_tensor::ops::conv::calculate_padding;
|
||||
use burn_tensor::ops::conv::calculate_conv_padding;
|
||||
|
||||
use libm::sqrt;
|
||||
|
||||
|
@ -21,6 +21,9 @@ pub struct Conv2dConfig {
|
|||
/// The stride of the convolution.
|
||||
#[config(default = "[1, 1]")]
|
||||
pub stride: [usize; 2],
|
||||
/// Spacing between kernel elements.
|
||||
#[config(default = "[1, 1]")]
|
||||
pub dilation: [usize; 2],
|
||||
/// The padding configuration.
|
||||
#[config(default = "Conv2dPaddingConfig::Valid")]
|
||||
pub padding: Conv2dPaddingConfig,
|
||||
|
@ -59,6 +62,7 @@ pub struct Conv2d<B: Backend> {
|
|||
bias: Option<Param<Tensor<B, 1>>>,
|
||||
stride: [usize; 2],
|
||||
kernel_size: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
padding: Conv2dPaddingConfig,
|
||||
}
|
||||
|
||||
|
@ -90,8 +94,9 @@ impl Conv2dConfig {
|
|||
Conv2d {
|
||||
weight: Param::from(weight),
|
||||
bias: bias.map(Param::from),
|
||||
stride: [1, 1], // TODO: Add the stride to the config when properly supported.
|
||||
stride: self.stride,
|
||||
kernel_size: self.kernel_size,
|
||||
dilation: self.dilation,
|
||||
padding: self.padding.clone(),
|
||||
}
|
||||
}
|
||||
|
@ -102,6 +107,7 @@ impl Conv2dConfig {
|
|||
weight: record.weight,
|
||||
bias: record.bias,
|
||||
stride: self.stride,
|
||||
dilation: self.dilation,
|
||||
kernel_size: self.kernel_size,
|
||||
padding: self.padding.clone(),
|
||||
}
|
||||
|
@ -126,6 +132,7 @@ impl<B: Backend> Conv2d<B> {
|
|||
self.bias.as_ref().map(|bias| bias.val()),
|
||||
self.stride,
|
||||
padding,
|
||||
self.dilation,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
@ -139,8 +146,8 @@ impl Conv2dPaddingConfig {
|
|||
stride: &[usize; 2],
|
||||
) -> [usize; 2] {
|
||||
let same_padding = || {
|
||||
let p1 = calculate_padding(kernel_size[0], stride[0], height, height);
|
||||
let p2 = calculate_padding(kernel_size[1], stride[1], width, width);
|
||||
let p1 = calculate_conv_padding(kernel_size[0], stride[0], height, height);
|
||||
let p2 = calculate_conv_padding(kernel_size[1], stride[1], width, width);
|
||||
|
||||
[p1, p2]
|
||||
};
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use burn_tensor::ElementConversion;
|
||||
use burn_tensor::{ops::conv::calculate_conv_output_size, ElementConversion};
|
||||
use ndarray::{Array4, Dim};
|
||||
|
||||
use crate::{
|
||||
|
@ -20,13 +20,20 @@ pub(crate) fn conv2d<E: FloatNdArrayElement>(
|
|||
let [batch_size, _in_channels, in_height, in_width] = x.shape().dims;
|
||||
let [out_channels, in_channels, kernel_height, kernel_width] = weight.shape().dims;
|
||||
|
||||
let out_height = (in_height + 2 * padding_height - dilatation_height * (kernel_height - 1) - 1)
|
||||
/ stride_height
|
||||
+ 1;
|
||||
|
||||
let out_width = (in_width + 2 * padding_width - dilatation_width * (kernel_width - 1) - 1)
|
||||
/ stride_width
|
||||
+ 1;
|
||||
let out_height = calculate_conv_output_size(
|
||||
kernel_height,
|
||||
stride_height,
|
||||
padding_height,
|
||||
dilatation_height,
|
||||
in_height,
|
||||
);
|
||||
let out_width = calculate_conv_output_size(
|
||||
kernel_width,
|
||||
stride_width,
|
||||
padding_width,
|
||||
dilatation_width,
|
||||
in_width,
|
||||
);
|
||||
|
||||
let x = apply_padding_4d(x, padding, 0i32.elem()).array;
|
||||
|
||||
|
@ -86,10 +93,15 @@ pub(crate) fn conv_transpose2d<E: FloatNdArrayElement>(
|
|||
let [batch_size, _in_channels, in_height, in_width] = x.shape().dims;
|
||||
let [in_channels, out_channels, kernel_height, kernel_width] = weight.shape().dims;
|
||||
|
||||
let out_height =
|
||||
(in_height - 1) * stride_height + kernel_height + out_padding_height - 2 * padding_height;
|
||||
let out_height = (in_height - 1) * stride_height
|
||||
+ dilation_height * (kernel_height - 1)
|
||||
+ out_padding_height
|
||||
- 2 * padding_height
|
||||
+ 1;
|
||||
let out_width =
|
||||
(in_width - 1) * stride_width + kernel_width + out_padding_width - 2 * padding_width;
|
||||
(in_width - 1) * stride_width + dilation_width * (kernel_width - 1) + out_padding_width
|
||||
- 2 * padding_width
|
||||
+ 1;
|
||||
|
||||
let x = x.array;
|
||||
let mut output = Array4::zeros(Dim([batch_size, out_channels, out_height, out_width]));
|
||||
|
|
|
@ -75,8 +75,9 @@ impl<E: FloatNdArrayElement> ModuleOps<NdArrayBackend<E>> for NdArrayBackend<E>
|
|||
bias: Option<NdArrayTensor<E, 1>>,
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
) -> NdArrayTensor<E, 4> {
|
||||
conv2d(x, weight, bias, stride, padding, [1, 1])
|
||||
conv2d(x, weight, bias, stride, padding, dilation)
|
||||
}
|
||||
|
||||
fn conv_transpose2d(
|
||||
|
@ -85,9 +86,10 @@ impl<E: FloatNdArrayElement> ModuleOps<NdArrayBackend<E>> for NdArrayBackend<E>
|
|||
bias: Option<NdArrayTensor<E, 1>>,
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
out_padding: [usize; 2],
|
||||
padding_out: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
) -> NdArrayTensor<E, 4> {
|
||||
conv_transpose2d(x, weight, bias, stride, padding, out_padding, [1, 1])
|
||||
conv_transpose2d(x, weight, bias, stride, padding, padding_out, dilation)
|
||||
}
|
||||
|
||||
fn max_pool2d(
|
||||
|
|
|
@ -32,6 +32,7 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
|
|||
bias: Option<TchTensor<E, 1>>,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
dilation: usize,
|
||||
) -> TchTensor<E, 3> {
|
||||
let tensor = tch::Tensor::conv1d(
|
||||
&x.tensor,
|
||||
|
@ -39,7 +40,7 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
|
|||
bias.map(|t| t.tensor),
|
||||
&[stride as i64],
|
||||
&[padding as i64],
|
||||
&[1],
|
||||
&[dilation as i64],
|
||||
1,
|
||||
);
|
||||
|
||||
|
@ -52,6 +53,7 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
|
|||
bias: Option<TchTensor<E, 1>>,
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
) -> TchTensor<E, 4> {
|
||||
let tensor = tch::Tensor::conv2d(
|
||||
&x.tensor,
|
||||
|
@ -59,7 +61,7 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
|
|||
bias.map(|t| t.tensor),
|
||||
&[stride[0] as i64, stride[1] as i64],
|
||||
&[padding[0] as i64, padding[1] as i64],
|
||||
&[1, 1],
|
||||
&[dilation[0] as i64, dilation[1] as i64],
|
||||
1,
|
||||
);
|
||||
|
||||
|
@ -72,7 +74,8 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
|
|||
bias: Option<TchTensor<E, 1>>,
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
out_padding: [usize; 2],
|
||||
padding_out: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
) -> TchTensor<E, 4> {
|
||||
let tensor = tch::Tensor::conv_transpose2d(
|
||||
&x.tensor,
|
||||
|
@ -80,9 +83,9 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
|
|||
bias.map(|t| t.tensor),
|
||||
&[stride[0] as i64, stride[1] as i64],
|
||||
&[padding[0] as i64, padding[1] as i64],
|
||||
&[out_padding[0] as i64, out_padding[1] as i64],
|
||||
&[padding_out[0] as i64, padding_out[1] as i64],
|
||||
1,
|
||||
&[1, 1],
|
||||
&[dilation[0] as i64, dilation[1] as i64],
|
||||
);
|
||||
|
||||
TchTensor::new(tensor)
|
||||
|
@ -95,6 +98,7 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
|
|||
stride: usize,
|
||||
padding: usize,
|
||||
padding_out: usize,
|
||||
dilation: usize,
|
||||
) -> TchTensor<E, 3> {
|
||||
let tensor = tch::Tensor::conv_transpose1d(
|
||||
&x.tensor,
|
||||
|
@ -104,7 +108,7 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
|
|||
&[padding as i64],
|
||||
&[padding_out as i64],
|
||||
1,
|
||||
&[1],
|
||||
&[dilation as i64],
|
||||
);
|
||||
|
||||
TchTensor::new(tensor)
|
||||
|
|
|
@ -15,6 +15,7 @@ pub fn conv1d<B>(
|
|||
bias: Option<Tensor<B, 1>>,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
dilation: usize,
|
||||
) -> Tensor<B, 3>
|
||||
where
|
||||
B: Backend,
|
||||
|
@ -25,6 +26,7 @@ where
|
|||
bias.map(|b| b.primitive),
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
))
|
||||
}
|
||||
|
||||
|
@ -35,6 +37,7 @@ pub fn conv2d<B>(
|
|||
bias: Option<Tensor<B, 1>>,
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
) -> Tensor<B, 4>
|
||||
where
|
||||
B: Backend,
|
||||
|
@ -45,6 +48,7 @@ where
|
|||
bias.map(|b| b.primitive),
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
))
|
||||
}
|
||||
|
||||
|
@ -56,6 +60,7 @@ pub fn conv_transpose1d<B>(
|
|||
stride: usize,
|
||||
padding: usize,
|
||||
padding_out: usize,
|
||||
dilation: usize,
|
||||
) -> Tensor<B, 3>
|
||||
where
|
||||
B: Backend,
|
||||
|
@ -67,6 +72,7 @@ where
|
|||
stride,
|
||||
padding,
|
||||
padding_out,
|
||||
dilation,
|
||||
))
|
||||
}
|
||||
|
||||
|
@ -77,7 +83,8 @@ pub fn conv_transpose2d<B>(
|
|||
bias: Option<Tensor<B, 1>>,
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
out_padding: [usize; 2],
|
||||
padding_out: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
) -> Tensor<B, 4>
|
||||
where
|
||||
B: Backend,
|
||||
|
@ -88,7 +95,8 @@ where
|
|||
bias.map(|b| b.primitive),
|
||||
stride,
|
||||
padding,
|
||||
out_padding,
|
||||
padding_out,
|
||||
dilation,
|
||||
))
|
||||
}
|
||||
|
||||
|
|
|
@ -53,6 +53,7 @@ pub trait ModuleOps<B: Backend> {
|
|||
bias: Option<B::TensorPrimitive<1>>,
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
) -> B::TensorPrimitive<4>;
|
||||
/// Two dimensional transposed convolution.
|
||||
///
|
||||
|
@ -67,7 +68,8 @@ pub trait ModuleOps<B: Backend> {
|
|||
bias: Option<B::TensorPrimitive<1>>,
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
out_padding: [usize; 2],
|
||||
padding_out: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
) -> B::TensorPrimitive<4>;
|
||||
|
||||
/// Backward pass for the [conv2d](ModuleOps::conv2d) operation.
|
||||
|
@ -76,9 +78,11 @@ pub trait ModuleOps<B: Backend> {
|
|||
weight: B::TensorPrimitive<4>,
|
||||
bias: Option<B::TensorPrimitive<1>>,
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
output_grad: B::TensorPrimitive<4>,
|
||||
) -> Conv2dBackward<B> {
|
||||
conv::conv2d_backward(x, weight, bias, stride, output_grad)
|
||||
conv::conv2d_backward(x, weight, bias, stride, padding, dilation, output_grad)
|
||||
}
|
||||
/// One dimensional convolution.
|
||||
///
|
||||
|
@ -93,8 +97,9 @@ pub trait ModuleOps<B: Backend> {
|
|||
bias: Option<B::TensorPrimitive<1>>,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
dilation: usize,
|
||||
) -> B::TensorPrimitive<3> {
|
||||
conv::conv1d_from_conv2d::<B>(x, weight, bias, stride, padding)
|
||||
conv::conv1d_from_conv2d::<B>(x, weight, bias, stride, padding, dilation)
|
||||
}
|
||||
/// One dimensional transposed convolution.
|
||||
///
|
||||
|
@ -110,6 +115,7 @@ pub trait ModuleOps<B: Backend> {
|
|||
stride: usize,
|
||||
padding: usize,
|
||||
padding_out: usize,
|
||||
dilation: usize,
|
||||
) -> B::TensorPrimitive<3> {
|
||||
conv::conv_transpose1d_from_conv_transpose2d::<B>(
|
||||
x,
|
||||
|
@ -118,6 +124,7 @@ pub trait ModuleOps<B: Backend> {
|
|||
stride,
|
||||
padding,
|
||||
padding_out,
|
||||
dilation,
|
||||
)
|
||||
}
|
||||
/// Backward pass for the [conv1d](ModuleOps::conv1d) operation.
|
||||
|
@ -126,9 +133,11 @@ pub trait ModuleOps<B: Backend> {
|
|||
weight: B::TensorPrimitive<3>,
|
||||
bias: Option<B::TensorPrimitive<1>>,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
dilation: usize,
|
||||
output_grad: B::TensorPrimitive<3>,
|
||||
) -> Conv1dBackward<B> {
|
||||
conv::conv1d_backward(x, weight, bias, stride, output_grad)
|
||||
conv::conv1d_backward(x, weight, bias, stride, padding, dilation, output_grad)
|
||||
}
|
||||
/// Two dimensional max pooling.
|
||||
///
|
||||
|
|
|
@ -2,9 +2,8 @@ use super::{Conv1dBackward, Conv2dBackward};
|
|||
use crate::{backend::Backend, Shape};
|
||||
use libm::ceilf;
|
||||
|
||||
/// Calculate the expected padding size required when applying a convolution with the specified
|
||||
/// kernel size, stride, and input size to get the desired output size.
|
||||
pub fn calculate_padding(
|
||||
/// Calculate the expected padding size required when applying a convolution.
|
||||
pub fn calculate_conv_padding(
|
||||
kernel_size: usize,
|
||||
stride: usize,
|
||||
size_in: usize,
|
||||
|
@ -21,29 +20,22 @@ pub fn calculate_padding(
|
|||
padding as usize
|
||||
}
|
||||
|
||||
/// Calculate the expected output size when applying a convolution with the specified kernel size,
|
||||
/// stride and padding.
|
||||
pub fn calculate_output_size(
|
||||
/// Calculate the expected output size when doing a convolution operation.
|
||||
pub fn calculate_conv_output_size(
|
||||
kernel_size: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
dilation: usize,
|
||||
size_in: usize,
|
||||
) -> usize {
|
||||
let kernel_size = kernel_size as f32;
|
||||
let stride = stride as f32;
|
||||
let padding = padding as f32;
|
||||
let size_in = size_in as f32;
|
||||
|
||||
let size_out = (size_in + (2. * padding) - kernel_size) / stride;
|
||||
let size_out = ceilf(size_out + 1.);
|
||||
|
||||
size_out as usize
|
||||
(size_in + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1
|
||||
}
|
||||
|
||||
fn calculate_padding_out(
|
||||
kernel_size: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
dilation: usize,
|
||||
size_in: usize,
|
||||
size_out: usize,
|
||||
) -> usize {
|
||||
|
@ -51,8 +43,10 @@ fn calculate_padding_out(
|
|||
return 0;
|
||||
}
|
||||
|
||||
let out = calculate_output_size(kernel_size, stride, padding, size_out) as i64;
|
||||
i64::max(0, out - size_in as i64) as usize
|
||||
let out = 1 + libm::ceil(
|
||||
(size_in + 2 * padding - dilation * (kernel_size - 1) - 1) as f64 / stride as f64,
|
||||
) as usize;
|
||||
i64::max(0, out as i64 - size_out as i64) as usize
|
||||
}
|
||||
|
||||
/// Calculate the [1D convolution](crate::ops::ModuleOps::conv1d) backward pass using convolutions.
|
||||
|
@ -61,14 +55,22 @@ pub(crate) fn conv1d_backward<B: Backend>(
|
|||
weight: B::TensorPrimitive<3>,
|
||||
bias: Option<B::TensorPrimitive<1>>,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
dilation: usize,
|
||||
output_grad: B::TensorPrimitive<3>,
|
||||
) -> Conv1dBackward<B> {
|
||||
let [batch_size, channels_in, length_in] = B::shape(&x).dims;
|
||||
let [_batch_size, channels_out, length_out] = B::shape(&output_grad).dims;
|
||||
let [_, _, kernel_size] = B::shape(&weight).dims;
|
||||
|
||||
let padding = calculate_padding(kernel_size, stride, length_in, length_out);
|
||||
let padding_out = calculate_padding_out(kernel_size, stride, padding, length_out, length_in);
|
||||
let padding_out = calculate_padding_out(
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
length_in,
|
||||
length_out,
|
||||
);
|
||||
|
||||
let x_grad = B::conv_transpose1d(
|
||||
output_grad.clone(),
|
||||
|
@ -77,11 +79,19 @@ pub(crate) fn conv1d_backward<B: Backend>(
|
|||
stride,
|
||||
padding,
|
||||
padding_out,
|
||||
dilation,
|
||||
);
|
||||
|
||||
let x_swapped = B::swap_dims(x, 0, 1);
|
||||
let output_grad_swapped = B::swap_dims(output_grad, 0, 1);
|
||||
let weight_grad_swapped = B::conv1d(x_swapped, output_grad_swapped.clone(), None, 1, padding);
|
||||
let weight_grad_swapped = B::conv1d(
|
||||
x_swapped,
|
||||
output_grad_swapped.clone(),
|
||||
None,
|
||||
dilation,
|
||||
padding,
|
||||
stride,
|
||||
);
|
||||
let mut weight_grad = B::swap_dims(weight_grad_swapped, 0, 1);
|
||||
|
||||
if B::shape(&weight_grad) != Shape::new([channels_out, channels_in, kernel_size]) {
|
||||
|
@ -110,28 +120,39 @@ pub(crate) fn conv2d_backward<B: Backend>(
|
|||
weight: B::TensorPrimitive<4>,
|
||||
bias: Option<B::TensorPrimitive<1>>,
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
output_grad: B::TensorPrimitive<4>,
|
||||
) -> Conv2dBackward<B> {
|
||||
let [batch_size, channels_in, height_in, width_in] = B::shape(&x).dims;
|
||||
let [_batch_size, channels_out, height_out, width_out] = B::shape(&output_grad).dims;
|
||||
let [_, _, kernel_size_1, kernel_size_2] = B::shape(&weight).dims;
|
||||
let [stride_1, stride_2] = stride;
|
||||
|
||||
let padding_1 = calculate_padding(kernel_size_1, stride_1, height_in, height_out);
|
||||
let padding_2 = calculate_padding(kernel_size_2, stride_2, width_in, width_out);
|
||||
|
||||
let padding_1_out =
|
||||
calculate_padding_out(kernel_size_1, stride_1, padding_1, height_out, height_in);
|
||||
let padding_2_out =
|
||||
calculate_padding_out(kernel_size_2, stride_2, padding_2, width_out, width_in);
|
||||
let padding_1_out = calculate_padding_out(
|
||||
kernel_size_1,
|
||||
stride[0],
|
||||
padding[0],
|
||||
dilation[0],
|
||||
height_in,
|
||||
height_out,
|
||||
);
|
||||
let padding_2_out = calculate_padding_out(
|
||||
kernel_size_2,
|
||||
stride[1],
|
||||
padding[1],
|
||||
dilation[1],
|
||||
width_in,
|
||||
width_out,
|
||||
);
|
||||
|
||||
let x_grad = B::conv_transpose2d(
|
||||
output_grad.clone(),
|
||||
weight,
|
||||
None,
|
||||
[stride_1, stride_2],
|
||||
[padding_1, padding_2],
|
||||
stride,
|
||||
padding,
|
||||
[padding_1_out, padding_2_out],
|
||||
dilation,
|
||||
);
|
||||
|
||||
let x_swapped = B::swap_dims(x, 0, 1);
|
||||
|
@ -140,8 +161,9 @@ pub(crate) fn conv2d_backward<B: Backend>(
|
|||
x_swapped,
|
||||
output_grad_swapped.clone(),
|
||||
None,
|
||||
[1, 1],
|
||||
[padding_1, padding_2],
|
||||
dilation,
|
||||
padding,
|
||||
stride,
|
||||
);
|
||||
let mut weight_grad = B::swap_dims(weight_grad_swapped, 0, 1);
|
||||
|
||||
|
@ -182,6 +204,7 @@ pub(crate) fn conv1d_from_conv2d<B: Backend>(
|
|||
bias: Option<B::TensorPrimitive<1>>,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
dilation: usize,
|
||||
) -> B::TensorPrimitive<3> {
|
||||
let [channels_out, _channels_in, kernel_size] = B::shape(&weight).dims;
|
||||
let [batch_size, channels_in, length_in] = B::shape(&x).dims;
|
||||
|
@ -192,7 +215,7 @@ pub(crate) fn conv1d_from_conv2d<B: Backend>(
|
|||
);
|
||||
let x = B::reshape(x, Shape::new([batch_size, channels_in, length_in, 1]));
|
||||
|
||||
let tensor = B::conv2d(x, weight, bias, [stride, 1], [padding, 0]);
|
||||
let tensor = B::conv2d(x, weight, bias, [stride, 1], [padding, 0], [dilation, 1]);
|
||||
let [batch_size, channels_out, height_out, _weight_out] = B::shape(&tensor).dims;
|
||||
B::reshape(tensor, Shape::from([batch_size, channels_out, height_out]))
|
||||
}
|
||||
|
@ -205,6 +228,7 @@ pub(crate) fn conv_transpose1d_from_conv_transpose2d<B: Backend>(
|
|||
stride: usize,
|
||||
padding: usize,
|
||||
padding_out: usize,
|
||||
dilation: usize,
|
||||
) -> B::TensorPrimitive<3> {
|
||||
let [channels_in, channels_out, kernel_size] = B::shape(&weight).dims;
|
||||
let [batch_size, _channels_in, length_in] = B::shape(&x).dims;
|
||||
|
@ -215,7 +239,15 @@ pub(crate) fn conv_transpose1d_from_conv_transpose2d<B: Backend>(
|
|||
);
|
||||
let x = B::reshape(x, Shape::new([batch_size, channels_in, length_in, 1]));
|
||||
|
||||
let tensor = B::conv_transpose2d(x, weight, bias, [stride, 1], [padding, 0], [padding_out, 0]);
|
||||
let tensor = B::conv_transpose2d(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
[stride, 1],
|
||||
[padding, 0],
|
||||
[padding_out, 0],
|
||||
[dilation, 1],
|
||||
);
|
||||
let [batch_size, channels_out, height_out, _weight_out] = B::shape(&tensor).dims;
|
||||
B::reshape(tensor, Shape::from([batch_size, channels_out, height_out]))
|
||||
}
|
||||
|
@ -230,8 +262,9 @@ mod tests {
|
|||
let stride = 1;
|
||||
let padding = 1;
|
||||
let size_in = 3;
|
||||
let dilation = 1;
|
||||
|
||||
let size_out = calculate_output_size(kernel_size, stride, padding, size_in);
|
||||
let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in);
|
||||
|
||||
assert_eq!(size_out, 3);
|
||||
}
|
||||
|
@ -242,20 +275,35 @@ mod tests {
|
|||
let stride = 2;
|
||||
let padding = 3;
|
||||
let size_in = 27;
|
||||
let dilation = 1;
|
||||
|
||||
let size_out = calculate_output_size(kernel_size, stride, padding, size_in);
|
||||
let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in);
|
||||
|
||||
assert_eq!(size_out, 15);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_calculate_output_size_3() {
|
||||
let kernel_size = 5;
|
||||
let stride = 2;
|
||||
let padding = 3;
|
||||
let size_in = 27;
|
||||
let dilation = 2;
|
||||
|
||||
let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in);
|
||||
|
||||
assert_eq!(size_out, 13);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_calculate_same_padding_1() {
|
||||
let kernel_size = 3;
|
||||
let stride = 1;
|
||||
let size_in = 3;
|
||||
let dilation = 1;
|
||||
|
||||
let padding = calculate_padding(kernel_size, stride, size_in, size_in);
|
||||
let size_out = calculate_output_size(kernel_size, stride, padding, size_in);
|
||||
let padding = calculate_conv_padding(kernel_size, stride, size_in, size_in);
|
||||
let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in);
|
||||
|
||||
assert_eq!(size_in, size_out, "Expected size");
|
||||
}
|
||||
|
@ -265,9 +313,10 @@ mod tests {
|
|||
let kernel_size = 3;
|
||||
let stride = 2;
|
||||
let size_in = 7;
|
||||
let dilation = 1;
|
||||
|
||||
let padding = calculate_padding(kernel_size, stride, size_in, size_in);
|
||||
let size_out = calculate_output_size(kernel_size, stride, padding, size_in);
|
||||
let padding = calculate_conv_padding(kernel_size, stride, size_in, size_in);
|
||||
let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in);
|
||||
|
||||
assert_eq!(size_in, size_out, "Expected size");
|
||||
}
|
||||
|
@ -278,9 +327,11 @@ mod tests {
|
|||
let stride = 2;
|
||||
let size_in = 7;
|
||||
let size_out = 10;
|
||||
let dilation = 1;
|
||||
|
||||
let padding = calculate_padding(kernel_size, stride, size_in, size_out);
|
||||
let size_out_expected = calculate_output_size(kernel_size, stride, padding, size_in);
|
||||
let padding = calculate_conv_padding(kernel_size, stride, size_in, size_out);
|
||||
let size_out_expected =
|
||||
calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in);
|
||||
|
||||
assert_eq!(size_out, size_out_expected, "Expected size");
|
||||
}
|
||||
|
|
|
@ -13,6 +13,7 @@ mod tests {
|
|||
kernel_size: 3,
|
||||
padding: 1,
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
length: 6,
|
||||
};
|
||||
|
||||
|
@ -30,6 +31,25 @@ mod tests {
|
|||
]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv1d_dilation() {
|
||||
let test = Conv1dTestCase {
|
||||
batch_size: 2,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size: 3,
|
||||
padding: 1,
|
||||
stride: 1,
|
||||
dilation: 2,
|
||||
length: 4,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from_floats([
|
||||
[[5., 5.], [5., 5.]],
|
||||
[[5., 5.], [5., 5.]],
|
||||
]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv1d_complex() {
|
||||
let test = Conv1dTestCase {
|
||||
|
@ -39,6 +59,7 @@ mod tests {
|
|||
kernel_size: 3,
|
||||
padding: 1,
|
||||
stride: 2,
|
||||
dilation: 1,
|
||||
length: 9,
|
||||
};
|
||||
|
||||
|
@ -65,6 +86,7 @@ mod tests {
|
|||
kernel_size: usize,
|
||||
padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
length: usize,
|
||||
}
|
||||
|
||||
|
@ -73,7 +95,14 @@ mod tests {
|
|||
let weights = TestTensor::ones([self.channels_out, self.channels_in, self.kernel_size]);
|
||||
let bias = TestTensor::ones([self.channels_out]);
|
||||
let x = TestTensor::ones([self.batch_size, self.channels_in, self.length]);
|
||||
let output = conv1d(x, weights, Some(bias), self.stride, self.padding);
|
||||
let output = conv1d(
|
||||
x,
|
||||
weights,
|
||||
Some(bias),
|
||||
self.stride,
|
||||
self.padding,
|
||||
self.dilation,
|
||||
);
|
||||
|
||||
y.to_data().assert_approx_eq(&output.into_data(), 3);
|
||||
}
|
||||
|
|
|
@ -16,6 +16,8 @@ mod tests {
|
|||
padding_2: 1,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
height: 6,
|
||||
width: 6,
|
||||
};
|
||||
|
@ -88,62 +90,24 @@ mod tests {
|
|||
padding_2: 2,
|
||||
stride_1: 2,
|
||||
stride_2: 3,
|
||||
height: 7,
|
||||
width: 9,
|
||||
dilation_1: 1,
|
||||
dilation_2: 2,
|
||||
height: 4,
|
||||
width: 5,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from_floats([
|
||||
[
|
||||
[
|
||||
[1., 13., 13., 13.],
|
||||
[1., 19., 19., 19.],
|
||||
[1., 19., 19., 19.],
|
||||
[1., 13., 13., 13.],
|
||||
[[7., 13., 7.], [10., 19., 10.]],
|
||||
[[7., 13., 7.], [10., 19., 10.]],
|
||||
[[7., 13., 7.], [10., 19., 10.]],
|
||||
[[7., 13., 7.], [10., 19., 10.]],
|
||||
],
|
||||
[
|
||||
[1., 13., 13., 13.],
|
||||
[1., 19., 19., 19.],
|
||||
[1., 19., 19., 19.],
|
||||
[1., 13., 13., 13.],
|
||||
],
|
||||
[
|
||||
[1., 13., 13., 13.],
|
||||
[1., 19., 19., 19.],
|
||||
[1., 19., 19., 19.],
|
||||
[1., 13., 13., 13.],
|
||||
],
|
||||
[
|
||||
[1., 13., 13., 13.],
|
||||
[1., 19., 19., 19.],
|
||||
[1., 19., 19., 19.],
|
||||
[1., 13., 13., 13.],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[1., 13., 13., 13.],
|
||||
[1., 19., 19., 19.],
|
||||
[1., 19., 19., 19.],
|
||||
[1., 13., 13., 13.],
|
||||
],
|
||||
[
|
||||
[1., 13., 13., 13.],
|
||||
[1., 19., 19., 19.],
|
||||
[1., 19., 19., 19.],
|
||||
[1., 13., 13., 13.],
|
||||
],
|
||||
[
|
||||
[1., 13., 13., 13.],
|
||||
[1., 19., 19., 19.],
|
||||
[1., 19., 19., 19.],
|
||||
[1., 13., 13., 13.],
|
||||
],
|
||||
[
|
||||
[1., 13., 13., 13.],
|
||||
[1., 19., 19., 19.],
|
||||
[1., 19., 19., 19.],
|
||||
[1., 13., 13., 13.],
|
||||
],
|
||||
[[7., 13., 7.], [10., 19., 10.]],
|
||||
[[7., 13., 7.], [10., 19., 10.]],
|
||||
[[7., 13., 7.], [10., 19., 10.]],
|
||||
[[7., 13., 7.], [10., 19., 10.]],
|
||||
],
|
||||
]));
|
||||
}
|
||||
|
@ -158,6 +122,8 @@ mod tests {
|
|||
padding_2: usize,
|
||||
stride_1: usize,
|
||||
stride_2: usize,
|
||||
dilation_1: usize,
|
||||
dilation_2: usize,
|
||||
height: usize,
|
||||
width: usize,
|
||||
}
|
||||
|
@ -178,6 +144,7 @@ mod tests {
|
|||
Some(bias),
|
||||
[self.stride_1, self.stride_2],
|
||||
[self.padding_1, self.padding_2],
|
||||
[self.dilation_1, self.dilation_2],
|
||||
);
|
||||
|
||||
y.to_data().assert_approx_eq(&output.into_data(), 3);
|
||||
|
|
|
@ -14,6 +14,7 @@ mod tests {
|
|||
padding: 1,
|
||||
padding_out: 0,
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
length: 4,
|
||||
};
|
||||
|
||||
|
@ -33,6 +34,7 @@ mod tests {
|
|||
padding: 1,
|
||||
padding_out: 1,
|
||||
stride: 2,
|
||||
dilation: 1,
|
||||
length: 4,
|
||||
};
|
||||
|
||||
|
@ -42,6 +44,26 @@ mod tests {
|
|||
]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose1d_dilation() {
|
||||
let test = ConvTranspose1dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size: 3,
|
||||
padding: 1,
|
||||
padding_out: 0,
|
||||
stride: 1,
|
||||
dilation: 2,
|
||||
length: 4,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from_floats([[
|
||||
[30., 64., 78., 76., 94., 52.],
|
||||
[49., 101., 127., 113., 143., 77.],
|
||||
]]));
|
||||
}
|
||||
|
||||
struct ConvTranspose1dTestCase {
|
||||
batch_size: usize,
|
||||
channels_in: usize,
|
||||
|
@ -50,6 +72,7 @@ mod tests {
|
|||
padding: usize,
|
||||
padding_out: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
length: usize,
|
||||
}
|
||||
|
||||
|
@ -81,6 +104,7 @@ mod tests {
|
|||
self.stride,
|
||||
self.padding,
|
||||
self.padding_out,
|
||||
self.dilation,
|
||||
);
|
||||
|
||||
y.to_data().assert_approx_eq(&output.into_data(), 3);
|
||||
|
|
|
@ -18,6 +18,8 @@ mod tests {
|
|||
padding_out_2: 0,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
height: 2,
|
||||
width: 2,
|
||||
};
|
||||
|
@ -38,6 +40,8 @@ mod tests {
|
|||
padding_out_2: 0,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
height: 4,
|
||||
width: 4,
|
||||
};
|
||||
|
@ -78,6 +82,8 @@ mod tests {
|
|||
padding_out_2: 0,
|
||||
stride_1: 2,
|
||||
stride_2: 2,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
height: 2,
|
||||
width: 2,
|
||||
};
|
||||
|
@ -90,6 +96,44 @@ mod tests {
|
|||
]]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose2d_dilation_2() {
|
||||
let test = ConvTranspose2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 1,
|
||||
padding_2: 1,
|
||||
padding_out_1: 1,
|
||||
padding_out_2: 1,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 2,
|
||||
dilation_2: 2,
|
||||
height: 2,
|
||||
width: 2,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from_floats([[
|
||||
[
|
||||
[126., 116., 136., 124., 146.],
|
||||
[108., 88., 114., 92., 120.],
|
||||
[156., 140., 166., 148., 176.],
|
||||
[126., 100., 132., 104., 138.],
|
||||
[186., 164., 196., 172., 206.],
|
||||
],
|
||||
[
|
||||
[217., 189., 227., 197., 237.],
|
||||
[163., 125., 169., 129., 175.],
|
||||
[247., 213., 257., 221., 267.],
|
||||
[181., 137., 187., 141., 193.],
|
||||
[277., 237., 287., 245., 297.],
|
||||
],
|
||||
]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose2d_stride2_out_padding() {
|
||||
let test = ConvTranspose2dTestCase {
|
||||
|
@ -104,6 +148,8 @@ mod tests {
|
|||
padding_out_2: 1,
|
||||
stride_1: 2,
|
||||
stride_2: 2,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
height: 4,
|
||||
width: 4,
|
||||
};
|
||||
|
@ -144,6 +190,8 @@ mod tests {
|
|||
padding_out_2: usize,
|
||||
stride_1: usize,
|
||||
stride_2: usize,
|
||||
dilation_1: usize,
|
||||
dilation_2: usize,
|
||||
height: usize,
|
||||
width: usize,
|
||||
}
|
||||
|
@ -181,6 +229,7 @@ mod tests {
|
|||
[self.stride_1, self.stride_2],
|
||||
[self.padding_1, self.padding_2],
|
||||
[self.padding_out_1, self.padding_out_2],
|
||||
[self.dilation_1, self.dilation_2],
|
||||
);
|
||||
|
||||
y.to_data().assert_approx_eq(&output.into_data(), 3);
|
||||
|
|
Binary file not shown.
Loading…
Reference in New Issue