Support count_include_pad attr in avg_pool2d ONNX (#978)

Fixes #636
This commit is contained in:
Dilshod Tadjibaev 2023-11-21 12:21:12 -06:00 committed by GitHub
parent cb616ed72c
commit 445f41bb7b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 58 additions and 22 deletions

View File

@ -20,7 +20,7 @@ pub struct AvgPool2dConfig {
pub padding: PaddingConfig2d,
/// If the padding is counted in the denominator when computing the average.
#[config(default = "true")]
count_include_pad: bool,
pub count_include_pad: bool,
}
/// Applies a 2D avg pooling over input tensors.

View File

@ -10,14 +10,20 @@ class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
# TODO when https://github.com/burn-rs/burn/issues/636 is resolved, test this with a model
# that uses `count_include_pad=False` and padding=(2, 1)
self.pool2d = nn.AvgPool2d((4, 2), stride=(
2, 1), padding=(0, 0), count_include_pad=False)
self.pool2d1 = nn.AvgPool2d((4, 2), stride=(
2, 1))
def forward(self, x):
x = self.pool2d(x)
return x
self.pool2d2 = nn.AvgPool2d((4, 2), stride=(
2, 1), padding=(2, 1), count_include_pad=True)
self.pool2d3 = nn.AvgPool2d((4, 2), stride=(
2, 1), padding=(2, 1), count_include_pad=False)
def forward(self, x1, x2, x3):
y1 = self.pool2d1(x1)
y2 = self.pool2d2(x2)
y3 = self.pool2d3(x3)
return y1, y2, y3
def main():
@ -33,18 +39,22 @@ def main():
device = torch.device("cpu")
file_name = "avg_pool2d.onnx"
test_input = torch.randn(1, 1, 5, 5, device=device)
torch.onnx.export(model, test_input, file_name,
input1 = torch.randn(1, 1, 5, 5, device=device)
torch.onnx.export(model, (input1, input1, input1), file_name,
verbose=False, opset_version=16)
print("Finished exporting model to {}".format(file_name))
# Output some test data for use in the test
print("Test input data shape of ones: {}".format(test_input.shape))
print("Test input data of ones: {}".format(test_input))
output = model.forward(test_input)
print("Test output data shape: {}".format(output.shape))
print("Test output: {}".format(output))
print("Test input data shape: {}".format(input1.shape))
print("Test input data: {}".format(input1))
output1, output2, output3 = model.forward(input1, input1, input1)
print("Test output1 data shape: {}".format(output1.shape))
print("Test output2 data shape: {}".format(output2.shape))
print("Test output3 data shape: {}".format(output3.shape))
print("Test output1: {}".format(output1))
print("Test output2: {}".format(output2))
print("Test output3: {}".format(output3))
if __name__ == '__main__':

View File

@ -362,10 +362,30 @@ mod tests {
[-1.805, -0.476, 0.205, 0.338, 1.353],
[0.374, 0.013, 0.774, -0.109, -0.271],
]]]);
let output = model.forward(input);
let expected = Data::from([[[[0.008, -0.131, -0.208, 0.425]]]]);
let (output1, output2, output3) = model.forward(input.clone(), input.clone(), input);
let expected1 = Data::from([[[[0.008, -0.131, -0.208, 0.425]]]]);
let expected2 = Data::from([[[
[-0.045, 0.202, -0.050, -0.295, 0.162, 0.160],
[-0.176, 0.008, -0.131, -0.208, 0.425, 0.319],
[-0.084, -0.146, 0.017, 0.170, 0.216, 0.125],
]]]);
let expected3 = Data::from([[[
[-0.182, 0.404, -0.100, -0.590, 0.324, 0.638],
[-0.352, 0.008, -0.131, -0.208, 0.425, 0.638],
[-0.224, -0.195, 0.023, 0.226, 0.288, 0.335],
]]]);
output.to_data().assert_approx_eq(&expected, 3);
let expected_shape1 = Shape::from([1, 1, 1, 4]);
let expected_shape2 = Shape::from([1, 1, 3, 6]);
let expected_shape3 = Shape::from([1, 1, 3, 6]);
assert_eq!(output1.shape(), expected_shape1);
assert_eq!(output2.shape(), expected_shape2);
assert_eq!(output3.shape(), expected_shape3);
output1.to_data().assert_approx_eq(&expected1, 3);
output2.to_data().assert_approx_eq(&expected2, 3);
output3.to_data().assert_approx_eq(&expected3, 3);
}
#[test]

View File

@ -51,6 +51,7 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for AvgPool2dNode {
let kernel_size = self.config.kernel_size.to_tokens();
let strides = self.config.strides.to_tokens();
let padding = self.config.padding.to_tokens();
let count_include_pad = self.config.count_include_pad;
let init_line = quote! {
init();
@ -60,6 +61,7 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for AvgPool2dNode {
let #name = AvgPool2dConfig::new(#kernel_size)
.with_strides(#strides)
.with_padding(#padding)
.with_count_include_pad(#count_include_pad)
.#init_line
};
@ -137,6 +139,7 @@ mod tests {
let avg_pool2d = AvgPool2dConfig::new([3, 3])
.with_strides([1, 1])
.with_padding(PaddingConfig2d::Valid)
.with_count_include_pad(true)
.init();
Self {

View File

@ -129,6 +129,7 @@ pub fn avg_pool2d_config(curr: &Node) -> AvgPool2dConfig {
let mut strides = vec![1, 1];
let mut pads = vec![0, 0, 0, 0];
let mut count_include_pad: i64 = 0;
let mut ceil_mode: i64 = 0;
for (key, value) in curr.attrs.iter() {
match key.as_str() {
@ -136,19 +137,21 @@ pub fn avg_pool2d_config(curr: &Node) -> AvgPool2dConfig {
"strides" => strides = value.clone().into_i64s(),
"pads" => pads = value.clone().into_i64s(),
"count_include_pad" => count_include_pad = value.clone().into_i64(),
"ceil_mode" => ceil_mode = value.clone().into_i64(),
_ => {}
}
}
let padding = padding_config(&pads);
if count_include_pad == 1 && padding != PaddingConfig2d::Valid {
todo!("AvgPool2d: count_include_pad is not supported. See https://github.com/burn-rs/burn/issues/636");
if ceil_mode == 1 {
panic!("ceil_mode is not supported");
}
let padding = padding_config(&pads);
AvgPool2dConfig::new([kernel_shape[0] as usize, kernel_shape[1] as usize])
.with_strides([strides[0] as usize, strides[1] as usize])
.with_padding(padding)
.with_count_include_pad(count_include_pad == 1)
}
/// Create a FlattenConfig from the attributes of the node