From 445f41bb7b7062f29a2b033dcd79c2737aaef2f3 Mon Sep 17 00:00:00 2001
From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com>
Date: Tue, 21 Nov 2023 12:21:12 -0600
Subject: [PATCH] Support count_include_pad attr in avg_pool2d ONNX (#978)
Fixes #636
---
burn-core/src/nn/pool/avg_pool2d.rs | 2 +-
.../tests/avg_pool2d/avg_pool2d.onnx | Bin 232 -> 763 bytes
.../onnx-tests/tests/avg_pool2d/avg_pool2d.py | 38 +++++++++++-------
burn-import/onnx-tests/tests/onnx_tests.rs | 26 ++++++++++--
burn-import/src/burn/node/avg_pool2d.rs | 3 ++
burn-import/src/onnx/op_configuration.rs | 11 +++--
6 files changed, 58 insertions(+), 22 deletions(-)
diff --git a/burn-core/src/nn/pool/avg_pool2d.rs b/burn-core/src/nn/pool/avg_pool2d.rs
index f3bb2b60e..8c83c3c3d 100644
--- a/burn-core/src/nn/pool/avg_pool2d.rs
+++ b/burn-core/src/nn/pool/avg_pool2d.rs
@@ -20,7 +20,7 @@ pub struct AvgPool2dConfig {
pub padding: PaddingConfig2d,
/// If the padding is counted in the denominator when computing the average.
#[config(default = "true")]
- count_include_pad: bool,
+ pub count_include_pad: bool,
}
/// Applies a 2D avg pooling over input tensors.
diff --git a/burn-import/onnx-tests/tests/avg_pool2d/avg_pool2d.onnx b/burn-import/onnx-tests/tests/avg_pool2d/avg_pool2d.onnx
index bee96743283b81f901e7ae6551aa2e3f17a5cf79..e30064b69a5f627cda48f0aa29d405bae54336f3 100644
GIT binary patch
literal 763
zcmchU-%7(U6o;4C+IsY&i&M(H^RnQ!w%bjws!!mxH=(v+Ep|yrJJGip?9F@<6CGn0
zt$z?C9Fmjo7x*}FUE;{sR;auj+tDPP%#t@7wg}Dz=PyZe`^;3fWcNbUbVg7-9!;e_
zM@2Y2K-Z@SP2j;v&Z?TOM8W*Q2I0VmVJ_CZr4`TXb-}323LhDA0cTH4aaPl&%p^;#
z6mO6d48hWoCLMe(bLD~4w5lptlj2mrb$TpAg64b6=wB?g&(7(`mJvaVy=DAo%a194
zGs~|m39Pb;)1}H}`Eb+y`Ij(q5%B=KXjiu0?e2I;eie`U;*VE=;tpsOECc5q@JPdB
L0~dYp4tw4^hwQdM
delta 121
zcmey(`hwAdgF}eDpt2;tC^
NodeCodegen for AvgPool2dNode {
let kernel_size = self.config.kernel_size.to_tokens();
let strides = self.config.strides.to_tokens();
let padding = self.config.padding.to_tokens();
+ let count_include_pad = self.config.count_include_pad;
let init_line = quote! {
init();
@@ -60,6 +61,7 @@ impl NodeCodegen for AvgPool2dNode {
let #name = AvgPool2dConfig::new(#kernel_size)
.with_strides(#strides)
.with_padding(#padding)
+ .with_count_include_pad(#count_include_pad)
.#init_line
};
@@ -137,6 +139,7 @@ mod tests {
let avg_pool2d = AvgPool2dConfig::new([3, 3])
.with_strides([1, 1])
.with_padding(PaddingConfig2d::Valid)
+ .with_count_include_pad(true)
.init();
Self {
diff --git a/burn-import/src/onnx/op_configuration.rs b/burn-import/src/onnx/op_configuration.rs
index e9bf01878..d7dc4d23b 100644
--- a/burn-import/src/onnx/op_configuration.rs
+++ b/burn-import/src/onnx/op_configuration.rs
@@ -129,6 +129,7 @@ pub fn avg_pool2d_config(curr: &Node) -> AvgPool2dConfig {
let mut strides = vec![1, 1];
let mut pads = vec![0, 0, 0, 0];
let mut count_include_pad: i64 = 0;
+ let mut ceil_mode: i64 = 0;
for (key, value) in curr.attrs.iter() {
match key.as_str() {
@@ -136,19 +137,21 @@ pub fn avg_pool2d_config(curr: &Node) -> AvgPool2dConfig {
"strides" => strides = value.clone().into_i64s(),
"pads" => pads = value.clone().into_i64s(),
"count_include_pad" => count_include_pad = value.clone().into_i64(),
+ "ceil_mode" => ceil_mode = value.clone().into_i64(),
_ => {}
}
}
- let padding = padding_config(&pads);
-
- if count_include_pad == 1 && padding != PaddingConfig2d::Valid {
- todo!("AvgPool2d: count_include_pad is not supported. See https://github.com/burn-rs/burn/issues/636");
+ if ceil_mode == 1 {
+ panic!("ceil_mode is not supported");
}
+ let padding = padding_config(&pads);
+
AvgPool2dConfig::new([kernel_shape[0] as usize, kernel_shape[1] as usize])
.with_strides([strides[0] as usize, strides[1] as usize])
.with_padding(padding)
+ .with_count_include_pad(count_include_pad == 1)
}
/// Create a FlattenConfig from the attributes of the node