From 1554a3c898fcdcfc64b70bf64a074dd3c4be29bf Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Fri, 4 Aug 2023 15:51:51 -0500 Subject: [PATCH] Full support for ONNX scalar operators and Constants (#578) --- Cargo.toml | 2 + burn-core/src/module/param/base.rs | 6 + burn-core/src/module/param/constant.rs | 54 ++++- burn-core/src/record/mod.rs | 2 + burn-import/README.md | 6 +- burn-import/onnx-tests/Cargo.toml | 13 ++ burn-import/onnx-tests/README.md | 32 +++ burn-import/onnx-tests/build.rs | 19 ++ burn-import/onnx-tests/src/lib.rs | 1 + burn-import/onnx-tests/tests/add/add.onnx | Bin 0 -> 519 bytes burn-import/onnx-tests/tests/add/add.py | 57 +++++ .../tests}/concat/concat.onnx | 24 +-- .../tests}/concat/concat.py | 19 +- .../onnx-tests/tests/conv2d/conv2d.onnx | Bin 0 -> 1044 bytes burn-import/onnx-tests/tests/conv2d/conv2d.py | 43 ++++ .../tests}/conv2d/conv2d.rs | 0 burn-import/onnx-tests/tests/div/div.onnx | Bin 0 -> 376 bytes burn-import/onnx-tests/tests/div/div.py | 47 +++++ burn-import/onnx-tests/tests/mul/mul.onnx | Bin 0 -> 519 bytes burn-import/onnx-tests/tests/mul/mul.py | 57 +++++ burn-import/onnx-tests/tests/onnx_tests.rs | 128 +++++++++++ burn-import/onnx-tests/tests/sub/sub.onnx | Bin 0 -> 519 bytes burn-import/onnx-tests/tests/sub/sub.py | 57 +++++ burn-import/src/burn/graph.rs | 166 +++++++++++---- burn-import/src/burn/node/base.rs | 65 ++---- burn-import/src/burn/node/batch_norm.rs | 8 +- burn-import/src/burn/node/binary.rs | 153 +++++++++++--- burn-import/src/burn/node/concat.rs | 23 +- burn-import/src/burn/node/constant.rs | 159 +++++++++++--- burn-import/src/burn/node/conv2d.rs | 8 +- burn-import/src/burn/node/linear.rs | 8 +- burn-import/src/burn/node/matmul.rs | 23 +- burn-import/src/burn/node/max_pool2d.rs | 8 +- burn-import/src/burn/node/reshape.rs | 16 +- burn-import/src/burn/node/unary.rs | 183 ++++++++++++---- burn-import/src/burn/ty.rs | 68 +++++- burn-import/src/onnx/dim_inference.rs | 74 ++++++- burn-import/src/onnx/from_onnx.rs | 48 ++--- burn-import/src/onnx/ir.rs | 13 +- burn-import/src/onnx/to_burn.rs | 199 ++++++++++++++---- burn-import/tests/data/concat/concat.rs | 31 --- burn-import/tests/data/conv2d/conv2d.onnx | Bin 17731 -> 0 bytes burn-import/tests/data/conv2d/conv2d.py | 36 ---- burn-import/tests/data/model1/README.md | 17 -- burn-import/tests/data/model1/model1.onnx | Bin 13253 -> 0 bytes burn-import/tests/data/model1/model1.py | 46 ---- burn-import/tests/data/model1/model1.rs | 62 ------ burn-import/tests/onnx_tests.rs | 42 ---- .../src/bin/{mnist.rs => mnist_inference.rs} | 0 49 files changed, 1462 insertions(+), 561 deletions(-) create mode 100644 burn-import/onnx-tests/Cargo.toml create mode 100644 burn-import/onnx-tests/README.md create mode 100644 burn-import/onnx-tests/build.rs create mode 100644 burn-import/onnx-tests/src/lib.rs create mode 100644 burn-import/onnx-tests/tests/add/add.onnx create mode 100755 burn-import/onnx-tests/tests/add/add.py rename burn-import/{tests/data => onnx-tests/tests}/concat/concat.onnx (64%) rename burn-import/{tests/data => onnx-tests/tests}/concat/concat.py (56%) mode change 100644 => 100755 create mode 100644 burn-import/onnx-tests/tests/conv2d/conv2d.onnx create mode 100755 burn-import/onnx-tests/tests/conv2d/conv2d.py rename burn-import/{tests/data => onnx-tests/tests}/conv2d/conv2d.rs (100%) create mode 100644 burn-import/onnx-tests/tests/div/div.onnx create mode 100755 burn-import/onnx-tests/tests/div/div.py create mode 100644 burn-import/onnx-tests/tests/mul/mul.onnx create mode 100755 burn-import/onnx-tests/tests/mul/mul.py create mode 100644 burn-import/onnx-tests/tests/onnx_tests.rs create mode 100644 burn-import/onnx-tests/tests/sub/sub.onnx create mode 100755 burn-import/onnx-tests/tests/sub/sub.py delete mode 100644 burn-import/tests/data/concat/concat.rs delete mode 100644 burn-import/tests/data/conv2d/conv2d.onnx delete mode 100644 burn-import/tests/data/conv2d/conv2d.py delete mode 100644 burn-import/tests/data/model1/README.md delete mode 100644 burn-import/tests/data/model1/model1.onnx delete mode 100644 burn-import/tests/data/model1/model1.py delete mode 100644 burn-import/tests/data/model1/model1.rs delete mode 100644 burn-import/tests/onnx_tests.rs rename examples/onnx-inference/src/bin/{mnist.rs => mnist_inference.rs} (100%) diff --git a/Cargo.toml b/Cargo.toml index 4c3cd3c5c..f2337a4fa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ members = [ "burn-dataset", "burn-derive", "burn-import", + "burn-import/onnx-tests", "burn-ndarray", "burn-no-std-tests", "burn-tch", @@ -29,6 +30,7 @@ dashmap = "5.4.0" dirs = "5.0.1" fake = "2.6.1" flate2 = "1.0.26" +float-cmp = "0.9.0" gix-tempfile = {version = "7.0.0", features = ["signals"]} hashbrown = "0.14.0" indicatif = "0.17.5" diff --git a/burn-core/src/module/param/base.rs b/burn-core/src/module/param/base.rs index 6b405f352..72174cba7 100644 --- a/burn-core/src/module/param/base.rs +++ b/burn-core/src/module/param/base.rs @@ -23,6 +23,12 @@ impl Param { pub fn val(&self) -> T { self.value.clone() } + + /// Execute the given function on the inner value. + pub fn map T>(mut self, func: F) -> Self { + self.value = func(self.value); + self + } } impl core::ops::Deref for Param { diff --git a/burn-core/src/module/param/constant.rs b/burn-core/src/module/param/constant.rs index a34920402..34213808b 100644 --- a/burn-core/src/module/param/constant.rs +++ b/burn-core/src/module/param/constant.rs @@ -1,3 +1,5 @@ +use core::marker::PhantomData; + use crate::{ self as burn, module::{ADModule, Module, ModuleMapper, ModuleVisitor}, @@ -135,12 +137,10 @@ impl Module for Tensor { } fn into_record(self) -> Self::Record { - // Treat as a constant and do not record - ConstantRecord::new() + ConstantRecord } fn load_record(self, _record: Self::Record) -> Self { - // Treat as a constant and do not load self } } @@ -153,15 +153,49 @@ impl ADModule for Tensor { } } +impl Module for PhantomData { + type Record = ConstantRecord; + + fn visit>(&self, _visitor: &mut V) { + // Nothing to do + } + + fn map>(self, _mapper: &mut M) -> Self { + self + } + + fn load_record(self, _record: Self::Record) -> Self { + self + } + + fn into_record(self) -> Self::Record { + ConstantRecord::new() + } +} + +impl ADModule for PhantomData { + type InnerModule = PhantomData; + + fn valid(&self) -> Self::InnerModule { + PhantomData + } +} + #[cfg(all(test, feature = "std"))] mod tests { + use core::marker::PhantomData; + + use burn_tensor::backend::Backend; use burn_tensor::Tensor; - use crate::module::Module; + use crate::TestBackend; use crate::{ record::{BinBytesRecorder, FullPrecisionSettings, Recorder}, TestADBackend, }; + use burn::module::Module; + + use crate as burn; #[test] fn tensor_load_record_setting() { @@ -185,4 +219,16 @@ mod tests { assert!(!no_grad_is_require_grad); assert!(!with_default_is_require_grad); } + + #[test] + fn empty_module_with_phantom() { + #[derive(Module, Debug, new)] + struct EmptyModule { + _phantom: PhantomData, + } + + let _module = EmptyModule::::new(); + + assert_eq!(core::mem::size_of::>(), 0); + } } diff --git a/burn-core/src/record/mod.rs b/burn-core/src/record/mod.rs index 62c0d0922..701dcf5da 100644 --- a/burn-core/src/record/mod.rs +++ b/burn-core/src/record/mod.rs @@ -15,3 +15,5 @@ pub use settings::*; mod file; #[cfg(feature = "std")] pub use file::*; + +pub use primitive::ParamSerde; diff --git a/burn-import/README.md b/burn-import/README.md index 96eeab1c0..bb8b4c2f4 100644 --- a/burn-import/README.md +++ b/burn-import/README.md @@ -32,7 +32,7 @@ List taken from [here](https://github.com/onnx/onnx/blob/main/docs/Operators.md) - [ ] BitwiseOr - [ ] BitwiseXor - [ ] BlackmanWindow -- [ ] Cast +- [x] Cast - [ ] CastLike - [ ] Ceil - [ ] Celu @@ -40,9 +40,9 @@ List taken from [here](https://github.com/onnx/onnx/blob/main/docs/Operators.md) - [ ] Clip - [ ] Col - [ ] Compress -- [ ] Concat +- [x] Concat - [ ] ConcatFromSequence -- [ ] Constant +- [x] Constant - [ ] ConstantOfShape - [ ] Conv - [ ] Conv1d diff --git a/burn-import/onnx-tests/Cargo.toml b/burn-import/onnx-tests/Cargo.toml new file mode 100644 index 000000000..85d0e60a3 --- /dev/null +++ b/burn-import/onnx-tests/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "onnx-tests" +version = "0.9.0" +edition = "2021" + +[dev-dependencies] +burn = { path = "../../burn" } +burn-ndarray = { path = "../../burn-ndarray" } +serde = { workspace = true } +float-cmp = { workspace = true } + +[build-dependencies] +burn-import = { path = "../" } diff --git a/burn-import/onnx-tests/README.md b/burn-import/onnx-tests/README.md new file mode 100644 index 000000000..b3797877a --- /dev/null +++ b/burn-import/onnx-tests/README.md @@ -0,0 +1,32 @@ +# ONNX Tests + +This crate contains ONNX models that are utilized in testing the conversion of ONNX to Burn source +code through the `burn-import` crate. The tests are designed as end-to-end tests, ensuring that ONNX +models are accurately converted into Burn source code. Of utmost importance is verifying that the +converted Burn source code compiles without errors and produces the same output as the original ONNX +model. + +Here is the directory structure of this crate: + +- `tests/`: This directory contains the ONNX model and the Python script to generate it. +- `tests//.onnx`: The ONNX model is generated by the script. +- `tests//.py`: This is the Python script responsible for generating the ONNX model + using PyTorch. +- `tests/onnx_tests.rs`: This is the main test file, where all the tests are contained. +- `build.rs`: This build script generates the ONNX models and is executed by `cargo test` before + running the actual tests. + +## Adding new tests + +Here are the steps to add a new test: + +1. Add your Python script to the `tests/` directory. Refer to existing scripts for examples. +2. Run your Python script to generate the ONNX model and inspect the output of the model with the + test data. Use the inputs and outputs in your test. +3. Make sure the ONNX output contains the desired operators by verifying with the + [Netron](https://github.com/lutzroeder/netron) app. Sometimes PyTorch will optimize the model and + remove operators that are not necessary for the model to run. If this happens, you can disable + optimization by setting `torch.onnx.export(..., do_constant_folding=False)`. +4. Add an entry to the `build.rs` file to account for the generation of the new ONNX model. +5. Include a test in `tests/onnx_tests.rs` to test the new ONNX model. +6. Run `cargo test` to ensure your test passes. diff --git a/burn-import/onnx-tests/build.rs b/burn-import/onnx-tests/build.rs new file mode 100644 index 000000000..4b3e932bf --- /dev/null +++ b/burn-import/onnx-tests/build.rs @@ -0,0 +1,19 @@ +use burn_import::onnx::ModelGen; + +fn main() { + // Re-run this build script if the onnx-tests directory changes. + println!("cargo:rerun-if-changed=tests"); + + // Add onnx models. + ModelGen::new() + .input("tests/add/add.onnx") + .input("tests/sub/sub.onnx") + .input("tests/mul/mul.onnx") + .input("tests/div/div.onnx") + .input("tests/concat/concat.onnx") + .input("tests/conv2d/conv2d.onnx") + .out_dir("model/") + .run_from_script(); + + // panic!("Purposefully failing build to output logs."); +} diff --git a/burn-import/onnx-tests/src/lib.rs b/burn-import/onnx-tests/src/lib.rs new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/burn-import/onnx-tests/src/lib.rs @@ -0,0 +1 @@ + diff --git a/burn-import/onnx-tests/tests/add/add.onnx b/burn-import/onnx-tests/tests/add/add.onnx new file mode 100644 index 0000000000000000000000000000000000000000..238726f233bdb343ab56bef62705dfde9f68cb32 GIT binary patch literal 519 zcmd=)ivXjS00TpVJu+Uv$iijK#hss*S7Bx4n3582z=g#IAzpou z9Na<{5LbyANN|~P@jx^?Cl;5)8w&C1gBft8tYE$p3rJLplZ&Y&UxIM~BNLaikO z@rDSaxsgPXd?Nz$4F}K{EDQ__96&C2;X<>H3#*|*{9xxJ*}?`9fOx@x3)OpE=*oo{ z&7~kZjKDgCIKlA}pOsk>rHT@GLSkGZ9E?H&TudB{K+FWh%s|W%C59GEY+NiH+(HaV P(p-$@xHLI22?ziH48?`@ literal 0 HcmV?d00001 diff --git a/burn-import/onnx-tests/tests/add/add.py b/burn-import/onnx-tests/tests/add/add.py new file mode 100755 index 000000000..f65482249 --- /dev/null +++ b/burn-import/onnx-tests/tests/add/add.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 + +# used to generate model: onnx-tests/tests/add/add.onnx + +import torch +import torch.nn as nn + + +class Model(nn.Module): + def __init__(self): + # Declare a constant float tensor with ones + self.a = torch.ones(1, 1, 1, 4) + + # Declare a scalar + self.b = 5.0 + super(Model, self).__init__() + + def forward(self, x, k): + + # Add a tensor input and a constant tensor + x = x + self.a + + # Add a scalar constant and a scalar input + d = self.b + k + + # Add a tensor and a scalar + x = x + d + + return x + + +def main(): + + # Export to onnx + model = Model() + model.eval() + device = torch.device("cpu") + onnx_name = "add.onnx" + dummy_input = torch.randn(1, 2, 3, 4, device=device) + + scalar = 2.0 + + torch.onnx.export(model, (dummy_input, scalar), onnx_name, + verbose=False, opset_version=16) + + print("Finished exporting model to {}".format(onnx_name)) + + # Output some test data for use in the test + test_input = torch.tensor([[[[1.0, 2.0, 3.0, 4.0]]]]) + + print("Test input data: {}, {}".format(test_input, scalar)) + output = model.forward(test_input, scalar) + print("Test output data: {}".format(output)) + + +if __name__ == '__main__': + main() diff --git a/burn-import/tests/data/concat/concat.onnx b/burn-import/onnx-tests/tests/concat/concat.onnx similarity index 64% rename from burn-import/tests/data/concat/concat.onnx rename to burn-import/onnx-tests/tests/concat/concat.onnx index 424f5fab7..239790da2 100644 --- a/burn-import/tests/data/concat/concat.onnx +++ b/burn-import/onnx-tests/tests/concat/concat.onnx @@ -1,4 +1,4 @@ -pytorch2.0.1:£ +pytorch2.0.1:¡ P onnx::Concat_0 onnx::Concat_0/Concat_output_0/Concat"Concat* @@ -9,16 +9,16 @@ P /Concat_output_0 /Concat_output_0 /Concat_output_02 /Concat_1"Concat* -axis  torch_jitZ) -onnx::Concat_0 - +axis  torch_jitZ( +onnx::Concat_0 +  -€ - - b -2 - + + +b +2 +  -€ - - B \ No newline at end of file + + +B \ No newline at end of file diff --git a/burn-import/tests/data/concat/concat.py b/burn-import/onnx-tests/tests/concat/concat.py old mode 100644 new mode 100755 similarity index 56% rename from burn-import/tests/data/concat/concat.py rename to burn-import/onnx-tests/tests/concat/concat.py index 2ddd27c45..ce7f70689 --- a/burn-import/tests/data/concat/concat.py +++ b/burn-import/onnx-tests/tests/concat/concat.py @@ -1,9 +1,9 @@ -# used to generate model: burn-import/tests/data/conv2d/conv2d.onnx +#!/usr/bin/env python3 + +# used to generate model: onnx-tests/tests/concat/concat.onnx import torch import torch.nn as nn -import onnx -from onnxoptimizer import optimize class Model(nn.Module): def __init__(self): @@ -24,9 +24,20 @@ def main(): model.eval() device = torch.device("cpu") onnx_name = "concat.onnx" - dummy_input = torch.randn(1,256,13,13, device=device) + dummy_input = torch.randn(1,2,3,5, device=device) torch.onnx.export(model, dummy_input, onnx_name, verbose=False, opset_version=16) + + print("Finished exporting model to {}".format(onnx_name)) + + # Output some test data for use in the test + test_input = torch.randn(1,2,3,5, device=device) + print("Test input data shape: {}".format(test_input.shape)) + output = model.forward(test_input) + + print("Test output data shape: {}".format(output.shape)) + + if __name__ == '__main__': main() diff --git a/burn-import/onnx-tests/tests/conv2d/conv2d.onnx b/burn-import/onnx-tests/tests/conv2d/conv2d.onnx new file mode 100644 index 0000000000000000000000000000000000000000..ae34249e6ba5f915712fdb7d34483957d94b54c2 GIT binary patch literal 1044 zcmZ9LYfPJU6vlabq1{Uty3#SmWvgUBR~fWGSS|1QuUt0?K^#OYGXb{NwT_kA78o&E zRFHLJnFKP#A%Q8uhE>vRI0$*qKVg}K5TZnu=q_@QZmV%A zn9SIS^k+IfKj=vmg{nruZZq3zYuvUvyUQpvGX0E1BV)xUoptq&7(qWH(ClPorz}pp zrMAjtt#Md_CUMBLgB3Yy%r2wICZhrcThM{;eGQs^ z8RAnn1mvBE?c?r>R$M#u5jYBNZ?M&dNFB~8< zYb)TR1E}BgDrk?K<;UhT$@HOmT-cR?Y3H1H823}h)H1hXYQ^rkW_(oeVUD`;DOcEM zpgj2sT3*(Hs`q@3x~~T^n0Q*&^c>ZCgk<@F442;)qi1#<+rLgCkoE*ii&6;REu>tD zkWALOadh8RG*#W><(pY#jYfjEVH1WfUclkLgXFjMND@EN0miPMAi4AkuCJ=eF(1&%$OGDneoY)q?png5;#Cxxb!5TM(0N`-8?Jbu z`%(s(sP$pk!30CW=L-7zTp1J{%Oc+5c#>O{Ow|$(?`zwH@@6m2mt5m#vkPD(74J&zhQ_UOIKj~VlP-@J-l_V0KL=#p2iI*`Yf87ec@z# z!VA=qW93GaO0W;B>HR2%nzSvvH(Nvlofa6_vg5=D_t2~@L~lkkIh{C)*1KaoSMnuH z6@AMuj4ojOTS_uV7x3(;gfwEdl$OLB zNU?(XN-Q8zElw_`l6(oq1&mBwRw(+6gajb^4CCQC*`XqsdTqGSY~jMtEySzul39i@ zj0MD1Vg?diI$Yck7lRZSaG@#`Vlm=)ivXjS00V=A12SH~$iijK#hss*S7Bx4TbdJZz=g#IAzpou z9Na<{5LbyANN|~P@jx^?Cl;5)8w&C1gBft8tYE$p3rJLplZ&Y&UxIM~BNLaikO z@rDSaxsgPXd?Nz$4F}K{EDQ_}96&C2;X<>H3#*|*{9xxJ*}?`9fOx@x3)OpE=*oo{ z&7~kZjKDgCIKlA}pOsk>rHT@GLSkGZ9E?H&TudB{K+FWh%s|W%C59GEY+NiH+(HaV P(p-$@xHLI22?ziH9D9cK literal 0 HcmV?d00001 diff --git a/burn-import/onnx-tests/tests/mul/mul.py b/burn-import/onnx-tests/tests/mul/mul.py new file mode 100755 index 000000000..32a1c168d --- /dev/null +++ b/burn-import/onnx-tests/tests/mul/mul.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 + +# used to generate model: onnx-tests/tests/add/add.onnx + +import torch +import torch.nn as nn + + +class Model(nn.Module): + def __init__(self): + # Declare a constant float tensor + self.a = torch.full((1, 1, 1, 4), 3.0) + + # Declare a scalar + self.b = 7.0 + super(Model, self).__init__() + + def forward(self, x, k): + + # Multiply the input by the constant tensor + x = x * self.a + + # Multiply the input scalar by the constant scalar + d = k * self.b + + # Multiply the result of the previous multiplications + x = x * d + + return x + + +def main(): + + # Export to onnx + model = Model() + model.eval() + device = torch.device("cpu") + onnx_name = "mul.onnx" + dummy_input = torch.randn(1, 2, 3, 4, device=device) + + scalar = 6.0 + + torch.onnx.export(model, (dummy_input, scalar), onnx_name, + verbose=False, opset_version=16) + + print("Finished exporting model to {}".format(onnx_name)) + + # Output some test data for use in the test + test_input = torch.tensor([[[[1.0, 2.0, 3.0, 4.0]]]]) + + print("Test input data: {}, {}".format(test_input, scalar)) + output = model.forward(test_input, scalar) + print("Test output data: {}".format(output)) + + +if __name__ == '__main__': + main() diff --git a/burn-import/onnx-tests/tests/onnx_tests.rs b/burn-import/onnx-tests/tests/onnx_tests.rs new file mode 100644 index 000000000..b95a8de31 --- /dev/null +++ b/burn-import/onnx-tests/tests/onnx_tests.rs @@ -0,0 +1,128 @@ +pub mod add { + include!(concat!(env!("OUT_DIR"), "/model/add.rs")); +} + +pub mod sub { + include!(concat!(env!("OUT_DIR"), "/model/sub.rs")); +} + +pub mod mul { + include!(concat!(env!("OUT_DIR"), "/model/mul.rs")); +} + +pub mod div { + include!(concat!(env!("OUT_DIR"), "/model/div.rs")); +} + +pub mod concat { + include!(concat!(env!("OUT_DIR"), "/model/concat.rs")); +} + +pub mod conv2d { + include!(concat!(env!("OUT_DIR"), "/model/conv2d.rs")); +} + +#[cfg(test)] +mod tests { + use super::*; + + use burn::tensor::{Data, Shape, Tensor}; + + use float_cmp::ApproxEq; + + type Backend = burn_ndarray::NdArrayBackend; + + #[test] + fn add_scalar_to_tensor_and_tensor_to_tensor() { + // Initialize the model with weights (loaded from the exported file) + let model: add::Model = add::Model::default(); + + // Run the model + let input = Tensor::::from_floats([[[[1., 2., 3., 4.]]]]); + let scalar = 2f64; + let output = model.forward(input, scalar); + let expected = Data::from([[[[9., 10., 11., 12.]]]]); + + assert_eq!(output.to_data(), expected); + } + + #[test] + fn sub_scalar_from_tensor_and_tensor_from_tensor() { + // Initialize the model with weights (loaded from the exported file) + let model: sub::Model = sub::Model::default(); + + // Run the model + let input = Tensor::::from_floats([[[[1., 2., 3., 4.]]]]); + let scalar = 3.0f64; + let output = model.forward(input, scalar); + let expected = Data::from([[[[6., 7., 8., 9.]]]]); + + assert_eq!(output.to_data(), expected); + } + + #[test] + fn mul_scalar_with_tensor_and_tensor_with_tensor() { + // Initialize the model with weights (loaded from the exported file) + let model: mul::Model = mul::Model::default(); + + // Run the model + let input = Tensor::::from_floats([[[[1., 2., 3., 4.]]]]); + let scalar = 6.0f64; + let output = model.forward(input, scalar); + let expected = Data::from([[[[126., 252., 378., 504.]]]]); + + assert_eq!(output.to_data(), expected); + } + + #[test] + fn div_tensor_by_scalar_and_tensor_by_tensor() { + // Initialize the model without weights (because the exported file does not contain them) + let model: div::Model = div::Model::new(); + + // Run the model + let input = Tensor::::from_floats([[[[3., 6., 6., 9.]]]]); + let scalar1 = 9.0f64; + let scalar2 = 3.0f64; + let output = model.forward(input, scalar1, scalar2); + let expected = Data::from([[[[1., 2., 2., 3.]]]]); + + assert_eq!(output.to_data(), expected); + } + + #[test] + fn concat_tensors() { + // Initialize the model + let model: concat::Model = concat::Model::new(); + + // Run the model + let input = Tensor::::zeros([1, 2, 3, 5]); + + let output = model.forward(input); + + let expected = Shape::from([1, 18, 3, 5]); + + assert_eq!(output.shape(), expected); + } + + #[test] + fn conv2d() { + // Initialize the model with weights (loaded from the exported file) + let model: conv2d::Model = conv2d::Model::default(); + + // Run the model with ones as input for easier testing + let input = Tensor::::ones([2, 4, 10, 15]); + + let output = model.forward(input); + + let expected_shape = Shape::from([2, 6, 6, 15]); + assert_eq!(output.shape().clone(), expected_shape); + + // We are using the sum of the output tensor to test the correctness of the conv2d node + // because the output tensor is too large to compare with the expected tensor. + let output_sum = output.sum().into_scalar(); + + let expected_sum = 24.004_995; // from pytorch + + assert!(expected_sum.approx_eq(output_sum, (1.0e-4, 2))); + } +} diff --git a/burn-import/onnx-tests/tests/sub/sub.onnx b/burn-import/onnx-tests/tests/sub/sub.onnx new file mode 100644 index 0000000000000000000000000000000000000000..7ffdfc8083cccc46a0c68316488cd8034846bc1a GIT binary patch literal 519 zcmd=)ivXjS00TpVJu+Uv$iijK#hss*S7BupT$&Vbz=g#IAzpou z9Na<{5LbyANN|~P@jx^?Cl;5)8w&C1gBft8tYE$p3rJLplZ&Y&UxIM~BNLaikO z@rDSaxsgPXd?Nz$4F}K{EDQ_+jv$x2aG_bph1F0Yez5bAY+-{4K)hhUh3Y*nbmc;f z=28$HMqnL6oZxtg&&n)`Qbh?oAu%oy4n`pXE+!5}AZ7w$W*}yX5 { default: Option, blank_spaces: bool, gen_new_fn: bool, + graph_input_types: Vec, + graph_output_types: Vec, } impl BurnGraph { @@ -163,20 +165,22 @@ impl BurnGraph { fn build_scope(&mut self) { log::debug!("Building the scope nodes len => '{}'", self.nodes.len()); - let input = self.nodes.first().unwrap(); - - fn to_tensor(ty: Type<'_>) -> Option<&TensorType> { + fn to_tensor(ty: Type) -> Option { match ty { - Type::Tensor(tensor) => Some(tensor), + Type::Tensor(tensor) => Some(tensor.clone()), + Type::Scalar(_) => None, Type::Other(_) => None, } } - input - .input_types() + // Register graph tensor input with 0 as node position + self.graph_input_types + .clone() .into_iter() .flat_map(to_tensor) - .for_each(|tensor| self.scope.tensor_register_variable(tensor, 0)); + .for_each(|tensor| { + self.scope.tensor_register_variable(&tensor, 0); + }); self.nodes .iter() @@ -187,7 +191,7 @@ impl BurnGraph { .flat_map(to_tensor) .for_each(|tensor| { self.scope - .tensor_register_variable(tensor, node_position + 1) + .tensor_register_variable(&tensor, node_position + 1) }) }); @@ -198,7 +202,10 @@ impl BurnGraph { node.input_types() .into_iter() .flat_map(to_tensor) - .for_each(|tensor| self.scope.tensor_register_future_use(tensor, node_position)) + .for_each(|tensor| { + self.scope + .tensor_register_future_use(&tensor, node_position) + }) }); } @@ -240,16 +247,33 @@ impl BurnGraph { let name = field.name(); let ty = field.ty(); - quote! { - #name: #ty, + if matches!(&field, Type::Tensor(_)) { + quote! { + #name: burn::module::Param<#ty>, + } + } else { + quote! { + #name: #ty, + } } }) .for_each(|code| body.extend(code)); - quote! { - #[derive(Module, Debug)] - pub struct Model { - #body + // Add dummy field if no field is present to avoid empty struct + // and make sure we can derive Module trait and use it in a model. + if body.is_empty() { + quote! { + #[derive(Module, Debug)] + pub struct Model { + _phantom: core::marker::PhantomData, + } + } + } else { + quote! { + #[derive(Module, Debug)] + pub struct Model { + #body + } } } } @@ -269,13 +293,24 @@ impl BurnGraph { .map(|field| field.name().clone()) .collect::>(); - quote! { - #[allow(dead_code)] - pub fn new() -> Self { - #body + if fields.is_empty() { + quote! { + #[allow(dead_code)] + pub fn new() -> Self { + Self { + _phantom: core::marker::PhantomData, + } + } + } + } else { + quote! { + #[allow(dead_code)] + pub fn new() -> Self { + #body - Self { - #(#fields,)* + Self { + #(#fields,)* + } } } } @@ -295,12 +330,22 @@ impl BurnGraph { .map(|field| field.name().clone()) .collect::>(); - quote! { - pub fn new_with(record: ModelRecord) -> Self { - #body + if fields.is_empty() { + quote! { + pub fn new_with(_record: ModelRecord) -> Self { + Self { + _phantom: core::marker::PhantomData, + } + } + } + } else { + quote! { + pub fn new_with(record: ModelRecord) -> Self { + #body - Self { - #(#fields,)* + Self { + #(#fields,)* + } } } } @@ -311,26 +356,19 @@ impl BurnGraph { let mut output_type_def = quote! {}; let mut output_return_def = quote! {}; - self.nodes - .first() - .unwrap() - .input_types() - .into_iter() - .for_each(|input| { - let name = input.name(); - let ty = input.ty(); + self.graph_input_types.iter().for_each(|input| { + let name = input.name().clone(); + let ty = input.ty().clone(); - input_def.extend(quote! { - #name: #ty, + input_def.extend(quote! { + #name: #ty, - }) - }); + }) + }); - let output_types = self.nodes.last().unwrap().output_types(); + let multiple_output = self.graph_output_types.len() > 1; - let multiple_output = output_types.len() > 1; - - output_types.into_iter().for_each(|output| { + self.graph_output_types.iter().for_each(|output| { let name = output.name(); let ty = output.ty(); @@ -379,6 +417,48 @@ impl BurnGraph { } } } + + /// Register the input and output types of the graph using the passed in names. + /// The names must be unique and match the names of the inputs and outputs of the nodes. + /// The order will be preserved. + /// + /// # Arguments + /// + /// * `input_names` - The names of the inputs of the graph. + /// * `output_names` - The names of the outputs of the graph. + /// + /// # Panics + /// + /// Panics if the graph is empty. + pub fn register_input_output(&mut self, input_names: Vec, output_names: Vec) { + assert!( + !self.nodes.is_empty(), + "Cannot register input and output types for an empty graph." + ); + + // Get the unique names of each input of the nodes + let mut inputs = HashMap::new(); + let mut outputs = HashMap::new(); + for node in self.nodes.iter() { + for input in node.input_types() { + inputs.insert(input.name().to_string(), input); + } + for output in node.output_types() { + outputs.insert(output.name().to_string(), output); + } + } + + // Get the input and output types of the graph using passed in names + input_names.iter().for_each(|input| { + self.graph_input_types + .push(inputs.get(input).unwrap().clone()); + }); + + output_names.iter().for_each(|output| { + self.graph_output_types + .push(outputs.get(output).unwrap().clone()); + }); + } } #[derive(new)] diff --git a/burn-import/src/burn/node/base.rs b/burn-import/src/burn/node/base.rs index 5783d2d13..1c31fbfc3 100644 --- a/burn-import/src/burn/node/base.rs +++ b/burn-import/src/burn/node/base.rs @@ -77,7 +77,7 @@ pub enum Node { MaxPool2d(MaxPool2dNode), Linear(LinearNode), BatchNorm(BatchNormNode), - Constant(ConstantNode), + Constant(ConstantNode), Unary(UnaryNode), Reshape(ReshapeNode), Concat(ConcatNode), @@ -174,7 +174,6 @@ impl NodeCodegen for Node { #[cfg(test)] pub(crate) mod tests { use crate::burn::{ - codegen::ToTokens, graph::BurnGraph, node::{conv2d::Conv2dNode, matmul::MatmulNode, test::assert_tokens, NodeCodegen}, TensorType, @@ -185,14 +184,18 @@ pub(crate) mod tests { use proc_macro2::TokenStream; use quote::quote; - fn one_node_graph + 'static>( + pub(crate) fn one_node_graph + 'static>( node_gen: T, forward: TokenStream, + input_names: Vec, + output_names: Vec, ) { let mut graph = BurnGraph::::default(); graph.register(node_gen); + graph.register_input_output(input_names, output_names); + let expected = quote! { use burn::{ module::Module, @@ -200,11 +203,15 @@ pub(crate) mod tests { }; #[derive(Module, Debug)] - pub struct Model {} + pub struct Model { + _phantom: core::marker::PhantomData, + } impl Model { - pub fn new_with(record: ModelRecord) -> Self { - Self { } + pub fn new_with(_record: ModelRecord) -> Self { + Self { + _phantom: core::marker::PhantomData, + } } #[allow(clippy::let_and_return)] @@ -215,42 +222,6 @@ pub(crate) mod tests { assert_tokens(graph.codegen(), expected); } - pub(crate) fn codegen_unary_operator< - const N: usize, - T: NodeCodegen + 'static, - >( - node_gen: T, - function: TokenStream, - ) { - let forward = |function, tensor_dim| { - quote! { - pub fn forward(&self, tensor1: Tensor) -> Tensor { - #function - } - } - }; - - one_node_graph(node_gen, forward(function, N.to_tokens())); - } - - pub(crate) fn codegen_binary_operator< - const N: usize, - T: NodeCodegen + 'static, - >( - node_gen: T, - function: TokenStream, - ) { - let forward = |function, tensor_dim| { - quote! { - pub fn forward(&self, tensor1: Tensor, tensor2: Tensor) -> Tensor { - #function - } - } - }; - - one_node_graph(node_gen, forward(function, N.to_tokens())); - } - #[test] fn test_codegen_two_nodes() { let mut graph = BurnGraph::::default(); @@ -269,6 +240,11 @@ pub(crate) mod tests { Conv2dConfig::new([3, 3], [3, 3]).with_padding(PaddingConfig2d::Valid), )); + graph.register_input_output( + vec!["tensor1".to_string(), "tensor2".to_string()], + vec!["tensor4".to_string()], + ); + let expected = quote! { use burn::{ module::Module, @@ -333,6 +309,11 @@ pub(crate) mod tests { TensorType::new_float("output", 4), )); + graph.register_input_output( + vec!["tensor1".to_string(), "tensor2".to_string()], + vec!["output".to_string()], + ); + let expected = quote! { use burn::{ module::Module, diff --git a/burn-import/src/burn/node/batch_norm.rs b/burn-import/src/burn/node/batch_norm.rs index 7b31bb126..246f5afa7 100644 --- a/burn-import/src/burn/node/batch_norm.rs +++ b/burn-import/src/burn/node/batch_norm.rs @@ -102,13 +102,13 @@ macro_rules! batch_norm_serialize { impl NodeCodegen for BatchNormNode { fn input_types(&self) -> Vec { - vec![Type::Tensor(&self.input)] + vec![Type::Tensor(self.input.clone())] } fn output_types(&self) -> Vec { - vec![Type::Tensor(&self.output)] + vec![Type::Tensor(self.output.clone())] } fn field_type(&self) -> Option { - Some(Type::Other(&self.field)) + Some(Type::Other(self.field.clone())) } fn field_init(&self, with_record: bool) -> Option { @@ -181,6 +181,8 @@ mod tests { BatchNormConfig::new(128), )); + graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); + let expected = quote! { use burn::{ module::Module, diff --git a/burn-import/src/burn/node/binary.rs b/burn-import/src/burn/node/binary.rs index da497e098..b7517a8b3 100644 --- a/burn-import/src/burn/node/binary.rs +++ b/burn-import/src/burn/node/binary.rs @@ -1,5 +1,5 @@ use super::{Node, NodeCodegen}; -use crate::burn::{Scope, TensorType, Type}; +use crate::burn::{Scope, Type}; use burn::record::PrecisionSettings; use proc_macro2::TokenStream; use quote::quote; @@ -32,9 +32,9 @@ type FnPointer = Arc TokenStream>; /// Node for all binary operators. #[derive(Clone, new)] pub struct BinaryNode { - pub lhs: TensorType, - pub rhs: TensorType, - pub output: TensorType, + pub lhs: Type, + pub rhs: Type, + pub output: Type, pub binary_type: BinaryType, function: FnPointer, } @@ -56,17 +56,35 @@ impl std::fmt::Debug for BinaryNode { impl NodeCodegen for BinaryNode { fn output_types(&self) -> Vec { - vec![Type::Tensor(&self.output)] + vec![self.output.clone()] } fn input_types(&self) -> Vec { - vec![Type::Tensor(&self.lhs), Type::Tensor(&self.rhs)] + vec![self.lhs.clone(), self.rhs.clone()] } fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { - let lhs = scope.tensor_use_owned(&self.lhs, node_position); - let rhs = scope.tensor_use_owned(&self.rhs, node_position); - let output = &self.output.name; + // Get the lhs name in the form of token stream. + let lhs = match &self.lhs { + Type::Tensor(tensor) => scope.tensor_use_owned(tensor, node_position), + Type::Scalar(scalar) => { + let name = scalar.name.clone(); + quote! { #name } + } + _ => panic!("lhs must be a tensor or scalar"), + }; + + // Get the rhs name in the form of token stream + let rhs = match &self.rhs { + Type::Tensor(tensor) => scope.tensor_use_owned(tensor, node_position), + Type::Scalar(scalar) => { + let name = scalar.name.clone(); + quote! { #name } + } + _ => panic!("rhs must be a tensor or scalar"), + }; + + let output = &self.output.name(); let function = (self.function)(lhs, rhs); quote! { @@ -80,27 +98,53 @@ impl NodeCodegen for BinaryNode { } impl BinaryNode { - pub(crate) fn add(lhs: TensorType, rhs: TensorType, output: TensorType) -> Self { - let function = move |lhs, rhs| quote! { #lhs.add(#rhs) }; + pub(crate) fn add(lhs: Type, rhs: Type, output: Type) -> Self { + let function = match (&lhs, &rhs) { + (Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.add(#rhs) }, + (Type::Tensor(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs.add_scalar(#rhs) }, + (Type::Scalar(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #rhs.add_scalar(#lhs) }, + (Type::Scalar(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs + #rhs }, + _ => panic!("Addition is supported for tensor and scalar only"), + }; + Self::new(lhs, rhs, output, BinaryType::Add, Arc::new(function)) } - pub(crate) fn sub(lhs: TensorType, rhs: TensorType, output: TensorType) -> Self { - let function = move |lhs, rhs| quote! { #lhs.sub(#rhs) }; + pub(crate) fn sub(lhs: Type, rhs: Type, output: Type) -> Self { + let function = match (&lhs, &rhs) { + (Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.sub(#rhs) }, + (Type::Tensor(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs.sub_scalar(#rhs) }, + (Type::Scalar(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs - #rhs }, + _ => panic!("Subtraction is supported for tensor and scalar only"), + }; + Self::new(lhs, rhs, output, BinaryType::Sub, Arc::new(function)) } - pub(crate) fn mul(lhs: TensorType, rhs: TensorType, output: TensorType) -> Self { - let function = move |lhs, rhs| quote! { #lhs.mul(#rhs) }; + pub(crate) fn mul(lhs: Type, rhs: Type, output: Type) -> Self { + let function = match (&lhs, &rhs) { + (Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.mul(#rhs) }, + (Type::Tensor(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs.mul_scalar(#rhs) }, + (Type::Scalar(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #rhs.mul_scalar(#lhs) }, + (Type::Scalar(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs * #rhs }, + _ => panic!("Multiplication is supported for tensor and scalar only"), + }; + Self::new(lhs, rhs, output, BinaryType::Mul, Arc::new(function)) } - pub(crate) fn div(lhs: TensorType, rhs: TensorType, output: TensorType) -> Self { - let function = move |lhs, rhs| quote! { #lhs.div(#rhs) }; + pub(crate) fn div(lhs: Type, rhs: Type, output: Type) -> Self { + let function = match (&lhs, &rhs) { + (Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.div(#rhs) }, + (Type::Tensor(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs.div_scalar(#rhs) }, + (Type::Scalar(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs / #rhs }, + _ => panic!("Division is supported for tensor and scalar only"), + }; + Self::new(lhs, rhs, output, BinaryType::Div, Arc::new(function)) } - pub(crate) fn equal(lhs: TensorType, rhs: TensorType, output: TensorType) -> Self { + pub(crate) fn equal(lhs: Type, rhs: Type, output: Type) -> Self { let function = move |lhs, rhs| quote! { #lhs.equal(#rhs) }; Self::new(lhs, rhs, output, BinaryType::Equal, Arc::new(function)) } @@ -110,48 +154,93 @@ impl BinaryNode { mod tests { use super::*; - use crate::burn::node::tests::codegen_binary_operator; - use crate::burn::TensorType; + use crate::burn::node::tests::one_node_graph; + use crate::burn::{ScalarKind, ScalarType, TensorType}; - macro_rules! test_binary_operator { + macro_rules! test_binary_operator_on_tensors { ($operator:ident) => {{ - codegen_binary_operator::<4, _>( + one_node_graph( BinaryNode::$operator( - TensorType::new_float("tensor1", 4), - TensorType::new_float("tensor2", 4), - TensorType::new_float("tensor3", 4), + Type::Tensor(TensorType::new_float("tensor1", 4)), + Type::Tensor(TensorType::new_float("tensor2", 4)), + Type::Tensor(TensorType::new_float("tensor3", 4)), ), quote! { - let tensor3 = tensor1.$operator(tensor2); + pub fn forward(&self, tensor1: Tensor, tensor2: Tensor) -> Tensor { + let tensor3 = tensor1.$operator(tensor2); - tensor3 + tensor3 + } }, + vec!["tensor1".to_string(), "tensor2".to_string()], + vec!["tensor3".to_string()], + ); + }}; + } + + macro_rules! test_binary_operator_on_tensor_and_scalar { + ($operator:ident, $burn_operator:ident) => {{ + one_node_graph( + BinaryNode::$operator( + Type::Tensor(TensorType::new_float("tensor1", 4)), + Type::Scalar(ScalarType::new("scalar1", ScalarKind::Float32)), + Type::Tensor(TensorType::new_float("tensor3", 4)), + ), + quote! { + pub fn forward(&self, scalar1: f32, tensor1: Tensor) -> Tensor { + let tensor3 = tensor1.$burn_operator(scalar1); + + tensor3 + } + }, + vec!["scalar1".to_string(), "tensor1".to_string()], + vec!["tensor3".to_string()], ); }}; } #[test] fn test_binary_codegen_add() { - test_binary_operator!(add); + test_binary_operator_on_tensors!(add); + } + + #[test] + fn test_binary_codegen_add_scalar() { + test_binary_operator_on_tensor_and_scalar!(add, add_scalar); } #[test] fn test_binary_codegen_sub() { - test_binary_operator!(sub); + test_binary_operator_on_tensors!(sub); + } + + #[test] + fn test_binary_codegen_sub_scalar() { + test_binary_operator_on_tensor_and_scalar!(sub, sub_scalar); } #[test] fn test_binary_codegen_mul() { - test_binary_operator!(mul); + test_binary_operator_on_tensors!(mul); + } + + #[test] + fn test_binary_codegen_mul_scalar() { + test_binary_operator_on_tensor_and_scalar!(mul, mul_scalar); } #[test] fn test_binary_codegen_div() { - test_binary_operator!(div); + test_binary_operator_on_tensors!(div); + } + + #[test] + fn test_binary_codegen_div_scalar() { + test_binary_operator_on_tensor_and_scalar!(div, div_scalar); } #[test] fn test_binary_codegen_equal() { - test_binary_operator!(equal); + test_binary_operator_on_tensors!(equal); } } diff --git a/burn-import/src/burn/node/concat.rs b/burn-import/src/burn/node/concat.rs index 6169d45ca..fcadc7f47 100644 --- a/burn-import/src/burn/node/concat.rs +++ b/burn-import/src/burn/node/concat.rs @@ -14,11 +14,14 @@ pub struct ConcatNode { impl NodeCodegen for ConcatNode { fn output_types(&self) -> Vec { - vec![Type::Tensor(&self.output)] + vec![Type::Tensor(self.output.clone())] } fn input_types(&self) -> Vec { - self.inputs.iter().map(Type::Tensor).collect() + self.inputs + .iter() + .map(|t| Type::Tensor(t.clone())) + .collect() } fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { @@ -65,6 +68,11 @@ mod tests { 1, )); + graph.register_input_output( + vec!["tensor1".to_string(), "tensor2".to_string()], + vec!["tensor3".to_string()], + ); + let expected = quote! { use burn::{ module::Module, @@ -72,12 +80,17 @@ mod tests { }; #[derive(Module, Debug)] - pub struct Model {} + pub struct Model { + _phantom: core::marker::PhantomData, + } impl Model { - pub fn new_with(record: ModelRecord) -> Self { - Self { } + pub fn new_with(_record: ModelRecord) -> Self { + Self { + _phantom: core::marker::PhantomData, + } } + #[allow(clippy::let_and_return)] pub fn forward(&self, tensor1: Tensor, tensor2: Tensor) -> Tensor { let tensor3 = burn::tensor::Tensor::cat(vec![tensor1, tensor2], 1); diff --git a/burn-import/src/burn/node/constant.rs b/burn-import/src/burn/node/constant.rs index f5968e97c..1f48e0d5c 100644 --- a/burn-import/src/burn/node/constant.rs +++ b/burn-import/src/burn/node/constant.rs @@ -1,72 +1,177 @@ use super::{Node, NodeCodegen}; -use crate::burn::{OtherType, Scope, Type}; -use burn::record::PrecisionSettings; +use crate::burn::{ScalarKind, ScalarType, Scope, TensorType, ToTokens, Type}; +use burn::{ + module::ParamId, + record::{ParamSerde, PrecisionSettings}, + tensor::DataSerialize, +}; use proc_macro2::{Ident, Span, TokenStream}; use quote::quote; +use serde::Serialize; #[derive(Debug, Clone)] -pub struct ConstantNode { +pub struct ConstantNode { pub name: String, - pub value: ConstantValue, - output_ty: OtherType, + pub value: ConstantValue, + pub output: Type, +} + +#[derive(Debug, Clone)] +pub enum TensorValue { + Float(DataSerialize), + Int(DataSerialize), } #[derive(Debug, Clone, new)] -pub enum ConstantValue { - Int(i32), - Float(f32), - Bool(bool), +pub enum ConstantValue { + /// Float constant. + Float32(f32), + Float64(f64), + + /// Integer constant. + Int32(i32), + Int64(i64), + + /// Tensor constant. + Tensor(TensorType, TensorValue), } -impl ConstantValue { +impl ConstantValue { pub fn ty_tokens(&self) -> TokenStream { match self { - ConstantValue::Int(_) => quote! { i32 }, - ConstantValue::Float(_) => quote! { f32 }, - ConstantValue::Bool(_) => quote! { bool }, + ConstantValue::Float32(_) => quote! { f32 }, + ConstantValue::Float64(_) => quote! { f64 }, + ConstantValue::Int32(_) => quote! { i32 }, + ConstantValue::Int64(_) => quote! { i64 }, + ConstantValue::Tensor(tensor_type, _) => { + let ty = tensor_type.ty(); + quote! { burn::module::Param<#ty>} + } } } pub fn val_tokens(&self) -> TokenStream { match self { - ConstantValue::Int(val) => quote! { #val }, - ConstantValue::Float(val) => quote! { #val }, - ConstantValue::Bool(val) => quote! { #val }, + ConstantValue::Float32(val) => quote! { #val }, + ConstantValue::Float64(val) => quote! { #val }, + ConstantValue::Int32(val) => quote! { #val }, + ConstantValue::Int64(val) => quote! { #val }, + ConstantValue::Tensor(_, _) => { + panic!("Tensor constant is not assignable.") + } } } } -impl ConstantNode { - pub fn new(name: String, value: ConstantValue) -> Self { - let output_ty = OtherType::new(name.clone(), value.ty_tokens()); - +impl ConstantNode { + pub fn new(name: String, value: ConstantValue, output: Type) -> Self { Self { name, value, - output_ty, + output, + } + } + pub fn constant_value_into_type(&self) -> Type { + let name = Ident::new(self.name.as_str(), Span::call_site()); + match &self.value { + ConstantValue::Float32(_) => Type::Scalar(ScalarType { + name, + kind: ScalarKind::Float32, + }), + ConstantValue::Float64(_) => Type::Scalar(ScalarType { + name, + kind: ScalarKind::Float64, + }), + ConstantValue::Int32(_) => Type::Scalar(ScalarType { + name, + kind: ScalarKind::Int32, + }), + ConstantValue::Int64(_) => Type::Scalar(ScalarType { + name, + kind: ScalarKind::Int64, + }), + ConstantValue::Tensor(tensor_type, _) => Type::Tensor(tensor_type.clone()), } } } -impl NodeCodegen for ConstantNode { +impl NodeCodegen for ConstantNode { fn output_types(&self) -> Vec { - vec![Type::Other(&self.output_ty)] + vec![self.output.clone()] } fn input_types(&self) -> Vec { vec![] } + fn field_type(&self) -> Option { + match &self.value { + ConstantValue::Tensor(tensor_type, _) => Some(Type::Tensor(tensor_type.clone())), + _ => None, + } + } + + fn field_init(&self, with_record: bool) -> Option { + match &self.value { + ConstantValue::Tensor(tensor_type, _) => { + let ty = tensor_type.ty(); + let name = Ident::new(self.name.as_ref(), Span::call_site()); + let shape = tensor_type.clone().shape.unwrap().to_tokens(); + let dim = tensor_type.clone().dim.to_tokens(); + + if with_record { + Some(quote! { + let #name = record.#name.map(|tensor| tensor.set_require_grad(false)); + }) + } else { + Some(quote! { + let #name: burn::module::Param<#ty> = burn::module::Param::new( + burn::module::ParamId::new(), + Tensor::::zeros(#shape).set_require_grad(false), + ); + }) + } + } + _ => None, + } + } + fn forward(&self, _scope: &mut Scope, _node_position: usize) -> TokenStream { let name = Ident::new(self.name.as_ref(), Span::call_site()); - let val = self.value.val_tokens(); - let ty = self.value.ty_tokens(); + let output = self.output.name(); - quote! { - let #name: #ty = #val; + match &self.value { + ConstantValue::Tensor(_, _) => { + quote! { + let #output = self.#name.val(); + } + } + _ => { + let val = self.value.val_tokens(); + let ty = self.value.ty_tokens(); + + quote! { + let #output: #ty = #val; + } + } } } fn into_node(self) -> Node { Node::Constant(self) } + + fn field_serialize(&self, serializer: S) -> Result { + if let ConstantValue::Tensor(_, ds) = &self.value { + let data: DataSerialize = match ds { + TensorValue::Float(data) => data.clone().convert(), + TensorValue::Int(data) => data.clone().convert(), + }; + let data = ParamSerde::new(ParamId::new().into_string(), data); + return data.serialize(serializer); + } + + S::serialize_none(serializer) + } } + +// TODO add test missing for constant node (@antimora 8/2/2023) diff --git a/burn-import/src/burn/node/conv2d.rs b/burn-import/src/burn/node/conv2d.rs index 7a332f3de..997e43933 100644 --- a/burn-import/src/burn/node/conv2d.rs +++ b/burn-import/src/burn/node/conv2d.rs @@ -47,13 +47,13 @@ impl Conv2dNode { impl NodeCodegen for Conv2dNode { fn input_types(&self) -> Vec { - vec![Type::Tensor(&self.input)] + vec![Type::Tensor(self.input.clone())] } fn output_types(&self) -> Vec { - vec![Type::Tensor(&self.output)] + vec![Type::Tensor(self.output.clone())] } fn field_type(&self) -> Option { - Some(Type::Other(&self.field)) + Some(Type::Other(self.field.clone())) } fn field_init(&self, with_record: bool) -> Option { @@ -154,6 +154,8 @@ mod tests { Conv2dConfig::new([3, 3], [3, 3]).with_padding(PaddingConfig2d::Valid), )); + graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); + let expected = quote! { use burn::{ module::Module, diff --git a/burn-import/src/burn/node/linear.rs b/burn-import/src/burn/node/linear.rs index 2ef65c037..d52e2f04f 100644 --- a/burn-import/src/burn/node/linear.rs +++ b/burn-import/src/burn/node/linear.rs @@ -47,14 +47,14 @@ impl LinearNode { impl NodeCodegen for LinearNode { fn input_types(&self) -> Vec { - vec![Type::Tensor(&self.input)] + vec![Type::Tensor(self.input.clone())] } fn output_types(&self) -> Vec { - vec![Type::Tensor(&self.output)] + vec![Type::Tensor(self.output.clone())] } fn field_type(&self) -> Option { - Some(Type::Other(&self.field)) + Some(Type::Other(self.field.clone())) } fn field_init(&self, with_record: bool) -> Option { @@ -136,6 +136,8 @@ mod tests { LinearConfig::new(128, 128), )); + graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); + let expected = quote! { use burn::{ module::Module, diff --git a/burn-import/src/burn/node/matmul.rs b/burn-import/src/burn/node/matmul.rs index 77c47bbd4..374d2da0f 100644 --- a/burn-import/src/burn/node/matmul.rs +++ b/burn-import/src/burn/node/matmul.rs @@ -13,11 +13,14 @@ pub struct MatmulNode { impl NodeCodegen for MatmulNode { fn output_types(&self) -> Vec { - vec![Type::Tensor(&self.output)] + vec![Type::Tensor(self.output.clone())] } fn input_types(&self) -> Vec { - vec![Type::Tensor(&self.lhs), Type::Tensor(&self.rhs)] + vec![ + Type::Tensor(self.lhs.clone()), + Type::Tensor(self.rhs.clone()), + ] } fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { @@ -57,6 +60,11 @@ mod tests { TensorType::new_float("tensor3", 4), )); + graph.register_input_output( + vec!["tensor1".to_string(), "tensor2".to_string()], + vec!["tensor3".to_string()], + ); + let expected = quote! { use burn::{ module::Module, @@ -64,12 +72,17 @@ mod tests { }; #[derive(Module, Debug)] - pub struct Model {} + pub struct Model { + _phantom: core::marker::PhantomData, + } impl Model { - pub fn new_with(record: ModelRecord) -> Self { - Self { } + pub fn new_with(_record: ModelRecord) -> Self { + Self { + _phantom: core::marker::PhantomData, + } } + #[allow(clippy::let_and_return)] pub fn forward(&self, tensor1: Tensor, tensor2: Tensor) -> Tensor { let tensor3 = tensor1.matmul(tensor2); diff --git a/burn-import/src/burn/node/max_pool2d.rs b/burn-import/src/burn/node/max_pool2d.rs index 9e75864ea..5cfa562f8 100644 --- a/burn-import/src/burn/node/max_pool2d.rs +++ b/burn-import/src/burn/node/max_pool2d.rs @@ -37,13 +37,13 @@ impl MaxPool2dNode { impl NodeCodegen for MaxPool2dNode { fn input_types(&self) -> Vec { - vec![Type::Tensor(&self.input)] + vec![Type::Tensor(self.input.clone())] } fn output_types(&self) -> Vec { - vec![Type::Tensor(&self.output)] + vec![Type::Tensor(self.output.clone())] } fn field_type(&self) -> Option { - Some(Type::Other(&self.field)) + Some(Type::Other(self.field.clone())) } fn field_init(&self, _with_record: bool) -> Option { @@ -109,6 +109,8 @@ mod tests { .with_padding(PaddingConfig2d::Valid), )); + graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); + let expected = quote! { use burn::{ module::Module, diff --git a/burn-import/src/burn/node/reshape.rs b/burn-import/src/burn/node/reshape.rs index bfb6354ae..578d178ee 100644 --- a/burn-import/src/burn/node/reshape.rs +++ b/burn-import/src/burn/node/reshape.rs @@ -13,11 +13,11 @@ pub struct ReshapeNode { impl NodeCodegen for ReshapeNode { fn output_types(&self) -> Vec { - vec![Type::Tensor(&self.output)] + vec![Type::Tensor(self.output.clone())] } fn input_types(&self) -> Vec { - vec![Type::Tensor(&self.input)] + vec![Type::Tensor(self.input.clone())] } fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { @@ -56,6 +56,8 @@ mod tests { [4, 4, 4, 4].into(), )); + graph.register_input_output(vec!["tensor1".to_string()], vec!["tensor2".to_string()]); + let expected = quote! { use burn::{ module::Module, @@ -63,11 +65,15 @@ mod tests { }; #[derive(Module, Debug)] - pub struct Model {} + pub struct Model { + _phantom: core::marker::PhantomData, + } impl Model { - pub fn new_with(record: ModelRecord) -> Self { - Self { } + pub fn new_with(_record: ModelRecord) -> Self { + Self { + _phantom: core::marker::PhantomData, + } } #[allow(clippy::let_and_return)] pub fn forward(&self, tensor1: Tensor) -> Tensor { diff --git a/burn-import/src/burn/node/unary.rs b/burn-import/src/burn/node/unary.rs index a0a30417c..1409d89be 100644 --- a/burn-import/src/burn/node/unary.rs +++ b/burn-import/src/burn/node/unary.rs @@ -1,5 +1,5 @@ use super::{Node, NodeCodegen}; -use crate::burn::{Scope, TensorType, ToTokens, Type}; +use crate::burn::{Scope, ToTokens, Type}; use burn::record::PrecisionSettings; use proc_macro2::TokenStream; use quote::quote; @@ -11,8 +11,8 @@ type FnPointer = Arc TokenStream>; /// Node for all unary operators. #[derive(Clone, new)] pub struct UnaryNode { - pub input: TensorType, - pub output: TensorType, + pub input: Type, + pub output: Type, pub kind: UnaryNodeKind, function: FnPointer, } @@ -20,20 +20,22 @@ pub struct UnaryNode { /// Type of unary node. #[derive(Clone)] pub enum UnaryNodeKind { + Cast, Flatten, + LogSoftmax, Relu, Sigmoid, - LogSoftmax, Transpose, } impl UnaryNodeKind { pub fn as_str(&self) -> &str { match self { + Self::Cast => "cast", Self::Flatten => "flatten", + Self::LogSoftmax => "log_softmax", Self::Relu => "relu", Self::Sigmoid => "sigmoid", - Self::LogSoftmax => "log_softmax", Self::Transpose => "transpose", } } @@ -55,16 +57,26 @@ impl std::fmt::Debug for UnaryNode { impl NodeCodegen for UnaryNode { fn output_types(&self) -> Vec { - vec![Type::Tensor(&self.output)] + vec![self.output.clone()] } fn input_types(&self) -> Vec { - vec![Type::Tensor(&self.input)] + vec![self.input.clone()] } fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { - let input = scope.tensor_use_owned(&self.input, node_position); - let output = &self.output.name; + // Get the lhs name in the form of token stream. + let input = match &self.input { + Type::Tensor(tensor) => scope.tensor_use_owned(tensor, node_position), + Type::Scalar(scalar) => { + let name = scalar.name.clone(); + quote! { #name } + } + _ => panic!("lhs must be a tensor or scalar"), + }; + + // let input = scope.tensor_use_owned(&self.input, node_position); + let output = &self.output.name(); let function = (self.function)(input); quote! { @@ -78,12 +90,7 @@ impl NodeCodegen for UnaryNode { } impl UnaryNode { - pub(crate) fn flatten( - input: TensorType, - output: TensorType, - start_dim: usize, - end_dim: usize, - ) -> Self { + pub(crate) fn flatten(input: Type, output: Type, start_dim: usize, end_dim: usize) -> Self { let start_dim = start_dim.to_tokens(); let end_dim = end_dim.to_tokens(); let function = move |input| quote! { #input.flatten(#start_dim, #end_dim) }; @@ -91,109 +98,193 @@ impl UnaryNode { Self::new(input, output, UnaryNodeKind::Flatten, Arc::new(function)) } - pub(crate) fn relu(input: TensorType, output: TensorType) -> Self { + pub(crate) fn relu(input: Type, output: Type) -> Self { let function = move |input| quote! { burn::tensor::activation::relu(#input) }; Self::new(input, output, UnaryNodeKind::Relu, Arc::new(function)) } - pub(crate) fn sigmoid(input: TensorType, output: TensorType) -> Self { + pub(crate) fn sigmoid(input: Type, output: Type) -> Self { let function = move |input| quote! { burn::tensor::activation::sigmoid(#input) }; Self::new(input, output, UnaryNodeKind::Sigmoid, Arc::new(function)) } - pub(crate) fn log_softmax(input: TensorType, output: TensorType, dim: usize) -> Self { + pub(crate) fn log_softmax(input: Type, output: Type, dim: usize) -> Self { let dim = dim.to_tokens(); let function = move |input| quote! { burn::tensor::activation::log_softmax(#input, #dim) }; Self::new(input, output, UnaryNodeKind::LogSoftmax, Arc::new(function)) } - pub(crate) fn transpose(input: TensorType, output: TensorType) -> Self { + pub(crate) fn transpose(input: Type, output: Type) -> Self { let function = move |input| quote! { #input.transpose() }; Self::new(input, output, UnaryNodeKind::Transpose, Arc::new(function)) } + + /// Casts the input to the output type. + /// + /// Currently this function only supports the following conversions: + /// 1) scalar -> scalar + /// + /// TODO: Implement the following conversions: + /// 2) tensor int -> tensor float + /// 3) tensor float -> tensor int + /// 4) tensor -> scalar + /// 5) scalar -> tensor + pub(crate) fn cast(input: Type, output: Type) -> Self { + let function = match output.clone() { + Type::Scalar(scalar) => { + let ty = scalar.ty(); + move |input| quote! { #input as #ty } + } + Type::Tensor(_tensor) => { + // TODO: Implement this after tensor Int is implemented (@antimora 8/2/2023) + // TODO: If the input is scalar and the output type is a tensor, + // we should generate another code block. (@antimora 8/4/2023) + // Tensor::from_data(Data::from([#input]).convert()).unsqueeze(); + todo!() + } + + _ => panic!("output must be a tensor"), + }; + + Self::new(input, output, UnaryNodeKind::Cast, Arc::new(function)) + } } #[cfg(test)] mod tests { use super::*; - use crate::burn::node::tests::codegen_unary_operator; - use crate::burn::TensorType; + use crate::burn::node::tests::one_node_graph; + use crate::burn::{ScalarKind, ScalarType, TensorType}; #[test] fn test_unary_codegen_flatten() { - codegen_unary_operator::<4, _>( + one_node_graph( UnaryNode::flatten( - TensorType::new_float("tensor1", 4), - TensorType::new_float("tensor2", 4), + Type::Tensor(TensorType::new_float("tensor1", 4)), + Type::Tensor(TensorType::new_float("tensor2", 4)), 1, 2, ), quote! { - let tensor2 = tensor1.flatten(1, 2); + pub fn forward(&self, tensor1: Tensor) -> Tensor { + let tensor2 = tensor1.flatten(1, 2); - tensor2 + tensor2 + } }, + vec!["tensor1".to_string()], + vec!["tensor2".to_string()], ); } #[test] fn test_unary_codegen_relu() { - codegen_unary_operator::<4, _>( + one_node_graph( UnaryNode::relu( - TensorType::new_float("tensor1", 4), - TensorType::new_float("tensor2", 4), + Type::Tensor(TensorType::new_float("tensor1", 4)), + Type::Tensor(TensorType::new_float("tensor2", 4)), ), quote! { - let tensor2 = burn::tensor::activation::relu(tensor1); + pub fn forward(&self, tensor1: Tensor) -> Tensor { + let tensor2 = burn::tensor::activation::relu(tensor1); - tensor2 + tensor2 + } }, + vec!["tensor1".to_string()], + vec!["tensor2".to_string()], ); } #[test] fn test_unary_codegen_sigmoid() { - codegen_unary_operator::<4, _>( + one_node_graph( UnaryNode::sigmoid( - TensorType::new_float("tensor1", 4), - TensorType::new_float("tensor2", 4), + Type::Tensor(TensorType::new_float("tensor1", 4)), + Type::Tensor(TensorType::new_float("tensor2", 4)), ), quote! { - let tensor2 = burn::tensor::activation::sigmoid(tensor1); + pub fn forward(&self, tensor1: Tensor) -> Tensor { + let tensor2 = burn::tensor::activation::sigmoid(tensor1); - tensor2 + tensor2 + } }, + vec!["tensor1".to_string()], + vec!["tensor2".to_string()], ); } #[test] fn test_unary_codegen_log_softmax() { - codegen_unary_operator::<4, _>( + one_node_graph( UnaryNode::log_softmax( - TensorType::new_float("tensor1", 4), - TensorType::new_float("tensor2", 4), + Type::Tensor(TensorType::new_float("tensor1", 4)), + Type::Tensor(TensorType::new_float("tensor2", 4)), 1, ), quote! { - let tensor2 = burn::tensor::activation::log_softmax(tensor1, 1); + pub fn forward(&self, tensor1: Tensor) -> Tensor { + let tensor2 = burn::tensor::activation::log_softmax(tensor1, 1); - tensor2 + tensor2 + } }, + vec!["tensor1".to_string()], + vec!["tensor2".to_string()], ); } #[test] fn test_unary_codegen_transpose() { - codegen_unary_operator::<4, _>( + one_node_graph( UnaryNode::transpose( - TensorType::new_float("tensor1", 4), - TensorType::new_float("tensor2", 4), + Type::Tensor(TensorType::new_float("tensor1", 4)), + Type::Tensor(TensorType::new_float("tensor2", 4)), ), quote! { - let tensor2 = tensor1.transpose(); + pub fn forward(&self, tensor1: Tensor) -> Tensor { + let tensor2 = tensor1.transpose(); - tensor2 + tensor2 + } }, + vec!["tensor1".to_string()], + vec!["tensor2".to_string()], + ); + } + + #[test] + fn test_unary_codegen_cast() { + one_node_graph( + UnaryNode::cast( + Type::Scalar(ScalarType::new("scalar1", ScalarKind::Float64)), + Type::Scalar(ScalarType::new("scalar2", ScalarKind::Float32)), + ), + quote! { + pub fn forward(&self, scalar1: f64) -> f32 { + let scalar2 = scalar1 as f32; + + scalar2 + } + }, + vec!["scalar1".to_string()], + vec!["scalar2".to_string()], + ); + one_node_graph( + UnaryNode::cast( + Type::Scalar(ScalarType::new("scalar1", ScalarKind::Float32)), + Type::Scalar(ScalarType::new("scalar2", ScalarKind::Float64)), + ), + quote! { + pub fn forward(&self, scalar1: f32) -> f64 { + let scalar2 = scalar1 as f64; + + scalar2 + } + }, + vec!["scalar1".to_string()], + vec!["scalar2".to_string()], ); } } diff --git a/burn-import/src/burn/ty.rs b/burn-import/src/burn/ty.rs index 9dedb4dcf..acbb02463 100644 --- a/burn-import/src/burn/ty.rs +++ b/burn-import/src/burn/ty.rs @@ -10,64 +10,114 @@ pub struct TensorType { pub name: Ident, pub dim: usize, pub kind: TensorKind, + pub shape: Option>, } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Copy)] pub enum TensorKind { Int, Float, Bool, } +#[derive(Debug, Clone)] +pub enum ScalarKind { + Int32, + Int64, + Float32, + Float64, + Bool, +} + +#[derive(Debug, Clone)] +pub struct ScalarType { + pub name: Ident, + pub kind: ScalarKind, +} + #[derive(Debug, Clone)] pub struct OtherType { pub name: Ident, pub ty: TokenStream, } -pub enum Type<'a> { - Tensor(&'a TensorType), - Other(&'a OtherType), +#[derive(Debug, Clone)] +pub enum Type { + /// Tensor type. + Tensor(TensorType), + + /// Scalar type. + Scalar(ScalarType), + + // Other type (more flexible type). + Other(OtherType), } -impl<'a> Type<'a> { +impl Type { pub fn name(&self) -> &Ident { match self { Type::Tensor(tensor) => &tensor.name, + Type::Scalar(scalar) => &scalar.name, Type::Other(other) => &other.name, } } pub fn ty(&self) -> TokenStream { match self { Type::Tensor(tensor) => tensor.ty(), + Type::Scalar(scalar) => scalar.ty(), Type::Other(other) => other.ty(), } } } +impl ScalarType { + pub fn new>(name: S, kind: ScalarKind) -> Self { + Self { + name: Ident::new(name.as_ref(), Span::call_site()), + kind, + } + } + pub fn ty(&self) -> TokenStream { + match self.kind { + ScalarKind::Int32 => quote! { i32 }, + ScalarKind::Int64 => quote! { i64 }, + ScalarKind::Float32 => quote! { f32 }, + ScalarKind::Float64 => quote! { f64 }, + ScalarKind::Bool => quote! { bool }, + } + } +} + impl TensorType { - pub fn new>(name: S, dim: usize, kind: TensorKind) -> Self { + pub fn new>( + name: S, + dim: usize, + kind: TensorKind, + shape: Option>, + ) -> Self { Self { name: Ident::new(name.as_ref(), Span::call_site()), dim, kind, + shape, } } pub fn new_float>(name: S, dim: usize) -> Self { - Self::new(name, dim, TensorKind::Float) + Self::new(name, dim, TensorKind::Float, None) } pub fn new_int>(name: S, dim: usize) -> Self { - Self::new(name, dim, TensorKind::Int) + Self::new(name, dim, TensorKind::Int, None) } pub fn new_bool>(name: S, dim: usize) -> Self { - Self::new(name, dim, TensorKind::Bool) + Self::new(name, dim, TensorKind::Bool, None) } pub fn ty(&self) -> TokenStream { let dim = self.dim.to_tokens(); + // TODO use passed elem kind and do not assume float (@antimora 8/1/2023) quote! { Tensor } diff --git a/burn-import/src/onnx/dim_inference.rs b/burn-import/src/onnx/dim_inference.rs index f53e61214..fb4c05c8f 100644 --- a/burn-import/src/onnx/dim_inference.rs +++ b/burn-import/src/onnx/dim_inference.rs @@ -1,8 +1,11 @@ use std::collections::HashMap; +use protobuf::Enum; + use super::{ - ir::{ArgType, Argument, AttributeValue, Node, NodeType, TensorArg}, + ir::{ArgType, Argument, AttributeValue, ElementType, Node, NodeType, TensorArg}, op_configuration::flatten_config, + protos::tensor_proto::DataType, }; struct TensorDimUpdater { @@ -70,15 +73,13 @@ pub fn dim_inference( NodeType::Sub => same_as_input(node), NodeType::Pow => same_as_input(node), NodeType::Mul => same_as_input(node), - NodeType::Cast => same_as_input(node), + NodeType::Cast => cast_update_outputs(node), NodeType::Div => same_as_input(node), NodeType::Sqrt => same_as_input(node), NodeType::Softmax => same_as_input(node), NodeType::Erf => same_as_input(node), NodeType::ReduceMean => mean_update_outputs(node), - NodeType::Constant => { - node.outputs[0].ty = ArgType::Constant; - } + NodeType::Constant => constant_update_outputs(node), NodeType::Equal => same_as_input(node), NodeType::Shape => shape_update_outputs(node), NodeType::Unsqueeze => unsqueeze_update_outputs(node), @@ -89,7 +90,9 @@ pub fn dim_inference( NodeType::Concat => concat_update_outputs(node), NodeType::Reshape => reshape_update_outputs(node), NodeType::Dropout => same_as_input(node), - NodeType::GlobalAveragePool => same_as_input(node), //FIXME use correct output + + //FIXME use correct output for GAP (@antimora 8/1/2023) + NodeType::GlobalAveragePool => same_as_input(node), _ => todo!( "shape inference for {:?} is not implemented", node.node_type @@ -102,6 +105,20 @@ pub fn dim_inference( updater.update_arguments(graph_outputs); } +fn constant_update_outputs(node: &mut Node) { + // Fix the tensor dimension of the output when the value is tensor + let output = &mut node.outputs[0]; + match node.attrs.get("value") { + Some(value) => match &value { + AttributeValue::Tensor(tensor) => { + output.ty = ArgType::Tensor(TensorArg { dim: tensor.dim }); + } + _ => {} + }, + None => panic!("Constant node must have a value attribute"), + }; +} + /// Infer the shape of the output tensor of a Conv2d node fn linear_update_outputs(node: &mut Node) { if node.inputs.len() != 1 { @@ -119,6 +136,49 @@ fn linear_update_outputs(node: &mut Node) { } } +/// Update the output type using "to" attribute +fn cast_update_outputs(node: &mut Node) { + if node.inputs.len() != 1 { + panic!("Cast: multiple inputs are not supported"); + } + let output = &mut node.outputs[0]; + + // Extract cast type and update the output tensor + let elem_type = match node.attrs.get("to") { + Some(value) => match &value { + AttributeValue::Int64(type_id) => match DataType::from_i32(*type_id as i32).unwrap() { + DataType::FLOAT => ElementType::Float32, + DataType::INT32 => ElementType::Int32, + DataType::INT64 => ElementType::Int64, + DataType::DOUBLE => ElementType::Float64, + _ => panic!("Cast: unsupported type"), + }, + _ => panic!("'to' attribute must be an Int64"), + }, + None => panic!("Constant node must have a value attribute"), + }; + + match output.ty.clone() { + ArgType::Tensor(tensor) => { + if tensor.dim == 0 { + // treat 0-dim tensor as scalar + output.ty = ArgType::Scalar(elem_type); + } else { + todo!("Cast: update tensor type"); + // TODO track the type of the tensor elements (@antimora 8/1/2023) + // output.ty = ArgType::Tensor(TensorArg { + // dim: tensor.dim, + // elem_type, + // }); + } + } + ArgType::Scalar(_scalar) => { + output.ty = ArgType::Scalar(elem_type); + } + _ => panic!("Only tensor input is valid"), + } +} + fn concat_update_outputs(node: &mut Node) { let tensor = node .inputs @@ -184,7 +244,7 @@ fn unsqueeze_update_outputs(node: &mut Node) { let dim = match node_input.clone().ty { ArgType::Tensor(tensor) => tensor.dim, ArgType::Shape(dim) => dim, - ArgType::Constant => panic!("Needs shape or tensor"), + ArgType::Scalar(_) => panic!("Needs shape or tensor"), }; node.outputs[0].ty = ArgType::Tensor(TensorArg { dim: dim + 1 }); diff --git a/burn-import/src/onnx/from_onnx.rs b/burn-import/src/onnx/from_onnx.rs index abbfae51a..811a00de7 100644 --- a/burn-import/src/onnx/from_onnx.rs +++ b/burn-import/src/onnx/from_onnx.rs @@ -36,6 +36,16 @@ pub fn parse_onnx(onnx_path: &Path) -> ONNXGraph { let onnx_model: ModelProto = Message::parse_from_reader(&mut file).expect("Unable to parse ONNX file"); + log::debug!("Number of nodes: {:?}", onnx_model.graph.node.len()); + log::debug!("Number of inputs: {:?}", onnx_model.graph.input.len()); + + log::debug!( + "Number of initializers: {:?}", + onnx_model.graph.initializer.len() + ); + + log::debug!("Number of outputs: {:?}", onnx_model.graph.output.len()); + // Convert the nodes let mut nodes: Vec = vec![]; for onnx_node in onnx_model.graph.node.iter() { @@ -46,10 +56,7 @@ pub fn parse_onnx(onnx_path: &Path) -> ONNXGraph { move_inputs_to_state(&mut nodes, &onnx_model.graph.initializer); // Get the topological sort of the nodes and the top nodes - let (ts, top_nodes) = get_top_nodes(&nodes); - - // Sort the nodes - top_sort_nodes(&mut nodes, ts); + top_sort_nodes(&mut nodes); // Collect inputs, outputs and initializers let check_if_initializer: HashSet = onnx_model @@ -58,7 +65,8 @@ pub fn parse_onnx(onnx_path: &Path) -> ONNXGraph { .iter() .map(|x| x.name.clone()) .collect(); - let mut inputs = collect_inputs(&onnx_model, &check_if_initializer, top_nodes); + let mut inputs = collect_inputs(&onnx_model, &check_if_initializer); + let mut outputs = collect_outputs(&onnx_model, check_if_initializer); let states = collect_states(onnx_model); @@ -90,10 +98,7 @@ fn collect_states(onnx_model: ModelProto) -> Vec { for initializer in onnx_model.graph.initializer.iter() { let tensor_proto = initializer.clone(); - let name = tensor_proto.name.clone(); - - // FIXME data conversion for the tensor is incorrect let tensor: Tensor = tensor_proto.try_into().unwrap(); let ty = StateType::Tensor(tensor); let arg = State { name, ty }; @@ -108,7 +113,6 @@ fn collect_outputs( onnx_model: &ModelProto, check_if_initializer: HashSet, ) -> Vec { - // TODO: filter out the outputs that are not used in the graph let outputs: Vec = onnx_model .graph .output @@ -123,42 +127,30 @@ fn collect_outputs( fn collect_inputs( onnx_model: &ModelProto, check_if_initializer: &HashSet, - top_nodes: HashSet, ) -> Vec { + // Get the unique inputs let inputs: Vec = onnx_model .graph .input .iter() .filter(|x| !check_if_initializer.contains(x.name.as_str())) - .filter(|x| top_nodes.contains(&x.name)) + // .filter(|x| top_nodes.contains(&x.name)) .map(|x| Argument::try_from(x.clone()).unwrap()) .collect(); - inputs + + // Convert to a vector and return + inputs.into_iter().collect() } /// Sort the nodes in topological order -fn top_sort_nodes(nodes: &mut Vec, mut ts: TopologicalSort) { +fn top_sort_nodes(nodes: &mut Vec) { + let mut ts = topsort(nodes); *nodes = vec![]; while let Some(node) = ts.pop() { nodes.push(node); } } -/// Get the top nodes in the graph -fn get_top_nodes(nodes: &Vec) -> (TopologicalSort, HashSet) { - // Get the names of the top nodes (first nodes in the graph to receive the input) - // Sometimes onnx will pass inputs to be used as weights and biases but they are not truly inputs - let ts = topsort(nodes); - let mut top_nodes: HashSet = HashSet::new(); - - for node in ts.peek_all() { - for input in node.inputs.iter() { - top_nodes.insert(input.name.clone()); - } - } - (ts, top_nodes) -} - fn to_string(bytes: Vec) -> String { from_utf8(bytes.as_slice()).unwrap().to_string() } diff --git a/burn-import/src/onnx/ir.rs b/burn-import/src/onnx/ir.rs index 69ef8f856..6902a5c55 100644 --- a/burn-import/src/onnx/ir.rs +++ b/burn-import/src/onnx/ir.rs @@ -14,9 +14,9 @@ pub struct Argument { #[derive(Debug, Clone)] pub enum ArgType { - Tensor(TensorArg), + Scalar(ElementType), Shape(usize), - Constant, + Tensor(TensorArg), } #[derive(new, Default, Debug, Clone)] @@ -142,6 +142,15 @@ impl core::hash::Hash for Argument { } } +impl Eq for Argument {} + +// Required by HashSet +impl PartialEq for Argument { + fn eq(&self, other: &Self) -> bool { + self.name == other.name + } +} + /// The list of supported node types (ONNX operators and some extra ones to map easily to Burn's ops) #[derive(Debug, Hash, Eq, PartialEq, EnumString, Clone, Display)] pub enum NodeType { diff --git a/burn-import/src/onnx/to_burn.rs b/burn-import/src/onnx/to_burn.rs index 78cafeadc..953462c8c 100644 --- a/burn-import/src/onnx/to_burn.rs +++ b/burn-import/src/onnx/to_burn.rs @@ -16,7 +16,7 @@ use crate::{ batch_norm::BatchNormNode, binary::BinaryNode, concat::ConcatNode, - constant::{ConstantNode, ConstantValue}, + constant::{ConstantNode, ConstantValue, TensorValue}, conv2d::Conv2dNode, linear::LinearNode, matmul::MatmulNode, @@ -24,7 +24,7 @@ use crate::{ reshape::ReshapeNode, unary::UnaryNode, }, - TensorType, + ScalarKind, ScalarType, TensorKind, TensorType, Type, }, format_tokens, logger::init_log, @@ -39,7 +39,7 @@ use crate::{ use super::{ from_onnx::parse_onnx, - ir::{ArgType, Argument, ONNXGraph, State, StateType, Tensor, TensorData}, + ir::{ArgType, Argument, ElementType, ONNXGraph, State, StateType, Tensor, TensorData}, op_configuration::concat_config, }; @@ -98,6 +98,8 @@ impl ModelGen { fn run(&self, is_build_script: bool) { log::info!("Starting to convert ONNX to Burn"); + log::info!("Starting to convert ONNX to Burn"); + // prepend the out_dir to the cargo_out_dir if this is a build script let out_dir = if is_build_script { let cargo_out_dir = env::var("OUT_DIR").expect("OUT_DIR env is not set"); @@ -112,6 +114,8 @@ impl ModelGen { log::debug!("Output directory: {:?}", out_dir); + log::debug!("Output directory: {:?}", out_dir); + create_dir_all(&out_dir).unwrap(); for input in self.inputs.iter() { @@ -122,10 +126,16 @@ impl ModelGen { log::debug!("Input file name: {:?}", file_name); log::debug!("Output file: {:?}", out_file); + log::info!("Converting {:?}", input); + log::debug!("Input file name: {:?}", file_name); + log::debug!("Output file: {:?}", out_file); + Self::generate_model(self.development, input, out_file); } log::info!("Finished converting ONNX to Burn"); + + log::info!("Finished converting ONNX to Burn"); } /// Generate model source code and model state. @@ -134,6 +144,10 @@ impl ModelGen { log::debug!("Development mode: {:?}", development); log::debug!("Output file: {:?}", out_file); + log::info!("Generating model from {:?}", input); + log::debug!("Development mode: {:?}", development); + log::debug!("Output file: {:?}", out_file); + let graph = parse_onnx(input.as_ref()); if development { @@ -161,6 +175,8 @@ impl ModelGen { fs::write(out_file.with_extension("rs"), code_str).unwrap(); log::info!("Model generated"); + + log::info!("Model generated"); } } @@ -186,64 +202,112 @@ impl ONNXGraph { NodeType::Relu => graph.register(Self::relu_conversion(node)), NodeType::Flatten => graph.register(Self::flatten_conversion(node)), NodeType::LogSoftmax => graph.register(Self::log_softmax_conversion(node)), - NodeType::Constant => graph.register(Self::constant_conversion(node)), + NodeType::Constant => graph.register(Self::constant_conversion::(node)), NodeType::Reshape => graph.register(Self::reshape_conversion(node)), NodeType::Sigmoid => graph.register(Self::sigmoid_conversion(node)), NodeType::Transpose => graph.register(Self::transpose_conversion(node)), NodeType::Concat => graph.register(Self::concat_conversion(node)), + NodeType::Cast => graph.register(Self::cast_conversion(node)), _ => panic!("Unsupported node conversion {}", node.node_type), } } + // Get input and output names + let input_names = self + .inputs + .iter() + .map(|input| input.name.clone()) + .collect::>(); + let output_names = self + .outputs + .iter() + .map(|output| output.name.clone()) + .collect::>(); + + // Register inputs and outputs with the graph + graph.register_input_output(input_names, output_names); + graph } - fn constant_conversion(mut node: Node) -> ConstantNode { + fn constant_conversion(mut node: Node) -> ConstantNode { let output = node.outputs.get(0).unwrap(); + let value = node.attrs.remove("value").unwrap(); let value = match value { - AttributeValue::Float32(val) => ConstantValue::Float(val), - AttributeValue::Int64(val) => ConstantValue::Int(val as i32), - AttributeValue::Float32s(val) => ConstantValue::Float(val[0]), - AttributeValue::Int64s(val) => ConstantValue::Int(val[0] as i32), - _ => panic!("Unsupported constant node: {:?}", node), + AttributeValue::Float32(val) => ConstantValue::Float32(val), + AttributeValue::Int64(val) => ConstantValue::Int64(val), + AttributeValue::Tensor(tensor) => { + if tensor.dim == 0 { + // Treat zero dim tensor as scalar value by extracting the first element + // because PyTorch/ONNX uses zero dim tensor for scalar values + match tensor.data.unwrap() { + TensorData::Float32(val) => ConstantValue::Float32(val[0]), + TensorData::Float64(val) => ConstantValue::Float64(val[0]), + TensorData::Int32(val) => ConstantValue::Int32(val[0]), + TensorData::Int64(val) => ConstantValue::Int64(val[0]), + _ => panic!( + "Unsupported zero dim constant tensor type: {:?} ", + tensor.elem_type + ), + } + } else { + let ds = match tensor.elem_type { + ElementType::Float32 | ElementType::Float64 => TensorValue::Float( + tensor.clone().into_data_serialize::(), + ), + ElementType::Int32 | ElementType::Int64 => { + TensorValue::Int(tensor.clone().into_data_serialize::()) + } + _ => panic!("Unsupported constant tensor type: {:?} ", tensor.elem_type), + }; + + ConstantValue::::Tensor( + TensorType::new( + node.name.clone(), + tensor.dim, + tensor.elem_type.into(), + tensor.shape, + ), + ds, + ) + } + } + _ => panic!("Unsupported constant value: {:?} ", value), }; - ConstantNode::new(output.name.clone(), value) + ConstantNode::new(node.name.clone(), value, output.to_type()) } fn add_conversion(node: Node) -> BinaryNode { - // FIXME scalar vs tensor - let lhs = node.inputs.get(0).unwrap().to_tensor_type(); - let rhs = node.inputs.get(1).unwrap().to_tensor_type(); - let output = node.outputs.get(0).unwrap().to_tensor_type(); + let lhs = node.inputs.get(0).unwrap().to_type(); + let rhs = node.inputs.get(1).unwrap().to_type(); + let output = node.outputs.get(0).unwrap().to_type(); - BinaryNode::add(lhs, rhs, output) + BinaryNode::add(lhs.clone(), rhs.clone(), output.clone()) } fn sub_conversion(node: Node) -> BinaryNode { - // FIXME scalar vs tensor - let lhs = node.inputs.get(0).unwrap().to_tensor_type(); - let rhs = node.inputs.get(1).unwrap().to_tensor_type(); - let output = node.outputs.get(0).unwrap().to_tensor_type(); + let lhs = node.inputs.get(0).unwrap().to_type(); + let rhs = node.inputs.get(1).unwrap().to_type(); + let output = node.outputs.get(0).unwrap().to_type(); BinaryNode::sub(lhs, rhs, output) } fn mul_conversion(node: Node) -> BinaryNode { - // FIXME scalar vs tensor - let lhs = node.inputs.get(0).unwrap().to_tensor_type(); - let rhs = node.inputs.get(1).unwrap().to_tensor_type(); - let output = node.outputs.get(0).unwrap().to_tensor_type(); + let lhs = node.inputs.get(0).unwrap().to_type(); + let rhs = node.inputs.get(1).unwrap().to_type(); + let output = node.outputs.get(0).unwrap().to_type(); BinaryNode::mul(lhs, rhs, output) } fn div_conversion(node: Node) -> BinaryNode { - let lhs = node.inputs.get(0).unwrap().to_tensor_type(); - let rhs = node.inputs.get(1).unwrap().to_tensor_type(); - let output = node.outputs.get(0).unwrap().to_tensor_type(); + let lhs = node.inputs.get(0).unwrap().to_type(); + let rhs = node.inputs.get(1).unwrap().to_type(); + let output = node.outputs.get(0).unwrap().to_type(); BinaryNode::div(lhs, rhs, output) } @@ -257,35 +321,42 @@ impl ONNXGraph { } fn equal_conversion(node: Node) -> BinaryNode { - let lhs = node.inputs.get(0).unwrap().to_tensor_type(); - let rhs = node.inputs.get(1).unwrap().to_tensor_type(); - let output = node.outputs.get(0).unwrap().to_tensor_type(); + let lhs = node.inputs.get(0).unwrap().to_type(); + let rhs = node.inputs.get(1).unwrap().to_type(); + let output = node.outputs.get(0).unwrap().to_type(); BinaryNode::equal(lhs, rhs, output) } fn relu_conversion(node: Node) -> UnaryNode { - let input = node.inputs.get(0).unwrap().to_tensor_type(); - let output = node.outputs.get(0).unwrap().to_tensor_type(); + let input = node.inputs.get(0).unwrap().to_type(); + let output = node.outputs.get(0).unwrap().to_type(); UnaryNode::relu(input, output) } fn flatten_conversion(node: Node) -> UnaryNode { - let input = node.inputs.get(0).unwrap().to_tensor_type(); - let output = node.outputs.get(0).unwrap().to_tensor_type(); + let input = node.inputs.get(0).unwrap().to_type(); + let output = node.outputs.get(0).unwrap().to_type(); let (start_dim, end_dim) = flatten_config(&node); UnaryNode::flatten(input, output, start_dim, end_dim) } fn transpose_conversion(node: Node) -> UnaryNode { - let input = node.inputs.get(0).unwrap().to_tensor_type(); - let output = node.outputs.get(0).unwrap().to_tensor_type(); + let input = node.inputs.get(0).unwrap().to_type(); + let output = node.outputs.get(0).unwrap().to_type(); UnaryNode::transpose(input, output) } + fn cast_conversion(node: Node) -> UnaryNode { + let input = node.inputs.get(0).unwrap().to_type(); + let output = node.outputs.get(0).unwrap().to_type(); + + UnaryNode::cast(input, output) + } + fn reshape_conversion(mut node: Node) -> ReshapeNode { let input = node.inputs.get(0).unwrap().to_tensor_type(); let output = node.outputs.get(0).unwrap().to_tensor_type(); @@ -299,15 +370,15 @@ impl ONNXGraph { } fn sigmoid_conversion(node: Node) -> UnaryNode { - let input = node.inputs.get(0).unwrap().to_tensor_type(); - let output = node.outputs.get(0).unwrap().to_tensor_type(); + let input = node.inputs.get(0).unwrap().to_type(); + let output = node.outputs.get(0).unwrap().to_type(); UnaryNode::sigmoid(input, output) } fn log_softmax_conversion(node: Node) -> UnaryNode { - let input = node.inputs.get(0).unwrap().to_tensor_type(); - let output = node.outputs.get(0).unwrap().to_tensor_type(); + let input = node.inputs.get(0).unwrap().to_type(); + let output = node.outputs.get(0).unwrap().to_type(); let dim = log_softmax_config(&node); UnaryNode::log_softmax(input, output, dim) @@ -419,8 +490,54 @@ impl Argument { pub fn to_tensor_type(&self) -> TensorType { match &self.ty { ArgType::Tensor(tensor) => TensorType::new_float(self.name.clone(), tensor.dim), + _ => panic!("Can't transform to tensor."), + } + } + + pub fn to_type(&self) -> Type { + match &self.ty { + ArgType::Tensor(tensor) => { + // Treat tensor with dim 0 as scalar + if tensor.dim == 0 { + // FIXME Convert to correct scalar type (@antimora 8/1/2023) + // Currently it's not dangerous because we don't use specific scalar type + Type::Scalar(ScalarType::new(self.name.clone(), ScalarKind::Float64)) + } else { + Type::Tensor(TensorType::new_float(self.name.clone(), tensor.dim)) + } + } + + ArgType::Scalar(elem_type) => { + Type::Scalar(ScalarType::new(self.name.clone(), elem_type.into())) + } ArgType::Shape(_shape) => panic!("Can't transform shape to tensor."), - ArgType::Constant => panic!("Can't transform constant to tensor."), + } + } +} + +impl From<&ElementType> for ScalarKind { + fn from(elem_type: &ElementType) -> Self { + match elem_type { + ElementType::Float32 => ScalarKind::Float32, + ElementType::Float64 => ScalarKind::Float64, + ElementType::Int32 => ScalarKind::Int32, + ElementType::Int64 => ScalarKind::Int64, + ElementType::Bool => ScalarKind::Bool, + ElementType::String => panic!("String tensor unsupported"), + ElementType::Float16 => panic!("Float16 tensor unsupported"), + } + } +} + +impl From for TensorKind { + fn from(elem_type: ElementType) -> Self { + match elem_type { + ElementType::Float32 => TensorKind::Float, + ElementType::Float64 => TensorKind::Float, + ElementType::Int32 => TensorKind::Int, + ElementType::Int64 => TensorKind::Int, + ElementType::Bool => TensorKind::Bool, + _ => panic!("Unsupported tensor type"), } } } diff --git a/burn-import/tests/data/concat/concat.rs b/burn-import/tests/data/concat/concat.rs deleted file mode 100644 index 87e52e6db..000000000 --- a/burn-import/tests/data/concat/concat.rs +++ /dev/null @@ -1,31 +0,0 @@ -// Generated by integration tests -use burn::{ - module::Module, - tensor::{backend::Backend, Tensor}, -}; - - -#[derive(Module, Debug)] -pub struct Model {} - -impl Model { - pub fn new_with(record: ModelRecord) -> Self { - Self {} - } - - #[allow(clippy::let_and_return)] - pub fn forward(&self, input1: Tensor, input1: Tensor) -> Tensor { - let concat1_out1 = burn::tensor::Tensor::cat(vec![input1.clone(), input1.clone()], 1); - let concat2_out1 = burn::tensor::Tensor::cat( - vec![ - input1.clone(), - concat1_out1.clone(), - concat1_out1.clone(), - concat1_out1.clone(), - concat1_out1, - ], - 1, - ); - concat2_out1 - } -} diff --git a/burn-import/tests/data/conv2d/conv2d.onnx b/burn-import/tests/data/conv2d/conv2d.onnx deleted file mode 100644 index a67c34af3ed9f8bf95a4cf76365c8f5822a31744..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 17731 zcmY(qdsxif_y0c#Nk}Ro48lkP3MEDMrWN&N7ZXD z2BDh}(j7uKA%wU?i2M6`e|~@c{;2Dkx#sHH)84POpX>R2ELBy_;oe96ynOe#O&q$= zh;OvOyfA8zTJj*Zq3)jEhyB#l_jq~wFEBc?*L|OxpPJhL_sZFQ_aV(eizbeo|9`LN zuV6njUFCnj4K&q8xVU@l_H+00JR}?<93&q!+(2Dz=ssVs!``}s<%0$rXsf9o*z4=L z*JIZqx82@**$)l<--pJksd(>pIV4mO4*uWI|K8G28-B>o*WG3BA@;rh{Uyy2|F_az z``!HvWD$c@^;K0>ho}zK7_|KVx6HNTh(Vod22}TXC>n>I#-Pt*!Er_vd~xu=zH_z2 z)wu%NwT*Drixki*GXsaF-%Ei+cg-1` z9RUISW1v`g3UqUFp}TK8Ni&*6jfd!hP_{_WI;k3JEoWnQp&RWEUk1WRPuSHV#9qz> zXv!}olD#ot(>V=7tQ}B1Y*mAX+X8T)GvGwa9+V%?q{@n;U@)Z`xHlII6vr+JRO}uT zH^bwQT|O1%nXydFy!%wXz5;#i=i-a4DWG9@mtLHvfqnJH@MTCW2<6ta-@S(15m95Eagbn6__J;avesw~WRWul-Q#$PdyU zkbv@>S)^zGcDU)22@bWIx1&6(}+g zGSRkTG!<$<(#2G&du$nfX8wiF_>+&pf7L;qsW(XG>Ee|^bvVAe62)ICnW6>D@ZzC! zAWoYMV@emJUfL?CE)znYb3O`(X$U@M1)^-|b|4Zt;f2)0-Tsl7*SsFS{4R&JL(8$Z zKqBCU*n<2sp9oC`QJd*ciS4g(M0%M|F3KxFe)SBfCr883w>7Bo@d$9(sTVFe#+0At+>*|lVv$+-qpIW-9BUUl>ykpvoJn~9-A9B|%7n&%Ba0>+VZ;ai|8 z^cy!4E2m|&2NytwUpmU3N=dta4Q6PTL8(J2$~nH&wB_e7&Dctr-Ub978ZXMFYsG=QM*!#DFg|VDgK@7!fW( z`&-tq{6s$Rwd1jOa|r(Yo(z_+*Mj{|2k3UZPn3=$LHFKJs>s+u3~T4XyQXw3TsR5? z0!P7`zg5v!ln?&FSvW9rBXmc^!HlM@s2KS#+4`&soqQ!=D-EFJXcCbO7J^@jH)u_K zK*UAc(D<_q+=2%ab&EhytcxbC_0_QTRUOU<%tYmaC>ZAJh(XqOXgETET2x(8;v2mRa^qO>$3{Dv)s3(A1 z;sIzfiKEY;4C;1WrKZ`7KyOq6`N$!-{E;W}hY?iV-%0vs?!!9f0yV2|KwBF*bZ%Wq zy|>hpgNf?cq+1WaKQ6)kRrl%1srIP&TueK+JrWdsyGmZo-Hn|IJApGSmpH7l1hO(5 zD_0l8DSvYutr7}{3iqJLj+eCeKP{Rkk3#RH$;h)!r@a^31kZQpU~f$=)0-bpJqu4^ z&>=4n`$P#Gv)+;JQ9^2aYdzH%90FEm{?ztk70U0tqWqEpa^g@dy2lg}Mb&!O+P)}Y_E@09oYIMLi8!Hx~*AbfKjyXSj} z{nLH0p#KcoOB-lWun9OmxIk~z%P@6O5elZA$ASoT-B}~qYO1P65iz0CaPEyH2&bnrDaY;US_g=u3CC4GPY9ZPWV`#74Fgo}_FAd#Z zgDy6cQP*b!sFs&gg+K^@JkP_)_Yr86HWzs%=b>{$Jjnhkg@Dmf(Eaue;e|aWEh_>c zZNQ2wjCMqc-BH@&7e``@XM*U$HhTM93T9bm!6<_Y>@)2nimvwpepMdMSBb};))9Do zofjT;If9PcqG{Ljg;2HH1gG|%Mc(RhMA$N(Du?L`q{%_V^vE|-TE@r4VdT=u z#9{BOF+}Q}#FRRA(vL}X=zZ-AqcPh#U%>>|F8)EL=s|!p>4EDOVlO(5& za5v8#TdJ%<>Tw!v7UYqxx+N&e*#*nrW`M`}1|oOgOoUTgNpa_9Y+rj66xX+c>%JVA z7|#Jty)_v$y$o9|lA$NGitIlchT@h9g7cfBQO#Y97b6qExVM4y=X+ACvEFbtwjR%~ z+l|$=86bMF1IGK;qdf8^$$N8yv^gY!uvMKQAxT8wsEVBp4b;nj6ZYj=fw+Dc*}8Ta zHh#NF0;;RXzEk^1qo(&n#0dQqy8bN;fz52ZL@**G!XB{l|bKm zb?~sr1ZojS*JY`r-s(lzUo{IvhC0x%oB^Tlf777n`KWS!C0J$G&?e1?RP1d)T1gEJ z>a>PBt0P2IoDHL{?8flV|B`7!XXNNdF+7)UDqq~jxYh0;9D6NTu2~A+OEyzmlU2ws zxh-f3{!g%_V>22_5| zbLB#&{njs1RICM}PoEj7=`+Il^Il+cQ3e&`wqfo2-DrNQ2*qPO8Wg1i)T-hQbucL; z1^tDnzETExm!iQt`5qNAJ?0v>or&UPhhR;KIkvcYLhmnCRG)beBtHhDpT;53usuh* zr#b>%zX6oYIl_Nh28uh`MI%*8P)*EqzOm^?PpYTiplyRAp~XD{Rj=k!@7QZ(hITG? zsZ7LOQEKS$j{&A7jsm{27K3hmqI2!T&^E>v(n2s;WeyrKp+KdEI~I&v(w)<1R4ME);^k?uSP zA3iL?R{J%WGm!(bd5P3hTLG$dp0IfJ2CQ_}@;6a4)3^xCQzU@Jh~V+F)}nl}W(4V@BnZW3bQmC@3eX zf&bzhD4Hb&-n49@)hmYIfoW*_dlS{5skm0250O`AV!OheC`b*i*s}(`$1b2TaT~Dc z{(VXk;aOq71};ufISZRXn3uQ_BfP+M!XMI-nuFfu9*+fwZoAcYQ{(&ghQuj6*#z3 z8cc>(W6X8`?Ixa_4*1L{@@<%j!8jvD<^~x3o-Dh6x_B+ zss5-Fq~5U@wb#|7^gt7({R6FvYJL?RVZG;we7V|7H(W zTPDGvUs_nF9ZrRJ-w4F->zJOxJQ|f;2fZ(+6Vuo*=rx@~dPjX_#1XzA{`FVGzvYMV9AyU*k4M?R>GmNm#po8VArCYE`Ip^jH1sPR*=L2WB0OpAcy!TGpbMULic z*?#jbAF?~LP-wk~RF1U81cO8fU0M%ruSP@OAYS)@mN|aJUaX*|N42SaxjJ}Xy-n2{r{Q9QWb9lsR?r)#OIoG|ft!&Lh_{Y`o6F*G zxk&`*4mg2auY+cXO~KmP4bZ)G0~HQ_%W8#N)a!E=W)GQ-y88Z5em(-^506uR_8~|b z8%-`aIiSm|9PCf5BK@P@6Tgp{aB%QG6#A-yp!llKA7(DnC zj+lM`r(94#PTDVOY8nlkKY!Qj7ex@O3sWIMA;+y}gBHX!9Zrc+y!v2pYX z5}>-C)joG<_`}OY{9_r_^&U-y4w1K~Rhpy|t9TKXV`&MBUUBhIeJ zF@txZTi13_&5c6;i{8iyItUl9pMtt@S8BPMgOU;(P@Vb@eYYqHUwizZ{2)!j)qW+= zFZ-9s_qH+zZo6aFm>AF;atqROZ?}gl-Y#rKE0OBLxnc{*XxN|cda{?N$_D&vz z?+Ah~hu5K1tuLsZE{DVWJu&Zg3ak4v+5NT0JnIq&zx9}^zz)>Y(*$u`Ktn*nO}e~g zCFWf&f>rOsaqSOt>>d4#h`xPitU?G?Up^8Chwkit~F$i^K6uw6jL|$-c~klMpd>Zbssf?-v^3OlFp+X3pF756R>MXCH?bXD4N{P z#DV`(!75#b52hMo@|m63I@KIoPfeVjWB)AZwr?&&b% zmYPehOe{cU#(EOrW`dpEETT2#7EPJxfj5tw2HC4UU~xPX^mpB+v*OR<7rS$4;adr% zo@a5_jI(I}a~epU)+0qYr^L2hsj!?~_g|D)IxrC(|(ud|lz(UQ>Z*l6J7#DLYk zdi;^9ht(^!U};eZDxQ|WsT2;VU#=iGKb!>Js99jEyed$+7)!brO{a?SUzzo*qd;Dn zM|9gtsdq**X%tq1rB(*mD!wrqeG=+W{FSzIyl8)}D&_w&fcAs)aKBp)cKK?7r9~Nt z&J_^Zxj)o2zkTpREO-1#SoW z>FT$8&|POYHZPw7eMd5go2d*pauPA=hYb!CUnSM;X0SnJ44QvCg<2o0D2L})PaMi< ztj%}YbEBJVUseM{iWi`RW;EI~dV``;0!es-WW75LkJCrM%WYX09ybwqjrJJVT#Ny= zJnH;+4#cePC#q4JM4;`7J#%WX&hi);InMxXu9ksN^}XQQ84KjiJPI)@sv-34ztquj z0_znfqOeMnrVY)cp|+cmYj~6h|HPQLB<6wldmrFduYjlH*Q10t8Jf8H*z|M=E&4H= zcDFQO(X!dJsjz}byiFireGByw%t8(4Y`Ru23I>$y+T+dwuk9_BKif-kk~6Rm6vQ#B z0gCP&!@*UBaR1pP9U>SJT5JWiFdb!lfr=k)}$vOQ4Gw~}zr9L8F?F7mrBGyl!; zLDioNp|(E>nyaKlw{;RZ`FIYt=L*s7Q3tVGItxWQ-HeCv3}|?k2d^KU#upE>zjUk0t~48>0C$An+;N?9F&0EqUJ@}Um*kcGN38B%BC_CQ$a0#Fo3_gE z%F9U5dNU2Uy*g04Mhap#CxP>U8qhm@9O_2r;hzry$aNjh6ltjv{)s4}yttXlvm&Wd z@{Vp>BuC+`Eu?OlHQB#87Zt`!sg^}4<(s*nP0>g=^}Yb}glAwt4?y>eI$XE3z?)9- zU~2e~HWtN_(9d~r@~l6mIUYq}(X9sKulw-l0U4hDAwkO#Ye12l3tqw$tQ{PL;!(k< zY)yx^Tmy6%IuY`Q*+55b2@&OfBNM*O!;^0TAAe$frz921{kVhnI}N4AJ{vGeR}AV` zqmc9Pk-7SUTKw~-78RT}L9tOa1}{$qbat|z!YNB)~fwq!{_>T8H0#F#JU!5kGjz96vb(eWl^}Y(XM+Uj2)1OJRG#;Tu88o?Dlf zFRA6>JnU`@!7cAk15fQ2-6Y$IVz|O?WlaiP^QgdCR#Ub$Eembp!l>A zQ^$|QDW^(N^oM}0Ln>XBI~@4$IJkI83=FrBp&|zbX`zRmeWh;8*Q0c&DyZ~H zX_0jZR*oz|D(YuWb?NEI$Q?LtE&@ZRz;uNdUTi$)&xbMbKUD1Co#HKpE-HbeUcy z3Cj~92nC?_+XwZ%(?Rf~3`Nd%@Ue3e_D(dS;F^tjdR{o7A%Nzyx*$HK2>}PRXz3O& z5Ix;as*gyqGot|lj9<{SvBzNG{CJQlR)St&7oi!|P`M-#pQ+_SA3p_z3-&;~&tY`y ze@ZnzT%$ieEWxx_0G$D$(C3*(st*K#`txHzo>xMWrW%eM;fU8YV%fVQ3+m*%iEe)l z*}nZ7?Z1{s^2WP@?JK0O?XrM;s|HuQlbAIekngq+ysZ5vjM5 zrl=qYtQv#u&5qEJx)A-IMZ?diM6elkk?^`xpz(Y-X$?uomN&UMHt5NOr2l{#b0(6{{fwNQ^z&FPN)9MdnK%W}EdmRCOgPO>c>k5!;oJX5aOvE1Z z^JrYX1G?9|Bz?`MP|4qleH0v#x_7lCkfC8NCZpTMy# zf-WA;X0T)8A;kX_K6uiA^5&5!`QZh1_sVhYi8$yfWY25aQYy4h!>Dhb=-@V$&a7IG zsuxp0a&A9ZtvZP<%Orxv12a&oIvh1hB+P>sLoqGa1q^jlU`}(Vx}4{qeC7_>UB$mJ%!+6^Wc(XU$?cGDsVe z&GtJsl;fX3cKEJF;oPt0BIozSXqObc%OeC-HyKw~twP;V8ML#c5ET+OB_3Q2Zd>ON z=b1$ia6t^g?+T%pdjZm{PvNDjEXT0t5^3#S4eh$y3HOF6rSvm>q4T{oOet z{u(GSbvi*r&HEXb{B^LyCm20lSdAC;9}znoGG8e@h|XuKQKA|JIpcHj;Id5YsqP`) zf8^kL!%*P6ES5VN^N$CKs-xG=hbQc1A9us9k~HMnoDgVuHB$92`BWlj`G^}&Nx*{FH1$P3igyfST1uTkoSXy;FKS`w z;!5z7jDmh*OT*`ufWg2~6m<+{*6tjHe!Ll2H`@}PPIW}bxC`hO9gDr?I>5i|jBV_9 z`bM{)+uK%Rd#9i2w^t*rKel7-_#qH3W3%4m({zH;1$&3Pf}!(W!bx&xgz^_eYHbIL zl@oCIZ$GrJC;@Tv2lE5X1eKo(VC73gTstrld)mZQ_~e@ToHLqu|G_89Z{@=0stG7t z_Jb7dN+h;ACqOaHp2{4al3vZ1w7q8u={#-&=&=C3jdJN{hg9r3oI~y#&q4ox#PFKE zGq`yYW})OWy?e@*)t0H)v-v$qV;HL0b`sJ)^I*O31RB;y(&mNLSgL=Y{(CJ8-K&z( zPsay7z*f9@iOW6*`SiN;H|p1A3(b;9tk(QQ#n)mRuH3GH){r3Vj!|Gg69g*K^X!>Y zK^?Y#bo#p*`gLnSZZk&Uzn#^JOJ`v3ws|xl z?NiLLb!rjqU(-NaOeW#$ODbS)SO$ESWl}WOko5*fQPp=jZDD<%u3|oT-|Z3{KJA94 ze>^}dU>Ar@W|{8_NyBe{3c+zj5q*5TkLuTjpkdtz;0|n~erC&YAiV^ix`m=ctSaH% znF&RQw-S}S5)%6>0h9r~4c*JC=)Z=gNT#q{O|zWou-1hBmkfEK9|j)MLaY|vqMa92 zabee9e7Cp))E*y4-EC{>muG44#wQS)v$uiAem2v%{lr{j#%|yz7Ef3(!YFv%&6Za55wE%D>r?GljEzIb#LBq@g&}-8x z=$`V6`V{71%%e)$!Ri2(mJ)=YIRLHgaoFTnOSVjvLGPeEM)GecD&w@Ezk4p#t&PU6 zH%CFz;{Y~GE|Sij`q1|{8j6M|5>BOsU{k&x2Hcs7Ax(2pc+3`0%$$N8jre-esOwaG zs1%et#|XMth=}gg3W&Oz0OFt?sycTO8L%z}w~f{S$yV6)xEuzwDnXvwEqK#517&}g zf=qP*bk}u}ex?cxUl-A)nY&<{?QCo}VSV|nvF3`xekQ?425+oukl)R&Uws(4@XtQf zT2M+wcH5}%6!`cG*4tWqf3e0WCJsBw-z{K6b+rdY=-Lmow=UK!J@JP z463L_!C%H`y>ttT2YoQ-ZHXtkJ1WWA#4L!?9S%=dv31xm6W%!O#O50*&@#*b0*@SH z{rzykc-0Z;SUU@wK59|%QBT2><`nF$HlgY>i$UyOL5r3?C(G9*gW}FK`1C3P!%J?_ z7TAuQZ99pz_XTVUvm)e0KQ%67GZWqvR5W;i#v30{bPj^98}%?y!R9HcW+bNjJa~8n zL99^@s-^72e#cxCkG#*^@JdGAiwgL|uYruRDahUE#rzz61V#O?$i#a;sbhl+h_(Dk zp1u?6XN`u@H5%x6#UDAgI`w@1Ly~8>2u!9X<0nl&tS;i9L#+joYvwSVuB!sodz!?` zYa&ccor0$4i{aaZ0^EM&Jbc;fjC%UNiK<{E%;^)L$?m-@8>t3EIa|=NY8TYmT_&I2 zq#%Dz5$M_$qJwfaM z%Xc;g4HdwTO*yFGS76YM)gWTMNU2H{BY)CHS{MWB@a_Q>KkA@SEZ^7p@g>pwRYUUa z1sJrkh6cK#12MH@1;MKzONDHOJPp=LhGuEg~$YK54 z28PLG$6Um?)GI5AG9$lD$oB<=jv%59Wcu56?8Xsu)GZ`8BUkzTW zJd4dWDRi=6IC@<9N<4N`%JG@SXz#AT#<)_VUbO&JL4;%dwxHFKwJ?O`R~Md)Mee>047V$us(jLf z?q6PbacL^_9T)+|eQF?X7z6LzE5Ub#6dG5u{>k)oOz!YUdE93~f%AFv{V0b}lPLJJ zK!z{f598#!naGvZQ3r(^30$s$qo!`@+^ ziPhXQsCY3MG*UOgmb&3M)@w9sbV&s+mKB7XaF>$$L-5bbNTiS664|14;9xV0q!p=B zH+z3=)W z+b%Zq{AEiVN;-&i>0DyIVJNh$&tfy5L$ox`0S4kTsO7RT=y2&YIrS?Xh8%H0zp^sO zoEM6JjuxZzN1q_^jw@DcO$3{@C#Yjh8qvMkMIBZ+fn<+9S!OXGZ3AO~6K-H$IVKh? zpO&FGMAux;w~qSX3j^ghb$C{H8W-H#hWpRwg5T|<5NzClr5sJ@Os#}oqhRn@iZqX7 z2`&qs5&5qI2)r8yPW9?oH{va|z5JcFu8GDMJIb-oN(1h08G}}1ZJ_-_0tk;^H@7@j zfH(R#U^sUUNXOg|)ODFd=ij4XV3jevTV;hF{}jN&-hKFu-Lsj0o1u-aHfkE3K#f_i z1>^UgMGg0ADvR6%wi`phtT(3nzJ18J!nH0EeZmxbH5)v0RpjY}H^nJPw1G0Qk z5+@{0Ri0ot28-!PmlE%BgSe^Za*+1qaHa)pce~Nx-p57zI#l@ zW#^&GjMG#eJ&#EXjm2yGBT)2w4CQpcXLHMZsCz#WIL6b=q%G;F@ZLymle%fAp%!q8 z!eQ}G)Ynr^YuANA-q`_aHf zM(4!@Ait9Rem^GTXqPi!H|;#Qm2RM(4Xg(i^@^5SE&|R}vxfHN|B(O-F1q}Y;3{$s zi#YpmVqFQeoY@E#mN`hW3TaKo6byh1*m$M{U7k*Zdm1`usqP00C!K-$rK7Pw!5gfO zt%9Cg+TiebDR_?zU_BTED)LEXY)vCEV6_e{a^X^*_IRq8!Pe@NJ3#a<3`dFbaR2Wz zG+-GDLsJ>y+{vhKSsDZUm*0u$x=Ao$MIAQ$V0GpSWZK&|(#~(z;6Gsq*#22ZT(%Jq zAK-$x;3Uz%mr57k?V`h6U|&+LL{8W|w;PH6aMcmjo2Tv;|V7o@e`C~Ph@SL-+h(V<1C zWx5w6EU#iLT>%Kv+;D;4sew;eUzXuL;#y7cBzrGIi?ebR46$m4JfHjGNb~f%gq$T%1P0 z^?^`OvU+akNgtH|SPP;spW*_Kt^0i578od& zfE(+Dec2X=UuS0GJBJIX7-vc@dlg`ll_ykwJcyEtUBG#|n$fV0VC2^?GF|7>v9>24 z-L8e<+VInGX>v5`4zYvRiexl>=>So}BpCXB0WQ75$6G)3vCi~4RhnF6gf*Uwp}d-S z|8<#p&CU^sY_1SRTLV+g`st!e=HO<12*fdA4VnpOL1`v6v6amwOq0a2jd)cRC3#I)^PW z8=%{!k?49_!i1NbFvzu*)b7uRL)V96AQgjB-;Orzd``Tj{e-JOoEY!7hugQaarx{z z=xz=qS|%J2n=gh-Bj%yTW0u1m5JTU*2V}rdLHOD8@%cgl+NK2o335QvxEdt35}2}B z1qRNIfmN^T;ip|8_t417*0=}jztsVpw^=rUJNZnXVnNunx6>$ zi{dG7Q4#tva@1UU4#dY|%yY7qqt2H~u+@7@_M;c)^a-=lF~@o3DNT?^c^lj#*B zmOYx|0^GwU09R!}a9S%Gy$F9raM6SC}fB{=E2~VQc&Xc71?ji7ESvTJEJZiH+PB@tcVCkp} za(_qKXFP>vLs_i%_tc%4vS%l;Iizw`&;_)KA0?YS^8J{wdfx6{A=nBvzb zqcC)3F{SY4$4`;9dspSCR?m{yt(q;|@{P{6h4`6yui9rSN@E2^RO| zK+m&S>>0WUw8Apc_HHY&V)KL%=_4^_U^r~En2uJF%a9HQxaoHi_x?P^av2{fJr&7v zTq#hyzXVoURN)mP*2^&dLV6F>Ft&q>1-73{=$Wcih=EOx4vc8uH|;V3ke^N?fB+Rq6}Zr5ka#MID9@52Jaa zTC!|QG+JI)!_L3&Fb;;dNGCf3kf>dSwv{3XH%o`g4~y}??K&8f=!eD?)xhz-x}?Q6 z4oDI^d*{9p{y4=UXVx5Zo*;z==uRf8_1Z)d!y#!WgJ@^gTIyAxgWP8XcD+o%rY$jS z|2+{1W!a2Du>z6xI><7s0o(IYG-yyNt@fS(oufW6fx=qI?ckxzJs*lZ{lG(;i){n! zEX$A~f}8EdKm@j^`p1Ftyf^-zrnNGejkeNDOi$Hj?B}qvB!3r@0H-(zw`KhqZ4GFz zPlb_gRTzG$m$a)*g*X3tqKcE0sD`|v=d`l1|JM@WA1D&|U)qk2#W}zmR1H>(jshoh zgTVB{7#ef;Je3%%BI%DRuudfsMoh`TWh~1Q-nI)JBkqtVW$u`qqk*5FI%C?d!$hSZ z85RbWqf6o(ILKqqNpt~tY0t)(L0TYu&?)E`oDa$}ZF1qqIXpk)9MX6(+g~Mu*l{e; zdU%Nnwf;623%dp4RkMin{7Zp`iVS?GT43C{QrwZU9=j&)ME;*)=y>TH`L`(_)x5I6 zVuOs$Z+iq*UwK40?zKQ1IIToR%PWY%!4&Yq06w3WpB!ST7+Jgx^d^mTx(f3XZ_SoN&}FE~DR6R-*cv z+oZqZIJ8vOkN^W!nsoAS+VYP8+D%uIQm-^nKV${U&N~d}+}wuhHkSFSuZE)NU93;0 zW!_~iN9Bj@R6J=137S_31I-@^M;co9)4v8BOg&+GG|P~{DhwLmN$MmjU~49XpbZ~s zS6UpT^Yf7Nb(#6cBdbv3fG1ran+C;cIrufW1a0)sP&xk(BiUF2^DalRym>OpM&+^h zBo~s}Cxd*>4l2HIh_N)>2s*FCki+IArpGwgbt9Krb%l`yPt)-3%~BM8DsSjDwE_LE z2&mkfhKgb70?sOhS#`rM&^WsU_ZuX`n+jIP-;AUk8UP-_6WO`8CK@zMiUA`tP@=J$ z3}tiUQw5t*C|W~mYX~-Ns-^ZLWjNq`0(YKSfk79Jft8*=#C#~AF%~|^@c=W$!g=)e zMH@Eb6+_8;8!TFXi+IoQBs*Orusc8<9r78X9@0y=Ew!r-A6i63bBhuCtVDE8}$mVPVy@h7@@j)RdV@DFv)IDxX9`@PQOGe76~7#px*KJ!^2gwuy)$=P<7)vf5yh1>+%10d8#|V{DWS zyiaMUV_l5h?IWl-C!NNO9s}#s3ejVi9(X6M1YUn5jZ%g~`>Ks7Ju;iLSy!R%zBl9u zcN8XB90xbY<6vqzmi8tllF>;;&_Brq_OI@xDgg;}$+jx29{quIJqduuA>&axXbt`P zW;Zs;MpE_5j=&4_#->BtV1Fsw&wY)i$&0Mf>fmIm9OMG;EHtqNy}urn`y zP`$MhE%vU!aI*^NzyF2=WToT!Ij2FdIRU%!Sw?YID5?*?0Fv9D@O0NOEOAzmyW`2A;lR-F@-QlV zFJZVrf2gP|kfB<9mVX>gJ0r7+Y||=eF7Kd{)1#@bPa!<&*@fcb-Kc!Hn~eTE6k?Lu zIndRnXmg46I@#G8@yQtclT`-1C%cKUawayL6oGko28MT-fnU6mjCn8wyJwz5W$hY> z810I!Dt>5b^NOmfZz8;Zz47Xt>F7E^1I4bI95AXi&5B7E~kFE z8qm09CR(x@vL$K@>oIymPtgUeEja<)3!%)ae~Lgc{3+x5ZV&_(6yQ(&8q^DBy|#NU zpeIQpU0b+d7q$_3<0pgt;pNExwh?Uh&jFioDKcp%aLP|Hs{W{_F`_1F9AE--YO64< zCmp24GK3=bEE&%OjV~_+ZV9u=ogoUK1-a1cP|3(vF9aPQ4(2(Z0{e*^c=yd3|9B08 zgue{2?rbShY^x!9J6PUhB|A&@zzsHvJ#a?Re0-gj2NAmKQNFf}YC<%)RO$k8zDBi% zrGfqTIiRs2jaV*O0amyY$M>@{@LJ`-VT?c{`4)+NdS7u6S1Q(p6Uu6$nKqm zSaD8(F`37)vvUdr3B6aT&^=`?Ntq1XICs2J|7vZKnAUy}t$$g&!_PP5?m0RmU z6q<=oG}AFi)fJK)BG7mDd~BCqp}jA1Ncf-%6l(2e+NWQl1KEz~{dOaf+LbZVmI`L^ zW(5+BTePm*A9|**fo9Jryz=ZMu03yu%A@8W96p7lX{n-(=TvmqJ%U_z&p>;ZernTE zhij*ugdn$lD0#D-R*pFhJ4Uf}U=_>VPudH$`mtzyc_@gE*E5SphJ(F!H5h$83rgQT zf{+X|oEv`uJ-!Zu0M+RK*Tvb!G<`-o45^q^_e%;1-f6+IR{b|#H`ZF+uS^n1#_q-UC_YLXrOR$6U8Om(vXOjUA>t91z! zQXK$CdNT70K27KN#T{y_BU$RFNy9A+OG z;}P~m$R35I^8T;b@5AX(TlP?FqKAbV8$_0(n!5UeaQs0xI`gb#4sN52_heM4vT^TR zJ7o;4LG!eW9C;ooyV*dc+YZ6EbqhAeUgiE#7Vyn>^bEGcsnBENSqpt^+)I~&b5WEZ zk67C*%V{m9!MPe*SKP{aMnA`RS{*No&|qroK9VwX(NUJg56WcNak_!j9~8TfzrP12 za}}ONwZM4VOoBV%gC7OrX~IG>XzxlT&kM+>c)(Pp6RdE#nT13zfKrNN>ABl*PvnSW zCyo5-yk0sttwsIJH)MUj?jaN3o*!LsP$ybms;u?SHJA-$ox}Y=tE4YUg!S}5sd1fv9+F(L_o2cgy#T;)n zhUn}3DxzxHa7iQN=5}nDJc%juIyh<;b~iI5{`y_mq8jJ2daVWs@aU2Icj z50avHbH`&jZtw3UZ-&I1f-D%1+)N#TrEJXAf_bx9Ncuv@bU&C}j^E{H diff --git a/burn-import/tests/data/conv2d/conv2d.py b/burn-import/tests/data/conv2d/conv2d.py deleted file mode 100644 index 75f9ba0ee..000000000 --- a/burn-import/tests/data/conv2d/conv2d.py +++ /dev/null @@ -1,36 +0,0 @@ -# used to generate model: burn-import/tests/data/conv2d/conv2d.onnx - -import torch -import torch.nn as nn -import onnx -from onnxoptimizer import optimize - -class Model(nn.Module): - def __init__(self): - super(Model, self).__init__() - self.conv1 = nn.Conv2d(16, 36, (3, 5), groups = 2, stride=(2, 1), padding=(4, 2), dilation=(3, 1)) - - def forward(self, x): - x = self.conv1(x) - return x - -def main(): - - # Export to onnx - model = Model() - model.eval() - device = torch.device("cpu") - dummy_input = torch.randn(20, 16, 50, 100, device=device) - torch.onnx.export(model, dummy_input, "conv2d.onnx", - verbose=False, opset_version=16) - - # Apply the optimization pass to simplify the model - onnx_model = onnx.load("conv2d.onnx") - optimized_model = optimize(onnx_model) - - # Save the optimized model - onnx.save(optimized_model, "conv2d.onnx") - - -if __name__ == '__main__': - main() diff --git a/burn-import/tests/data/model1/README.md b/burn-import/tests/data/model1/README.md deleted file mode 100644 index 7df98221b..000000000 --- a/burn-import/tests/data/model1/README.md +++ /dev/null @@ -1,17 +0,0 @@ -# Model1 test data files - -This directory contains the test data for the model1 test. The test data is generated by running the -following command: - -```bash -python3 model1.py -cargo run model1.onnx ./ -``` - -The following files are generated: - -- `model1.onnx`: The ONNX model -- `model1.rs`: The generated Rust code for the model (the path in the comment needs to be fixed for - the test) -- `model1.json`: The data of the model -- `model1.graph.txt`: The IR of the model diff --git a/burn-import/tests/data/model1/model1.onnx b/burn-import/tests/data/model1/model1.onnx deleted file mode 100644 index b2e4916c9a363ff4ee00b59c6ac5c0941cbb28e1..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 13253 zcmb_@d0dWNxA5Iuw^WiOLj#gZlQi7dT1lvo22u$PXw)Q`r%)uB%akEgPv#_~YpsMx zhL8+}Od)efrttMV?|a_oyx%$B`JF$$>;7H8d+)v0y7sX4UV9CDtEy_NB+gDwNSYL- zuVia!ZD}LhA1U20Csm1#Pn@1?X(Lsiln_6|#&TwObYxVrRQiwH#OSbT+B#N$Wvm9b zA%rAMPj15xvDR<*pGu7t|Ad)oOVuVv$A%?GC&W*4k#mvDl~XZOmnuajB}`A$Ys<-- zb&{%22~Ua-j}4g?6_yz8qTr(NCsGHgVq)0jX)cm~{QfM`lB!HgPKur!{vThG_Lu&* z4bj%H8WkQp{qKe;{gpRX{IgVABPad8;8OMYgrvB?JNl2@pB{sLLK zh9ysm@@k_Gi;bT1S4aP@x6A(%$5{Kn(zW$jB|LFjbZkO=r-SGFIONL7wT+)@TtZxU zeDd_TP7x6i1OF`3k+x4x3X6`9j*kq9OPCz4Cuti;X@K;9(IV{*R&H(FNDhzxyOFB@ zQDm(0PncOdsbbiy=xKj3NV-7!Ux}pBh)Mq%FV#QOe@zQ*tG1+-dw5*j-=A0eSDCTm zpAa)msZv;MVpLcsi6rH({%A-QCx$2goo_EyZtGwCG}nLBe^*ZW-+cEUJhT1R!OE<1*)! zdzLu|E;MxB?Wj_k^US?e;bUZJGv~6x+O5uF;I`6(m+qsfZ#TS9Aaa>s5q}eQJCAx6 z<(zU{TW}t1=bSKff2nn!jMAlF=akMkDRFkI^KxFhHqUwZLlcNNcFRkRHf1>L z`k?df+~uXa?6#Jg*{Z6x4cmW`c*70n{z`SesBgnQahav>E)I4!PEReJG{U3wUiNZl z{~%*$GrPZN{!K^>H~e#s|4;9K)bzKQ`IF{?}3{F}M8l%_2skBghK0$-ayNm_- z;$!sOD1$uAKZ^mAmHE{=JN7zd!yQz+W4hT^!9q8jZtpPTX%BtT0(wL5H9@#e^&Xs4 zS46{{Eo3*qm0lI+fXnR@C1c;cqAEp^Oz&Dys-io=%pSaBh7JeWnsdtIP@~_`gsr4v;>fo)D$lsvLN5bU2saSJO2*vf>BG)3VjS> zaGk;z$Se9N%id~|6yd8$T0{M8Qp(p3rlicDbE!(`IG?!(p& zH|Sm0M=&W`7mK>@##6C$dInX5M`DUn z&M6oDLfQEFBlyXqvk*If95q*TgemUfxb!-5sCS4EaqA6P7i;r+(@to)t0&L8@m}~^ zl_?a&rir)D_T{y~-S~pn0l0hl4{_bFO<-VS!?9ng$j(8RLi9)D!D zYb9)Za}jKc_$1!mJe28_94L&p#sW)Esu^U%J2b~&WuqaVwYv#v-v;3~{eigek}uEm z*g)`m5O1}T!fN$SnCo?ulHM3|25WQ8n=lGGH=27WAA$%jg7crHxJ8aYntOn}x)svm zfHJsp)CY3g7g3YbI&s(4XH*mQ5AsA|XE#h)SS=R+ zQli#iCmyA{5ww)FIi+a3%;IG=Rb1_ga>7u&{bUlWuHGRyiB{OnD-v)+9GiCU&tD6k z(|Ffu&~xDg?5sZ?w{BZOKZA#%;%6TS-MCo>^T+Zm^j+?$4@PpS{*mP!~AaAf8 z#t*Upz0INYOXUNV&1@{`;%&kn^-E~Rw{hGwDqdK5!wx;NPSLMxgSka77kc;h!!rA$ zC3#;pIQM~(Am`Z=7af$x*6=Q1Y*Gu&Z#znQ;YjgNMF{M z&X&Zn=E-2x8g*OL4}1<9#osA)(P3J8-iM2mjJaFaIdnR}fbSZg63aA_sCk_YeVzPE zFbs?#!?SkcVUu{?bfjLqe_;|eEtLrymRoU;gdh%B?SLt!H)Jo5>9do4Uk*zSMaix! zCFzFmgrx_@@mLE@e4J#C@YDjP)pW+I!#vTjW0iO`a6EpEN};sF0@i4sr-=)0lB5EJ zfcx%za7!Yrai1-!O&`S-r;Sm0O$;wKiRNL+I=G7b zv%K;hpsBWS@Psm--sZvw#>Vn0g)`9o-DrHI7|2R_~yi)jIB_;|Y~p5GdT$>;m9 zzO+A!Z=%rW)?{&kRK_Fahwz|Fv3TjlX{hUw2|0BGI6GM$gYUHQE~gKw%G*&&#io++ zFG6|yzHMO0dPFuBpvrP=H|#Pknih)_hDWhOxB=#>0}c2Qf#Kt?{mqLHl&{GF*^jcp4yW$puV|cTW z+}elVrCkE~$JgQVseyQ^zyhylnvv$z5FD$nfvBEa-}E<>EL7y> z9s!s&(}?_r)KIFYJTE;}1{o&yoH8{`mKb`PhOP9%4N5y{jl4SN#C64b&00!1?kp=^ zwhsn+Ej_1aW%=yAY>Cyda~uR9&vmVjB8yTO+qvG`Np z7nKGp@`sgre0Y8fXdP1!zxq^@^YLORFpHxUou0zlYEzbBIt=+`PQsCP`0`9J_7Hve z@sim0UIA*O3qz0=<((dNPd+Ho$>tuP(Zf}8&lMd3d)K-dmmrSblEQ}P@F#nD^ zuZRg}^WxtmF<)4cwB$4Fcx8Yt%iqb8fO~rQ(M`7uqeWbXcH}sP42rXOf zcxb5#2lmpUK~q)n+S4)kB>y$6z8Z-;3^Z};=R}@(I)_Z@nGiXBAbV`j6x@4s!FkyS zXo1@qIIIxGkupcls<}kVpV;6@pF~zcTRwgdIsa#_(0XYMpPgMz5&dsM{+U!7a6JJZ z?aUM(y{a!6YCahsEPhWbw0_Xl!k)ZmZC~26V+0f{ZK3S>opFAzN$7B55<8x{E2xJW zK%wSkm^WS*zDl!1_4$^tcR`z+$W!Gvr~AOAIeqzcRwv#ei{SXk)56kP|5bxQz`VsFlS zW57?BMqyy@UVNus4IkB`aCg*tp%)u*Nx%p`G5Z8HdqlxjSs`U?+y;ur?CAIH(O5HO z0CRoz$6H9A4o zX9n`(?b+~s%5IXo7KhwfAN518k%>+X74(=cnr|9_R-O@DwLOoz_m9GJeSNT|I20O# zUy54kLrAWlHV;2hPJ1hTG3oS6NZhF`*#0^PjUyik&Wk@ovR+SC%W%Y3=e7`KIKkU@ z9k?=23+uG9;JPqLoZVX;bH`=MjLUb>>EChOP0oXzj#!`~nDX5b$Hdr^T5P@FmV2KV zh|5c~u*2NJIPkp^Ut4F28|LJ}8apqljnn7#@$Fb`Su{t4mcbfZ9fq6o1e*teMf+;H z-w$}#N)yhQdmqvd9TFo{SI}y~3E!&*u|}JIntCc8eAQz(`;9uvg#}`6(xsA)y%O=A z&rguQcn2(Xy0Y__^W>HCy*3vki9B@43?coveuafvE6f5l$=Nx<@&|3`ns+V?H-Mi;*%#kD92!F zWGkF&@x$3`H8At7KBVo7LhHsk!n4h`IBHD=H5Tj?4R-g#&I($bb$9_toXSe3p9n;^ zsf%G&7ou6NL+Im!FECfm1n2GuK)Lol*t}C6s%!Vcr3qCu+^93_q_t-;SOd>pvc@;( zHDQ~eg*ECOKz+Rf-QPWqZxmagg2Pa}Fz^O#!Mkwc`D4hRtqP%C*1t*N|9-HtpY`W?!4Wm~CF zQ7CIoTTWTUGoa?y3)-OD42IQbW%qS^!2SLH@J)L(204C$;GMU@I8mOb*2=Icegf+U zhjZNUzL;&RhnjCAvDf%SuwQbK-lPTb>LdMmR{yKQ6Q##cu4)bLUsZ9^k1qJzBLo{V z-jZ**18)d3;e-?RZMxf7_8Z!rmP?E=wCtU%e40PK=%9g(#^u7jI!&z0mU5-Ra2~Ne z5t}>|*vsZNU49CjVR-~HmYT5b*Xy7cc|>fE_uwz>#$iIX7nYp33NapM=*RIM{9|DP zYENm~`@fF^-Q|nHt@|FZ9pS;FIZy>$f2Qm6iA|86~ zfStCr?aQNy`q>S_>(_Mzv-{aJV3IT1j2gIYMPn^);h~@R;Nc;XWICg6cZs=1_w9bl6 zzuqICBr|SYcvl=Xrk-k!S;F{49U8H37;bJq05_b};6tg-IR0(|Z&7*xm748PeZ3{f z8=LSWyC^!}6oh?0v`6nmb5xmQh8iW|Jh_7n6gm&(DLLAhoAX&Ly_rRh-3?gS6oNx; zjDlBtI`X$tBQ%-)3zU}|x9Oy*aJ%9gl)pSpfpb>~l`3{Ps-Xx%d+3Sp-v$3)XFLm-4sv5`+*Y9eY%luEtr8Jv_j*#tH@<}p)AP$C#{*=lM0G2({tMcaAvv=s(I^U2W4+G zCkt#CuMDPZJJaTYu4wJ!O&1(x{BHDXxZyaNJC>Q@SDk#Sm+!?%vHe&v>nZgJc>;Ic z#&LnY&Jc-PcdLy+A zBgrIF;mNyPI=9=5S6wmU*Ig62my$b;&pS>!GhImIQa`>MeoG7vDuL_Qy9zVcO~f^3 zYGOm0A?42Y5`anb$7o|A;;MO<$z+%t|9qm3 z-CP#H`cF={KCg)`e%Vg_(#EpRc_m)6tVuL|TP$w4ybH$uk%&bnLn(0YdND|T5n#Ov zHWY29N1L34uiK1itHoxTSK|aX%00Pceg@R`ox)Z>J7LD9gH+d9nU@zTV`1n%i1-o3 z&K;KnG#XO-1J>x$*G-)ChZBcxSt;gh8H?X$Ov0Aw88BqxQu?{44@OP>OsS@Qp}|X) z@{KY0#; z7Is6ScGpWfJ*OL{UQmE;wS8dwaRRG!T}-)hq-2GjEh;Tj!o~GFX-_+A2p(q5+wUxa z`X4VL`(=Cd>FXoV!5AtmEQI%pmN?th166yy6tV`%fwSXY((1QW?DF0VEvG8-g#`!U zgKsoCwaIOdLpj38pNT9Rng&JhB6zjp2s~{xnrml1r9qE%u+X;wHhGV2+uxVKfTj-I z%PWl}?QhA95ALOjms25J!2uUMdT?Vi z%?#o6@jHcQw{6hpm6y26qXryWI-~fiGlr-4K~H%p9m{j$=gW4%D)l;ec6t(CPq5^g zh1OL3qDp*GWQ>y*=wj^qu3#O~PWYkXfKjD^)M&qjT8=k>^P*C6ELtjjS~?CBt6E4a zX1G{%G#Kor8nNTtr9we>TWVc4jTVmXfNu3);FP&78V)Fw1-Ny^1Cb_}Qc~sA(P$jL zZj-xZJJyr+12?gB#2s)kFd~wRn!E^d)V;8FUKATR z_+h4vCwEluijA+ssN@hqfVLlQ`h8hE+nq_GW5N0&@IH znAY#6mc(XAeLalJs_SK^Eu66B$N`!?Xd>D_tN^**lUdTTwK!q|k@~E0ptaqLv^L7o z=dux~I$}Ky)pEr*p04!E^BCBBjo_x{BtiG_A-MSjuxyDtd!5`1igjU}c;x`ONLQ5< zv~=a^nm%am+ZQ(k?|~a-ozT%FmOVeakZ-aQc7HmW-QNwz!*OkUiO)??`8foersq=T zt^qvBI|gM7D!|ElDt1^F!qpbLpv$iAXgZ*UQss1L#-!WeqM$C`m#?IZA~)XBaS2E_ zzJ%JX_$qs(BVOo1W zW^06RHV%aRy?H>A17OuomTUGZDR z2U^)V617j91UEG&j*l3^cONejy2|Tg)&*bExo*JM1{z@BPUd*)%urs`RR;~%_%QFW z!E^03vG(N)n&e=Bk9yt~lVWdxEcPTwo&|H4?uxi}Z5VF$(n3Sut1{1*!zt`$f9$nq zHvJerh|^aWi!I}qz-g6s*hk-if7Z+J%0VxX6ofdrEWbvEL1V=G5p~q8q5;Dy#&OOb z2fnN_jJ#y2LqvVE%4UK3`0dBZoxKN(D;u|4fhaR>&^awTPyfjl#t6EGn-mAyvauR3+)f)`foJ zPf0sm{%|-iomda?!3SZ|k2t=5#7W$i;mOH24Efx=1PmOxm25LxVEmQoU@~GSSaetj zSI_zIcvWX=QVtb#|5y#B>Am@Uzad!g)QglqY^AcE)g{~ROyQ3+7sDgXGh$xAZ<@01Nb3IO@=A}S<=n1IyJ4}$G0ax}sAaOw$K5<)1Kdw!{a?fa9eC#&( zDm$Y7;Y9YSdq5YjE)p4aF)zB9?1C1;^T3xh^Q9rSzFb9;ZsTO-Q(YkNM5E|qAIneW z?66yTAJ{W4oNVgelau3c4v~3a;`Afph_O8}F!;8h{P-N5>;kA56%N5&o`7eJGD{Y{ zJ|(TGfR)@6zotGE7lap5O6M9UpNCh)eQpuxWQu4#B~vV(nk|@C3}?e)6JZ9`f$g*7 zU>FlmEA+kaaeD*&aVr+r9_)&e)C{L1Z@*CTGlZ4F?Rn0LTIkxj3rb$;mCPOP$?LRD zIZ>yT(yGIG!H@@Fs-I0WTvx-D2Xn~cfH`MwZqEtpz4*$J8IX?YLjDvIP8W5>_LD;J z;#DbU`dE{>jtLsV3DSxR6?H$pg*~rEL+lb&kX#?@@I@gMC(E=rxnmqHxzP(_p6(~N z5tR_{h&*y{B!--xi1{0=IXT;u`{v4{;lY#Qxg$gP+HOTm4e{X*6%tgt8GtKJcgFNT zo`^nKL4w^qJ23D_z>f9G=;r?mZrWI4rdtNtbO5eD;?C8Nnm~7lExU{aVO!Hc^mF(~ z9~GqB)g%!mQ?i`!n1CBfPr>B(0w;P($aAF{Gz_bx2ce4Cx+;xwOG&(&bW}JsIt(*^ zhCxz?lk~GoG}_L{h40%6!0u5uNUoRTn?~yBp|*pL#F%lC#&+2H+ZZ}!8KZXcen>oK zEHsUoFCG$ou&Sqo?+2>$-Tg;|$GK9p==+*%lU6~HT`NRc%qPV{E0|TkQ5fH^1N535 zz<;ba#o88gl%$(*UGi&KXx0$HIUPZmJ_jE=ndkP)URH#_wks)uVnt`l!{ z^g*@#F&MNUf*Y#iA!1lMJ#O*F%=7@#DW45T9_VAJ@?x1yKOOWyf1Wrk7k(5O;B?bA zU%~FajidX^1CK1laJudGc<6wP6$dCyU6+~c;nLz{c!I3(P;T{ zFI8`Cq4+i*=Cm=!c)j!$*$(pH6E@cTT=^T#91=&eBo#e?YFXIO#$y z`QYB&02k-pBfBRunizV4UVVsXNmqTDVmm!LlsXX`3RNIM4RKPf2k)Pq2mZ6lVdo_c zz7k~zHCD@MZIdcO{A}3d@j{qkF`l1frqIwy-dG-0L++pTQFDHqPh@2&CGJ~9$9{%z z?y}CJWWEcfOo`&Gp#hY6&;}H@OCYAIf*i+=le@@&)CU+;(;Js1o8hSdf7aXfM(Fq29es3$(%YoI*t^yQ zZOs~?Y`vM_@-0)Go-Dzlv=KbDK#P04)8nrPtJ`9rZy<2^ap6FEXPjByCWF-qp!Xyh zH|vk&k^yh2T|*$38Lp(+ieD&c`V(^fWzO65L(x6;J#D``j-7ceOgLl0+e~FR-f1!~ zs77mA_7e`Z+%!GPDEMML&mxtwT}a zts9PaO}=aiT>0b&KSF)bwm^H= zCU`aaEgjyo6auz(!IUg4DfZnW$SnsPaV`;iz&G&fu$Q)L+wh36Cb~Vi1isS*UZOt* zA>5A7zRZR<>*T@xdk5aW>M~ub`bj(5?6kxFS3v)?4oC9Pw)o*^EKXZ(jh{X3@Wi?Iuq=KuQU`U|b@T_d>U+`B z0oUNZnh{%14dbA>vE2QY39iYXCO~;Scr+llKI{Q?L$1Kb75U2n;;qAHZ7NArb21z_?=742WR7+!oIhek8KaDJN~ z!dX%b)-$cy?zbuVPkTX*uD3>v6i2e!NGp_>!KohZ=P4^^2yC@+hIfP9cx-|cy$U`+WTq$PPh3H20~~O2VnBucmP-U;dJU7vmJ?4HD7Nl&Ti!a`ZS=<9`bE;|Iml0UK+mXAo1{)UcDVZY{ z(?o@xkW#o#%vxZ<5}yZ@l0LKK@Q2%=-n$)`$0(!ZMjv)-Jp?wZ?QzDE17O!Nkz>TzV{586ZoCVoSukrmm=BmTnAo$+Jocmrt){q8Yo$v4b4v+Vbfj@ zN-F*Yrsu6_d1yOcW&9i7M3hosUrQ=3*e34!u>$;^?}MRXZpoq67o;VbPtSWSrss|f zF03HTnRy$eMn3%Zx|9PqIr2~ib1YRp4WVC$amQ7f_%TC|e}8dCUEfwJwmvD0FVo_7 z3UcT;ak-!(8Or^qYoL~}PBhLR!wpZoc}cDrt{K@N);-?=LWedTGuKnRyTg}U51ohH zAqnCupG{;VdrPkx7DGj)176H873_rqx)MAEdU6|WYpvK~|BQNV{eyDmPUNqd>q(-g zz$V8I(`>n3X!F<g@;eqJt{n9Bgw!#HBX}D9PY^B)qk1(w5 zG9I-anh3#Rhr#UgMYwC9C1i{_4a3^z#;Yb7=ey{0!V@1JbmxbmFpisialfOI&Ig&%M96 zVDZ=EV&Lr0;)*T((Q!gOsTU`}&Y+RFugD7D6`g=nzC#dxO$v0{GhoZ9Le|l`7;kv1$-Q zSy2X+G_=6YbIUOH^R_Xs~XPy zY=^Pa;^1CCSDe*mzbMAl(XE>HcxYQAy=-(Os7O}7lR-Nl`QonZ3y?k4 zAN_2v(xx}1^g7rOBlK+9qDCP5^LJs>+>gQ>nE`k6ld|#L?PSsx!|1hVK248{;q(vL zLUy(%exG_3wjHv?q*`O{o8yT6RYqZP`U_$CrZ9}Vq=%DVT9Z$;A2$#6BmGxKoEa=f zT~4ZFvZohlwQdpD9#BJ@00GDC8h{2168W|l@;6s&?53#%1T=X@6IlJpd`fIi5Pokk$JIL>v1N!0Pu)Bg zbL|S~*0M;vTU9Pry*W-NE$unT@DhDoR7CswnBa-n%TSth9XgL0iQ_eTgW|G$vVC+K zez*CbMkwm>V+}<%yys7;2NEFDUJ)*SekY81=ZFJr8p-#W2d1~%NFxnL;;e@e zDpvj^wNr=#pGNb=guTMMEfe{ke;1S=D&u+K#_YAWFHfA>=1)?xg0IP@bl7GItjHgP zx_{)sG+!qytdr++SJZi3@^UadAJ2Zv)yT5$9(C9ghk0}CxqNhg@SsoP#jNE*?%6*| zvRB*TzJQ^axVl^{*c>CcUH%QX2Uop3u=LJh&>mR~dq&w(Ko=Lh z`-eWylT)Uj_kA$lSeG~V8;K?3S5oe|l#(Uiqwsv%7|e+<=fGFqROR{^^dHwkefnz{ zl6{QAogHwsw6Hy)wbyvQ|G)<1CAX;ILk3B< z6v{3!*N=p6OSL^&XVa0vL%bZP-teEtZZK@T~}Nv&L~dg)C>!F zSv?Bp9Nh&=-YDXb=X0RUY8IWha>dy55}0u`n#tUTw>1xK^QV0mmQK41FEpld{@D_- zG3$lULDvxdQ!`*$lLL(bg86erFAg1WQ|1$tDqf#>LsS?%4C{r{@ac4$EPd+0&ofQ~1)Em+(?G1YH-8;>v{@n7qC-Y_I$PF+Y2-z1k~SXS)S{UbDdb?5|>( z!oZTc0fRYq)idh8`Wu}X(21LkBcVF(J$!f?&v^!SA?U>_h&O!!iQ%p^tWA$g4wwn5 zb**AUy#Q%1`e35V0a4%aHZ+F66igq~(EcUs;p3u4lfhe>zQi6M z#5&`opIVrfH4;5lPt%(a6;9n1$bP0k<&_iYvszg^*tiX4N1@6gLgkJudEZ?*S z^vs7*gWM2qdGr$=S+t#n=sypdQe-At#OrlH5OQ_G(@qE5(5Q@}}AMM#A#_2iXx;g70W5s@` z$Qg}g>xxD7;RD&sz3m*UvAdwW-Gsj%SxyBr%xI^59QJ6g1-CdUx3tTonguh#G_vh{ zjqzEikvyWNkTvw_K${(IZ9)CdMWLjVA1_-q9xuJBfy@_s$;?qz+IA@VU-uD)n+^`` zgxR5996QewW=*OTPAxRzLG8_OT(&lzd>?>O9?D#~*&WT`FUo(rEil~le>%ARpJ8UQ zzo`GZEAo$f2Y=H3@17u|f8Dpy)|Kk0%4us! { - conv2d1: Conv2d, - batchnormalization1: BatchNorm, - linear1: Linear, - batchnormalization2: BatchNorm, -} - -impl Model { - pub fn new_with(record: ModelRecord) -> Self { - let conv2d1 = Conv2dConfig::new([1, 8], [3, 3]) - .with_stride([1, 1]) - .with_padding(PaddingConfig2d::Valid) - .with_dilation([1, 1]) - .with_groups(1) - .with_bias(true) - .init_with(record.conv2d1); - let batchnormalization1 = BatchNormConfig::new(8) - .with_epsilon(0.000009999999747378752f64) - .with_momentum(0.8999999761581421f64) - .init_with(record.batchnormalization1); - let linear1 = LinearConfig::new(288, 10) - .with_bias(true) - .init_with(record.linear1); - let batchnormalization2 = BatchNormConfig::new(10) - .with_epsilon(0.000009999999747378752f64) - .with_momentum(0.8999999761581421f64) - .init_with(record.batchnormalization2); - Self { - conv2d1, - batchnormalization1, - linear1, - batchnormalization2, - } - } - - #[allow(clippy::let_and_return)] - pub fn forward(&self, input1: Tensor) -> Tensor { - let conv2d1_out1 = self.conv2d1.forward(input1); - let relu1_out1 = burn::tensor::activation::relu(conv2d1_out1); - let batchnormalization1_out1 = self.batchnormalization1.forward(relu1_out1); - let flatten1_out1 = batchnormalization1_out1.flatten(1, 3); - let linear1_out1 = self.linear1.forward(flatten1_out1); - let batchnormalization2_out1 = self.batchnormalization2.forward(linear1_out1); - let logsoftmax1_out1 = burn::tensor::activation::log_softmax(batchnormalization2_out1, 1); - logsoftmax1_out1 - } -} diff --git a/burn-import/tests/onnx_tests.rs b/burn-import/tests/onnx_tests.rs deleted file mode 100644 index 96e97d8b4..000000000 --- a/burn-import/tests/onnx_tests.rs +++ /dev/null @@ -1,42 +0,0 @@ -#[cfg(test)] -#[cfg(feature = "onnx")] -mod tests { - use std::fs::read_to_string; - use std::path::Path; - - use burn::record::FullPrecisionSettings; - use pretty_assertions::assert_eq; - use rstest::*; - - fn code>(onnx_path: P) -> String { - let graph = burn_import::onnx::parse_onnx(onnx_path.as_ref()); - let graph = graph - .into_burn::() - .with_blank_space(true) - .with_top_comment(Some("Generated by integration tests".into())); - - burn_import::format_tokens(graph.codegen()) - } - - #[rstest] - #[case::mixed("model1")] - #[case::conv2d("conv2d")] - #[case::concat("concat")] - // #[case::description_here("model2")] <- Add more models here - fn test_codegen(#[case] model_name: &str) { - let input_file = format!("tests/data/{model_name}/{model_name}.onnx"); - let source_file = format!("tests/data/{model_name}/{model_name}.rs"); - let source_expected: String = - read_to_string(source_file).expect("Expected source file is missing"); - - let generated_code = code(input_file); - - // Uncomment this to update the expected code - // println!("Generated code:\n{}", generated_code); - - assert_eq!( - source_expected, generated_code, - "Expected code is left, actual code is right" - ); - } -} diff --git a/examples/onnx-inference/src/bin/mnist.rs b/examples/onnx-inference/src/bin/mnist_inference.rs similarity index 100% rename from examples/onnx-inference/src/bin/mnist.rs rename to examples/onnx-inference/src/bin/mnist_inference.rs