Max pool1d (#602)

This commit is contained in:
Caio Piccirillo 2023-08-09 22:13:48 +02:00 committed by GitHub
parent 1f01fcb640
commit cb283a9e5b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 541 additions and 1 deletions

View File

@ -564,6 +564,79 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
panic!("Can't differentiate avg pool 2d backward.");
}
fn max_pool1d(
x: ADTensor<B, 3>,
kernel_size: usize,
stride: usize,
padding: usize,
) -> ADTensor<B, 3> {
match MaxPool1D.prepare([x.node], [x.graph]).statefull() {
OpsKind::Tracked(prep) => {
let output =
B::max_pool1d_with_indices(x.primitive.clone(), kernel_size, stride, padding);
prep.finish(
(x.primitive, output.indices, kernel_size, stride, padding),
output.output,
)
}
OpsKind::UnTracked(prep) => {
prep.finish(B::max_pool1d(x.primitive, kernel_size, stride, padding))
}
}
}
fn max_pool1d_with_indices(
x: ADTensor<B, 3>,
kernel_size: usize,
stride: usize,
padding: usize,
) -> MaxPool1dWithIndices<ADBackendDecorator<B>> {
match MaxPool1D.prepare([x.node], [x.graph]).statefull() {
OpsKind::Tracked(prep) => {
let output =
B::max_pool1d_with_indices(x.primitive.clone(), kernel_size, stride, padding);
let output_tensor = prep.finish(
(
x.primitive,
output.indices.clone(),
kernel_size,
stride,
padding,
),
output.output,
);
MaxPool1dWithIndices::new(output_tensor, output.indices)
}
OpsKind::UnTracked(prep) => {
let output = B::max_pool1d_with_indices(x.primitive, kernel_size, stride, padding);
let output_tensor = prep.finish(output.output);
MaxPool1dWithIndices::new(output_tensor, output.indices)
}
}
}
fn max_pool1d_with_indices_backward(
x: ADTensor<B, 3>,
kernel_size: usize,
stride: usize,
padding: usize,
output_grad: ADTensor<B, 3>,
indices: IntTensor<B, 3>,
) -> MaxPool1dBackward<ADBackendDecorator<B>> {
let output = B::max_pool1d_with_indices_backward(
x.primitive,
kernel_size,
stride,
padding,
output_grad.primitive,
indices,
);
MaxPool1dBackward::new(ADTensor::new(output.x_grad))
}
fn max_pool2d(
x: ADTensor<B, 4>,
kernel_size: [usize; 2],
@ -694,6 +767,26 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
}
}
#[derive(Debug)]
struct MaxPool1D;
impl<B: Backend> Backward<B, 3, 1> for MaxPool1D {
type State = (B::TensorPrimitive<3>, IntTensor<B, 3>, usize, usize, usize);
fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
let [node_parent] = ops.parents;
let grad = grads.consume::<B, 3>(&ops.node);
let (x, indices, kernel_size, stride, padding) = ops.state;
if let Some(node) = node_parent {
let grad =
B::max_pool1d_with_indices_backward(x, kernel_size, stride, padding, grad, indices);
grads.register::<B, 3>(node, grad.x_grad);
}
}
}
#[derive(Debug)]
struct MaxPool2D;

View File

@ -0,0 +1,85 @@
#[burn_tensor_testgen::testgen(ad_max_pool1d)]
mod tests {
use super::*;
use burn_tensor::{module::max_pool1d, Data};
#[test]
fn test_max_pool1d_simple() {
let batch_size = 1;
let channels_in = 1;
let kernel_size = 4;
let padding = 0;
let stride = 1;
let x = TestADTensor::from_floats([[[0.9861, 0.5474, 0.4477, 0.0732, 0.3548, 0.8221]]])
.require_grad();
let x_grad_expected = TestADTensor::from_floats([[[1., 1., 0., 0., 0., 1.]]]);
let output = max_pool1d(x.clone(), kernel_size, stride, padding);
let grads = output.backward();
// Asserts
let x_grad_actual = x.grad(&grads).unwrap();
x_grad_expected
.to_data()
.assert_approx_eq(&x_grad_actual.to_data(), 3);
}
#[test]
fn test_max_pool1d_complex() {
let batch_size = 1;
let channels_in = 1;
let kernel_size = 4;
let padding = 0;
let stride = 1;
let x = TestADTensor::from_floats([[[
0.5388, 0.0676, 0.7122, 0.8316, 0.0653, 0.9154, 0.1536, 0.9089, 0.8016, 0.7518, 0.2073,
0.0501, 0.8811, 0.5604, 0.5075, 0.4384, 0.9963, 0.9698, 0.4988, 0.2609, 0.3391, 0.2230,
0.4610, 0.5365, 0.6880,
]]])
.require_grad();
let x_grad_expected = TestADTensor::from_floats([[[
0., 0., 0., 2., 0., 4., 0., 2., 1., 0., 0., 0., 4., 0., 0., 0., 4., 1., 1., 0., 0., 0.,
1., 1., 1.,
]]]);
let output = max_pool1d(x.clone(), kernel_size, stride, padding);
let grads = output.backward();
// Asserts
let x_grad_actual = x.grad(&grads).unwrap();
x_grad_expected
.to_data()
.assert_approx_eq(&x_grad_actual.to_data(), 3);
}
#[test]
fn test_max_pool1d_complex_with_padding() {
let batch_size = 1;
let channels_in = 1;
let kernel_size = 4;
let padding = 2;
let stride = 1;
let x = TestADTensor::from_floats([[[
0.5388, 0.0676, 0.7122, 0.8316, 0.0653, 0.9154, 0.1536, 0.9089, 0.8016, 0.7518, 0.2073,
0.0501, 0.8811, 0.5604, 0.5075, 0.4384, 0.9963, 0.9698, 0.4988, 0.2609, 0.3391, 0.2230,
0.4610, 0.5365, 0.6880,
]]])
.require_grad();
let x_grad_expected = TestADTensor::from_floats([[[
1., 0., 1., 2., 0., 4., 0., 2., 1., 0., 0., 0., 4., 0., 0., 0., 4., 1., 1., 0., 0., 0.,
1., 1., 3.,
]]]);
let output = max_pool1d(x.clone(), kernel_size, stride, padding);
let grads = output.backward();
// Asserts
let x_grad_actual = x.grad(&grads).unwrap();
x_grad_expected
.to_data()
.assert_approx_eq(&x_grad_actual.to_data(), 3);
}
}

View File

@ -27,6 +27,7 @@ mod log1p;
mod mask;
mod matmul;
mod maxmin;
mod maxpool1d;
mod maxpool2d;
mod mul;
mod multithread;
@ -61,6 +62,7 @@ macro_rules! testgen_all {
burn_autodiff::testgen_ad_conv2d!();
burn_autodiff::testgen_ad_conv_transpose1d!();
burn_autodiff::testgen_ad_conv_transpose2d!();
burn_autodiff::testgen_ad_max_pool1d!();
burn_autodiff::testgen_ad_max_pool2d!();
burn_autodiff::testgen_ad_avg_pool1d!();
burn_autodiff::testgen_ad_avg_pool2d!();

View File

@ -0,0 +1,59 @@
use crate as burn;
use crate::config::Config;
use crate::module::Module;
use crate::nn::PaddingConfig1d;
use crate::tensor::backend::Backend;
use crate::tensor::Tensor;
use burn_tensor::module::max_pool1d;
/// Configuration to create a [1D max pooling](MaxPool1d) layer.
#[derive(Config)]
pub struct MaxPool1dConfig {
/// The number of channels.
pub channels: usize,
/// The size of the kernel.
pub kernel_size: usize,
/// The stride.
#[config(default = "1")]
pub stride: usize,
/// The padding configuration.
#[config(default = "PaddingConfig1d::Valid")]
pub padding: PaddingConfig1d,
}
/// Applies a 1D max pooling over input tensors.
#[derive(Module, Debug, Clone)]
pub struct MaxPool1d {
stride: usize,
kernel_size: usize,
padding: PaddingConfig1d,
}
impl MaxPool1dConfig {
/// Initialize a new [max pool 1d](MaxPool1d) module.
pub fn init(&self) -> MaxPool1d {
MaxPool1d {
stride: self.stride,
kernel_size: self.kernel_size,
padding: self.padding.clone(),
}
}
}
impl MaxPool1d {
/// Applies the forward pass on the input tensor.
///
/// # Shapes
///
/// - input: [batch_size, channels, length_in],
/// - output: [batch_size, channels, length_out],
pub fn forward<B: Backend>(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
let [_batch_size, _channels, length] = input.dims();
let padding = self
.padding
.calculate_padding_1d(length, self.kernel_size, self.stride);
max_pool1d(input, self.kernel_size, self.stride, padding)
}
}

View File

@ -2,10 +2,12 @@ mod adaptive_avg_pool1d;
mod adaptive_avg_pool2d;
mod avg_pool1d;
mod avg_pool2d;
mod max_pool1d;
mod max_pool2d;
pub use adaptive_avg_pool1d::*;
pub use adaptive_avg_pool2d::*;
pub use avg_pool1d::*;
pub use avg_pool2d::*;
pub use max_pool1d::*;
pub use max_pool2d::*;

View File

@ -1,6 +1,7 @@
use crate::{element::TchElement, TchBackend, TchTensor};
use burn_tensor::ops::{
ConvOptions, ConvTransposeOptions, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps,
ConvOptions, ConvTransposeOptions, MaxPool1dWithIndices, MaxPool2dBackward,
MaxPool2dWithIndices, ModuleOps,
};
impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
@ -163,6 +164,42 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
TchTensor::new(tensor)
}
fn max_pool1d(
x: TchTensor<E, 3>,
kernel_size: usize,
stride: usize,
padding: usize,
) -> TchTensor<E, 3> {
let tensor = tch::Tensor::max_pool1d(
&x.tensor,
kernel_size as i64,
stride as i64,
padding as i64,
1,
false,
);
TchTensor::new(tensor)
}
fn max_pool1d_with_indices(
x: TchTensor<E, 3>,
kernel_size: usize,
stride: usize,
padding: usize,
) -> MaxPool1dWithIndices<TchBackend<E>> {
let (tensor, indices) = tch::Tensor::max_pool1d_with_indices(
&x.tensor,
kernel_size as i64,
stride as i64,
padding as i64,
1,
false,
);
MaxPool1dWithIndices::new(TchTensor::new(tensor), TchTensor::new(indices))
}
fn max_pool2d(
x: TchTensor<E, 4>,
kernel_size: [usize; 2],

View File

@ -84,6 +84,19 @@ where
))
}
/// Applies a [1D max pooling](crate::ops::ModuleOps::max_pool1d).
pub fn max_pool1d<B>(
x: Tensor<B, 3>,
kernel_size: usize,
stride: usize,
padding: usize,
) -> Tensor<B, 3>
where
B: Backend,
{
Tensor::new(B::max_pool1d(x.primitive, kernel_size, stride, padding))
}
/// Applies a [2D max pooling](crate::ops::ModuleOps::max_pool2d).
pub fn max_pool2d<B>(
x: Tensor<B, 4>,
@ -123,6 +136,21 @@ where
Tensor::new(B::avg_pool1d(x.primitive, kernel_size, stride, padding))
}
/// Applies a [1D max pooling](crate::ops::ModuleOps::max_pool1d).
pub fn max_pool1d_with_indices<B>(
x: Tensor<B, 3>,
kernel_size: usize,
stride: usize,
padding: usize,
) -> (Tensor<B, 3>, Tensor<B, 3, Int>)
where
B: Backend,
{
let output = B::max_pool1d_with_indices(x.primitive, kernel_size, stride, padding);
(Tensor::new(output.output), Tensor::new(output.indices))
}
/// Applies a [2D max pooling with indices](crate::ops::ModuleOps::max_pool2d_with_indices).
pub fn max_pool2d_with_indices<B>(
x: Tensor<B, 4>,

View File

@ -14,6 +14,23 @@ pub struct Conv2dBackward<B: Backend> {
pub bias_grad: Option<B::TensorPrimitive<1>>,
}
/// Gradient computed during the backward pass for each tensor used by [max_pool1d](ModuleOps::max_pool1d).
#[derive(new)]
pub struct MaxPool1dBackward<B: Backend> {
/// Gradient.
pub x_grad: B::TensorPrimitive<3>,
}
/// Results from [max_pool1d](ModuleOps::max_pool1d_with_indices).
#[derive(new)]
pub struct MaxPool1dWithIndices<B: Backend> {
/// The output tensor.
pub output: B::TensorPrimitive<3>,
/// The indices tensor.
pub indices: B::IntTensorPrimitive<3>,
}
/// Gradient computed during the backward pass for each tensor used by [max_pool2d](ModuleOps::max_pool2d).
#[derive(new)]
pub struct MaxPool2dBackward<B: Backend> {
@ -299,6 +316,52 @@ pub trait ModuleOps<B: Backend> {
) -> B::TensorPrimitive<3> {
pool::adaptive_avg_pool1d_backward_from_2d::<B>(x, grad)
}
/// One dimensional max pooling.
///
/// # Shapes
///
/// x: [batch_size, channels, length],
fn max_pool1d(
x: B::TensorPrimitive<3>,
kernel_size: usize,
stride: usize,
padding: usize,
) -> B::TensorPrimitive<3> {
pool::max_pool1d_from_2d::<B>(x, kernel_size, stride, padding)
}
/// One dimensional max pooling with indices.
///
/// # Shapes
///
/// x: [batch_size, channels, height, width],
fn max_pool1d_with_indices(
x: B::TensorPrimitive<3>,
kernel_size: usize,
stride: usize,
padding: usize,
) -> MaxPool1dWithIndices<B> {
pool::max_pool1d_with_indices_from_2d::<B>(x, kernel_size, stride, padding)
}
/// Backward pass for the [max pooling 1d](ModuleOps::max_pool1d_with_indices) operation.
fn max_pool1d_with_indices_backward(
x: B::TensorPrimitive<3>,
kernel_size: usize,
stride: usize,
padding: usize,
output_grad: B::TensorPrimitive<3>,
indices: B::IntTensorPrimitive<3>,
) -> MaxPool1dBackward<B> {
pool::max_pool1d_with_indices_backward_from_2d::<B>(
x,
kernel_size,
stride,
padding,
output_grad,
indices,
)
}
/// Two dimensional max pooling.
///
/// # Shapes

View File

@ -1,5 +1,7 @@
use crate::{backend::Backend, Shape};
use super::{MaxPool1dBackward, MaxPool1dWithIndices};
pub(crate) fn avg_pool1d_from_2d<B: Backend>(
x: B::TensorPrimitive<3>,
kernel_size: usize,
@ -62,3 +64,72 @@ pub(crate) fn adaptive_avg_pool1d_backward_from_2d<B: Backend>(
B::reshape(grad_x, Shape::from([batch_size, channels, length_in]))
}
pub(crate) fn max_pool1d_from_2d<B: Backend>(
x: B::TensorPrimitive<3>,
kernel_size: usize,
stride: usize,
padding: usize,
) -> B::TensorPrimitive<3> {
let [batch_size, channels, length] = B::shape(&x).dims;
let x = B::reshape(x, Shape::from([batch_size, channels, length, 1]));
let x = B::max_pool2d(x, [kernel_size, 1], [stride, 1], [padding, 0]);
let [batch_size, channels, length, _] = B::shape(&x).dims;
B::reshape(x, Shape::from([batch_size, channels, length]))
}
pub(crate) fn max_pool1d_with_indices_from_2d<B: Backend>(
x: B::TensorPrimitive<3>,
kernel_size: usize,
stride: usize,
padding: usize,
) -> MaxPool1dWithIndices<B> {
let [batch_size, channels, length] = B::shape(&x).dims;
let x = B::reshape(x, Shape::from([batch_size, channels, 1, length]));
let x = B::max_pool2d_with_indices(x, [1, kernel_size], [1, stride], [0, padding]);
let [batch_size, channels, _, length] = B::shape(&x.output).dims;
let output = B::reshape(x.output, Shape::from([batch_size, channels, length]));
let indices = B::int_reshape(
x.indices.clone(),
Shape::from([batch_size, channels, length]),
);
MaxPool1dWithIndices::new(output, indices)
}
pub(crate) fn max_pool1d_with_indices_backward_from_2d<B: Backend>(
x: B::TensorPrimitive<3>,
kernel_size: usize,
stride: usize,
padding: usize,
output_grad: B::TensorPrimitive<3>,
indices: B::IntTensorPrimitive<3>,
) -> MaxPool1dBackward<B> {
let [batch_size, channels, length_in] = B::shape(&x).dims;
let [_, _, length_out] = B::shape(&output_grad).dims;
let x = B::reshape(x, Shape::from([batch_size, channels, length_in, 1]));
let grad_x = B::reshape(
output_grad,
Shape::from([batch_size, channels, length_out, 1]),
);
let indices = B::int_reshape(indices, Shape::from([batch_size, channels, length_out, 1]));
let grad_x = B::max_pool2d_with_indices_backward(
x,
[kernel_size, 1],
[stride, 1],
[padding, 0],
grad_x,
indices,
)
.x_grad;
MaxPool1dBackward::new(B::reshape(
grad_x,
Shape::from([batch_size, channels, length_in]),
))
}

View File

@ -20,6 +20,7 @@ macro_rules! testgen_all {
burn_tensor::testgen_module_conv2d!();
burn_tensor::testgen_module_conv_transpose1d!();
burn_tensor::testgen_module_conv_transpose2d!();
burn_tensor::testgen_module_max_pool1d!();
burn_tensor::testgen_module_max_pool2d!();
burn_tensor::testgen_module_avg_pool1d!();
burn_tensor::testgen_module_avg_pool2d!();

View File

@ -0,0 +1,98 @@
#[burn_tensor_testgen::testgen(module_max_pool1d)]
mod tests {
use super::*;
use burn_tensor::module::{max_pool1d, max_pool1d_with_indices};
use burn_tensor::{backend::Backend, Data, Tensor};
type IntElem = <TestBackend as Backend>::IntElem;
#[test]
fn test_max_pool1d_simple() {
let batch_size = 2;
let channels_in = 2;
let kernel_size = 3;
let padding = 1;
let stride = 1;
let x = TestTensor::from_floats([[
[0.9861, 0.5474, 0.4477, 0.0732, 0.3548, 0.8221],
[0.8148, 0.5474, 0.9490, 0.7890, 0.5537, 0.5689],
]]);
let y = TestTensor::from_floats([[
[0.9861, 0.9861, 0.5474, 0.4477, 0.8221, 0.8221],
[0.8148, 0.9490, 0.9490, 0.9490, 0.7890, 0.5689],
]]);
let output = max_pool1d(x, kernel_size, stride, padding);
y.to_data().assert_approx_eq(&output.into_data(), 3);
}
#[test]
fn test_max_pool1d_different_padding_stride_kernel() {
let batch_size = 1;
let channels_in = 1;
let kernel_size = 3;
let padding = 1;
let stride = 2;
let x = TestTensor::from_floats([[[0.6309, 0.6112, 0.6998, 0.4708]]]);
let y = TestTensor::from_floats([[[0.6309, 0.6998]]]);
let output = max_pool1d(x, kernel_size, stride, padding);
y.to_data().assert_approx_eq(&output.into_data(), 3);
}
#[test]
fn test_max_pool1d_with_neg() {
let batch_size = 1;
let channels_in = 1;
let kernel_size = 3;
let padding = 1;
let stride = 1;
let x = TestTensor::from_floats([[[-0.6309, -0.6112, -0.6998, -0.4708]]]);
let y = TestTensor::from_floats([[[-0.6112, -0.6112, -0.4708, -0.4708]]]);
let output = max_pool1d(x, kernel_size, stride, padding);
y.to_data().assert_approx_eq(&output.into_data(), 3);
}
#[test]
fn test_max_pool1d_with_indices() {
let batch_size = 1;
let channels_in = 1;
let kernel_size = 2;
let padding = 1;
let stride = 1;
let x = TestTensor::from_floats([[[0.2479, 0.6386, 0.3166, 0.5742]]]);
let indices = Data::<IntElem, 3>::from([[[0, 1, 1, 3, 3]]]);
let y = TestTensor::from_floats([[[0.2479, 0.6386, 0.6386, 0.5742, 0.5742]]]);
let (output, output_indices) = max_pool1d_with_indices(x, kernel_size, stride, padding);
y.to_data().assert_approx_eq(&output.into_data(), 3);
assert_eq!(indices.value, output_indices.into_data().value);
}
#[test]
fn test_max_pool1d_complex() {
let batch_size = 1;
let channels_in = 1;
let kernel_size = 4;
let padding = 2;
let stride = 1;
let x = TestTensor::from_floats([[[0.5388, 0.0676, 0.7122, 0.8316, 0.0653]]]);
let indices = Data::<IntElem, 3>::from([[[0, 2, 3, 3, 3, 3]]]);
let y = TestTensor::from_floats([[[0.5388, 0.7122, 0.8316, 0.8316, 0.8316, 0.8316]]]);
let (output, output_indices) = max_pool1d_with_indices(x, kernel_size, stride, padding);
y.to_data().assert_approx_eq(&output.into_data(), 3);
assert_eq!(indices.value, output_indices.into_data().value);
}
}

View File

@ -7,4 +7,5 @@ mod conv2d;
mod conv_transpose1d;
mod conv_transpose2d;
mod forward;
mod maxpool1d;
mod maxpool2d;