Feat/avg pool1d (#349)

This commit is contained in:
Nathaniel Simard 2023-05-15 08:29:45 -04:00 committed by GitHub
parent 43154ce7e0
commit 976102fec0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 343 additions and 12 deletions

View File

@ -270,6 +270,40 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
) -> ADTensor<B, 3> {
todo!("Transposed 1D convolution doesn't yet support backward.");
}
fn avg_pool1d(
x: ADTensor<B, 3>,
kernel_size: usize,
stride: usize,
padding: usize,
) -> ADTensor<B, 3> {
#[derive(Debug)]
struct AvgPool1D;
impl<B: Backend> Backward<B, 3, 1> for AvgPool1D {
type State = (B::TensorPrimitive<3>, usize, usize, usize);
fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
let [node_parent] = ops.parents;
let grad = grads.consume::<B, 3>(&ops.node);
let (x, kernel_size, stride, padding) = ops.state;
if let Some(node) = node_parent {
let grad = B::avg_pool1d_backward(x, grad, kernel_size, stride, padding);
grads.register::<B, 3>(node, grad);
}
}
}
match AvgPool1D.prepare([x.node], [x.graph]).statefull() {
OpsKind::Tracked(prep) => {
let output = B::avg_pool1d(x.primitive.clone(), kernel_size, stride, padding);
prep.finish((x.primitive, kernel_size, stride, padding), output)
}
OpsKind::UnTracked(prep) => {
prep.finish(B::avg_pool1d(x.primitive, kernel_size, stride, padding))
}
}
}
fn avg_pool2d(
x: ADTensor<B, 4>,

View File

@ -0,0 +1,68 @@
#[burn_tensor_testgen::testgen(ad_avg_pool1d)]
mod tests {
use super::*;
use burn_tensor::module::avg_pool1d;
use burn_tensor::{Data, Shape, Tensor};
#[test]
fn test_avg_pool1d_simple() {
let test = AvgPool1dTestCase {
batch_size: 1,
channels: 1,
kernel_size: 3,
padding: 0,
stride: 1,
length: 6,
};
test.assert_output(TestTensor::from_floats([[[
0.3333, 0.6667, 1.0000, 1.0000, 0.6667, 0.3333,
]]]));
}
#[test]
fn test_avg_pool1d_complex() {
let test = AvgPool1dTestCase {
batch_size: 1,
channels: 2,
kernel_size: 3,
padding: 1,
stride: 2,
length: 6,
};
test.assert_output(TestTensor::from_floats([[
[0.3333, 0.6667, 0.3333, 0.6667, 0.3333, 0.3333],
[0.3333, 0.6667, 0.3333, 0.6667, 0.3333, 0.3333],
]]));
}
struct AvgPool1dTestCase {
batch_size: usize,
channels: usize,
kernel_size: usize,
padding: usize,
stride: usize,
length: usize,
}
impl AvgPool1dTestCase {
fn assert_output(self, x_grad: TestTensor<3>) {
let shape_x = Shape::new([self.batch_size, self.channels, self.length]);
let x = TestADTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements())
.reshape(shape_x)
.into_data()
.convert(),
)
.require_grad();
let output = avg_pool1d(x.clone(), self.kernel_size, self.stride, self.padding);
let grads = output.backward();
let x_grad_actual = x.grad(&grads).unwrap();
x_grad
.to_data()
.assert_approx_eq(&x_grad_actual.into_data(), 3);
}
}
}

View File

@ -1,5 +1,6 @@
mod add;
mod aggregation;
mod avgpool1d;
mod avgpool2d;
mod backward;
mod broadcast;
@ -52,6 +53,7 @@ macro_rules! testgen_all {
burn_autodiff::testgen_ad_conv1d!();
burn_autodiff::testgen_ad_conv2d!();
burn_autodiff::testgen_ad_max_pool2d!();
burn_autodiff::testgen_ad_avg_pool1d!();
burn_autodiff::testgen_ad_avg_pool2d!();
burn_autodiff::testgen_module_backward!();

View File

@ -125,16 +125,10 @@ impl<B: Backend> Conv1d<B> {
/// - input: [batch_size, channels_in, length_in],
/// - output: [batch_size, channels_out, length_out],
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
let same_padding = || {
let [_batch_size, _channels_in, length] = input.dims();
calculate_conv_padding(self.kernel_size, self.stride, length, length)
};
let padding = match &self.padding {
Conv1dPaddingConfig::Valid => 0,
Conv1dPaddingConfig::Same => same_padding(),
Conv1dPaddingConfig::Explicit(value) => *value,
};
let [_batch_size, _channels, length] = input.dims();
let padding = self
.padding
.calculate_padding_1d(length, self.kernel_size, self.stride);
conv1d(
input,
@ -145,6 +139,23 @@ impl<B: Backend> Conv1d<B> {
}
}
impl Conv1dPaddingConfig {
pub(crate) fn calculate_padding_1d(
&self,
length: usize,
kernel_size: usize,
stride: usize,
) -> usize {
let same_padding = || calculate_conv_padding(kernel_size, stride, length, length);
match self {
Conv1dPaddingConfig::Valid => 0,
Conv1dPaddingConfig::Same => same_padding(),
Conv1dPaddingConfig::Explicit(value) => *value,
}
}
}
#[cfg(test)]
mod tests {
use burn_tensor::Data;

View File

@ -0,0 +1,62 @@
use crate as burn;
use crate::config::Config;
use crate::module::Module;
use crate::nn::conv::Conv1dPaddingConfig;
use crate::tensor::backend::Backend;
use crate::tensor::Tensor;
use burn_tensor::module::avg_pool1d;
/// Configuration to create a [1D avg pooling](AvgPool1d) layer.
#[derive(Config)]
pub struct AvgPool1dConfig {
/// The number of channels.
pub channels: usize,
/// The size of the kernel.
pub kernel_size: usize,
/// The stride.
#[config(default = "1")]
pub stride: usize,
/// The padding configuration.
#[config(default = "AvgPool1dPaddingConfig::Valid")]
pub padding: AvgPool1dPaddingConfig,
}
/// Padding configuration for 1D avg pooling [config](AvgPool1dConfig).
pub type AvgPool1dPaddingConfig = Conv1dPaddingConfig;
/// Applies a 1D avg pooling over input tensors.
#[derive(Module, Debug, Clone)]
pub struct AvgPool1d {
stride: usize,
kernel_size: usize,
padding: AvgPool1dPaddingConfig,
}
impl AvgPool1dConfig {
/// Initialize a new [avg pool 1d](AvgPool1d) module.
pub fn init(&self) -> AvgPool1d {
AvgPool1d {
stride: self.stride,
kernel_size: self.kernel_size,
padding: self.padding.clone(),
}
}
}
impl AvgPool1d {
/// Applies the forward pass on the input tensor.
///
/// # Shapes
///
/// - input: [batch_size, channels, length_in],
/// - output: [batch_size, channels, length_out],
pub fn forward<B: Backend>(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
let [_batch_size, _channels, length] = input.dims();
let padding = self
.padding
.calculate_padding_1d(length, self.kernel_size, self.stride);
avg_pool1d(input, self.kernel_size, self.stride, padding)
}
}

View File

@ -7,7 +7,7 @@ use crate::tensor::backend::Backend;
use crate::tensor::Tensor;
use burn_tensor::module::avg_pool2d;
/// Configuration to create an [2D avg pooling](AvgPool2d) layer.
/// Configuration to create a [2D avg pooling](AvgPool2d) layer.
#[derive(Config)]
pub struct AvgPool2dConfig {
/// The number of channels.

View File

@ -1,5 +1,7 @@
mod avg_pool1d;
mod avg_pool2d;
mod max_pool2d;
pub use avg_pool1d::*;
pub use avg_pool2d::*;
pub use max_pool2d::*;

View File

@ -106,6 +106,23 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
TchTensor::new(tensor)
}
fn avg_pool1d(
x: TchTensor<E, 3>,
kernel_size: usize,
stride: usize,
padding: usize,
) -> TchTensor<E, 3> {
let tensor = tch::Tensor::avg_pool1d(
&x.tensor,
[kernel_size as i64],
[stride as i64],
[padding as i64],
false,
true,
);
TchTensor::new(tensor)
}
fn avg_pool2d(
x: TchTensor<E, 4>,
kernel_size: [usize; 2],

View File

@ -110,6 +110,19 @@ where
Tensor::new(B::avg_pool2d(x.primitive, kernel_size, stride, padding))
}
/// Applies a [1D avg pooling](crate::ops::ModuleOps::avg_pool1d).
pub fn avg_pool1d<B>(
x: Tensor<B, 3>,
kernel_size: usize,
stride: usize,
padding: usize,
) -> Tensor<B, 3>
where
B: Backend,
{
Tensor::new(B::avg_pool1d(x.primitive, kernel_size, stride, padding))
}
/// Applies a [2D max pooling with indexes](crate::ops::ModuleOps::max_pool2d_with_indexes).
pub fn max_pool2d_with_indexes<B>(
x: Tensor<B, 4>,

View File

@ -1,4 +1,4 @@
use super::conv;
use super::{conv, pool};
use crate::backend::Backend;
/// Gradient computed during the backward pass for each tensor used by [conv2d](ModuleOps::conv2d).
@ -136,6 +136,29 @@ pub trait ModuleOps<B: Backend> {
) -> Conv1dBackward<B> {
conv::conv1d_backward(x, weight, bias, output_grad, options)
}
/// One dimensional avg pooling.
///
/// # Shapes
///
/// x: [batch_size, channels, length],
fn avg_pool1d(
x: B::TensorPrimitive<3>,
kernel_size: usize,
stride: usize,
padding: usize,
) -> B::TensorPrimitive<3> {
pool::avg_pool1d_from_avg_pool2d::<B>(x, kernel_size, stride, padding)
}
/// Backward pass for the [avg pooling 1d](ModuleOps::avg_pool1d) operation.
fn avg_pool1d_backward(
x: B::TensorPrimitive<3>,
grad: B::TensorPrimitive<3>,
kernel_size: usize,
stride: usize,
padding: usize,
) -> B::TensorPrimitive<3> {
pool::avg_pool1d_backward_from_avg_pool2d::<B>(x, grad, kernel_size, stride, padding)
}
/// Two dimensional avg pooling.
///
/// # Shapes

View File

@ -1,4 +1,5 @@
pub mod conv;
pub mod pool;
mod base;

View File

@ -0,0 +1,35 @@
use crate::{backend::Backend, Shape};
pub(crate) fn avg_pool1d_from_avg_pool2d<B: Backend>(
x: B::TensorPrimitive<3>,
kernel_size: usize,
stride: usize,
padding: usize,
) -> B::TensorPrimitive<3> {
let [batch_size, channels, length] = B::shape(&x).dims;
let x = B::reshape(x, Shape::from([batch_size, channels, length, 1]));
let x = B::avg_pool2d(x, [kernel_size, 1], [stride, 1], [padding, 0]);
let [batch_size, channels, length, _] = B::shape(&x).dims;
B::reshape(x, Shape::from([batch_size, channels, length]))
}
pub(crate) fn avg_pool1d_backward_from_avg_pool2d<B: Backend>(
x: B::TensorPrimitive<3>,
grad: B::TensorPrimitive<3>,
kernel_size: usize,
stride: usize,
padding: usize,
) -> B::TensorPrimitive<3> {
let [batch_size, channels, length_in] = B::shape(&x).dims;
let [_, _, length_out] = B::shape(&grad).dims;
let x = B::reshape(x, Shape::from([batch_size, channels, length_in, 1]));
let grad_x = B::reshape(grad, Shape::from([batch_size, channels, length_out, 1]));
let grad_x = B::avg_pool2d_backward(x, grad_x, [kernel_size, 1], [stride, 1], [padding, 0]);
B::reshape(grad_x, Shape::from([batch_size, channels, length_in]))
}

View File

@ -19,6 +19,7 @@ macro_rules! testgen_all {
burn_tensor::testgen_module_conv_transpose1d!();
burn_tensor::testgen_module_conv_transpose2d!();
burn_tensor::testgen_module_max_pool2d!();
burn_tensor::testgen_module_avg_pool1d!();
burn_tensor::testgen_module_avg_pool2d!();
// test ops

View File

@ -0,0 +1,61 @@
#[burn_tensor_testgen::testgen(module_avg_pool1d)]
mod tests {
use super::*;
use burn_tensor::module::avg_pool1d;
use burn_tensor::{Data, Shape, Tensor};
#[test]
fn test_avg_pool1d_simple() {
let test = AvgPool1dTestCase {
batch_size: 1,
channels: 1,
kernel_size: 3,
padding: 0,
stride: 1,
length: 6,
};
test.assert_output(TestTensor::from_floats([[[1., 2., 3., 4.]]]));
}
#[test]
fn test_avg_pool1d_complex() {
let test = AvgPool1dTestCase {
batch_size: 1,
channels: 2,
kernel_size: 3,
padding: 1,
stride: 2,
length: 6,
};
test.assert_output(TestTensor::from_floats([[
[0.3333, 2.0000, 4.0000],
[4.3333, 8.0000, 10.0000],
]]));
}
struct AvgPool1dTestCase {
batch_size: usize,
channels: usize,
kernel_size: usize,
padding: usize,
stride: usize,
length: usize,
}
impl AvgPool1dTestCase {
fn assert_output(self, y: TestTensor<3>) {
let shape_x = Shape::new([self.batch_size, self.channels, self.length]);
let x = TestTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements())
.reshape(shape_x)
.into_data()
.convert(),
);
let output = avg_pool1d(x, self.kernel_size, self.stride, self.padding);
y.to_data().assert_approx_eq(&output.into_data(), 3);
}
}
}

View File

@ -1,3 +1,4 @@
mod avgpool1d;
mod avgpool2d;
mod conv1d;
mod conv2d;