mirror of https://github.com/tracel-ai/burn.git
Feat/tensor/adaptive avg pool2d (#572)
This commit is contained in:
parent
597eab524d
commit
8436d4ff66
|
@ -38,12 +38,11 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
}
|
||||
|
||||
fn embedding_backward(
|
||||
weights: ADTensor<B, 2>,
|
||||
output: ADTensor<B, 3>,
|
||||
indices: IntTensor<B, 2>,
|
||||
_weights: ADTensor<B, 2>,
|
||||
_output: ADTensor<B, 3>,
|
||||
_indices: IntTensor<B, 2>,
|
||||
) -> ADTensor<B, 2> {
|
||||
let tensor = B::embedding_backward(weights.primitive, output.primitive, indices);
|
||||
ADTensor::new(tensor)
|
||||
panic!("Can't differentiate embedding backward.");
|
||||
}
|
||||
|
||||
fn conv2d(
|
||||
|
@ -446,16 +445,15 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn avg_pool2d_backward(
|
||||
x: ADTensor<B, 4>,
|
||||
grad: ADTensor<B, 4>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
_x: ADTensor<B, 4>,
|
||||
_grad: ADTensor<B, 4>,
|
||||
_kernel_size: [usize; 2],
|
||||
_stride: [usize; 2],
|
||||
_padding: [usize; 2],
|
||||
) -> ADTensor<B, 4> {
|
||||
let tensor =
|
||||
B::avg_pool2d_backward(x.primitive, grad.primitive, kernel_size, stride, padding);
|
||||
ADTensor::new(tensor)
|
||||
panic!("Can't differentiate avg pool 2d backward.");
|
||||
}
|
||||
|
||||
fn max_pool2d(
|
||||
|
@ -513,22 +511,50 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
}
|
||||
|
||||
fn max_pool2d_with_indices_backward(
|
||||
x: ADTensor<B, 4>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
output_grad: ADTensor<B, 4>,
|
||||
indices: IntTensor<B, 4>,
|
||||
_x: ADTensor<B, 4>,
|
||||
_kernel_size: [usize; 2],
|
||||
_stride: [usize; 2],
|
||||
_padding: [usize; 2],
|
||||
_output_grad: ADTensor<B, 4>,
|
||||
_indices: IntTensor<B, 4>,
|
||||
) -> MaxPool2dBackward<ADBackendDecorator<B>> {
|
||||
let output = B::max_pool2d_with_indices_backward(
|
||||
x.primitive,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
output_grad.primitive,
|
||||
indices,
|
||||
);
|
||||
MaxPool2dBackward::new(ADTensor::new(output.x_grad))
|
||||
panic!("Can't differentiate max pool2d with indices backward.");
|
||||
}
|
||||
|
||||
fn adaptive_avg_pool2d(x: ADTensor<B, 4>, output_size: [usize; 2]) -> ADTensor<B, 4> {
|
||||
#[derive(Debug)]
|
||||
struct AdaptiveAvgPool2D;
|
||||
|
||||
impl<B: Backend> Backward<B, 4, 1> for AdaptiveAvgPool2D {
|
||||
type State = B::TensorPrimitive<4>;
|
||||
|
||||
fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
|
||||
let [node_parent] = ops.parents;
|
||||
let grad = grads.consume::<B, 4>(&ops.node);
|
||||
|
||||
if let Some(node) = node_parent {
|
||||
let grad = B::adaptive_avg_pool2d_backward(ops.state, grad);
|
||||
grads.register::<B, 4>(node, grad);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
match AdaptiveAvgPool2D.prepare([x.node], [x.graph]).statefull() {
|
||||
OpsKind::Tracked(prep) => prep.finish(
|
||||
x.primitive.clone(),
|
||||
B::adaptive_avg_pool2d(x.primitive, output_size),
|
||||
),
|
||||
OpsKind::UnTracked(prep) => {
|
||||
prep.finish(B::adaptive_avg_pool2d(x.primitive, output_size))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn adaptive_avg_pool2d_backward(
|
||||
_x: ADTensor<B, 4>,
|
||||
_grad: ADTensor<B, 4>,
|
||||
) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<4> {
|
||||
panic!("Can't differentiate adaptive avg pool2d backward.");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
#[burn_tensor_testgen::testgen(ad_adaptive_avg_pool2d)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::module::adaptive_avg_pool2d;
|
||||
use burn_tensor::{Data, Shape, Tensor};
|
||||
|
||||
#[test]
|
||||
fn test_avg_pool2d_simple() {
|
||||
let test = AdaptiveAvgPool2dTestCase {
|
||||
batch_size: 1,
|
||||
channels: 2,
|
||||
height: 5,
|
||||
width: 3,
|
||||
output_size_1: 3,
|
||||
output_size_2: 2,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from_floats([[
|
||||
[
|
||||
[0.2500, 0.5000, 0.2500],
|
||||
[0.4167, 0.8333, 0.4167],
|
||||
[0.1667, 0.3333, 0.1667],
|
||||
[0.4167, 0.8333, 0.4167],
|
||||
[0.2500, 0.5000, 0.2500],
|
||||
],
|
||||
[
|
||||
[0.2500, 0.5000, 0.2500],
|
||||
[0.4167, 0.8333, 0.4167],
|
||||
[0.1667, 0.3333, 0.1667],
|
||||
[0.4167, 0.8333, 0.4167],
|
||||
[0.2500, 0.5000, 0.2500],
|
||||
],
|
||||
]]));
|
||||
}
|
||||
|
||||
struct AdaptiveAvgPool2dTestCase {
|
||||
batch_size: usize,
|
||||
channels: usize,
|
||||
height: usize,
|
||||
width: usize,
|
||||
output_size_1: usize,
|
||||
output_size_2: usize,
|
||||
}
|
||||
|
||||
impl AdaptiveAvgPool2dTestCase {
|
||||
fn assert_output(self, x_grad: TestTensor<4>) {
|
||||
let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]);
|
||||
let x = TestADTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_x.num_elements())
|
||||
.reshape(shape_x)
|
||||
.into_data()
|
||||
.convert(),
|
||||
)
|
||||
.require_grad();
|
||||
let output = adaptive_avg_pool2d(x.clone(), [self.output_size_1, self.output_size_2]);
|
||||
let grads = output.backward();
|
||||
let x_grad_actual = x.grad(&grads).unwrap();
|
||||
|
||||
x_grad
|
||||
.to_data()
|
||||
.assert_approx_eq(&x_grad_actual.into_data(), 3);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,6 +1,7 @@
|
|||
#![allow(missing_docs)]
|
||||
|
||||
mod abs;
|
||||
mod adaptive_avgpool2d;
|
||||
mod add;
|
||||
mod aggregation;
|
||||
mod avgpool1d;
|
||||
|
@ -60,6 +61,7 @@ macro_rules! testgen_all {
|
|||
burn_autodiff::testgen_ad_max_pool2d!();
|
||||
burn_autodiff::testgen_ad_avg_pool1d!();
|
||||
burn_autodiff::testgen_ad_avg_pool2d!();
|
||||
burn_autodiff::testgen_ad_adaptive_avg_pool2d!();
|
||||
burn_autodiff::testgen_module_backward!();
|
||||
|
||||
// Tensor
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
use crate as burn;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Tensor;
|
||||
use burn_tensor::module::adaptive_avg_pool2d;
|
||||
|
||||
/// Configuration to create a [2D adaptive avg pooling](AdaptiveAvgPool2d) layer.
|
||||
#[derive(Config)]
|
||||
pub struct AdaptiveAvgPool2dConfig {
|
||||
/// The size of the output.
|
||||
pub output_size: [usize; 2],
|
||||
}
|
||||
|
||||
/// Applies a 2D adaptive avg pooling over input tensors.
|
||||
#[derive(Module, Debug, Clone)]
|
||||
pub struct AdaptiveAvgPool2d {
|
||||
output_size: [usize; 2],
|
||||
}
|
||||
|
||||
impl AdaptiveAvgPool2dConfig {
|
||||
/// Initialize a new [adaptive avg pool 2d](AdaptiveAvgPool2d) module.
|
||||
pub fn init(&self) -> AdaptiveAvgPool2d {
|
||||
AdaptiveAvgPool2d {
|
||||
output_size: self.output_size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AdaptiveAvgPool2d {
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - input: [batch_size, channels, height_in, width_in],
|
||||
/// - output: [batch_size, channels, height_out, width_out],
|
||||
pub fn forward<B: Backend>(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
adaptive_avg_pool2d(input, self.output_size)
|
||||
}
|
||||
}
|
|
@ -10,8 +10,6 @@ use burn_tensor::module::avg_pool1d;
|
|||
/// Configuration to create a [1D avg pooling](AvgPool1d) layer.
|
||||
#[derive(Config)]
|
||||
pub struct AvgPool1dConfig {
|
||||
/// The number of channels.
|
||||
pub channels: usize,
|
||||
/// The size of the kernel.
|
||||
pub kernel_size: usize,
|
||||
/// The stride.
|
||||
|
|
|
@ -10,8 +10,6 @@ use burn_tensor::module::avg_pool2d;
|
|||
/// Configuration to create a [2D avg pooling](AvgPool2d) layer.
|
||||
#[derive(Config)]
|
||||
pub struct AvgPool2dConfig {
|
||||
/// The number of channels.
|
||||
pub channels: usize,
|
||||
/// The size of the kernel.
|
||||
pub kernel_size: [usize; 2],
|
||||
/// The strides.
|
||||
|
|
|
@ -10,8 +10,6 @@ use burn_tensor::module::max_pool2d;
|
|||
/// Configuration to create an [2D max pooling](MaxPool2d) layer.
|
||||
#[derive(Debug, Config)]
|
||||
pub struct MaxPool2dConfig {
|
||||
/// The number of channels.
|
||||
pub channels: usize,
|
||||
/// The size of the kernel.
|
||||
pub kernel_size: [usize; 2],
|
||||
/// The strides.
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
mod adaptive_avg_pool2d;
|
||||
mod avg_pool1d;
|
||||
mod avg_pool2d;
|
||||
mod max_pool2d;
|
||||
|
||||
pub use adaptive_avg_pool2d::*;
|
||||
pub use avg_pool1d::*;
|
||||
pub use avg_pool2d::*;
|
||||
pub use max_pool2d::*;
|
||||
|
|
|
@ -48,7 +48,6 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for MaxPool2dNode {
|
|||
|
||||
fn field_init(&self, _with_record: bool) -> Option<TokenStream> {
|
||||
let name = &self.field.name;
|
||||
let channels = self.config.channels.to_tokens();
|
||||
let kernel_size = self.config.kernel_size.to_tokens();
|
||||
let strides = self.config.strides.to_tokens();
|
||||
let padding = self.config.padding.to_tokens();
|
||||
|
@ -58,7 +57,7 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for MaxPool2dNode {
|
|||
};
|
||||
|
||||
let tokens = quote! {
|
||||
let #name = MaxPool2dConfig::new(#channels, #kernel_size)
|
||||
let #name = MaxPool2dConfig::new(#kernel_size)
|
||||
.with_strides(#strides)
|
||||
.with_padding(#padding)
|
||||
.#init_line
|
||||
|
@ -105,7 +104,7 @@ mod tests {
|
|||
"max_pool2d",
|
||||
TensorType::new_float("input", 4),
|
||||
TensorType::new_float("output", 4),
|
||||
MaxPool2dConfig::new(1, [3, 3])
|
||||
MaxPool2dConfig::new([3, 3])
|
||||
.with_strides([1, 1])
|
||||
.with_padding(PaddingConfig2d::Valid),
|
||||
));
|
||||
|
@ -126,7 +125,7 @@ mod tests {
|
|||
|
||||
impl<B: Backend> Model <B> {
|
||||
pub fn new_with(record: ModelRecord<B>) -> Self {
|
||||
let max_pool2d = MaxPool2dConfig::new(1, [3, 3])
|
||||
let max_pool2d = MaxPool2dConfig::new([3, 3])
|
||||
.with_strides([1, 1])
|
||||
.with_padding(PaddingConfig2d::Valid)
|
||||
.init();
|
||||
|
|
|
@ -69,14 +69,12 @@ pub fn conv2d_config(curr: &Node) -> Conv2dConfig {
|
|||
|
||||
/// Create a MaxPool2dConfig from the attributes of the node
|
||||
pub fn max_pool2d_config(curr: &Node) -> MaxPool2dConfig {
|
||||
let mut channels: i64 = 1;
|
||||
let mut kernel_shape = Vec::new();
|
||||
let mut strides = Vec::new();
|
||||
let mut pads = Vec::new();
|
||||
|
||||
for (key, value) in curr.attrs.iter() {
|
||||
match key.as_str() {
|
||||
"channels" => attr_value_i64(value, &mut channels),
|
||||
"kernel_shape" => attr_value_vec_i64(value, &mut kernel_shape),
|
||||
"strides" => attr_value_vec_i64(value, &mut strides),
|
||||
"pads" => attr_value_vec_i64(value, &mut pads),
|
||||
|
@ -86,12 +84,9 @@ pub fn max_pool2d_config(curr: &Node) -> MaxPool2dConfig {
|
|||
|
||||
let padding = padding_config(&pads);
|
||||
|
||||
MaxPool2dConfig::new(
|
||||
channels as usize,
|
||||
[kernel_shape[0] as usize, kernel_shape[1] as usize],
|
||||
)
|
||||
.with_strides([strides[0] as usize, strides[1] as usize])
|
||||
.with_padding(padding)
|
||||
MaxPool2dConfig::new([kernel_shape[0] as usize, kernel_shape[1] as usize])
|
||||
.with_strides([strides[0] as usize, strides[1] as usize])
|
||||
.with_padding(padding)
|
||||
}
|
||||
|
||||
/// Create a FlattenConfig from the attributes of the node
|
||||
|
|
|
@ -0,0 +1,103 @@
|
|||
use crate::{
|
||||
element::FloatNdArrayElement, iter_par, run_par, sharing::UnsafeSharedRef,
|
||||
tensor::NdArrayTensor,
|
||||
};
|
||||
use burn_tensor::ElementConversion;
|
||||
use ndarray::Array4;
|
||||
|
||||
pub(crate) fn adaptive_avg_pool2d<E: FloatNdArrayElement>(
|
||||
x: NdArrayTensor<E, 4>,
|
||||
output_size: [usize; 2],
|
||||
) -> NdArrayTensor<E, 4> {
|
||||
let [batch_size, channels, input_height, input_width] = x.shape().dims;
|
||||
|
||||
let x = x.array;
|
||||
let mut output = Array4::from_elem(
|
||||
(batch_size, channels, output_size[0], output_size[1]),
|
||||
0.elem(),
|
||||
);
|
||||
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
|
||||
|
||||
run_par!(|| {
|
||||
iter_par!(0, batch_size * channels).for_each(|k| unsafe {
|
||||
let b = k / channels;
|
||||
let c = k % channels;
|
||||
|
||||
let output = unsafe_shared_out.get();
|
||||
for h in 0..output_size[0] {
|
||||
for w in 0..output_size[1] {
|
||||
let ih_start = start_index(h, output_size[0], input_height);
|
||||
let ih_end = end_index(h, output_size[0], input_height);
|
||||
let iw_start = start_index(w, output_size[1], input_width);
|
||||
let iw_end = end_index(w, output_size[1], input_width);
|
||||
|
||||
let mut sum_val: E = 0.elem();
|
||||
|
||||
for ih in ih_start..ih_end {
|
||||
for iw in iw_start..iw_end {
|
||||
sum_val += x[[b, c, ih, iw]];
|
||||
}
|
||||
}
|
||||
|
||||
let count: E = (((ih_end - ih_start) * (iw_end - iw_start)) as i32).elem();
|
||||
output[[b, c, h, w]] = sum_val / count.elem();
|
||||
}
|
||||
}
|
||||
})
|
||||
});
|
||||
|
||||
NdArrayTensor::new(output.into_dyn().into_shared())
|
||||
}
|
||||
|
||||
pub(crate) fn adaptive_avg_pool2d_backward<E: FloatNdArrayElement>(
|
||||
x: NdArrayTensor<E, 4>,
|
||||
grad: NdArrayTensor<E, 4>,
|
||||
) -> NdArrayTensor<E, 4> {
|
||||
let [_, _, input_height, input_width] = x.shape().dims;
|
||||
let [batch_size, channels, output_height, output_width] = grad.shape().dims;
|
||||
|
||||
let mut output_grad =
|
||||
Array4::from_elem((batch_size, channels, input_height, input_width), 0.elem());
|
||||
let unsafe_shared_out = UnsafeSharedRef::new(&mut output_grad);
|
||||
|
||||
run_par!(|| {
|
||||
iter_par!(0, batch_size * channels).for_each(|k| unsafe {
|
||||
let b = k / channels;
|
||||
let c = k % channels;
|
||||
|
||||
let output_grad = unsafe_shared_out.get();
|
||||
for oh in 0..output_height {
|
||||
for ow in 0..output_width {
|
||||
let ih_start = start_index(oh, output_height, input_height);
|
||||
let ih_end = end_index(oh, output_height, input_height);
|
||||
|
||||
let iw_start = start_index(ow, output_width, input_width);
|
||||
let iw_end = end_index(ow, output_width, input_width);
|
||||
|
||||
let count: E = (((ih_end - ih_start) * (iw_end - iw_start)) as i32).elem();
|
||||
|
||||
for ih in ih_start..ih_end {
|
||||
for iw in iw_start..iw_end {
|
||||
output_grad[[b, c, ih, iw]] +=
|
||||
grad.array[[b, c, oh, ow]] / count.elem();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
});
|
||||
|
||||
NdArrayTensor::new(output_grad.into_dyn().into_shared())
|
||||
}
|
||||
|
||||
fn start_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize {
|
||||
libm::floorf((output_size_index as f32 * input_size as f32) / output_size as f32) as usize
|
||||
}
|
||||
|
||||
fn end_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize {
|
||||
let index =
|
||||
libm::ceilf(((output_size_index + 1) as f32 * input_size as f32) / output_size as f32)
|
||||
as usize;
|
||||
|
||||
usize::min(index, input_size)
|
||||
}
|
|
@ -5,6 +5,7 @@ mod int_tensor;
|
|||
mod module;
|
||||
mod tensor;
|
||||
|
||||
pub(crate) mod adaptive_avgpool;
|
||||
pub(crate) mod avgpool;
|
||||
pub(crate) mod conv;
|
||||
pub(crate) mod macros;
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
use super::{
|
||||
adaptive_avgpool::{adaptive_avg_pool2d, adaptive_avg_pool2d_backward},
|
||||
avgpool::{avg_pool2d, avg_pool2d_backward},
|
||||
conv::{conv2d, conv_transpose2d},
|
||||
maxpool::{max_pool2d, max_pool2d_backward, max_pool2d_with_indices},
|
||||
|
@ -81,4 +82,15 @@ impl<E: FloatNdArrayElement> ModuleOps<NdArrayBackend<E>> for NdArrayBackend<E>
|
|||
indices,
|
||||
))
|
||||
}
|
||||
|
||||
fn adaptive_avg_pool2d(x: NdArrayTensor<E, 4>, output_size: [usize; 2]) -> NdArrayTensor<E, 4> {
|
||||
adaptive_avg_pool2d(x, output_size)
|
||||
}
|
||||
|
||||
fn adaptive_avg_pool2d_backward(
|
||||
x: NdArrayTensor<E, 4>,
|
||||
grad: NdArrayTensor<E, 4>,
|
||||
) -> NdArrayTensor<E, 4> {
|
||||
adaptive_avg_pool2d_backward(x, grad)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -26,7 +26,7 @@ impl<B: Backend> ConvBlock<B> {
|
|||
let conv = nn::conv::Conv2dConfig::new(config.channels, config.kernel_size)
|
||||
.with_padding(nn::PaddingConfig2d::Same)
|
||||
.init();
|
||||
let pool = nn::pool::MaxPool2dConfig::new(config.channels[1], config.kernel_size)
|
||||
let pool = nn::pool::MaxPool2dConfig::new(config.kernel_size)
|
||||
.with_padding(nn::PaddingConfig2d::Same)
|
||||
.init();
|
||||
let activation = nn::GELU::new();
|
||||
|
|
|
@ -220,4 +220,16 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
|
|||
|
||||
MaxPool2dBackward::new(TchTensor::new(grad))
|
||||
}
|
||||
|
||||
fn adaptive_avg_pool2d(x: TchTensor<E, 4>, output_size: [usize; 2]) -> TchTensor<E, 4> {
|
||||
let tensor = tch::Tensor::adaptive_avg_pool2d(&x.tensor, output_size.map(|e| e as i64));
|
||||
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
fn adaptive_avg_pool2d_backward(x: TchTensor<E, 4>, grad: TchTensor<E, 4>) -> TchTensor<E, 4> {
|
||||
let tensor = tch::Tensor::internal_adaptive_avg_pool2d_backward(&x.tensor, &grad.tensor);
|
||||
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -137,3 +137,11 @@ where
|
|||
|
||||
(Tensor::new(output.output), Tensor::new(output.indices))
|
||||
}
|
||||
|
||||
/// Applies a [2D adaptive avg pooling](crate::ops::ModuleOps::adaptive_avg_pool2d).
|
||||
pub fn adaptive_avg_pool2d<B>(x: Tensor<B, 4>, output_size: [usize; 2]) -> Tensor<B, 4>
|
||||
where
|
||||
B: Backend,
|
||||
{
|
||||
Tensor::new(B::adaptive_avg_pool2d(x.primitive, output_size))
|
||||
}
|
||||
|
|
|
@ -262,6 +262,22 @@ pub trait ModuleOps<B: Backend> {
|
|||
padding: [usize; 2],
|
||||
) -> B::TensorPrimitive<4>;
|
||||
|
||||
/// Two dimensional adaptive avg pooling.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// x: [batch_size, channels, height, width],
|
||||
fn adaptive_avg_pool2d(
|
||||
x: B::TensorPrimitive<4>,
|
||||
output_size: [usize; 2],
|
||||
) -> B::TensorPrimitive<4>;
|
||||
|
||||
/// Backward pass for the [adaptive avg pooling 2d](ModuleOps::adaptive_avg_pool2d) operation.
|
||||
fn adaptive_avg_pool2d_backward(
|
||||
x: B::TensorPrimitive<4>,
|
||||
grad: B::TensorPrimitive<4>,
|
||||
) -> B::TensorPrimitive<4>;
|
||||
|
||||
/// Two dimensional max pooling.
|
||||
///
|
||||
/// # Shapes
|
||||
|
|
|
@ -23,6 +23,7 @@ macro_rules! testgen_all {
|
|||
burn_tensor::testgen_module_max_pool2d!();
|
||||
burn_tensor::testgen_module_avg_pool1d!();
|
||||
burn_tensor::testgen_module_avg_pool2d!();
|
||||
burn_tensor::testgen_module_adaptive_avg_pool2d!();
|
||||
|
||||
// test ops
|
||||
burn_tensor::testgen_add!();
|
||||
|
|
|
@ -0,0 +1,103 @@
|
|||
#[burn_tensor_testgen::testgen(module_adaptive_avg_pool2d)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::module::adaptive_avg_pool2d;
|
||||
use burn_tensor::{Data, Shape, Tensor};
|
||||
|
||||
#[test]
|
||||
fn test_adaptive_avg_pool2d_simple() {
|
||||
let test = AdaptiveAvgPool2dTestCase {
|
||||
batch_size: 1,
|
||||
channels: 2,
|
||||
height: 8,
|
||||
width: 6,
|
||||
height_out: 4,
|
||||
width_out: 4,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from_floats([[
|
||||
[
|
||||
[3.5000, 4.5000, 6.5000, 7.5000],
|
||||
[15.5000, 16.5000, 18.5000, 19.5000],
|
||||
[27.5000, 28.5000, 30.5000, 31.5000],
|
||||
[39.5000, 40.5000, 42.5000, 43.5000],
|
||||
],
|
||||
[
|
||||
[51.5000, 52.5000, 54.5000, 55.5000],
|
||||
[63.5000, 64.5000, 66.5000, 67.5000],
|
||||
[75.5000, 76.5000, 78.5000, 79.5000],
|
||||
[87.5000, 88.5000, 90.5000, 91.5000],
|
||||
],
|
||||
]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_adaptive_avg_pool2d_dyn_filter_size() {
|
||||
let test = AdaptiveAvgPool2dTestCase {
|
||||
batch_size: 1,
|
||||
channels: 2,
|
||||
height: 5,
|
||||
width: 7,
|
||||
height_out: 3,
|
||||
width_out: 2,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from_floats([[
|
||||
[[5.0000, 8.0000], [15.5000, 18.5000], [26.0000, 29.0000]],
|
||||
[[40.0000, 43.0000], [50.5000, 53.5000], [61.0000, 64.0000]],
|
||||
]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_adaptive_avg_pool2d_bigger_output() {
|
||||
let test = AdaptiveAvgPool2dTestCase {
|
||||
batch_size: 1,
|
||||
channels: 2,
|
||||
height: 4,
|
||||
width: 3,
|
||||
height_out: 5,
|
||||
width_out: 4,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from_floats([[
|
||||
[
|
||||
[0.0000, 0.5000, 1.5000, 2.0000],
|
||||
[1.5000, 2.0000, 3.0000, 3.5000],
|
||||
[4.5000, 5.0000, 6.0000, 6.5000],
|
||||
[7.5000, 8.0000, 9.0000, 9.5000],
|
||||
[9.0000, 9.5000, 10.5000, 11.0000],
|
||||
],
|
||||
[
|
||||
[12.0000, 12.5000, 13.5000, 14.0000],
|
||||
[13.5000, 14.0000, 15.0000, 15.5000],
|
||||
[16.5000, 17.0000, 18.0000, 18.5000],
|
||||
[19.5000, 20.0000, 21.0000, 21.5000],
|
||||
[21.0000, 21.5000, 22.5000, 23.0000],
|
||||
],
|
||||
]]));
|
||||
}
|
||||
|
||||
struct AdaptiveAvgPool2dTestCase {
|
||||
batch_size: usize,
|
||||
channels: usize,
|
||||
height: usize,
|
||||
width: usize,
|
||||
height_out: usize,
|
||||
width_out: usize,
|
||||
}
|
||||
|
||||
impl AdaptiveAvgPool2dTestCase {
|
||||
fn assert_output(self, y: TestTensor<4>) {
|
||||
let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]);
|
||||
let x = TestTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_x.num_elements())
|
||||
.reshape(shape_x)
|
||||
.into_data()
|
||||
.convert(),
|
||||
);
|
||||
let output = adaptive_avg_pool2d(x, [self.height_out, self.width_out]);
|
||||
|
||||
y.to_data().assert_approx_eq(&output.into_data(), 3);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,3 +1,4 @@
|
|||
mod adaptive_avgpool2d;
|
||||
mod avgpool1d;
|
||||
mod avgpool2d;
|
||||
mod conv1d;
|
||||
|
|
|
@ -0,0 +1,108 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use burn_tensor::Shape;
|
||||
use wgpu::Buffer;
|
||||
|
||||
use crate::{
|
||||
element::WgpuElement,
|
||||
kernel::{elemwise_workgroup, KernelSettings},
|
||||
kernel_wgsl,
|
||||
tensor::WgpuTensor,
|
||||
};
|
||||
|
||||
kernel_wgsl!(
|
||||
AdaptiveAvgPool2d,
|
||||
"../../template/pool/adaptive_avg_pool2d.wgsl"
|
||||
);
|
||||
kernel_wgsl!(
|
||||
AdaptiveAvgPool2dBackward,
|
||||
"../../template/pool/adaptive_avg_pool2d_backward.wgsl"
|
||||
);
|
||||
|
||||
pub(crate) fn adaptive_avg_pool2d<E: WgpuElement>(
|
||||
x: WgpuTensor<E, 4>,
|
||||
output_size: [usize; 2],
|
||||
) -> WgpuTensor<E, 4> {
|
||||
const WORKGROUP: usize = 32;
|
||||
|
||||
let [batch_size, channels, _, _] = x.shape.dims;
|
||||
|
||||
let output_shape = Shape::new([batch_size, channels, output_size[0], output_size[1]]);
|
||||
let num_elems = output_shape.num_elements();
|
||||
let output_buffer = x
|
||||
.context
|
||||
.create_buffer(num_elems * core::mem::size_of::<E>());
|
||||
let output = WgpuTensor::new(x.context.clone(), output_shape, output_buffer);
|
||||
|
||||
let kernel = x
|
||||
.context
|
||||
.compile_static::<KernelSettings<AdaptiveAvgPool2d, E, i32, WORKGROUP, WORKGROUP, 1>>();
|
||||
|
||||
let info_buffer = build_info(&x, &output);
|
||||
|
||||
x.context.execute(
|
||||
elemwise_workgroup(output.shape.num_elements(), WORKGROUP),
|
||||
kernel,
|
||||
&[&x.buffer, &output.buffer, &info_buffer],
|
||||
);
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
pub(crate) fn adaptive_avg_pool2d_backward<E: WgpuElement>(
|
||||
x: WgpuTensor<E, 4>,
|
||||
out_grad: WgpuTensor<E, 4>,
|
||||
) -> WgpuTensor<E, 4> {
|
||||
const WORKGROUP: usize = 32;
|
||||
|
||||
let output_shape = x.shape.clone();
|
||||
let num_elems = output_shape.num_elements();
|
||||
let output_buffer = x
|
||||
.context
|
||||
.create_buffer(num_elems * core::mem::size_of::<E>());
|
||||
let output = WgpuTensor::new(x.context.clone(), output_shape, output_buffer);
|
||||
|
||||
let kernel = x.context.compile_static::<KernelSettings<
|
||||
AdaptiveAvgPool2dBackward,
|
||||
E,
|
||||
i32,
|
||||
WORKGROUP,
|
||||
WORKGROUP,
|
||||
1,
|
||||
>>();
|
||||
|
||||
let info_buffer = build_info(&x, &out_grad);
|
||||
|
||||
x.context.execute(
|
||||
elemwise_workgroup(output.shape.num_elements(), WORKGROUP),
|
||||
kernel,
|
||||
&[&out_grad.buffer, &output.buffer, &info_buffer],
|
||||
);
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
fn build_info<E: WgpuElement>(x: &WgpuTensor<E, 4>, output: &WgpuTensor<E, 4>) -> Arc<Buffer> {
|
||||
let mut info: [u32; 16] = [0; 16];
|
||||
info[0] = x.strides[0] as u32;
|
||||
info[1] = x.strides[1] as u32;
|
||||
info[2] = x.strides[2] as u32;
|
||||
info[3] = x.strides[3] as u32;
|
||||
info[4] = x.shape.dims[0] as u32;
|
||||
info[5] = x.shape.dims[1] as u32;
|
||||
info[6] = x.shape.dims[2] as u32;
|
||||
info[7] = x.shape.dims[3] as u32;
|
||||
|
||||
info[8] = output.strides[0] as u32;
|
||||
info[9] = output.strides[1] as u32;
|
||||
info[10] = output.strides[2] as u32;
|
||||
info[11] = output.strides[3] as u32;
|
||||
info[12] = output.shape.dims[0] as u32;
|
||||
info[13] = output.shape.dims[1] as u32;
|
||||
info[14] = output.shape.dims[2] as u32;
|
||||
info[15] = output.shape.dims[3] as u32;
|
||||
|
||||
output
|
||||
.context
|
||||
.create_buffer_with_data(bytemuck::cast_slice(&info))
|
||||
}
|
|
@ -1,7 +1,9 @@
|
|||
mod adaptive_avg_pool2d;
|
||||
mod avg_pool2d;
|
||||
mod base;
|
||||
mod max_pool2d;
|
||||
|
||||
pub(crate) use adaptive_avg_pool2d::*;
|
||||
pub use avg_pool2d::*;
|
||||
pub(super) use base::*;
|
||||
pub use max_pool2d::*;
|
||||
|
|
|
@ -90,4 +90,18 @@ where
|
|||
padding,
|
||||
))
|
||||
}
|
||||
|
||||
fn adaptive_avg_pool2d(
|
||||
x: FloatTensor<Self, 4>,
|
||||
output_size: [usize; 2],
|
||||
) -> FloatTensor<Self, 4> {
|
||||
kernel::pool::adaptive_avg_pool2d(x, output_size)
|
||||
}
|
||||
|
||||
fn adaptive_avg_pool2d_backward(
|
||||
x: FloatTensor<Self, 4>,
|
||||
grad: FloatTensor<Self, 4>,
|
||||
) -> FloatTensor<Self, 4> {
|
||||
kernel::pool::adaptive_avg_pool2d_backward(x, grad)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,74 @@
|
|||
@group(0)
|
||||
@binding(0)
|
||||
var<storage, read> x: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(1)
|
||||
var<storage, read_write> output: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(2)
|
||||
var<storage, read> info: array<u32, 16>;
|
||||
|
||||
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
|
||||
|
||||
@compute
|
||||
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(num_workgroups) num_workgroups: vec3<u32>,
|
||||
) {
|
||||
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
|
||||
|
||||
let input_stride_0 = info[0];
|
||||
let input_stride_1 = info[1];
|
||||
let input_stride_2 = info[2];
|
||||
let input_stride_3 = info[3];
|
||||
let input_shape_0 = info[4];
|
||||
let input_shape_1 = info[5];
|
||||
let input_shape_2 = info[6];
|
||||
let input_shape_3 = info[7];
|
||||
|
||||
let output_stride_0 = info[8];
|
||||
let output_stride_1 = info[9];
|
||||
let output_stride_2 = info[10];
|
||||
let output_stride_3 = info[11];
|
||||
let output_shape_0 = info[12];
|
||||
let output_shape_1 = info[13];
|
||||
let output_shape_2 = info[14];
|
||||
let output_shape_3 = info[15];
|
||||
|
||||
let b = id / output_stride_0 % output_shape_0;
|
||||
let c = id / output_stride_1 % output_shape_1;
|
||||
let oh = id / output_stride_2 % output_shape_2;
|
||||
let ow = id / output_stride_3 % output_shape_3;
|
||||
|
||||
let ih_start = start_index(oh, output_shape_2, input_shape_2);
|
||||
let ih_end = end_index(oh, output_shape_2, input_shape_2);
|
||||
|
||||
let iw_start = start_index(ow, output_shape_3, input_shape_3);
|
||||
let iw_end = end_index(ow, output_shape_3, input_shape_3);
|
||||
|
||||
var sum = 0.0;
|
||||
|
||||
for (var ih = ih_start; ih < ih_end; ih++) {
|
||||
for (var iw = iw_start; iw < iw_end; iw++) {
|
||||
let index_input = b * input_stride_0 + c * input_stride_1 + ih * input_stride_2 + iw * input_stride_3;
|
||||
sum += x[index_input];
|
||||
}
|
||||
}
|
||||
|
||||
let count = {{ elem }}((ih_end - ih_start) * (iw_end - iw_start));
|
||||
output[id] = sum / count;
|
||||
}
|
||||
|
||||
fn start_index(output_size_index: u32, output_size: u32, input_size: u32) -> u32 {
|
||||
return u32(floor((f32(output_size_index) * f32(input_size)) / f32(output_size)));
|
||||
}
|
||||
|
||||
fn end_index(output_size_index: u32, output_size: u32, input_size: u32) -> u32 {
|
||||
let index = u32(ceil((f32(output_size_index + 1u) * f32(input_size)) / f32(output_size)));
|
||||
|
||||
return min(index, input_size);
|
||||
}
|
||||
|
|
@ -0,0 +1,88 @@
|
|||
@group(0)
|
||||
@binding(0)
|
||||
var<storage, read> grad: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(1)
|
||||
var<storage, read_write> output: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(2)
|
||||
var<storage, read> info: array<u32, 16>;
|
||||
|
||||
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
|
||||
|
||||
@compute
|
||||
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(num_workgroups) num_workgroups: vec3<u32>,
|
||||
) {
|
||||
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
|
||||
|
||||
let input_stride_0 = info[0];
|
||||
let input_stride_1 = info[1];
|
||||
let input_stride_2 = info[2];
|
||||
let input_stride_3 = info[3];
|
||||
let input_shape_0 = info[4];
|
||||
let input_shape_1 = info[5];
|
||||
let input_shape_2 = info[6];
|
||||
let input_shape_3 = info[7];
|
||||
|
||||
let grad_stride_0 = info[8];
|
||||
let grad_stride_1 = info[9];
|
||||
let grad_stride_2 = info[10];
|
||||
let grad_stride_3 = info[11];
|
||||
let grad_shape_0 = info[12];
|
||||
let grad_shape_1 = info[13];
|
||||
let grad_shape_2 = info[14];
|
||||
let grad_shape_3 = info[15];
|
||||
|
||||
let b = id / input_stride_0 % input_shape_0;
|
||||
let c = id / input_stride_1 % input_shape_1;
|
||||
let ih = id / input_stride_2 % input_shape_2;
|
||||
let iw = id / input_stride_3 % input_shape_3;
|
||||
|
||||
let oh_start = start_index(ih, input_shape_2, grad_shape_2);
|
||||
let oh_end = end_index(ih, input_shape_2, grad_shape_2);
|
||||
|
||||
let ow_start = start_index(iw, input_shape_3, grad_shape_3);
|
||||
let ow_end = end_index(iw, input_shape_3, grad_shape_3);
|
||||
|
||||
var grad_acc = 0.0;
|
||||
|
||||
for (var oh = oh_start; oh < oh_end; oh++) {
|
||||
for (var ow = ow_start; ow < ow_end; ow++) {
|
||||
let ih_start = start_index(oh, grad_shape_2, input_shape_2);
|
||||
let ih_end = end_index(oh, grad_shape_2, input_shape_2);
|
||||
|
||||
let iw_start = start_index(ow, grad_shape_3, input_shape_3);
|
||||
let iw_end = end_index(ow, grad_shape_3, input_shape_3);
|
||||
|
||||
let contributed_h = ih >= ih_start && ih < ih_end;
|
||||
let contributed_w = iw >= iw_start && iw < iw_end;
|
||||
|
||||
// If no contribution skip
|
||||
if !contributed_h || !contributed_w {
|
||||
continue;
|
||||
}
|
||||
|
||||
let index = b * grad_stride_0 + c * grad_stride_1 + oh * grad_stride_2 + ow * grad_stride_3;
|
||||
let count = {{ elem }}((ih_end - ih_start) * (iw_end - iw_start));
|
||||
|
||||
grad_acc += grad[index] / count;
|
||||
}
|
||||
}
|
||||
|
||||
output[id] = grad_acc;
|
||||
}
|
||||
|
||||
fn start_index(output_size_index: u32, output_size: u32, input_size: u32) -> u32 {
|
||||
return u32(floor((f32(output_size_index) * f32(input_size)) / f32(output_size)));
|
||||
}
|
||||
|
||||
fn end_index(output_size_index: u32, output_size: u32, input_size: u32) -> u32 {
|
||||
let index = u32(ceil((f32(output_size_index + 1u) * f32(input_size)) / f32(output_size)));
|
||||
|
||||
return min(index, input_size);
|
||||
}
|
Loading…
Reference in New Issue