mirror of https://github.com/tracel-ai/burn.git
Dilation maxpool (#668)
This commit is contained in:
parent
dd5ea5251c
commit
2fefc82099
|
@ -618,19 +618,36 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
kernel_size: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
dilation: usize,
|
||||
) -> ADTensor<B, 3> {
|
||||
match MaxPool1D.prepare([x.node], [x.graph]).statefull() {
|
||||
OpsKind::Tracked(prep) => {
|
||||
let output =
|
||||
B::max_pool1d_with_indices(x.primitive.clone(), kernel_size, stride, padding);
|
||||
let output = B::max_pool1d_with_indices(
|
||||
x.primitive.clone(),
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
);
|
||||
prep.finish(
|
||||
(x.primitive, output.indices, kernel_size, stride, padding),
|
||||
(
|
||||
x.primitive,
|
||||
output.indices,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
),
|
||||
output.output,
|
||||
)
|
||||
}
|
||||
OpsKind::UnTracked(prep) => {
|
||||
prep.finish(B::max_pool1d(x.primitive, kernel_size, stride, padding))
|
||||
}
|
||||
OpsKind::UnTracked(prep) => prep.finish(B::max_pool1d(
|
||||
x.primitive,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -639,11 +656,17 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
kernel_size: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
dilation: usize,
|
||||
) -> MaxPool1dWithIndices<ADBackendDecorator<B>> {
|
||||
match MaxPool1D.prepare([x.node], [x.graph]).statefull() {
|
||||
OpsKind::Tracked(prep) => {
|
||||
let output =
|
||||
B::max_pool1d_with_indices(x.primitive.clone(), kernel_size, stride, padding);
|
||||
let output = B::max_pool1d_with_indices(
|
||||
x.primitive.clone(),
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
);
|
||||
|
||||
let output_tensor = prep.finish(
|
||||
(
|
||||
|
@ -652,6 +675,7 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
),
|
||||
output.output,
|
||||
);
|
||||
|
@ -659,7 +683,8 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
MaxPool1dWithIndices::new(output_tensor, output.indices)
|
||||
}
|
||||
OpsKind::UnTracked(prep) => {
|
||||
let output = B::max_pool1d_with_indices(x.primitive, kernel_size, stride, padding);
|
||||
let output =
|
||||
B::max_pool1d_with_indices(x.primitive, kernel_size, stride, padding, dilation);
|
||||
let output_tensor = prep.finish(output.output);
|
||||
|
||||
MaxPool1dWithIndices::new(output_tensor, output.indices)
|
||||
|
@ -672,6 +697,7 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
kernel_size: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
dilation: usize,
|
||||
output_grad: ADTensor<B, 3>,
|
||||
indices: IntTensor<B, 3>,
|
||||
) -> MaxPool1dBackward<ADBackendDecorator<B>> {
|
||||
|
@ -680,6 +706,7 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
output_grad.primitive,
|
||||
indices,
|
||||
);
|
||||
|
@ -691,19 +718,36 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
) -> ADTensor<B, 4> {
|
||||
match MaxPool2D.prepare([x.node], [x.graph]).statefull() {
|
||||
OpsKind::Tracked(prep) => {
|
||||
let output =
|
||||
B::max_pool2d_with_indices(x.primitive.clone(), kernel_size, stride, padding);
|
||||
let output = B::max_pool2d_with_indices(
|
||||
x.primitive.clone(),
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
);
|
||||
prep.finish(
|
||||
(x.primitive, output.indices, kernel_size, stride, padding),
|
||||
(
|
||||
x.primitive,
|
||||
output.indices,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
),
|
||||
output.output,
|
||||
)
|
||||
}
|
||||
OpsKind::UnTracked(prep) => {
|
||||
prep.finish(B::max_pool2d(x.primitive, kernel_size, stride, padding))
|
||||
}
|
||||
OpsKind::UnTracked(prep) => prep.finish(B::max_pool2d(
|
||||
x.primitive,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -712,11 +756,17 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
) -> MaxPool2dWithIndices<ADBackendDecorator<B>> {
|
||||
match MaxPool2D.prepare([x.node], [x.graph]).statefull() {
|
||||
OpsKind::Tracked(prep) => {
|
||||
let output =
|
||||
B::max_pool2d_with_indices(x.primitive.clone(), kernel_size, stride, padding);
|
||||
let output = B::max_pool2d_with_indices(
|
||||
x.primitive.clone(),
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
);
|
||||
|
||||
let output_tensor = prep.finish(
|
||||
(
|
||||
|
@ -725,6 +775,7 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
),
|
||||
output.output,
|
||||
);
|
||||
|
@ -732,7 +783,8 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
MaxPool2dWithIndices::new(output_tensor, output.indices)
|
||||
}
|
||||
OpsKind::UnTracked(prep) => {
|
||||
let output = B::max_pool2d_with_indices(x.primitive, kernel_size, stride, padding);
|
||||
let output =
|
||||
B::max_pool2d_with_indices(x.primitive, kernel_size, stride, padding, dilation);
|
||||
let output_tensor = prep.finish(output.output);
|
||||
|
||||
MaxPool2dWithIndices::new(output_tensor, output.indices)
|
||||
|
@ -745,6 +797,7 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
_kernel_size: [usize; 2],
|
||||
_stride: [usize; 2],
|
||||
_padding: [usize; 2],
|
||||
_dilation: [usize; 2],
|
||||
_output_grad: ADTensor<B, 4>,
|
||||
_indices: IntTensor<B, 4>,
|
||||
) -> MaxPool2dBackward<ADBackendDecorator<B>> {
|
||||
|
@ -820,16 +873,30 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
struct MaxPool1D;
|
||||
|
||||
impl<B: Backend> Backward<B, 3, 1> for MaxPool1D {
|
||||
type State = (B::TensorPrimitive<3>, IntTensor<B, 3>, usize, usize, usize);
|
||||
type State = (
|
||||
B::TensorPrimitive<3>,
|
||||
IntTensor<B, 3>,
|
||||
usize,
|
||||
usize,
|
||||
usize,
|
||||
usize,
|
||||
);
|
||||
|
||||
fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
|
||||
let [node_parent] = ops.parents;
|
||||
let grad = grads.consume::<B, 3>(&ops.node);
|
||||
let (x, indices, kernel_size, stride, padding) = ops.state;
|
||||
let (x, indices, kernel_size, stride, padding, dilation) = ops.state;
|
||||
|
||||
if let Some(node) = node_parent {
|
||||
let grad =
|
||||
B::max_pool1d_with_indices_backward(x, kernel_size, stride, padding, grad, indices);
|
||||
let grad = B::max_pool1d_with_indices_backward(
|
||||
x,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
grad,
|
||||
indices,
|
||||
);
|
||||
|
||||
grads.register::<B, 3>(node, grad.x_grad);
|
||||
}
|
||||
|
@ -846,16 +913,24 @@ impl<B: Backend> Backward<B, 4, 1> for MaxPool2D {
|
|||
[usize; 2],
|
||||
[usize; 2],
|
||||
[usize; 2],
|
||||
[usize; 2],
|
||||
);
|
||||
|
||||
fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
|
||||
let [node_parent] = ops.parents;
|
||||
let grad = grads.consume::<B, 4>(&ops.node);
|
||||
let (x, indices, kernel_size, stride, padding) = ops.state;
|
||||
let (x, indices, kernel_size, stride, padding, dilation) = ops.state;
|
||||
|
||||
if let Some(node) = node_parent {
|
||||
let grad =
|
||||
B::max_pool2d_with_indices_backward(x, kernel_size, stride, padding, grad, indices);
|
||||
let grad = B::max_pool2d_with_indices_backward(
|
||||
x,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
grad,
|
||||
indices,
|
||||
);
|
||||
|
||||
grads.register::<B, 4>(node, grad.x_grad);
|
||||
}
|
||||
|
|
|
@ -5,17 +5,44 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_max_pool1d_simple() {
|
||||
let batch_size = 1;
|
||||
let channels_in = 1;
|
||||
let kernel_size = 4;
|
||||
let padding = 0;
|
||||
let stride = 1;
|
||||
let dilation = 1;
|
||||
|
||||
let x = TestADTensor::from_floats([[[0.9861, 0.5474, 0.4477, 0.0732, 0.3548, 0.8221]]])
|
||||
.require_grad();
|
||||
let x_grad_expected = TestADTensor::from_floats([[[1., 1., 0., 0., 0., 1.]]]);
|
||||
|
||||
let output = max_pool1d(x.clone(), kernel_size, stride, padding);
|
||||
let output = max_pool1d(x.clone(), kernel_size, stride, padding, dilation);
|
||||
let grads = output.backward();
|
||||
|
||||
// Asserts
|
||||
let x_grad_actual = x.grad(&grads).unwrap();
|
||||
x_grad_expected
|
||||
.to_data()
|
||||
.assert_approx_eq(&x_grad_actual.to_data(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_pool1d_with_dilation() {
|
||||
let kernel_size = 4;
|
||||
let padding = 0;
|
||||
let stride = 1;
|
||||
let dilation = 2;
|
||||
|
||||
let x = TestADTensor::from_floats([[[
|
||||
0.5388, 0.0676, 0.7122, 0.8316, 0.0653, 0.9154, 0.1536, 0.9089, 0.8016, 0.7518, 0.2073,
|
||||
0.0501, 0.8811, 0.5604, 0.5075, 0.4384, 0.9963, 0.9698, 0.4988, 0.2609, 0.3391, 0.2230,
|
||||
0.4610, 0.5365, 0.6880,
|
||||
]]])
|
||||
.require_grad();
|
||||
let x_grad_expected = TestADTensor::from_floats([[[
|
||||
0., 0., 1., 0., 0., 3., 0., 1., 2., 1., 0., 0., 2., 0., 0., 0., 4., 4., 0., 0., 0., 0.,
|
||||
0., 0., 1.,
|
||||
]]]);
|
||||
|
||||
let output = max_pool1d(x.clone(), kernel_size, stride, padding, dilation);
|
||||
let grads = output.backward();
|
||||
|
||||
// Asserts
|
||||
|
@ -27,11 +54,10 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_max_pool1d_complex() {
|
||||
let batch_size = 1;
|
||||
let channels_in = 1;
|
||||
let kernel_size = 4;
|
||||
let padding = 0;
|
||||
let stride = 1;
|
||||
let dilation = 1;
|
||||
|
||||
let x = TestADTensor::from_floats([[[
|
||||
0.5388, 0.0676, 0.7122, 0.8316, 0.0653, 0.9154, 0.1536, 0.9089, 0.8016, 0.7518, 0.2073,
|
||||
|
@ -44,7 +70,7 @@ mod tests {
|
|||
1., 1., 1.,
|
||||
]]]);
|
||||
|
||||
let output = max_pool1d(x.clone(), kernel_size, stride, padding);
|
||||
let output = max_pool1d(x.clone(), kernel_size, stride, padding, dilation);
|
||||
let grads = output.backward();
|
||||
|
||||
// Asserts
|
||||
|
@ -56,11 +82,10 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_max_pool1d_complex_with_padding() {
|
||||
let batch_size = 1;
|
||||
let channels_in = 1;
|
||||
let kernel_size = 4;
|
||||
let padding = 2;
|
||||
let stride = 1;
|
||||
let dilation = 1;
|
||||
|
||||
let x = TestADTensor::from_floats([[[
|
||||
0.5388, 0.0676, 0.7122, 0.8316, 0.0653, 0.9154, 0.1536, 0.9089, 0.8016, 0.7518, 0.2073,
|
||||
|
@ -73,7 +98,7 @@ mod tests {
|
|||
1., 1., 3.,
|
||||
]]]);
|
||||
|
||||
let output = max_pool1d(x.clone(), kernel_size, stride, padding);
|
||||
let output = max_pool1d(x.clone(), kernel_size, stride, padding, dilation);
|
||||
let grads = output.backward();
|
||||
|
||||
// Asserts
|
||||
|
|
|
@ -11,6 +11,8 @@ mod tests {
|
|||
let padding_2 = 0;
|
||||
let stride_1 = 1;
|
||||
let stride_2 = 1;
|
||||
let dilation_1 = 1;
|
||||
let dilation_2 = 1;
|
||||
|
||||
let x = TestADTensor::from_floats([[[
|
||||
[0.2479, 0.6386, 0.3166, 0.5742],
|
||||
|
@ -31,6 +33,7 @@ mod tests {
|
|||
[kernel_size_1, kernel_size_2],
|
||||
[stride_1, stride_2],
|
||||
[padding_1, padding_2],
|
||||
[dilation_1, dilation_2],
|
||||
);
|
||||
let grads = output.backward();
|
||||
|
||||
|
@ -49,6 +52,8 @@ mod tests {
|
|||
let padding_2 = 1;
|
||||
let stride_1 = 1;
|
||||
let stride_2 = 1;
|
||||
let dilation_1 = 1;
|
||||
let dilation_2 = 1;
|
||||
|
||||
let x = TestADTensor::from_floats([[[
|
||||
[0.2479, 0.6386, 0.3166, 0.5742],
|
||||
|
@ -69,6 +74,48 @@ mod tests {
|
|||
[kernel_size_1, kernel_size_2],
|
||||
[stride_1, stride_2],
|
||||
[padding_1, padding_2],
|
||||
[dilation_1, dilation_2],
|
||||
);
|
||||
let grads = output.backward();
|
||||
|
||||
// Asserts
|
||||
let x_grad_actual = x.grad(&grads).unwrap();
|
||||
x_grad_expected
|
||||
.to_data()
|
||||
.assert_approx_eq(&x_grad_actual.to_data(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_pool2d_with_dilation() {
|
||||
let kernel_size_1 = 2;
|
||||
let kernel_size_2 = 2;
|
||||
let padding_1 = 1;
|
||||
let padding_2 = 1;
|
||||
let stride_1 = 1;
|
||||
let stride_2 = 1;
|
||||
let dilation_1 = 2;
|
||||
let dilation_2 = 2;
|
||||
|
||||
let x = TestADTensor::from_floats([[[
|
||||
[0.2479, 0.6386, 0.3166, 0.5742],
|
||||
[0.7065, 0.1940, 0.6305, 0.8959],
|
||||
[0.5416, 0.8602, 0.8129, 0.1662],
|
||||
[0.3358, 0.3059, 0.8293, 0.0990],
|
||||
]]])
|
||||
.require_grad();
|
||||
let x_grad_expected = TestADTensor::from_floats([[[
|
||||
[0., 0., 0., 0.],
|
||||
[1., 1., 1., 2.],
|
||||
[0., 4., 4., 0.],
|
||||
[0., 1., 2., 0.],
|
||||
]]]);
|
||||
|
||||
let output = max_pool2d(
|
||||
x.clone(),
|
||||
[kernel_size_1, kernel_size_2],
|
||||
[stride_1, stride_2],
|
||||
[padding_1, padding_2],
|
||||
[dilation_1, dilation_2],
|
||||
);
|
||||
let grads = output.backward();
|
||||
|
||||
|
@ -87,6 +134,8 @@ mod tests {
|
|||
let padding_2 = 1;
|
||||
let stride_1 = 1;
|
||||
let stride_2 = 2;
|
||||
let dilation_1 = 1;
|
||||
let dilation_2 = 1;
|
||||
|
||||
let x = TestADTensor::from_floats([[[
|
||||
[0.5388, 0.0676, 0.7122, 0.8316, 0.0653],
|
||||
|
@ -109,6 +158,7 @@ mod tests {
|
|||
[kernel_size_1, kernel_size_2],
|
||||
[stride_1, stride_2],
|
||||
[padding_1, padding_2],
|
||||
[dilation_1, dilation_2],
|
||||
);
|
||||
let grads = output.backward();
|
||||
|
||||
|
|
|
@ -56,6 +56,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<CandleBackend<F, I>>
|
|||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
) -> FloatTensor<Self, 4> {
|
||||
todo!()
|
||||
}
|
||||
|
@ -65,6 +66,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<CandleBackend<F, I>>
|
|||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
) -> MaxPool2dWithIndices<CandleBackend<F, I>> {
|
||||
todo!()
|
||||
}
|
||||
|
@ -74,6 +76,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<CandleBackend<F, I>>
|
|||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
output_grad: FloatTensor<Self, 4>,
|
||||
indices: IntTensor<Self, 4>,
|
||||
) -> MaxPool2dBackward<CandleBackend<F, I>> {
|
||||
|
|
|
@ -18,6 +18,9 @@ pub struct MaxPool1dConfig {
|
|||
/// The padding configuration.
|
||||
#[config(default = "PaddingConfig1d::Valid")]
|
||||
pub padding: PaddingConfig1d,
|
||||
/// The dilation.
|
||||
#[config(default = "1")]
|
||||
pub dilation: usize,
|
||||
}
|
||||
|
||||
/// Applies a 1D max pooling over input tensors.
|
||||
|
@ -26,6 +29,7 @@ pub struct MaxPool1d {
|
|||
stride: usize,
|
||||
kernel_size: usize,
|
||||
padding: PaddingConfig1d,
|
||||
dilation: usize,
|
||||
}
|
||||
|
||||
impl MaxPool1dConfig {
|
||||
|
@ -35,6 +39,7 @@ impl MaxPool1dConfig {
|
|||
stride: self.stride,
|
||||
kernel_size: self.kernel_size,
|
||||
padding: self.padding.clone(),
|
||||
dilation: self.dilation,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -52,6 +57,6 @@ impl MaxPool1d {
|
|||
.padding
|
||||
.calculate_padding_1d(length, self.kernel_size, self.stride);
|
||||
|
||||
max_pool1d(input, self.kernel_size, self.stride, padding)
|
||||
max_pool1d(input, self.kernel_size, self.stride, padding, self.dilation)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,6 +18,9 @@ pub struct MaxPool2dConfig {
|
|||
/// The padding configuration.
|
||||
#[config(default = "PaddingConfig2d::Valid")]
|
||||
pub padding: PaddingConfig2d,
|
||||
/// The dilation.
|
||||
#[config(default = "[1, 1]")]
|
||||
pub dilation: [usize; 2],
|
||||
}
|
||||
|
||||
/// Applies a 2D max pooling over input tensors.
|
||||
|
@ -26,6 +29,7 @@ pub struct MaxPool2d {
|
|||
stride: [usize; 2],
|
||||
kernel_size: [usize; 2],
|
||||
padding: PaddingConfig2d,
|
||||
dilation: [usize; 2],
|
||||
}
|
||||
|
||||
impl MaxPool2dConfig {
|
||||
|
@ -35,6 +39,7 @@ impl MaxPool2dConfig {
|
|||
stride: self.strides,
|
||||
kernel_size: self.kernel_size,
|
||||
padding: self.padding.clone(),
|
||||
dilation: self.dilation,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -52,6 +57,6 @@ impl MaxPool2d {
|
|||
self.padding
|
||||
.calculate_padding_2d(height_in, width_in, &self.kernel_size, &self.stride);
|
||||
|
||||
max_pool2d(input, self.kernel_size, self.stride, padding)
|
||||
max_pool2d(input, self.kernel_size, self.stride, padding, self.dilation)
|
||||
}
|
||||
}
|
||||
|
|
Binary file not shown.
|
@ -10,8 +10,7 @@ class Model(nn.Module):
|
|||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
|
||||
# TODO support dilation=(3, 1) (see https://github.com/burn-rs/burn/issues/622)
|
||||
self.maxpool2d1 = nn.MaxPool2d((4, 2), stride=(2, 1), padding=(2, 1))
|
||||
self.maxpool2d1 = nn.MaxPool2d((4, 2), stride=(2, 1), padding=(2, 1), dilation=(1, 3))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.maxpool2d1(x)
|
||||
|
|
|
@ -305,9 +305,9 @@ mod tests {
|
|||
]]]);
|
||||
let output = model.forward(input);
|
||||
let expected = Data::from([[[
|
||||
[1.927, 1.927, 1.487, 0.901, 0.678, 0.678],
|
||||
[1.927, 1.927, 1.487, 0.901, 0.803, 0.678],
|
||||
[-0.217, 0.241, 0.241, 0.803, 0.803, -0.622],
|
||||
[0.901, 1.927, 1.487, 0.901],
|
||||
[0.901, 1.927, 1.487, 0.901],
|
||||
[-0.396, 0.803, 0.241, -0.396],
|
||||
]]]);
|
||||
|
||||
assert_eq!(output.to_data(), expected);
|
||||
|
|
|
@ -51,6 +51,7 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for MaxPool2dNode {
|
|||
let kernel_size = self.config.kernel_size.to_tokens();
|
||||
let strides = self.config.strides.to_tokens();
|
||||
let padding = self.config.padding.to_tokens();
|
||||
let dilation = self.config.dilation.to_tokens();
|
||||
|
||||
let init_line = quote! {
|
||||
init();
|
||||
|
@ -60,6 +61,7 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for MaxPool2dNode {
|
|||
let #name = MaxPool2dConfig::new(#kernel_size)
|
||||
.with_strides(#strides)
|
||||
.with_padding(#padding)
|
||||
.with_dilation(#dilation)
|
||||
.#init_line
|
||||
};
|
||||
|
||||
|
@ -111,7 +113,8 @@ mod tests {
|
|||
TensorType::new_float("output", 4),
|
||||
MaxPool2dConfig::new([3, 3])
|
||||
.with_strides([1, 1])
|
||||
.with_padding(PaddingConfig2d::Valid),
|
||||
.with_padding(PaddingConfig2d::Valid)
|
||||
.with_dilation([1, 1]),
|
||||
));
|
||||
|
||||
graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);
|
||||
|
@ -137,6 +140,7 @@ mod tests {
|
|||
let max_pool2d = MaxPool2dConfig::new([3, 3])
|
||||
.with_strides([1, 1])
|
||||
.with_padding(PaddingConfig2d::Valid)
|
||||
.with_dilation([1, 1])
|
||||
.init();
|
||||
|
||||
Self {
|
||||
|
|
|
@ -123,7 +123,7 @@ pub fn max_pool2d_config(curr: &Node) -> MaxPool2dConfig {
|
|||
let mut kernel_shape = Vec::new();
|
||||
let mut strides = Vec::new();
|
||||
let mut pads = Vec::new();
|
||||
let mut dilations = Vec::new();
|
||||
let mut dilations = vec![1, 1];
|
||||
|
||||
for (key, value) in curr.attrs.iter() {
|
||||
match key.as_str() {
|
||||
|
@ -135,15 +135,12 @@ pub fn max_pool2d_config(curr: &Node) -> MaxPool2dConfig {
|
|||
}
|
||||
}
|
||||
|
||||
if !dilations.is_empty() && (dilations[0] != 1 || dilations[1] != 1) {
|
||||
todo!("MaxPool2d: dilations are not supported. See https://github.com/burn-rs/burn/issues/622");
|
||||
}
|
||||
|
||||
let padding = padding_config(&pads);
|
||||
|
||||
MaxPool2dConfig::new([kernel_shape[0] as usize, kernel_shape[1] as usize])
|
||||
.with_strides([strides[0] as usize, strides[1] as usize])
|
||||
.with_padding(padding)
|
||||
.with_dilation([dilations[0] as usize, dilations[1] as usize])
|
||||
}
|
||||
|
||||
/// Create a AvgPool2dConfig from the attributes of the node
|
||||
|
|
|
@ -11,15 +11,21 @@ pub(crate) fn max_pool2d<E: FloatNdArrayElement>(
|
|||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
) -> NdArrayTensor<E, 4> {
|
||||
let [kernel_height, kernel_width] = kernel_size;
|
||||
let [padding_height, padding_width] = padding;
|
||||
let [stride_height, stride_width] = stride;
|
||||
let [dilation_height, dilation_width] = dilation;
|
||||
let [batch_size, channels, x_height, x_width] = x.shape().dims;
|
||||
let inf = (-f32::INFINITY).elem::<E>();
|
||||
|
||||
let out_height = ((x_height + 2 * padding_height - kernel_height) / stride_height) + 1;
|
||||
let out_width = ((x_width + 2 * padding_width - kernel_width) / stride_width) + 1;
|
||||
let out_height = ((x_height + 2 * padding_height - dilation_height * (kernel_height - 1) - 1)
|
||||
/ stride_height)
|
||||
+ 1;
|
||||
let out_width = ((x_width + 2 * padding_width - dilation_width * (kernel_width - 1) - 1)
|
||||
/ stride_width)
|
||||
+ 1;
|
||||
|
||||
let x = apply_padding_4d(x, padding, inf).array;
|
||||
|
||||
|
@ -38,10 +44,10 @@ pub(crate) fn max_pool2d<E: FloatNdArrayElement>(
|
|||
let mut max_val = inf;
|
||||
|
||||
for kh in 0..kernel_height {
|
||||
let ih = oh * stride_height + kh;
|
||||
let ih = oh * stride_height + kh * dilation_height;
|
||||
|
||||
for kw in 0..kernel_width {
|
||||
let iw = ow * stride_width + kw;
|
||||
let iw = ow * stride_width + kw * dilation_width;
|
||||
|
||||
let val = x[[b, c, ih, iw]];
|
||||
|
||||
|
@ -65,15 +71,21 @@ pub(crate) fn max_pool2d_with_indices<E: FloatNdArrayElement>(
|
|||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
) -> (NdArrayTensor<E, 4>, NdArrayTensor<i64, 4>) {
|
||||
let [kernel_height, kernel_width] = kernel_size;
|
||||
let [padding_height, padding_width] = padding;
|
||||
let [stride_height, stride_width] = stride;
|
||||
let [dilation_height, dilation_width] = dilation;
|
||||
let [batch_size, channels, x_height, x_width] = x.shape().dims;
|
||||
let inf = (-f32::INFINITY).elem::<E>();
|
||||
|
||||
let out_height = ((x_height + 2 * padding_height - kernel_height) / stride_height) + 1;
|
||||
let out_width = ((x_width + 2 * padding_width - kernel_width) / stride_width) + 1;
|
||||
let out_height = ((x_height + 2 * padding_height - dilation_height * (kernel_height - 1) - 1)
|
||||
/ stride_height)
|
||||
+ 1;
|
||||
let out_width = ((x_width + 2 * padding_width - dilation_width * (kernel_width - 1) - 1)
|
||||
/ stride_width)
|
||||
+ 1;
|
||||
|
||||
let x = apply_padding_4d(x, padding, inf).array;
|
||||
|
||||
|
@ -97,10 +109,10 @@ pub(crate) fn max_pool2d_with_indices<E: FloatNdArrayElement>(
|
|||
let mut index = 0;
|
||||
|
||||
for kh in 0..kernel_height {
|
||||
let ih = oh * stride_height + kh;
|
||||
let ih = oh * stride_height + kh * dilation_height;
|
||||
|
||||
for kw in 0..kernel_width {
|
||||
let iw = ow * stride_width + kw;
|
||||
let iw = ow * stride_width + kw * dilation_width;
|
||||
let val = x[[b, c, ih, iw]];
|
||||
|
||||
if val > max_val {
|
||||
|
@ -132,6 +144,7 @@ pub(crate) fn max_pool2d_backward<E: FloatNdArrayElement>(
|
|||
_kernel_size: [usize; 2],
|
||||
_stride: [usize; 2],
|
||||
_padding: [usize; 2],
|
||||
_dilation: [usize; 2],
|
||||
output_grad: NdArrayTensor<E, 4>,
|
||||
indices: NdArrayTensor<i64, 4>,
|
||||
) -> NdArrayTensor<E, 4> {
|
||||
|
|
|
@ -52,8 +52,9 @@ impl<E: FloatNdArrayElement> ModuleOps<NdArrayBackend<E>> for NdArrayBackend<E>
|
|||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
) -> NdArrayTensor<E, 4> {
|
||||
max_pool2d(x, kernel_size, stride, padding)
|
||||
max_pool2d(x, kernel_size, stride, padding, dilation)
|
||||
}
|
||||
|
||||
fn max_pool2d_with_indices(
|
||||
|
@ -61,8 +62,9 @@ impl<E: FloatNdArrayElement> ModuleOps<NdArrayBackend<E>> for NdArrayBackend<E>
|
|||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
) -> MaxPool2dWithIndices<NdArrayBackend<E>> {
|
||||
let (output, indices) = max_pool2d_with_indices(x, kernel_size, stride, padding);
|
||||
let (output, indices) = max_pool2d_with_indices(x, kernel_size, stride, padding, dilation);
|
||||
|
||||
MaxPool2dWithIndices::new(output, indices)
|
||||
}
|
||||
|
@ -72,6 +74,7 @@ impl<E: FloatNdArrayElement> ModuleOps<NdArrayBackend<E>> for NdArrayBackend<E>
|
|||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
output_grad: NdArrayTensor<E, 4>,
|
||||
indices: NdArrayTensor<i64, 4>,
|
||||
) -> MaxPool2dBackward<NdArrayBackend<E>> {
|
||||
|
@ -80,6 +83,7 @@ impl<E: FloatNdArrayElement> ModuleOps<NdArrayBackend<E>> for NdArrayBackend<E>
|
|||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
output_grad,
|
||||
indices,
|
||||
))
|
||||
|
|
|
@ -172,13 +172,14 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
|
|||
kernel_size: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
dilation: usize,
|
||||
) -> TchTensor<E, 3> {
|
||||
let tensor = tch::Tensor::max_pool1d(
|
||||
&x.tensor,
|
||||
kernel_size as i64,
|
||||
stride as i64,
|
||||
padding as i64,
|
||||
1,
|
||||
dilation as i64,
|
||||
false,
|
||||
);
|
||||
|
||||
|
@ -190,13 +191,14 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
|
|||
kernel_size: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
dilation: usize,
|
||||
) -> MaxPool1dWithIndices<TchBackend<E>> {
|
||||
let (tensor, indices) = tch::Tensor::max_pool1d_with_indices(
|
||||
&x.tensor,
|
||||
kernel_size as i64,
|
||||
stride as i64,
|
||||
padding as i64,
|
||||
1,
|
||||
dilation as i64,
|
||||
false,
|
||||
);
|
||||
|
||||
|
@ -208,13 +210,14 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
|
|||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
) -> TchTensor<E, 4> {
|
||||
let tensor = tch::Tensor::max_pool2d(
|
||||
&x.tensor,
|
||||
[kernel_size[0] as i64, kernel_size[1] as i64],
|
||||
[stride[0] as i64, stride[1] as i64],
|
||||
[padding[0] as i64, padding[1] as i64],
|
||||
[1, 1],
|
||||
[dilation[0] as i64, dilation[1] as i64],
|
||||
false,
|
||||
);
|
||||
|
||||
|
@ -226,13 +229,14 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
|
|||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
) -> MaxPool2dWithIndices<TchBackend<E>> {
|
||||
let (tensor, indices) = tch::Tensor::max_pool2d_with_indices(
|
||||
&x.tensor,
|
||||
[kernel_size[0] as i64, kernel_size[1] as i64],
|
||||
[stride[0] as i64, stride[1] as i64],
|
||||
[padding[0] as i64, padding[1] as i64],
|
||||
[1, 1],
|
||||
[dilation[0] as i64, dilation[1] as i64],
|
||||
false,
|
||||
);
|
||||
|
||||
|
@ -244,6 +248,7 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
|
|||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
output_grad: TchTensor<E, 4>,
|
||||
indices: TchTensor<i64, 4>,
|
||||
) -> MaxPool2dBackward<TchBackend<E>> {
|
||||
|
@ -253,7 +258,7 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
|
|||
[kernel_size[0] as i64, kernel_size[1] as i64],
|
||||
[stride[0] as i64, stride[1] as i64],
|
||||
[padding[0] as i64, padding[1] as i64],
|
||||
[1, 1],
|
||||
[dilation[0] as i64, dilation[1] as i64],
|
||||
false,
|
||||
&indices.tensor,
|
||||
);
|
||||
|
|
|
@ -90,11 +90,18 @@ pub fn max_pool1d<B>(
|
|||
kernel_size: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
dilation: usize,
|
||||
) -> Tensor<B, 3>
|
||||
where
|
||||
B: Backend,
|
||||
{
|
||||
Tensor::new(B::max_pool1d(x.primitive, kernel_size, stride, padding))
|
||||
Tensor::new(B::max_pool1d(
|
||||
x.primitive,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
))
|
||||
}
|
||||
|
||||
/// Applies a [2D max pooling](crate::ops::ModuleOps::max_pool2d).
|
||||
|
@ -103,11 +110,18 @@ pub fn max_pool2d<B>(
|
|||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
) -> Tensor<B, 4>
|
||||
where
|
||||
B: Backend,
|
||||
{
|
||||
Tensor::new(B::max_pool2d(x.primitive, kernel_size, stride, padding))
|
||||
Tensor::new(B::max_pool2d(
|
||||
x.primitive,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
))
|
||||
}
|
||||
|
||||
/// Applies a [2D avg pooling](crate::ops::ModuleOps::avg_pool2d).
|
||||
|
@ -156,11 +170,12 @@ pub fn max_pool1d_with_indices<B>(
|
|||
kernel_size: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
dilation: usize,
|
||||
) -> (Tensor<B, 3>, Tensor<B, 3, Int>)
|
||||
where
|
||||
B: Backend,
|
||||
{
|
||||
let output = B::max_pool1d_with_indices(x.primitive, kernel_size, stride, padding);
|
||||
let output = B::max_pool1d_with_indices(x.primitive, kernel_size, stride, padding, dilation);
|
||||
|
||||
(Tensor::new(output.output), Tensor::new(output.indices))
|
||||
}
|
||||
|
@ -171,11 +186,12 @@ pub fn max_pool2d_with_indices<B>(
|
|||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
) -> (Tensor<B, 4>, Tensor<B, 4, Int>)
|
||||
where
|
||||
B: Backend,
|
||||
{
|
||||
let output = B::max_pool2d_with_indices(x.primitive, kernel_size, stride, padding);
|
||||
let output = B::max_pool2d_with_indices(x.primitive, kernel_size, stride, padding, dilation);
|
||||
|
||||
(Tensor::new(output.output), Tensor::new(output.indices))
|
||||
}
|
||||
|
|
|
@ -337,8 +337,9 @@ pub trait ModuleOps<B: Backend> {
|
|||
kernel_size: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
dilation: usize,
|
||||
) -> B::TensorPrimitive<3> {
|
||||
pool::max_pool1d_from_2d::<B>(x, kernel_size, stride, padding)
|
||||
pool::max_pool1d_from_2d::<B>(x, kernel_size, stride, padding, dilation)
|
||||
}
|
||||
|
||||
/// One dimensional max pooling with indices.
|
||||
|
@ -351,8 +352,9 @@ pub trait ModuleOps<B: Backend> {
|
|||
kernel_size: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
dilation: usize,
|
||||
) -> MaxPool1dWithIndices<B> {
|
||||
pool::max_pool1d_with_indices_from_2d::<B>(x, kernel_size, stride, padding)
|
||||
pool::max_pool1d_with_indices_from_2d::<B>(x, kernel_size, stride, padding, dilation)
|
||||
}
|
||||
/// Backward pass for the [max pooling 1d](ModuleOps::max_pool1d_with_indices) operation.
|
||||
fn max_pool1d_with_indices_backward(
|
||||
|
@ -360,6 +362,7 @@ pub trait ModuleOps<B: Backend> {
|
|||
kernel_size: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
dilation: usize,
|
||||
output_grad: B::TensorPrimitive<3>,
|
||||
indices: B::IntTensorPrimitive<3>,
|
||||
) -> MaxPool1dBackward<B> {
|
||||
|
@ -368,6 +371,7 @@ pub trait ModuleOps<B: Backend> {
|
|||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
output_grad,
|
||||
indices,
|
||||
)
|
||||
|
@ -383,6 +387,7 @@ pub trait ModuleOps<B: Backend> {
|
|||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
) -> B::TensorPrimitive<4>;
|
||||
|
||||
/// Two dimensional max pooling with indices.
|
||||
|
@ -395,6 +400,7 @@ pub trait ModuleOps<B: Backend> {
|
|||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
) -> MaxPool2dWithIndices<B>;
|
||||
/// Backward pass for the [max pooling 2d](ModuleOps::max_pool2d_with_indices) operation.
|
||||
fn max_pool2d_with_indices_backward(
|
||||
|
@ -402,6 +408,7 @@ pub trait ModuleOps<B: Backend> {
|
|||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
output_grad: B::TensorPrimitive<4>,
|
||||
indices: B::IntTensorPrimitive<4>,
|
||||
) -> MaxPool2dBackward<B>;
|
||||
|
|
|
@ -85,11 +85,18 @@ pub(crate) fn max_pool1d_from_2d<B: Backend>(
|
|||
kernel_size: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
dilation: usize,
|
||||
) -> B::TensorPrimitive<3> {
|
||||
let [batch_size, channels, length] = B::shape(&x).dims;
|
||||
|
||||
let x = B::reshape(x, Shape::from([batch_size, channels, length, 1]));
|
||||
let x = B::max_pool2d(x, [kernel_size, 1], [stride, 1], [padding, 0]);
|
||||
let x = B::max_pool2d(
|
||||
x,
|
||||
[kernel_size, 1],
|
||||
[stride, 1],
|
||||
[padding, 0],
|
||||
[dilation, 1],
|
||||
);
|
||||
|
||||
let [batch_size, channels, length, _] = B::shape(&x).dims;
|
||||
|
||||
|
@ -101,11 +108,18 @@ pub(crate) fn max_pool1d_with_indices_from_2d<B: Backend>(
|
|||
kernel_size: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
dilation: usize,
|
||||
) -> MaxPool1dWithIndices<B> {
|
||||
let [batch_size, channels, length] = B::shape(&x).dims;
|
||||
|
||||
let x = B::reshape(x, Shape::from([batch_size, channels, 1, length]));
|
||||
let x = B::max_pool2d_with_indices(x, [1, kernel_size], [1, stride], [0, padding]);
|
||||
let x = B::max_pool2d_with_indices(
|
||||
x,
|
||||
[1, kernel_size],
|
||||
[1, stride],
|
||||
[0, padding],
|
||||
[1, dilation],
|
||||
);
|
||||
let [batch_size, channels, _, length] = B::shape(&x.output).dims;
|
||||
let output = B::reshape(x.output, Shape::from([batch_size, channels, length]));
|
||||
let indices = B::int_reshape(
|
||||
|
@ -120,6 +134,7 @@ pub(crate) fn max_pool1d_with_indices_backward_from_2d<B: Backend>(
|
|||
kernel_size: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
dilation: usize,
|
||||
output_grad: B::TensorPrimitive<3>,
|
||||
indices: B::IntTensorPrimitive<3>,
|
||||
) -> MaxPool1dBackward<B> {
|
||||
|
@ -138,6 +153,7 @@ pub(crate) fn max_pool1d_with_indices_backward_from_2d<B: Backend>(
|
|||
[kernel_size, 1],
|
||||
[stride, 1],
|
||||
[padding, 0],
|
||||
[dilation, 1],
|
||||
grad_x,
|
||||
indices,
|
||||
)
|
||||
|
|
|
@ -8,11 +8,10 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_max_pool1d_simple() {
|
||||
let batch_size = 2;
|
||||
let channels_in = 2;
|
||||
let kernel_size = 3;
|
||||
let padding = 1;
|
||||
let stride = 1;
|
||||
let dilation = 1;
|
||||
|
||||
let x = TestTensor::from_floats([[
|
||||
[0.9861, 0.5474, 0.4477, 0.0732, 0.3548, 0.8221],
|
||||
|
@ -23,56 +22,75 @@ mod tests {
|
|||
[0.8148, 0.9490, 0.9490, 0.9490, 0.7890, 0.5689],
|
||||
]]);
|
||||
|
||||
let output = max_pool1d(x, kernel_size, stride, padding);
|
||||
let output = max_pool1d(x, kernel_size, stride, padding, dilation);
|
||||
|
||||
y.to_data().assert_approx_eq(&output.into_data(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_pool1d_different_padding_stride_kernel() {
|
||||
let batch_size = 1;
|
||||
let channels_in = 1;
|
||||
let kernel_size = 3;
|
||||
let padding = 1;
|
||||
let stride = 2;
|
||||
let dilation = 1;
|
||||
|
||||
let x = TestTensor::from_floats([[[0.6309, 0.6112, 0.6998, 0.4708]]]);
|
||||
let y = TestTensor::from_floats([[[0.6309, 0.6998]]]);
|
||||
|
||||
let output = max_pool1d(x, kernel_size, stride, padding);
|
||||
let output = max_pool1d(x, kernel_size, stride, padding, dilation);
|
||||
|
||||
y.to_data().assert_approx_eq(&output.into_data(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_pool1d_with_neg() {
|
||||
let batch_size = 1;
|
||||
let channels_in = 1;
|
||||
let kernel_size = 3;
|
||||
let padding = 1;
|
||||
let stride = 1;
|
||||
let dilation = 1;
|
||||
|
||||
let x = TestTensor::from_floats([[[-0.6309, -0.6112, -0.6998, -0.4708]]]);
|
||||
let y = TestTensor::from_floats([[[-0.6112, -0.6112, -0.4708, -0.4708]]]);
|
||||
|
||||
let output = max_pool1d(x, kernel_size, stride, padding);
|
||||
let output = max_pool1d(x, kernel_size, stride, padding, dilation);
|
||||
|
||||
y.to_data().assert_approx_eq(&output.into_data(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_pool1d_with_dilation() {
|
||||
let kernel_size = 2;
|
||||
let padding = 1;
|
||||
let stride = 1;
|
||||
let dilation = 2;
|
||||
|
||||
let x = TestTensor::from_floats([[
|
||||
[0.9861, 0.5474, 0.4477, 0.0732, 0.3548, 0.8221],
|
||||
[0.8148, 0.5474, 0.9490, 0.7890, 0.5537, 0.5689],
|
||||
]]);
|
||||
let y = TestTensor::from_floats([[
|
||||
[0.5474, 0.9861, 0.5474, 0.4477, 0.8221, 0.3548],
|
||||
[0.5474, 0.9490, 0.7890, 0.9490, 0.7890, 0.5537],
|
||||
]]);
|
||||
|
||||
let output = max_pool1d(x, kernel_size, stride, padding, dilation);
|
||||
|
||||
y.to_data().assert_approx_eq(&output.into_data(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_pool1d_with_indices() {
|
||||
let batch_size = 1;
|
||||
let channels_in = 1;
|
||||
let kernel_size = 2;
|
||||
let padding = 1;
|
||||
let stride = 1;
|
||||
let dilation = 1;
|
||||
|
||||
let x = TestTensor::from_floats([[[0.2479, 0.6386, 0.3166, 0.5742]]]);
|
||||
let indices = Data::<IntElem, 3>::from([[[0, 1, 1, 3, 3]]]);
|
||||
let y = TestTensor::from_floats([[[0.2479, 0.6386, 0.6386, 0.5742, 0.5742]]]);
|
||||
|
||||
let (output, output_indices) = max_pool1d_with_indices(x, kernel_size, stride, padding);
|
||||
let (output, output_indices) =
|
||||
max_pool1d_with_indices(x, kernel_size, stride, padding, dilation);
|
||||
|
||||
y.to_data().assert_approx_eq(&output.into_data(), 3);
|
||||
assert_eq!(indices.value, output_indices.into_data().value);
|
||||
|
@ -80,17 +98,17 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_max_pool1d_complex() {
|
||||
let batch_size = 1;
|
||||
let channels_in = 1;
|
||||
let kernel_size = 4;
|
||||
let padding = 2;
|
||||
let stride = 1;
|
||||
let dilation = 1;
|
||||
|
||||
let x = TestTensor::from_floats([[[0.5388, 0.0676, 0.7122, 0.8316, 0.0653]]]);
|
||||
let indices = Data::<IntElem, 3>::from([[[0, 2, 3, 3, 3, 3]]]);
|
||||
let y = TestTensor::from_floats([[[0.5388, 0.7122, 0.8316, 0.8316, 0.8316, 0.8316]]]);
|
||||
|
||||
let (output, output_indices) = max_pool1d_with_indices(x, kernel_size, stride, padding);
|
||||
let (output, output_indices) =
|
||||
max_pool1d_with_indices(x, kernel_size, stride, padding, dilation);
|
||||
|
||||
y.to_data().assert_approx_eq(&output.into_data(), 3);
|
||||
assert_eq!(indices.value, output_indices.into_data().value);
|
||||
|
|
|
@ -16,6 +16,8 @@ mod tests {
|
|||
let padding_2 = 1;
|
||||
let stride_1 = 1;
|
||||
let stride_2 = 1;
|
||||
let dilation_1 = 1;
|
||||
let dilation_2 = 1;
|
||||
|
||||
let x = TestTensor::from_floats([
|
||||
[
|
||||
|
@ -99,6 +101,7 @@ mod tests {
|
|||
[kernel_size_1, kernel_size_2],
|
||||
[stride_1, stride_2],
|
||||
[padding_1, padding_2],
|
||||
[dilation_1, dilation_2],
|
||||
);
|
||||
|
||||
y.to_data().assert_approx_eq(&output.into_data(), 3);
|
||||
|
@ -114,6 +117,8 @@ mod tests {
|
|||
let padding_2 = 0;
|
||||
let stride_1 = 1;
|
||||
let stride_2 = 2;
|
||||
let dilation_1 = 1;
|
||||
let dilation_2 = 1;
|
||||
|
||||
let x = TestTensor::from_floats([[[
|
||||
[0.6309, 0.6112, 0.6998],
|
||||
|
@ -137,6 +142,7 @@ mod tests {
|
|||
[kernel_size_1, kernel_size_2],
|
||||
[stride_1, stride_2],
|
||||
[padding_1, padding_2],
|
||||
[dilation_1, dilation_2],
|
||||
);
|
||||
|
||||
y.to_data().assert_approx_eq(&output.into_data(), 3);
|
||||
|
@ -152,6 +158,8 @@ mod tests {
|
|||
let padding_2 = 1;
|
||||
let stride_1 = 1;
|
||||
let stride_2 = 1;
|
||||
let dilation_1 = 1;
|
||||
let dilation_2 = 1;
|
||||
|
||||
let x = TestTensor::from_floats([[[
|
||||
[0.6309, 0.6112, 0.6998],
|
||||
|
@ -176,12 +184,51 @@ mod tests {
|
|||
[kernel_size_1, kernel_size_2],
|
||||
[stride_1, stride_2],
|
||||
[padding_1, padding_2],
|
||||
[dilation_1, dilation_2],
|
||||
);
|
||||
|
||||
y.to_data().assert_approx_eq(&output.into_data(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_pool2d_with_dilation() {
|
||||
let batch_size = 1;
|
||||
let channels_in = 1;
|
||||
let kernel_size_1 = 2;
|
||||
let kernel_size_2 = 2;
|
||||
let padding_1 = 0;
|
||||
let padding_2 = 0;
|
||||
let stride_1 = 1;
|
||||
let stride_2 = 1;
|
||||
let dilation_1 = 2;
|
||||
let dilation_2 = 2;
|
||||
|
||||
let x = TestTensor::from_floats([[[
|
||||
[0.9861, 0.9861, 0.9490, 0.9490, 0.8221, 0.8221],
|
||||
[0.9861, 0.9861, 0.9490, 0.9490, 0.8221, 0.8221],
|
||||
[0.9540, 0.9540, 0.9540, 0.9490, 0.7890, 0.7111],
|
||||
[0.9540, 0.9540, 0.9540, 0.8855, 0.8855, 0.8855],
|
||||
[0.9540, 0.9540, 0.9540, 0.9432, 0.8855, 0.8855],
|
||||
[0.5063, 0.9432, 0.9432, 0.9432, 0.8855, 0.8855],
|
||||
]]]);
|
||||
let y = TestTensor::from_floats([[[
|
||||
[0.9861, 0.9861, 0.9540, 0.9490],
|
||||
[0.9861, 0.9861, 0.9540, 0.9490],
|
||||
[0.9540, 0.9540, 0.9540, 0.9490],
|
||||
[0.9540, 0.9540, 0.9540, 0.9432],
|
||||
]]]);
|
||||
|
||||
let output = max_pool2d(
|
||||
x,
|
||||
[kernel_size_1, kernel_size_2],
|
||||
[stride_1, stride_2],
|
||||
[padding_1, padding_2],
|
||||
[dilation_1, dilation_2],
|
||||
);
|
||||
|
||||
y.to_data().assert_approx_eq(&output.into_data(), 3);
|
||||
}
|
||||
|
||||
fn test_max_pool2d_with_indices() {
|
||||
let batch_size = 1;
|
||||
let channels_in = 1;
|
||||
|
@ -191,6 +238,8 @@ mod tests {
|
|||
let padding_2 = 1;
|
||||
let stride_1 = 1;
|
||||
let stride_2 = 1;
|
||||
let dilation_1 = 1;
|
||||
let dilation_2 = 1;
|
||||
|
||||
let x = TestTensor::from_floats([[[
|
||||
[0.2479, 0.6386, 0.3166, 0.5742],
|
||||
|
@ -218,6 +267,7 @@ mod tests {
|
|||
[kernel_size_1, kernel_size_2],
|
||||
[stride_1, stride_2],
|
||||
[padding_1, padding_2],
|
||||
[dilation_1, dilation_2],
|
||||
);
|
||||
|
||||
y.to_data().assert_approx_eq(&output.into_data(), 3);
|
||||
|
@ -234,6 +284,8 @@ mod tests {
|
|||
let padding_2 = 1;
|
||||
let stride_1 = 1;
|
||||
let stride_2 = 2;
|
||||
let dilation_1 = 1;
|
||||
let dilation_2 = 1;
|
||||
|
||||
let x = TestTensor::from_floats([[[
|
||||
[0.5388, 0.0676, 0.7122, 0.8316, 0.0653],
|
||||
|
@ -263,6 +315,7 @@ mod tests {
|
|||
[kernel_size_1, kernel_size_2],
|
||||
[stride_1, stride_2],
|
||||
[padding_1, padding_2],
|
||||
[dilation_1, dilation_2],
|
||||
);
|
||||
|
||||
y.to_data().assert_approx_eq(&output.into_data(), 3);
|
||||
|
|
|
@ -41,7 +41,8 @@ pub(crate) fn avg_pool2d<E: WgpuElement>(
|
|||
) -> WgpuTensor<E, 4> {
|
||||
const WORKGROUP: usize = 32;
|
||||
|
||||
let (info_buffer, output) = build_output_and_info_pool2d(&x, kernel_size, stride, padding);
|
||||
let (info_buffer, output) =
|
||||
build_output_and_info_pool2d(&x, kernel_size, stride, padding, [1, 1]);
|
||||
let kernel = match count_include_pad {
|
||||
true => x
|
||||
.context
|
||||
|
@ -77,7 +78,7 @@ pub(crate) fn avg_pool2d_backward<E: WgpuElement>(
|
|||
.context
|
||||
.create_buffer(num_elems * core::mem::size_of::<E>());
|
||||
let output = WgpuTensor::new(x.context.clone(), x.shape.clone(), buffer);
|
||||
let info_buffer = build_pool2d_info(&x, &grad, kernel_size, stride, padding);
|
||||
let info_buffer = build_pool2d_info(&x, &grad, kernel_size, stride, padding, [1, 1]);
|
||||
|
||||
let kernel =
|
||||
match count_include_pad {
|
||||
|
|
|
@ -9,14 +9,20 @@ pub fn build_output_and_info_pool2d<E: WgpuElement>(
|
|||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
) -> (Arc<Buffer>, WgpuTensor<E, 4>) {
|
||||
let [kernel_height, kernel_width] = kernel_size;
|
||||
let [padding_height, padding_width] = padding;
|
||||
let [stride_height, stride_width] = stride;
|
||||
let [dilation_height, dilation_width] = dilation;
|
||||
let [batch_size, channels, x_height, x_width] = x.shape.dims;
|
||||
|
||||
let out_height = ((x_height + 2 * padding_height - kernel_height) / stride_height) + 1;
|
||||
let out_width = ((x_width + 2 * padding_width - kernel_width) / stride_width) + 1;
|
||||
let out_height = ((x_height + 2 * padding_height - dilation_height * (kernel_height - 1) - 1)
|
||||
/ stride_height)
|
||||
+ 1;
|
||||
let out_width = ((x_width + 2 * padding_width - dilation_width * (kernel_width - 1) - 1)
|
||||
/ stride_width)
|
||||
+ 1;
|
||||
let shape_out = Shape::new([batch_size, channels, out_height, out_width]);
|
||||
let num_elems = shape_out.num_elements();
|
||||
|
||||
|
@ -25,7 +31,7 @@ pub fn build_output_and_info_pool2d<E: WgpuElement>(
|
|||
.create_buffer(num_elems * core::mem::size_of::<E>());
|
||||
let output = WgpuTensor::new(x.context.clone(), shape_out, buffer);
|
||||
|
||||
let info_buffer = build_pool2d_info(x, &output, kernel_size, stride, padding);
|
||||
let info_buffer = build_pool2d_info(x, &output, kernel_size, stride, padding, dilation);
|
||||
|
||||
(info_buffer, output)
|
||||
}
|
||||
|
@ -36,8 +42,9 @@ pub fn build_pool2d_info<E: WgpuElement>(
|
|||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
) -> Arc<Buffer> {
|
||||
let mut info: [u32; 22] = [0; 22];
|
||||
let mut info: [u32; 24] = [0; 24];
|
||||
info[0] = input.strides[0] as u32;
|
||||
info[1] = input.strides[1] as u32;
|
||||
info[2] = input.strides[2] as u32;
|
||||
|
@ -62,6 +69,8 @@ pub fn build_pool2d_info<E: WgpuElement>(
|
|||
info[19] = stride[1] as u32;
|
||||
info[20] = padding[0] as u32;
|
||||
info[21] = padding[1] as u32;
|
||||
info[22] = dilation[0] as u32;
|
||||
info[23] = dilation[1] as u32;
|
||||
|
||||
let info_buffer = input
|
||||
.context
|
||||
|
|
|
@ -24,10 +24,12 @@ pub(crate) fn max_pool2d<E: WgpuElement>(
|
|||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
) -> WgpuTensor<E, 4> {
|
||||
const WORKGROUP: usize = 32;
|
||||
|
||||
let (info_buffer, output) = build_output_and_info_pool2d(&x, kernel_size, stride, padding);
|
||||
let (info_buffer, output) =
|
||||
build_output_and_info_pool2d(&x, kernel_size, stride, padding, dilation);
|
||||
let kernel = x
|
||||
.context
|
||||
.compile_static::<KernelSettings<MaxPool2d, E, i32, WORKGROUP, WORKGROUP, 1>>();
|
||||
|
@ -46,10 +48,12 @@ pub(crate) fn max_pool2d_with_indices<E: WgpuElement, I: WgpuElement>(
|
|||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
) -> (WgpuTensor<E, 4>, WgpuTensor<I, 4>) {
|
||||
const WORKGROUP: usize = 32;
|
||||
|
||||
let (info_buffer, output) = build_output_and_info_pool2d(&x, kernel_size, stride, padding);
|
||||
let (info_buffer, output) =
|
||||
build_output_and_info_pool2d(&x, kernel_size, stride, padding, dilation);
|
||||
let num_elems = output.shape.num_elements();
|
||||
|
||||
let indices = WgpuTensor::new(
|
||||
|
@ -79,6 +83,7 @@ pub(crate) fn max_pool2d_with_indices_backward<E: WgpuElement, I: WgpuElement>(
|
|||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
) -> WgpuTensor<E, 4> {
|
||||
const WORKGROUP: usize = 32;
|
||||
|
||||
|
@ -91,7 +96,7 @@ pub(crate) fn max_pool2d_with_indices_backward<E: WgpuElement, I: WgpuElement>(
|
|||
.create_buffer(num_elems * core::mem::size_of::<E>());
|
||||
let output = WgpuTensor::new(x.context.clone(), x.shape.clone(), buffer);
|
||||
|
||||
let info_buffer = build_pool2d_info(&x, &grad, kernel_size, stride, padding);
|
||||
let info_buffer = build_pool2d_info(&x, &grad, kernel_size, stride, padding, dilation);
|
||||
|
||||
let kernel = x.context.compile_static::<KernelSettings<
|
||||
MaxPool2dWithIndicesBackward,
|
||||
|
@ -122,9 +127,10 @@ mod tests {
|
|||
let kernel_size = [3, 3];
|
||||
let stride = [2, 2];
|
||||
let padding = [1, 1];
|
||||
let dilation = [1, 1];
|
||||
|
||||
let pooled = module::max_pool2d(tensor, kernel_size, stride, padding);
|
||||
let pooled_ref = module::max_pool2d(tensor_ref, kernel_size, stride, padding);
|
||||
let pooled = module::max_pool2d(tensor, kernel_size, stride, padding, dilation);
|
||||
let pooled_ref = module::max_pool2d(tensor_ref, kernel_size, stride, padding, dilation);
|
||||
|
||||
pooled
|
||||
.into_data()
|
||||
|
@ -138,11 +144,12 @@ mod tests {
|
|||
let kernel_size = [3, 3];
|
||||
let stride = [2, 2];
|
||||
let padding = [1, 1];
|
||||
let dilation = [1, 1];
|
||||
|
||||
let (pooled, indices) =
|
||||
module::max_pool2d_with_indices(tensor, kernel_size, stride, padding);
|
||||
module::max_pool2d_with_indices(tensor, kernel_size, stride, padding, dilation);
|
||||
let (pooled_ref, indices_ref) =
|
||||
module::max_pool2d_with_indices(tensor_ref, kernel_size, stride, padding);
|
||||
module::max_pool2d_with_indices(tensor_ref, kernel_size, stride, padding, dilation);
|
||||
|
||||
pooled
|
||||
.into_data()
|
||||
|
@ -159,16 +166,23 @@ mod tests {
|
|||
let kernel_size = [3, 3];
|
||||
let stride = [2, 2];
|
||||
let padding = [1, 1];
|
||||
let dilation = [1, 1];
|
||||
|
||||
let (_, indices) =
|
||||
module::max_pool2d_with_indices(tensor.clone(), kernel_size, stride, padding);
|
||||
let (_, indices_ref) =
|
||||
module::max_pool2d_with_indices(tensor_ref.clone(), kernel_size, stride, padding);
|
||||
module::max_pool2d_with_indices(tensor.clone(), kernel_size, stride, padding, dilation);
|
||||
let (_, indices_ref) = module::max_pool2d_with_indices(
|
||||
tensor_ref.clone(),
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
);
|
||||
let grad = TestBackend::max_pool2d_with_indices_backward(
|
||||
tensor.into_primitive(),
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
grad_output.into_primitive(),
|
||||
indices.into_primitive(),
|
||||
)
|
||||
|
@ -178,6 +192,7 @@ mod tests {
|
|||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
grad_output_ref.into_primitive(),
|
||||
indices_ref.into_primitive(),
|
||||
)
|
||||
|
|
|
@ -59,8 +59,9 @@ where
|
|||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
) -> FloatTensor<Self, 4> {
|
||||
kernel::pool::max_pool2d(x, kernel_size, stride, padding)
|
||||
kernel::pool::max_pool2d(x, kernel_size, stride, padding, dilation)
|
||||
}
|
||||
|
||||
fn max_pool2d_with_indices(
|
||||
|
@ -68,9 +69,10 @@ where
|
|||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
) -> MaxPool2dWithIndices<WgpuBackend<G, F, I>> {
|
||||
let (output, indices) =
|
||||
kernel::pool::max_pool2d_with_indices(x, kernel_size, stride, padding);
|
||||
kernel::pool::max_pool2d_with_indices(x, kernel_size, stride, padding, dilation);
|
||||
|
||||
MaxPool2dWithIndices::new(output, indices)
|
||||
}
|
||||
|
@ -80,6 +82,7 @@ where
|
|||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
output_grad: FloatTensor<Self, 4>,
|
||||
indices: IntTensor<Self, 4>,
|
||||
) -> MaxPool2dBackward<WgpuBackend<G, F, I>> {
|
||||
|
@ -90,6 +93,7 @@ where
|
|||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
))
|
||||
}
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ var<storage, read_write> output: array<{{ elem }}>;
|
|||
|
||||
@group(0)
|
||||
@binding(2)
|
||||
var<storage, read> info: array<u32, 22>;
|
||||
var<storage, read> info: array<u32, 24>;
|
||||
|
||||
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
|
||||
|
||||
|
@ -44,6 +44,8 @@ fn main(
|
|||
let pool_stride_1 = info[19];
|
||||
let padding_0 = info[20];
|
||||
let padding_1 = info[21];
|
||||
let dilation_0 = info[22];
|
||||
let dilation_1 = info[23];
|
||||
|
||||
let b = id / output_stride_0 % output_shape_0;
|
||||
let c = id / output_stride_1 % output_shape_1;
|
||||
|
@ -53,7 +55,7 @@ fn main(
|
|||
var max_val = -32767.0;
|
||||
|
||||
for (var kh = 0u; kh < kernel_size_0; kh++) {
|
||||
let ih = oh * pool_stride_0 + kh;
|
||||
let ih = oh * pool_stride_0 + kh * dilation_0;
|
||||
|
||||
// Padding
|
||||
if ih < padding_0 || ih >= input_shape_2 + padding_0 {
|
||||
|
@ -61,7 +63,7 @@ fn main(
|
|||
}
|
||||
|
||||
for (var kw = 0u; kw < kernel_size_1; kw++) {
|
||||
let iw = ow * pool_stride_1 + kw;
|
||||
let iw = ow * pool_stride_1 + kw * dilation_1;
|
||||
|
||||
// Padding
|
||||
if iw < padding_1 || iw >= input_shape_3 + padding_1 {
|
||||
|
|
|
@ -12,7 +12,7 @@ var<storage, read_write> indices: array<{{ int }}>;
|
|||
|
||||
@group(0)
|
||||
@binding(3)
|
||||
var<storage, read> info: array<u32, 22>;
|
||||
var<storage, read> info: array<u32, 24>;
|
||||
|
||||
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
|
||||
|
||||
|
@ -48,6 +48,8 @@ fn main(
|
|||
let pool_stride_1 = info[19];
|
||||
let padding_0 = info[20];
|
||||
let padding_1 = info[21];
|
||||
let dilation_0 = info[22];
|
||||
let dilation_1 = info[23];
|
||||
|
||||
let b = id / output_stride_0 % output_shape_0;
|
||||
let c = id / output_stride_1 % output_shape_1;
|
||||
|
@ -58,7 +60,7 @@ fn main(
|
|||
var index = 0u;
|
||||
|
||||
for (var kh = 0u; kh < kernel_size_0; kh++) {
|
||||
let ih = oh * pool_stride_0 + kh;
|
||||
let ih = oh * pool_stride_0 + kh * dilation_0;
|
||||
|
||||
// Padding
|
||||
if ih < padding_0 || ih >= input_shape_2 + padding_0 {
|
||||
|
@ -66,7 +68,7 @@ fn main(
|
|||
}
|
||||
|
||||
for (var kw = 0u; kw < kernel_size_1; kw++) {
|
||||
let iw = ow * pool_stride_1 + kw;
|
||||
let iw = ow * pool_stride_1 + kw * dilation_1;
|
||||
|
||||
// Padding
|
||||
if iw < padding_1 || iw >= input_shape_3 + padding_1 {
|
||||
|
|
|
@ -13,7 +13,7 @@ var<storage, read_write> output: array<{{ elem }}>;
|
|||
|
||||
@group(0)
|
||||
@binding(3)
|
||||
var<storage, read> info: array<u32, 22>;
|
||||
var<storage, read> info: array<u32, 24>;
|
||||
|
||||
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
|
||||
|
||||
|
@ -49,6 +49,8 @@ fn main(
|
|||
let pool_stride_1 = info[19];
|
||||
let padding_0 = info[20];
|
||||
let padding_1 = info[21];
|
||||
let dilation_0 = info[22];
|
||||
let dilation_1 = info[23];
|
||||
|
||||
let b = id / input_stride_0 % input_shape_0;
|
||||
let c = id / input_stride_1 % input_shape_1;
|
||||
|
@ -56,8 +58,8 @@ fn main(
|
|||
let iw = id / input_stride_3 % input_shape_3;
|
||||
|
||||
// The maximum number of overlapping filters that may content the current index.
|
||||
let kms_0 = i32(kernel_size_0) - i32(pool_stride_0);
|
||||
let kms_1 = i32(kernel_size_1) - i32(pool_stride_1);
|
||||
let kms_0 = i32(kernel_size_0 * dilation_0) - i32(pool_stride_0);
|
||||
let kms_1 = i32(kernel_size_1 * dilation_1) - i32(pool_stride_1);
|
||||
|
||||
let oh_start_tmp = (i32(ih + padding_0) - kms_0) / i32(pool_stride_0);
|
||||
let ow_start_tmp = (i32(iw + padding_1) - kms_1) / i32(pool_stride_1);
|
||||
|
|
Loading…
Reference in New Issue