Dilation maxpool (#668)

This commit is contained in:
Caio Piccirillo 2023-08-21 20:14:25 +02:00 committed by GitHub
parent dd5ea5251c
commit 2fefc82099
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 442 additions and 112 deletions

View File

@ -618,19 +618,36 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
kernel_size: usize,
stride: usize,
padding: usize,
dilation: usize,
) -> ADTensor<B, 3> {
match MaxPool1D.prepare([x.node], [x.graph]).statefull() {
OpsKind::Tracked(prep) => {
let output =
B::max_pool1d_with_indices(x.primitive.clone(), kernel_size, stride, padding);
let output = B::max_pool1d_with_indices(
x.primitive.clone(),
kernel_size,
stride,
padding,
dilation,
);
prep.finish(
(x.primitive, output.indices, kernel_size, stride, padding),
(
x.primitive,
output.indices,
kernel_size,
stride,
padding,
dilation,
),
output.output,
)
}
OpsKind::UnTracked(prep) => {
prep.finish(B::max_pool1d(x.primitive, kernel_size, stride, padding))
}
OpsKind::UnTracked(prep) => prep.finish(B::max_pool1d(
x.primitive,
kernel_size,
stride,
padding,
dilation,
)),
}
}
@ -639,11 +656,17 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
kernel_size: usize,
stride: usize,
padding: usize,
dilation: usize,
) -> MaxPool1dWithIndices<ADBackendDecorator<B>> {
match MaxPool1D.prepare([x.node], [x.graph]).statefull() {
OpsKind::Tracked(prep) => {
let output =
B::max_pool1d_with_indices(x.primitive.clone(), kernel_size, stride, padding);
let output = B::max_pool1d_with_indices(
x.primitive.clone(),
kernel_size,
stride,
padding,
dilation,
);
let output_tensor = prep.finish(
(
@ -652,6 +675,7 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
kernel_size,
stride,
padding,
dilation,
),
output.output,
);
@ -659,7 +683,8 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
MaxPool1dWithIndices::new(output_tensor, output.indices)
}
OpsKind::UnTracked(prep) => {
let output = B::max_pool1d_with_indices(x.primitive, kernel_size, stride, padding);
let output =
B::max_pool1d_with_indices(x.primitive, kernel_size, stride, padding, dilation);
let output_tensor = prep.finish(output.output);
MaxPool1dWithIndices::new(output_tensor, output.indices)
@ -672,6 +697,7 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
kernel_size: usize,
stride: usize,
padding: usize,
dilation: usize,
output_grad: ADTensor<B, 3>,
indices: IntTensor<B, 3>,
) -> MaxPool1dBackward<ADBackendDecorator<B>> {
@ -680,6 +706,7 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
kernel_size,
stride,
padding,
dilation,
output_grad.primitive,
indices,
);
@ -691,19 +718,36 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
) -> ADTensor<B, 4> {
match MaxPool2D.prepare([x.node], [x.graph]).statefull() {
OpsKind::Tracked(prep) => {
let output =
B::max_pool2d_with_indices(x.primitive.clone(), kernel_size, stride, padding);
let output = B::max_pool2d_with_indices(
x.primitive.clone(),
kernel_size,
stride,
padding,
dilation,
);
prep.finish(
(x.primitive, output.indices, kernel_size, stride, padding),
(
x.primitive,
output.indices,
kernel_size,
stride,
padding,
dilation,
),
output.output,
)
}
OpsKind::UnTracked(prep) => {
prep.finish(B::max_pool2d(x.primitive, kernel_size, stride, padding))
}
OpsKind::UnTracked(prep) => prep.finish(B::max_pool2d(
x.primitive,
kernel_size,
stride,
padding,
dilation,
)),
}
}
@ -712,11 +756,17 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
) -> MaxPool2dWithIndices<ADBackendDecorator<B>> {
match MaxPool2D.prepare([x.node], [x.graph]).statefull() {
OpsKind::Tracked(prep) => {
let output =
B::max_pool2d_with_indices(x.primitive.clone(), kernel_size, stride, padding);
let output = B::max_pool2d_with_indices(
x.primitive.clone(),
kernel_size,
stride,
padding,
dilation,
);
let output_tensor = prep.finish(
(
@ -725,6 +775,7 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
kernel_size,
stride,
padding,
dilation,
),
output.output,
);
@ -732,7 +783,8 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
MaxPool2dWithIndices::new(output_tensor, output.indices)
}
OpsKind::UnTracked(prep) => {
let output = B::max_pool2d_with_indices(x.primitive, kernel_size, stride, padding);
let output =
B::max_pool2d_with_indices(x.primitive, kernel_size, stride, padding, dilation);
let output_tensor = prep.finish(output.output);
MaxPool2dWithIndices::new(output_tensor, output.indices)
@ -745,6 +797,7 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
_kernel_size: [usize; 2],
_stride: [usize; 2],
_padding: [usize; 2],
_dilation: [usize; 2],
_output_grad: ADTensor<B, 4>,
_indices: IntTensor<B, 4>,
) -> MaxPool2dBackward<ADBackendDecorator<B>> {
@ -820,16 +873,30 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
struct MaxPool1D;
impl<B: Backend> Backward<B, 3, 1> for MaxPool1D {
type State = (B::TensorPrimitive<3>, IntTensor<B, 3>, usize, usize, usize);
type State = (
B::TensorPrimitive<3>,
IntTensor<B, 3>,
usize,
usize,
usize,
usize,
);
fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
let [node_parent] = ops.parents;
let grad = grads.consume::<B, 3>(&ops.node);
let (x, indices, kernel_size, stride, padding) = ops.state;
let (x, indices, kernel_size, stride, padding, dilation) = ops.state;
if let Some(node) = node_parent {
let grad =
B::max_pool1d_with_indices_backward(x, kernel_size, stride, padding, grad, indices);
let grad = B::max_pool1d_with_indices_backward(
x,
kernel_size,
stride,
padding,
dilation,
grad,
indices,
);
grads.register::<B, 3>(node, grad.x_grad);
}
@ -846,16 +913,24 @@ impl<B: Backend> Backward<B, 4, 1> for MaxPool2D {
[usize; 2],
[usize; 2],
[usize; 2],
[usize; 2],
);
fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
let [node_parent] = ops.parents;
let grad = grads.consume::<B, 4>(&ops.node);
let (x, indices, kernel_size, stride, padding) = ops.state;
let (x, indices, kernel_size, stride, padding, dilation) = ops.state;
if let Some(node) = node_parent {
let grad =
B::max_pool2d_with_indices_backward(x, kernel_size, stride, padding, grad, indices);
let grad = B::max_pool2d_with_indices_backward(
x,
kernel_size,
stride,
padding,
dilation,
grad,
indices,
);
grads.register::<B, 4>(node, grad.x_grad);
}

View File

@ -5,17 +5,44 @@ mod tests {
#[test]
fn test_max_pool1d_simple() {
let batch_size = 1;
let channels_in = 1;
let kernel_size = 4;
let padding = 0;
let stride = 1;
let dilation = 1;
let x = TestADTensor::from_floats([[[0.9861, 0.5474, 0.4477, 0.0732, 0.3548, 0.8221]]])
.require_grad();
let x_grad_expected = TestADTensor::from_floats([[[1., 1., 0., 0., 0., 1.]]]);
let output = max_pool1d(x.clone(), kernel_size, stride, padding);
let output = max_pool1d(x.clone(), kernel_size, stride, padding, dilation);
let grads = output.backward();
// Asserts
let x_grad_actual = x.grad(&grads).unwrap();
x_grad_expected
.to_data()
.assert_approx_eq(&x_grad_actual.to_data(), 3);
}
#[test]
fn test_max_pool1d_with_dilation() {
let kernel_size = 4;
let padding = 0;
let stride = 1;
let dilation = 2;
let x = TestADTensor::from_floats([[[
0.5388, 0.0676, 0.7122, 0.8316, 0.0653, 0.9154, 0.1536, 0.9089, 0.8016, 0.7518, 0.2073,
0.0501, 0.8811, 0.5604, 0.5075, 0.4384, 0.9963, 0.9698, 0.4988, 0.2609, 0.3391, 0.2230,
0.4610, 0.5365, 0.6880,
]]])
.require_grad();
let x_grad_expected = TestADTensor::from_floats([[[
0., 0., 1., 0., 0., 3., 0., 1., 2., 1., 0., 0., 2., 0., 0., 0., 4., 4., 0., 0., 0., 0.,
0., 0., 1.,
]]]);
let output = max_pool1d(x.clone(), kernel_size, stride, padding, dilation);
let grads = output.backward();
// Asserts
@ -27,11 +54,10 @@ mod tests {
#[test]
fn test_max_pool1d_complex() {
let batch_size = 1;
let channels_in = 1;
let kernel_size = 4;
let padding = 0;
let stride = 1;
let dilation = 1;
let x = TestADTensor::from_floats([[[
0.5388, 0.0676, 0.7122, 0.8316, 0.0653, 0.9154, 0.1536, 0.9089, 0.8016, 0.7518, 0.2073,
@ -44,7 +70,7 @@ mod tests {
1., 1., 1.,
]]]);
let output = max_pool1d(x.clone(), kernel_size, stride, padding);
let output = max_pool1d(x.clone(), kernel_size, stride, padding, dilation);
let grads = output.backward();
// Asserts
@ -56,11 +82,10 @@ mod tests {
#[test]
fn test_max_pool1d_complex_with_padding() {
let batch_size = 1;
let channels_in = 1;
let kernel_size = 4;
let padding = 2;
let stride = 1;
let dilation = 1;
let x = TestADTensor::from_floats([[[
0.5388, 0.0676, 0.7122, 0.8316, 0.0653, 0.9154, 0.1536, 0.9089, 0.8016, 0.7518, 0.2073,
@ -73,7 +98,7 @@ mod tests {
1., 1., 3.,
]]]);
let output = max_pool1d(x.clone(), kernel_size, stride, padding);
let output = max_pool1d(x.clone(), kernel_size, stride, padding, dilation);
let grads = output.backward();
// Asserts

View File

@ -11,6 +11,8 @@ mod tests {
let padding_2 = 0;
let stride_1 = 1;
let stride_2 = 1;
let dilation_1 = 1;
let dilation_2 = 1;
let x = TestADTensor::from_floats([[[
[0.2479, 0.6386, 0.3166, 0.5742],
@ -31,6 +33,7 @@ mod tests {
[kernel_size_1, kernel_size_2],
[stride_1, stride_2],
[padding_1, padding_2],
[dilation_1, dilation_2],
);
let grads = output.backward();
@ -49,6 +52,8 @@ mod tests {
let padding_2 = 1;
let stride_1 = 1;
let stride_2 = 1;
let dilation_1 = 1;
let dilation_2 = 1;
let x = TestADTensor::from_floats([[[
[0.2479, 0.6386, 0.3166, 0.5742],
@ -69,6 +74,48 @@ mod tests {
[kernel_size_1, kernel_size_2],
[stride_1, stride_2],
[padding_1, padding_2],
[dilation_1, dilation_2],
);
let grads = output.backward();
// Asserts
let x_grad_actual = x.grad(&grads).unwrap();
x_grad_expected
.to_data()
.assert_approx_eq(&x_grad_actual.to_data(), 3);
}
#[test]
fn test_max_pool2d_with_dilation() {
let kernel_size_1 = 2;
let kernel_size_2 = 2;
let padding_1 = 1;
let padding_2 = 1;
let stride_1 = 1;
let stride_2 = 1;
let dilation_1 = 2;
let dilation_2 = 2;
let x = TestADTensor::from_floats([[[
[0.2479, 0.6386, 0.3166, 0.5742],
[0.7065, 0.1940, 0.6305, 0.8959],
[0.5416, 0.8602, 0.8129, 0.1662],
[0.3358, 0.3059, 0.8293, 0.0990],
]]])
.require_grad();
let x_grad_expected = TestADTensor::from_floats([[[
[0., 0., 0., 0.],
[1., 1., 1., 2.],
[0., 4., 4., 0.],
[0., 1., 2., 0.],
]]]);
let output = max_pool2d(
x.clone(),
[kernel_size_1, kernel_size_2],
[stride_1, stride_2],
[padding_1, padding_2],
[dilation_1, dilation_2],
);
let grads = output.backward();
@ -87,6 +134,8 @@ mod tests {
let padding_2 = 1;
let stride_1 = 1;
let stride_2 = 2;
let dilation_1 = 1;
let dilation_2 = 1;
let x = TestADTensor::from_floats([[[
[0.5388, 0.0676, 0.7122, 0.8316, 0.0653],
@ -109,6 +158,7 @@ mod tests {
[kernel_size_1, kernel_size_2],
[stride_1, stride_2],
[padding_1, padding_2],
[dilation_1, dilation_2],
);
let grads = output.backward();

View File

@ -56,6 +56,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<CandleBackend<F, I>>
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
) -> FloatTensor<Self, 4> {
todo!()
}
@ -65,6 +66,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<CandleBackend<F, I>>
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
) -> MaxPool2dWithIndices<CandleBackend<F, I>> {
todo!()
}
@ -74,6 +76,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<CandleBackend<F, I>>
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
output_grad: FloatTensor<Self, 4>,
indices: IntTensor<Self, 4>,
) -> MaxPool2dBackward<CandleBackend<F, I>> {

View File

@ -18,6 +18,9 @@ pub struct MaxPool1dConfig {
/// The padding configuration.
#[config(default = "PaddingConfig1d::Valid")]
pub padding: PaddingConfig1d,
/// The dilation.
#[config(default = "1")]
pub dilation: usize,
}
/// Applies a 1D max pooling over input tensors.
@ -26,6 +29,7 @@ pub struct MaxPool1d {
stride: usize,
kernel_size: usize,
padding: PaddingConfig1d,
dilation: usize,
}
impl MaxPool1dConfig {
@ -35,6 +39,7 @@ impl MaxPool1dConfig {
stride: self.stride,
kernel_size: self.kernel_size,
padding: self.padding.clone(),
dilation: self.dilation,
}
}
}
@ -52,6 +57,6 @@ impl MaxPool1d {
.padding
.calculate_padding_1d(length, self.kernel_size, self.stride);
max_pool1d(input, self.kernel_size, self.stride, padding)
max_pool1d(input, self.kernel_size, self.stride, padding, self.dilation)
}
}

View File

@ -18,6 +18,9 @@ pub struct MaxPool2dConfig {
/// The padding configuration.
#[config(default = "PaddingConfig2d::Valid")]
pub padding: PaddingConfig2d,
/// The dilation.
#[config(default = "[1, 1]")]
pub dilation: [usize; 2],
}
/// Applies a 2D max pooling over input tensors.
@ -26,6 +29,7 @@ pub struct MaxPool2d {
stride: [usize; 2],
kernel_size: [usize; 2],
padding: PaddingConfig2d,
dilation: [usize; 2],
}
impl MaxPool2dConfig {
@ -35,6 +39,7 @@ impl MaxPool2dConfig {
stride: self.strides,
kernel_size: self.kernel_size,
padding: self.padding.clone(),
dilation: self.dilation,
}
}
}
@ -52,6 +57,6 @@ impl MaxPool2d {
self.padding
.calculate_padding_2d(height_in, width_in, &self.kernel_size, &self.stride);
max_pool2d(input, self.kernel_size, self.stride, padding)
max_pool2d(input, self.kernel_size, self.stride, padding, self.dilation)
}
}

View File

@ -10,8 +10,7 @@ class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
# TODO support dilation=(3, 1) (see https://github.com/burn-rs/burn/issues/622)
self.maxpool2d1 = nn.MaxPool2d((4, 2), stride=(2, 1), padding=(2, 1))
self.maxpool2d1 = nn.MaxPool2d((4, 2), stride=(2, 1), padding=(2, 1), dilation=(1, 3))
def forward(self, x):
x = self.maxpool2d1(x)

View File

@ -305,9 +305,9 @@ mod tests {
]]]);
let output = model.forward(input);
let expected = Data::from([[[
[1.927, 1.927, 1.487, 0.901, 0.678, 0.678],
[1.927, 1.927, 1.487, 0.901, 0.803, 0.678],
[-0.217, 0.241, 0.241, 0.803, 0.803, -0.622],
[0.901, 1.927, 1.487, 0.901],
[0.901, 1.927, 1.487, 0.901],
[-0.396, 0.803, 0.241, -0.396],
]]]);
assert_eq!(output.to_data(), expected);

View File

@ -51,6 +51,7 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for MaxPool2dNode {
let kernel_size = self.config.kernel_size.to_tokens();
let strides = self.config.strides.to_tokens();
let padding = self.config.padding.to_tokens();
let dilation = self.config.dilation.to_tokens();
let init_line = quote! {
init();
@ -60,6 +61,7 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for MaxPool2dNode {
let #name = MaxPool2dConfig::new(#kernel_size)
.with_strides(#strides)
.with_padding(#padding)
.with_dilation(#dilation)
.#init_line
};
@ -111,7 +113,8 @@ mod tests {
TensorType::new_float("output", 4),
MaxPool2dConfig::new([3, 3])
.with_strides([1, 1])
.with_padding(PaddingConfig2d::Valid),
.with_padding(PaddingConfig2d::Valid)
.with_dilation([1, 1]),
));
graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);
@ -137,6 +140,7 @@ mod tests {
let max_pool2d = MaxPool2dConfig::new([3, 3])
.with_strides([1, 1])
.with_padding(PaddingConfig2d::Valid)
.with_dilation([1, 1])
.init();
Self {

View File

@ -123,7 +123,7 @@ pub fn max_pool2d_config(curr: &Node) -> MaxPool2dConfig {
let mut kernel_shape = Vec::new();
let mut strides = Vec::new();
let mut pads = Vec::new();
let mut dilations = Vec::new();
let mut dilations = vec![1, 1];
for (key, value) in curr.attrs.iter() {
match key.as_str() {
@ -135,15 +135,12 @@ pub fn max_pool2d_config(curr: &Node) -> MaxPool2dConfig {
}
}
if !dilations.is_empty() && (dilations[0] != 1 || dilations[1] != 1) {
todo!("MaxPool2d: dilations are not supported. See https://github.com/burn-rs/burn/issues/622");
}
let padding = padding_config(&pads);
MaxPool2dConfig::new([kernel_shape[0] as usize, kernel_shape[1] as usize])
.with_strides([strides[0] as usize, strides[1] as usize])
.with_padding(padding)
.with_dilation([dilations[0] as usize, dilations[1] as usize])
}
/// Create a AvgPool2dConfig from the attributes of the node

View File

@ -11,15 +11,21 @@ pub(crate) fn max_pool2d<E: FloatNdArrayElement>(
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
) -> NdArrayTensor<E, 4> {
let [kernel_height, kernel_width] = kernel_size;
let [padding_height, padding_width] = padding;
let [stride_height, stride_width] = stride;
let [dilation_height, dilation_width] = dilation;
let [batch_size, channels, x_height, x_width] = x.shape().dims;
let inf = (-f32::INFINITY).elem::<E>();
let out_height = ((x_height + 2 * padding_height - kernel_height) / stride_height) + 1;
let out_width = ((x_width + 2 * padding_width - kernel_width) / stride_width) + 1;
let out_height = ((x_height + 2 * padding_height - dilation_height * (kernel_height - 1) - 1)
/ stride_height)
+ 1;
let out_width = ((x_width + 2 * padding_width - dilation_width * (kernel_width - 1) - 1)
/ stride_width)
+ 1;
let x = apply_padding_4d(x, padding, inf).array;
@ -38,10 +44,10 @@ pub(crate) fn max_pool2d<E: FloatNdArrayElement>(
let mut max_val = inf;
for kh in 0..kernel_height {
let ih = oh * stride_height + kh;
let ih = oh * stride_height + kh * dilation_height;
for kw in 0..kernel_width {
let iw = ow * stride_width + kw;
let iw = ow * stride_width + kw * dilation_width;
let val = x[[b, c, ih, iw]];
@ -65,15 +71,21 @@ pub(crate) fn max_pool2d_with_indices<E: FloatNdArrayElement>(
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
) -> (NdArrayTensor<E, 4>, NdArrayTensor<i64, 4>) {
let [kernel_height, kernel_width] = kernel_size;
let [padding_height, padding_width] = padding;
let [stride_height, stride_width] = stride;
let [dilation_height, dilation_width] = dilation;
let [batch_size, channels, x_height, x_width] = x.shape().dims;
let inf = (-f32::INFINITY).elem::<E>();
let out_height = ((x_height + 2 * padding_height - kernel_height) / stride_height) + 1;
let out_width = ((x_width + 2 * padding_width - kernel_width) / stride_width) + 1;
let out_height = ((x_height + 2 * padding_height - dilation_height * (kernel_height - 1) - 1)
/ stride_height)
+ 1;
let out_width = ((x_width + 2 * padding_width - dilation_width * (kernel_width - 1) - 1)
/ stride_width)
+ 1;
let x = apply_padding_4d(x, padding, inf).array;
@ -97,10 +109,10 @@ pub(crate) fn max_pool2d_with_indices<E: FloatNdArrayElement>(
let mut index = 0;
for kh in 0..kernel_height {
let ih = oh * stride_height + kh;
let ih = oh * stride_height + kh * dilation_height;
for kw in 0..kernel_width {
let iw = ow * stride_width + kw;
let iw = ow * stride_width + kw * dilation_width;
let val = x[[b, c, ih, iw]];
if val > max_val {
@ -132,6 +144,7 @@ pub(crate) fn max_pool2d_backward<E: FloatNdArrayElement>(
_kernel_size: [usize; 2],
_stride: [usize; 2],
_padding: [usize; 2],
_dilation: [usize; 2],
output_grad: NdArrayTensor<E, 4>,
indices: NdArrayTensor<i64, 4>,
) -> NdArrayTensor<E, 4> {

View File

@ -52,8 +52,9 @@ impl<E: FloatNdArrayElement> ModuleOps<NdArrayBackend<E>> for NdArrayBackend<E>
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
) -> NdArrayTensor<E, 4> {
max_pool2d(x, kernel_size, stride, padding)
max_pool2d(x, kernel_size, stride, padding, dilation)
}
fn max_pool2d_with_indices(
@ -61,8 +62,9 @@ impl<E: FloatNdArrayElement> ModuleOps<NdArrayBackend<E>> for NdArrayBackend<E>
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
) -> MaxPool2dWithIndices<NdArrayBackend<E>> {
let (output, indices) = max_pool2d_with_indices(x, kernel_size, stride, padding);
let (output, indices) = max_pool2d_with_indices(x, kernel_size, stride, padding, dilation);
MaxPool2dWithIndices::new(output, indices)
}
@ -72,6 +74,7 @@ impl<E: FloatNdArrayElement> ModuleOps<NdArrayBackend<E>> for NdArrayBackend<E>
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
output_grad: NdArrayTensor<E, 4>,
indices: NdArrayTensor<i64, 4>,
) -> MaxPool2dBackward<NdArrayBackend<E>> {
@ -80,6 +83,7 @@ impl<E: FloatNdArrayElement> ModuleOps<NdArrayBackend<E>> for NdArrayBackend<E>
kernel_size,
stride,
padding,
dilation,
output_grad,
indices,
))

View File

@ -172,13 +172,14 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
kernel_size: usize,
stride: usize,
padding: usize,
dilation: usize,
) -> TchTensor<E, 3> {
let tensor = tch::Tensor::max_pool1d(
&x.tensor,
kernel_size as i64,
stride as i64,
padding as i64,
1,
dilation as i64,
false,
);
@ -190,13 +191,14 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
kernel_size: usize,
stride: usize,
padding: usize,
dilation: usize,
) -> MaxPool1dWithIndices<TchBackend<E>> {
let (tensor, indices) = tch::Tensor::max_pool1d_with_indices(
&x.tensor,
kernel_size as i64,
stride as i64,
padding as i64,
1,
dilation as i64,
false,
);
@ -208,13 +210,14 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
) -> TchTensor<E, 4> {
let tensor = tch::Tensor::max_pool2d(
&x.tensor,
[kernel_size[0] as i64, kernel_size[1] as i64],
[stride[0] as i64, stride[1] as i64],
[padding[0] as i64, padding[1] as i64],
[1, 1],
[dilation[0] as i64, dilation[1] as i64],
false,
);
@ -226,13 +229,14 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
) -> MaxPool2dWithIndices<TchBackend<E>> {
let (tensor, indices) = tch::Tensor::max_pool2d_with_indices(
&x.tensor,
[kernel_size[0] as i64, kernel_size[1] as i64],
[stride[0] as i64, stride[1] as i64],
[padding[0] as i64, padding[1] as i64],
[1, 1],
[dilation[0] as i64, dilation[1] as i64],
false,
);
@ -244,6 +248,7 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
output_grad: TchTensor<E, 4>,
indices: TchTensor<i64, 4>,
) -> MaxPool2dBackward<TchBackend<E>> {
@ -253,7 +258,7 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
[kernel_size[0] as i64, kernel_size[1] as i64],
[stride[0] as i64, stride[1] as i64],
[padding[0] as i64, padding[1] as i64],
[1, 1],
[dilation[0] as i64, dilation[1] as i64],
false,
&indices.tensor,
);

View File

@ -90,11 +90,18 @@ pub fn max_pool1d<B>(
kernel_size: usize,
stride: usize,
padding: usize,
dilation: usize,
) -> Tensor<B, 3>
where
B: Backend,
{
Tensor::new(B::max_pool1d(x.primitive, kernel_size, stride, padding))
Tensor::new(B::max_pool1d(
x.primitive,
kernel_size,
stride,
padding,
dilation,
))
}
/// Applies a [2D max pooling](crate::ops::ModuleOps::max_pool2d).
@ -103,11 +110,18 @@ pub fn max_pool2d<B>(
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
) -> Tensor<B, 4>
where
B: Backend,
{
Tensor::new(B::max_pool2d(x.primitive, kernel_size, stride, padding))
Tensor::new(B::max_pool2d(
x.primitive,
kernel_size,
stride,
padding,
dilation,
))
}
/// Applies a [2D avg pooling](crate::ops::ModuleOps::avg_pool2d).
@ -156,11 +170,12 @@ pub fn max_pool1d_with_indices<B>(
kernel_size: usize,
stride: usize,
padding: usize,
dilation: usize,
) -> (Tensor<B, 3>, Tensor<B, 3, Int>)
where
B: Backend,
{
let output = B::max_pool1d_with_indices(x.primitive, kernel_size, stride, padding);
let output = B::max_pool1d_with_indices(x.primitive, kernel_size, stride, padding, dilation);
(Tensor::new(output.output), Tensor::new(output.indices))
}
@ -171,11 +186,12 @@ pub fn max_pool2d_with_indices<B>(
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
) -> (Tensor<B, 4>, Tensor<B, 4, Int>)
where
B: Backend,
{
let output = B::max_pool2d_with_indices(x.primitive, kernel_size, stride, padding);
let output = B::max_pool2d_with_indices(x.primitive, kernel_size, stride, padding, dilation);
(Tensor::new(output.output), Tensor::new(output.indices))
}

View File

@ -337,8 +337,9 @@ pub trait ModuleOps<B: Backend> {
kernel_size: usize,
stride: usize,
padding: usize,
dilation: usize,
) -> B::TensorPrimitive<3> {
pool::max_pool1d_from_2d::<B>(x, kernel_size, stride, padding)
pool::max_pool1d_from_2d::<B>(x, kernel_size, stride, padding, dilation)
}
/// One dimensional max pooling with indices.
@ -351,8 +352,9 @@ pub trait ModuleOps<B: Backend> {
kernel_size: usize,
stride: usize,
padding: usize,
dilation: usize,
) -> MaxPool1dWithIndices<B> {
pool::max_pool1d_with_indices_from_2d::<B>(x, kernel_size, stride, padding)
pool::max_pool1d_with_indices_from_2d::<B>(x, kernel_size, stride, padding, dilation)
}
/// Backward pass for the [max pooling 1d](ModuleOps::max_pool1d_with_indices) operation.
fn max_pool1d_with_indices_backward(
@ -360,6 +362,7 @@ pub trait ModuleOps<B: Backend> {
kernel_size: usize,
stride: usize,
padding: usize,
dilation: usize,
output_grad: B::TensorPrimitive<3>,
indices: B::IntTensorPrimitive<3>,
) -> MaxPool1dBackward<B> {
@ -368,6 +371,7 @@ pub trait ModuleOps<B: Backend> {
kernel_size,
stride,
padding,
dilation,
output_grad,
indices,
)
@ -383,6 +387,7 @@ pub trait ModuleOps<B: Backend> {
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
) -> B::TensorPrimitive<4>;
/// Two dimensional max pooling with indices.
@ -395,6 +400,7 @@ pub trait ModuleOps<B: Backend> {
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
) -> MaxPool2dWithIndices<B>;
/// Backward pass for the [max pooling 2d](ModuleOps::max_pool2d_with_indices) operation.
fn max_pool2d_with_indices_backward(
@ -402,6 +408,7 @@ pub trait ModuleOps<B: Backend> {
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
output_grad: B::TensorPrimitive<4>,
indices: B::IntTensorPrimitive<4>,
) -> MaxPool2dBackward<B>;

View File

@ -85,11 +85,18 @@ pub(crate) fn max_pool1d_from_2d<B: Backend>(
kernel_size: usize,
stride: usize,
padding: usize,
dilation: usize,
) -> B::TensorPrimitive<3> {
let [batch_size, channels, length] = B::shape(&x).dims;
let x = B::reshape(x, Shape::from([batch_size, channels, length, 1]));
let x = B::max_pool2d(x, [kernel_size, 1], [stride, 1], [padding, 0]);
let x = B::max_pool2d(
x,
[kernel_size, 1],
[stride, 1],
[padding, 0],
[dilation, 1],
);
let [batch_size, channels, length, _] = B::shape(&x).dims;
@ -101,11 +108,18 @@ pub(crate) fn max_pool1d_with_indices_from_2d<B: Backend>(
kernel_size: usize,
stride: usize,
padding: usize,
dilation: usize,
) -> MaxPool1dWithIndices<B> {
let [batch_size, channels, length] = B::shape(&x).dims;
let x = B::reshape(x, Shape::from([batch_size, channels, 1, length]));
let x = B::max_pool2d_with_indices(x, [1, kernel_size], [1, stride], [0, padding]);
let x = B::max_pool2d_with_indices(
x,
[1, kernel_size],
[1, stride],
[0, padding],
[1, dilation],
);
let [batch_size, channels, _, length] = B::shape(&x.output).dims;
let output = B::reshape(x.output, Shape::from([batch_size, channels, length]));
let indices = B::int_reshape(
@ -120,6 +134,7 @@ pub(crate) fn max_pool1d_with_indices_backward_from_2d<B: Backend>(
kernel_size: usize,
stride: usize,
padding: usize,
dilation: usize,
output_grad: B::TensorPrimitive<3>,
indices: B::IntTensorPrimitive<3>,
) -> MaxPool1dBackward<B> {
@ -138,6 +153,7 @@ pub(crate) fn max_pool1d_with_indices_backward_from_2d<B: Backend>(
[kernel_size, 1],
[stride, 1],
[padding, 0],
[dilation, 1],
grad_x,
indices,
)

View File

@ -8,11 +8,10 @@ mod tests {
#[test]
fn test_max_pool1d_simple() {
let batch_size = 2;
let channels_in = 2;
let kernel_size = 3;
let padding = 1;
let stride = 1;
let dilation = 1;
let x = TestTensor::from_floats([[
[0.9861, 0.5474, 0.4477, 0.0732, 0.3548, 0.8221],
@ -23,56 +22,75 @@ mod tests {
[0.8148, 0.9490, 0.9490, 0.9490, 0.7890, 0.5689],
]]);
let output = max_pool1d(x, kernel_size, stride, padding);
let output = max_pool1d(x, kernel_size, stride, padding, dilation);
y.to_data().assert_approx_eq(&output.into_data(), 3);
}
#[test]
fn test_max_pool1d_different_padding_stride_kernel() {
let batch_size = 1;
let channels_in = 1;
let kernel_size = 3;
let padding = 1;
let stride = 2;
let dilation = 1;
let x = TestTensor::from_floats([[[0.6309, 0.6112, 0.6998, 0.4708]]]);
let y = TestTensor::from_floats([[[0.6309, 0.6998]]]);
let output = max_pool1d(x, kernel_size, stride, padding);
let output = max_pool1d(x, kernel_size, stride, padding, dilation);
y.to_data().assert_approx_eq(&output.into_data(), 3);
}
#[test]
fn test_max_pool1d_with_neg() {
let batch_size = 1;
let channels_in = 1;
let kernel_size = 3;
let padding = 1;
let stride = 1;
let dilation = 1;
let x = TestTensor::from_floats([[[-0.6309, -0.6112, -0.6998, -0.4708]]]);
let y = TestTensor::from_floats([[[-0.6112, -0.6112, -0.4708, -0.4708]]]);
let output = max_pool1d(x, kernel_size, stride, padding);
let output = max_pool1d(x, kernel_size, stride, padding, dilation);
y.to_data().assert_approx_eq(&output.into_data(), 3);
}
#[test]
fn test_max_pool1d_with_dilation() {
let kernel_size = 2;
let padding = 1;
let stride = 1;
let dilation = 2;
let x = TestTensor::from_floats([[
[0.9861, 0.5474, 0.4477, 0.0732, 0.3548, 0.8221],
[0.8148, 0.5474, 0.9490, 0.7890, 0.5537, 0.5689],
]]);
let y = TestTensor::from_floats([[
[0.5474, 0.9861, 0.5474, 0.4477, 0.8221, 0.3548],
[0.5474, 0.9490, 0.7890, 0.9490, 0.7890, 0.5537],
]]);
let output = max_pool1d(x, kernel_size, stride, padding, dilation);
y.to_data().assert_approx_eq(&output.into_data(), 3);
}
#[test]
fn test_max_pool1d_with_indices() {
let batch_size = 1;
let channels_in = 1;
let kernel_size = 2;
let padding = 1;
let stride = 1;
let dilation = 1;
let x = TestTensor::from_floats([[[0.2479, 0.6386, 0.3166, 0.5742]]]);
let indices = Data::<IntElem, 3>::from([[[0, 1, 1, 3, 3]]]);
let y = TestTensor::from_floats([[[0.2479, 0.6386, 0.6386, 0.5742, 0.5742]]]);
let (output, output_indices) = max_pool1d_with_indices(x, kernel_size, stride, padding);
let (output, output_indices) =
max_pool1d_with_indices(x, kernel_size, stride, padding, dilation);
y.to_data().assert_approx_eq(&output.into_data(), 3);
assert_eq!(indices.value, output_indices.into_data().value);
@ -80,17 +98,17 @@ mod tests {
#[test]
fn test_max_pool1d_complex() {
let batch_size = 1;
let channels_in = 1;
let kernel_size = 4;
let padding = 2;
let stride = 1;
let dilation = 1;
let x = TestTensor::from_floats([[[0.5388, 0.0676, 0.7122, 0.8316, 0.0653]]]);
let indices = Data::<IntElem, 3>::from([[[0, 2, 3, 3, 3, 3]]]);
let y = TestTensor::from_floats([[[0.5388, 0.7122, 0.8316, 0.8316, 0.8316, 0.8316]]]);
let (output, output_indices) = max_pool1d_with_indices(x, kernel_size, stride, padding);
let (output, output_indices) =
max_pool1d_with_indices(x, kernel_size, stride, padding, dilation);
y.to_data().assert_approx_eq(&output.into_data(), 3);
assert_eq!(indices.value, output_indices.into_data().value);

View File

@ -16,6 +16,8 @@ mod tests {
let padding_2 = 1;
let stride_1 = 1;
let stride_2 = 1;
let dilation_1 = 1;
let dilation_2 = 1;
let x = TestTensor::from_floats([
[
@ -99,6 +101,7 @@ mod tests {
[kernel_size_1, kernel_size_2],
[stride_1, stride_2],
[padding_1, padding_2],
[dilation_1, dilation_2],
);
y.to_data().assert_approx_eq(&output.into_data(), 3);
@ -114,6 +117,8 @@ mod tests {
let padding_2 = 0;
let stride_1 = 1;
let stride_2 = 2;
let dilation_1 = 1;
let dilation_2 = 1;
let x = TestTensor::from_floats([[[
[0.6309, 0.6112, 0.6998],
@ -137,6 +142,7 @@ mod tests {
[kernel_size_1, kernel_size_2],
[stride_1, stride_2],
[padding_1, padding_2],
[dilation_1, dilation_2],
);
y.to_data().assert_approx_eq(&output.into_data(), 3);
@ -152,6 +158,8 @@ mod tests {
let padding_2 = 1;
let stride_1 = 1;
let stride_2 = 1;
let dilation_1 = 1;
let dilation_2 = 1;
let x = TestTensor::from_floats([[[
[0.6309, 0.6112, 0.6998],
@ -176,12 +184,51 @@ mod tests {
[kernel_size_1, kernel_size_2],
[stride_1, stride_2],
[padding_1, padding_2],
[dilation_1, dilation_2],
);
y.to_data().assert_approx_eq(&output.into_data(), 3);
}
#[test]
fn test_max_pool2d_with_dilation() {
let batch_size = 1;
let channels_in = 1;
let kernel_size_1 = 2;
let kernel_size_2 = 2;
let padding_1 = 0;
let padding_2 = 0;
let stride_1 = 1;
let stride_2 = 1;
let dilation_1 = 2;
let dilation_2 = 2;
let x = TestTensor::from_floats([[[
[0.9861, 0.9861, 0.9490, 0.9490, 0.8221, 0.8221],
[0.9861, 0.9861, 0.9490, 0.9490, 0.8221, 0.8221],
[0.9540, 0.9540, 0.9540, 0.9490, 0.7890, 0.7111],
[0.9540, 0.9540, 0.9540, 0.8855, 0.8855, 0.8855],
[0.9540, 0.9540, 0.9540, 0.9432, 0.8855, 0.8855],
[0.5063, 0.9432, 0.9432, 0.9432, 0.8855, 0.8855],
]]]);
let y = TestTensor::from_floats([[[
[0.9861, 0.9861, 0.9540, 0.9490],
[0.9861, 0.9861, 0.9540, 0.9490],
[0.9540, 0.9540, 0.9540, 0.9490],
[0.9540, 0.9540, 0.9540, 0.9432],
]]]);
let output = max_pool2d(
x,
[kernel_size_1, kernel_size_2],
[stride_1, stride_2],
[padding_1, padding_2],
[dilation_1, dilation_2],
);
y.to_data().assert_approx_eq(&output.into_data(), 3);
}
fn test_max_pool2d_with_indices() {
let batch_size = 1;
let channels_in = 1;
@ -191,6 +238,8 @@ mod tests {
let padding_2 = 1;
let stride_1 = 1;
let stride_2 = 1;
let dilation_1 = 1;
let dilation_2 = 1;
let x = TestTensor::from_floats([[[
[0.2479, 0.6386, 0.3166, 0.5742],
@ -218,6 +267,7 @@ mod tests {
[kernel_size_1, kernel_size_2],
[stride_1, stride_2],
[padding_1, padding_2],
[dilation_1, dilation_2],
);
y.to_data().assert_approx_eq(&output.into_data(), 3);
@ -234,6 +284,8 @@ mod tests {
let padding_2 = 1;
let stride_1 = 1;
let stride_2 = 2;
let dilation_1 = 1;
let dilation_2 = 1;
let x = TestTensor::from_floats([[[
[0.5388, 0.0676, 0.7122, 0.8316, 0.0653],
@ -263,6 +315,7 @@ mod tests {
[kernel_size_1, kernel_size_2],
[stride_1, stride_2],
[padding_1, padding_2],
[dilation_1, dilation_2],
);
y.to_data().assert_approx_eq(&output.into_data(), 3);

View File

@ -41,7 +41,8 @@ pub(crate) fn avg_pool2d<E: WgpuElement>(
) -> WgpuTensor<E, 4> {
const WORKGROUP: usize = 32;
let (info_buffer, output) = build_output_and_info_pool2d(&x, kernel_size, stride, padding);
let (info_buffer, output) =
build_output_and_info_pool2d(&x, kernel_size, stride, padding, [1, 1]);
let kernel = match count_include_pad {
true => x
.context
@ -77,7 +78,7 @@ pub(crate) fn avg_pool2d_backward<E: WgpuElement>(
.context
.create_buffer(num_elems * core::mem::size_of::<E>());
let output = WgpuTensor::new(x.context.clone(), x.shape.clone(), buffer);
let info_buffer = build_pool2d_info(&x, &grad, kernel_size, stride, padding);
let info_buffer = build_pool2d_info(&x, &grad, kernel_size, stride, padding, [1, 1]);
let kernel =
match count_include_pad {

View File

@ -9,14 +9,20 @@ pub fn build_output_and_info_pool2d<E: WgpuElement>(
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
) -> (Arc<Buffer>, WgpuTensor<E, 4>) {
let [kernel_height, kernel_width] = kernel_size;
let [padding_height, padding_width] = padding;
let [stride_height, stride_width] = stride;
let [dilation_height, dilation_width] = dilation;
let [batch_size, channels, x_height, x_width] = x.shape.dims;
let out_height = ((x_height + 2 * padding_height - kernel_height) / stride_height) + 1;
let out_width = ((x_width + 2 * padding_width - kernel_width) / stride_width) + 1;
let out_height = ((x_height + 2 * padding_height - dilation_height * (kernel_height - 1) - 1)
/ stride_height)
+ 1;
let out_width = ((x_width + 2 * padding_width - dilation_width * (kernel_width - 1) - 1)
/ stride_width)
+ 1;
let shape_out = Shape::new([batch_size, channels, out_height, out_width]);
let num_elems = shape_out.num_elements();
@ -25,7 +31,7 @@ pub fn build_output_and_info_pool2d<E: WgpuElement>(
.create_buffer(num_elems * core::mem::size_of::<E>());
let output = WgpuTensor::new(x.context.clone(), shape_out, buffer);
let info_buffer = build_pool2d_info(x, &output, kernel_size, stride, padding);
let info_buffer = build_pool2d_info(x, &output, kernel_size, stride, padding, dilation);
(info_buffer, output)
}
@ -36,8 +42,9 @@ pub fn build_pool2d_info<E: WgpuElement>(
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
) -> Arc<Buffer> {
let mut info: [u32; 22] = [0; 22];
let mut info: [u32; 24] = [0; 24];
info[0] = input.strides[0] as u32;
info[1] = input.strides[1] as u32;
info[2] = input.strides[2] as u32;
@ -62,6 +69,8 @@ pub fn build_pool2d_info<E: WgpuElement>(
info[19] = stride[1] as u32;
info[20] = padding[0] as u32;
info[21] = padding[1] as u32;
info[22] = dilation[0] as u32;
info[23] = dilation[1] as u32;
let info_buffer = input
.context

View File

@ -24,10 +24,12 @@ pub(crate) fn max_pool2d<E: WgpuElement>(
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
) -> WgpuTensor<E, 4> {
const WORKGROUP: usize = 32;
let (info_buffer, output) = build_output_and_info_pool2d(&x, kernel_size, stride, padding);
let (info_buffer, output) =
build_output_and_info_pool2d(&x, kernel_size, stride, padding, dilation);
let kernel = x
.context
.compile_static::<KernelSettings<MaxPool2d, E, i32, WORKGROUP, WORKGROUP, 1>>();
@ -46,10 +48,12 @@ pub(crate) fn max_pool2d_with_indices<E: WgpuElement, I: WgpuElement>(
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
) -> (WgpuTensor<E, 4>, WgpuTensor<I, 4>) {
const WORKGROUP: usize = 32;
let (info_buffer, output) = build_output_and_info_pool2d(&x, kernel_size, stride, padding);
let (info_buffer, output) =
build_output_and_info_pool2d(&x, kernel_size, stride, padding, dilation);
let num_elems = output.shape.num_elements();
let indices = WgpuTensor::new(
@ -79,6 +83,7 @@ pub(crate) fn max_pool2d_with_indices_backward<E: WgpuElement, I: WgpuElement>(
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
) -> WgpuTensor<E, 4> {
const WORKGROUP: usize = 32;
@ -91,7 +96,7 @@ pub(crate) fn max_pool2d_with_indices_backward<E: WgpuElement, I: WgpuElement>(
.create_buffer(num_elems * core::mem::size_of::<E>());
let output = WgpuTensor::new(x.context.clone(), x.shape.clone(), buffer);
let info_buffer = build_pool2d_info(&x, &grad, kernel_size, stride, padding);
let info_buffer = build_pool2d_info(&x, &grad, kernel_size, stride, padding, dilation);
let kernel = x.context.compile_static::<KernelSettings<
MaxPool2dWithIndicesBackward,
@ -122,9 +127,10 @@ mod tests {
let kernel_size = [3, 3];
let stride = [2, 2];
let padding = [1, 1];
let dilation = [1, 1];
let pooled = module::max_pool2d(tensor, kernel_size, stride, padding);
let pooled_ref = module::max_pool2d(tensor_ref, kernel_size, stride, padding);
let pooled = module::max_pool2d(tensor, kernel_size, stride, padding, dilation);
let pooled_ref = module::max_pool2d(tensor_ref, kernel_size, stride, padding, dilation);
pooled
.into_data()
@ -138,11 +144,12 @@ mod tests {
let kernel_size = [3, 3];
let stride = [2, 2];
let padding = [1, 1];
let dilation = [1, 1];
let (pooled, indices) =
module::max_pool2d_with_indices(tensor, kernel_size, stride, padding);
module::max_pool2d_with_indices(tensor, kernel_size, stride, padding, dilation);
let (pooled_ref, indices_ref) =
module::max_pool2d_with_indices(tensor_ref, kernel_size, stride, padding);
module::max_pool2d_with_indices(tensor_ref, kernel_size, stride, padding, dilation);
pooled
.into_data()
@ -159,16 +166,23 @@ mod tests {
let kernel_size = [3, 3];
let stride = [2, 2];
let padding = [1, 1];
let dilation = [1, 1];
let (_, indices) =
module::max_pool2d_with_indices(tensor.clone(), kernel_size, stride, padding);
let (_, indices_ref) =
module::max_pool2d_with_indices(tensor_ref.clone(), kernel_size, stride, padding);
module::max_pool2d_with_indices(tensor.clone(), kernel_size, stride, padding, dilation);
let (_, indices_ref) = module::max_pool2d_with_indices(
tensor_ref.clone(),
kernel_size,
stride,
padding,
dilation,
);
let grad = TestBackend::max_pool2d_with_indices_backward(
tensor.into_primitive(),
kernel_size,
stride,
padding,
dilation,
grad_output.into_primitive(),
indices.into_primitive(),
)
@ -178,6 +192,7 @@ mod tests {
kernel_size,
stride,
padding,
dilation,
grad_output_ref.into_primitive(),
indices_ref.into_primitive(),
)

View File

@ -59,8 +59,9 @@ where
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
) -> FloatTensor<Self, 4> {
kernel::pool::max_pool2d(x, kernel_size, stride, padding)
kernel::pool::max_pool2d(x, kernel_size, stride, padding, dilation)
}
fn max_pool2d_with_indices(
@ -68,9 +69,10 @@ where
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
) -> MaxPool2dWithIndices<WgpuBackend<G, F, I>> {
let (output, indices) =
kernel::pool::max_pool2d_with_indices(x, kernel_size, stride, padding);
kernel::pool::max_pool2d_with_indices(x, kernel_size, stride, padding, dilation);
MaxPool2dWithIndices::new(output, indices)
}
@ -80,6 +82,7 @@ where
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
output_grad: FloatTensor<Self, 4>,
indices: IntTensor<Self, 4>,
) -> MaxPool2dBackward<WgpuBackend<G, F, I>> {
@ -90,6 +93,7 @@ where
kernel_size,
stride,
padding,
dilation,
))
}

View File

@ -8,7 +8,7 @@ var<storage, read_write> output: array<{{ elem }}>;
@group(0)
@binding(2)
var<storage, read> info: array<u32, 22>;
var<storage, read> info: array<u32, 24>;
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
@ -44,6 +44,8 @@ fn main(
let pool_stride_1 = info[19];
let padding_0 = info[20];
let padding_1 = info[21];
let dilation_0 = info[22];
let dilation_1 = info[23];
let b = id / output_stride_0 % output_shape_0;
let c = id / output_stride_1 % output_shape_1;
@ -53,7 +55,7 @@ fn main(
var max_val = -32767.0;
for (var kh = 0u; kh < kernel_size_0; kh++) {
let ih = oh * pool_stride_0 + kh;
let ih = oh * pool_stride_0 + kh * dilation_0;
// Padding
if ih < padding_0 || ih >= input_shape_2 + padding_0 {
@ -61,7 +63,7 @@ fn main(
}
for (var kw = 0u; kw < kernel_size_1; kw++) {
let iw = ow * pool_stride_1 + kw;
let iw = ow * pool_stride_1 + kw * dilation_1;
// Padding
if iw < padding_1 || iw >= input_shape_3 + padding_1 {

View File

@ -12,7 +12,7 @@ var<storage, read_write> indices: array<{{ int }}>;
@group(0)
@binding(3)
var<storage, read> info: array<u32, 22>;
var<storage, read> info: array<u32, 24>;
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
@ -48,6 +48,8 @@ fn main(
let pool_stride_1 = info[19];
let padding_0 = info[20];
let padding_1 = info[21];
let dilation_0 = info[22];
let dilation_1 = info[23];
let b = id / output_stride_0 % output_shape_0;
let c = id / output_stride_1 % output_shape_1;
@ -58,7 +60,7 @@ fn main(
var index = 0u;
for (var kh = 0u; kh < kernel_size_0; kh++) {
let ih = oh * pool_stride_0 + kh;
let ih = oh * pool_stride_0 + kh * dilation_0;
// Padding
if ih < padding_0 || ih >= input_shape_2 + padding_0 {
@ -66,7 +68,7 @@ fn main(
}
for (var kw = 0u; kw < kernel_size_1; kw++) {
let iw = ow * pool_stride_1 + kw;
let iw = ow * pool_stride_1 + kw * dilation_1;
// Padding
if iw < padding_1 || iw >= input_shape_3 + padding_1 {

View File

@ -13,7 +13,7 @@ var<storage, read_write> output: array<{{ elem }}>;
@group(0)
@binding(3)
var<storage, read> info: array<u32, 22>;
var<storage, read> info: array<u32, 24>;
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
@ -49,6 +49,8 @@ fn main(
let pool_stride_1 = info[19];
let padding_0 = info[20];
let padding_1 = info[21];
let dilation_0 = info[22];
let dilation_1 = info[23];
let b = id / input_stride_0 % input_shape_0;
let c = id / input_stride_1 % input_shape_1;
@ -56,8 +58,8 @@ fn main(
let iw = id / input_stride_3 % input_shape_3;
// The maximum number of overlapping filters that may content the current index.
let kms_0 = i32(kernel_size_0) - i32(pool_stride_0);
let kms_1 = i32(kernel_size_1) - i32(pool_stride_1);
let kms_0 = i32(kernel_size_0 * dilation_0) - i32(pool_stride_0);
let kms_1 = i32(kernel_size_1 * dilation_1) - i32(pool_stride_1);
let oh_start_tmp = (i32(ih + padding_0) - kms_0) / i32(pool_stride_0);
let ow_start_tmp = (i32(iw + padding_1) - kms_1) / i32(pool_stride_1);