From cb283a9e5b5aacc62a7c7684c329e270aefaca8d Mon Sep 17 00:00:00 2001 From: Caio Piccirillo <34453935+CaioPiccirillo@users.noreply.github.com> Date: Wed, 9 Aug 2023 22:13:48 +0200 Subject: [PATCH] Max pool1d (#602) --- burn-autodiff/src/ops/module.rs | 93 ++++++++++++++++++++ burn-autodiff/src/tests/maxpool1d.rs | 85 +++++++++++++++++++ burn-autodiff/src/tests/mod.rs | 2 + burn-core/src/nn/pool/max_pool1d.rs | 59 +++++++++++++ burn-core/src/nn/pool/mod.rs | 2 + burn-tch/src/ops/module.rs | 39 ++++++++- burn-tensor/src/tensor/module.rs | 28 +++++++ burn-tensor/src/tensor/ops/modules/base.rs | 63 ++++++++++++++ burn-tensor/src/tensor/ops/modules/pool.rs | 71 ++++++++++++++++ burn-tensor/src/tests/mod.rs | 1 + burn-tensor/src/tests/module/maxpool1d.rs | 98 ++++++++++++++++++++++ burn-tensor/src/tests/module/mod.rs | 1 + 12 files changed, 541 insertions(+), 1 deletion(-) create mode 100644 burn-autodiff/src/tests/maxpool1d.rs create mode 100644 burn-core/src/nn/pool/max_pool1d.rs create mode 100644 burn-tensor/src/tests/module/maxpool1d.rs diff --git a/burn-autodiff/src/ops/module.rs b/burn-autodiff/src/ops/module.rs index e99a4b924..ef458aefc 100644 --- a/burn-autodiff/src/ops/module.rs +++ b/burn-autodiff/src/ops/module.rs @@ -564,6 +564,79 @@ impl ModuleOps> for ADBackendDecorator { panic!("Can't differentiate avg pool 2d backward."); } + fn max_pool1d( + x: ADTensor, + kernel_size: usize, + stride: usize, + padding: usize, + ) -> ADTensor { + match MaxPool1D.prepare([x.node], [x.graph]).statefull() { + OpsKind::Tracked(prep) => { + let output = + B::max_pool1d_with_indices(x.primitive.clone(), kernel_size, stride, padding); + prep.finish( + (x.primitive, output.indices, kernel_size, stride, padding), + output.output, + ) + } + OpsKind::UnTracked(prep) => { + prep.finish(B::max_pool1d(x.primitive, kernel_size, stride, padding)) + } + } + } + + fn max_pool1d_with_indices( + x: ADTensor, + kernel_size: usize, + stride: usize, + padding: usize, + ) -> MaxPool1dWithIndices> { + match MaxPool1D.prepare([x.node], [x.graph]).statefull() { + OpsKind::Tracked(prep) => { + let output = + B::max_pool1d_with_indices(x.primitive.clone(), kernel_size, stride, padding); + + let output_tensor = prep.finish( + ( + x.primitive, + output.indices.clone(), + kernel_size, + stride, + padding, + ), + output.output, + ); + + MaxPool1dWithIndices::new(output_tensor, output.indices) + } + OpsKind::UnTracked(prep) => { + let output = B::max_pool1d_with_indices(x.primitive, kernel_size, stride, padding); + let output_tensor = prep.finish(output.output); + + MaxPool1dWithIndices::new(output_tensor, output.indices) + } + } + } + + fn max_pool1d_with_indices_backward( + x: ADTensor, + kernel_size: usize, + stride: usize, + padding: usize, + output_grad: ADTensor, + indices: IntTensor, + ) -> MaxPool1dBackward> { + let output = B::max_pool1d_with_indices_backward( + x.primitive, + kernel_size, + stride, + padding, + output_grad.primitive, + indices, + ); + MaxPool1dBackward::new(ADTensor::new(output.x_grad)) + } + fn max_pool2d( x: ADTensor, kernel_size: [usize; 2], @@ -694,6 +767,26 @@ impl ModuleOps> for ADBackendDecorator { } } +#[derive(Debug)] +struct MaxPool1D; + +impl Backward for MaxPool1D { + type State = (B::TensorPrimitive<3>, IntTensor, usize, usize, usize); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let [node_parent] = ops.parents; + let grad = grads.consume::(&ops.node); + let (x, indices, kernel_size, stride, padding) = ops.state; + + if let Some(node) = node_parent { + let grad = + B::max_pool1d_with_indices_backward(x, kernel_size, stride, padding, grad, indices); + + grads.register::(node, grad.x_grad); + } + } +} + #[derive(Debug)] struct MaxPool2D; diff --git a/burn-autodiff/src/tests/maxpool1d.rs b/burn-autodiff/src/tests/maxpool1d.rs new file mode 100644 index 000000000..456ff8302 --- /dev/null +++ b/burn-autodiff/src/tests/maxpool1d.rs @@ -0,0 +1,85 @@ +#[burn_tensor_testgen::testgen(ad_max_pool1d)] +mod tests { + use super::*; + use burn_tensor::{module::max_pool1d, Data}; + + #[test] + fn test_max_pool1d_simple() { + let batch_size = 1; + let channels_in = 1; + let kernel_size = 4; + let padding = 0; + let stride = 1; + + let x = TestADTensor::from_floats([[[0.9861, 0.5474, 0.4477, 0.0732, 0.3548, 0.8221]]]) + .require_grad(); + let x_grad_expected = TestADTensor::from_floats([[[1., 1., 0., 0., 0., 1.]]]); + + let output = max_pool1d(x.clone(), kernel_size, stride, padding); + let grads = output.backward(); + + // Asserts + let x_grad_actual = x.grad(&grads).unwrap(); + x_grad_expected + .to_data() + .assert_approx_eq(&x_grad_actual.to_data(), 3); + } + + #[test] + fn test_max_pool1d_complex() { + let batch_size = 1; + let channels_in = 1; + let kernel_size = 4; + let padding = 0; + let stride = 1; + + let x = TestADTensor::from_floats([[[ + 0.5388, 0.0676, 0.7122, 0.8316, 0.0653, 0.9154, 0.1536, 0.9089, 0.8016, 0.7518, 0.2073, + 0.0501, 0.8811, 0.5604, 0.5075, 0.4384, 0.9963, 0.9698, 0.4988, 0.2609, 0.3391, 0.2230, + 0.4610, 0.5365, 0.6880, + ]]]) + .require_grad(); + let x_grad_expected = TestADTensor::from_floats([[[ + 0., 0., 0., 2., 0., 4., 0., 2., 1., 0., 0., 0., 4., 0., 0., 0., 4., 1., 1., 0., 0., 0., + 1., 1., 1., + ]]]); + + let output = max_pool1d(x.clone(), kernel_size, stride, padding); + let grads = output.backward(); + + // Asserts + let x_grad_actual = x.grad(&grads).unwrap(); + x_grad_expected + .to_data() + .assert_approx_eq(&x_grad_actual.to_data(), 3); + } + + #[test] + fn test_max_pool1d_complex_with_padding() { + let batch_size = 1; + let channels_in = 1; + let kernel_size = 4; + let padding = 2; + let stride = 1; + + let x = TestADTensor::from_floats([[[ + 0.5388, 0.0676, 0.7122, 0.8316, 0.0653, 0.9154, 0.1536, 0.9089, 0.8016, 0.7518, 0.2073, + 0.0501, 0.8811, 0.5604, 0.5075, 0.4384, 0.9963, 0.9698, 0.4988, 0.2609, 0.3391, 0.2230, + 0.4610, 0.5365, 0.6880, + ]]]) + .require_grad(); + let x_grad_expected = TestADTensor::from_floats([[[ + 1., 0., 1., 2., 0., 4., 0., 2., 1., 0., 0., 0., 4., 0., 0., 0., 4., 1., 1., 0., 0., 0., + 1., 1., 3., + ]]]); + + let output = max_pool1d(x.clone(), kernel_size, stride, padding); + let grads = output.backward(); + + // Asserts + let x_grad_actual = x.grad(&grads).unwrap(); + x_grad_expected + .to_data() + .assert_approx_eq(&x_grad_actual.to_data(), 3); + } +} diff --git a/burn-autodiff/src/tests/mod.rs b/burn-autodiff/src/tests/mod.rs index 9c0f3f4c5..9409ca017 100644 --- a/burn-autodiff/src/tests/mod.rs +++ b/burn-autodiff/src/tests/mod.rs @@ -27,6 +27,7 @@ mod log1p; mod mask; mod matmul; mod maxmin; +mod maxpool1d; mod maxpool2d; mod mul; mod multithread; @@ -61,6 +62,7 @@ macro_rules! testgen_all { burn_autodiff::testgen_ad_conv2d!(); burn_autodiff::testgen_ad_conv_transpose1d!(); burn_autodiff::testgen_ad_conv_transpose2d!(); + burn_autodiff::testgen_ad_max_pool1d!(); burn_autodiff::testgen_ad_max_pool2d!(); burn_autodiff::testgen_ad_avg_pool1d!(); burn_autodiff::testgen_ad_avg_pool2d!(); diff --git a/burn-core/src/nn/pool/max_pool1d.rs b/burn-core/src/nn/pool/max_pool1d.rs new file mode 100644 index 000000000..0178d1017 --- /dev/null +++ b/burn-core/src/nn/pool/max_pool1d.rs @@ -0,0 +1,59 @@ +use crate as burn; + +use crate::config::Config; +use crate::module::Module; +use crate::nn::PaddingConfig1d; +use crate::tensor::backend::Backend; +use crate::tensor::Tensor; +use burn_tensor::module::max_pool1d; + +/// Configuration to create a [1D max pooling](MaxPool1d) layer. +#[derive(Config)] +pub struct MaxPool1dConfig { + /// The number of channels. + pub channels: usize, + /// The size of the kernel. + pub kernel_size: usize, + /// The stride. + #[config(default = "1")] + pub stride: usize, + /// The padding configuration. + #[config(default = "PaddingConfig1d::Valid")] + pub padding: PaddingConfig1d, +} + +/// Applies a 1D max pooling over input tensors. +#[derive(Module, Debug, Clone)] +pub struct MaxPool1d { + stride: usize, + kernel_size: usize, + padding: PaddingConfig1d, +} + +impl MaxPool1dConfig { + /// Initialize a new [max pool 1d](MaxPool1d) module. + pub fn init(&self) -> MaxPool1d { + MaxPool1d { + stride: self.stride, + kernel_size: self.kernel_size, + padding: self.padding.clone(), + } + } +} + +impl MaxPool1d { + /// Applies the forward pass on the input tensor. + /// + /// # Shapes + /// + /// - input: [batch_size, channels, length_in], + /// - output: [batch_size, channels, length_out], + pub fn forward(&self, input: Tensor) -> Tensor { + let [_batch_size, _channels, length] = input.dims(); + let padding = self + .padding + .calculate_padding_1d(length, self.kernel_size, self.stride); + + max_pool1d(input, self.kernel_size, self.stride, padding) + } +} diff --git a/burn-core/src/nn/pool/mod.rs b/burn-core/src/nn/pool/mod.rs index 0018397d4..622a4b66f 100644 --- a/burn-core/src/nn/pool/mod.rs +++ b/burn-core/src/nn/pool/mod.rs @@ -2,10 +2,12 @@ mod adaptive_avg_pool1d; mod adaptive_avg_pool2d; mod avg_pool1d; mod avg_pool2d; +mod max_pool1d; mod max_pool2d; pub use adaptive_avg_pool1d::*; pub use adaptive_avg_pool2d::*; pub use avg_pool1d::*; pub use avg_pool2d::*; +pub use max_pool1d::*; pub use max_pool2d::*; diff --git a/burn-tch/src/ops/module.rs b/burn-tch/src/ops/module.rs index 3e11afa19..4aae0f683 100644 --- a/burn-tch/src/ops/module.rs +++ b/burn-tch/src/ops/module.rs @@ -1,6 +1,7 @@ use crate::{element::TchElement, TchBackend, TchTensor}; use burn_tensor::ops::{ - ConvOptions, ConvTransposeOptions, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps, + ConvOptions, ConvTransposeOptions, MaxPool1dWithIndices, MaxPool2dBackward, + MaxPool2dWithIndices, ModuleOps, }; impl ModuleOps> for TchBackend { @@ -163,6 +164,42 @@ impl ModuleOps> for TchBackend { TchTensor::new(tensor) } + fn max_pool1d( + x: TchTensor, + kernel_size: usize, + stride: usize, + padding: usize, + ) -> TchTensor { + let tensor = tch::Tensor::max_pool1d( + &x.tensor, + kernel_size as i64, + stride as i64, + padding as i64, + 1, + false, + ); + + TchTensor::new(tensor) + } + + fn max_pool1d_with_indices( + x: TchTensor, + kernel_size: usize, + stride: usize, + padding: usize, + ) -> MaxPool1dWithIndices> { + let (tensor, indices) = tch::Tensor::max_pool1d_with_indices( + &x.tensor, + kernel_size as i64, + stride as i64, + padding as i64, + 1, + false, + ); + + MaxPool1dWithIndices::new(TchTensor::new(tensor), TchTensor::new(indices)) + } + fn max_pool2d( x: TchTensor, kernel_size: [usize; 2], diff --git a/burn-tensor/src/tensor/module.rs b/burn-tensor/src/tensor/module.rs index b6464fc7e..b5ed70283 100644 --- a/burn-tensor/src/tensor/module.rs +++ b/burn-tensor/src/tensor/module.rs @@ -84,6 +84,19 @@ where )) } +/// Applies a [1D max pooling](crate::ops::ModuleOps::max_pool1d). +pub fn max_pool1d( + x: Tensor, + kernel_size: usize, + stride: usize, + padding: usize, +) -> Tensor +where + B: Backend, +{ + Tensor::new(B::max_pool1d(x.primitive, kernel_size, stride, padding)) +} + /// Applies a [2D max pooling](crate::ops::ModuleOps::max_pool2d). pub fn max_pool2d( x: Tensor, @@ -123,6 +136,21 @@ where Tensor::new(B::avg_pool1d(x.primitive, kernel_size, stride, padding)) } +/// Applies a [1D max pooling](crate::ops::ModuleOps::max_pool1d). +pub fn max_pool1d_with_indices( + x: Tensor, + kernel_size: usize, + stride: usize, + padding: usize, +) -> (Tensor, Tensor) +where + B: Backend, +{ + let output = B::max_pool1d_with_indices(x.primitive, kernel_size, stride, padding); + + (Tensor::new(output.output), Tensor::new(output.indices)) +} + /// Applies a [2D max pooling with indices](crate::ops::ModuleOps::max_pool2d_with_indices). pub fn max_pool2d_with_indices( x: Tensor, diff --git a/burn-tensor/src/tensor/ops/modules/base.rs b/burn-tensor/src/tensor/ops/modules/base.rs index 0b52671eb..5a1148386 100644 --- a/burn-tensor/src/tensor/ops/modules/base.rs +++ b/burn-tensor/src/tensor/ops/modules/base.rs @@ -14,6 +14,23 @@ pub struct Conv2dBackward { pub bias_grad: Option>, } +/// Gradient computed during the backward pass for each tensor used by [max_pool1d](ModuleOps::max_pool1d). +#[derive(new)] +pub struct MaxPool1dBackward { + /// Gradient. + pub x_grad: B::TensorPrimitive<3>, +} + +/// Results from [max_pool1d](ModuleOps::max_pool1d_with_indices). +#[derive(new)] +pub struct MaxPool1dWithIndices { + /// The output tensor. + pub output: B::TensorPrimitive<3>, + + /// The indices tensor. + pub indices: B::IntTensorPrimitive<3>, +} + /// Gradient computed during the backward pass for each tensor used by [max_pool2d](ModuleOps::max_pool2d). #[derive(new)] pub struct MaxPool2dBackward { @@ -299,6 +316,52 @@ pub trait ModuleOps { ) -> B::TensorPrimitive<3> { pool::adaptive_avg_pool1d_backward_from_2d::(x, grad) } + /// One dimensional max pooling. + /// + /// # Shapes + /// + /// x: [batch_size, channels, length], + fn max_pool1d( + x: B::TensorPrimitive<3>, + kernel_size: usize, + stride: usize, + padding: usize, + ) -> B::TensorPrimitive<3> { + pool::max_pool1d_from_2d::(x, kernel_size, stride, padding) + } + + /// One dimensional max pooling with indices. + /// + /// # Shapes + /// + /// x: [batch_size, channels, height, width], + fn max_pool1d_with_indices( + x: B::TensorPrimitive<3>, + kernel_size: usize, + stride: usize, + padding: usize, + ) -> MaxPool1dWithIndices { + pool::max_pool1d_with_indices_from_2d::(x, kernel_size, stride, padding) + } + /// Backward pass for the [max pooling 1d](ModuleOps::max_pool1d_with_indices) operation. + fn max_pool1d_with_indices_backward( + x: B::TensorPrimitive<3>, + kernel_size: usize, + stride: usize, + padding: usize, + output_grad: B::TensorPrimitive<3>, + indices: B::IntTensorPrimitive<3>, + ) -> MaxPool1dBackward { + pool::max_pool1d_with_indices_backward_from_2d::( + x, + kernel_size, + stride, + padding, + output_grad, + indices, + ) + } + /// Two dimensional max pooling. /// /// # Shapes diff --git a/burn-tensor/src/tensor/ops/modules/pool.rs b/burn-tensor/src/tensor/ops/modules/pool.rs index fb76f9309..f7486221c 100644 --- a/burn-tensor/src/tensor/ops/modules/pool.rs +++ b/burn-tensor/src/tensor/ops/modules/pool.rs @@ -1,5 +1,7 @@ use crate::{backend::Backend, Shape}; +use super::{MaxPool1dBackward, MaxPool1dWithIndices}; + pub(crate) fn avg_pool1d_from_2d( x: B::TensorPrimitive<3>, kernel_size: usize, @@ -62,3 +64,72 @@ pub(crate) fn adaptive_avg_pool1d_backward_from_2d( B::reshape(grad_x, Shape::from([batch_size, channels, length_in])) } + +pub(crate) fn max_pool1d_from_2d( + x: B::TensorPrimitive<3>, + kernel_size: usize, + stride: usize, + padding: usize, +) -> B::TensorPrimitive<3> { + let [batch_size, channels, length] = B::shape(&x).dims; + + let x = B::reshape(x, Shape::from([batch_size, channels, length, 1])); + let x = B::max_pool2d(x, [kernel_size, 1], [stride, 1], [padding, 0]); + + let [batch_size, channels, length, _] = B::shape(&x).dims; + + B::reshape(x, Shape::from([batch_size, channels, length])) +} + +pub(crate) fn max_pool1d_with_indices_from_2d( + x: B::TensorPrimitive<3>, + kernel_size: usize, + stride: usize, + padding: usize, +) -> MaxPool1dWithIndices { + let [batch_size, channels, length] = B::shape(&x).dims; + + let x = B::reshape(x, Shape::from([batch_size, channels, 1, length])); + let x = B::max_pool2d_with_indices(x, [1, kernel_size], [1, stride], [0, padding]); + let [batch_size, channels, _, length] = B::shape(&x.output).dims; + let output = B::reshape(x.output, Shape::from([batch_size, channels, length])); + let indices = B::int_reshape( + x.indices.clone(), + Shape::from([batch_size, channels, length]), + ); + MaxPool1dWithIndices::new(output, indices) +} + +pub(crate) fn max_pool1d_with_indices_backward_from_2d( + x: B::TensorPrimitive<3>, + kernel_size: usize, + stride: usize, + padding: usize, + output_grad: B::TensorPrimitive<3>, + indices: B::IntTensorPrimitive<3>, +) -> MaxPool1dBackward { + let [batch_size, channels, length_in] = B::shape(&x).dims; + let [_, _, length_out] = B::shape(&output_grad).dims; + + let x = B::reshape(x, Shape::from([batch_size, channels, length_in, 1])); + let grad_x = B::reshape( + output_grad, + Shape::from([batch_size, channels, length_out, 1]), + ); + let indices = B::int_reshape(indices, Shape::from([batch_size, channels, length_out, 1])); + + let grad_x = B::max_pool2d_with_indices_backward( + x, + [kernel_size, 1], + [stride, 1], + [padding, 0], + grad_x, + indices, + ) + .x_grad; + + MaxPool1dBackward::new(B::reshape( + grad_x, + Shape::from([batch_size, channels, length_in]), + )) +} diff --git a/burn-tensor/src/tests/mod.rs b/burn-tensor/src/tests/mod.rs index f2e29f389..a1e0d352f 100644 --- a/burn-tensor/src/tests/mod.rs +++ b/burn-tensor/src/tests/mod.rs @@ -20,6 +20,7 @@ macro_rules! testgen_all { burn_tensor::testgen_module_conv2d!(); burn_tensor::testgen_module_conv_transpose1d!(); burn_tensor::testgen_module_conv_transpose2d!(); + burn_tensor::testgen_module_max_pool1d!(); burn_tensor::testgen_module_max_pool2d!(); burn_tensor::testgen_module_avg_pool1d!(); burn_tensor::testgen_module_avg_pool2d!(); diff --git a/burn-tensor/src/tests/module/maxpool1d.rs b/burn-tensor/src/tests/module/maxpool1d.rs new file mode 100644 index 000000000..3eed04421 --- /dev/null +++ b/burn-tensor/src/tests/module/maxpool1d.rs @@ -0,0 +1,98 @@ +#[burn_tensor_testgen::testgen(module_max_pool1d)] +mod tests { + use super::*; + use burn_tensor::module::{max_pool1d, max_pool1d_with_indices}; + use burn_tensor::{backend::Backend, Data, Tensor}; + + type IntElem = ::IntElem; + + #[test] + fn test_max_pool1d_simple() { + let batch_size = 2; + let channels_in = 2; + let kernel_size = 3; + let padding = 1; + let stride = 1; + + let x = TestTensor::from_floats([[ + [0.9861, 0.5474, 0.4477, 0.0732, 0.3548, 0.8221], + [0.8148, 0.5474, 0.9490, 0.7890, 0.5537, 0.5689], + ]]); + let y = TestTensor::from_floats([[ + [0.9861, 0.9861, 0.5474, 0.4477, 0.8221, 0.8221], + [0.8148, 0.9490, 0.9490, 0.9490, 0.7890, 0.5689], + ]]); + + let output = max_pool1d(x, kernel_size, stride, padding); + + y.to_data().assert_approx_eq(&output.into_data(), 3); + } + + #[test] + fn test_max_pool1d_different_padding_stride_kernel() { + let batch_size = 1; + let channels_in = 1; + let kernel_size = 3; + let padding = 1; + let stride = 2; + + let x = TestTensor::from_floats([[[0.6309, 0.6112, 0.6998, 0.4708]]]); + let y = TestTensor::from_floats([[[0.6309, 0.6998]]]); + + let output = max_pool1d(x, kernel_size, stride, padding); + + y.to_data().assert_approx_eq(&output.into_data(), 3); + } + + #[test] + fn test_max_pool1d_with_neg() { + let batch_size = 1; + let channels_in = 1; + let kernel_size = 3; + let padding = 1; + let stride = 1; + + let x = TestTensor::from_floats([[[-0.6309, -0.6112, -0.6998, -0.4708]]]); + let y = TestTensor::from_floats([[[-0.6112, -0.6112, -0.4708, -0.4708]]]); + + let output = max_pool1d(x, kernel_size, stride, padding); + + y.to_data().assert_approx_eq(&output.into_data(), 3); + } + + #[test] + fn test_max_pool1d_with_indices() { + let batch_size = 1; + let channels_in = 1; + let kernel_size = 2; + let padding = 1; + let stride = 1; + + let x = TestTensor::from_floats([[[0.2479, 0.6386, 0.3166, 0.5742]]]); + let indices = Data::::from([[[0, 1, 1, 3, 3]]]); + let y = TestTensor::from_floats([[[0.2479, 0.6386, 0.6386, 0.5742, 0.5742]]]); + + let (output, output_indices) = max_pool1d_with_indices(x, kernel_size, stride, padding); + + y.to_data().assert_approx_eq(&output.into_data(), 3); + assert_eq!(indices.value, output_indices.into_data().value); + } + + #[test] + fn test_max_pool1d_complex() { + let batch_size = 1; + let channels_in = 1; + let kernel_size = 4; + let padding = 2; + let stride = 1; + + let x = TestTensor::from_floats([[[0.5388, 0.0676, 0.7122, 0.8316, 0.0653]]]); + let indices = Data::::from([[[0, 2, 3, 3, 3, 3]]]); + let y = TestTensor::from_floats([[[0.5388, 0.7122, 0.8316, 0.8316, 0.8316, 0.8316]]]); + + let (output, output_indices) = max_pool1d_with_indices(x, kernel_size, stride, padding); + + y.to_data().assert_approx_eq(&output.into_data(), 3); + assert_eq!(indices.value, output_indices.into_data().value); + } +} diff --git a/burn-tensor/src/tests/module/mod.rs b/burn-tensor/src/tests/module/mod.rs index 3ac680303..13aa4564e 100644 --- a/burn-tensor/src/tests/module/mod.rs +++ b/burn-tensor/src/tests/module/mod.rs @@ -7,4 +7,5 @@ mod conv2d; mod conv_transpose1d; mod conv_transpose2d; mod forward; +mod maxpool1d; mod maxpool2d;