mirror of https://github.com/tracel-ai/burn.git
Max pool1d (#602)
This commit is contained in:
parent
1f01fcb640
commit
cb283a9e5b
|
@ -564,6 +564,79 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
||||||
panic!("Can't differentiate avg pool 2d backward.");
|
panic!("Can't differentiate avg pool 2d backward.");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn max_pool1d(
|
||||||
|
x: ADTensor<B, 3>,
|
||||||
|
kernel_size: usize,
|
||||||
|
stride: usize,
|
||||||
|
padding: usize,
|
||||||
|
) -> ADTensor<B, 3> {
|
||||||
|
match MaxPool1D.prepare([x.node], [x.graph]).statefull() {
|
||||||
|
OpsKind::Tracked(prep) => {
|
||||||
|
let output =
|
||||||
|
B::max_pool1d_with_indices(x.primitive.clone(), kernel_size, stride, padding);
|
||||||
|
prep.finish(
|
||||||
|
(x.primitive, output.indices, kernel_size, stride, padding),
|
||||||
|
output.output,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
OpsKind::UnTracked(prep) => {
|
||||||
|
prep.finish(B::max_pool1d(x.primitive, kernel_size, stride, padding))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn max_pool1d_with_indices(
|
||||||
|
x: ADTensor<B, 3>,
|
||||||
|
kernel_size: usize,
|
||||||
|
stride: usize,
|
||||||
|
padding: usize,
|
||||||
|
) -> MaxPool1dWithIndices<ADBackendDecorator<B>> {
|
||||||
|
match MaxPool1D.prepare([x.node], [x.graph]).statefull() {
|
||||||
|
OpsKind::Tracked(prep) => {
|
||||||
|
let output =
|
||||||
|
B::max_pool1d_with_indices(x.primitive.clone(), kernel_size, stride, padding);
|
||||||
|
|
||||||
|
let output_tensor = prep.finish(
|
||||||
|
(
|
||||||
|
x.primitive,
|
||||||
|
output.indices.clone(),
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
),
|
||||||
|
output.output,
|
||||||
|
);
|
||||||
|
|
||||||
|
MaxPool1dWithIndices::new(output_tensor, output.indices)
|
||||||
|
}
|
||||||
|
OpsKind::UnTracked(prep) => {
|
||||||
|
let output = B::max_pool1d_with_indices(x.primitive, kernel_size, stride, padding);
|
||||||
|
let output_tensor = prep.finish(output.output);
|
||||||
|
|
||||||
|
MaxPool1dWithIndices::new(output_tensor, output.indices)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn max_pool1d_with_indices_backward(
|
||||||
|
x: ADTensor<B, 3>,
|
||||||
|
kernel_size: usize,
|
||||||
|
stride: usize,
|
||||||
|
padding: usize,
|
||||||
|
output_grad: ADTensor<B, 3>,
|
||||||
|
indices: IntTensor<B, 3>,
|
||||||
|
) -> MaxPool1dBackward<ADBackendDecorator<B>> {
|
||||||
|
let output = B::max_pool1d_with_indices_backward(
|
||||||
|
x.primitive,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
output_grad.primitive,
|
||||||
|
indices,
|
||||||
|
);
|
||||||
|
MaxPool1dBackward::new(ADTensor::new(output.x_grad))
|
||||||
|
}
|
||||||
|
|
||||||
fn max_pool2d(
|
fn max_pool2d(
|
||||||
x: ADTensor<B, 4>,
|
x: ADTensor<B, 4>,
|
||||||
kernel_size: [usize; 2],
|
kernel_size: [usize; 2],
|
||||||
|
@ -694,6 +767,26 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct MaxPool1D;
|
||||||
|
|
||||||
|
impl<B: Backend> Backward<B, 3, 1> for MaxPool1D {
|
||||||
|
type State = (B::TensorPrimitive<3>, IntTensor<B, 3>, usize, usize, usize);
|
||||||
|
|
||||||
|
fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
|
||||||
|
let [node_parent] = ops.parents;
|
||||||
|
let grad = grads.consume::<B, 3>(&ops.node);
|
||||||
|
let (x, indices, kernel_size, stride, padding) = ops.state;
|
||||||
|
|
||||||
|
if let Some(node) = node_parent {
|
||||||
|
let grad =
|
||||||
|
B::max_pool1d_with_indices_backward(x, kernel_size, stride, padding, grad, indices);
|
||||||
|
|
||||||
|
grads.register::<B, 3>(node, grad.x_grad);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct MaxPool2D;
|
struct MaxPool2D;
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,85 @@
|
||||||
|
#[burn_tensor_testgen::testgen(ad_max_pool1d)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use burn_tensor::{module::max_pool1d, Data};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_max_pool1d_simple() {
|
||||||
|
let batch_size = 1;
|
||||||
|
let channels_in = 1;
|
||||||
|
let kernel_size = 4;
|
||||||
|
let padding = 0;
|
||||||
|
let stride = 1;
|
||||||
|
|
||||||
|
let x = TestADTensor::from_floats([[[0.9861, 0.5474, 0.4477, 0.0732, 0.3548, 0.8221]]])
|
||||||
|
.require_grad();
|
||||||
|
let x_grad_expected = TestADTensor::from_floats([[[1., 1., 0., 0., 0., 1.]]]);
|
||||||
|
|
||||||
|
let output = max_pool1d(x.clone(), kernel_size, stride, padding);
|
||||||
|
let grads = output.backward();
|
||||||
|
|
||||||
|
// Asserts
|
||||||
|
let x_grad_actual = x.grad(&grads).unwrap();
|
||||||
|
x_grad_expected
|
||||||
|
.to_data()
|
||||||
|
.assert_approx_eq(&x_grad_actual.to_data(), 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_max_pool1d_complex() {
|
||||||
|
let batch_size = 1;
|
||||||
|
let channels_in = 1;
|
||||||
|
let kernel_size = 4;
|
||||||
|
let padding = 0;
|
||||||
|
let stride = 1;
|
||||||
|
|
||||||
|
let x = TestADTensor::from_floats([[[
|
||||||
|
0.5388, 0.0676, 0.7122, 0.8316, 0.0653, 0.9154, 0.1536, 0.9089, 0.8016, 0.7518, 0.2073,
|
||||||
|
0.0501, 0.8811, 0.5604, 0.5075, 0.4384, 0.9963, 0.9698, 0.4988, 0.2609, 0.3391, 0.2230,
|
||||||
|
0.4610, 0.5365, 0.6880,
|
||||||
|
]]])
|
||||||
|
.require_grad();
|
||||||
|
let x_grad_expected = TestADTensor::from_floats([[[
|
||||||
|
0., 0., 0., 2., 0., 4., 0., 2., 1., 0., 0., 0., 4., 0., 0., 0., 4., 1., 1., 0., 0., 0.,
|
||||||
|
1., 1., 1.,
|
||||||
|
]]]);
|
||||||
|
|
||||||
|
let output = max_pool1d(x.clone(), kernel_size, stride, padding);
|
||||||
|
let grads = output.backward();
|
||||||
|
|
||||||
|
// Asserts
|
||||||
|
let x_grad_actual = x.grad(&grads).unwrap();
|
||||||
|
x_grad_expected
|
||||||
|
.to_data()
|
||||||
|
.assert_approx_eq(&x_grad_actual.to_data(), 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_max_pool1d_complex_with_padding() {
|
||||||
|
let batch_size = 1;
|
||||||
|
let channels_in = 1;
|
||||||
|
let kernel_size = 4;
|
||||||
|
let padding = 2;
|
||||||
|
let stride = 1;
|
||||||
|
|
||||||
|
let x = TestADTensor::from_floats([[[
|
||||||
|
0.5388, 0.0676, 0.7122, 0.8316, 0.0653, 0.9154, 0.1536, 0.9089, 0.8016, 0.7518, 0.2073,
|
||||||
|
0.0501, 0.8811, 0.5604, 0.5075, 0.4384, 0.9963, 0.9698, 0.4988, 0.2609, 0.3391, 0.2230,
|
||||||
|
0.4610, 0.5365, 0.6880,
|
||||||
|
]]])
|
||||||
|
.require_grad();
|
||||||
|
let x_grad_expected = TestADTensor::from_floats([[[
|
||||||
|
1., 0., 1., 2., 0., 4., 0., 2., 1., 0., 0., 0., 4., 0., 0., 0., 4., 1., 1., 0., 0., 0.,
|
||||||
|
1., 1., 3.,
|
||||||
|
]]]);
|
||||||
|
|
||||||
|
let output = max_pool1d(x.clone(), kernel_size, stride, padding);
|
||||||
|
let grads = output.backward();
|
||||||
|
|
||||||
|
// Asserts
|
||||||
|
let x_grad_actual = x.grad(&grads).unwrap();
|
||||||
|
x_grad_expected
|
||||||
|
.to_data()
|
||||||
|
.assert_approx_eq(&x_grad_actual.to_data(), 3);
|
||||||
|
}
|
||||||
|
}
|
|
@ -27,6 +27,7 @@ mod log1p;
|
||||||
mod mask;
|
mod mask;
|
||||||
mod matmul;
|
mod matmul;
|
||||||
mod maxmin;
|
mod maxmin;
|
||||||
|
mod maxpool1d;
|
||||||
mod maxpool2d;
|
mod maxpool2d;
|
||||||
mod mul;
|
mod mul;
|
||||||
mod multithread;
|
mod multithread;
|
||||||
|
@ -61,6 +62,7 @@ macro_rules! testgen_all {
|
||||||
burn_autodiff::testgen_ad_conv2d!();
|
burn_autodiff::testgen_ad_conv2d!();
|
||||||
burn_autodiff::testgen_ad_conv_transpose1d!();
|
burn_autodiff::testgen_ad_conv_transpose1d!();
|
||||||
burn_autodiff::testgen_ad_conv_transpose2d!();
|
burn_autodiff::testgen_ad_conv_transpose2d!();
|
||||||
|
burn_autodiff::testgen_ad_max_pool1d!();
|
||||||
burn_autodiff::testgen_ad_max_pool2d!();
|
burn_autodiff::testgen_ad_max_pool2d!();
|
||||||
burn_autodiff::testgen_ad_avg_pool1d!();
|
burn_autodiff::testgen_ad_avg_pool1d!();
|
||||||
burn_autodiff::testgen_ad_avg_pool2d!();
|
burn_autodiff::testgen_ad_avg_pool2d!();
|
||||||
|
|
|
@ -0,0 +1,59 @@
|
||||||
|
use crate as burn;
|
||||||
|
|
||||||
|
use crate::config::Config;
|
||||||
|
use crate::module::Module;
|
||||||
|
use crate::nn::PaddingConfig1d;
|
||||||
|
use crate::tensor::backend::Backend;
|
||||||
|
use crate::tensor::Tensor;
|
||||||
|
use burn_tensor::module::max_pool1d;
|
||||||
|
|
||||||
|
/// Configuration to create a [1D max pooling](MaxPool1d) layer.
|
||||||
|
#[derive(Config)]
|
||||||
|
pub struct MaxPool1dConfig {
|
||||||
|
/// The number of channels.
|
||||||
|
pub channels: usize,
|
||||||
|
/// The size of the kernel.
|
||||||
|
pub kernel_size: usize,
|
||||||
|
/// The stride.
|
||||||
|
#[config(default = "1")]
|
||||||
|
pub stride: usize,
|
||||||
|
/// The padding configuration.
|
||||||
|
#[config(default = "PaddingConfig1d::Valid")]
|
||||||
|
pub padding: PaddingConfig1d,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Applies a 1D max pooling over input tensors.
|
||||||
|
#[derive(Module, Debug, Clone)]
|
||||||
|
pub struct MaxPool1d {
|
||||||
|
stride: usize,
|
||||||
|
kernel_size: usize,
|
||||||
|
padding: PaddingConfig1d,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MaxPool1dConfig {
|
||||||
|
/// Initialize a new [max pool 1d](MaxPool1d) module.
|
||||||
|
pub fn init(&self) -> MaxPool1d {
|
||||||
|
MaxPool1d {
|
||||||
|
stride: self.stride,
|
||||||
|
kernel_size: self.kernel_size,
|
||||||
|
padding: self.padding.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MaxPool1d {
|
||||||
|
/// Applies the forward pass on the input tensor.
|
||||||
|
///
|
||||||
|
/// # Shapes
|
||||||
|
///
|
||||||
|
/// - input: [batch_size, channels, length_in],
|
||||||
|
/// - output: [batch_size, channels, length_out],
|
||||||
|
pub fn forward<B: Backend>(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
|
||||||
|
let [_batch_size, _channels, length] = input.dims();
|
||||||
|
let padding = self
|
||||||
|
.padding
|
||||||
|
.calculate_padding_1d(length, self.kernel_size, self.stride);
|
||||||
|
|
||||||
|
max_pool1d(input, self.kernel_size, self.stride, padding)
|
||||||
|
}
|
||||||
|
}
|
|
@ -2,10 +2,12 @@ mod adaptive_avg_pool1d;
|
||||||
mod adaptive_avg_pool2d;
|
mod adaptive_avg_pool2d;
|
||||||
mod avg_pool1d;
|
mod avg_pool1d;
|
||||||
mod avg_pool2d;
|
mod avg_pool2d;
|
||||||
|
mod max_pool1d;
|
||||||
mod max_pool2d;
|
mod max_pool2d;
|
||||||
|
|
||||||
pub use adaptive_avg_pool1d::*;
|
pub use adaptive_avg_pool1d::*;
|
||||||
pub use adaptive_avg_pool2d::*;
|
pub use adaptive_avg_pool2d::*;
|
||||||
pub use avg_pool1d::*;
|
pub use avg_pool1d::*;
|
||||||
pub use avg_pool2d::*;
|
pub use avg_pool2d::*;
|
||||||
|
pub use max_pool1d::*;
|
||||||
pub use max_pool2d::*;
|
pub use max_pool2d::*;
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
use crate::{element::TchElement, TchBackend, TchTensor};
|
use crate::{element::TchElement, TchBackend, TchTensor};
|
||||||
use burn_tensor::ops::{
|
use burn_tensor::ops::{
|
||||||
ConvOptions, ConvTransposeOptions, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps,
|
ConvOptions, ConvTransposeOptions, MaxPool1dWithIndices, MaxPool2dBackward,
|
||||||
|
MaxPool2dWithIndices, ModuleOps,
|
||||||
};
|
};
|
||||||
|
|
||||||
impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
|
impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
|
||||||
|
@ -163,6 +164,42 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
|
||||||
TchTensor::new(tensor)
|
TchTensor::new(tensor)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn max_pool1d(
|
||||||
|
x: TchTensor<E, 3>,
|
||||||
|
kernel_size: usize,
|
||||||
|
stride: usize,
|
||||||
|
padding: usize,
|
||||||
|
) -> TchTensor<E, 3> {
|
||||||
|
let tensor = tch::Tensor::max_pool1d(
|
||||||
|
&x.tensor,
|
||||||
|
kernel_size as i64,
|
||||||
|
stride as i64,
|
||||||
|
padding as i64,
|
||||||
|
1,
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
|
||||||
|
TchTensor::new(tensor)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn max_pool1d_with_indices(
|
||||||
|
x: TchTensor<E, 3>,
|
||||||
|
kernel_size: usize,
|
||||||
|
stride: usize,
|
||||||
|
padding: usize,
|
||||||
|
) -> MaxPool1dWithIndices<TchBackend<E>> {
|
||||||
|
let (tensor, indices) = tch::Tensor::max_pool1d_with_indices(
|
||||||
|
&x.tensor,
|
||||||
|
kernel_size as i64,
|
||||||
|
stride as i64,
|
||||||
|
padding as i64,
|
||||||
|
1,
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
|
||||||
|
MaxPool1dWithIndices::new(TchTensor::new(tensor), TchTensor::new(indices))
|
||||||
|
}
|
||||||
|
|
||||||
fn max_pool2d(
|
fn max_pool2d(
|
||||||
x: TchTensor<E, 4>,
|
x: TchTensor<E, 4>,
|
||||||
kernel_size: [usize; 2],
|
kernel_size: [usize; 2],
|
||||||
|
|
|
@ -84,6 +84,19 @@ where
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Applies a [1D max pooling](crate::ops::ModuleOps::max_pool1d).
|
||||||
|
pub fn max_pool1d<B>(
|
||||||
|
x: Tensor<B, 3>,
|
||||||
|
kernel_size: usize,
|
||||||
|
stride: usize,
|
||||||
|
padding: usize,
|
||||||
|
) -> Tensor<B, 3>
|
||||||
|
where
|
||||||
|
B: Backend,
|
||||||
|
{
|
||||||
|
Tensor::new(B::max_pool1d(x.primitive, kernel_size, stride, padding))
|
||||||
|
}
|
||||||
|
|
||||||
/// Applies a [2D max pooling](crate::ops::ModuleOps::max_pool2d).
|
/// Applies a [2D max pooling](crate::ops::ModuleOps::max_pool2d).
|
||||||
pub fn max_pool2d<B>(
|
pub fn max_pool2d<B>(
|
||||||
x: Tensor<B, 4>,
|
x: Tensor<B, 4>,
|
||||||
|
@ -123,6 +136,21 @@ where
|
||||||
Tensor::new(B::avg_pool1d(x.primitive, kernel_size, stride, padding))
|
Tensor::new(B::avg_pool1d(x.primitive, kernel_size, stride, padding))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Applies a [1D max pooling](crate::ops::ModuleOps::max_pool1d).
|
||||||
|
pub fn max_pool1d_with_indices<B>(
|
||||||
|
x: Tensor<B, 3>,
|
||||||
|
kernel_size: usize,
|
||||||
|
stride: usize,
|
||||||
|
padding: usize,
|
||||||
|
) -> (Tensor<B, 3>, Tensor<B, 3, Int>)
|
||||||
|
where
|
||||||
|
B: Backend,
|
||||||
|
{
|
||||||
|
let output = B::max_pool1d_with_indices(x.primitive, kernel_size, stride, padding);
|
||||||
|
|
||||||
|
(Tensor::new(output.output), Tensor::new(output.indices))
|
||||||
|
}
|
||||||
|
|
||||||
/// Applies a [2D max pooling with indices](crate::ops::ModuleOps::max_pool2d_with_indices).
|
/// Applies a [2D max pooling with indices](crate::ops::ModuleOps::max_pool2d_with_indices).
|
||||||
pub fn max_pool2d_with_indices<B>(
|
pub fn max_pool2d_with_indices<B>(
|
||||||
x: Tensor<B, 4>,
|
x: Tensor<B, 4>,
|
||||||
|
|
|
@ -14,6 +14,23 @@ pub struct Conv2dBackward<B: Backend> {
|
||||||
pub bias_grad: Option<B::TensorPrimitive<1>>,
|
pub bias_grad: Option<B::TensorPrimitive<1>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Gradient computed during the backward pass for each tensor used by [max_pool1d](ModuleOps::max_pool1d).
|
||||||
|
#[derive(new)]
|
||||||
|
pub struct MaxPool1dBackward<B: Backend> {
|
||||||
|
/// Gradient.
|
||||||
|
pub x_grad: B::TensorPrimitive<3>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Results from [max_pool1d](ModuleOps::max_pool1d_with_indices).
|
||||||
|
#[derive(new)]
|
||||||
|
pub struct MaxPool1dWithIndices<B: Backend> {
|
||||||
|
/// The output tensor.
|
||||||
|
pub output: B::TensorPrimitive<3>,
|
||||||
|
|
||||||
|
/// The indices tensor.
|
||||||
|
pub indices: B::IntTensorPrimitive<3>,
|
||||||
|
}
|
||||||
|
|
||||||
/// Gradient computed during the backward pass for each tensor used by [max_pool2d](ModuleOps::max_pool2d).
|
/// Gradient computed during the backward pass for each tensor used by [max_pool2d](ModuleOps::max_pool2d).
|
||||||
#[derive(new)]
|
#[derive(new)]
|
||||||
pub struct MaxPool2dBackward<B: Backend> {
|
pub struct MaxPool2dBackward<B: Backend> {
|
||||||
|
@ -299,6 +316,52 @@ pub trait ModuleOps<B: Backend> {
|
||||||
) -> B::TensorPrimitive<3> {
|
) -> B::TensorPrimitive<3> {
|
||||||
pool::adaptive_avg_pool1d_backward_from_2d::<B>(x, grad)
|
pool::adaptive_avg_pool1d_backward_from_2d::<B>(x, grad)
|
||||||
}
|
}
|
||||||
|
/// One dimensional max pooling.
|
||||||
|
///
|
||||||
|
/// # Shapes
|
||||||
|
///
|
||||||
|
/// x: [batch_size, channels, length],
|
||||||
|
fn max_pool1d(
|
||||||
|
x: B::TensorPrimitive<3>,
|
||||||
|
kernel_size: usize,
|
||||||
|
stride: usize,
|
||||||
|
padding: usize,
|
||||||
|
) -> B::TensorPrimitive<3> {
|
||||||
|
pool::max_pool1d_from_2d::<B>(x, kernel_size, stride, padding)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// One dimensional max pooling with indices.
|
||||||
|
///
|
||||||
|
/// # Shapes
|
||||||
|
///
|
||||||
|
/// x: [batch_size, channels, height, width],
|
||||||
|
fn max_pool1d_with_indices(
|
||||||
|
x: B::TensorPrimitive<3>,
|
||||||
|
kernel_size: usize,
|
||||||
|
stride: usize,
|
||||||
|
padding: usize,
|
||||||
|
) -> MaxPool1dWithIndices<B> {
|
||||||
|
pool::max_pool1d_with_indices_from_2d::<B>(x, kernel_size, stride, padding)
|
||||||
|
}
|
||||||
|
/// Backward pass for the [max pooling 1d](ModuleOps::max_pool1d_with_indices) operation.
|
||||||
|
fn max_pool1d_with_indices_backward(
|
||||||
|
x: B::TensorPrimitive<3>,
|
||||||
|
kernel_size: usize,
|
||||||
|
stride: usize,
|
||||||
|
padding: usize,
|
||||||
|
output_grad: B::TensorPrimitive<3>,
|
||||||
|
indices: B::IntTensorPrimitive<3>,
|
||||||
|
) -> MaxPool1dBackward<B> {
|
||||||
|
pool::max_pool1d_with_indices_backward_from_2d::<B>(
|
||||||
|
x,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
output_grad,
|
||||||
|
indices,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
/// Two dimensional max pooling.
|
/// Two dimensional max pooling.
|
||||||
///
|
///
|
||||||
/// # Shapes
|
/// # Shapes
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
use crate::{backend::Backend, Shape};
|
use crate::{backend::Backend, Shape};
|
||||||
|
|
||||||
|
use super::{MaxPool1dBackward, MaxPool1dWithIndices};
|
||||||
|
|
||||||
pub(crate) fn avg_pool1d_from_2d<B: Backend>(
|
pub(crate) fn avg_pool1d_from_2d<B: Backend>(
|
||||||
x: B::TensorPrimitive<3>,
|
x: B::TensorPrimitive<3>,
|
||||||
kernel_size: usize,
|
kernel_size: usize,
|
||||||
|
@ -62,3 +64,72 @@ pub(crate) fn adaptive_avg_pool1d_backward_from_2d<B: Backend>(
|
||||||
|
|
||||||
B::reshape(grad_x, Shape::from([batch_size, channels, length_in]))
|
B::reshape(grad_x, Shape::from([batch_size, channels, length_in]))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn max_pool1d_from_2d<B: Backend>(
|
||||||
|
x: B::TensorPrimitive<3>,
|
||||||
|
kernel_size: usize,
|
||||||
|
stride: usize,
|
||||||
|
padding: usize,
|
||||||
|
) -> B::TensorPrimitive<3> {
|
||||||
|
let [batch_size, channels, length] = B::shape(&x).dims;
|
||||||
|
|
||||||
|
let x = B::reshape(x, Shape::from([batch_size, channels, length, 1]));
|
||||||
|
let x = B::max_pool2d(x, [kernel_size, 1], [stride, 1], [padding, 0]);
|
||||||
|
|
||||||
|
let [batch_size, channels, length, _] = B::shape(&x).dims;
|
||||||
|
|
||||||
|
B::reshape(x, Shape::from([batch_size, channels, length]))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn max_pool1d_with_indices_from_2d<B: Backend>(
|
||||||
|
x: B::TensorPrimitive<3>,
|
||||||
|
kernel_size: usize,
|
||||||
|
stride: usize,
|
||||||
|
padding: usize,
|
||||||
|
) -> MaxPool1dWithIndices<B> {
|
||||||
|
let [batch_size, channels, length] = B::shape(&x).dims;
|
||||||
|
|
||||||
|
let x = B::reshape(x, Shape::from([batch_size, channels, 1, length]));
|
||||||
|
let x = B::max_pool2d_with_indices(x, [1, kernel_size], [1, stride], [0, padding]);
|
||||||
|
let [batch_size, channels, _, length] = B::shape(&x.output).dims;
|
||||||
|
let output = B::reshape(x.output, Shape::from([batch_size, channels, length]));
|
||||||
|
let indices = B::int_reshape(
|
||||||
|
x.indices.clone(),
|
||||||
|
Shape::from([batch_size, channels, length]),
|
||||||
|
);
|
||||||
|
MaxPool1dWithIndices::new(output, indices)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn max_pool1d_with_indices_backward_from_2d<B: Backend>(
|
||||||
|
x: B::TensorPrimitive<3>,
|
||||||
|
kernel_size: usize,
|
||||||
|
stride: usize,
|
||||||
|
padding: usize,
|
||||||
|
output_grad: B::TensorPrimitive<3>,
|
||||||
|
indices: B::IntTensorPrimitive<3>,
|
||||||
|
) -> MaxPool1dBackward<B> {
|
||||||
|
let [batch_size, channels, length_in] = B::shape(&x).dims;
|
||||||
|
let [_, _, length_out] = B::shape(&output_grad).dims;
|
||||||
|
|
||||||
|
let x = B::reshape(x, Shape::from([batch_size, channels, length_in, 1]));
|
||||||
|
let grad_x = B::reshape(
|
||||||
|
output_grad,
|
||||||
|
Shape::from([batch_size, channels, length_out, 1]),
|
||||||
|
);
|
||||||
|
let indices = B::int_reshape(indices, Shape::from([batch_size, channels, length_out, 1]));
|
||||||
|
|
||||||
|
let grad_x = B::max_pool2d_with_indices_backward(
|
||||||
|
x,
|
||||||
|
[kernel_size, 1],
|
||||||
|
[stride, 1],
|
||||||
|
[padding, 0],
|
||||||
|
grad_x,
|
||||||
|
indices,
|
||||||
|
)
|
||||||
|
.x_grad;
|
||||||
|
|
||||||
|
MaxPool1dBackward::new(B::reshape(
|
||||||
|
grad_x,
|
||||||
|
Shape::from([batch_size, channels, length_in]),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
|
@ -20,6 +20,7 @@ macro_rules! testgen_all {
|
||||||
burn_tensor::testgen_module_conv2d!();
|
burn_tensor::testgen_module_conv2d!();
|
||||||
burn_tensor::testgen_module_conv_transpose1d!();
|
burn_tensor::testgen_module_conv_transpose1d!();
|
||||||
burn_tensor::testgen_module_conv_transpose2d!();
|
burn_tensor::testgen_module_conv_transpose2d!();
|
||||||
|
burn_tensor::testgen_module_max_pool1d!();
|
||||||
burn_tensor::testgen_module_max_pool2d!();
|
burn_tensor::testgen_module_max_pool2d!();
|
||||||
burn_tensor::testgen_module_avg_pool1d!();
|
burn_tensor::testgen_module_avg_pool1d!();
|
||||||
burn_tensor::testgen_module_avg_pool2d!();
|
burn_tensor::testgen_module_avg_pool2d!();
|
||||||
|
|
|
@ -0,0 +1,98 @@
|
||||||
|
#[burn_tensor_testgen::testgen(module_max_pool1d)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use burn_tensor::module::{max_pool1d, max_pool1d_with_indices};
|
||||||
|
use burn_tensor::{backend::Backend, Data, Tensor};
|
||||||
|
|
||||||
|
type IntElem = <TestBackend as Backend>::IntElem;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_max_pool1d_simple() {
|
||||||
|
let batch_size = 2;
|
||||||
|
let channels_in = 2;
|
||||||
|
let kernel_size = 3;
|
||||||
|
let padding = 1;
|
||||||
|
let stride = 1;
|
||||||
|
|
||||||
|
let x = TestTensor::from_floats([[
|
||||||
|
[0.9861, 0.5474, 0.4477, 0.0732, 0.3548, 0.8221],
|
||||||
|
[0.8148, 0.5474, 0.9490, 0.7890, 0.5537, 0.5689],
|
||||||
|
]]);
|
||||||
|
let y = TestTensor::from_floats([[
|
||||||
|
[0.9861, 0.9861, 0.5474, 0.4477, 0.8221, 0.8221],
|
||||||
|
[0.8148, 0.9490, 0.9490, 0.9490, 0.7890, 0.5689],
|
||||||
|
]]);
|
||||||
|
|
||||||
|
let output = max_pool1d(x, kernel_size, stride, padding);
|
||||||
|
|
||||||
|
y.to_data().assert_approx_eq(&output.into_data(), 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_max_pool1d_different_padding_stride_kernel() {
|
||||||
|
let batch_size = 1;
|
||||||
|
let channels_in = 1;
|
||||||
|
let kernel_size = 3;
|
||||||
|
let padding = 1;
|
||||||
|
let stride = 2;
|
||||||
|
|
||||||
|
let x = TestTensor::from_floats([[[0.6309, 0.6112, 0.6998, 0.4708]]]);
|
||||||
|
let y = TestTensor::from_floats([[[0.6309, 0.6998]]]);
|
||||||
|
|
||||||
|
let output = max_pool1d(x, kernel_size, stride, padding);
|
||||||
|
|
||||||
|
y.to_data().assert_approx_eq(&output.into_data(), 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_max_pool1d_with_neg() {
|
||||||
|
let batch_size = 1;
|
||||||
|
let channels_in = 1;
|
||||||
|
let kernel_size = 3;
|
||||||
|
let padding = 1;
|
||||||
|
let stride = 1;
|
||||||
|
|
||||||
|
let x = TestTensor::from_floats([[[-0.6309, -0.6112, -0.6998, -0.4708]]]);
|
||||||
|
let y = TestTensor::from_floats([[[-0.6112, -0.6112, -0.4708, -0.4708]]]);
|
||||||
|
|
||||||
|
let output = max_pool1d(x, kernel_size, stride, padding);
|
||||||
|
|
||||||
|
y.to_data().assert_approx_eq(&output.into_data(), 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_max_pool1d_with_indices() {
|
||||||
|
let batch_size = 1;
|
||||||
|
let channels_in = 1;
|
||||||
|
let kernel_size = 2;
|
||||||
|
let padding = 1;
|
||||||
|
let stride = 1;
|
||||||
|
|
||||||
|
let x = TestTensor::from_floats([[[0.2479, 0.6386, 0.3166, 0.5742]]]);
|
||||||
|
let indices = Data::<IntElem, 3>::from([[[0, 1, 1, 3, 3]]]);
|
||||||
|
let y = TestTensor::from_floats([[[0.2479, 0.6386, 0.6386, 0.5742, 0.5742]]]);
|
||||||
|
|
||||||
|
let (output, output_indices) = max_pool1d_with_indices(x, kernel_size, stride, padding);
|
||||||
|
|
||||||
|
y.to_data().assert_approx_eq(&output.into_data(), 3);
|
||||||
|
assert_eq!(indices.value, output_indices.into_data().value);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_max_pool1d_complex() {
|
||||||
|
let batch_size = 1;
|
||||||
|
let channels_in = 1;
|
||||||
|
let kernel_size = 4;
|
||||||
|
let padding = 2;
|
||||||
|
let stride = 1;
|
||||||
|
|
||||||
|
let x = TestTensor::from_floats([[[0.5388, 0.0676, 0.7122, 0.8316, 0.0653]]]);
|
||||||
|
let indices = Data::<IntElem, 3>::from([[[0, 2, 3, 3, 3, 3]]]);
|
||||||
|
let y = TestTensor::from_floats([[[0.5388, 0.7122, 0.8316, 0.8316, 0.8316, 0.8316]]]);
|
||||||
|
|
||||||
|
let (output, output_indices) = max_pool1d_with_indices(x, kernel_size, stride, padding);
|
||||||
|
|
||||||
|
y.to_data().assert_approx_eq(&output.into_data(), 3);
|
||||||
|
assert_eq!(indices.value, output_indices.into_data().value);
|
||||||
|
}
|
||||||
|
}
|
|
@ -7,4 +7,5 @@ mod conv2d;
|
||||||
mod conv_transpose1d;
|
mod conv_transpose1d;
|
||||||
mod conv_transpose2d;
|
mod conv_transpose2d;
|
||||||
mod forward;
|
mod forward;
|
||||||
|
mod maxpool1d;
|
||||||
mod maxpool2d;
|
mod maxpool2d;
|
||||||
|
|
Loading…
Reference in New Issue