From 34d233cd3e1bb160c913b3b65a47dc896c3c8659 Mon Sep 17 00:00:00 2001 From: Nathaniel Simard Date: Sat, 21 Jan 2023 15:39:21 -0500 Subject: [PATCH] Feat/max pooling backend (#152) --- burn-autodiff/src/ops/module.rs | 127 ++++++++++ burn-autodiff/src/tests/maxpool2d.rs | 85 +++++++ burn-autodiff/src/tests/mod.rs | 2 + burn-ndarray/src/lib.rs | 2 - burn-ndarray/src/{conv/mod.rs => ops/conv.rs} | 53 ++-- burn-ndarray/src/ops/maxpool.rs | 144 +++++++++++ burn-ndarray/src/ops/mod.rs | 4 + burn-ndarray/src/ops/module.rs | 51 +++- burn-ndarray/src/ops/padding.rs | 24 ++ burn-ndarray/src/ops/tensor.rs | 5 +- burn-tch/src/ops/module.rs | 90 ++++++- burn-tensor/src/tensor/module.rs | 28 +++ burn-tensor/src/tensor/ops/modules/base.rs | 44 ++++ burn-tensor/src/tests/mod.rs | 1 + burn-tensor/src/tests/module/maxpool2d.rs | 230 ++++++++++++++++++ burn-tensor/src/tests/module/mod.rs | 1 + 16 files changed, 849 insertions(+), 42 deletions(-) create mode 100644 burn-autodiff/src/tests/maxpool2d.rs rename burn-ndarray/src/{conv/mod.rs => ops/conv.rs} (74%) create mode 100644 burn-ndarray/src/ops/maxpool.rs create mode 100644 burn-ndarray/src/ops/padding.rs create mode 100644 burn-tensor/src/tests/module/maxpool2d.rs diff --git a/burn-autodiff/src/ops/module.rs b/burn-autodiff/src/ops/module.rs index b7a8b103f..e6969e8f0 100644 --- a/burn-autodiff/src/ops/module.rs +++ b/burn-autodiff/src/ops/module.rs @@ -104,6 +104,7 @@ impl ModuleOps> for ADBackendDecorator { if let Some(bias) = bias { order = usize::max(order, bias.node.order); } + order += 1; let ops = ForwardConv::::new( x.node.clone(), @@ -118,6 +119,78 @@ impl ModuleOps> for ADBackendDecorator { ADTensor { node, shape } } + + fn max_pool2d( + x: & as Backend>::TensorPrimitive<4>, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + ) -> as Backend>::TensorPrimitive<4> { + let output = B::max_pool2d_with_indexes(x.tensor_ref(), kernel_size, stride, padding); + let shape = *B::shape(&output.output); + let order = x.node.order + 1; + + let ops = ForwardMaxPool::::new( + x.node.clone(), + Arc::new(output.indexes), + kernel_size, + stride, + padding, + ); + let ops = Box::new(ops); + let state = ForwardNodeState::new(output.output); + let node = ForwardNode::new(order, state, ops); + let node = std::sync::Arc::new(node); + + ADTensor { node, shape } + } + + fn max_pool2d_with_indexes( + x: & as Backend>::TensorPrimitive<4>, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + ) -> MaxPool2dWithIndexes> { + let output = B::max_pool2d_with_indexes(x.tensor_ref(), kernel_size, stride, padding); + let shape = *B::shape(&output.output); + let order = x.node.order + 1; + + let ops = ForwardMaxPool::::new( + x.node.clone(), + Arc::new(output.indexes.clone()), + kernel_size, + stride, + padding, + ); + let ops = Box::new(ops); + let state = ForwardNodeState::new(output.output); + let node = ForwardNode::new(order, state, ops); + let node = std::sync::Arc::new(node); + + MaxPool2dWithIndexes::new(ADTensor { node, shape }, output.indexes) + } + + fn max_pool2d_with_indexes_backward( + x: & as Backend>::TensorPrimitive<4>, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + output_grad: & as Backend>::TensorPrimitive<4>, + indexes: &< as Backend>::IntegerBackend as Backend>::TensorPrimitive< + 4, + >, + ) -> MaxPool2dBackward> { + let tensor = B::max_pool2d_with_indexes_backward( + x.tensor_ref(), + kernel_size, + stride, + padding, + output_grad.tensor_ref(), + indexes, + ); + + MaxPool2dBackward::new(ADTensor::from_tensor(tensor.x_grad)) + } } #[derive(new, Debug)] @@ -231,3 +304,57 @@ impl BackwardRecordedOps> for BackwardConv { + x: ForwardNodeRef>, + indexes: Arc<::TensorPrimitive>, + kernel_size: [usize; D], + stride: [usize; D], + padding: [usize; D], +} + +#[derive(new, Debug)] +pub struct BackwardMaxPool { + x: BackwardNodeRef>, + indexes: Arc<::TensorPrimitive>, + kernel_size: [usize; D], + stride: [usize; D], + padding: [usize; D], +} + +impl ForwardRecordedOps> for ForwardMaxPool { + fn to_backward( + &self, + graph: &mut Forward2BackwardGraphConverter, + ) -> BackwardRecordedOpsBoxed> { + let ops = BackwardMaxPool::::new( + Arc::new(BackwardNode::from_node(&self.x, graph)), + self.indexes.clone(), + self.kernel_size, + self.stride, + self.padding, + ); + + Box::new(ops) + } +} + +impl BackwardRecordedOps> for BackwardMaxPool { + fn backward_step(&self, state: &BackwardNodeState>) { + let grads = B::max_pool2d_with_indexes_backward( + &self.x.state.value, + self.kernel_size, + self.stride, + self.padding, + &state.grad.borrow(), + &self.indexes, + ); + + self.x.state.update_grad(grads.x_grad); + } + + fn backward_parents(&self) -> Vec { + vec![self.x.clone()] + } +} diff --git a/burn-autodiff/src/tests/maxpool2d.rs b/burn-autodiff/src/tests/maxpool2d.rs new file mode 100644 index 000000000..380dda486 --- /dev/null +++ b/burn-autodiff/src/tests/maxpool2d.rs @@ -0,0 +1,85 @@ +#[burn_tensor_testgen::testgen(ad_max_pool2d)] +mod tests { + use super::*; + use burn_tensor::{module::max_pool2d, Data}; + + #[test] + fn test_max_pool2d_simple() { + let batch_size = 1; + let channels_in = 1; + let kernel_size_1 = 2; + let kernel_size_2 = 2; + let padding_1 = 1; + let padding_2 = 1; + let stride_1 = 1; + let stride_2 = 1; + + let x = TestADTensor::from_floats([[[ + [0.2479, 0.6386, 0.3166, 0.5742], + [0.7065, 0.1940, 0.6305, 0.8959], + [0.5416, 0.8602, 0.8129, 0.1662], + [0.3358, 0.3059, 0.8293, 0.0990], + ]]]); + let x_grad_expected = TestADTensor::from_floats([[[ + [1., 3., 0., 2.], + [3., 0., 0., 4.], + [1., 4., 0., 1.], + [2., 0., 3., 1.], + ]]]); + + let output = max_pool2d( + &x, + [kernel_size_1, kernel_size_2], + [stride_1, stride_2], + [padding_1, padding_2], + ); + let grads = output.backward(); + + // Asserts + let x_grad_actual = x.grad(&grads).unwrap(); + x_grad_expected + .to_data() + .assert_approx_eq(&x_grad_actual.to_data(), 3); + } + + #[test] + fn test_max_pool2d_complex() { + let batch_size = 1; + let channels_in = 1; + let kernel_size_1 = 4; + let kernel_size_2 = 2; + let padding_1 = 2; + let padding_2 = 1; + let stride_1 = 1; + let stride_2 = 2; + + let x = TestADTensor::from_floats([[[ + [0.5388, 0.0676, 0.7122, 0.8316, 0.0653], + [0.9154, 0.1536, 0.9089, 0.8016, 0.7518], + [0.2073, 0.0501, 0.8811, 0.5604, 0.5075], + [0.4384, 0.9963, 0.9698, 0.4988, 0.2609], + [0.3391, 0.2230, 0.4610, 0.5365, 0.6880], + ]]]); + let x_grad_expected = TestADTensor::from_floats([[[ + [0., 0., 0., 3., 0.], + [4., 0., 2., 1., 0.], + [0., 0., 0., 0., 0.], + [2., 4., 0., 0., 0.], + [0., 0., 0., 0., 2.], + ]]]); + + let output = max_pool2d( + &x, + [kernel_size_1, kernel_size_2], + [stride_1, stride_2], + [padding_1, padding_2], + ); + let grads = output.backward(); + + // Asserts + let x_grad_actual = x.grad(&grads).unwrap(); + x_grad_expected + .to_data() + .assert_approx_eq(&x_grad_actual.to_data(), 3); + } +} diff --git a/burn-autodiff/src/tests/mod.rs b/burn-autodiff/src/tests/mod.rs index 26e554916..60a0a0f54 100644 --- a/burn-autodiff/src/tests/mod.rs +++ b/burn-autodiff/src/tests/mod.rs @@ -13,6 +13,7 @@ mod index; mod log; mod mask; mod matmul; +mod maxpool2d; mod mul; mod multithread; mod neg; @@ -34,6 +35,7 @@ macro_rules! testgen_all { // Modules burn_autodiff::testgen_ad_conv1d!(); burn_autodiff::testgen_ad_conv2d!(); + burn_autodiff::testgen_ad_max_pool2d!(); burn_autodiff::testgen_module_backward!(); // Tensor diff --git a/burn-ndarray/src/lib.rs b/burn-ndarray/src/lib.rs index 14532f151..7f81d75a5 100644 --- a/burn-ndarray/src/lib.rs +++ b/burn-ndarray/src/lib.rs @@ -16,8 +16,6 @@ mod tensor; pub use backend::*; pub(crate) use tensor::*; -pub(crate) mod conv; - #[cfg(test)] mod tests { type TestBackend = crate::NdArrayBackend; diff --git a/burn-ndarray/src/conv/mod.rs b/burn-ndarray/src/ops/conv.rs similarity index 74% rename from burn-ndarray/src/conv/mod.rs rename to burn-ndarray/src/ops/conv.rs index 8936bc2fe..df676ea04 100644 --- a/burn-ndarray/src/conv/mod.rs +++ b/burn-ndarray/src/ops/conv.rs @@ -1,9 +1,32 @@ +use super::padding::apply_padding2d; use crate::{element::NdArrayElement, tensor::NdArrayTensor, NdArrayBackend, NdArrayDevice}; use burn_tensor::{ops::TensorOps, Shape}; /// This method is not the most efficient, but it serves as a basic implementation that is easy to understand. /// A more optimized version should be used in its place. pub(crate) fn conv2d_naive( + x: &NdArrayTensor, + weight: &NdArrayTensor, + bias: Option<&NdArrayTensor>, + stride: [usize; 2], + padding: [usize; 2], +) -> NdArrayTensor { + let [batch_size, channels_in, heigth, width] = x.shape.dims; + let mut results = Vec::with_capacity(batch_size); + + for b in 0..batch_size { + let x = NdArrayBackend::index(x, [b..b + 1, 0..channels_in, 0..heigth, 0..width]); + let x = NdArrayBackend::reshape(&x, Shape::new([channels_in, heigth, width])); + + results.push(conv2d_naive_no_batch_size( + &x, weight, bias, stride, padding, + )); + } + + NdArrayBackend::cat(&results, 0) +} + +pub(crate) fn conv2d_naive_no_batch_size( x: &NdArrayTensor, weight: &NdArrayTensor, bias: Option<&NdArrayTensor>, @@ -11,6 +34,7 @@ pub(crate) fn conv2d_naive( padding: [usize; 2], ) -> NdArrayTensor { let [channels_out, channels_in, k1, k2] = weight.shape.dims; + let [_, heigth, width] = x.shape.dims; let mut results = Vec::new(); for co in 0..channels_out { @@ -20,7 +44,9 @@ pub(crate) fn conv2d_naive( let kernel = NdArrayBackend::index(weight, [co..co + 1, ci..ci + 1, 0..k1, 0..k2]); let kernel = NdArrayBackend::reshape(&kernel, Shape::new([k1, k2])); - let x = apply_padding(x, ci, padding); + let x = NdArrayBackend::index(x, [ci..ci + 1, 0..heigth, 0..width]); + let x = NdArrayBackend::reshape(&x, Shape::new([heigth, width])); + let x = apply_padding2d(&x, padding); let matrix = conv2d_with_kernel(x, kernel, stride); let [heigth, width] = matrix.shape.dims; @@ -45,31 +71,6 @@ pub(crate) fn conv2d_naive( result } -fn apply_padding( - x: &NdArrayTensor, - channel: usize, - padding: [usize; 2], -) -> NdArrayTensor { - let [_, heigth, width] = x.shape.dims; - let heigth_new = heigth + (2 * padding[0]); - let width_new = width + (2 * padding[1]); - - let x = NdArrayBackend::index(x, [channel..channel + 1, 0..heigth, 0..width]); - let x = NdArrayBackend::reshape(&x, Shape::new([heigth, width])); - - let mut x_new = NdArrayBackend::zeros(Shape::new([heigth_new, width_new]), NdArrayDevice::Cpu); - x_new = NdArrayBackend::index_assign( - &x_new, - [ - padding[0]..heigth + padding[0], - padding[1]..width + padding[1], - ], - &x, - ); - - x_new -} - fn conv2d_with_kernel( x: NdArrayTensor, kernel: NdArrayTensor, diff --git a/burn-ndarray/src/ops/maxpool.rs b/burn-ndarray/src/ops/maxpool.rs new file mode 100644 index 000000000..71a5fa05f --- /dev/null +++ b/burn-ndarray/src/ops/maxpool.rs @@ -0,0 +1,144 @@ +use super::padding::apply_padding2d; +use crate::{element::NdArrayElement, tensor::NdArrayTensor, NdArrayBackend, NdArrayDevice}; +use burn_tensor::{ops::TensorOps, Data, Shape}; + +/// This method is not the most efficient, but it serves as a basic implementation that is easy to understand. +/// A more optimized version should be used in its place. +pub(crate) fn max_pool2d_with_indexes_naive( + x: &NdArrayTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], +) -> (NdArrayTensor, NdArrayTensor) { + let mut batches = Vec::new(); + let mut batches_indexes = Vec::new(); + + let [batch_size, channels, heigth, width] = x.shape.dims; + + for b in 0..batch_size { + let mut batch = Vec::new(); + let mut batch_indexes = Vec::new(); + + for c in 0..channels { + let x = NdArrayBackend::index(x, [b..b + 1, c..c + 1, 0..heigth, 0..width]); + let x = NdArrayBackend::reshape(&x, Shape::new([heigth, width])); + let x = apply_padding2d(&x, padding); + + let (matrix, indexes) = max_pool2d_with_kernel(x, kernel_size, stride, padding); + let [heigth, width] = matrix.shape.dims; + + let matrix = NdArrayBackend::reshape(&matrix, Shape::new([1, 1, heigth, width])); + let indexes = NdArrayBackend::reshape(&indexes, Shape::new([1, 1, heigth, width])); + + batch.push(matrix); + batch_indexes.push(indexes); + } + let batch = NdArrayBackend::cat(&batch, 1); + let batch_indexes = NdArrayBackend::cat(&batch_indexes, 1); + + batches.push(batch); + batches_indexes.push(batch_indexes); + } + + ( + NdArrayBackend::cat(&batches, 0), + NdArrayBackend::cat(&batches_indexes, 0), + ) +} + +pub(crate) fn max_pool2d_backward_naive( + x: &NdArrayTensor, + _kernel_size: [usize; 2], + _stride: [usize; 2], + _padding: [usize; 2], + output_grad: &NdArrayTensor, + indexes: &NdArrayTensor, +) -> NdArrayTensor { + let [_batch_size, _channels, heigth, width] = output_grad.shape.dims; + let [batch_size, channels, heigth_x, width_x] = x.shape.dims; + + let output_grad_flatten = NdArrayBackend::reshape( + output_grad, + Shape::new([batch_size, channels, heigth * width]), + ); + let indexes_flatten = + NdArrayBackend::reshape(indexes, Shape::new([batch_size, channels, heigth * width])); + let mut output_flatten = NdArrayBackend::zeros( + Shape::new([batch_size, channels, heigth_x * width_x]), + NdArrayDevice::Cpu, + ); + + for b in 0..batch_size { + for c in 0..channels { + for i in 0..(heigth * width) { + let index = NdArrayBackend::index(&indexes_flatten, [b..b + 1, c..c + 1, i..i + 1]); + let index = NdArrayBackend::into_data(index).value[0] as usize; + + let current_value = + NdArrayBackend::index(&output_flatten, [b..b + 1, c..c + 1, index..index + 1]); + let output_grad = + NdArrayBackend::index(&output_grad_flatten, [b..b + 1, c..c + 1, i..i + 1]); + let updated_value = NdArrayBackend::add(¤t_value, &output_grad); + + output_flatten = NdArrayBackend::index_assign( + &output_flatten, + [b..b + 1, c..c + 1, index..index + 1], + &updated_value, + ); + } + } + } + + NdArrayBackend::reshape(&output_flatten, x.shape) +} + +fn max_pool2d_with_kernel( + x: NdArrayTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], +) -> (NdArrayTensor, NdArrayTensor) { + let [k1, k2] = kernel_size; + let [p1, p2] = padding; + let [heigth, width] = x.shape.dims; + + let heigth_new = f32::ceil((heigth - k1 + 1) as f32 / stride[0] as f32) as usize; + let width_new = f32::ceil((width - k2 + 1) as f32 / stride[1] as f32) as usize; + let mut output = NdArrayBackend::empty(Shape::new([heigth_new, width_new]), NdArrayDevice::Cpu); + let mut indexes = + NdArrayBackend::empty(Shape::new([heigth_new, width_new]), NdArrayDevice::Cpu); + + for i in 0..heigth_new { + for j in 0..width_new { + let i_x = i * stride[0]; + let j_x = j * stride[1]; + + let x_ij = NdArrayBackend::index(&x, [i_x..i_x + k1, j_x..j_x + k2]); + let x_flatten = NdArrayBackend::reshape(&x_ij, Shape::new([k1 * k2])); + let index = NdArrayBackend::argmax(&x_flatten, 0); + let index = NdArrayBackend::into_data(index).value[0]; + let value = NdArrayBackend::into_data(x_flatten).value[index as usize]; + let value = NdArrayBackend::from_data( + Data::new(vec![value], Shape::new([1, 1])), + NdArrayDevice::Cpu, + ); + + let index_i = index / k2 as i64; + let index_j = index - (index_i * k2 as i64); + let ii = i64::max(0, i_x as i64 - p1 as i64 + index_i); + let jj = i64::max(0, j_x as i64 - p2 as i64 + index_j); + let h = (heigth - (2 * p1)) as i64; + let index = ii * h + jj; + + let index = NdArrayBackend::from_data( + Data::new(vec![index], Shape::new([1, 1])), + NdArrayDevice::Cpu, + ); + + indexes = NdArrayBackend::index_assign(&indexes, [i..i + 1, j..j + 1], &index); + output = NdArrayBackend::index_assign(&output, [i..i + 1, j..j + 1], &value); + } + } + + (output, indexes) +} diff --git a/burn-ndarray/src/ops/mod.rs b/burn-ndarray/src/ops/mod.rs index 66111d0fd..4833c014a 100644 --- a/burn-ndarray/src/ops/mod.rs +++ b/burn-ndarray/src/ops/mod.rs @@ -1,3 +1,7 @@ mod creation; mod module; mod tensor; + +pub(crate) mod conv; +pub(crate) mod maxpool; +pub(crate) mod padding; diff --git a/burn-ndarray/src/ops/module.rs b/burn-ndarray/src/ops/module.rs index 1a6cd3c05..faff6c2f1 100644 --- a/burn-ndarray/src/ops/module.rs +++ b/burn-ndarray/src/ops/module.rs @@ -1,7 +1,12 @@ -use crate::{conv::conv2d_naive, element::NdArrayElement, tensor::NdArrayTensor, NdArrayBackend}; +use crate::{element::NdArrayElement, tensor::NdArrayTensor, NdArrayBackend}; use burn_tensor::{ops::*, Shape}; use std::ops::Add; +use super::{ + conv::conv2d_naive, + maxpool::{max_pool2d_backward_naive, max_pool2d_with_indexes_naive}, +}; + impl ModuleOps> for NdArrayBackend { fn embedding( weights: &NdArrayTensor, @@ -68,16 +73,44 @@ impl ModuleOps> for NdArrayBackend { stride: [usize; 2], padding: [usize; 2], ) -> NdArrayTensor { - let [batch_size, channels_in, heigth, width] = x.shape.dims; - let mut results = Vec::with_capacity(batch_size); + conv2d_naive(x, weight, bias, stride, padding) + } - for b in 0..batch_size { - let x = NdArrayBackend::index(x, [b..b + 1, 0..channels_in, 0..heigth, 0..width]); - let x = NdArrayBackend::reshape(&x, Shape::new([channels_in, heigth, width])); + fn max_pool2d( + x: &NdArrayTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + ) -> NdArrayTensor { + max_pool2d_with_indexes_naive(x, kernel_size, stride, padding).0 + } - results.push(conv2d_naive(&x, weight, bias, stride, padding)); - } + fn max_pool2d_with_indexes( + x: &NdArrayTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + ) -> MaxPool2dWithIndexes> { + let (output, indexes) = max_pool2d_with_indexes_naive(x, kernel_size, stride, padding); - NdArrayBackend::cat(&results, 0) + MaxPool2dWithIndexes::new(output, indexes) + } + + fn max_pool2d_with_indexes_backward( + x: &NdArrayTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + output_grad: &NdArrayTensor, + indexes: &NdArrayTensor, + ) -> MaxPool2dBackward> { + MaxPool2dBackward::new(max_pool2d_backward_naive( + x, + kernel_size, + stride, + padding, + output_grad, + indexes, + )) } } diff --git a/burn-ndarray/src/ops/padding.rs b/burn-ndarray/src/ops/padding.rs new file mode 100644 index 000000000..45a779e2c --- /dev/null +++ b/burn-ndarray/src/ops/padding.rs @@ -0,0 +1,24 @@ +use crate::{element::NdArrayElement, tensor::NdArrayTensor, NdArrayBackend, NdArrayDevice}; +use burn_tensor::{ops::TensorOps, Shape}; + +pub(crate) fn apply_padding2d( + x: &NdArrayTensor, + padding: [usize; 2], +) -> NdArrayTensor { + let [heigth, width] = x.shape.dims; + + let heigth_new = heigth + (2 * padding[0]); + let width_new = width + (2 * padding[1]); + + let mut x_new = NdArrayBackend::zeros(Shape::new([heigth_new, width_new]), NdArrayDevice::Cpu); + x_new = NdArrayBackend::index_assign( + &x_new, + [ + padding[0]..heigth + padding[0], + padding[1]..width + padding[1], + ], + x, + ); + + x_new +} diff --git a/burn-ndarray/src/ops/tensor.rs b/burn-ndarray/src/ops/tensor.rs index 5c60af81b..e2010fd86 100644 --- a/burn-ndarray/src/ops/tensor.rs +++ b/burn-ndarray/src/ops/tensor.rs @@ -1,6 +1,3 @@ -use std::cmp::Ordering; -use std::ops::Range; - use crate::tensor::BatchMatrix; use crate::{element::NdArrayElement, tensor::NdArrayTensor, NdArrayBackend}; use crate::{to_nd_array_tensor, NdArrayDevice, SEED}; @@ -9,6 +6,8 @@ use burn_tensor::{backend::Backend, ops::TensorOps, Data, ElementConversion, Sha use ndarray::{ArcArray, Axis, Dim, Dimension, IxDyn, SliceInfoElem}; use rand::rngs::StdRng; use rand::SeedableRng; +use std::cmp::Ordering; +use std::ops::Range; macro_rules! keepdim { ( diff --git a/burn-tch/src/ops/module.rs b/burn-tch/src/ops/module.rs index 3875f38ab..a5190ebad 100644 --- a/burn-tch/src/ops/module.rs +++ b/burn-tch/src/ops/module.rs @@ -1,5 +1,8 @@ -use crate::{element::TchElement, TchBackend, TchTensor}; -use burn_tensor::{ops::ModuleOps, Shape}; +use crate::{element::TchElement, TchBackend, TchKind, TchTensor}; +use burn_tensor::{ + ops::{MaxPool2dBackward, MaxPool2dWithIndexes, ModuleOps}, + Shape, +}; impl ModuleOps> for TchBackend { fn embedding(weights: &TchTensor, indexes: &TchTensor) -> TchTensor { @@ -85,4 +88,87 @@ impl ModuleOps> for TchBackend { shape, } } + + fn max_pool2d( + x: &TchTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + ) -> TchTensor { + let tensor = tch::Tensor::max_pool2d( + &x.tensor, + &[kernel_size[0] as i64, kernel_size[1] as i64], + &[stride[0] as i64, stride[1] as i64], + &[padding[0] as i64, padding[1] as i64], + &[1, 1], + false, + ); + let shape = Shape::from(tensor.size()); + + TchTensor { + kind: x.kind, + tensor, + shape, + } + } + + fn max_pool2d_with_indexes( + x: &TchTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + ) -> MaxPool2dWithIndexes> { + let (tensor, indexes) = tch::Tensor::max_pool2d_with_indices( + &x.tensor, + &[kernel_size[0] as i64, kernel_size[1] as i64], + &[stride[0] as i64, stride[1] as i64], + &[padding[0] as i64, padding[1] as i64], + &[1, 1], + false, + ); + let shape = Shape::from(tensor.size()); + + let output = TchTensor { + kind: x.kind, + tensor, + shape, + }; + let shape = Shape::from(indexes.size()); + let indexes = TchTensor { + kind: TchKind::::new(), + tensor: indexes, + shape, + }; + + MaxPool2dWithIndexes::new(output, indexes) + } + + fn max_pool2d_with_indexes_backward( + x: &TchTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + output_grad: &TchTensor, + indexes: &TchTensor, + ) -> MaxPool2dBackward> { + let grad = tch::Tensor::max_pool2d_with_indices_backward( + &x.tensor, + &output_grad.tensor, + &[kernel_size[0] as i64, kernel_size[1] as i64], + &[stride[0] as i64, stride[1] as i64], + &[padding[0] as i64, padding[1] as i64], + &[1, 1], + false, + &indexes.tensor, + ); + + let shape = Shape::from(grad.size()); + let tensor = TchTensor { + kind: x.kind, + tensor: grad, + shape, + }; + + MaxPool2dBackward::new(tensor) + } } diff --git a/burn-tensor/src/tensor/module.rs b/burn-tensor/src/tensor/module.rs index 23aa47832..a9ab4ad19 100644 --- a/burn-tensor/src/tensor/module.rs +++ b/burn-tensor/src/tensor/module.rs @@ -47,3 +47,31 @@ where padding, )) } + +/// Applies a [2D max pooling](crate::ops::ModuleOps::max_pool2d). +pub fn max_pool2d( + x: &Tensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], +) -> Tensor +where + B: Backend, +{ + Tensor::new(B::max_pool2d(&x.value, kernel_size, stride, padding)) +} + +/// Applies a [2D max pooling with indexes](crate::ops::ModuleOps::max_pool2d_with_indexes). +pub fn max_pool2d_with_indexes( + x: &Tensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], +) -> (Tensor, Tensor) +where + B: Backend, +{ + let output = B::max_pool2d_with_indexes(&x.value, kernel_size, stride, padding); + + (Tensor::new(output.output), Tensor::new(output.indexes)) +} diff --git a/burn-tensor/src/tensor/ops/modules/base.rs b/burn-tensor/src/tensor/ops/modules/base.rs index 718817054..206db4994 100644 --- a/burn-tensor/src/tensor/ops/modules/base.rs +++ b/burn-tensor/src/tensor/ops/modules/base.rs @@ -9,6 +9,19 @@ pub struct Conv2dBackward { pub bias_grad: Option>, } +/// Gradient computed during the backward pass for each tensor used by [max_pool2d](ModuleOps::max_pool2d). +#[derive(new)] +pub struct MaxPool2dBackward { + pub x_grad: B::TensorPrimitive<4>, +} + +/// Results from [max_pool2d](ModuleOps::max_pool2d_with_indexes). +#[derive(new)] +pub struct MaxPool2dWithIndexes { + pub output: B::TensorPrimitive<4>, + pub indexes: ::TensorPrimitive<4>, +} + /// Gradient computed during the backward pass for each tensor used by [conv1d](ModuleOps::conv1d). #[derive(new)] pub struct Conv1dBackward { @@ -77,4 +90,35 @@ pub trait ModuleOps { ) -> Conv1dBackward { conv::conv1d_backward(x, weight, bias, stride, output_grad) } + /// Two dimensional max pooling. + /// + /// # Shapes + /// + /// x: [batch_size, channels, height, width], + fn max_pool2d( + x: &B::TensorPrimitive<4>, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + ) -> B::TensorPrimitive<4>; + /// Two dimensional max pooling with indexes. + /// + /// # Shapes + /// + /// x: [batch_size, channels, height, width], + fn max_pool2d_with_indexes( + x: &B::TensorPrimitive<4>, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + ) -> MaxPool2dWithIndexes; + /// Backward pass for the [max pooling 2d](ModuleOps::max_pool2d_with_indexes) operation. + fn max_pool2d_with_indexes_backward( + x: &B::TensorPrimitive<4>, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + output_grad: &B::TensorPrimitive<4>, + indexes: &::TensorPrimitive<4>, + ) -> MaxPool2dBackward; } diff --git a/burn-tensor/src/tests/mod.rs b/burn-tensor/src/tests/mod.rs index bc5e2ab3c..12804a20e 100644 --- a/burn-tensor/src/tests/mod.rs +++ b/burn-tensor/src/tests/mod.rs @@ -15,6 +15,7 @@ macro_rules! testgen_all { burn_tensor::testgen_module_forward!(); burn_tensor::testgen_module_conv1d!(); burn_tensor::testgen_module_conv2d!(); + burn_tensor::testgen_module_max_pool2d!(); // test ops burn_tensor::testgen_add!(); diff --git a/burn-tensor/src/tests/module/maxpool2d.rs b/burn-tensor/src/tests/module/maxpool2d.rs new file mode 100644 index 000000000..4e4ea0135 --- /dev/null +++ b/burn-tensor/src/tests/module/maxpool2d.rs @@ -0,0 +1,230 @@ +#[burn_tensor_testgen::testgen(module_max_pool2d)] +mod tests { + use super::*; + use burn_tensor::module::{max_pool2d, max_pool2d_with_indexes}; + use burn_tensor::{Data, Tensor}; + + #[test] + fn test_max_pool2d_simple() { + let batch_size = 2; + let channels_in = 2; + let kernel_size_1 = 3; + let kernel_size_2 = 3; + let padding_1 = 1; + let padding_2 = 1; + let stride_1 = 1; + let stride_2 = 1; + + let x = TestTensor::from_floats([ + [ + [ + [0.9861, 0.5474, 0.4477, 0.0732, 0.3548, 0.8221], + [0.8148, 0.5474, 0.9490, 0.7890, 0.5537, 0.5689], + [0.5986, 0.2059, 0.4897, 0.6136, 0.2965, 0.6182], + [0.1485, 0.9540, 0.4023, 0.6176, 0.7111, 0.3392], + [0.3703, 0.0472, 0.2771, 0.1868, 0.8855, 0.5605], + [0.5063, 0.1638, 0.9432, 0.7836, 0.8696, 0.1068], + ], + [ + [0.8872, 0.0137, 0.1652, 0.5505, 0.6127, 0.6473], + [0.1128, 0.0888, 0.1152, 0.5456, 0.6199, 0.7947], + [0.5911, 0.7781, 0.7256, 0.6578, 0.0989, 0.9149], + [0.5879, 0.5189, 0.6561, 0.0578, 0.7025, 0.6426], + [0.9590, 0.0325, 0.6455, 0.6248, 0.2009, 0.1544], + [0.7339, 0.1369, 0.6598, 0.5528, 0.6775, 0.1572], + ], + ], + [ + [ + [0.6853, 0.6439, 0.4639, 0.5573, 0.2723, 0.5910], + [0.5419, 0.7729, 0.6743, 0.8956, 0.2997, 0.9546], + [0.0334, 0.2178, 0.6917, 0.4958, 0.3357, 0.6584], + [0.7358, 0.9074, 0.2462, 0.5159, 0.6420, 0.2441], + [0.7602, 0.6297, 0.6073, 0.5937, 0.8037, 0.4881], + [0.8859, 0.0974, 0.3954, 0.6763, 0.1078, 0.7467], + ], + [ + [0.2991, 0.5012, 0.8024, 0.7653, 0.9378, 0.7952], + [0.7393, 0.2336, 0.9521, 0.2719, 0.8445, 0.0454], + [0.6479, 0.9822, 0.7905, 0.0318, 0.2474, 0.0628], + [0.9955, 0.7591, 0.4140, 0.3215, 0.4349, 0.1527], + [0.8064, 0.0164, 0.4002, 0.2024, 0.6128, 0.5827], + [0.5368, 0.7895, 0.8727, 0.7793, 0.0910, 0.3421], + ], + ], + ]); + let y = TestTensor::from_floats([ + [ + [ + [0.9861, 0.9861, 0.9490, 0.9490, 0.8221, 0.8221], + [0.9861, 0.9861, 0.9490, 0.9490, 0.8221, 0.8221], + [0.9540, 0.9540, 0.9540, 0.9490, 0.7890, 0.7111], + [0.9540, 0.9540, 0.9540, 0.8855, 0.8855, 0.8855], + [0.9540, 0.9540, 0.9540, 0.9432, 0.8855, 0.8855], + [0.5063, 0.9432, 0.9432, 0.9432, 0.8855, 0.8855], + ], + [ + [0.8872, 0.8872, 0.5505, 0.6199, 0.7947, 0.7947], + [0.8872, 0.8872, 0.7781, 0.7256, 0.9149, 0.9149], + [0.7781, 0.7781, 0.7781, 0.7256, 0.9149, 0.9149], + [0.9590, 0.9590, 0.7781, 0.7256, 0.9149, 0.9149], + [0.9590, 0.9590, 0.6598, 0.7025, 0.7025, 0.7025], + [0.9590, 0.9590, 0.6598, 0.6775, 0.6775, 0.6775], + ], + ], + [ + [ + [0.7729, 0.7729, 0.8956, 0.8956, 0.9546, 0.9546], + [0.7729, 0.7729, 0.8956, 0.8956, 0.9546, 0.9546], + [0.9074, 0.9074, 0.9074, 0.8956, 0.9546, 0.9546], + [0.9074, 0.9074, 0.9074, 0.8037, 0.8037, 0.8037], + [0.9074, 0.9074, 0.9074, 0.8037, 0.8037, 0.8037], + [0.8859, 0.8859, 0.6763, 0.8037, 0.8037, 0.8037], + ], + [ + [0.7393, 0.9521, 0.9521, 0.9521, 0.9378, 0.9378], + [0.9822, 0.9822, 0.9822, 0.9521, 0.9378, 0.9378], + [0.9955, 0.9955, 0.9822, 0.9521, 0.8445, 0.8445], + [0.9955, 0.9955, 0.9822, 0.7905, 0.6128, 0.6128], + [0.9955, 0.9955, 0.8727, 0.8727, 0.7793, 0.6128], + [0.8064, 0.8727, 0.8727, 0.8727, 0.7793, 0.6128], + ], + ], + ]); + + let output = max_pool2d( + &x, + [kernel_size_1, kernel_size_2], + [stride_1, stride_2], + [padding_1, padding_2], + ); + + y.to_data().assert_approx_eq(&output.into_data(), 3); + } + + #[test] + fn test_max_pool2d_different_padding_stride_kernel() { + let batch_size = 1; + let channels_in = 1; + let kernel_size_1 = 3; + let kernel_size_2 = 1; + let padding_1 = 1; + let padding_2 = 0; + let stride_1 = 1; + let stride_2 = 2; + + let x = TestTensor::from_floats([[[ + [0.6309, 0.6112, 0.6998], + [0.4708, 0.9161, 0.5402], + [0.4577, 0.7397, 0.9870], + [0.6380, 0.4352, 0.5884], + [0.6277, 0.5139, 0.4525], + [0.9333, 0.9846, 0.5006], + ]]]); + let y = TestTensor::from_floats([[[ + [0.6309, 0.6998], + [0.6309, 0.9870], + [0.6380, 0.9870], + [0.6380, 0.9870], + [0.9333, 0.5884], + [0.9333, 0.5006], + ]]]); + + let output = max_pool2d( + &x, + [kernel_size_1, kernel_size_2], + [stride_1, stride_2], + [padding_1, padding_2], + ); + + y.to_data().assert_approx_eq(&output.into_data(), 3); + } + + #[test] + fn test_max_pool2d_with_indexes() { + let batch_size = 1; + let channels_in = 1; + let kernel_size_1 = 2; + let kernel_size_2 = 2; + let padding_1 = 1; + let padding_2 = 1; + let stride_1 = 1; + let stride_2 = 1; + + let x = TestTensor::from_floats([[[ + [0.2479, 0.6386, 0.3166, 0.5742], + [0.7065, 0.1940, 0.6305, 0.8959], + [0.5416, 0.8602, 0.8129, 0.1662], + [0.3358, 0.3059, 0.8293, 0.0990], + ]]]); + let indexes = Data::::from([[[ + [0, 1, 1, 3, 3], + [4, 4, 1, 7, 7], + [4, 9, 9, 7, 7], + [8, 9, 9, 14, 11], + [12, 12, 14, 14, 15], + ]]]); + let y = TestTensor::from_floats([[[ + [0.2479, 0.6386, 0.6386, 0.5742, 0.5742], + [0.7065, 0.7065, 0.6386, 0.8959, 0.8959], + [0.7065, 0.8602, 0.8602, 0.8959, 0.8959], + [0.5416, 0.8602, 0.8602, 0.8293, 0.1662], + [0.3358, 0.3358, 0.8293, 0.8293, 0.0990], + ]]]); + + let (output, output_indexes) = max_pool2d_with_indexes( + &x, + [kernel_size_1, kernel_size_2], + [stride_1, stride_2], + [padding_1, padding_2], + ); + + y.to_data().assert_approx_eq(&output.into_data(), 3); + assert_eq!(indexes.value, output_indexes.into_data().value); + } + + #[test] + fn test_max_pool2d_complex() { + let batch_size = 1; + let channels_in = 1; + let kernel_size_1 = 4; + let kernel_size_2 = 2; + let padding_1 = 2; + let padding_2 = 1; + let stride_1 = 1; + let stride_2 = 2; + + let x = TestTensor::from_floats([[[ + [0.5388, 0.0676, 0.7122, 0.8316, 0.0653], + [0.9154, 0.1536, 0.9089, 0.8016, 0.7518], + [0.2073, 0.0501, 0.8811, 0.5604, 0.5075], + [0.4384, 0.9963, 0.9698, 0.4988, 0.2609], + [0.3391, 0.2230, 0.4610, 0.5365, 0.6880], + ]]]); + let indexes = Data::::from([[[ + [5, 7, 3], + [5, 7, 3], + [5, 16, 3], + [5, 16, 8], + [15, 16, 24], + [15, 16, 24], + ]]]); + let y = TestTensor::from_floats([[[ + [0.9154, 0.9089, 0.8316], + [0.9154, 0.9089, 0.8316], + [0.9154, 0.9963, 0.8316], + [0.9154, 0.9963, 0.8016], + [0.4384, 0.9963, 0.688], + [0.4384, 0.9963, 0.688], + ]]]); + let (output, output_indexes) = max_pool2d_with_indexes( + &x, + [kernel_size_1, kernel_size_2], + [stride_1, stride_2], + [padding_1, padding_2], + ); + + y.to_data().assert_approx_eq(&output.into_data(), 3); + assert_eq!(indexes.value, output_indexes.into_data().value); + } +} diff --git a/burn-tensor/src/tests/module/mod.rs b/burn-tensor/src/tests/module/mod.rs index 0574a48f7..571d8b77c 100644 --- a/burn-tensor/src/tests/module/mod.rs +++ b/burn-tensor/src/tests/module/mod.rs @@ -1,3 +1,4 @@ mod conv1d; mod conv2d; mod forward; +mod maxpool2d;