mirror of https://github.com/tracel-ai/burn.git
Feat/max pooling backend (#152)
This commit is contained in:
parent
1d0d92a269
commit
34d233cd3e
|
@ -104,6 +104,7 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
if let Some(bias) = bias {
|
||||
order = usize::max(order, bias.node.order);
|
||||
}
|
||||
order += 1;
|
||||
|
||||
let ops = ForwardConv::<B, 1, 3>::new(
|
||||
x.node.clone(),
|
||||
|
@ -118,6 +119,78 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
|
||||
ADTensor { node, shape }
|
||||
}
|
||||
|
||||
fn max_pool2d(
|
||||
x: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<4>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<4> {
|
||||
let output = B::max_pool2d_with_indexes(x.tensor_ref(), kernel_size, stride, padding);
|
||||
let shape = *B::shape(&output.output);
|
||||
let order = x.node.order + 1;
|
||||
|
||||
let ops = ForwardMaxPool::<B, 2, 4>::new(
|
||||
x.node.clone(),
|
||||
Arc::new(output.indexes),
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
);
|
||||
let ops = Box::new(ops);
|
||||
let state = ForwardNodeState::new(output.output);
|
||||
let node = ForwardNode::new(order, state, ops);
|
||||
let node = std::sync::Arc::new(node);
|
||||
|
||||
ADTensor { node, shape }
|
||||
}
|
||||
|
||||
fn max_pool2d_with_indexes(
|
||||
x: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<4>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
) -> MaxPool2dWithIndexes<ADBackendDecorator<B>> {
|
||||
let output = B::max_pool2d_with_indexes(x.tensor_ref(), kernel_size, stride, padding);
|
||||
let shape = *B::shape(&output.output);
|
||||
let order = x.node.order + 1;
|
||||
|
||||
let ops = ForwardMaxPool::<B, 2, 4>::new(
|
||||
x.node.clone(),
|
||||
Arc::new(output.indexes.clone()),
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
);
|
||||
let ops = Box::new(ops);
|
||||
let state = ForwardNodeState::new(output.output);
|
||||
let node = ForwardNode::new(order, state, ops);
|
||||
let node = std::sync::Arc::new(node);
|
||||
|
||||
MaxPool2dWithIndexes::new(ADTensor { node, shape }, output.indexes)
|
||||
}
|
||||
|
||||
fn max_pool2d_with_indexes_backward(
|
||||
x: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<4>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
output_grad: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<4>,
|
||||
indexes: &<<ADBackendDecorator<B> as Backend>::IntegerBackend as Backend>::TensorPrimitive<
|
||||
4,
|
||||
>,
|
||||
) -> MaxPool2dBackward<ADBackendDecorator<B>> {
|
||||
let tensor = B::max_pool2d_with_indexes_backward(
|
||||
x.tensor_ref(),
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
output_grad.tensor_ref(),
|
||||
indexes,
|
||||
);
|
||||
|
||||
MaxPool2dBackward::new(ADTensor::from_tensor(tensor.x_grad))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(new, Debug)]
|
||||
|
@ -231,3 +304,57 @@ impl<B: Backend> BackwardRecordedOps<B::TensorPrimitive<3>> for BackwardConv<B,
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(new, Debug)]
|
||||
pub struct ForwardMaxPool<B: Backend, const D: usize, const S: usize> {
|
||||
x: ForwardNodeRef<B::TensorPrimitive<S>>,
|
||||
indexes: Arc<<B::IntegerBackend as Backend>::TensorPrimitive<S>>,
|
||||
kernel_size: [usize; D],
|
||||
stride: [usize; D],
|
||||
padding: [usize; D],
|
||||
}
|
||||
|
||||
#[derive(new, Debug)]
|
||||
pub struct BackwardMaxPool<B: Backend, const D: usize, const S: usize> {
|
||||
x: BackwardNodeRef<B::TensorPrimitive<S>>,
|
||||
indexes: Arc<<B::IntegerBackend as Backend>::TensorPrimitive<S>>,
|
||||
kernel_size: [usize; D],
|
||||
stride: [usize; D],
|
||||
padding: [usize; D],
|
||||
}
|
||||
|
||||
impl<B: Backend> ForwardRecordedOps<B::TensorPrimitive<4>> for ForwardMaxPool<B, 2, 4> {
|
||||
fn to_backward(
|
||||
&self,
|
||||
graph: &mut Forward2BackwardGraphConverter,
|
||||
) -> BackwardRecordedOpsBoxed<B::TensorPrimitive<4>> {
|
||||
let ops = BackwardMaxPool::<B, 2, 4>::new(
|
||||
Arc::new(BackwardNode::from_node(&self.x, graph)),
|
||||
self.indexes.clone(),
|
||||
self.kernel_size,
|
||||
self.stride,
|
||||
self.padding,
|
||||
);
|
||||
|
||||
Box::new(ops)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> BackwardRecordedOps<B::TensorPrimitive<4>> for BackwardMaxPool<B, 2, 4> {
|
||||
fn backward_step(&self, state: &BackwardNodeState<B::TensorPrimitive<4>>) {
|
||||
let grads = B::max_pool2d_with_indexes_backward(
|
||||
&self.x.state.value,
|
||||
self.kernel_size,
|
||||
self.stride,
|
||||
self.padding,
|
||||
&state.grad.borrow(),
|
||||
&self.indexes,
|
||||
);
|
||||
|
||||
self.x.state.update_grad(grads.x_grad);
|
||||
}
|
||||
|
||||
fn backward_parents(&self) -> Vec<RecordedOpsParentRef> {
|
||||
vec![self.x.clone()]
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,85 @@
|
|||
#[burn_tensor_testgen::testgen(ad_max_pool2d)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::{module::max_pool2d, Data};
|
||||
|
||||
#[test]
|
||||
fn test_max_pool2d_simple() {
|
||||
let batch_size = 1;
|
||||
let channels_in = 1;
|
||||
let kernel_size_1 = 2;
|
||||
let kernel_size_2 = 2;
|
||||
let padding_1 = 1;
|
||||
let padding_2 = 1;
|
||||
let stride_1 = 1;
|
||||
let stride_2 = 1;
|
||||
|
||||
let x = TestADTensor::from_floats([[[
|
||||
[0.2479, 0.6386, 0.3166, 0.5742],
|
||||
[0.7065, 0.1940, 0.6305, 0.8959],
|
||||
[0.5416, 0.8602, 0.8129, 0.1662],
|
||||
[0.3358, 0.3059, 0.8293, 0.0990],
|
||||
]]]);
|
||||
let x_grad_expected = TestADTensor::from_floats([[[
|
||||
[1., 3., 0., 2.],
|
||||
[3., 0., 0., 4.],
|
||||
[1., 4., 0., 1.],
|
||||
[2., 0., 3., 1.],
|
||||
]]]);
|
||||
|
||||
let output = max_pool2d(
|
||||
&x,
|
||||
[kernel_size_1, kernel_size_2],
|
||||
[stride_1, stride_2],
|
||||
[padding_1, padding_2],
|
||||
);
|
||||
let grads = output.backward();
|
||||
|
||||
// Asserts
|
||||
let x_grad_actual = x.grad(&grads).unwrap();
|
||||
x_grad_expected
|
||||
.to_data()
|
||||
.assert_approx_eq(&x_grad_actual.to_data(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_pool2d_complex() {
|
||||
let batch_size = 1;
|
||||
let channels_in = 1;
|
||||
let kernel_size_1 = 4;
|
||||
let kernel_size_2 = 2;
|
||||
let padding_1 = 2;
|
||||
let padding_2 = 1;
|
||||
let stride_1 = 1;
|
||||
let stride_2 = 2;
|
||||
|
||||
let x = TestADTensor::from_floats([[[
|
||||
[0.5388, 0.0676, 0.7122, 0.8316, 0.0653],
|
||||
[0.9154, 0.1536, 0.9089, 0.8016, 0.7518],
|
||||
[0.2073, 0.0501, 0.8811, 0.5604, 0.5075],
|
||||
[0.4384, 0.9963, 0.9698, 0.4988, 0.2609],
|
||||
[0.3391, 0.2230, 0.4610, 0.5365, 0.6880],
|
||||
]]]);
|
||||
let x_grad_expected = TestADTensor::from_floats([[[
|
||||
[0., 0., 0., 3., 0.],
|
||||
[4., 0., 2., 1., 0.],
|
||||
[0., 0., 0., 0., 0.],
|
||||
[2., 4., 0., 0., 0.],
|
||||
[0., 0., 0., 0., 2.],
|
||||
]]]);
|
||||
|
||||
let output = max_pool2d(
|
||||
&x,
|
||||
[kernel_size_1, kernel_size_2],
|
||||
[stride_1, stride_2],
|
||||
[padding_1, padding_2],
|
||||
);
|
||||
let grads = output.backward();
|
||||
|
||||
// Asserts
|
||||
let x_grad_actual = x.grad(&grads).unwrap();
|
||||
x_grad_expected
|
||||
.to_data()
|
||||
.assert_approx_eq(&x_grad_actual.to_data(), 3);
|
||||
}
|
||||
}
|
|
@ -13,6 +13,7 @@ mod index;
|
|||
mod log;
|
||||
mod mask;
|
||||
mod matmul;
|
||||
mod maxpool2d;
|
||||
mod mul;
|
||||
mod multithread;
|
||||
mod neg;
|
||||
|
@ -34,6 +35,7 @@ macro_rules! testgen_all {
|
|||
// Modules
|
||||
burn_autodiff::testgen_ad_conv1d!();
|
||||
burn_autodiff::testgen_ad_conv2d!();
|
||||
burn_autodiff::testgen_ad_max_pool2d!();
|
||||
burn_autodiff::testgen_module_backward!();
|
||||
|
||||
// Tensor
|
||||
|
|
|
@ -16,8 +16,6 @@ mod tensor;
|
|||
pub use backend::*;
|
||||
pub(crate) use tensor::*;
|
||||
|
||||
pub(crate) mod conv;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
type TestBackend = crate::NdArrayBackend<f32>;
|
||||
|
|
|
@ -1,9 +1,32 @@
|
|||
use super::padding::apply_padding2d;
|
||||
use crate::{element::NdArrayElement, tensor::NdArrayTensor, NdArrayBackend, NdArrayDevice};
|
||||
use burn_tensor::{ops::TensorOps, Shape};
|
||||
|
||||
/// This method is not the most efficient, but it serves as a basic implementation that is easy to understand.
|
||||
/// A more optimized version should be used in its place.
|
||||
pub(crate) fn conv2d_naive<E: NdArrayElement>(
|
||||
x: &NdArrayTensor<E, 4>,
|
||||
weight: &NdArrayTensor<E, 4>,
|
||||
bias: Option<&NdArrayTensor<E, 1>>,
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
) -> NdArrayTensor<E, 4> {
|
||||
let [batch_size, channels_in, heigth, width] = x.shape.dims;
|
||||
let mut results = Vec::with_capacity(batch_size);
|
||||
|
||||
for b in 0..batch_size {
|
||||
let x = NdArrayBackend::index(x, [b..b + 1, 0..channels_in, 0..heigth, 0..width]);
|
||||
let x = NdArrayBackend::reshape(&x, Shape::new([channels_in, heigth, width]));
|
||||
|
||||
results.push(conv2d_naive_no_batch_size(
|
||||
&x, weight, bias, stride, padding,
|
||||
));
|
||||
}
|
||||
|
||||
NdArrayBackend::cat(&results, 0)
|
||||
}
|
||||
|
||||
pub(crate) fn conv2d_naive_no_batch_size<E: NdArrayElement>(
|
||||
x: &NdArrayTensor<E, 3>,
|
||||
weight: &NdArrayTensor<E, 4>,
|
||||
bias: Option<&NdArrayTensor<E, 1>>,
|
||||
|
@ -11,6 +34,7 @@ pub(crate) fn conv2d_naive<E: NdArrayElement>(
|
|||
padding: [usize; 2],
|
||||
) -> NdArrayTensor<E, 4> {
|
||||
let [channels_out, channels_in, k1, k2] = weight.shape.dims;
|
||||
let [_, heigth, width] = x.shape.dims;
|
||||
let mut results = Vec::new();
|
||||
|
||||
for co in 0..channels_out {
|
||||
|
@ -20,7 +44,9 @@ pub(crate) fn conv2d_naive<E: NdArrayElement>(
|
|||
let kernel = NdArrayBackend::index(weight, [co..co + 1, ci..ci + 1, 0..k1, 0..k2]);
|
||||
let kernel = NdArrayBackend::reshape(&kernel, Shape::new([k1, k2]));
|
||||
|
||||
let x = apply_padding(x, ci, padding);
|
||||
let x = NdArrayBackend::index(x, [ci..ci + 1, 0..heigth, 0..width]);
|
||||
let x = NdArrayBackend::reshape(&x, Shape::new([heigth, width]));
|
||||
let x = apply_padding2d(&x, padding);
|
||||
|
||||
let matrix = conv2d_with_kernel(x, kernel, stride);
|
||||
let [heigth, width] = matrix.shape.dims;
|
||||
|
@ -45,31 +71,6 @@ pub(crate) fn conv2d_naive<E: NdArrayElement>(
|
|||
result
|
||||
}
|
||||
|
||||
fn apply_padding<E: NdArrayElement>(
|
||||
x: &NdArrayTensor<E, 3>,
|
||||
channel: usize,
|
||||
padding: [usize; 2],
|
||||
) -> NdArrayTensor<E, 2> {
|
||||
let [_, heigth, width] = x.shape.dims;
|
||||
let heigth_new = heigth + (2 * padding[0]);
|
||||
let width_new = width + (2 * padding[1]);
|
||||
|
||||
let x = NdArrayBackend::index(x, [channel..channel + 1, 0..heigth, 0..width]);
|
||||
let x = NdArrayBackend::reshape(&x, Shape::new([heigth, width]));
|
||||
|
||||
let mut x_new = NdArrayBackend::zeros(Shape::new([heigth_new, width_new]), NdArrayDevice::Cpu);
|
||||
x_new = NdArrayBackend::index_assign(
|
||||
&x_new,
|
||||
[
|
||||
padding[0]..heigth + padding[0],
|
||||
padding[1]..width + padding[1],
|
||||
],
|
||||
&x,
|
||||
);
|
||||
|
||||
x_new
|
||||
}
|
||||
|
||||
fn conv2d_with_kernel<E: NdArrayElement>(
|
||||
x: NdArrayTensor<E, 2>,
|
||||
kernel: NdArrayTensor<E, 2>,
|
|
@ -0,0 +1,144 @@
|
|||
use super::padding::apply_padding2d;
|
||||
use crate::{element::NdArrayElement, tensor::NdArrayTensor, NdArrayBackend, NdArrayDevice};
|
||||
use burn_tensor::{ops::TensorOps, Data, Shape};
|
||||
|
||||
/// This method is not the most efficient, but it serves as a basic implementation that is easy to understand.
|
||||
/// A more optimized version should be used in its place.
|
||||
pub(crate) fn max_pool2d_with_indexes_naive<E: NdArrayElement>(
|
||||
x: &NdArrayTensor<E, 4>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
) -> (NdArrayTensor<E, 4>, NdArrayTensor<i64, 4>) {
|
||||
let mut batches = Vec::new();
|
||||
let mut batches_indexes = Vec::new();
|
||||
|
||||
let [batch_size, channels, heigth, width] = x.shape.dims;
|
||||
|
||||
for b in 0..batch_size {
|
||||
let mut batch = Vec::new();
|
||||
let mut batch_indexes = Vec::new();
|
||||
|
||||
for c in 0..channels {
|
||||
let x = NdArrayBackend::index(x, [b..b + 1, c..c + 1, 0..heigth, 0..width]);
|
||||
let x = NdArrayBackend::reshape(&x, Shape::new([heigth, width]));
|
||||
let x = apply_padding2d(&x, padding);
|
||||
|
||||
let (matrix, indexes) = max_pool2d_with_kernel(x, kernel_size, stride, padding);
|
||||
let [heigth, width] = matrix.shape.dims;
|
||||
|
||||
let matrix = NdArrayBackend::reshape(&matrix, Shape::new([1, 1, heigth, width]));
|
||||
let indexes = NdArrayBackend::reshape(&indexes, Shape::new([1, 1, heigth, width]));
|
||||
|
||||
batch.push(matrix);
|
||||
batch_indexes.push(indexes);
|
||||
}
|
||||
let batch = NdArrayBackend::cat(&batch, 1);
|
||||
let batch_indexes = NdArrayBackend::cat(&batch_indexes, 1);
|
||||
|
||||
batches.push(batch);
|
||||
batches_indexes.push(batch_indexes);
|
||||
}
|
||||
|
||||
(
|
||||
NdArrayBackend::cat(&batches, 0),
|
||||
NdArrayBackend::cat(&batches_indexes, 0),
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn max_pool2d_backward_naive<E: NdArrayElement>(
|
||||
x: &NdArrayTensor<E, 4>,
|
||||
_kernel_size: [usize; 2],
|
||||
_stride: [usize; 2],
|
||||
_padding: [usize; 2],
|
||||
output_grad: &NdArrayTensor<E, 4>,
|
||||
indexes: &NdArrayTensor<i64, 4>,
|
||||
) -> NdArrayTensor<E, 4> {
|
||||
let [_batch_size, _channels, heigth, width] = output_grad.shape.dims;
|
||||
let [batch_size, channels, heigth_x, width_x] = x.shape.dims;
|
||||
|
||||
let output_grad_flatten = NdArrayBackend::reshape(
|
||||
output_grad,
|
||||
Shape::new([batch_size, channels, heigth * width]),
|
||||
);
|
||||
let indexes_flatten =
|
||||
NdArrayBackend::reshape(indexes, Shape::new([batch_size, channels, heigth * width]));
|
||||
let mut output_flatten = NdArrayBackend::zeros(
|
||||
Shape::new([batch_size, channels, heigth_x * width_x]),
|
||||
NdArrayDevice::Cpu,
|
||||
);
|
||||
|
||||
for b in 0..batch_size {
|
||||
for c in 0..channels {
|
||||
for i in 0..(heigth * width) {
|
||||
let index = NdArrayBackend::index(&indexes_flatten, [b..b + 1, c..c + 1, i..i + 1]);
|
||||
let index = NdArrayBackend::into_data(index).value[0] as usize;
|
||||
|
||||
let current_value =
|
||||
NdArrayBackend::index(&output_flatten, [b..b + 1, c..c + 1, index..index + 1]);
|
||||
let output_grad =
|
||||
NdArrayBackend::index(&output_grad_flatten, [b..b + 1, c..c + 1, i..i + 1]);
|
||||
let updated_value = NdArrayBackend::add(¤t_value, &output_grad);
|
||||
|
||||
output_flatten = NdArrayBackend::index_assign(
|
||||
&output_flatten,
|
||||
[b..b + 1, c..c + 1, index..index + 1],
|
||||
&updated_value,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
NdArrayBackend::reshape(&output_flatten, x.shape)
|
||||
}
|
||||
|
||||
fn max_pool2d_with_kernel<E: NdArrayElement>(
|
||||
x: NdArrayTensor<E, 2>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
) -> (NdArrayTensor<E, 2>, NdArrayTensor<i64, 2>) {
|
||||
let [k1, k2] = kernel_size;
|
||||
let [p1, p2] = padding;
|
||||
let [heigth, width] = x.shape.dims;
|
||||
|
||||
let heigth_new = f32::ceil((heigth - k1 + 1) as f32 / stride[0] as f32) as usize;
|
||||
let width_new = f32::ceil((width - k2 + 1) as f32 / stride[1] as f32) as usize;
|
||||
let mut output = NdArrayBackend::empty(Shape::new([heigth_new, width_new]), NdArrayDevice::Cpu);
|
||||
let mut indexes =
|
||||
NdArrayBackend::empty(Shape::new([heigth_new, width_new]), NdArrayDevice::Cpu);
|
||||
|
||||
for i in 0..heigth_new {
|
||||
for j in 0..width_new {
|
||||
let i_x = i * stride[0];
|
||||
let j_x = j * stride[1];
|
||||
|
||||
let x_ij = NdArrayBackend::index(&x, [i_x..i_x + k1, j_x..j_x + k2]);
|
||||
let x_flatten = NdArrayBackend::reshape(&x_ij, Shape::new([k1 * k2]));
|
||||
let index = NdArrayBackend::argmax(&x_flatten, 0);
|
||||
let index = NdArrayBackend::into_data(index).value[0];
|
||||
let value = NdArrayBackend::into_data(x_flatten).value[index as usize];
|
||||
let value = NdArrayBackend::from_data(
|
||||
Data::new(vec![value], Shape::new([1, 1])),
|
||||
NdArrayDevice::Cpu,
|
||||
);
|
||||
|
||||
let index_i = index / k2 as i64;
|
||||
let index_j = index - (index_i * k2 as i64);
|
||||
let ii = i64::max(0, i_x as i64 - p1 as i64 + index_i);
|
||||
let jj = i64::max(0, j_x as i64 - p2 as i64 + index_j);
|
||||
let h = (heigth - (2 * p1)) as i64;
|
||||
let index = ii * h + jj;
|
||||
|
||||
let index = NdArrayBackend::from_data(
|
||||
Data::new(vec![index], Shape::new([1, 1])),
|
||||
NdArrayDevice::Cpu,
|
||||
);
|
||||
|
||||
indexes = NdArrayBackend::index_assign(&indexes, [i..i + 1, j..j + 1], &index);
|
||||
output = NdArrayBackend::index_assign(&output, [i..i + 1, j..j + 1], &value);
|
||||
}
|
||||
}
|
||||
|
||||
(output, indexes)
|
||||
}
|
|
@ -1,3 +1,7 @@
|
|||
mod creation;
|
||||
mod module;
|
||||
mod tensor;
|
||||
|
||||
pub(crate) mod conv;
|
||||
pub(crate) mod maxpool;
|
||||
pub(crate) mod padding;
|
||||
|
|
|
@ -1,7 +1,12 @@
|
|||
use crate::{conv::conv2d_naive, element::NdArrayElement, tensor::NdArrayTensor, NdArrayBackend};
|
||||
use crate::{element::NdArrayElement, tensor::NdArrayTensor, NdArrayBackend};
|
||||
use burn_tensor::{ops::*, Shape};
|
||||
use std::ops::Add;
|
||||
|
||||
use super::{
|
||||
conv::conv2d_naive,
|
||||
maxpool::{max_pool2d_backward_naive, max_pool2d_with_indexes_naive},
|
||||
};
|
||||
|
||||
impl<E: NdArrayElement> ModuleOps<NdArrayBackend<E>> for NdArrayBackend<E> {
|
||||
fn embedding(
|
||||
weights: &NdArrayTensor<E, 2>,
|
||||
|
@ -68,16 +73,44 @@ impl<E: NdArrayElement> ModuleOps<NdArrayBackend<E>> for NdArrayBackend<E> {
|
|||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
) -> NdArrayTensor<E, 4> {
|
||||
let [batch_size, channels_in, heigth, width] = x.shape.dims;
|
||||
let mut results = Vec::with_capacity(batch_size);
|
||||
conv2d_naive(x, weight, bias, stride, padding)
|
||||
}
|
||||
|
||||
for b in 0..batch_size {
|
||||
let x = NdArrayBackend::index(x, [b..b + 1, 0..channels_in, 0..heigth, 0..width]);
|
||||
let x = NdArrayBackend::reshape(&x, Shape::new([channels_in, heigth, width]));
|
||||
fn max_pool2d(
|
||||
x: &NdArrayTensor<E, 4>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
) -> NdArrayTensor<E, 4> {
|
||||
max_pool2d_with_indexes_naive(x, kernel_size, stride, padding).0
|
||||
}
|
||||
|
||||
results.push(conv2d_naive(&x, weight, bias, stride, padding));
|
||||
}
|
||||
fn max_pool2d_with_indexes(
|
||||
x: &NdArrayTensor<E, 4>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
) -> MaxPool2dWithIndexes<NdArrayBackend<E>> {
|
||||
let (output, indexes) = max_pool2d_with_indexes_naive(x, kernel_size, stride, padding);
|
||||
|
||||
NdArrayBackend::cat(&results, 0)
|
||||
MaxPool2dWithIndexes::new(output, indexes)
|
||||
}
|
||||
|
||||
fn max_pool2d_with_indexes_backward(
|
||||
x: &NdArrayTensor<E, 4>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
output_grad: &NdArrayTensor<E, 4>,
|
||||
indexes: &NdArrayTensor<i64, 4>,
|
||||
) -> MaxPool2dBackward<NdArrayBackend<E>> {
|
||||
MaxPool2dBackward::new(max_pool2d_backward_naive(
|
||||
x,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
output_grad,
|
||||
indexes,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
use crate::{element::NdArrayElement, tensor::NdArrayTensor, NdArrayBackend, NdArrayDevice};
|
||||
use burn_tensor::{ops::TensorOps, Shape};
|
||||
|
||||
pub(crate) fn apply_padding2d<E: NdArrayElement>(
|
||||
x: &NdArrayTensor<E, 2>,
|
||||
padding: [usize; 2],
|
||||
) -> NdArrayTensor<E, 2> {
|
||||
let [heigth, width] = x.shape.dims;
|
||||
|
||||
let heigth_new = heigth + (2 * padding[0]);
|
||||
let width_new = width + (2 * padding[1]);
|
||||
|
||||
let mut x_new = NdArrayBackend::zeros(Shape::new([heigth_new, width_new]), NdArrayDevice::Cpu);
|
||||
x_new = NdArrayBackend::index_assign(
|
||||
&x_new,
|
||||
[
|
||||
padding[0]..heigth + padding[0],
|
||||
padding[1]..width + padding[1],
|
||||
],
|
||||
x,
|
||||
);
|
||||
|
||||
x_new
|
||||
}
|
|
@ -1,6 +1,3 @@
|
|||
use std::cmp::Ordering;
|
||||
use std::ops::Range;
|
||||
|
||||
use crate::tensor::BatchMatrix;
|
||||
use crate::{element::NdArrayElement, tensor::NdArrayTensor, NdArrayBackend};
|
||||
use crate::{to_nd_array_tensor, NdArrayDevice, SEED};
|
||||
|
@ -9,6 +6,8 @@ use burn_tensor::{backend::Backend, ops::TensorOps, Data, ElementConversion, Sha
|
|||
use ndarray::{ArcArray, Axis, Dim, Dimension, IxDyn, SliceInfoElem};
|
||||
use rand::rngs::StdRng;
|
||||
use rand::SeedableRng;
|
||||
use std::cmp::Ordering;
|
||||
use std::ops::Range;
|
||||
|
||||
macro_rules! keepdim {
|
||||
(
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
use crate::{element::TchElement, TchBackend, TchTensor};
|
||||
use burn_tensor::{ops::ModuleOps, Shape};
|
||||
use crate::{element::TchElement, TchBackend, TchKind, TchTensor};
|
||||
use burn_tensor::{
|
||||
ops::{MaxPool2dBackward, MaxPool2dWithIndexes, ModuleOps},
|
||||
Shape,
|
||||
};
|
||||
|
||||
impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
|
||||
fn embedding(weights: &TchTensor<E, 2>, indexes: &TchTensor<i64, 2>) -> TchTensor<E, 3> {
|
||||
|
@ -85,4 +88,87 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
|
|||
shape,
|
||||
}
|
||||
}
|
||||
|
||||
fn max_pool2d(
|
||||
x: &TchTensor<E, 4>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
) -> TchTensor<E, 4> {
|
||||
let tensor = tch::Tensor::max_pool2d(
|
||||
&x.tensor,
|
||||
&[kernel_size[0] as i64, kernel_size[1] as i64],
|
||||
&[stride[0] as i64, stride[1] as i64],
|
||||
&[padding[0] as i64, padding[1] as i64],
|
||||
&[1, 1],
|
||||
false,
|
||||
);
|
||||
let shape = Shape::from(tensor.size());
|
||||
|
||||
TchTensor {
|
||||
kind: x.kind,
|
||||
tensor,
|
||||
shape,
|
||||
}
|
||||
}
|
||||
|
||||
fn max_pool2d_with_indexes(
|
||||
x: &TchTensor<E, 4>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
) -> MaxPool2dWithIndexes<TchBackend<E>> {
|
||||
let (tensor, indexes) = tch::Tensor::max_pool2d_with_indices(
|
||||
&x.tensor,
|
||||
&[kernel_size[0] as i64, kernel_size[1] as i64],
|
||||
&[stride[0] as i64, stride[1] as i64],
|
||||
&[padding[0] as i64, padding[1] as i64],
|
||||
&[1, 1],
|
||||
false,
|
||||
);
|
||||
let shape = Shape::from(tensor.size());
|
||||
|
||||
let output = TchTensor {
|
||||
kind: x.kind,
|
||||
tensor,
|
||||
shape,
|
||||
};
|
||||
let shape = Shape::from(indexes.size());
|
||||
let indexes = TchTensor {
|
||||
kind: TchKind::<i64>::new(),
|
||||
tensor: indexes,
|
||||
shape,
|
||||
};
|
||||
|
||||
MaxPool2dWithIndexes::new(output, indexes)
|
||||
}
|
||||
|
||||
fn max_pool2d_with_indexes_backward(
|
||||
x: &TchTensor<E, 4>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
output_grad: &TchTensor<E, 4>,
|
||||
indexes: &TchTensor<i64, 4>,
|
||||
) -> MaxPool2dBackward<TchBackend<E>> {
|
||||
let grad = tch::Tensor::max_pool2d_with_indices_backward(
|
||||
&x.tensor,
|
||||
&output_grad.tensor,
|
||||
&[kernel_size[0] as i64, kernel_size[1] as i64],
|
||||
&[stride[0] as i64, stride[1] as i64],
|
||||
&[padding[0] as i64, padding[1] as i64],
|
||||
&[1, 1],
|
||||
false,
|
||||
&indexes.tensor,
|
||||
);
|
||||
|
||||
let shape = Shape::from(grad.size());
|
||||
let tensor = TchTensor {
|
||||
kind: x.kind,
|
||||
tensor: grad,
|
||||
shape,
|
||||
};
|
||||
|
||||
MaxPool2dBackward::new(tensor)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -47,3 +47,31 @@ where
|
|||
padding,
|
||||
))
|
||||
}
|
||||
|
||||
/// Applies a [2D max pooling](crate::ops::ModuleOps::max_pool2d).
|
||||
pub fn max_pool2d<B>(
|
||||
x: &Tensor<B, 4>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
) -> Tensor<B, 4>
|
||||
where
|
||||
B: Backend,
|
||||
{
|
||||
Tensor::new(B::max_pool2d(&x.value, kernel_size, stride, padding))
|
||||
}
|
||||
|
||||
/// Applies a [2D max pooling with indexes](crate::ops::ModuleOps::max_pool2d_with_indexes).
|
||||
pub fn max_pool2d_with_indexes<B>(
|
||||
x: &Tensor<B, 4>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
) -> (Tensor<B, 4>, Tensor<B::IntegerBackend, 4>)
|
||||
where
|
||||
B: Backend,
|
||||
{
|
||||
let output = B::max_pool2d_with_indexes(&x.value, kernel_size, stride, padding);
|
||||
|
||||
(Tensor::new(output.output), Tensor::new(output.indexes))
|
||||
}
|
||||
|
|
|
@ -9,6 +9,19 @@ pub struct Conv2dBackward<B: Backend> {
|
|||
pub bias_grad: Option<B::TensorPrimitive<1>>,
|
||||
}
|
||||
|
||||
/// Gradient computed during the backward pass for each tensor used by [max_pool2d](ModuleOps::max_pool2d).
|
||||
#[derive(new)]
|
||||
pub struct MaxPool2dBackward<B: Backend> {
|
||||
pub x_grad: B::TensorPrimitive<4>,
|
||||
}
|
||||
|
||||
/// Results from [max_pool2d](ModuleOps::max_pool2d_with_indexes).
|
||||
#[derive(new)]
|
||||
pub struct MaxPool2dWithIndexes<B: Backend> {
|
||||
pub output: B::TensorPrimitive<4>,
|
||||
pub indexes: <B::IntegerBackend as Backend>::TensorPrimitive<4>,
|
||||
}
|
||||
|
||||
/// Gradient computed during the backward pass for each tensor used by [conv1d](ModuleOps::conv1d).
|
||||
#[derive(new)]
|
||||
pub struct Conv1dBackward<B: Backend> {
|
||||
|
@ -77,4 +90,35 @@ pub trait ModuleOps<B: Backend> {
|
|||
) -> Conv1dBackward<B> {
|
||||
conv::conv1d_backward(x, weight, bias, stride, output_grad)
|
||||
}
|
||||
/// Two dimensional max pooling.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// x: [batch_size, channels, height, width],
|
||||
fn max_pool2d(
|
||||
x: &B::TensorPrimitive<4>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
) -> B::TensorPrimitive<4>;
|
||||
/// Two dimensional max pooling with indexes.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// x: [batch_size, channels, height, width],
|
||||
fn max_pool2d_with_indexes(
|
||||
x: &B::TensorPrimitive<4>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
) -> MaxPool2dWithIndexes<B>;
|
||||
/// Backward pass for the [max pooling 2d](ModuleOps::max_pool2d_with_indexes) operation.
|
||||
fn max_pool2d_with_indexes_backward(
|
||||
x: &B::TensorPrimitive<4>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
output_grad: &B::TensorPrimitive<4>,
|
||||
indexes: &<B::IntegerBackend as Backend>::TensorPrimitive<4>,
|
||||
) -> MaxPool2dBackward<B>;
|
||||
}
|
||||
|
|
|
@ -15,6 +15,7 @@ macro_rules! testgen_all {
|
|||
burn_tensor::testgen_module_forward!();
|
||||
burn_tensor::testgen_module_conv1d!();
|
||||
burn_tensor::testgen_module_conv2d!();
|
||||
burn_tensor::testgen_module_max_pool2d!();
|
||||
|
||||
// test ops
|
||||
burn_tensor::testgen_add!();
|
||||
|
|
|
@ -0,0 +1,230 @@
|
|||
#[burn_tensor_testgen::testgen(module_max_pool2d)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::module::{max_pool2d, max_pool2d_with_indexes};
|
||||
use burn_tensor::{Data, Tensor};
|
||||
|
||||
#[test]
|
||||
fn test_max_pool2d_simple() {
|
||||
let batch_size = 2;
|
||||
let channels_in = 2;
|
||||
let kernel_size_1 = 3;
|
||||
let kernel_size_2 = 3;
|
||||
let padding_1 = 1;
|
||||
let padding_2 = 1;
|
||||
let stride_1 = 1;
|
||||
let stride_2 = 1;
|
||||
|
||||
let x = TestTensor::from_floats([
|
||||
[
|
||||
[
|
||||
[0.9861, 0.5474, 0.4477, 0.0732, 0.3548, 0.8221],
|
||||
[0.8148, 0.5474, 0.9490, 0.7890, 0.5537, 0.5689],
|
||||
[0.5986, 0.2059, 0.4897, 0.6136, 0.2965, 0.6182],
|
||||
[0.1485, 0.9540, 0.4023, 0.6176, 0.7111, 0.3392],
|
||||
[0.3703, 0.0472, 0.2771, 0.1868, 0.8855, 0.5605],
|
||||
[0.5063, 0.1638, 0.9432, 0.7836, 0.8696, 0.1068],
|
||||
],
|
||||
[
|
||||
[0.8872, 0.0137, 0.1652, 0.5505, 0.6127, 0.6473],
|
||||
[0.1128, 0.0888, 0.1152, 0.5456, 0.6199, 0.7947],
|
||||
[0.5911, 0.7781, 0.7256, 0.6578, 0.0989, 0.9149],
|
||||
[0.5879, 0.5189, 0.6561, 0.0578, 0.7025, 0.6426],
|
||||
[0.9590, 0.0325, 0.6455, 0.6248, 0.2009, 0.1544],
|
||||
[0.7339, 0.1369, 0.6598, 0.5528, 0.6775, 0.1572],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[0.6853, 0.6439, 0.4639, 0.5573, 0.2723, 0.5910],
|
||||
[0.5419, 0.7729, 0.6743, 0.8956, 0.2997, 0.9546],
|
||||
[0.0334, 0.2178, 0.6917, 0.4958, 0.3357, 0.6584],
|
||||
[0.7358, 0.9074, 0.2462, 0.5159, 0.6420, 0.2441],
|
||||
[0.7602, 0.6297, 0.6073, 0.5937, 0.8037, 0.4881],
|
||||
[0.8859, 0.0974, 0.3954, 0.6763, 0.1078, 0.7467],
|
||||
],
|
||||
[
|
||||
[0.2991, 0.5012, 0.8024, 0.7653, 0.9378, 0.7952],
|
||||
[0.7393, 0.2336, 0.9521, 0.2719, 0.8445, 0.0454],
|
||||
[0.6479, 0.9822, 0.7905, 0.0318, 0.2474, 0.0628],
|
||||
[0.9955, 0.7591, 0.4140, 0.3215, 0.4349, 0.1527],
|
||||
[0.8064, 0.0164, 0.4002, 0.2024, 0.6128, 0.5827],
|
||||
[0.5368, 0.7895, 0.8727, 0.7793, 0.0910, 0.3421],
|
||||
],
|
||||
],
|
||||
]);
|
||||
let y = TestTensor::from_floats([
|
||||
[
|
||||
[
|
||||
[0.9861, 0.9861, 0.9490, 0.9490, 0.8221, 0.8221],
|
||||
[0.9861, 0.9861, 0.9490, 0.9490, 0.8221, 0.8221],
|
||||
[0.9540, 0.9540, 0.9540, 0.9490, 0.7890, 0.7111],
|
||||
[0.9540, 0.9540, 0.9540, 0.8855, 0.8855, 0.8855],
|
||||
[0.9540, 0.9540, 0.9540, 0.9432, 0.8855, 0.8855],
|
||||
[0.5063, 0.9432, 0.9432, 0.9432, 0.8855, 0.8855],
|
||||
],
|
||||
[
|
||||
[0.8872, 0.8872, 0.5505, 0.6199, 0.7947, 0.7947],
|
||||
[0.8872, 0.8872, 0.7781, 0.7256, 0.9149, 0.9149],
|
||||
[0.7781, 0.7781, 0.7781, 0.7256, 0.9149, 0.9149],
|
||||
[0.9590, 0.9590, 0.7781, 0.7256, 0.9149, 0.9149],
|
||||
[0.9590, 0.9590, 0.6598, 0.7025, 0.7025, 0.7025],
|
||||
[0.9590, 0.9590, 0.6598, 0.6775, 0.6775, 0.6775],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[0.7729, 0.7729, 0.8956, 0.8956, 0.9546, 0.9546],
|
||||
[0.7729, 0.7729, 0.8956, 0.8956, 0.9546, 0.9546],
|
||||
[0.9074, 0.9074, 0.9074, 0.8956, 0.9546, 0.9546],
|
||||
[0.9074, 0.9074, 0.9074, 0.8037, 0.8037, 0.8037],
|
||||
[0.9074, 0.9074, 0.9074, 0.8037, 0.8037, 0.8037],
|
||||
[0.8859, 0.8859, 0.6763, 0.8037, 0.8037, 0.8037],
|
||||
],
|
||||
[
|
||||
[0.7393, 0.9521, 0.9521, 0.9521, 0.9378, 0.9378],
|
||||
[0.9822, 0.9822, 0.9822, 0.9521, 0.9378, 0.9378],
|
||||
[0.9955, 0.9955, 0.9822, 0.9521, 0.8445, 0.8445],
|
||||
[0.9955, 0.9955, 0.9822, 0.7905, 0.6128, 0.6128],
|
||||
[0.9955, 0.9955, 0.8727, 0.8727, 0.7793, 0.6128],
|
||||
[0.8064, 0.8727, 0.8727, 0.8727, 0.7793, 0.6128],
|
||||
],
|
||||
],
|
||||
]);
|
||||
|
||||
let output = max_pool2d(
|
||||
&x,
|
||||
[kernel_size_1, kernel_size_2],
|
||||
[stride_1, stride_2],
|
||||
[padding_1, padding_2],
|
||||
);
|
||||
|
||||
y.to_data().assert_approx_eq(&output.into_data(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_pool2d_different_padding_stride_kernel() {
|
||||
let batch_size = 1;
|
||||
let channels_in = 1;
|
||||
let kernel_size_1 = 3;
|
||||
let kernel_size_2 = 1;
|
||||
let padding_1 = 1;
|
||||
let padding_2 = 0;
|
||||
let stride_1 = 1;
|
||||
let stride_2 = 2;
|
||||
|
||||
let x = TestTensor::from_floats([[[
|
||||
[0.6309, 0.6112, 0.6998],
|
||||
[0.4708, 0.9161, 0.5402],
|
||||
[0.4577, 0.7397, 0.9870],
|
||||
[0.6380, 0.4352, 0.5884],
|
||||
[0.6277, 0.5139, 0.4525],
|
||||
[0.9333, 0.9846, 0.5006],
|
||||
]]]);
|
||||
let y = TestTensor::from_floats([[[
|
||||
[0.6309, 0.6998],
|
||||
[0.6309, 0.9870],
|
||||
[0.6380, 0.9870],
|
||||
[0.6380, 0.9870],
|
||||
[0.9333, 0.5884],
|
||||
[0.9333, 0.5006],
|
||||
]]]);
|
||||
|
||||
let output = max_pool2d(
|
||||
&x,
|
||||
[kernel_size_1, kernel_size_2],
|
||||
[stride_1, stride_2],
|
||||
[padding_1, padding_2],
|
||||
);
|
||||
|
||||
y.to_data().assert_approx_eq(&output.into_data(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_pool2d_with_indexes() {
|
||||
let batch_size = 1;
|
||||
let channels_in = 1;
|
||||
let kernel_size_1 = 2;
|
||||
let kernel_size_2 = 2;
|
||||
let padding_1 = 1;
|
||||
let padding_2 = 1;
|
||||
let stride_1 = 1;
|
||||
let stride_2 = 1;
|
||||
|
||||
let x = TestTensor::from_floats([[[
|
||||
[0.2479, 0.6386, 0.3166, 0.5742],
|
||||
[0.7065, 0.1940, 0.6305, 0.8959],
|
||||
[0.5416, 0.8602, 0.8129, 0.1662],
|
||||
[0.3358, 0.3059, 0.8293, 0.0990],
|
||||
]]]);
|
||||
let indexes = Data::<i64, 4>::from([[[
|
||||
[0, 1, 1, 3, 3],
|
||||
[4, 4, 1, 7, 7],
|
||||
[4, 9, 9, 7, 7],
|
||||
[8, 9, 9, 14, 11],
|
||||
[12, 12, 14, 14, 15],
|
||||
]]]);
|
||||
let y = TestTensor::from_floats([[[
|
||||
[0.2479, 0.6386, 0.6386, 0.5742, 0.5742],
|
||||
[0.7065, 0.7065, 0.6386, 0.8959, 0.8959],
|
||||
[0.7065, 0.8602, 0.8602, 0.8959, 0.8959],
|
||||
[0.5416, 0.8602, 0.8602, 0.8293, 0.1662],
|
||||
[0.3358, 0.3358, 0.8293, 0.8293, 0.0990],
|
||||
]]]);
|
||||
|
||||
let (output, output_indexes) = max_pool2d_with_indexes(
|
||||
&x,
|
||||
[kernel_size_1, kernel_size_2],
|
||||
[stride_1, stride_2],
|
||||
[padding_1, padding_2],
|
||||
);
|
||||
|
||||
y.to_data().assert_approx_eq(&output.into_data(), 3);
|
||||
assert_eq!(indexes.value, output_indexes.into_data().value);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_pool2d_complex() {
|
||||
let batch_size = 1;
|
||||
let channels_in = 1;
|
||||
let kernel_size_1 = 4;
|
||||
let kernel_size_2 = 2;
|
||||
let padding_1 = 2;
|
||||
let padding_2 = 1;
|
||||
let stride_1 = 1;
|
||||
let stride_2 = 2;
|
||||
|
||||
let x = TestTensor::from_floats([[[
|
||||
[0.5388, 0.0676, 0.7122, 0.8316, 0.0653],
|
||||
[0.9154, 0.1536, 0.9089, 0.8016, 0.7518],
|
||||
[0.2073, 0.0501, 0.8811, 0.5604, 0.5075],
|
||||
[0.4384, 0.9963, 0.9698, 0.4988, 0.2609],
|
||||
[0.3391, 0.2230, 0.4610, 0.5365, 0.6880],
|
||||
]]]);
|
||||
let indexes = Data::<i64, 4>::from([[[
|
||||
[5, 7, 3],
|
||||
[5, 7, 3],
|
||||
[5, 16, 3],
|
||||
[5, 16, 8],
|
||||
[15, 16, 24],
|
||||
[15, 16, 24],
|
||||
]]]);
|
||||
let y = TestTensor::from_floats([[[
|
||||
[0.9154, 0.9089, 0.8316],
|
||||
[0.9154, 0.9089, 0.8316],
|
||||
[0.9154, 0.9963, 0.8316],
|
||||
[0.9154, 0.9963, 0.8016],
|
||||
[0.4384, 0.9963, 0.688],
|
||||
[0.4384, 0.9963, 0.688],
|
||||
]]]);
|
||||
let (output, output_indexes) = max_pool2d_with_indexes(
|
||||
&x,
|
||||
[kernel_size_1, kernel_size_2],
|
||||
[stride_1, stride_2],
|
||||
[padding_1, padding_2],
|
||||
);
|
||||
|
||||
y.to_data().assert_approx_eq(&output.into_data(), 3);
|
||||
assert_eq!(indexes.value, output_indexes.into_data().value);
|
||||
}
|
||||
}
|
|
@ -1,3 +1,4 @@
|
|||
mod conv1d;
|
||||
mod conv2d;
|
||||
mod forward;
|
||||
mod maxpool2d;
|
||||
|
|
Loading…
Reference in New Issue