mirror of https://github.com/tracel-ai/burn.git
Feat/group conv (#306)
This commit is contained in:
parent
78ac09fb7a
commit
c5e31b272f
|
@ -50,9 +50,7 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
x: ADTensor<B, 4>,
|
||||
weight: ADTensor<B, 4>,
|
||||
bias: Option<ADTensor<B, 1>>,
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
options: ConvOptions<2>,
|
||||
) -> ADTensor<B, 4> {
|
||||
#[derive(Debug)]
|
||||
struct Conv2DWithBias;
|
||||
|
@ -64,18 +62,15 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
B::TensorPrimitive<4>,
|
||||
B::TensorPrimitive<4>,
|
||||
B::TensorPrimitive<1>,
|
||||
[usize; 2],
|
||||
[usize; 2],
|
||||
[usize; 2],
|
||||
ConvOptions<2>,
|
||||
);
|
||||
|
||||
fn backward(self, ops: Ops<Self::State, 3>, grads: &mut Gradients) {
|
||||
let [node_x, node_weight, node_bias] = ops.parents;
|
||||
let grad = grads.consume::<B, 4>(&ops.node);
|
||||
|
||||
let (x, weight, bias, stride, padding, dilation) = ops.state;
|
||||
let backward =
|
||||
B::conv2d_backward(x, weight, Some(bias), stride, padding, dilation, grad);
|
||||
let (x, weight, bias, options) = ops.state;
|
||||
let backward = B::conv2d_backward(x, weight, Some(bias), grad, options);
|
||||
|
||||
if let Some(node) = node_x {
|
||||
grads.register::<B, 4>(node, backward.x_grad)
|
||||
|
@ -90,20 +85,14 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
}
|
||||
|
||||
impl<B: Backend> Backward<B, 4, 2> for Conv2DNoBias {
|
||||
type State = (
|
||||
B::TensorPrimitive<4>,
|
||||
B::TensorPrimitive<4>,
|
||||
[usize; 2],
|
||||
[usize; 2],
|
||||
[usize; 2],
|
||||
);
|
||||
type State = (B::TensorPrimitive<4>, B::TensorPrimitive<4>, ConvOptions<2>);
|
||||
|
||||
fn backward(self, ops: Ops<Self::State, 2>, grads: &mut Gradients) {
|
||||
let [node_x, node_weight] = ops.parents;
|
||||
let grad = grads.consume::<B, 4>(&ops.node);
|
||||
|
||||
let (x, weight, stride, padding, dilation) = ops.state;
|
||||
let backward = B::conv2d_backward(x, weight, None, stride, padding, dilation, grad);
|
||||
let (x, weight, options) = ops.state;
|
||||
let backward = B::conv2d_backward(x, weight, None, grad, options);
|
||||
|
||||
if let Some(node) = node_x {
|
||||
grads.register::<B, 4>(node, backward.x_grad)
|
||||
|
@ -128,26 +117,15 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
x.primitive.clone(),
|
||||
weight.primitive.clone(),
|
||||
bias.primitive.clone(),
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
),
|
||||
B::conv2d(
|
||||
x.primitive,
|
||||
weight.primitive,
|
||||
Some(bias.primitive),
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
options.clone(),
|
||||
),
|
||||
B::conv2d(x.primitive, weight.primitive, Some(bias.primitive), options),
|
||||
),
|
||||
OpsKind::UnTracked(prep) => prep.finish(B::conv2d(
|
||||
x.primitive,
|
||||
weight.primitive,
|
||||
Some(bias.primitive),
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
options,
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
@ -160,27 +138,13 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
(
|
||||
x.primitive.clone(),
|
||||
weight.primitive.clone(),
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
),
|
||||
B::conv2d(
|
||||
x.primitive,
|
||||
weight.primitive,
|
||||
None,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
options.clone(),
|
||||
),
|
||||
B::conv2d(x.primitive, weight.primitive, None, options),
|
||||
),
|
||||
OpsKind::UnTracked(prep) => prep.finish(B::conv2d(
|
||||
x.primitive,
|
||||
weight.primitive,
|
||||
None,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
)),
|
||||
OpsKind::UnTracked(prep) => {
|
||||
prep.finish(B::conv2d(x.primitive, weight.primitive, None, options))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -190,33 +154,16 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
_x: ADTensor<B, 4>,
|
||||
_weight: ADTensor<B, 4>,
|
||||
_bias: Option<ADTensor<B, 1>>,
|
||||
_stride: [usize; 2],
|
||||
_padding: [usize; 2],
|
||||
_padding_out: [usize; 2],
|
||||
_dilation: [usize; 2],
|
||||
_options: ConvTransposeOptions<2>,
|
||||
) -> ADTensor<B, 4> {
|
||||
todo!("Transposed 2D convolution doesn't yet support backward.");
|
||||
}
|
||||
|
||||
fn conv_transpose1d(
|
||||
_x: ADTensor<B, 3>,
|
||||
_weight: ADTensor<B, 3>,
|
||||
_bias: Option<ADTensor<B, 1>>,
|
||||
_stride: usize,
|
||||
_padding: usize,
|
||||
_padding_out: usize,
|
||||
_dilation: usize,
|
||||
) -> ADTensor<B, 3> {
|
||||
todo!("Transposed 1D convolution doesn't yet support backward.");
|
||||
}
|
||||
|
||||
fn conv1d(
|
||||
x: ADTensor<B, 3>,
|
||||
weight: ADTensor<B, 3>,
|
||||
bias: Option<ADTensor<B, 1>>,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
dilation: usize,
|
||||
options: ConvOptions<1>,
|
||||
) -> ADTensor<B, 3> {
|
||||
#[derive(Debug)]
|
||||
struct Conv1DWithBias;
|
||||
|
@ -228,18 +175,15 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
B::TensorPrimitive<3>,
|
||||
B::TensorPrimitive<3>,
|
||||
B::TensorPrimitive<1>,
|
||||
usize,
|
||||
usize,
|
||||
usize,
|
||||
ConvOptions<1>,
|
||||
);
|
||||
|
||||
fn backward(self, ops: Ops<Self::State, 3>, grads: &mut Gradients) {
|
||||
let [node_x, node_weight, node_bias] = ops.parents;
|
||||
let grad = grads.consume::<B, 3>(&ops.node);
|
||||
|
||||
let (x, weight, bias, stride, padding, dilation) = ops.state;
|
||||
let backward =
|
||||
B::conv1d_backward(x, weight, Some(bias), stride, padding, dilation, grad);
|
||||
let (x, weight, bias, options) = ops.state;
|
||||
let backward = B::conv1d_backward(x, weight, Some(bias), grad, options);
|
||||
|
||||
if let Some(node) = node_x {
|
||||
grads.register::<B, 3>(node, backward.x_grad)
|
||||
|
@ -254,20 +198,14 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
}
|
||||
|
||||
impl<B: Backend> Backward<B, 3, 2> for Conv1DNoBias {
|
||||
type State = (
|
||||
B::TensorPrimitive<3>,
|
||||
B::TensorPrimitive<3>,
|
||||
usize,
|
||||
usize,
|
||||
usize,
|
||||
);
|
||||
type State = (B::TensorPrimitive<3>, B::TensorPrimitive<3>, ConvOptions<1>);
|
||||
|
||||
fn backward(self, ops: Ops<Self::State, 2>, grads: &mut Gradients) {
|
||||
let [node_x, node_weight] = ops.parents;
|
||||
let grad = grads.consume::<B, 3>(&ops.node);
|
||||
|
||||
let (x, weight, stride, padding, dilation) = ops.state;
|
||||
let backward = B::conv1d_backward(x, weight, None, stride, padding, dilation, grad);
|
||||
let (x, weight, options) = ops.state;
|
||||
let backward = B::conv1d_backward(x, weight, None, grad, options);
|
||||
|
||||
if let Some(node) = node_x {
|
||||
grads.register::<B, 3>(node, backward.x_grad)
|
||||
|
@ -291,26 +229,15 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
x.primitive.clone(),
|
||||
weight.primitive.clone(),
|
||||
bias.primitive.clone(),
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
),
|
||||
B::conv1d(
|
||||
x.primitive,
|
||||
weight.primitive,
|
||||
Some(bias.primitive),
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
options.clone(),
|
||||
),
|
||||
B::conv1d(x.primitive, weight.primitive, Some(bias.primitive), options),
|
||||
),
|
||||
OpsKind::UnTracked(prep) => prep.finish(B::conv1d(
|
||||
x.primitive,
|
||||
weight.primitive,
|
||||
Some(bias.primitive),
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
options,
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
@ -323,32 +250,27 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
(
|
||||
x.primitive.clone(),
|
||||
weight.primitive.clone(),
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
),
|
||||
B::conv1d(
|
||||
x.primitive,
|
||||
weight.primitive,
|
||||
None,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
options.clone(),
|
||||
),
|
||||
B::conv1d(x.primitive, weight.primitive, None, options),
|
||||
),
|
||||
OpsKind::UnTracked(prep) => prep.finish(B::conv1d(
|
||||
x.primitive,
|
||||
weight.primitive,
|
||||
None,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
)),
|
||||
OpsKind::UnTracked(prep) => {
|
||||
prep.finish(B::conv1d(x.primitive, weight.primitive, None, options))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn conv_transpose1d(
|
||||
_x: ADTensor<B, 3>,
|
||||
_weight: ADTensor<B, 3>,
|
||||
_bias: Option<ADTensor<B, 1>>,
|
||||
_options: ConvTransposeOptions<1>,
|
||||
) -> ADTensor<B, 3> {
|
||||
todo!("Transposed 1D convolution doesn't yet support backward.");
|
||||
}
|
||||
|
||||
fn max_pool2d(
|
||||
x: ADTensor<B, 4>,
|
||||
kernel_size: [usize; 2],
|
||||
|
|
|
@ -1,39 +1,31 @@
|
|||
#[burn_tensor_testgen::testgen(ad_conv1d)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::{module::conv1d, Data};
|
||||
use burn_tensor::{module::conv1d, ops::ConvOptions, Data, Shape};
|
||||
|
||||
#[test]
|
||||
fn test_conv1d_basic() {
|
||||
let test = Conv1dTestCase {
|
||||
batch_size: 2,
|
||||
channels_in: 3,
|
||||
channels_out: 3,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size: 3,
|
||||
padding: 1,
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
length: 6,
|
||||
groups: 1,
|
||||
length: 4,
|
||||
};
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats([
|
||||
[
|
||||
[6., 9., 9., 9., 9., 6.],
|
||||
[6., 9., 9., 9., 9., 6.],
|
||||
[6., 9., 9., 9., 9., 6.],
|
||||
],
|
||||
[
|
||||
[6., 9., 9., 9., 9., 6.],
|
||||
[6., 9., 9., 9., 9., 6.],
|
||||
[6., 9., 9., 9., 9., 6.],
|
||||
],
|
||||
[[14., 24., 24., 18.], [26., 42., 42., 30.]],
|
||||
[[14., 24., 24., 18.], [26., 42., 42., 30.]],
|
||||
]),
|
||||
weight: TestTensor::from_floats([
|
||||
[[10., 12., 10.], [10., 12., 10.], [10., 12., 10.]],
|
||||
[[10., 12., 10.], [10., 12., 10.], [10., 12., 10.]],
|
||||
[[10., 12., 10.], [10., 12., 10.], [10., 12., 10.]],
|
||||
[[30., 44., 36.], [54., 76., 60.]],
|
||||
[[30., 44., 36.], [54., 76., 60.]],
|
||||
]),
|
||||
bias: TestTensor::from_floats([12., 12., 12.]),
|
||||
bias: TestTensor::from_floats([8., 8.]),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
@ -48,19 +40,20 @@ mod tests {
|
|||
padding: 1,
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
length: 6,
|
||||
groups: 1,
|
||||
length: 4,
|
||||
};
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats([
|
||||
[[6., 9., 9., 9., 9., 6.], [6., 9., 9., 9., 9., 6.]],
|
||||
[[6., 9., 9., 9., 9., 6.], [6., 9., 9., 9., 9., 6.]],
|
||||
[[39., 63., 63., 45.], [57., 90., 90., 63.]],
|
||||
[[39., 63., 63., 45.], [57., 90., 90., 63.]],
|
||||
]),
|
||||
weight: TestTensor::from_floats([
|
||||
[[10., 12., 10.], [10., 12., 10.]],
|
||||
[[10., 12., 10.], [10., 12., 10.]],
|
||||
[[10., 12., 10.], [10., 12., 10.]],
|
||||
[[30., 44., 36.], [54., 76., 60.]],
|
||||
[[30., 44., 36.], [54., 76., 60.]],
|
||||
[[30., 44., 36.], [54., 76., 60.]],
|
||||
]),
|
||||
bias: TestTensor::from_floats([12., 12., 12.]),
|
||||
bias: TestTensor::from_floats([8., 8., 8.]),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
@ -75,18 +68,19 @@ mod tests {
|
|||
padding: 2,
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
length: 6,
|
||||
groups: 1,
|
||||
length: 4,
|
||||
};
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats([
|
||||
[[6., 6., 6., 6., 6., 6.], [6., 6., 6., 6., 6., 6.]],
|
||||
[[6., 6., 6., 6., 6., 6.], [6., 6., 6., 6., 6., 6.]],
|
||||
[[24., 24., 24., 24.], [42., 42., 42., 42.]],
|
||||
[[24., 24., 24., 24.], [42., 42., 42., 42.]],
|
||||
]),
|
||||
weight: TestTensor::from_floats([
|
||||
[[12., 12., 12.], [12., 12., 12.]],
|
||||
[[12., 12., 12.], [12., 12., 12.]],
|
||||
[[44., 44., 44.], [76., 76., 76.]],
|
||||
[[44., 44., 44.], [76., 76., 76.]],
|
||||
]),
|
||||
bias: TestTensor::from_floats([16., 16.]),
|
||||
bias: TestTensor::from_floats([12., 12.]),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
@ -101,16 +95,17 @@ mod tests {
|
|||
padding: 1,
|
||||
stride: 2,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
length: 4,
|
||||
};
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats([
|
||||
[[2., 4., 2., 2.], [2., 4., 2., 2.]],
|
||||
[[2., 4., 2., 2.], [2., 4., 2., 2.]],
|
||||
[[8., 16., 8., 10.], [14., 28., 14., 16.]],
|
||||
[[8., 16., 8., 10.], [14., 28., 14., 16.]],
|
||||
]),
|
||||
weight: TestTensor::from_floats([
|
||||
[[2., 4., 4.], [2., 4., 4.]],
|
||||
[[2., 4., 4.], [2., 4., 4.]],
|
||||
[[10., 20., 24.], [18., 36., 40.]],
|
||||
[[10., 20., 24.], [18., 36., 40.]],
|
||||
]),
|
||||
bias: TestTensor::from_floats([4., 4.]),
|
||||
};
|
||||
|
@ -127,22 +122,47 @@ mod tests {
|
|||
padding: 1,
|
||||
stride: 1,
|
||||
dilation: 2,
|
||||
groups: 1,
|
||||
length: 4,
|
||||
};
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats([
|
||||
[[2., 2., 2., 2.], [2., 2., 2., 2.]],
|
||||
[[2., 2., 2., 2.], [2., 2., 2., 2.]],
|
||||
[[6., 8., 8., 10.], [12., 14., 14., 16.]],
|
||||
[[6., 8., 8., 10.], [12., 14., 14., 16.]],
|
||||
]),
|
||||
weight: TestTensor::from_floats([
|
||||
[[2., 4., 2.], [2., 4., 2.]],
|
||||
[[2., 4., 2.], [2., 4., 2.]],
|
||||
[[8., 22., 14.], [16., 38., 22.]],
|
||||
[[8., 22., 14.], [16., 38., 22.]],
|
||||
]),
|
||||
bias: TestTensor::from_floats([4., 4.]),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv1d_groups() {
|
||||
let test = Conv1dTestCase {
|
||||
batch_size: 2,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size: 3,
|
||||
padding: 1,
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
groups: 2,
|
||||
length: 4,
|
||||
};
|
||||
let grads = Grads {
|
||||
x: TestTensor::from_floats([
|
||||
[[1., 3., 3., 3.], [7., 12., 12., 9.]],
|
||||
[[1., 3., 3., 3.], [7., 12., 12., 9.]],
|
||||
]),
|
||||
weight: TestTensor::from_floats([[[30., 44., 36.]], [[54., 76., 60.]]]),
|
||||
bias: TestTensor::from_floats([8., 8.]),
|
||||
};
|
||||
test.assert_grads(grads);
|
||||
}
|
||||
|
||||
struct Conv1dTestCase {
|
||||
batch_size: usize,
|
||||
channels_in: usize,
|
||||
|
@ -151,6 +171,7 @@ mod tests {
|
|||
padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
groups: usize,
|
||||
length: usize,
|
||||
}
|
||||
|
||||
|
@ -162,19 +183,38 @@ mod tests {
|
|||
|
||||
impl Conv1dTestCase {
|
||||
fn assert_grads(self, expected_grads: Grads) {
|
||||
let weight =
|
||||
TestADTensor::ones([self.channels_out, self.channels_in, self.kernel_size])
|
||||
.require_grad();
|
||||
let bias = TestADTensor::ones([self.channels_out]).require_grad();
|
||||
let x =
|
||||
TestADTensor::ones([self.batch_size, self.channels_in, self.length]).require_grad();
|
||||
let shape_x = Shape::new([self.batch_size, self.channels_in, self.length]);
|
||||
let shape_weight = Shape::new([
|
||||
self.channels_out,
|
||||
self.channels_in / self.groups,
|
||||
self.kernel_size,
|
||||
]);
|
||||
let weight = TestADTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_weight.num_elements())
|
||||
.reshape(shape_weight)
|
||||
.into_data()
|
||||
.convert(),
|
||||
)
|
||||
.require_grad();
|
||||
let bias = TestADTensor::from_data(
|
||||
TestTensorInt::arange(0..self.channels_out)
|
||||
.into_data()
|
||||
.convert(),
|
||||
)
|
||||
.require_grad();
|
||||
let x = TestADTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_x.num_elements())
|
||||
.reshape(shape_x)
|
||||
.into_data()
|
||||
.convert(),
|
||||
)
|
||||
.require_grad();
|
||||
|
||||
let output = conv1d(
|
||||
x.clone(),
|
||||
weight.clone(),
|
||||
Some(bias.clone()),
|
||||
self.stride,
|
||||
self.padding,
|
||||
self.dilation,
|
||||
ConvOptions::new([self.stride], [self.padding], [self.dilation], self.groups),
|
||||
);
|
||||
let grads = output.backward();
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -8,6 +8,7 @@ use crate::tensor::backend::Backend;
|
|||
use crate::tensor::Tensor;
|
||||
use burn_tensor::module::conv1d;
|
||||
use burn_tensor::ops::conv::calculate_conv_padding;
|
||||
use burn_tensor::ops::ConvOptions;
|
||||
|
||||
use libm::sqrt;
|
||||
|
||||
|
@ -26,6 +27,9 @@ pub struct Conv1dConfig {
|
|||
/// Spacing between kernel elements.
|
||||
#[config(default = "1")]
|
||||
pub dilation: usize,
|
||||
/// Controls the connections between input and output channels.
|
||||
#[config(default = "1")]
|
||||
pub groups: usize,
|
||||
/// The padding configuration.
|
||||
#[config(default = "Conv1dPaddingConfig::Valid")]
|
||||
pub padding: Conv1dPaddingConfig,
|
||||
|
@ -65,6 +69,7 @@ pub struct Conv1d<B: Backend> {
|
|||
stride: usize,
|
||||
kernel_size: usize,
|
||||
dilation: usize,
|
||||
groups: usize,
|
||||
padding: Conv1dPaddingConfig,
|
||||
}
|
||||
|
||||
|
@ -95,6 +100,7 @@ impl Conv1dConfig {
|
|||
kernel_size: self.kernel_size,
|
||||
padding: self.padding.clone(),
|
||||
dilation: self.dilation,
|
||||
groups: self.groups,
|
||||
}
|
||||
}
|
||||
/// Initialize a new [conv1d](Conv1d) module with a [record](Conv1dRecord).
|
||||
|
@ -106,6 +112,7 @@ impl Conv1dConfig {
|
|||
kernel_size: self.kernel_size,
|
||||
padding: self.padding.clone(),
|
||||
dilation: self.dilation,
|
||||
groups: self.groups,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -133,9 +140,7 @@ impl<B: Backend> Conv1d<B> {
|
|||
input,
|
||||
self.weight.val(),
|
||||
self.bias.as_ref().map(|bias| bias.val()),
|
||||
self.stride,
|
||||
padding,
|
||||
self.dilation,
|
||||
ConvOptions::new([self.stride], [padding], [self.dilation], self.groups),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ use crate::tensor::backend::Backend;
|
|||
use crate::tensor::Tensor;
|
||||
use burn_tensor::module::conv2d;
|
||||
use burn_tensor::ops::conv::calculate_conv_padding;
|
||||
use burn_tensor::ops::ConvOptions;
|
||||
|
||||
use libm::sqrt;
|
||||
|
||||
|
@ -24,6 +25,9 @@ pub struct Conv2dConfig {
|
|||
/// Spacing between kernel elements.
|
||||
#[config(default = "[1, 1]")]
|
||||
pub dilation: [usize; 2],
|
||||
/// Controls the connections between input and output channels.
|
||||
#[config(default = "1")]
|
||||
pub groups: usize,
|
||||
/// The padding configuration.
|
||||
#[config(default = "Conv2dPaddingConfig::Valid")]
|
||||
pub padding: Conv2dPaddingConfig,
|
||||
|
@ -63,6 +67,7 @@ pub struct Conv2d<B: Backend> {
|
|||
stride: [usize; 2],
|
||||
kernel_size: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
groups: usize,
|
||||
padding: Conv2dPaddingConfig,
|
||||
}
|
||||
|
||||
|
@ -98,6 +103,7 @@ impl Conv2dConfig {
|
|||
kernel_size: self.kernel_size,
|
||||
dilation: self.dilation,
|
||||
padding: self.padding.clone(),
|
||||
groups: self.groups,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -110,6 +116,7 @@ impl Conv2dConfig {
|
|||
dilation: self.dilation,
|
||||
kernel_size: self.kernel_size,
|
||||
padding: self.padding.clone(),
|
||||
groups: self.groups,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -130,9 +137,7 @@ impl<B: Backend> Conv2d<B> {
|
|||
input,
|
||||
self.weight.val(),
|
||||
self.bias.as_ref().map(|bias| bias.val()),
|
||||
self.stride,
|
||||
padding,
|
||||
self.dilation,
|
||||
ConvOptions::new(self.stride, padding, self.dilation, self.groups),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,6 +15,7 @@ pub(crate) trait NdArrayElement:
|
|||
+ ndarray::ScalarOperand
|
||||
+ ExpElement
|
||||
+ num_traits::FromPrimitive
|
||||
+ core::ops::AddAssign
|
||||
+ core::cmp::PartialEq
|
||||
+ core::cmp::PartialOrd<Self>
|
||||
{
|
||||
|
|
|
@ -260,7 +260,7 @@ where
|
|||
|
||||
for (i, index) in indexes.iter().enumerate() {
|
||||
let index = *index as usize;
|
||||
tensor[[b, index]] = tensor[[b, index]] + value[[b, i]];
|
||||
tensor[[b, index]] += value[[b, i]];
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -351,7 +351,7 @@ where
|
|||
let mut view = output_array.index_axis_mut(Axis(dim), index as usize);
|
||||
let value = value.array.index_axis(Axis(0), index_value);
|
||||
|
||||
view.zip_mut_with(&value, |a, b| *a = *a + *b);
|
||||
view.zip_mut_with(&value, |a, b| *a += *b);
|
||||
}
|
||||
|
||||
NdArrayTensor::new(output_array.into_shared())
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
use burn_tensor::{ops::conv::calculate_conv_output_size, ElementConversion};
|
||||
use burn_tensor::{
|
||||
ops::{conv::calculate_conv_output_size, ConvOptions, ConvTransposeOptions},
|
||||
ElementConversion,
|
||||
};
|
||||
use ndarray::{Array4, Dim};
|
||||
|
||||
use crate::{
|
||||
|
@ -10,13 +13,11 @@ pub(crate) fn conv2d<E: FloatNdArrayElement>(
|
|||
x: NdArrayTensor<E, 4>,
|
||||
weight: NdArrayTensor<E, 4>,
|
||||
bias: Option<NdArrayTensor<E, 1>>,
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilatation: [usize; 2],
|
||||
options: ConvOptions<2>,
|
||||
) -> NdArrayTensor<E, 4> {
|
||||
let [dilatation_height, dilatation_width] = dilatation;
|
||||
let [padding_height, padding_width] = padding;
|
||||
let [stride_height, stride_width] = stride;
|
||||
let [dilatation_height, dilatation_width] = options.dilation;
|
||||
let [padding_height, padding_width] = options.padding;
|
||||
let [stride_height, stride_width] = options.stride;
|
||||
let [batch_size, _in_channels, in_height, in_width] = x.shape().dims;
|
||||
let [out_channels, in_channels, kernel_height, kernel_width] = weight.shape().dims;
|
||||
|
||||
|
@ -35,7 +36,7 @@ pub(crate) fn conv2d<E: FloatNdArrayElement>(
|
|||
in_width,
|
||||
);
|
||||
|
||||
let x = apply_padding_4d(x, padding, 0i32.elem()).array;
|
||||
let x = apply_padding_4d(x, options.padding, 0i32.elem()).array;
|
||||
|
||||
let mut output = Array4::zeros(Dim([batch_size, out_channels, out_height, out_width]));
|
||||
|
||||
|
@ -45,10 +46,11 @@ pub(crate) fn conv2d<E: FloatNdArrayElement>(
|
|||
iter_par!(0, batch_size * out_channels).for_each(|k| unsafe {
|
||||
let b = k / out_channels;
|
||||
let oc = k % out_channels;
|
||||
let g = k % options.groups;
|
||||
|
||||
let output = unsafe_shared_out.get();
|
||||
|
||||
for ic in 0..in_channels {
|
||||
for ic in (in_channels * g)..(in_channels * (g + 1)) {
|
||||
for kh in 0..kernel_height {
|
||||
for kw in 0..kernel_width {
|
||||
for oh in 0..out_height {
|
||||
|
@ -56,8 +58,9 @@ pub(crate) fn conv2d<E: FloatNdArrayElement>(
|
|||
let ih = oh * stride_height + kh * dilatation_height;
|
||||
let iw = ow * stride_width + kw * dilatation_width;
|
||||
|
||||
output[[b, oc, oh, ow]] = output[[b, oc, oh, ow]]
|
||||
+ x[[b, ic, ih, iw]] * weight.array[[oc, ic, kh, kw]];
|
||||
let weight_ic = ic - (g * in_channels);
|
||||
output[[b, oc, oh, ow]] +=
|
||||
x[[b, ic, ih, iw]] * weight.array[[oc, weight_ic, kh, kw]];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -67,7 +70,7 @@ pub(crate) fn conv2d<E: FloatNdArrayElement>(
|
|||
if let Some(bias) = &bias {
|
||||
for oh in 0..out_height {
|
||||
for ow in 0..out_width {
|
||||
output[[b, oc, oh, ow]] = output[[b, oc, oh, ow]] + bias.array[oc];
|
||||
output[[b, oc, oh, ow]] += bias.array[oc];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -81,15 +84,12 @@ pub(crate) fn conv_transpose2d<E: FloatNdArrayElement>(
|
|||
x: NdArrayTensor<E, 4>,
|
||||
weight: NdArrayTensor<E, 4>,
|
||||
bias: Option<NdArrayTensor<E, 1>>,
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
out_padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
options: ConvTransposeOptions<2>,
|
||||
) -> NdArrayTensor<E, 4> {
|
||||
let [dilation_height, dilation_width] = dilation;
|
||||
let [padding_height, padding_width] = padding;
|
||||
let [stride_height, stride_width] = stride;
|
||||
let [out_padding_height, out_padding_width] = out_padding;
|
||||
let [dilation_height, dilation_width] = options.dilation;
|
||||
let [padding_height, padding_width] = options.padding;
|
||||
let [stride_height, stride_width] = options.stride;
|
||||
let [out_padding_height, out_padding_width] = options.padding_out;
|
||||
let [batch_size, _in_channels, in_height, in_width] = x.shape().dims;
|
||||
let [in_channels, out_channels, kernel_height, kernel_width] = weight.shape().dims;
|
||||
|
||||
|
@ -104,18 +104,28 @@ pub(crate) fn conv_transpose2d<E: FloatNdArrayElement>(
|
|||
+ 1;
|
||||
|
||||
let x = x.array;
|
||||
let mut output = Array4::zeros(Dim([batch_size, out_channels, out_height, out_width]));
|
||||
let mut output = Array4::zeros(Dim([
|
||||
batch_size,
|
||||
out_channels * options.groups,
|
||||
out_height,
|
||||
out_width,
|
||||
]));
|
||||
|
||||
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
|
||||
|
||||
run_par!(|| {
|
||||
iter_par!(0, batch_size * out_channels).for_each(|k| unsafe {
|
||||
let b = k / out_channels;
|
||||
iter_par!(0, batch_size * out_channels * options.groups).for_each(|k| unsafe {
|
||||
let b = k / (out_channels * options.groups);
|
||||
let oc = k % out_channels;
|
||||
let g = k % options.groups;
|
||||
|
||||
let output = unsafe_shared_out.get();
|
||||
|
||||
for ic in 0..in_channels {
|
||||
let oc_out = oc + (out_channels * g);
|
||||
let ic_start = g * (in_channels / options.groups);
|
||||
let ic_end = ic_start + in_channels / options.groups;
|
||||
|
||||
for ic in ic_start..ic_end {
|
||||
for ih in 0..in_height {
|
||||
for iw in 0..in_width {
|
||||
for kh in 0..kernel_height {
|
||||
|
@ -134,8 +144,8 @@ pub(crate) fn conv_transpose2d<E: FloatNdArrayElement>(
|
|||
let oh = oh - padding_height;
|
||||
let ow = ow - padding_width;
|
||||
|
||||
output[[b, oc, oh, ow]] = output[[b, oc, oh, ow]]
|
||||
+ x[[b, ic, ih, iw]] * weight.array[[ic, oc, kh, kw]];
|
||||
output[[b, oc_out, oh, ow]] +=
|
||||
x[[b, ic, ih, iw]] * weight.array[[ic, oc, kh, kw]];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -145,7 +155,7 @@ pub(crate) fn conv_transpose2d<E: FloatNdArrayElement>(
|
|||
if let Some(bias) = &bias {
|
||||
for oh in 0..out_height {
|
||||
for ow in 0..out_width {
|
||||
output[[b, oc, oh, ow]] = output[[b, oc, oh, ow]] + bias.array[oc];
|
||||
output[[b, oc_out, oh, ow]] += bias.array[oc_out];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -160,7 +160,7 @@ pub(crate) fn max_pool2d_backward<E: FloatNdArrayElement>(
|
|||
let index_h = index as usize / width_x;
|
||||
let index_w = index as usize % width_x;
|
||||
|
||||
output[[b, c, index_h, index_w]] = output[[b, c, index_h, index_w]] + grad;
|
||||
output[[b, c, index_h, index_w]] += grad;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
|
|
@ -73,23 +73,18 @@ impl<E: FloatNdArrayElement> ModuleOps<NdArrayBackend<E>> for NdArrayBackend<E>
|
|||
x: NdArrayTensor<E, 4>,
|
||||
weight: NdArrayTensor<E, 4>,
|
||||
bias: Option<NdArrayTensor<E, 1>>,
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
options: ConvOptions<2>,
|
||||
) -> NdArrayTensor<E, 4> {
|
||||
conv2d(x, weight, bias, stride, padding, dilation)
|
||||
conv2d(x, weight, bias, options)
|
||||
}
|
||||
|
||||
fn conv_transpose2d(
|
||||
x: NdArrayTensor<E, 4>,
|
||||
weight: NdArrayTensor<E, 4>,
|
||||
bias: Option<NdArrayTensor<E, 1>>,
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
padding_out: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
options: ConvTransposeOptions<2>,
|
||||
) -> NdArrayTensor<E, 4> {
|
||||
conv_transpose2d(x, weight, bias, stride, padding, padding_out, dilation)
|
||||
conv_transpose2d(x, weight, bias, options)
|
||||
}
|
||||
|
||||
fn max_pool2d(
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
use crate::{element::TchElement, TchBackend, TchTensor};
|
||||
use burn_tensor::ops::{MaxPool2dBackward, MaxPool2dWithIndexes, ModuleOps};
|
||||
use burn_tensor::ops::{
|
||||
ConvOptions, ConvTransposeOptions, MaxPool2dBackward, MaxPool2dWithIndexes, ModuleOps,
|
||||
};
|
||||
|
||||
impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
|
||||
fn embedding(weights: TchTensor<E, 2>, indexes: TchTensor<i64, 2>) -> TchTensor<E, 3> {
|
||||
|
@ -30,18 +32,16 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
|
|||
x: TchTensor<E, 3>,
|
||||
weight: TchTensor<E, 3>,
|
||||
bias: Option<TchTensor<E, 1>>,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
dilation: usize,
|
||||
options: ConvOptions<1>,
|
||||
) -> TchTensor<E, 3> {
|
||||
let tensor = tch::Tensor::conv1d(
|
||||
&x.tensor,
|
||||
&weight.tensor,
|
||||
bias.map(|t| t.tensor),
|
||||
&[stride as i64],
|
||||
&[padding as i64],
|
||||
&[dilation as i64],
|
||||
1,
|
||||
&options.stride.map(|i| i as i64),
|
||||
&options.padding.map(|i| i as i64),
|
||||
&options.dilation.map(|i| i as i64),
|
||||
options.groups as i64,
|
||||
);
|
||||
|
||||
TchTensor::new(tensor)
|
||||
|
@ -51,18 +51,16 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
|
|||
x: TchTensor<E, 4>,
|
||||
weight: TchTensor<E, 4>,
|
||||
bias: Option<TchTensor<E, 1>>,
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
options: ConvOptions<2>,
|
||||
) -> TchTensor<E, 4> {
|
||||
let tensor = tch::Tensor::conv2d(
|
||||
&x.tensor,
|
||||
&weight.tensor,
|
||||
bias.map(|t| t.tensor),
|
||||
&[stride[0] as i64, stride[1] as i64],
|
||||
&[padding[0] as i64, padding[1] as i64],
|
||||
&[dilation[0] as i64, dilation[1] as i64],
|
||||
1,
|
||||
&options.stride.map(|i| i as i64),
|
||||
&options.padding.map(|i| i as i64),
|
||||
&options.dilation.map(|i| i as i64),
|
||||
options.groups as i64,
|
||||
);
|
||||
|
||||
TchTensor::new(tensor)
|
||||
|
@ -72,20 +70,17 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
|
|||
x: TchTensor<E, 4>,
|
||||
weight: TchTensor<E, 4>,
|
||||
bias: Option<TchTensor<E, 1>>,
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
padding_out: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
options: ConvTransposeOptions<2>,
|
||||
) -> TchTensor<E, 4> {
|
||||
let tensor = tch::Tensor::conv_transpose2d(
|
||||
&x.tensor,
|
||||
&weight.tensor,
|
||||
bias.map(|t| t.tensor),
|
||||
&[stride[0] as i64, stride[1] as i64],
|
||||
&[padding[0] as i64, padding[1] as i64],
|
||||
&[padding_out[0] as i64, padding_out[1] as i64],
|
||||
1,
|
||||
&[dilation[0] as i64, dilation[1] as i64],
|
||||
&options.stride.map(|i| i as i64),
|
||||
&options.padding.map(|i| i as i64),
|
||||
&options.padding_out.map(|i| i as i64),
|
||||
options.groups as i64,
|
||||
&options.dilation.map(|i| i as i64),
|
||||
);
|
||||
|
||||
TchTensor::new(tensor)
|
||||
|
@ -95,20 +90,17 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
|
|||
x: TchTensor<E, 3>,
|
||||
weight: TchTensor<E, 3>,
|
||||
bias: Option<TchTensor<E, 1>>,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
padding_out: usize,
|
||||
dilation: usize,
|
||||
options: ConvTransposeOptions<1>,
|
||||
) -> TchTensor<E, 3> {
|
||||
let tensor = tch::Tensor::conv_transpose1d(
|
||||
&x.tensor,
|
||||
&weight.tensor,
|
||||
bias.map(|t| t.tensor),
|
||||
&[stride as i64],
|
||||
&[padding as i64],
|
||||
&[padding_out as i64],
|
||||
1,
|
||||
&[dilation as i64],
|
||||
&options.stride.map(|i| i as i64),
|
||||
&options.padding.map(|i| i as i64),
|
||||
&options.padding_out.map(|i| i as i64),
|
||||
options.groups as i64,
|
||||
&options.dilation.map(|i| i as i64),
|
||||
);
|
||||
|
||||
TchTensor::new(tensor)
|
||||
|
|
|
@ -1,4 +1,8 @@
|
|||
use crate::{backend::Backend, Int, Tensor};
|
||||
use crate::{
|
||||
backend::Backend,
|
||||
ops::{ConvOptions, ConvTransposeOptions},
|
||||
Int, Tensor,
|
||||
};
|
||||
|
||||
/// Applies the [embedding module](crate::ops::ModuleOps::embedding).
|
||||
pub fn embedding<B>(weights: Tensor<B, 2>, indexes: Tensor<B, 2, Int>) -> Tensor<B, 3>
|
||||
|
@ -13,9 +17,7 @@ pub fn conv1d<B>(
|
|||
x: Tensor<B, 3>,
|
||||
weight: Tensor<B, 3>,
|
||||
bias: Option<Tensor<B, 1>>,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
dilation: usize,
|
||||
options: ConvOptions<1>,
|
||||
) -> Tensor<B, 3>
|
||||
where
|
||||
B: Backend,
|
||||
|
@ -24,9 +26,7 @@ where
|
|||
x.primitive,
|
||||
weight.primitive,
|
||||
bias.map(|b| b.primitive),
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
options,
|
||||
))
|
||||
}
|
||||
|
||||
|
@ -35,9 +35,7 @@ pub fn conv2d<B>(
|
|||
x: Tensor<B, 4>,
|
||||
weight: Tensor<B, 4>,
|
||||
bias: Option<Tensor<B, 1>>,
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
options: ConvOptions<2>,
|
||||
) -> Tensor<B, 4>
|
||||
where
|
||||
B: Backend,
|
||||
|
@ -46,9 +44,7 @@ where
|
|||
x.primitive,
|
||||
weight.primitive,
|
||||
bias.map(|b| b.primitive),
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
options,
|
||||
))
|
||||
}
|
||||
|
||||
|
@ -57,10 +53,7 @@ pub fn conv_transpose1d<B>(
|
|||
x: Tensor<B, 3>,
|
||||
weight: Tensor<B, 3>,
|
||||
bias: Option<Tensor<B, 1>>,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
padding_out: usize,
|
||||
dilation: usize,
|
||||
options: ConvTransposeOptions<1>,
|
||||
) -> Tensor<B, 3>
|
||||
where
|
||||
B: Backend,
|
||||
|
@ -69,10 +62,7 @@ where
|
|||
x.primitive,
|
||||
weight.primitive,
|
||||
bias.map(|b| b.primitive),
|
||||
stride,
|
||||
padding,
|
||||
padding_out,
|
||||
dilation,
|
||||
options,
|
||||
))
|
||||
}
|
||||
|
||||
|
@ -81,10 +71,7 @@ pub fn conv_transpose2d<B>(
|
|||
x: Tensor<B, 4>,
|
||||
weight: Tensor<B, 4>,
|
||||
bias: Option<Tensor<B, 1>>,
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
padding_out: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
options: ConvTransposeOptions<2>,
|
||||
) -> Tensor<B, 4>
|
||||
where
|
||||
B: Backend,
|
||||
|
@ -93,10 +80,7 @@ where
|
|||
x.primitive,
|
||||
weight.primitive,
|
||||
bias.map(|b| b.primitive),
|
||||
stride,
|
||||
padding,
|
||||
padding_out,
|
||||
dilation,
|
||||
options,
|
||||
))
|
||||
}
|
||||
|
||||
|
|
|
@ -30,6 +30,25 @@ pub struct Conv1dBackward<B: Backend> {
|
|||
pub bias_grad: Option<B::TensorPrimitive<1>>,
|
||||
}
|
||||
|
||||
/// Convolution options.
|
||||
#[derive(new, Debug, Clone)]
|
||||
pub struct ConvOptions<const N: usize> {
|
||||
pub stride: [usize; N],
|
||||
pub padding: [usize; N],
|
||||
pub dilation: [usize; N],
|
||||
pub groups: usize,
|
||||
}
|
||||
|
||||
/// Transposed convolution options.
|
||||
#[derive(new, Debug, Clone)]
|
||||
pub struct ConvTransposeOptions<const N: usize> {
|
||||
pub stride: [usize; N],
|
||||
pub padding: [usize; N],
|
||||
pub padding_out: [usize; N],
|
||||
pub dilation: [usize; N],
|
||||
pub groups: usize,
|
||||
}
|
||||
|
||||
pub trait ModuleOps<B: Backend> {
|
||||
fn embedding(
|
||||
weights: B::TensorPrimitive<2>,
|
||||
|
@ -51,9 +70,7 @@ pub trait ModuleOps<B: Backend> {
|
|||
x: B::TensorPrimitive<4>,
|
||||
weight: B::TensorPrimitive<4>,
|
||||
bias: Option<B::TensorPrimitive<1>>,
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
options: ConvOptions<2>,
|
||||
) -> B::TensorPrimitive<4>;
|
||||
/// Two dimensional transposed convolution.
|
||||
///
|
||||
|
@ -66,10 +83,7 @@ pub trait ModuleOps<B: Backend> {
|
|||
x: B::TensorPrimitive<4>,
|
||||
weight: B::TensorPrimitive<4>,
|
||||
bias: Option<B::TensorPrimitive<1>>,
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
padding_out: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
options: ConvTransposeOptions<2>,
|
||||
) -> B::TensorPrimitive<4>;
|
||||
|
||||
/// Backward pass for the [conv2d](ModuleOps::conv2d) operation.
|
||||
|
@ -77,12 +91,10 @@ pub trait ModuleOps<B: Backend> {
|
|||
x: B::TensorPrimitive<4>,
|
||||
weight: B::TensorPrimitive<4>,
|
||||
bias: Option<B::TensorPrimitive<1>>,
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
output_grad: B::TensorPrimitive<4>,
|
||||
options: ConvOptions<2>,
|
||||
) -> Conv2dBackward<B> {
|
||||
conv::conv2d_backward(x, weight, bias, stride, padding, dilation, output_grad)
|
||||
conv::conv2d_backward(x, weight, bias, output_grad, options)
|
||||
}
|
||||
/// One dimensional convolution.
|
||||
///
|
||||
|
@ -95,11 +107,9 @@ pub trait ModuleOps<B: Backend> {
|
|||
x: B::TensorPrimitive<3>,
|
||||
weight: B::TensorPrimitive<3>,
|
||||
bias: Option<B::TensorPrimitive<1>>,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
dilation: usize,
|
||||
options: ConvOptions<1>,
|
||||
) -> B::TensorPrimitive<3> {
|
||||
conv::conv1d_from_conv2d::<B>(x, weight, bias, stride, padding, dilation)
|
||||
conv::conv1d_from_conv2d::<B>(x, weight, bias, options)
|
||||
}
|
||||
/// One dimensional transposed convolution.
|
||||
///
|
||||
|
@ -112,32 +122,19 @@ pub trait ModuleOps<B: Backend> {
|
|||
x: B::TensorPrimitive<3>,
|
||||
weight: B::TensorPrimitive<3>,
|
||||
bias: Option<B::TensorPrimitive<1>>,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
padding_out: usize,
|
||||
dilation: usize,
|
||||
options: ConvTransposeOptions<1>,
|
||||
) -> B::TensorPrimitive<3> {
|
||||
conv::conv_transpose1d_from_conv_transpose2d::<B>(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
stride,
|
||||
padding,
|
||||
padding_out,
|
||||
dilation,
|
||||
)
|
||||
conv::conv_transpose1d_from_conv_transpose2d::<B>(x, weight, bias, options)
|
||||
}
|
||||
/// Backward pass for the [conv1d](ModuleOps::conv1d) operation.
|
||||
fn conv1d_backward(
|
||||
x: B::TensorPrimitive<3>,
|
||||
weight: B::TensorPrimitive<3>,
|
||||
bias: Option<B::TensorPrimitive<1>>,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
dilation: usize,
|
||||
output_grad: B::TensorPrimitive<3>,
|
||||
options: ConvOptions<1>,
|
||||
) -> Conv1dBackward<B> {
|
||||
conv::conv1d_backward(x, weight, bias, stride, padding, dilation, output_grad)
|
||||
conv::conv1d_backward(x, weight, bias, output_grad, options)
|
||||
}
|
||||
/// Two dimensional max pooling.
|
||||
///
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use super::{Conv1dBackward, Conv2dBackward};
|
||||
use super::{Conv1dBackward, Conv2dBackward, ConvOptions, ConvTransposeOptions};
|
||||
use crate::{backend::Backend, Shape};
|
||||
use libm::ceilf;
|
||||
|
||||
|
@ -31,43 +31,26 @@ pub fn calculate_conv_output_size(
|
|||
(size_in + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1
|
||||
}
|
||||
|
||||
fn calculate_padding_out(
|
||||
kernel_size: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
dilation: usize,
|
||||
size_in: usize,
|
||||
size_out: usize,
|
||||
) -> usize {
|
||||
if stride <= 1 {
|
||||
return 0;
|
||||
}
|
||||
|
||||
let out = 1 + libm::ceil(
|
||||
(size_in + 2 * padding - dilation * (kernel_size - 1) - 1) as f64 / stride as f64,
|
||||
) as usize;
|
||||
i64::max(0, out as i64 - size_out as i64) as usize
|
||||
}
|
||||
|
||||
/// Calculate the [1D convolution](crate::ops::ModuleOps::conv1d) backward pass using convolutions.
|
||||
pub(crate) fn conv1d_backward<B: Backend>(
|
||||
x: B::TensorPrimitive<3>,
|
||||
weight: B::TensorPrimitive<3>,
|
||||
bias: Option<B::TensorPrimitive<1>>,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
dilation: usize,
|
||||
output_grad: B::TensorPrimitive<3>,
|
||||
options: ConvOptions<1>,
|
||||
) -> Conv1dBackward<B> {
|
||||
let [batch_size, channels_in, length_in] = B::shape(&x).dims;
|
||||
let weight_shape = B::shape(&weight);
|
||||
let weight_device = B::device(&weight);
|
||||
|
||||
let [batch_size, _, length_in] = B::shape(&x).dims;
|
||||
let [_batch_size, channels_out, length_out] = B::shape(&output_grad).dims;
|
||||
let [_, _, kernel_size] = B::shape(&weight).dims;
|
||||
let [_, _, kernel_size] = weight_shape.dims;
|
||||
|
||||
let padding_out = calculate_padding_out(
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
options.stride[0],
|
||||
options.padding[0],
|
||||
options.dilation[0],
|
||||
length_in,
|
||||
length_out,
|
||||
);
|
||||
|
@ -76,36 +59,30 @@ pub(crate) fn conv1d_backward<B: Backend>(
|
|||
output_grad.clone(),
|
||||
weight,
|
||||
None,
|
||||
stride,
|
||||
padding,
|
||||
padding_out,
|
||||
dilation,
|
||||
ConvTransposeOptions::new(
|
||||
options.stride,
|
||||
options.padding,
|
||||
[padding_out],
|
||||
options.dilation,
|
||||
options.groups,
|
||||
),
|
||||
);
|
||||
|
||||
let x_swapped = B::swap_dims(x, 0, 1);
|
||||
let output_grad_swapped = B::swap_dims(output_grad, 0, 1);
|
||||
let weight_grad_swapped = B::conv1d(
|
||||
x_swapped,
|
||||
output_grad_swapped.clone(),
|
||||
None,
|
||||
dilation,
|
||||
padding,
|
||||
stride,
|
||||
);
|
||||
let mut weight_grad = B::swap_dims(weight_grad_swapped, 0, 1);
|
||||
|
||||
if B::shape(&weight_grad) != Shape::new([channels_out, channels_in, kernel_size]) {
|
||||
weight_grad = B::index(
|
||||
weight_grad,
|
||||
[0..channels_out, 0..channels_in, 0..kernel_size],
|
||||
);
|
||||
}
|
||||
let weight_grad = match options.groups == 1 {
|
||||
true => conv1d_weight_grad_no_groups::<B>(x, output_grad.clone(), weight_shape, options),
|
||||
false => conv1d_weight_grad_groups::<B>(
|
||||
x,
|
||||
B::zeros(weight_shape, &weight_device),
|
||||
output_grad.clone(),
|
||||
options,
|
||||
),
|
||||
};
|
||||
|
||||
Conv1dBackward::new(
|
||||
x_grad,
|
||||
weight_grad,
|
||||
bias.map(|b| {
|
||||
let grad = output_grad_swapped;
|
||||
let grad = B::swap_dims(output_grad, 0, 1);
|
||||
let grad = B::reshape(grad, Shape::new([channels_out, batch_size * length_out]));
|
||||
let grad = B::sum_dim(grad, 1);
|
||||
|
||||
|
@ -119,28 +96,29 @@ pub(crate) fn conv2d_backward<B: Backend>(
|
|||
x: B::TensorPrimitive<4>,
|
||||
weight: B::TensorPrimitive<4>,
|
||||
bias: Option<B::TensorPrimitive<1>>,
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
output_grad: B::TensorPrimitive<4>,
|
||||
options: ConvOptions<2>,
|
||||
) -> Conv2dBackward<B> {
|
||||
let [batch_size, channels_in, height_in, width_in] = B::shape(&x).dims;
|
||||
let [_batch_size, channels_out, height_out, width_out] = B::shape(&output_grad).dims;
|
||||
let [_, _, kernel_size_1, kernel_size_2] = B::shape(&weight).dims;
|
||||
let weight_shape = B::shape(&weight);
|
||||
let weight_device = B::device(&weight);
|
||||
|
||||
let [batch_size, _channels_in, height_in, width_in] = B::shape(&x).dims;
|
||||
let [_, _, height_out, width_out] = B::shape(&output_grad).dims;
|
||||
let [channels_out, _, kernel_size_1, kernel_size_2] = weight_shape.dims;
|
||||
|
||||
let padding_1_out = calculate_padding_out(
|
||||
kernel_size_1,
|
||||
stride[0],
|
||||
padding[0],
|
||||
dilation[0],
|
||||
options.stride[0],
|
||||
options.padding[0],
|
||||
options.dilation[0],
|
||||
height_in,
|
||||
height_out,
|
||||
);
|
||||
let padding_2_out = calculate_padding_out(
|
||||
kernel_size_2,
|
||||
stride[1],
|
||||
padding[1],
|
||||
dilation[1],
|
||||
options.stride[1],
|
||||
options.padding[1],
|
||||
options.dilation[1],
|
||||
width_in,
|
||||
width_out,
|
||||
);
|
||||
|
@ -149,43 +127,30 @@ pub(crate) fn conv2d_backward<B: Backend>(
|
|||
output_grad.clone(),
|
||||
weight,
|
||||
None,
|
||||
stride,
|
||||
padding,
|
||||
[padding_1_out, padding_2_out],
|
||||
dilation,
|
||||
ConvTransposeOptions::new(
|
||||
options.stride,
|
||||
options.padding,
|
||||
[padding_1_out, padding_2_out],
|
||||
options.dilation,
|
||||
options.groups,
|
||||
),
|
||||
);
|
||||
|
||||
let x_swapped = B::swap_dims(x, 0, 1);
|
||||
let output_grad_swapped = B::swap_dims(output_grad, 0, 1);
|
||||
let weight_grad_swapped = B::conv2d(
|
||||
x_swapped,
|
||||
output_grad_swapped.clone(),
|
||||
None,
|
||||
dilation,
|
||||
padding,
|
||||
stride,
|
||||
);
|
||||
let mut weight_grad = B::swap_dims(weight_grad_swapped, 0, 1);
|
||||
|
||||
if B::shape(&weight_grad)
|
||||
!= Shape::new([channels_out, channels_in, kernel_size_1, kernel_size_2])
|
||||
{
|
||||
weight_grad = B::index(
|
||||
weight_grad,
|
||||
[
|
||||
0..channels_out,
|
||||
0..channels_in,
|
||||
0..kernel_size_1,
|
||||
0..kernel_size_2,
|
||||
],
|
||||
);
|
||||
}
|
||||
let weight_grad = match options.groups == 1 {
|
||||
true => conv2d_weight_grad_no_groups::<B>(x, output_grad.clone(), weight_shape, options),
|
||||
false => conv2d_weight_grad_groups::<B>(
|
||||
x,
|
||||
B::zeros(weight_shape, &weight_device),
|
||||
output_grad.clone(),
|
||||
options,
|
||||
),
|
||||
};
|
||||
|
||||
Conv2dBackward::new(
|
||||
x_grad,
|
||||
weight_grad,
|
||||
bias.map(|b| {
|
||||
let grad = output_grad_swapped;
|
||||
let grad = B::swap_dims(output_grad, 0, 1);
|
||||
let grad = B::reshape(
|
||||
grad,
|
||||
Shape::new([channels_out, batch_size * height_out * width_out]),
|
||||
|
@ -202,20 +167,28 @@ pub(crate) fn conv1d_from_conv2d<B: Backend>(
|
|||
x: B::TensorPrimitive<3>,
|
||||
weight: B::TensorPrimitive<3>,
|
||||
bias: Option<B::TensorPrimitive<1>>,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
dilation: usize,
|
||||
options: ConvOptions<1>,
|
||||
) -> B::TensorPrimitive<3> {
|
||||
let [channels_out, _channels_in, kernel_size] = B::shape(&weight).dims;
|
||||
let [batch_size, channels_in, length_in] = B::shape(&x).dims;
|
||||
|
||||
let weight = B::reshape(
|
||||
weight,
|
||||
Shape::new([channels_out, channels_in, kernel_size, 1]),
|
||||
Shape::new([channels_out, channels_in / options.groups, kernel_size, 1]),
|
||||
);
|
||||
let x = B::reshape(x, Shape::new([batch_size, channels_in, length_in, 1]));
|
||||
|
||||
let tensor = B::conv2d(x, weight, bias, [stride, 1], [padding, 0], [dilation, 1]);
|
||||
let tensor = B::conv2d(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
ConvOptions::new(
|
||||
[options.stride[0], 1],
|
||||
[options.padding[0], 0],
|
||||
[options.dilation[0], 1],
|
||||
options.groups,
|
||||
),
|
||||
);
|
||||
let [batch_size, channels_out, height_out, _weight_out] = B::shape(&tensor).dims;
|
||||
B::reshape(tensor, Shape::from([batch_size, channels_out, height_out]))
|
||||
}
|
||||
|
@ -225,10 +198,7 @@ pub(crate) fn conv_transpose1d_from_conv_transpose2d<B: Backend>(
|
|||
x: B::TensorPrimitive<3>,
|
||||
weight: B::TensorPrimitive<3>,
|
||||
bias: Option<B::TensorPrimitive<1>>,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
padding_out: usize,
|
||||
dilation: usize,
|
||||
options: ConvTransposeOptions<1>,
|
||||
) -> B::TensorPrimitive<3> {
|
||||
let [channels_in, channels_out, kernel_size] = B::shape(&weight).dims;
|
||||
let [batch_size, _channels_in, length_in] = B::shape(&x).dims;
|
||||
|
@ -243,15 +213,174 @@ pub(crate) fn conv_transpose1d_from_conv_transpose2d<B: Backend>(
|
|||
x,
|
||||
weight,
|
||||
bias,
|
||||
[stride, 1],
|
||||
[padding, 0],
|
||||
[padding_out, 0],
|
||||
[dilation, 1],
|
||||
ConvTransposeOptions::new(
|
||||
[options.stride[0], 1],
|
||||
[options.padding[0], 0],
|
||||
[options.padding_out[0], 0],
|
||||
[options.dilation[0], 1],
|
||||
options.groups,
|
||||
),
|
||||
);
|
||||
let [batch_size, channels_out, height_out, _weight_out] = B::shape(&tensor).dims;
|
||||
B::reshape(tensor, Shape::from([batch_size, channels_out, height_out]))
|
||||
}
|
||||
|
||||
fn conv1d_weight_grad_groups<B: Backend>(
|
||||
x: B::TensorPrimitive<3>,
|
||||
mut weight_grad: B::TensorPrimitive<3>,
|
||||
output_grad: B::TensorPrimitive<3>,
|
||||
options: ConvOptions<1>,
|
||||
) -> B::TensorPrimitive<3> {
|
||||
let [channels_out, increment_ci, kernel_size] = B::shape(&weight_grad).dims;
|
||||
let increment_co = channels_out / options.groups;
|
||||
|
||||
let x_swapped = B::swap_dims(x, 0, 1);
|
||||
let output_grad_swapped = B::swap_dims(output_grad, 0, 1);
|
||||
|
||||
for g in 0..options.groups {
|
||||
let start_idx_ci = g * increment_ci;
|
||||
let end_idx_ci = (g + 1) * increment_ci;
|
||||
let start_idx_co = g * increment_co;
|
||||
let end_idx_co = (g + 1) * increment_co;
|
||||
|
||||
let x = B::index(x_swapped.clone(), [start_idx_ci..end_idx_ci]);
|
||||
let grad = B::index(output_grad_swapped.clone(), [start_idx_co..end_idx_co]);
|
||||
let mut weight_grad_tmp = B::conv1d(
|
||||
x,
|
||||
grad,
|
||||
None,
|
||||
ConvOptions::new(options.dilation, options.padding, options.stride, 1),
|
||||
);
|
||||
weight_grad_tmp = B::swap_dims(weight_grad_tmp, 0, 1);
|
||||
weight_grad = B::index_assign(
|
||||
weight_grad,
|
||||
[start_idx_co..end_idx_co, 0..increment_ci, 0..kernel_size],
|
||||
weight_grad_tmp,
|
||||
);
|
||||
}
|
||||
|
||||
weight_grad
|
||||
}
|
||||
|
||||
fn conv2d_weight_grad_groups<B: Backend>(
|
||||
x: B::TensorPrimitive<4>,
|
||||
mut weight_grad: B::TensorPrimitive<4>,
|
||||
output_grad: B::TensorPrimitive<4>,
|
||||
options: ConvOptions<2>,
|
||||
) -> B::TensorPrimitive<4> {
|
||||
let [channels_out, increment_ci, kernel_size_1, kernel_size_2] = B::shape(&weight_grad).dims;
|
||||
let increment_co = channels_out / options.groups;
|
||||
|
||||
let x_swapped = B::swap_dims(x, 0, 1);
|
||||
let output_grad_swapped = B::swap_dims(output_grad, 0, 1);
|
||||
|
||||
for g in 0..options.groups {
|
||||
let start_idx_ci = g * increment_ci;
|
||||
let end_idx_ci = (g + 1) * increment_ci;
|
||||
let start_idx_co = g * increment_co;
|
||||
let end_idx_co = (g + 1) * increment_co;
|
||||
|
||||
let x = B::index(x_swapped.clone(), [start_idx_ci..end_idx_ci]);
|
||||
let grad = B::index(output_grad_swapped.clone(), [start_idx_co..end_idx_co]);
|
||||
let mut weight_grad_tmp = B::conv2d(
|
||||
x,
|
||||
grad,
|
||||
None,
|
||||
ConvOptions::new(options.dilation, options.padding, options.stride, 1),
|
||||
);
|
||||
weight_grad_tmp = B::swap_dims(weight_grad_tmp, 0, 1);
|
||||
weight_grad = B::index_assign(
|
||||
weight_grad,
|
||||
[
|
||||
start_idx_co..end_idx_co,
|
||||
0..increment_ci,
|
||||
0..kernel_size_1,
|
||||
0..kernel_size_2,
|
||||
],
|
||||
weight_grad_tmp,
|
||||
);
|
||||
}
|
||||
|
||||
weight_grad
|
||||
}
|
||||
|
||||
fn conv1d_weight_grad_no_groups<B: Backend>(
|
||||
x: B::TensorPrimitive<3>,
|
||||
output_grad: B::TensorPrimitive<3>,
|
||||
weight_shape: Shape<3>,
|
||||
options: ConvOptions<1>,
|
||||
) -> B::TensorPrimitive<3> {
|
||||
let x_swapped = B::swap_dims(x, 0, 1);
|
||||
let output_grad_swapped = B::swap_dims(output_grad, 0, 1);
|
||||
let weight_grad_swapped = B::conv1d(
|
||||
x_swapped,
|
||||
output_grad_swapped,
|
||||
None,
|
||||
ConvOptions::new(options.dilation, options.padding, options.stride, 1),
|
||||
);
|
||||
let mut weight_grad = B::swap_dims(weight_grad_swapped, 0, 1);
|
||||
|
||||
if B::shape(&weight_grad) != weight_shape {
|
||||
weight_grad = B::index(
|
||||
weight_grad,
|
||||
[
|
||||
0..weight_shape.dims[0],
|
||||
0..weight_shape.dims[1],
|
||||
0..weight_shape.dims[2],
|
||||
],
|
||||
);
|
||||
}
|
||||
weight_grad
|
||||
}
|
||||
|
||||
fn conv2d_weight_grad_no_groups<B: Backend>(
|
||||
x: B::TensorPrimitive<4>,
|
||||
output_grad: B::TensorPrimitive<4>,
|
||||
weight_shape: Shape<4>,
|
||||
options: ConvOptions<2>,
|
||||
) -> B::TensorPrimitive<4> {
|
||||
let x_swapped = B::swap_dims(x, 0, 1);
|
||||
let output_grad_swapped = B::swap_dims(output_grad, 0, 1);
|
||||
let weight_grad_swapped = B::conv2d(
|
||||
x_swapped,
|
||||
output_grad_swapped,
|
||||
None,
|
||||
ConvOptions::new(options.dilation, options.padding, options.stride, 1),
|
||||
);
|
||||
let mut weight_grad = B::swap_dims(weight_grad_swapped, 0, 1);
|
||||
|
||||
if B::shape(&weight_grad) != weight_shape {
|
||||
weight_grad = B::index(
|
||||
weight_grad,
|
||||
[
|
||||
0..weight_shape.dims[0],
|
||||
0..weight_shape.dims[1],
|
||||
0..weight_shape.dims[2],
|
||||
0..weight_shape.dims[3],
|
||||
],
|
||||
);
|
||||
}
|
||||
weight_grad
|
||||
}
|
||||
|
||||
fn calculate_padding_out(
|
||||
kernel_size: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
dilation: usize,
|
||||
size_in: usize,
|
||||
size_out: usize,
|
||||
) -> usize {
|
||||
if stride <= 1 {
|
||||
return 0;
|
||||
}
|
||||
|
||||
let out = 1 + libm::ceil(
|
||||
(size_in + 2 * padding - dilation * (kernel_size - 1) - 1) as f64 / stride as f64,
|
||||
) as usize;
|
||||
i64::max(0, out as i64 - size_out as i64) as usize
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
|
@ -2,32 +2,26 @@
|
|||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::module::conv1d;
|
||||
use burn_tensor::{Data, Tensor};
|
||||
use burn_tensor::ops::ConvOptions;
|
||||
use burn_tensor::{Data, Shape, Tensor};
|
||||
|
||||
#[test]
|
||||
fn test_conv1d_simple() {
|
||||
let test = Conv1dTestCase {
|
||||
batch_size: 2,
|
||||
channels_in: 3,
|
||||
channels_out: 3,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size: 3,
|
||||
padding: 1,
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
length: 6,
|
||||
groups: 1,
|
||||
length: 4,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from_floats([
|
||||
[
|
||||
[7., 10., 10., 10., 10., 7.],
|
||||
[7., 10., 10., 10., 10., 7.],
|
||||
[7., 10., 10., 10., 10., 7.],
|
||||
],
|
||||
[
|
||||
[7., 10., 10., 10., 10., 7.],
|
||||
[7., 10., 10., 10., 10., 7.],
|
||||
[7., 10., 10., 10., 10., 7.],
|
||||
],
|
||||
[[43., 67., 82., 49.], [104., 176., 227., 158.]],
|
||||
[[139., 187., 202., 113.], [392., 584., 635., 414.]],
|
||||
]));
|
||||
}
|
||||
|
||||
|
@ -41,12 +35,33 @@ mod tests {
|
|||
padding: 1,
|
||||
stride: 1,
|
||||
dilation: 2,
|
||||
groups: 1,
|
||||
length: 4,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from_floats([
|
||||
[[5., 5.], [5., 5.]],
|
||||
[[5., 5.], [5., 5.]],
|
||||
[[62., 38.], [159., 111.]],
|
||||
[[158., 102.], [447., 367.]],
|
||||
]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv1d_groups() {
|
||||
let test = Conv1dTestCase {
|
||||
batch_size: 2,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size: 3,
|
||||
padding: 1,
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
groups: 2,
|
||||
length: 4,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from_floats([
|
||||
[[2., 5., 8., 3.], [42., 63., 75., 47.]],
|
||||
[[26., 29., 32., 11.], [114., 159., 171., 103.]],
|
||||
]));
|
||||
}
|
||||
|
||||
|
@ -60,22 +75,13 @@ mod tests {
|
|||
padding: 1,
|
||||
stride: 2,
|
||||
dilation: 1,
|
||||
length: 9,
|
||||
groups: 1,
|
||||
length: 4,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from_floats([
|
||||
[
|
||||
[7., 10., 10., 10., 7.],
|
||||
[7., 10., 10., 10., 7.],
|
||||
[7., 10., 10., 10., 7.],
|
||||
[7., 10., 10., 10., 7.],
|
||||
],
|
||||
[
|
||||
[7., 10., 10., 10., 7.],
|
||||
[7., 10., 10., 10., 7.],
|
||||
[7., 10., 10., 10., 7.],
|
||||
[7., 10., 10., 10., 7.],
|
||||
],
|
||||
[[171., 294.], [415., 781.], [659., 1268.], [903., 1755.]],
|
||||
[[495., 726.], [1387., 2185.], [2279., 3644.], [3171., 5103.]],
|
||||
]));
|
||||
}
|
||||
|
||||
|
@ -87,21 +93,40 @@ mod tests {
|
|||
padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
groups: usize,
|
||||
length: usize,
|
||||
}
|
||||
|
||||
impl Conv1dTestCase {
|
||||
fn assert_output(self, y: TestTensor<3>) {
|
||||
let weights = TestTensor::ones([self.channels_out, self.channels_in, self.kernel_size]);
|
||||
let bias = TestTensor::ones([self.channels_out]);
|
||||
let x = TestTensor::ones([self.batch_size, self.channels_in, self.length]);
|
||||
let shape_x = Shape::new([self.batch_size, self.channels_in, self.length]);
|
||||
let shape_weight = Shape::new([
|
||||
self.channels_out,
|
||||
self.channels_in / self.groups,
|
||||
self.kernel_size,
|
||||
]);
|
||||
let weight = TestTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_weight.num_elements())
|
||||
.reshape(shape_weight)
|
||||
.into_data()
|
||||
.convert(),
|
||||
);
|
||||
let bias = TestTensor::from_data(
|
||||
TestTensorInt::arange(0..self.channels_out)
|
||||
.into_data()
|
||||
.convert(),
|
||||
);
|
||||
let x = TestTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_x.num_elements())
|
||||
.reshape(shape_x)
|
||||
.into_data()
|
||||
.convert(),
|
||||
);
|
||||
let output = conv1d(
|
||||
x,
|
||||
weights,
|
||||
weight,
|
||||
Some(bias),
|
||||
self.stride,
|
||||
self.padding,
|
||||
self.dilation,
|
||||
ConvOptions::new([self.stride], [self.padding], [self.dilation], self.groups),
|
||||
);
|
||||
|
||||
y.to_data().assert_approx_eq(&output.into_data(), 3);
|
||||
|
|
|
@ -2,14 +2,15 @@
|
|||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::module::conv2d;
|
||||
use burn_tensor::{Data, Tensor};
|
||||
use burn_tensor::ops::ConvOptions;
|
||||
use burn_tensor::{Data, Shape, Tensor};
|
||||
|
||||
#[test]
|
||||
fn test_conv2d_simple() {
|
||||
let test = Conv2dTestCase {
|
||||
batch_size: 2,
|
||||
channels_in: 3,
|
||||
channels_out: 3,
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 1,
|
||||
|
@ -18,64 +19,54 @@ mod tests {
|
|||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
height: 6,
|
||||
width: 6,
|
||||
groups: 1,
|
||||
height: 4,
|
||||
width: 4,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from_floats([
|
||||
test.assert_output(TestTensor::from_floats([[
|
||||
[
|
||||
[
|
||||
[13., 19., 19., 19., 19., 13.],
|
||||
[19., 28., 28., 28., 28., 19.],
|
||||
[19., 28., 28., 28., 28., 19.],
|
||||
[19., 28., 28., 28., 28., 19.],
|
||||
[19., 28., 28., 28., 28., 19.],
|
||||
[13., 19., 19., 19., 19., 13.],
|
||||
],
|
||||
[
|
||||
[13., 19., 19., 19., 19., 13.],
|
||||
[19., 28., 28., 28., 28., 19.],
|
||||
[19., 28., 28., 28., 28., 19.],
|
||||
[19., 28., 28., 28., 28., 19.],
|
||||
[19., 28., 28., 28., 28., 19.],
|
||||
[13., 19., 19., 19., 19., 13.],
|
||||
],
|
||||
[
|
||||
[13., 19., 19., 19., 19., 13.],
|
||||
[19., 28., 28., 28., 28., 19.],
|
||||
[19., 28., 28., 28., 28., 19.],
|
||||
[19., 28., 28., 28., 28., 19.],
|
||||
[19., 28., 28., 28., 28., 19.],
|
||||
[13., 19., 19., 19., 19., 13.],
|
||||
],
|
||||
[1196., 1796., 1916., 1264.],
|
||||
[1881., 2793., 2946., 1923.],
|
||||
[2313., 3405., 3558., 2307.],
|
||||
[1424., 2072., 2156., 1380.],
|
||||
],
|
||||
[
|
||||
[
|
||||
[13., 19., 19., 19., 19., 13.],
|
||||
[19., 28., 28., 28., 28., 19.],
|
||||
[19., 28., 28., 28., 28., 19.],
|
||||
[19., 28., 28., 28., 28., 19.],
|
||||
[19., 28., 28., 28., 28., 19.],
|
||||
[13., 19., 19., 19., 19., 13.],
|
||||
],
|
||||
[
|
||||
[13., 19., 19., 19., 19., 13.],
|
||||
[19., 28., 28., 28., 28., 19.],
|
||||
[19., 28., 28., 28., 28., 19.],
|
||||
[19., 28., 28., 28., 28., 19.],
|
||||
[19., 28., 28., 28., 28., 19.],
|
||||
[13., 19., 19., 19., 19., 13.],
|
||||
],
|
||||
[
|
||||
[13., 19., 19., 19., 19., 13.],
|
||||
[19., 28., 28., 28., 28., 19.],
|
||||
[19., 28., 28., 28., 28., 19.],
|
||||
[19., 28., 28., 28., 28., 19.],
|
||||
[19., 28., 28., 28., 28., 19.],
|
||||
[13., 19., 19., 19., 19., 13.],
|
||||
],
|
||||
[2709., 4173., 4509., 3065.],
|
||||
[4582., 7006., 7483., 5056.],
|
||||
[5878., 8914., 9391., 6304.],
|
||||
[4089., 6177., 6477., 4333.],
|
||||
],
|
||||
]));
|
||||
]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv2d_groups() {
|
||||
let test = Conv2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 0,
|
||||
padding_2: 0,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
groups: 2,
|
||||
height: 5,
|
||||
width: 5,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from_floats([[
|
||||
[[312., 348., 384.], [492., 528., 564.], [672., 708., 744.]],
|
||||
[
|
||||
[3724., 3841., 3958.],
|
||||
[4309., 4426., 4543.],
|
||||
[4894., 5011., 5128.],
|
||||
],
|
||||
]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -92,22 +83,23 @@ mod tests {
|
|||
stride_2: 3,
|
||||
dilation_1: 1,
|
||||
dilation_2: 2,
|
||||
groups: 1,
|
||||
height: 4,
|
||||
width: 5,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from_floats([
|
||||
[
|
||||
[[7., 13., 7.], [10., 19., 10.]],
|
||||
[[7., 13., 7.], [10., 19., 10.]],
|
||||
[[7., 13., 7.], [10., 19., 10.]],
|
||||
[[7., 13., 7.], [10., 19., 10.]],
|
||||
[[1845., 3789., 1926.], [3210., 6465., 3228.]],
|
||||
[[4276., 9082., 4789.], [8071., 16834., 8737.]],
|
||||
[[6707., 14375., 7652.], [12932., 27203., 14246.]],
|
||||
[[9138., 19668., 10515.], [17793., 37572., 19755.]],
|
||||
],
|
||||
[
|
||||
[[7., 13., 7.], [10., 19., 10.]],
|
||||
[[7., 13., 7.], [10., 19., 10.]],
|
||||
[[7., 13., 7.], [10., 19., 10.]],
|
||||
[[7., 13., 7.], [10., 19., 10.]],
|
||||
[[5445., 10629., 5166.], [8070., 15645., 7548.]],
|
||||
[[14356., 28882., 14509.], [22651., 45454., 22777.]],
|
||||
[[23267., 47135., 23852.], [37232., 75263., 38006.]],
|
||||
[[32178., 65388., 33195.], [51813., 105072., 53235.]],
|
||||
],
|
||||
]));
|
||||
}
|
||||
|
@ -124,27 +116,47 @@ mod tests {
|
|||
stride_2: usize,
|
||||
dilation_1: usize,
|
||||
dilation_2: usize,
|
||||
groups: usize,
|
||||
height: usize,
|
||||
width: usize,
|
||||
}
|
||||
|
||||
impl Conv2dTestCase {
|
||||
fn assert_output(self, y: TestTensor<4>) {
|
||||
let weights = TestTensor::ones([
|
||||
let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]);
|
||||
let shape_weight = Shape::new([
|
||||
self.channels_out,
|
||||
self.channels_in,
|
||||
self.channels_in / self.groups,
|
||||
self.kernel_size_1,
|
||||
self.kernel_size_2,
|
||||
]);
|
||||
let bias = TestTensor::ones([self.channels_out]);
|
||||
let x = TestTensor::ones([self.batch_size, self.channels_in, self.height, self.width]);
|
||||
let weight = TestTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_weight.num_elements())
|
||||
.reshape(shape_weight)
|
||||
.into_data()
|
||||
.convert(),
|
||||
);
|
||||
let bias = TestTensor::from_data(
|
||||
TestTensorInt::arange(0..self.channels_out)
|
||||
.into_data()
|
||||
.convert(),
|
||||
);
|
||||
let x = TestTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_x.num_elements())
|
||||
.reshape(shape_x)
|
||||
.into_data()
|
||||
.convert(),
|
||||
);
|
||||
let output = conv2d(
|
||||
x,
|
||||
weights,
|
||||
weight,
|
||||
Some(bias),
|
||||
[self.stride_1, self.stride_2],
|
||||
[self.padding_1, self.padding_2],
|
||||
[self.dilation_1, self.dilation_2],
|
||||
ConvOptions::new(
|
||||
[self.stride_1, self.stride_2],
|
||||
[self.padding_1, self.padding_2],
|
||||
[self.dilation_1, self.dilation_2],
|
||||
self.groups,
|
||||
),
|
||||
);
|
||||
|
||||
y.to_data().assert_approx_eq(&output.into_data(), 3);
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::module::conv_transpose1d;
|
||||
use burn_tensor::ops::ConvTransposeOptions;
|
||||
use burn_tensor::{Data, Shape, Tensor};
|
||||
|
||||
#[test]
|
||||
|
@ -15,6 +16,7 @@ mod tests {
|
|||
padding_out: 0,
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
length: 4,
|
||||
};
|
||||
|
||||
|
@ -35,6 +37,7 @@ mod tests {
|
|||
padding_out: 1,
|
||||
stride: 2,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
length: 4,
|
||||
};
|
||||
|
||||
|
@ -55,6 +58,7 @@ mod tests {
|
|||
padding_out: 0,
|
||||
stride: 1,
|
||||
dilation: 2,
|
||||
groups: 1,
|
||||
length: 4,
|
||||
};
|
||||
|
||||
|
@ -64,6 +68,27 @@ mod tests {
|
|||
]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose1d_groups() {
|
||||
let test = ConvTranspose1dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size: 3,
|
||||
padding: 1,
|
||||
padding_out: 0,
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
groups: 2,
|
||||
length: 4,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from_floats([[
|
||||
[0., 1., 4., 7.],
|
||||
[32., 59., 71., 59.],
|
||||
]]));
|
||||
}
|
||||
|
||||
struct ConvTranspose1dTestCase {
|
||||
batch_size: usize,
|
||||
channels_in: usize,
|
||||
|
@ -73,13 +98,18 @@ mod tests {
|
|||
padding_out: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
groups: usize,
|
||||
length: usize,
|
||||
}
|
||||
|
||||
impl ConvTranspose1dTestCase {
|
||||
fn assert_output(self, y: TestTensor<3>) {
|
||||
let shape_x = Shape::new([self.batch_size, self.channels_in, self.length]);
|
||||
let shape_weights = Shape::new([self.channels_in, self.channels_out, self.kernel_size]);
|
||||
let shape_weights = Shape::new([
|
||||
self.channels_in,
|
||||
self.channels_out / self.groups,
|
||||
self.kernel_size,
|
||||
]);
|
||||
let weights = TestTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_weights.num_elements())
|
||||
.reshape(shape_weights)
|
||||
|
@ -101,10 +131,13 @@ mod tests {
|
|||
x,
|
||||
weights,
|
||||
Some(bias),
|
||||
self.stride,
|
||||
self.padding,
|
||||
self.padding_out,
|
||||
self.dilation,
|
||||
ConvTransposeOptions::new(
|
||||
[self.stride],
|
||||
[self.padding],
|
||||
[self.padding_out],
|
||||
[self.dilation],
|
||||
self.groups,
|
||||
),
|
||||
);
|
||||
|
||||
y.to_data().assert_approx_eq(&output.into_data(), 3);
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::module::conv_transpose2d;
|
||||
use burn_tensor::ops::ConvTransposeOptions;
|
||||
use burn_tensor::{Data, Shape, Tensor};
|
||||
|
||||
#[test]
|
||||
|
@ -20,6 +21,7 @@ mod tests {
|
|||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
groups: 1,
|
||||
height: 2,
|
||||
width: 2,
|
||||
};
|
||||
|
@ -42,6 +44,7 @@ mod tests {
|
|||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
groups: 1,
|
||||
height: 4,
|
||||
width: 4,
|
||||
};
|
||||
|
@ -84,6 +87,7 @@ mod tests {
|
|||
stride_2: 2,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
groups: 1,
|
||||
height: 2,
|
||||
width: 2,
|
||||
};
|
||||
|
@ -112,6 +116,7 @@ mod tests {
|
|||
stride_2: 1,
|
||||
dilation_1: 2,
|
||||
dilation_2: 2,
|
||||
groups: 1,
|
||||
height: 2,
|
||||
width: 2,
|
||||
};
|
||||
|
@ -150,6 +155,7 @@ mod tests {
|
|||
stride_2: 2,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
groups: 1,
|
||||
height: 4,
|
||||
width: 4,
|
||||
};
|
||||
|
@ -178,6 +184,94 @@ mod tests {
|
|||
]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose2d_groups_2() {
|
||||
let test = ConvTranspose2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
channels_out: 2,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 1,
|
||||
padding_2: 1,
|
||||
padding_out_1: 0,
|
||||
padding_out_2: 0,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
groups: 2,
|
||||
height: 2,
|
||||
width: 2,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from_floats([[
|
||||
[[5., 11.], [23., 29.]],
|
||||
[[236., 258.], [302., 324.]],
|
||||
]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose2d_groups_different_channels() {
|
||||
let test = ConvTranspose2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
channels_out: 6,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 0,
|
||||
padding_2: 0,
|
||||
padding_out_1: 0,
|
||||
padding_out_2: 0,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
groups: 2,
|
||||
height: 2,
|
||||
width: 2,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from_floats([[
|
||||
[
|
||||
[0.0000e+00, 0.0000e+00, 1.0000e+00, 2.0000e+00],
|
||||
[0.0000e+00, 5.0000e+00, 1.1000e+01, 1.1000e+01],
|
||||
[6.0000e+00, 2.3000e+01, 2.9000e+01, 2.3000e+01],
|
||||
[1.2000e+01, 3.2000e+01, 3.7000e+01, 2.4000e+01],
|
||||
],
|
||||
[
|
||||
[1.0000e+00, 1.0000e+01, 1.1000e+01, 1.2000e+01],
|
||||
[1.9000e+01, 6.0000e+01, 6.6000e+01, 4.8000e+01],
|
||||
[2.5000e+01, 7.8000e+01, 8.4000e+01, 6.0000e+01],
|
||||
[3.1000e+01, 7.8000e+01, 8.3000e+01, 5.2000e+01],
|
||||
],
|
||||
[
|
||||
[2.0000e+00, 2.0000e+01, 2.1000e+01, 2.2000e+01],
|
||||
[3.8000e+01, 1.1500e+02, 1.2100e+02, 8.5000e+01],
|
||||
[4.4000e+01, 1.3300e+02, 1.3900e+02, 9.7000e+01],
|
||||
[5.0000e+01, 1.2400e+02, 1.2900e+02, 8.0000e+01],
|
||||
],
|
||||
[
|
||||
[1.1100e+02, 2.5000e+02, 2.5900e+02, 1.4800e+02],
|
||||
[2.8500e+02, 6.3400e+02, 6.5600e+02, 3.6600e+02],
|
||||
[3.1500e+02, 7.0000e+02, 7.2200e+02, 4.0200e+02],
|
||||
[2.0100e+02, 4.3800e+02, 4.5100e+02, 2.4800e+02],
|
||||
],
|
||||
[
|
||||
[1.4800e+02, 3.3200e+02, 3.4100e+02, 1.9400e+02],
|
||||
[3.7600e+02, 8.3300e+02, 8.5500e+02, 4.7500e+02],
|
||||
[4.0600e+02, 8.9900e+02, 9.2100e+02, 5.1100e+02],
|
||||
[2.5600e+02, 5.5600e+02, 5.6900e+02, 3.1200e+02],
|
||||
],
|
||||
[
|
||||
[1.8500e+02, 4.1400e+02, 4.2300e+02, 2.4000e+02],
|
||||
[4.6700e+02, 1.0320e+03, 1.0540e+03, 5.8400e+02],
|
||||
[4.9700e+02, 1.0980e+03, 1.1200e+03, 6.2000e+02],
|
||||
[3.1100e+02, 6.7400e+02, 6.8700e+02, 3.7600e+02],
|
||||
],
|
||||
]]));
|
||||
}
|
||||
|
||||
struct ConvTranspose2dTestCase {
|
||||
batch_size: usize,
|
||||
channels_in: usize,
|
||||
|
@ -192,6 +286,7 @@ mod tests {
|
|||
stride_2: usize,
|
||||
dilation_1: usize,
|
||||
dilation_2: usize,
|
||||
groups: usize,
|
||||
height: usize,
|
||||
width: usize,
|
||||
}
|
||||
|
@ -201,7 +296,7 @@ mod tests {
|
|||
let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]);
|
||||
let shape_weights = Shape::new([
|
||||
self.channels_in,
|
||||
self.channels_out,
|
||||
self.channels_out / self.groups,
|
||||
self.kernel_size_1,
|
||||
self.kernel_size_2,
|
||||
]);
|
||||
|
@ -226,10 +321,13 @@ mod tests {
|
|||
x,
|
||||
weights,
|
||||
Some(bias),
|
||||
[self.stride_1, self.stride_2],
|
||||
[self.padding_1, self.padding_2],
|
||||
[self.padding_out_1, self.padding_out_2],
|
||||
[self.dilation_1, self.dilation_2],
|
||||
ConvTransposeOptions::new(
|
||||
[self.stride_1, self.stride_2],
|
||||
[self.padding_1, self.padding_2],
|
||||
[self.padding_out_1, self.padding_out_2],
|
||||
[self.dilation_1, self.dilation_2],
|
||||
self.groups,
|
||||
),
|
||||
);
|
||||
|
||||
y.to_data().assert_approx_eq(&output.into_data(), 3);
|
||||
|
|
Loading…
Reference in New Issue