Feat/tensor/adaptive avg pool2d (#572)

This commit is contained in:
Nathaniel Simard 2023-08-04 10:23:59 -04:00 committed by GitHub
parent 597eab524d
commit 8436d4ff66
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 713 additions and 47 deletions

View File

@ -38,12 +38,11 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
}
fn embedding_backward(
weights: ADTensor<B, 2>,
output: ADTensor<B, 3>,
indices: IntTensor<B, 2>,
_weights: ADTensor<B, 2>,
_output: ADTensor<B, 3>,
_indices: IntTensor<B, 2>,
) -> ADTensor<B, 2> {
let tensor = B::embedding_backward(weights.primitive, output.primitive, indices);
ADTensor::new(tensor)
panic!("Can't differentiate embedding backward.");
}
fn conv2d(
@ -446,16 +445,15 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
}
}
}
fn avg_pool2d_backward(
x: ADTensor<B, 4>,
grad: ADTensor<B, 4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
_x: ADTensor<B, 4>,
_grad: ADTensor<B, 4>,
_kernel_size: [usize; 2],
_stride: [usize; 2],
_padding: [usize; 2],
) -> ADTensor<B, 4> {
let tensor =
B::avg_pool2d_backward(x.primitive, grad.primitive, kernel_size, stride, padding);
ADTensor::new(tensor)
panic!("Can't differentiate avg pool 2d backward.");
}
fn max_pool2d(
@ -513,22 +511,50 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
}
fn max_pool2d_with_indices_backward(
x: ADTensor<B, 4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
output_grad: ADTensor<B, 4>,
indices: IntTensor<B, 4>,
_x: ADTensor<B, 4>,
_kernel_size: [usize; 2],
_stride: [usize; 2],
_padding: [usize; 2],
_output_grad: ADTensor<B, 4>,
_indices: IntTensor<B, 4>,
) -> MaxPool2dBackward<ADBackendDecorator<B>> {
let output = B::max_pool2d_with_indices_backward(
x.primitive,
kernel_size,
stride,
padding,
output_grad.primitive,
indices,
);
MaxPool2dBackward::new(ADTensor::new(output.x_grad))
panic!("Can't differentiate max pool2d with indices backward.");
}
fn adaptive_avg_pool2d(x: ADTensor<B, 4>, output_size: [usize; 2]) -> ADTensor<B, 4> {
#[derive(Debug)]
struct AdaptiveAvgPool2D;
impl<B: Backend> Backward<B, 4, 1> for AdaptiveAvgPool2D {
type State = B::TensorPrimitive<4>;
fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
let [node_parent] = ops.parents;
let grad = grads.consume::<B, 4>(&ops.node);
if let Some(node) = node_parent {
let grad = B::adaptive_avg_pool2d_backward(ops.state, grad);
grads.register::<B, 4>(node, grad);
}
}
}
match AdaptiveAvgPool2D.prepare([x.node], [x.graph]).statefull() {
OpsKind::Tracked(prep) => prep.finish(
x.primitive.clone(),
B::adaptive_avg_pool2d(x.primitive, output_size),
),
OpsKind::UnTracked(prep) => {
prep.finish(B::adaptive_avg_pool2d(x.primitive, output_size))
}
}
}
fn adaptive_avg_pool2d_backward(
_x: ADTensor<B, 4>,
_grad: ADTensor<B, 4>,
) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<4> {
panic!("Can't differentiate adaptive avg pool2d backward.");
}
}

View File

@ -0,0 +1,64 @@
#[burn_tensor_testgen::testgen(ad_adaptive_avg_pool2d)]
mod tests {
use super::*;
use burn_tensor::module::adaptive_avg_pool2d;
use burn_tensor::{Data, Shape, Tensor};
#[test]
fn test_avg_pool2d_simple() {
let test = AdaptiveAvgPool2dTestCase {
batch_size: 1,
channels: 2,
height: 5,
width: 3,
output_size_1: 3,
output_size_2: 2,
};
test.assert_output(TestTensor::from_floats([[
[
[0.2500, 0.5000, 0.2500],
[0.4167, 0.8333, 0.4167],
[0.1667, 0.3333, 0.1667],
[0.4167, 0.8333, 0.4167],
[0.2500, 0.5000, 0.2500],
],
[
[0.2500, 0.5000, 0.2500],
[0.4167, 0.8333, 0.4167],
[0.1667, 0.3333, 0.1667],
[0.4167, 0.8333, 0.4167],
[0.2500, 0.5000, 0.2500],
],
]]));
}
struct AdaptiveAvgPool2dTestCase {
batch_size: usize,
channels: usize,
height: usize,
width: usize,
output_size_1: usize,
output_size_2: usize,
}
impl AdaptiveAvgPool2dTestCase {
fn assert_output(self, x_grad: TestTensor<4>) {
let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]);
let x = TestADTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements())
.reshape(shape_x)
.into_data()
.convert(),
)
.require_grad();
let output = adaptive_avg_pool2d(x.clone(), [self.output_size_1, self.output_size_2]);
let grads = output.backward();
let x_grad_actual = x.grad(&grads).unwrap();
x_grad
.to_data()
.assert_approx_eq(&x_grad_actual.into_data(), 3);
}
}
}

View File

@ -1,6 +1,7 @@
#![allow(missing_docs)]
mod abs;
mod adaptive_avgpool2d;
mod add;
mod aggregation;
mod avgpool1d;
@ -60,6 +61,7 @@ macro_rules! testgen_all {
burn_autodiff::testgen_ad_max_pool2d!();
burn_autodiff::testgen_ad_avg_pool1d!();
burn_autodiff::testgen_ad_avg_pool2d!();
burn_autodiff::testgen_ad_adaptive_avg_pool2d!();
burn_autodiff::testgen_module_backward!();
// Tensor

View File

@ -0,0 +1,41 @@
use crate as burn;
use crate::config::Config;
use crate::module::Module;
use crate::tensor::backend::Backend;
use crate::tensor::Tensor;
use burn_tensor::module::adaptive_avg_pool2d;
/// Configuration to create a [2D adaptive avg pooling](AdaptiveAvgPool2d) layer.
#[derive(Config)]
pub struct AdaptiveAvgPool2dConfig {
/// The size of the output.
pub output_size: [usize; 2],
}
/// Applies a 2D adaptive avg pooling over input tensors.
#[derive(Module, Debug, Clone)]
pub struct AdaptiveAvgPool2d {
output_size: [usize; 2],
}
impl AdaptiveAvgPool2dConfig {
/// Initialize a new [adaptive avg pool 2d](AdaptiveAvgPool2d) module.
pub fn init(&self) -> AdaptiveAvgPool2d {
AdaptiveAvgPool2d {
output_size: self.output_size,
}
}
}
impl AdaptiveAvgPool2d {
/// Applies the forward pass on the input tensor.
///
/// # Shapes
///
/// - input: [batch_size, channels, height_in, width_in],
/// - output: [batch_size, channels, height_out, width_out],
pub fn forward<B: Backend>(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
adaptive_avg_pool2d(input, self.output_size)
}
}

View File

@ -10,8 +10,6 @@ use burn_tensor::module::avg_pool1d;
/// Configuration to create a [1D avg pooling](AvgPool1d) layer.
#[derive(Config)]
pub struct AvgPool1dConfig {
/// The number of channels.
pub channels: usize,
/// The size of the kernel.
pub kernel_size: usize,
/// The stride.

View File

@ -10,8 +10,6 @@ use burn_tensor::module::avg_pool2d;
/// Configuration to create a [2D avg pooling](AvgPool2d) layer.
#[derive(Config)]
pub struct AvgPool2dConfig {
/// The number of channels.
pub channels: usize,
/// The size of the kernel.
pub kernel_size: [usize; 2],
/// The strides.

View File

@ -10,8 +10,6 @@ use burn_tensor::module::max_pool2d;
/// Configuration to create an [2D max pooling](MaxPool2d) layer.
#[derive(Debug, Config)]
pub struct MaxPool2dConfig {
/// The number of channels.
pub channels: usize,
/// The size of the kernel.
pub kernel_size: [usize; 2],
/// The strides.

View File

@ -1,7 +1,9 @@
mod adaptive_avg_pool2d;
mod avg_pool1d;
mod avg_pool2d;
mod max_pool2d;
pub use adaptive_avg_pool2d::*;
pub use avg_pool1d::*;
pub use avg_pool2d::*;
pub use max_pool2d::*;

View File

@ -48,7 +48,6 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for MaxPool2dNode {
fn field_init(&self, _with_record: bool) -> Option<TokenStream> {
let name = &self.field.name;
let channels = self.config.channels.to_tokens();
let kernel_size = self.config.kernel_size.to_tokens();
let strides = self.config.strides.to_tokens();
let padding = self.config.padding.to_tokens();
@ -58,7 +57,7 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for MaxPool2dNode {
};
let tokens = quote! {
let #name = MaxPool2dConfig::new(#channels, #kernel_size)
let #name = MaxPool2dConfig::new(#kernel_size)
.with_strides(#strides)
.with_padding(#padding)
.#init_line
@ -105,7 +104,7 @@ mod tests {
"max_pool2d",
TensorType::new_float("input", 4),
TensorType::new_float("output", 4),
MaxPool2dConfig::new(1, [3, 3])
MaxPool2dConfig::new([3, 3])
.with_strides([1, 1])
.with_padding(PaddingConfig2d::Valid),
));
@ -126,7 +125,7 @@ mod tests {
impl<B: Backend> Model <B> {
pub fn new_with(record: ModelRecord<B>) -> Self {
let max_pool2d = MaxPool2dConfig::new(1, [3, 3])
let max_pool2d = MaxPool2dConfig::new([3, 3])
.with_strides([1, 1])
.with_padding(PaddingConfig2d::Valid)
.init();

View File

@ -69,14 +69,12 @@ pub fn conv2d_config(curr: &Node) -> Conv2dConfig {
/// Create a MaxPool2dConfig from the attributes of the node
pub fn max_pool2d_config(curr: &Node) -> MaxPool2dConfig {
let mut channels: i64 = 1;
let mut kernel_shape = Vec::new();
let mut strides = Vec::new();
let mut pads = Vec::new();
for (key, value) in curr.attrs.iter() {
match key.as_str() {
"channels" => attr_value_i64(value, &mut channels),
"kernel_shape" => attr_value_vec_i64(value, &mut kernel_shape),
"strides" => attr_value_vec_i64(value, &mut strides),
"pads" => attr_value_vec_i64(value, &mut pads),
@ -86,12 +84,9 @@ pub fn max_pool2d_config(curr: &Node) -> MaxPool2dConfig {
let padding = padding_config(&pads);
MaxPool2dConfig::new(
channels as usize,
[kernel_shape[0] as usize, kernel_shape[1] as usize],
)
.with_strides([strides[0] as usize, strides[1] as usize])
.with_padding(padding)
MaxPool2dConfig::new([kernel_shape[0] as usize, kernel_shape[1] as usize])
.with_strides([strides[0] as usize, strides[1] as usize])
.with_padding(padding)
}
/// Create a FlattenConfig from the attributes of the node

View File

@ -0,0 +1,103 @@
use crate::{
element::FloatNdArrayElement, iter_par, run_par, sharing::UnsafeSharedRef,
tensor::NdArrayTensor,
};
use burn_tensor::ElementConversion;
use ndarray::Array4;
pub(crate) fn adaptive_avg_pool2d<E: FloatNdArrayElement>(
x: NdArrayTensor<E, 4>,
output_size: [usize; 2],
) -> NdArrayTensor<E, 4> {
let [batch_size, channels, input_height, input_width] = x.shape().dims;
let x = x.array;
let mut output = Array4::from_elem(
(batch_size, channels, output_size[0], output_size[1]),
0.elem(),
);
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
run_par!(|| {
iter_par!(0, batch_size * channels).for_each(|k| unsafe {
let b = k / channels;
let c = k % channels;
let output = unsafe_shared_out.get();
for h in 0..output_size[0] {
for w in 0..output_size[1] {
let ih_start = start_index(h, output_size[0], input_height);
let ih_end = end_index(h, output_size[0], input_height);
let iw_start = start_index(w, output_size[1], input_width);
let iw_end = end_index(w, output_size[1], input_width);
let mut sum_val: E = 0.elem();
for ih in ih_start..ih_end {
for iw in iw_start..iw_end {
sum_val += x[[b, c, ih, iw]];
}
}
let count: E = (((ih_end - ih_start) * (iw_end - iw_start)) as i32).elem();
output[[b, c, h, w]] = sum_val / count.elem();
}
}
})
});
NdArrayTensor::new(output.into_dyn().into_shared())
}
pub(crate) fn adaptive_avg_pool2d_backward<E: FloatNdArrayElement>(
x: NdArrayTensor<E, 4>,
grad: NdArrayTensor<E, 4>,
) -> NdArrayTensor<E, 4> {
let [_, _, input_height, input_width] = x.shape().dims;
let [batch_size, channels, output_height, output_width] = grad.shape().dims;
let mut output_grad =
Array4::from_elem((batch_size, channels, input_height, input_width), 0.elem());
let unsafe_shared_out = UnsafeSharedRef::new(&mut output_grad);
run_par!(|| {
iter_par!(0, batch_size * channels).for_each(|k| unsafe {
let b = k / channels;
let c = k % channels;
let output_grad = unsafe_shared_out.get();
for oh in 0..output_height {
for ow in 0..output_width {
let ih_start = start_index(oh, output_height, input_height);
let ih_end = end_index(oh, output_height, input_height);
let iw_start = start_index(ow, output_width, input_width);
let iw_end = end_index(ow, output_width, input_width);
let count: E = (((ih_end - ih_start) * (iw_end - iw_start)) as i32).elem();
for ih in ih_start..ih_end {
for iw in iw_start..iw_end {
output_grad[[b, c, ih, iw]] +=
grad.array[[b, c, oh, ow]] / count.elem();
}
}
}
}
})
});
NdArrayTensor::new(output_grad.into_dyn().into_shared())
}
fn start_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize {
libm::floorf((output_size_index as f32 * input_size as f32) / output_size as f32) as usize
}
fn end_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize {
let index =
libm::ceilf(((output_size_index + 1) as f32 * input_size as f32) / output_size as f32)
as usize;
usize::min(index, input_size)
}

View File

@ -5,6 +5,7 @@ mod int_tensor;
mod module;
mod tensor;
pub(crate) mod adaptive_avgpool;
pub(crate) mod avgpool;
pub(crate) mod conv;
pub(crate) mod macros;

View File

@ -1,4 +1,5 @@
use super::{
adaptive_avgpool::{adaptive_avg_pool2d, adaptive_avg_pool2d_backward},
avgpool::{avg_pool2d, avg_pool2d_backward},
conv::{conv2d, conv_transpose2d},
maxpool::{max_pool2d, max_pool2d_backward, max_pool2d_with_indices},
@ -81,4 +82,15 @@ impl<E: FloatNdArrayElement> ModuleOps<NdArrayBackend<E>> for NdArrayBackend<E>
indices,
))
}
fn adaptive_avg_pool2d(x: NdArrayTensor<E, 4>, output_size: [usize; 2]) -> NdArrayTensor<E, 4> {
adaptive_avg_pool2d(x, output_size)
}
fn adaptive_avg_pool2d_backward(
x: NdArrayTensor<E, 4>,
grad: NdArrayTensor<E, 4>,
) -> NdArrayTensor<E, 4> {
adaptive_avg_pool2d_backward(x, grad)
}
}

View File

@ -26,7 +26,7 @@ impl<B: Backend> ConvBlock<B> {
let conv = nn::conv::Conv2dConfig::new(config.channels, config.kernel_size)
.with_padding(nn::PaddingConfig2d::Same)
.init();
let pool = nn::pool::MaxPool2dConfig::new(config.channels[1], config.kernel_size)
let pool = nn::pool::MaxPool2dConfig::new(config.kernel_size)
.with_padding(nn::PaddingConfig2d::Same)
.init();
let activation = nn::GELU::new();

View File

@ -220,4 +220,16 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
MaxPool2dBackward::new(TchTensor::new(grad))
}
fn adaptive_avg_pool2d(x: TchTensor<E, 4>, output_size: [usize; 2]) -> TchTensor<E, 4> {
let tensor = tch::Tensor::adaptive_avg_pool2d(&x.tensor, output_size.map(|e| e as i64));
TchTensor::new(tensor)
}
fn adaptive_avg_pool2d_backward(x: TchTensor<E, 4>, grad: TchTensor<E, 4>) -> TchTensor<E, 4> {
let tensor = tch::Tensor::internal_adaptive_avg_pool2d_backward(&x.tensor, &grad.tensor);
TchTensor::new(tensor)
}
}

View File

@ -137,3 +137,11 @@ where
(Tensor::new(output.output), Tensor::new(output.indices))
}
/// Applies a [2D adaptive avg pooling](crate::ops::ModuleOps::adaptive_avg_pool2d).
pub fn adaptive_avg_pool2d<B>(x: Tensor<B, 4>, output_size: [usize; 2]) -> Tensor<B, 4>
where
B: Backend,
{
Tensor::new(B::adaptive_avg_pool2d(x.primitive, output_size))
}

View File

@ -262,6 +262,22 @@ pub trait ModuleOps<B: Backend> {
padding: [usize; 2],
) -> B::TensorPrimitive<4>;
/// Two dimensional adaptive avg pooling.
///
/// # Shapes
///
/// x: [batch_size, channels, height, width],
fn adaptive_avg_pool2d(
x: B::TensorPrimitive<4>,
output_size: [usize; 2],
) -> B::TensorPrimitive<4>;
/// Backward pass for the [adaptive avg pooling 2d](ModuleOps::adaptive_avg_pool2d) operation.
fn adaptive_avg_pool2d_backward(
x: B::TensorPrimitive<4>,
grad: B::TensorPrimitive<4>,
) -> B::TensorPrimitive<4>;
/// Two dimensional max pooling.
///
/// # Shapes

View File

@ -23,6 +23,7 @@ macro_rules! testgen_all {
burn_tensor::testgen_module_max_pool2d!();
burn_tensor::testgen_module_avg_pool1d!();
burn_tensor::testgen_module_avg_pool2d!();
burn_tensor::testgen_module_adaptive_avg_pool2d!();
// test ops
burn_tensor::testgen_add!();

View File

@ -0,0 +1,103 @@
#[burn_tensor_testgen::testgen(module_adaptive_avg_pool2d)]
mod tests {
use super::*;
use burn_tensor::module::adaptive_avg_pool2d;
use burn_tensor::{Data, Shape, Tensor};
#[test]
fn test_adaptive_avg_pool2d_simple() {
let test = AdaptiveAvgPool2dTestCase {
batch_size: 1,
channels: 2,
height: 8,
width: 6,
height_out: 4,
width_out: 4,
};
test.assert_output(TestTensor::from_floats([[
[
[3.5000, 4.5000, 6.5000, 7.5000],
[15.5000, 16.5000, 18.5000, 19.5000],
[27.5000, 28.5000, 30.5000, 31.5000],
[39.5000, 40.5000, 42.5000, 43.5000],
],
[
[51.5000, 52.5000, 54.5000, 55.5000],
[63.5000, 64.5000, 66.5000, 67.5000],
[75.5000, 76.5000, 78.5000, 79.5000],
[87.5000, 88.5000, 90.5000, 91.5000],
],
]]));
}
#[test]
fn test_adaptive_avg_pool2d_dyn_filter_size() {
let test = AdaptiveAvgPool2dTestCase {
batch_size: 1,
channels: 2,
height: 5,
width: 7,
height_out: 3,
width_out: 2,
};
test.assert_output(TestTensor::from_floats([[
[[5.0000, 8.0000], [15.5000, 18.5000], [26.0000, 29.0000]],
[[40.0000, 43.0000], [50.5000, 53.5000], [61.0000, 64.0000]],
]]));
}
#[test]
fn test_adaptive_avg_pool2d_bigger_output() {
let test = AdaptiveAvgPool2dTestCase {
batch_size: 1,
channels: 2,
height: 4,
width: 3,
height_out: 5,
width_out: 4,
};
test.assert_output(TestTensor::from_floats([[
[
[0.0000, 0.5000, 1.5000, 2.0000],
[1.5000, 2.0000, 3.0000, 3.5000],
[4.5000, 5.0000, 6.0000, 6.5000],
[7.5000, 8.0000, 9.0000, 9.5000],
[9.0000, 9.5000, 10.5000, 11.0000],
],
[
[12.0000, 12.5000, 13.5000, 14.0000],
[13.5000, 14.0000, 15.0000, 15.5000],
[16.5000, 17.0000, 18.0000, 18.5000],
[19.5000, 20.0000, 21.0000, 21.5000],
[21.0000, 21.5000, 22.5000, 23.0000],
],
]]));
}
struct AdaptiveAvgPool2dTestCase {
batch_size: usize,
channels: usize,
height: usize,
width: usize,
height_out: usize,
width_out: usize,
}
impl AdaptiveAvgPool2dTestCase {
fn assert_output(self, y: TestTensor<4>) {
let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]);
let x = TestTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements())
.reshape(shape_x)
.into_data()
.convert(),
);
let output = adaptive_avg_pool2d(x, [self.height_out, self.width_out]);
y.to_data().assert_approx_eq(&output.into_data(), 3);
}
}
}

View File

@ -1,3 +1,4 @@
mod adaptive_avgpool2d;
mod avgpool1d;
mod avgpool2d;
mod conv1d;

View File

@ -0,0 +1,108 @@
use std::sync::Arc;
use burn_tensor::Shape;
use wgpu::Buffer;
use crate::{
element::WgpuElement,
kernel::{elemwise_workgroup, KernelSettings},
kernel_wgsl,
tensor::WgpuTensor,
};
kernel_wgsl!(
AdaptiveAvgPool2d,
"../../template/pool/adaptive_avg_pool2d.wgsl"
);
kernel_wgsl!(
AdaptiveAvgPool2dBackward,
"../../template/pool/adaptive_avg_pool2d_backward.wgsl"
);
pub(crate) fn adaptive_avg_pool2d<E: WgpuElement>(
x: WgpuTensor<E, 4>,
output_size: [usize; 2],
) -> WgpuTensor<E, 4> {
const WORKGROUP: usize = 32;
let [batch_size, channels, _, _] = x.shape.dims;
let output_shape = Shape::new([batch_size, channels, output_size[0], output_size[1]]);
let num_elems = output_shape.num_elements();
let output_buffer = x
.context
.create_buffer(num_elems * core::mem::size_of::<E>());
let output = WgpuTensor::new(x.context.clone(), output_shape, output_buffer);
let kernel = x
.context
.compile_static::<KernelSettings<AdaptiveAvgPool2d, E, i32, WORKGROUP, WORKGROUP, 1>>();
let info_buffer = build_info(&x, &output);
x.context.execute(
elemwise_workgroup(output.shape.num_elements(), WORKGROUP),
kernel,
&[&x.buffer, &output.buffer, &info_buffer],
);
output
}
pub(crate) fn adaptive_avg_pool2d_backward<E: WgpuElement>(
x: WgpuTensor<E, 4>,
out_grad: WgpuTensor<E, 4>,
) -> WgpuTensor<E, 4> {
const WORKGROUP: usize = 32;
let output_shape = x.shape.clone();
let num_elems = output_shape.num_elements();
let output_buffer = x
.context
.create_buffer(num_elems * core::mem::size_of::<E>());
let output = WgpuTensor::new(x.context.clone(), output_shape, output_buffer);
let kernel = x.context.compile_static::<KernelSettings<
AdaptiveAvgPool2dBackward,
E,
i32,
WORKGROUP,
WORKGROUP,
1,
>>();
let info_buffer = build_info(&x, &out_grad);
x.context.execute(
elemwise_workgroup(output.shape.num_elements(), WORKGROUP),
kernel,
&[&out_grad.buffer, &output.buffer, &info_buffer],
);
output
}
fn build_info<E: WgpuElement>(x: &WgpuTensor<E, 4>, output: &WgpuTensor<E, 4>) -> Arc<Buffer> {
let mut info: [u32; 16] = [0; 16];
info[0] = x.strides[0] as u32;
info[1] = x.strides[1] as u32;
info[2] = x.strides[2] as u32;
info[3] = x.strides[3] as u32;
info[4] = x.shape.dims[0] as u32;
info[5] = x.shape.dims[1] as u32;
info[6] = x.shape.dims[2] as u32;
info[7] = x.shape.dims[3] as u32;
info[8] = output.strides[0] as u32;
info[9] = output.strides[1] as u32;
info[10] = output.strides[2] as u32;
info[11] = output.strides[3] as u32;
info[12] = output.shape.dims[0] as u32;
info[13] = output.shape.dims[1] as u32;
info[14] = output.shape.dims[2] as u32;
info[15] = output.shape.dims[3] as u32;
output
.context
.create_buffer_with_data(bytemuck::cast_slice(&info))
}

View File

@ -1,7 +1,9 @@
mod adaptive_avg_pool2d;
mod avg_pool2d;
mod base;
mod max_pool2d;
pub(crate) use adaptive_avg_pool2d::*;
pub use avg_pool2d::*;
pub(super) use base::*;
pub use max_pool2d::*;

View File

@ -90,4 +90,18 @@ where
padding,
))
}
fn adaptive_avg_pool2d(
x: FloatTensor<Self, 4>,
output_size: [usize; 2],
) -> FloatTensor<Self, 4> {
kernel::pool::adaptive_avg_pool2d(x, output_size)
}
fn adaptive_avg_pool2d_backward(
x: FloatTensor<Self, 4>,
grad: FloatTensor<Self, 4>,
) -> FloatTensor<Self, 4> {
kernel::pool::adaptive_avg_pool2d_backward(x, grad)
}
}

View File

@ -0,0 +1,74 @@
@group(0)
@binding(0)
var<storage, read> x: array<{{ elem }}>;
@group(0)
@binding(1)
var<storage, read_write> output: array<{{ elem }}>;
@group(0)
@binding(2)
var<storage, read> info: array<u32, 16>;
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
@compute
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
) {
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
let input_stride_0 = info[0];
let input_stride_1 = info[1];
let input_stride_2 = info[2];
let input_stride_3 = info[3];
let input_shape_0 = info[4];
let input_shape_1 = info[5];
let input_shape_2 = info[6];
let input_shape_3 = info[7];
let output_stride_0 = info[8];
let output_stride_1 = info[9];
let output_stride_2 = info[10];
let output_stride_3 = info[11];
let output_shape_0 = info[12];
let output_shape_1 = info[13];
let output_shape_2 = info[14];
let output_shape_3 = info[15];
let b = id / output_stride_0 % output_shape_0;
let c = id / output_stride_1 % output_shape_1;
let oh = id / output_stride_2 % output_shape_2;
let ow = id / output_stride_3 % output_shape_3;
let ih_start = start_index(oh, output_shape_2, input_shape_2);
let ih_end = end_index(oh, output_shape_2, input_shape_2);
let iw_start = start_index(ow, output_shape_3, input_shape_3);
let iw_end = end_index(ow, output_shape_3, input_shape_3);
var sum = 0.0;
for (var ih = ih_start; ih < ih_end; ih++) {
for (var iw = iw_start; iw < iw_end; iw++) {
let index_input = b * input_stride_0 + c * input_stride_1 + ih * input_stride_2 + iw * input_stride_3;
sum += x[index_input];
}
}
let count = {{ elem }}((ih_end - ih_start) * (iw_end - iw_start));
output[id] = sum / count;
}
fn start_index(output_size_index: u32, output_size: u32, input_size: u32) -> u32 {
return u32(floor((f32(output_size_index) * f32(input_size)) / f32(output_size)));
}
fn end_index(output_size_index: u32, output_size: u32, input_size: u32) -> u32 {
let index = u32(ceil((f32(output_size_index + 1u) * f32(input_size)) / f32(output_size)));
return min(index, input_size);
}

View File

@ -0,0 +1,88 @@
@group(0)
@binding(0)
var<storage, read> grad: array<{{ elem }}>;
@group(0)
@binding(1)
var<storage, read_write> output: array<{{ elem }}>;
@group(0)
@binding(2)
var<storage, read> info: array<u32, 16>;
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
@compute
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
) {
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
let input_stride_0 = info[0];
let input_stride_1 = info[1];
let input_stride_2 = info[2];
let input_stride_3 = info[3];
let input_shape_0 = info[4];
let input_shape_1 = info[5];
let input_shape_2 = info[6];
let input_shape_3 = info[7];
let grad_stride_0 = info[8];
let grad_stride_1 = info[9];
let grad_stride_2 = info[10];
let grad_stride_3 = info[11];
let grad_shape_0 = info[12];
let grad_shape_1 = info[13];
let grad_shape_2 = info[14];
let grad_shape_3 = info[15];
let b = id / input_stride_0 % input_shape_0;
let c = id / input_stride_1 % input_shape_1;
let ih = id / input_stride_2 % input_shape_2;
let iw = id / input_stride_3 % input_shape_3;
let oh_start = start_index(ih, input_shape_2, grad_shape_2);
let oh_end = end_index(ih, input_shape_2, grad_shape_2);
let ow_start = start_index(iw, input_shape_3, grad_shape_3);
let ow_end = end_index(iw, input_shape_3, grad_shape_3);
var grad_acc = 0.0;
for (var oh = oh_start; oh < oh_end; oh++) {
for (var ow = ow_start; ow < ow_end; ow++) {
let ih_start = start_index(oh, grad_shape_2, input_shape_2);
let ih_end = end_index(oh, grad_shape_2, input_shape_2);
let iw_start = start_index(ow, grad_shape_3, input_shape_3);
let iw_end = end_index(ow, grad_shape_3, input_shape_3);
let contributed_h = ih >= ih_start && ih < ih_end;
let contributed_w = iw >= iw_start && iw < iw_end;
// If no contribution skip
if !contributed_h || !contributed_w {
continue;
}
let index = b * grad_stride_0 + c * grad_stride_1 + oh * grad_stride_2 + ow * grad_stride_3;
let count = {{ elem }}((ih_end - ih_start) * (iw_end - iw_start));
grad_acc += grad[index] / count;
}
}
output[id] = grad_acc;
}
fn start_index(output_size_index: u32, output_size: u32, input_size: u32) -> u32 {
return u32(floor((f32(output_size_index) * f32(input_size)) / f32(output_size)));
}
fn end_index(output_size_index: u32, output_size: u32, input_size: u32) -> u32 {
let index = u32(ceil((f32(output_size_index + 1u) * f32(input_size)) / f32(output_size)));
return min(index, input_size);
}