Feat/max pooling backend (#152)

This commit is contained in:
Nathaniel Simard 2023-01-21 15:39:21 -05:00 committed by GitHub
parent 1d0d92a269
commit 34d233cd3e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 849 additions and 42 deletions

View File

@ -104,6 +104,7 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
if let Some(bias) = bias {
order = usize::max(order, bias.node.order);
}
order += 1;
let ops = ForwardConv::<B, 1, 3>::new(
x.node.clone(),
@ -118,6 +119,78 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
ADTensor { node, shape }
}
fn max_pool2d(
x: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<4> {
let output = B::max_pool2d_with_indexes(x.tensor_ref(), kernel_size, stride, padding);
let shape = *B::shape(&output.output);
let order = x.node.order + 1;
let ops = ForwardMaxPool::<B, 2, 4>::new(
x.node.clone(),
Arc::new(output.indexes),
kernel_size,
stride,
padding,
);
let ops = Box::new(ops);
let state = ForwardNodeState::new(output.output);
let node = ForwardNode::new(order, state, ops);
let node = std::sync::Arc::new(node);
ADTensor { node, shape }
}
fn max_pool2d_with_indexes(
x: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
) -> MaxPool2dWithIndexes<ADBackendDecorator<B>> {
let output = B::max_pool2d_with_indexes(x.tensor_ref(), kernel_size, stride, padding);
let shape = *B::shape(&output.output);
let order = x.node.order + 1;
let ops = ForwardMaxPool::<B, 2, 4>::new(
x.node.clone(),
Arc::new(output.indexes.clone()),
kernel_size,
stride,
padding,
);
let ops = Box::new(ops);
let state = ForwardNodeState::new(output.output);
let node = ForwardNode::new(order, state, ops);
let node = std::sync::Arc::new(node);
MaxPool2dWithIndexes::new(ADTensor { node, shape }, output.indexes)
}
fn max_pool2d_with_indexes_backward(
x: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
output_grad: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<4>,
indexes: &<<ADBackendDecorator<B> as Backend>::IntegerBackend as Backend>::TensorPrimitive<
4,
>,
) -> MaxPool2dBackward<ADBackendDecorator<B>> {
let tensor = B::max_pool2d_with_indexes_backward(
x.tensor_ref(),
kernel_size,
stride,
padding,
output_grad.tensor_ref(),
indexes,
);
MaxPool2dBackward::new(ADTensor::from_tensor(tensor.x_grad))
}
}
#[derive(new, Debug)]
@ -231,3 +304,57 @@ impl<B: Backend> BackwardRecordedOps<B::TensorPrimitive<3>> for BackwardConv<B,
}
}
}
#[derive(new, Debug)]
pub struct ForwardMaxPool<B: Backend, const D: usize, const S: usize> {
x: ForwardNodeRef<B::TensorPrimitive<S>>,
indexes: Arc<<B::IntegerBackend as Backend>::TensorPrimitive<S>>,
kernel_size: [usize; D],
stride: [usize; D],
padding: [usize; D],
}
#[derive(new, Debug)]
pub struct BackwardMaxPool<B: Backend, const D: usize, const S: usize> {
x: BackwardNodeRef<B::TensorPrimitive<S>>,
indexes: Arc<<B::IntegerBackend as Backend>::TensorPrimitive<S>>,
kernel_size: [usize; D],
stride: [usize; D],
padding: [usize; D],
}
impl<B: Backend> ForwardRecordedOps<B::TensorPrimitive<4>> for ForwardMaxPool<B, 2, 4> {
fn to_backward(
&self,
graph: &mut Forward2BackwardGraphConverter,
) -> BackwardRecordedOpsBoxed<B::TensorPrimitive<4>> {
let ops = BackwardMaxPool::<B, 2, 4>::new(
Arc::new(BackwardNode::from_node(&self.x, graph)),
self.indexes.clone(),
self.kernel_size,
self.stride,
self.padding,
);
Box::new(ops)
}
}
impl<B: Backend> BackwardRecordedOps<B::TensorPrimitive<4>> for BackwardMaxPool<B, 2, 4> {
fn backward_step(&self, state: &BackwardNodeState<B::TensorPrimitive<4>>) {
let grads = B::max_pool2d_with_indexes_backward(
&self.x.state.value,
self.kernel_size,
self.stride,
self.padding,
&state.grad.borrow(),
&self.indexes,
);
self.x.state.update_grad(grads.x_grad);
}
fn backward_parents(&self) -> Vec<RecordedOpsParentRef> {
vec![self.x.clone()]
}
}

View File

@ -0,0 +1,85 @@
#[burn_tensor_testgen::testgen(ad_max_pool2d)]
mod tests {
use super::*;
use burn_tensor::{module::max_pool2d, Data};
#[test]
fn test_max_pool2d_simple() {
let batch_size = 1;
let channels_in = 1;
let kernel_size_1 = 2;
let kernel_size_2 = 2;
let padding_1 = 1;
let padding_2 = 1;
let stride_1 = 1;
let stride_2 = 1;
let x = TestADTensor::from_floats([[[
[0.2479, 0.6386, 0.3166, 0.5742],
[0.7065, 0.1940, 0.6305, 0.8959],
[0.5416, 0.8602, 0.8129, 0.1662],
[0.3358, 0.3059, 0.8293, 0.0990],
]]]);
let x_grad_expected = TestADTensor::from_floats([[[
[1., 3., 0., 2.],
[3., 0., 0., 4.],
[1., 4., 0., 1.],
[2., 0., 3., 1.],
]]]);
let output = max_pool2d(
&x,
[kernel_size_1, kernel_size_2],
[stride_1, stride_2],
[padding_1, padding_2],
);
let grads = output.backward();
// Asserts
let x_grad_actual = x.grad(&grads).unwrap();
x_grad_expected
.to_data()
.assert_approx_eq(&x_grad_actual.to_data(), 3);
}
#[test]
fn test_max_pool2d_complex() {
let batch_size = 1;
let channels_in = 1;
let kernel_size_1 = 4;
let kernel_size_2 = 2;
let padding_1 = 2;
let padding_2 = 1;
let stride_1 = 1;
let stride_2 = 2;
let x = TestADTensor::from_floats([[[
[0.5388, 0.0676, 0.7122, 0.8316, 0.0653],
[0.9154, 0.1536, 0.9089, 0.8016, 0.7518],
[0.2073, 0.0501, 0.8811, 0.5604, 0.5075],
[0.4384, 0.9963, 0.9698, 0.4988, 0.2609],
[0.3391, 0.2230, 0.4610, 0.5365, 0.6880],
]]]);
let x_grad_expected = TestADTensor::from_floats([[[
[0., 0., 0., 3., 0.],
[4., 0., 2., 1., 0.],
[0., 0., 0., 0., 0.],
[2., 4., 0., 0., 0.],
[0., 0., 0., 0., 2.],
]]]);
let output = max_pool2d(
&x,
[kernel_size_1, kernel_size_2],
[stride_1, stride_2],
[padding_1, padding_2],
);
let grads = output.backward();
// Asserts
let x_grad_actual = x.grad(&grads).unwrap();
x_grad_expected
.to_data()
.assert_approx_eq(&x_grad_actual.to_data(), 3);
}
}

View File

@ -13,6 +13,7 @@ mod index;
mod log;
mod mask;
mod matmul;
mod maxpool2d;
mod mul;
mod multithread;
mod neg;
@ -34,6 +35,7 @@ macro_rules! testgen_all {
// Modules
burn_autodiff::testgen_ad_conv1d!();
burn_autodiff::testgen_ad_conv2d!();
burn_autodiff::testgen_ad_max_pool2d!();
burn_autodiff::testgen_module_backward!();
// Tensor

View File

@ -16,8 +16,6 @@ mod tensor;
pub use backend::*;
pub(crate) use tensor::*;
pub(crate) mod conv;
#[cfg(test)]
mod tests {
type TestBackend = crate::NdArrayBackend<f32>;

View File

@ -1,9 +1,32 @@
use super::padding::apply_padding2d;
use crate::{element::NdArrayElement, tensor::NdArrayTensor, NdArrayBackend, NdArrayDevice};
use burn_tensor::{ops::TensorOps, Shape};
/// This method is not the most efficient, but it serves as a basic implementation that is easy to understand.
/// A more optimized version should be used in its place.
pub(crate) fn conv2d_naive<E: NdArrayElement>(
x: &NdArrayTensor<E, 4>,
weight: &NdArrayTensor<E, 4>,
bias: Option<&NdArrayTensor<E, 1>>,
stride: [usize; 2],
padding: [usize; 2],
) -> NdArrayTensor<E, 4> {
let [batch_size, channels_in, heigth, width] = x.shape.dims;
let mut results = Vec::with_capacity(batch_size);
for b in 0..batch_size {
let x = NdArrayBackend::index(x, [b..b + 1, 0..channels_in, 0..heigth, 0..width]);
let x = NdArrayBackend::reshape(&x, Shape::new([channels_in, heigth, width]));
results.push(conv2d_naive_no_batch_size(
&x, weight, bias, stride, padding,
));
}
NdArrayBackend::cat(&results, 0)
}
pub(crate) fn conv2d_naive_no_batch_size<E: NdArrayElement>(
x: &NdArrayTensor<E, 3>,
weight: &NdArrayTensor<E, 4>,
bias: Option<&NdArrayTensor<E, 1>>,
@ -11,6 +34,7 @@ pub(crate) fn conv2d_naive<E: NdArrayElement>(
padding: [usize; 2],
) -> NdArrayTensor<E, 4> {
let [channels_out, channels_in, k1, k2] = weight.shape.dims;
let [_, heigth, width] = x.shape.dims;
let mut results = Vec::new();
for co in 0..channels_out {
@ -20,7 +44,9 @@ pub(crate) fn conv2d_naive<E: NdArrayElement>(
let kernel = NdArrayBackend::index(weight, [co..co + 1, ci..ci + 1, 0..k1, 0..k2]);
let kernel = NdArrayBackend::reshape(&kernel, Shape::new([k1, k2]));
let x = apply_padding(x, ci, padding);
let x = NdArrayBackend::index(x, [ci..ci + 1, 0..heigth, 0..width]);
let x = NdArrayBackend::reshape(&x, Shape::new([heigth, width]));
let x = apply_padding2d(&x, padding);
let matrix = conv2d_with_kernel(x, kernel, stride);
let [heigth, width] = matrix.shape.dims;
@ -45,31 +71,6 @@ pub(crate) fn conv2d_naive<E: NdArrayElement>(
result
}
fn apply_padding<E: NdArrayElement>(
x: &NdArrayTensor<E, 3>,
channel: usize,
padding: [usize; 2],
) -> NdArrayTensor<E, 2> {
let [_, heigth, width] = x.shape.dims;
let heigth_new = heigth + (2 * padding[0]);
let width_new = width + (2 * padding[1]);
let x = NdArrayBackend::index(x, [channel..channel + 1, 0..heigth, 0..width]);
let x = NdArrayBackend::reshape(&x, Shape::new([heigth, width]));
let mut x_new = NdArrayBackend::zeros(Shape::new([heigth_new, width_new]), NdArrayDevice::Cpu);
x_new = NdArrayBackend::index_assign(
&x_new,
[
padding[0]..heigth + padding[0],
padding[1]..width + padding[1],
],
&x,
);
x_new
}
fn conv2d_with_kernel<E: NdArrayElement>(
x: NdArrayTensor<E, 2>,
kernel: NdArrayTensor<E, 2>,

View File

@ -0,0 +1,144 @@
use super::padding::apply_padding2d;
use crate::{element::NdArrayElement, tensor::NdArrayTensor, NdArrayBackend, NdArrayDevice};
use burn_tensor::{ops::TensorOps, Data, Shape};
/// This method is not the most efficient, but it serves as a basic implementation that is easy to understand.
/// A more optimized version should be used in its place.
pub(crate) fn max_pool2d_with_indexes_naive<E: NdArrayElement>(
x: &NdArrayTensor<E, 4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
) -> (NdArrayTensor<E, 4>, NdArrayTensor<i64, 4>) {
let mut batches = Vec::new();
let mut batches_indexes = Vec::new();
let [batch_size, channels, heigth, width] = x.shape.dims;
for b in 0..batch_size {
let mut batch = Vec::new();
let mut batch_indexes = Vec::new();
for c in 0..channels {
let x = NdArrayBackend::index(x, [b..b + 1, c..c + 1, 0..heigth, 0..width]);
let x = NdArrayBackend::reshape(&x, Shape::new([heigth, width]));
let x = apply_padding2d(&x, padding);
let (matrix, indexes) = max_pool2d_with_kernel(x, kernel_size, stride, padding);
let [heigth, width] = matrix.shape.dims;
let matrix = NdArrayBackend::reshape(&matrix, Shape::new([1, 1, heigth, width]));
let indexes = NdArrayBackend::reshape(&indexes, Shape::new([1, 1, heigth, width]));
batch.push(matrix);
batch_indexes.push(indexes);
}
let batch = NdArrayBackend::cat(&batch, 1);
let batch_indexes = NdArrayBackend::cat(&batch_indexes, 1);
batches.push(batch);
batches_indexes.push(batch_indexes);
}
(
NdArrayBackend::cat(&batches, 0),
NdArrayBackend::cat(&batches_indexes, 0),
)
}
pub(crate) fn max_pool2d_backward_naive<E: NdArrayElement>(
x: &NdArrayTensor<E, 4>,
_kernel_size: [usize; 2],
_stride: [usize; 2],
_padding: [usize; 2],
output_grad: &NdArrayTensor<E, 4>,
indexes: &NdArrayTensor<i64, 4>,
) -> NdArrayTensor<E, 4> {
let [_batch_size, _channels, heigth, width] = output_grad.shape.dims;
let [batch_size, channels, heigth_x, width_x] = x.shape.dims;
let output_grad_flatten = NdArrayBackend::reshape(
output_grad,
Shape::new([batch_size, channels, heigth * width]),
);
let indexes_flatten =
NdArrayBackend::reshape(indexes, Shape::new([batch_size, channels, heigth * width]));
let mut output_flatten = NdArrayBackend::zeros(
Shape::new([batch_size, channels, heigth_x * width_x]),
NdArrayDevice::Cpu,
);
for b in 0..batch_size {
for c in 0..channels {
for i in 0..(heigth * width) {
let index = NdArrayBackend::index(&indexes_flatten, [b..b + 1, c..c + 1, i..i + 1]);
let index = NdArrayBackend::into_data(index).value[0] as usize;
let current_value =
NdArrayBackend::index(&output_flatten, [b..b + 1, c..c + 1, index..index + 1]);
let output_grad =
NdArrayBackend::index(&output_grad_flatten, [b..b + 1, c..c + 1, i..i + 1]);
let updated_value = NdArrayBackend::add(&current_value, &output_grad);
output_flatten = NdArrayBackend::index_assign(
&output_flatten,
[b..b + 1, c..c + 1, index..index + 1],
&updated_value,
);
}
}
}
NdArrayBackend::reshape(&output_flatten, x.shape)
}
fn max_pool2d_with_kernel<E: NdArrayElement>(
x: NdArrayTensor<E, 2>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
) -> (NdArrayTensor<E, 2>, NdArrayTensor<i64, 2>) {
let [k1, k2] = kernel_size;
let [p1, p2] = padding;
let [heigth, width] = x.shape.dims;
let heigth_new = f32::ceil((heigth - k1 + 1) as f32 / stride[0] as f32) as usize;
let width_new = f32::ceil((width - k2 + 1) as f32 / stride[1] as f32) as usize;
let mut output = NdArrayBackend::empty(Shape::new([heigth_new, width_new]), NdArrayDevice::Cpu);
let mut indexes =
NdArrayBackend::empty(Shape::new([heigth_new, width_new]), NdArrayDevice::Cpu);
for i in 0..heigth_new {
for j in 0..width_new {
let i_x = i * stride[0];
let j_x = j * stride[1];
let x_ij = NdArrayBackend::index(&x, [i_x..i_x + k1, j_x..j_x + k2]);
let x_flatten = NdArrayBackend::reshape(&x_ij, Shape::new([k1 * k2]));
let index = NdArrayBackend::argmax(&x_flatten, 0);
let index = NdArrayBackend::into_data(index).value[0];
let value = NdArrayBackend::into_data(x_flatten).value[index as usize];
let value = NdArrayBackend::from_data(
Data::new(vec![value], Shape::new([1, 1])),
NdArrayDevice::Cpu,
);
let index_i = index / k2 as i64;
let index_j = index - (index_i * k2 as i64);
let ii = i64::max(0, i_x as i64 - p1 as i64 + index_i);
let jj = i64::max(0, j_x as i64 - p2 as i64 + index_j);
let h = (heigth - (2 * p1)) as i64;
let index = ii * h + jj;
let index = NdArrayBackend::from_data(
Data::new(vec![index], Shape::new([1, 1])),
NdArrayDevice::Cpu,
);
indexes = NdArrayBackend::index_assign(&indexes, [i..i + 1, j..j + 1], &index);
output = NdArrayBackend::index_assign(&output, [i..i + 1, j..j + 1], &value);
}
}
(output, indexes)
}

View File

@ -1,3 +1,7 @@
mod creation;
mod module;
mod tensor;
pub(crate) mod conv;
pub(crate) mod maxpool;
pub(crate) mod padding;

View File

@ -1,7 +1,12 @@
use crate::{conv::conv2d_naive, element::NdArrayElement, tensor::NdArrayTensor, NdArrayBackend};
use crate::{element::NdArrayElement, tensor::NdArrayTensor, NdArrayBackend};
use burn_tensor::{ops::*, Shape};
use std::ops::Add;
use super::{
conv::conv2d_naive,
maxpool::{max_pool2d_backward_naive, max_pool2d_with_indexes_naive},
};
impl<E: NdArrayElement> ModuleOps<NdArrayBackend<E>> for NdArrayBackend<E> {
fn embedding(
weights: &NdArrayTensor<E, 2>,
@ -68,16 +73,44 @@ impl<E: NdArrayElement> ModuleOps<NdArrayBackend<E>> for NdArrayBackend<E> {
stride: [usize; 2],
padding: [usize; 2],
) -> NdArrayTensor<E, 4> {
let [batch_size, channels_in, heigth, width] = x.shape.dims;
let mut results = Vec::with_capacity(batch_size);
conv2d_naive(x, weight, bias, stride, padding)
}
for b in 0..batch_size {
let x = NdArrayBackend::index(x, [b..b + 1, 0..channels_in, 0..heigth, 0..width]);
let x = NdArrayBackend::reshape(&x, Shape::new([channels_in, heigth, width]));
fn max_pool2d(
x: &NdArrayTensor<E, 4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
) -> NdArrayTensor<E, 4> {
max_pool2d_with_indexes_naive(x, kernel_size, stride, padding).0
}
results.push(conv2d_naive(&x, weight, bias, stride, padding));
}
fn max_pool2d_with_indexes(
x: &NdArrayTensor<E, 4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
) -> MaxPool2dWithIndexes<NdArrayBackend<E>> {
let (output, indexes) = max_pool2d_with_indexes_naive(x, kernel_size, stride, padding);
NdArrayBackend::cat(&results, 0)
MaxPool2dWithIndexes::new(output, indexes)
}
fn max_pool2d_with_indexes_backward(
x: &NdArrayTensor<E, 4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
output_grad: &NdArrayTensor<E, 4>,
indexes: &NdArrayTensor<i64, 4>,
) -> MaxPool2dBackward<NdArrayBackend<E>> {
MaxPool2dBackward::new(max_pool2d_backward_naive(
x,
kernel_size,
stride,
padding,
output_grad,
indexes,
))
}
}

View File

@ -0,0 +1,24 @@
use crate::{element::NdArrayElement, tensor::NdArrayTensor, NdArrayBackend, NdArrayDevice};
use burn_tensor::{ops::TensorOps, Shape};
pub(crate) fn apply_padding2d<E: NdArrayElement>(
x: &NdArrayTensor<E, 2>,
padding: [usize; 2],
) -> NdArrayTensor<E, 2> {
let [heigth, width] = x.shape.dims;
let heigth_new = heigth + (2 * padding[0]);
let width_new = width + (2 * padding[1]);
let mut x_new = NdArrayBackend::zeros(Shape::new([heigth_new, width_new]), NdArrayDevice::Cpu);
x_new = NdArrayBackend::index_assign(
&x_new,
[
padding[0]..heigth + padding[0],
padding[1]..width + padding[1],
],
x,
);
x_new
}

View File

@ -1,6 +1,3 @@
use std::cmp::Ordering;
use std::ops::Range;
use crate::tensor::BatchMatrix;
use crate::{element::NdArrayElement, tensor::NdArrayTensor, NdArrayBackend};
use crate::{to_nd_array_tensor, NdArrayDevice, SEED};
@ -9,6 +6,8 @@ use burn_tensor::{backend::Backend, ops::TensorOps, Data, ElementConversion, Sha
use ndarray::{ArcArray, Axis, Dim, Dimension, IxDyn, SliceInfoElem};
use rand::rngs::StdRng;
use rand::SeedableRng;
use std::cmp::Ordering;
use std::ops::Range;
macro_rules! keepdim {
(

View File

@ -1,5 +1,8 @@
use crate::{element::TchElement, TchBackend, TchTensor};
use burn_tensor::{ops::ModuleOps, Shape};
use crate::{element::TchElement, TchBackend, TchKind, TchTensor};
use burn_tensor::{
ops::{MaxPool2dBackward, MaxPool2dWithIndexes, ModuleOps},
Shape,
};
impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
fn embedding(weights: &TchTensor<E, 2>, indexes: &TchTensor<i64, 2>) -> TchTensor<E, 3> {
@ -85,4 +88,87 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
shape,
}
}
fn max_pool2d(
x: &TchTensor<E, 4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
) -> TchTensor<E, 4> {
let tensor = tch::Tensor::max_pool2d(
&x.tensor,
&[kernel_size[0] as i64, kernel_size[1] as i64],
&[stride[0] as i64, stride[1] as i64],
&[padding[0] as i64, padding[1] as i64],
&[1, 1],
false,
);
let shape = Shape::from(tensor.size());
TchTensor {
kind: x.kind,
tensor,
shape,
}
}
fn max_pool2d_with_indexes(
x: &TchTensor<E, 4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
) -> MaxPool2dWithIndexes<TchBackend<E>> {
let (tensor, indexes) = tch::Tensor::max_pool2d_with_indices(
&x.tensor,
&[kernel_size[0] as i64, kernel_size[1] as i64],
&[stride[0] as i64, stride[1] as i64],
&[padding[0] as i64, padding[1] as i64],
&[1, 1],
false,
);
let shape = Shape::from(tensor.size());
let output = TchTensor {
kind: x.kind,
tensor,
shape,
};
let shape = Shape::from(indexes.size());
let indexes = TchTensor {
kind: TchKind::<i64>::new(),
tensor: indexes,
shape,
};
MaxPool2dWithIndexes::new(output, indexes)
}
fn max_pool2d_with_indexes_backward(
x: &TchTensor<E, 4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
output_grad: &TchTensor<E, 4>,
indexes: &TchTensor<i64, 4>,
) -> MaxPool2dBackward<TchBackend<E>> {
let grad = tch::Tensor::max_pool2d_with_indices_backward(
&x.tensor,
&output_grad.tensor,
&[kernel_size[0] as i64, kernel_size[1] as i64],
&[stride[0] as i64, stride[1] as i64],
&[padding[0] as i64, padding[1] as i64],
&[1, 1],
false,
&indexes.tensor,
);
let shape = Shape::from(grad.size());
let tensor = TchTensor {
kind: x.kind,
tensor: grad,
shape,
};
MaxPool2dBackward::new(tensor)
}
}

View File

@ -47,3 +47,31 @@ where
padding,
))
}
/// Applies a [2D max pooling](crate::ops::ModuleOps::max_pool2d).
pub fn max_pool2d<B>(
x: &Tensor<B, 4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
) -> Tensor<B, 4>
where
B: Backend,
{
Tensor::new(B::max_pool2d(&x.value, kernel_size, stride, padding))
}
/// Applies a [2D max pooling with indexes](crate::ops::ModuleOps::max_pool2d_with_indexes).
pub fn max_pool2d_with_indexes<B>(
x: &Tensor<B, 4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
) -> (Tensor<B, 4>, Tensor<B::IntegerBackend, 4>)
where
B: Backend,
{
let output = B::max_pool2d_with_indexes(&x.value, kernel_size, stride, padding);
(Tensor::new(output.output), Tensor::new(output.indexes))
}

View File

@ -9,6 +9,19 @@ pub struct Conv2dBackward<B: Backend> {
pub bias_grad: Option<B::TensorPrimitive<1>>,
}
/// Gradient computed during the backward pass for each tensor used by [max_pool2d](ModuleOps::max_pool2d).
#[derive(new)]
pub struct MaxPool2dBackward<B: Backend> {
pub x_grad: B::TensorPrimitive<4>,
}
/// Results from [max_pool2d](ModuleOps::max_pool2d_with_indexes).
#[derive(new)]
pub struct MaxPool2dWithIndexes<B: Backend> {
pub output: B::TensorPrimitive<4>,
pub indexes: <B::IntegerBackend as Backend>::TensorPrimitive<4>,
}
/// Gradient computed during the backward pass for each tensor used by [conv1d](ModuleOps::conv1d).
#[derive(new)]
pub struct Conv1dBackward<B: Backend> {
@ -77,4 +90,35 @@ pub trait ModuleOps<B: Backend> {
) -> Conv1dBackward<B> {
conv::conv1d_backward(x, weight, bias, stride, output_grad)
}
/// Two dimensional max pooling.
///
/// # Shapes
///
/// x: [batch_size, channels, height, width],
fn max_pool2d(
x: &B::TensorPrimitive<4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
) -> B::TensorPrimitive<4>;
/// Two dimensional max pooling with indexes.
///
/// # Shapes
///
/// x: [batch_size, channels, height, width],
fn max_pool2d_with_indexes(
x: &B::TensorPrimitive<4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
) -> MaxPool2dWithIndexes<B>;
/// Backward pass for the [max pooling 2d](ModuleOps::max_pool2d_with_indexes) operation.
fn max_pool2d_with_indexes_backward(
x: &B::TensorPrimitive<4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
output_grad: &B::TensorPrimitive<4>,
indexes: &<B::IntegerBackend as Backend>::TensorPrimitive<4>,
) -> MaxPool2dBackward<B>;
}

View File

@ -15,6 +15,7 @@ macro_rules! testgen_all {
burn_tensor::testgen_module_forward!();
burn_tensor::testgen_module_conv1d!();
burn_tensor::testgen_module_conv2d!();
burn_tensor::testgen_module_max_pool2d!();
// test ops
burn_tensor::testgen_add!();

View File

@ -0,0 +1,230 @@
#[burn_tensor_testgen::testgen(module_max_pool2d)]
mod tests {
use super::*;
use burn_tensor::module::{max_pool2d, max_pool2d_with_indexes};
use burn_tensor::{Data, Tensor};
#[test]
fn test_max_pool2d_simple() {
let batch_size = 2;
let channels_in = 2;
let kernel_size_1 = 3;
let kernel_size_2 = 3;
let padding_1 = 1;
let padding_2 = 1;
let stride_1 = 1;
let stride_2 = 1;
let x = TestTensor::from_floats([
[
[
[0.9861, 0.5474, 0.4477, 0.0732, 0.3548, 0.8221],
[0.8148, 0.5474, 0.9490, 0.7890, 0.5537, 0.5689],
[0.5986, 0.2059, 0.4897, 0.6136, 0.2965, 0.6182],
[0.1485, 0.9540, 0.4023, 0.6176, 0.7111, 0.3392],
[0.3703, 0.0472, 0.2771, 0.1868, 0.8855, 0.5605],
[0.5063, 0.1638, 0.9432, 0.7836, 0.8696, 0.1068],
],
[
[0.8872, 0.0137, 0.1652, 0.5505, 0.6127, 0.6473],
[0.1128, 0.0888, 0.1152, 0.5456, 0.6199, 0.7947],
[0.5911, 0.7781, 0.7256, 0.6578, 0.0989, 0.9149],
[0.5879, 0.5189, 0.6561, 0.0578, 0.7025, 0.6426],
[0.9590, 0.0325, 0.6455, 0.6248, 0.2009, 0.1544],
[0.7339, 0.1369, 0.6598, 0.5528, 0.6775, 0.1572],
],
],
[
[
[0.6853, 0.6439, 0.4639, 0.5573, 0.2723, 0.5910],
[0.5419, 0.7729, 0.6743, 0.8956, 0.2997, 0.9546],
[0.0334, 0.2178, 0.6917, 0.4958, 0.3357, 0.6584],
[0.7358, 0.9074, 0.2462, 0.5159, 0.6420, 0.2441],
[0.7602, 0.6297, 0.6073, 0.5937, 0.8037, 0.4881],
[0.8859, 0.0974, 0.3954, 0.6763, 0.1078, 0.7467],
],
[
[0.2991, 0.5012, 0.8024, 0.7653, 0.9378, 0.7952],
[0.7393, 0.2336, 0.9521, 0.2719, 0.8445, 0.0454],
[0.6479, 0.9822, 0.7905, 0.0318, 0.2474, 0.0628],
[0.9955, 0.7591, 0.4140, 0.3215, 0.4349, 0.1527],
[0.8064, 0.0164, 0.4002, 0.2024, 0.6128, 0.5827],
[0.5368, 0.7895, 0.8727, 0.7793, 0.0910, 0.3421],
],
],
]);
let y = TestTensor::from_floats([
[
[
[0.9861, 0.9861, 0.9490, 0.9490, 0.8221, 0.8221],
[0.9861, 0.9861, 0.9490, 0.9490, 0.8221, 0.8221],
[0.9540, 0.9540, 0.9540, 0.9490, 0.7890, 0.7111],
[0.9540, 0.9540, 0.9540, 0.8855, 0.8855, 0.8855],
[0.9540, 0.9540, 0.9540, 0.9432, 0.8855, 0.8855],
[0.5063, 0.9432, 0.9432, 0.9432, 0.8855, 0.8855],
],
[
[0.8872, 0.8872, 0.5505, 0.6199, 0.7947, 0.7947],
[0.8872, 0.8872, 0.7781, 0.7256, 0.9149, 0.9149],
[0.7781, 0.7781, 0.7781, 0.7256, 0.9149, 0.9149],
[0.9590, 0.9590, 0.7781, 0.7256, 0.9149, 0.9149],
[0.9590, 0.9590, 0.6598, 0.7025, 0.7025, 0.7025],
[0.9590, 0.9590, 0.6598, 0.6775, 0.6775, 0.6775],
],
],
[
[
[0.7729, 0.7729, 0.8956, 0.8956, 0.9546, 0.9546],
[0.7729, 0.7729, 0.8956, 0.8956, 0.9546, 0.9546],
[0.9074, 0.9074, 0.9074, 0.8956, 0.9546, 0.9546],
[0.9074, 0.9074, 0.9074, 0.8037, 0.8037, 0.8037],
[0.9074, 0.9074, 0.9074, 0.8037, 0.8037, 0.8037],
[0.8859, 0.8859, 0.6763, 0.8037, 0.8037, 0.8037],
],
[
[0.7393, 0.9521, 0.9521, 0.9521, 0.9378, 0.9378],
[0.9822, 0.9822, 0.9822, 0.9521, 0.9378, 0.9378],
[0.9955, 0.9955, 0.9822, 0.9521, 0.8445, 0.8445],
[0.9955, 0.9955, 0.9822, 0.7905, 0.6128, 0.6128],
[0.9955, 0.9955, 0.8727, 0.8727, 0.7793, 0.6128],
[0.8064, 0.8727, 0.8727, 0.8727, 0.7793, 0.6128],
],
],
]);
let output = max_pool2d(
&x,
[kernel_size_1, kernel_size_2],
[stride_1, stride_2],
[padding_1, padding_2],
);
y.to_data().assert_approx_eq(&output.into_data(), 3);
}
#[test]
fn test_max_pool2d_different_padding_stride_kernel() {
let batch_size = 1;
let channels_in = 1;
let kernel_size_1 = 3;
let kernel_size_2 = 1;
let padding_1 = 1;
let padding_2 = 0;
let stride_1 = 1;
let stride_2 = 2;
let x = TestTensor::from_floats([[[
[0.6309, 0.6112, 0.6998],
[0.4708, 0.9161, 0.5402],
[0.4577, 0.7397, 0.9870],
[0.6380, 0.4352, 0.5884],
[0.6277, 0.5139, 0.4525],
[0.9333, 0.9846, 0.5006],
]]]);
let y = TestTensor::from_floats([[[
[0.6309, 0.6998],
[0.6309, 0.9870],
[0.6380, 0.9870],
[0.6380, 0.9870],
[0.9333, 0.5884],
[0.9333, 0.5006],
]]]);
let output = max_pool2d(
&x,
[kernel_size_1, kernel_size_2],
[stride_1, stride_2],
[padding_1, padding_2],
);
y.to_data().assert_approx_eq(&output.into_data(), 3);
}
#[test]
fn test_max_pool2d_with_indexes() {
let batch_size = 1;
let channels_in = 1;
let kernel_size_1 = 2;
let kernel_size_2 = 2;
let padding_1 = 1;
let padding_2 = 1;
let stride_1 = 1;
let stride_2 = 1;
let x = TestTensor::from_floats([[[
[0.2479, 0.6386, 0.3166, 0.5742],
[0.7065, 0.1940, 0.6305, 0.8959],
[0.5416, 0.8602, 0.8129, 0.1662],
[0.3358, 0.3059, 0.8293, 0.0990],
]]]);
let indexes = Data::<i64, 4>::from([[[
[0, 1, 1, 3, 3],
[4, 4, 1, 7, 7],
[4, 9, 9, 7, 7],
[8, 9, 9, 14, 11],
[12, 12, 14, 14, 15],
]]]);
let y = TestTensor::from_floats([[[
[0.2479, 0.6386, 0.6386, 0.5742, 0.5742],
[0.7065, 0.7065, 0.6386, 0.8959, 0.8959],
[0.7065, 0.8602, 0.8602, 0.8959, 0.8959],
[0.5416, 0.8602, 0.8602, 0.8293, 0.1662],
[0.3358, 0.3358, 0.8293, 0.8293, 0.0990],
]]]);
let (output, output_indexes) = max_pool2d_with_indexes(
&x,
[kernel_size_1, kernel_size_2],
[stride_1, stride_2],
[padding_1, padding_2],
);
y.to_data().assert_approx_eq(&output.into_data(), 3);
assert_eq!(indexes.value, output_indexes.into_data().value);
}
#[test]
fn test_max_pool2d_complex() {
let batch_size = 1;
let channels_in = 1;
let kernel_size_1 = 4;
let kernel_size_2 = 2;
let padding_1 = 2;
let padding_2 = 1;
let stride_1 = 1;
let stride_2 = 2;
let x = TestTensor::from_floats([[[
[0.5388, 0.0676, 0.7122, 0.8316, 0.0653],
[0.9154, 0.1536, 0.9089, 0.8016, 0.7518],
[0.2073, 0.0501, 0.8811, 0.5604, 0.5075],
[0.4384, 0.9963, 0.9698, 0.4988, 0.2609],
[0.3391, 0.2230, 0.4610, 0.5365, 0.6880],
]]]);
let indexes = Data::<i64, 4>::from([[[
[5, 7, 3],
[5, 7, 3],
[5, 16, 3],
[5, 16, 8],
[15, 16, 24],
[15, 16, 24],
]]]);
let y = TestTensor::from_floats([[[
[0.9154, 0.9089, 0.8316],
[0.9154, 0.9089, 0.8316],
[0.9154, 0.9963, 0.8316],
[0.9154, 0.9963, 0.8016],
[0.4384, 0.9963, 0.688],
[0.4384, 0.9963, 0.688],
]]]);
let (output, output_indexes) = max_pool2d_with_indexes(
&x,
[kernel_size_1, kernel_size_2],
[stride_1, stride_2],
[padding_1, padding_2],
);
y.to_data().assert_approx_eq(&output.into_data(), 3);
assert_eq!(indexes.value, output_indexes.into_data().value);
}
}

View File

@ -1,3 +1,4 @@
mod conv1d;
mod conv2d;
mod forward;
mod maxpool2d;