Fix bugs with MaxPool2d in ONNX conversation (#623)

This commit is contained in:
Dilshod Tadjibaev 2023-08-10 08:35:15 -05:00 committed by GitHub
parent 4c663b4cb7
commit b79fa0748f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 88 additions and 3 deletions

View File

@ -16,6 +16,7 @@ fn main() {
.input("tests/global_avr_pool/global_avr_pool.onnx")
.input("tests/softmax/softmax.onnx")
.input("tests/log_softmax/log_softmax.onnx")
.input("tests/maxpool2d/maxpool2d.onnx")
.out_dir("model/")
.run_from_script();

Binary file not shown.

View File

@ -0,0 +1,49 @@
#!/usr/bin/env python3
# used to generate model: maxpool2d1.onnx
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
# TODO support dilation=(3, 1) (see https://github.com/burn-rs/burn/issues/622)
self.maxpool2d1 = nn.MaxPool2d((4, 2), stride=(2, 1), padding=(2, 1))
def forward(self, x):
x = self.maxpool2d1(x)
return x
def main():
# Set seed for reproducibility
torch.manual_seed(42)
# Print options
torch.set_printoptions(precision=3)
# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")
file_name = "maxpool2d.onnx"
test_input = torch.randn(1, 1, 5, 5, device=device)
torch.onnx.export(model, test_input, file_name,
verbose=False, opset_version=16)
print("Finished exporting model to {}".format(file_name))
# Output some test data for use in the test
print("Test input data shape of ones: {}".format(test_input.shape))
print("Test input data of ones: {}".format(test_input))
output = model.forward(test_input)
print("Test output data shape: {}".format(output.shape))
print("Test output: {}".format(output))
if __name__ == '__main__':
main()

View File

@ -20,7 +20,8 @@ include_models!(
dropout,
global_avr_pool,
softmax,
log_softmax
log_softmax,
maxpool2d
);
#[cfg(test)]
@ -209,4 +210,27 @@ mod tests {
assert_eq!(output.to_data(), expected);
}
#[test]
fn maxpool2d() {
// Initialize the model without weights (because the exported file does not contain them)
let model: maxpool2d::Model<Backend> = maxpool2d::Model::new();
// Run the model
let input = Tensor::<Backend, 4>::from_floats([[[
[1.927, 1.487, 0.901, -2.106, 0.678],
[-1.235, -0.043, -1.605, -0.752, -0.687],
[-0.493, 0.241, -1.111, 0.092, -2.317],
[-0.217, -1.385, -0.396, 0.803, -0.622],
[-0.592, -0.063, -0.829, 0.331, -1.558],
]]]);
let output = model.forward(input);
let expected = Data::from([[[
[1.927, 1.927, 1.487, 0.901, 0.678, 0.678],
[1.927, 1.927, 1.487, 0.901, 0.803, 0.678],
[-0.217, 0.241, 0.241, 0.803, 0.803, -0.622],
]]]);
assert_eq!(output.to_data(), expected);
}
}

View File

@ -25,7 +25,7 @@ impl MaxPool2dNode {
field: OtherType::new(
name,
quote! {
MaxPool2d<B>
MaxPool2d
},
),
input,
@ -75,6 +75,7 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for MaxPool2dNode {
let #output = self.#field.forward(#input);
}
}
fn register_imports(&self, imports: &mut BurnImports) {
imports.register("burn::nn::PaddingConfig2d");
imports.register("burn::nn::pool::MaxPool2d");
@ -84,6 +85,10 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for MaxPool2dNode {
fn into_node(self) -> Node<PS> {
Node::MaxPool2d(self)
}
fn field_serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
S::serialize_none(serializer)
}
}
#[cfg(test)]
@ -122,7 +127,7 @@ mod tests {
#[derive(Module, Debug)]
pub struct Model <B: Backend> {
max_pool2d: MaxPool2d<B>,
max_pool2d: MaxPool2d,
phantom: core::marker::PhantomData<B>,
}

View File

@ -75,16 +75,22 @@ pub fn max_pool2d_config(curr: &Node) -> MaxPool2dConfig {
let mut kernel_shape = Vec::new();
let mut strides = Vec::new();
let mut pads = Vec::new();
let mut dilations = Vec::new();
for (key, value) in curr.attrs.iter() {
match key.as_str() {
"kernel_shape" => attr_value_vec_i64(value, &mut kernel_shape),
"strides" => attr_value_vec_i64(value, &mut strides),
"pads" => attr_value_vec_i64(value, &mut pads),
"dilations" => attr_value_vec_i64(value, &mut dilations),
_ => {}
}
}
if !dilations.is_empty() && (dilations[0] != 1 || dilations[1] != 1) {
todo!("MaxPool2d: dilations are not supported. See https://github.com/burn-rs/burn/issues/622");
}
let padding = padding_config(&pads);
MaxPool2dConfig::new([kernel_shape[0] as usize, kernel_shape[1] as usize])