From 445f41bb7b7062f29a2b033dcd79c2737aaef2f3 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Tue, 21 Nov 2023 12:21:12 -0600 Subject: [PATCH] Support count_include_pad attr in avg_pool2d ONNX (#978) Fixes #636 --- burn-core/src/nn/pool/avg_pool2d.rs | 2 +- .../tests/avg_pool2d/avg_pool2d.onnx | Bin 232 -> 763 bytes .../onnx-tests/tests/avg_pool2d/avg_pool2d.py | 38 +++++++++++------- burn-import/onnx-tests/tests/onnx_tests.rs | 26 ++++++++++-- burn-import/src/burn/node/avg_pool2d.rs | 3 ++ burn-import/src/onnx/op_configuration.rs | 11 +++-- 6 files changed, 58 insertions(+), 22 deletions(-) diff --git a/burn-core/src/nn/pool/avg_pool2d.rs b/burn-core/src/nn/pool/avg_pool2d.rs index f3bb2b60e..8c83c3c3d 100644 --- a/burn-core/src/nn/pool/avg_pool2d.rs +++ b/burn-core/src/nn/pool/avg_pool2d.rs @@ -20,7 +20,7 @@ pub struct AvgPool2dConfig { pub padding: PaddingConfig2d, /// If the padding is counted in the denominator when computing the average. #[config(default = "true")] - count_include_pad: bool, + pub count_include_pad: bool, } /// Applies a 2D avg pooling over input tensors. diff --git a/burn-import/onnx-tests/tests/avg_pool2d/avg_pool2d.onnx b/burn-import/onnx-tests/tests/avg_pool2d/avg_pool2d.onnx index bee96743283b81f901e7ae6551aa2e3f17a5cf79..e30064b69a5f627cda48f0aa29d405bae54336f3 100644 GIT binary patch literal 763 zcmchU-%7(U6o;4C+IsY&i&M(H^RnQ!w%bjws!!mxH=(v+Ep|yrJJGip?9F@<6CGn0 zt$z?C9Fmjo7x*}FUE;{sR;auj+tDPP%#t@7wg}Dz=PyZe`^;3fWcNbUbVg7-9!;e_ zM@2Y2K-Z@SP2j;v&Z?TOM8W*Q2I0VmVJ_CZr4`TXb-}323LhDA0cTH4aaPl&%p^;# z6mO6d48hWoCLMe(bLD~4w5lptlj2mrb$TpAg64b6=wB?g&(7(`mJvaVy=DAo%a194 zGs~|m39Pb;)1}H}`Eb+y`Ij(q5%B=KXjiu0?e2I;eie`U;*VE=;tpsOECc5q@JPdB L0~dYp4tw4^hwQdM delta 121 zcmey(`hwAdgF}eDpt2;tC^ NodeCodegen for AvgPool2dNode { let kernel_size = self.config.kernel_size.to_tokens(); let strides = self.config.strides.to_tokens(); let padding = self.config.padding.to_tokens(); + let count_include_pad = self.config.count_include_pad; let init_line = quote! { init(); @@ -60,6 +61,7 @@ impl NodeCodegen for AvgPool2dNode { let #name = AvgPool2dConfig::new(#kernel_size) .with_strides(#strides) .with_padding(#padding) + .with_count_include_pad(#count_include_pad) .#init_line }; @@ -137,6 +139,7 @@ mod tests { let avg_pool2d = AvgPool2dConfig::new([3, 3]) .with_strides([1, 1]) .with_padding(PaddingConfig2d::Valid) + .with_count_include_pad(true) .init(); Self { diff --git a/burn-import/src/onnx/op_configuration.rs b/burn-import/src/onnx/op_configuration.rs index e9bf01878..d7dc4d23b 100644 --- a/burn-import/src/onnx/op_configuration.rs +++ b/burn-import/src/onnx/op_configuration.rs @@ -129,6 +129,7 @@ pub fn avg_pool2d_config(curr: &Node) -> AvgPool2dConfig { let mut strides = vec![1, 1]; let mut pads = vec![0, 0, 0, 0]; let mut count_include_pad: i64 = 0; + let mut ceil_mode: i64 = 0; for (key, value) in curr.attrs.iter() { match key.as_str() { @@ -136,19 +137,21 @@ pub fn avg_pool2d_config(curr: &Node) -> AvgPool2dConfig { "strides" => strides = value.clone().into_i64s(), "pads" => pads = value.clone().into_i64s(), "count_include_pad" => count_include_pad = value.clone().into_i64(), + "ceil_mode" => ceil_mode = value.clone().into_i64(), _ => {} } } - let padding = padding_config(&pads); - - if count_include_pad == 1 && padding != PaddingConfig2d::Valid { - todo!("AvgPool2d: count_include_pad is not supported. See https://github.com/burn-rs/burn/issues/636"); + if ceil_mode == 1 { + panic!("ceil_mode is not supported"); } + let padding = padding_config(&pads); + AvgPool2dConfig::new([kernel_shape[0] as usize, kernel_shape[1] as usize]) .with_strides([strides[0] as usize, strides[1] as usize]) .with_padding(padding) + .with_count_include_pad(count_include_pad == 1) } /// Create a FlattenConfig from the attributes of the node