mirror of https://github.com/tracel-ai/burn.git
Fix bugs with MaxPool2d in ONNX conversation (#623)
This commit is contained in:
parent
4c663b4cb7
commit
b79fa0748f
|
@ -16,6 +16,7 @@ fn main() {
|
|||
.input("tests/global_avr_pool/global_avr_pool.onnx")
|
||||
.input("tests/softmax/softmax.onnx")
|
||||
.input("tests/log_softmax/log_softmax.onnx")
|
||||
.input("tests/maxpool2d/maxpool2d.onnx")
|
||||
.out_dir("model/")
|
||||
.run_from_script();
|
||||
|
||||
|
|
Binary file not shown.
|
@ -0,0 +1,49 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# used to generate model: maxpool2d1.onnx
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
|
||||
# TODO support dilation=(3, 1) (see https://github.com/burn-rs/burn/issues/622)
|
||||
self.maxpool2d1 = nn.MaxPool2d((4, 2), stride=(2, 1), padding=(2, 1))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.maxpool2d1(x)
|
||||
return x
|
||||
|
||||
|
||||
def main():
|
||||
# Set seed for reproducibility
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Print options
|
||||
torch.set_printoptions(precision=3)
|
||||
|
||||
# Export to onnx
|
||||
model = Model()
|
||||
model.eval()
|
||||
device = torch.device("cpu")
|
||||
|
||||
file_name = "maxpool2d.onnx"
|
||||
test_input = torch.randn(1, 1, 5, 5, device=device)
|
||||
torch.onnx.export(model, test_input, file_name,
|
||||
verbose=False, opset_version=16)
|
||||
|
||||
print("Finished exporting model to {}".format(file_name))
|
||||
|
||||
# Output some test data for use in the test
|
||||
print("Test input data shape of ones: {}".format(test_input.shape))
|
||||
print("Test input data of ones: {}".format(test_input))
|
||||
output = model.forward(test_input)
|
||||
print("Test output data shape: {}".format(output.shape))
|
||||
print("Test output: {}".format(output))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -20,7 +20,8 @@ include_models!(
|
|||
dropout,
|
||||
global_avr_pool,
|
||||
softmax,
|
||||
log_softmax
|
||||
log_softmax,
|
||||
maxpool2d
|
||||
);
|
||||
|
||||
#[cfg(test)]
|
||||
|
@ -209,4 +210,27 @@ mod tests {
|
|||
|
||||
assert_eq!(output.to_data(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn maxpool2d() {
|
||||
// Initialize the model without weights (because the exported file does not contain them)
|
||||
let model: maxpool2d::Model<Backend> = maxpool2d::Model::new();
|
||||
|
||||
// Run the model
|
||||
let input = Tensor::<Backend, 4>::from_floats([[[
|
||||
[1.927, 1.487, 0.901, -2.106, 0.678],
|
||||
[-1.235, -0.043, -1.605, -0.752, -0.687],
|
||||
[-0.493, 0.241, -1.111, 0.092, -2.317],
|
||||
[-0.217, -1.385, -0.396, 0.803, -0.622],
|
||||
[-0.592, -0.063, -0.829, 0.331, -1.558],
|
||||
]]]);
|
||||
let output = model.forward(input);
|
||||
let expected = Data::from([[[
|
||||
[1.927, 1.927, 1.487, 0.901, 0.678, 0.678],
|
||||
[1.927, 1.927, 1.487, 0.901, 0.803, 0.678],
|
||||
[-0.217, 0.241, 0.241, 0.803, 0.803, -0.622],
|
||||
]]]);
|
||||
|
||||
assert_eq!(output.to_data(), expected);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,7 +25,7 @@ impl MaxPool2dNode {
|
|||
field: OtherType::new(
|
||||
name,
|
||||
quote! {
|
||||
MaxPool2d<B>
|
||||
MaxPool2d
|
||||
},
|
||||
),
|
||||
input,
|
||||
|
@ -75,6 +75,7 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for MaxPool2dNode {
|
|||
let #output = self.#field.forward(#input);
|
||||
}
|
||||
}
|
||||
|
||||
fn register_imports(&self, imports: &mut BurnImports) {
|
||||
imports.register("burn::nn::PaddingConfig2d");
|
||||
imports.register("burn::nn::pool::MaxPool2d");
|
||||
|
@ -84,6 +85,10 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for MaxPool2dNode {
|
|||
fn into_node(self) -> Node<PS> {
|
||||
Node::MaxPool2d(self)
|
||||
}
|
||||
|
||||
fn field_serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
|
||||
S::serialize_none(serializer)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
@ -122,7 +127,7 @@ mod tests {
|
|||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Model <B: Backend> {
|
||||
max_pool2d: MaxPool2d<B>,
|
||||
max_pool2d: MaxPool2d,
|
||||
phantom: core::marker::PhantomData<B>,
|
||||
}
|
||||
|
||||
|
|
|
@ -75,16 +75,22 @@ pub fn max_pool2d_config(curr: &Node) -> MaxPool2dConfig {
|
|||
let mut kernel_shape = Vec::new();
|
||||
let mut strides = Vec::new();
|
||||
let mut pads = Vec::new();
|
||||
let mut dilations = Vec::new();
|
||||
|
||||
for (key, value) in curr.attrs.iter() {
|
||||
match key.as_str() {
|
||||
"kernel_shape" => attr_value_vec_i64(value, &mut kernel_shape),
|
||||
"strides" => attr_value_vec_i64(value, &mut strides),
|
||||
"pads" => attr_value_vec_i64(value, &mut pads),
|
||||
"dilations" => attr_value_vec_i64(value, &mut dilations),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
if !dilations.is_empty() && (dilations[0] != 1 || dilations[1] != 1) {
|
||||
todo!("MaxPool2d: dilations are not supported. See https://github.com/burn-rs/burn/issues/622");
|
||||
}
|
||||
|
||||
let padding = padding_config(&pads);
|
||||
|
||||
MaxPool2dConfig::new([kernel_shape[0] as usize, kernel_shape[1] as usize])
|
||||
|
|
Loading…
Reference in New Issue