diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 1d5d787ca..dd30a4a71 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -79,3 +79,8 @@ jobs: with: crate: burn-train + test-burn-import: + uses: burn-rs/burn/.github/workflows/test-template.yml@main + with: + crate: burn-import + diff --git a/Cargo.toml b/Cargo.toml index c1b2e7e10..1b8f7cf8d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,24 +5,25 @@ resolver = "2" members = [ "burn", - "burn-core", - "burn-train", - "burn-derive", - "burn-tensor", - "burn-tensor-testgen", - "burn-dataset", - "burn-tch", - "burn-ndarray", "burn-autodiff", "burn-common", + "burn-core", + "burn-dataset", + "burn-derive", + "burn-import", + "burn-ndarray", "burn-no-std-tests", + "burn-tch", + "burn-tensor-testgen", + "burn-tensor", + "burn-train", "examples/*", ] [workspace.dependencies] const-random = "0.1.15" dashmap = "5.4.0" -dirs = "4.0.0" +dirs = "5.0.0" fake = "2.5.0" flate2 = "1.0.25" half = {version = "2", features = ["alloc", "num-traits"], default-features = false} @@ -33,6 +34,11 @@ log = "0.4.17" log4rs = "1.2.0" spin = {version = "0.9.5", features = ["mutex", "spin_mutex"]} thiserror = "1.0.39" +proc-macro2 = "1.0.54" +quote = "1.0.26" +syn = "2.0" +strum = "0.24" +strum_macros = "0.24" # # The following packages disable the "std" feature for no_std compatibility @@ -44,7 +50,7 @@ rand = {version = "0.8.5", default-features = false, features = ["std_rng"]}# st rand_distr = {version = "0.4.3", default-features = false} uuid = {version = "1.3.0", default-features = false} +bincode = {version = "2.0.0-rc", features = ["alloc", "serde"], default-features = false} +rmp-serde = {version = "1.1.1"} serde = {version = "1.0.155", default-features = false, features = ["derive", "alloc"]}# alloc is for no_std, derive is needed serde_json = {version = "1.0.94", default-features = false} -rmp-serde = {version = "1.1.1"} -bincode = {version = "2.0.0-rc", features=["alloc", "serde"], default-features = false} diff --git a/NOTICES.md b/NOTICES.md new file mode 100644 index 000000000..f0f65e020 --- /dev/null +++ b/NOTICES.md @@ -0,0 +1,248 @@ +# NOTICES AND INFORMATION + +This file contains notices and information required by libraries that this +repository copied or derived from. + +## PyTorch MNIST Example + +**Source**: https://github.com/pytorch/examples/blob/main/mnist/main.py + +License: BSD 3-Clause License + +Copyright (c) 2017, +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +## ONNX + +**Source**: https://github.com/onnx/onnx/blob/main/onnx/onnx.proto3 + +License: Apache License 2.0 + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + +Copyright [yyyy] [name of copyright owner] + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + + diff --git a/burn-derive/Cargo.toml b/burn-derive/Cargo.toml index 158218193..581d8bd09 100644 --- a/burn-derive/Cargo.toml +++ b/burn-derive/Cargo.toml @@ -16,6 +16,7 @@ version = "0.6.0" proc-macro = true [dependencies] -proc-macro2 = "1.0.52" -quote = "1.0.26" -syn = "1.0.109" +proc-macro2 = {workspace = true} +quote = {workspace = true} +# syn = {workspace = true} +syn = "1.0.76" # TODO upgrade to 2.0 \ No newline at end of file diff --git a/burn-import/Cargo.toml b/burn-import/Cargo.toml new file mode 100644 index 000000000..6b05176b0 --- /dev/null +++ b/burn-import/Cargo.toml @@ -0,0 +1,36 @@ +[package] +authors = [ + "Dilshod Tadjibaev (@antimora)", +] +edition = "2021" +license = "MIT/Apache-2.0" +name = "burn-import" +readme = "README.md" +repository = "https://github.com/burn-rs/burn/tree/main/burn-import" + +version = "0.6.0" + +[features] +default = ["onnx"] +onnx = [] + +[dependencies] +burn = {path = "../burn", version = "0.6.0"} +burn-ndarray = {path = "../burn-ndarray", version = "0.6.0"} + +half = {workspace = true} +proc-macro2 = {workspace = true} +protobuf = {version = "3.2", features = ["with-bytes"]} +quote = {workspace = true} +rust-format = {version = "0.3", features = ["token_stream", "post_process"]} +serde = {workspace = true} +strum = {workspace = true} +strum_macros = {workspace = true} +syn = {workspace = true, features = ["parsing"]} +topological-sort = {version = "0.2.2"} + +[build-dependencies] +protobuf-codegen = {version = "3.2"} + +[dev-dependencies] +rstest = "0.17.0" diff --git a/burn-import/README.md b/burn-import/README.md new file mode 100644 index 000000000..eba672c1d --- /dev/null +++ b/burn-import/README.md @@ -0,0 +1,65 @@ +# Burn Import + +`burn-import` is a crate designed to facilitate importing models trained in other machine learning +frameworks into the Burn framework. This tool generates a Rust source file that aligns the source +model with Burn's model and converts tensor data into a format compatible with Burn. + +Currently under development, `burn-import` supports importing ONNX models with a limited set of +operators. + +## Supported ONNX Operators + +- Conv2d +- Gemm (Linear layer) +- Flatten +- LogSoftmax + +## Usage + +### Importing ONNX models + +In `build.rs`, add the following: + +```rust +use burn_import::onnx::ModelCodeGen; + +fn main() { + ModelCodeGen::new() + .input("src/model/mnist.onnx") // Path to the ONNX model + .out_dir("model/") // Directory to output the generated Rust source file (under target/) + .run_from_script(); +} +``` + +Then, add the following to mod.rs under `src/model`: + +```rust +pub mod mnist { + include!(concat!(env!("OUT_DIR"), "/model/mnist.rs")); +} +``` + +Finally, in your code, you can use the imported model as follows: + +```rust +use burn::tensor; +use burn_ndarray::NdArrayBackend; +use onnx_inference::model::mnist::{Model, INPUT1_SHAPE}; + +fn main() { + + // Create a new model + let model: Model> = Model::new(); + + // Create a new input tensor (all zeros for demonstration purposes) + let input = tensor::Tensor::, 4>::zeros(INPUT1_SHAPE); + + // Run the model + let output = model.forward(input); + + // Print the output + println!("{:?}", output); +} +``` + +You can view the working example in the `examples/onnx-inference` directory. diff --git a/burn-import/build.rs b/burn-import/build.rs new file mode 100644 index 000000000..f6cd68165 --- /dev/null +++ b/burn-import/build.rs @@ -0,0 +1,11 @@ +fn main() { + if cfg!(feature = "onnx") { + // Generate the onnx protobuf files + protobuf_codegen::Codegen::new() + .pure() + .includes(["src"]) + .input("src/onnx/protos/onnx.proto") + .cargo_out_dir("onnx-protos") + .run_from_script(); + } +} diff --git a/burn-import/src/lib.rs b/burn-import/src/lib.rs new file mode 100644 index 000000000..41d03d931 --- /dev/null +++ b/burn-import/src/lib.rs @@ -0,0 +1,6 @@ +#![allow(clippy::ptr_arg)] +#![allow(clippy::single_match)] +#![allow(clippy::upper_case_acronyms)] + +#[cfg(feature = "onnx")] +pub mod onnx; diff --git a/burn-import/src/onnx/coalesce.rs b/burn-import/src/onnx/coalesce.rs new file mode 100644 index 000000000..1ea80d7b0 --- /dev/null +++ b/burn-import/src/onnx/coalesce.rs @@ -0,0 +1,47 @@ +use super::ir::{AttributeValue, Node, NodeType}; + +/// The function transforms the graph into a new one where the nodes are coalesced into a single node. +pub fn coalesce(nodes: &mut Vec) { + for node in nodes.iter_mut() { + match node.node_type { + NodeType::Gemm => convert_gemm(node), + _ => {} + } + } +} + +/// This function converts a Gemm node into a Linear node +/// +/// Warning: This function is not complete yet. +/// It only supports the case where the Gemm node is a straight linear transformation. +fn convert_gemm(node: &mut Node) { + if node.inputs.len() != 1 { + panic!("Gemm node must have 3 inputs"); + } + + if node.outputs.len() != 1 { + panic!("Gemm node must have 1 output"); + } + let straight_linear = match ( + node.attrs.get("alpha"), + node.attrs.get("beta"), + node.attrs.get("transB"), + ) { + ( + Some(AttributeValue::Float32(alpha)), + Some(AttributeValue::Float32(beta)), + Some(AttributeValue::Int64(trans_b)), + ) => *alpha == 1.0 && *beta == 1.0 && *trans_b == 1, + _ => false, + }; + + if straight_linear { + node.node_type = NodeType::Linear; + node.is_stateful = true; + node.attrs.remove("alpha"); + node.attrs.remove("beta"); + node.attrs.remove("transB"); + } else { + panic!("Full Gemm node not supported yet."); + } +} diff --git a/burn-import/src/onnx/codegen.rs b/burn-import/src/onnx/codegen.rs new file mode 100644 index 000000000..5caa6dd5a --- /dev/null +++ b/burn-import/src/onnx/codegen.rs @@ -0,0 +1,574 @@ +use std::{ + collections::HashSet, + env, + fs::{self, create_dir_all}, + path::{Path, PathBuf}, +}; + +use burn::nn::conv::Conv2dPaddingConfig; +use proc_macro2::{Span, TokenStream}; +use quote::quote; +use syn::{Ident, Type}; + +use crate::onnx::{ + ir::{ArgType, Node, NodeType}, + op_configuration::{conv2d_config, flatten_config, linear_config, log_softmax_config}, +}; + +use super::{convert::parse_onnx, ir::Graph}; + +use rust_format::{Config, Edition, Formatter, PostProcess, RustFmt}; + +/// Code generation for onnx files. +#[derive(Debug, Default)] +pub struct ModelCodeGen { + out_dir: Option, + + /// List of onnx files to generate source code from. + inputs: Vec, +} + +/// Generate code from `.onnx` files and save it to the `out_dir`. +impl ModelCodeGen { + pub fn new() -> Self { + Self::default() + } + + /// Set output directory. + pub fn out_dir(&mut self, out_dir: &str) -> &mut Self { + let cargo_out_dir = env::var("OUT_DIR").expect("OUT_DIR env is not set"); + let mut path = PathBuf::from(cargo_out_dir); + + // Append the out_dir to the cargo_out_dir + path.push(Path::new(out_dir)); + + self.out_dir = Some(path); + self + } + + /// Add input file. + pub fn input(&mut self, input: &str) -> &mut Self { + self.inputs.push(input.into()); + self + } + + /// Run code generation. + /// + /// This function is intended to be called from `build.rs` script. + pub fn run_from_script(&self) { + self.run(); + } + + /// Run code generation. + pub fn run(&self) { + let config = Config::new_str() + .post_proc(PostProcess::ReplaceMarkersAndDocBlocks) + .edition(Edition::Rust2021); + + let rust_formatter = RustFmt::from_config(config); + + let out_dir = self.out_dir.as_ref().expect("out_dir is not set"); + create_dir_all(out_dir).unwrap(); + + for input in self.inputs.iter() { + let file_name = input.file_stem().unwrap(); + let out_file = out_dir.join(file_name); + let out_file = out_file.with_extension("rs"); + + let model = ModelSourceCode::new(input); + let code_str = rust_formatter.format_tokens(model.body()).unwrap(); + + fs::write(out_file, code_str).unwrap(); + } + } +} + +/// A model that can be used to generate code +#[derive(Debug, Clone)] +pub struct ModelSourceCode { + onnx_path: PathBuf, + pub graph: Graph, +} + +impl ModelSourceCode { + /// Create a new model from the onnx file + pub fn new>(onnx_path: P) -> Self { + let graph = parse_onnx(onnx_path.as_ref()); + Self { + onnx_path: onnx_path.as_ref().to_path_buf(), + graph, + } + } + + /// Generates source code for the model + pub fn body(&self) -> TokenStream { + let input = "Model"; // TODO make this a parameter + let input = Ident::new(input, Span::call_site()); + + let declaration = self.declaration(&input); + + let file_path = self.onnx_path.to_str().unwrap(); + + let top_file_comment = format!("Generated from {file_path} by burn-import"); + + let mut imports: HashSet = HashSet::new(); + + let implementation = self.implementation(&mut imports); + + let import_statements = self.import_statements(&imports); + + let shape_constants = self.shape_constants(); + + //TODO print out the old -> new name mapping + quote! { + _comment_!(#top_file_comment); + _blank_!(); + _blank_!(); + #import_statements + _blank_!(); + #shape_constants + _blank_!(); + #declaration + _blank_!(); + #[allow(dead_code)] + #[allow(clippy::new_without_default)] + #[allow(clippy::let_and_return)] + #implementation + + } + } + + fn shape_constants(&self) -> TokenStream { + let input_constants = self.graph.inputs.iter().enumerate().map(|(i, input)| { + let name = format!("INPUT{}_SHAPE", i + 1); + let name = Ident::new(&name, Span::call_site()); + let ArgType::Tensor(tensor) = input.clone().arg_type.unwrap(); + let dims = tensor.shape; + let dims_count = dims.len(); + quote! { + pub const #name: [usize; #dims_count] = [#(#dims),*]; + } + }); + + let output_constants = self.graph.outputs.iter().enumerate().map(|(i, input)| { + let name = format!("OUTPUT{}_SHAPE", i + 1); + let name = Ident::new(&name, Span::call_site()); + let ArgType::Tensor(tensor) = input.clone().arg_type.unwrap(); + let dims = tensor.shape; + let dims_count = dims.len(); + quote! { + pub const #name: [usize; #dims_count] = [#(#dims),*]; + } + }); + + quote! { + #(#input_constants)* + #(#output_constants)* + } + } + + /// Generates import statements for the model + fn import_statements(&self, imports: &HashSet) -> TokenStream { + let mut import_tokens = vec![]; + + for import in imports.iter() { + let path: syn::Path = + syn::parse_str(import).expect("Unable to parse input string as a path"); + + import_tokens.push(quote! { #path }); + } + + quote! { + use burn::{ + module::Module, + nn, + tensor::{backend::Backend, Tensor}, + }; + + #(use #import_tokens;)* + } + } + + /// Generates the declaration portion of the source code for the model + fn declaration(&self, name: &Ident) -> TokenStream { + let fields = self.declaration_fields(); + + let mut field_names = vec![]; + let mut field_types = vec![]; + + for (field_name, field_type) in fields.iter() { + field_names.push(field_name); + field_types.push(field_type); + } + + quote! { + // TODO add documentation + #[doc = "This is a generated model from an ONNX file"] + #[derive(Module, Debug)] + pub struct #name { + #( + #field_names: #field_types, + )* + } + + } + } + + /// Model implementation code + fn implementation(&self, imports: &mut HashSet) -> TokenStream { + let forward_method = self.forward_method(imports); + + let new_method = self.new_method(); + + quote! { + impl Model { + #new_method + #forward_method + } + } + } + + /// Generates the new method for the model + fn forward_method(&self, imports: &mut HashSet) -> TokenStream { + let inputs = self.forward_signature_input(); + let return_type = self.forward_signature_return(); + let results = self.forward_method_results(); + + let mut call_nodes: Vec = vec![]; + + for node in self.graph.nodes.iter() { + if node.is_stateful { + call_nodes.push(Self::node_call_stateful(node)); + } else { + call_nodes.push(Self::node_call_stateless(node, imports)); + } + } + + quote! { + pub fn forward(&self, #(#inputs,)*) -> #return_type { + #(#call_nodes)* + #results + } + } + } + + /// Generates source code for the stateful node calls, i.e. conv, dropout, etc. + fn node_call_stateful(node: &Node) -> TokenStream { + if !node.is_stateful { + panic!("Node must be stateful"); + } + + let name = Ident::new(&node.name, Span::call_site()); + + let mut inputs = vec![]; + + for input in node.inputs.iter() { + let name = Ident::new(&input.name, Span::call_site()); + inputs.push(quote! { + #name + }); + } + + let mut outputs = vec![]; + + for output in node.outputs.iter() { + let name = Ident::new(&output.name, Span::call_site()); + outputs.push(quote! { + #name + }); + } + + if outputs.len() == 1 { + let output = outputs.pop().unwrap(); + quote! { + let #output = self.#name.forward(#(#inputs,)*); + } + } else { + quote! { + let (#(#outputs,)*) = self.#name.forward(#(#inputs,)*); + } + } + } + + /// Generates source code for the forward method results + fn forward_method_results(&self) -> TokenStream { + let mut outputs = vec![]; + for output in self.graph.outputs.iter() { + let name = Ident::new(&output.name, Span::call_site()); + outputs.push(quote! { + #name + }); + } + if outputs.len() == 1 { + let output = outputs.pop().unwrap(); + quote! { + #output + } + } else { + quote! { + (#(#outputs,)*) + } + } + } + + /// Generates source code for the stateless node calls, i.e. add, mul, etc. + fn node_call_stateless(node: &Node, imports: &mut HashSet) -> TokenStream { + if node.is_stateful { + panic!("Node must be stateless"); + } + + let mut inputs = vec![]; + + for input in node.inputs.iter() { + let name = Ident::new(&input.name, Span::call_site()); + inputs.push(quote! { + #name + }); + } + + let mut outputs = vec![]; + + for output in node.outputs.iter() { + let name = Ident::new(&output.name, Span::call_site()); + outputs.push(quote! { + #name + }); + } + + let rhs = Self::node_call_stateless_rhs(node, imports); + + if outputs.len() == 1 { + let output = outputs.pop().unwrap(); + quote! { + let #output = #rhs; + } + } else { + quote! { + let (#(#outputs,)*) = #rhs; + } + } + } + + /// Generates source code for the right hand side stateless node calls, i.e. add, relu, etc. + fn node_call_stateless_rhs(node: &Node, imports: &mut HashSet) -> TokenStream { + let mut inputs = vec![]; + + for input in node.inputs.iter() { + let name = Ident::new(&input.name, Span::call_site()); + inputs.push(quote! { + #name + }); + } + + let input1 = inputs.pop().unwrap(); + + match node.node_type { + NodeType::Relu => { + imports.insert("burn::tensor::activation::relu".to_string()); + + quote! { relu(#input1) } + } + NodeType::LogSoftmax => { + imports.insert("burn::tensor::activation::log_softmax".to_string()); + let dim = log_softmax_config(node); + + quote! { log_softmax(#input1, #dim) } + } + NodeType::Flatten => { + let (start_dim, end_dim) = flatten_config(node); + + quote! { #input1.flatten(#start_dim, #end_dim) } + } + _ => quote! {}, + } + } + + /// Generates the forward method signature + fn forward_signature_input(&self) -> Vec { + let mut fields = vec![]; + + for input in self.graph.inputs.iter() { + let name = Ident::new(&input.name, Span::call_site()); + + let ty = match input.arg_type.as_ref().unwrap() { + ArgType::Tensor(tensor) => { + let d = &tensor.shape.len(); + syn::parse_str::(format!("Tensor").as_str()).unwrap() + } + }; + + fields.push(quote! { + #name: #ty + }); + } + fields + } + + /// Generates the forward method return signature + fn forward_signature_return(&self) -> TokenStream { + let mut field_types = vec![]; + + for output in self.graph.outputs.iter() { + let ty = match output.arg_type.as_ref().unwrap() { + ArgType::Tensor(tensor) => { + let d = &tensor.shape.len(); + syn::parse_str::(format!("Tensor").as_str()).unwrap() + } + }; + + field_types.push(ty); + } + + if field_types.len() == 1 { + // Return one output + quote! { + #( + #field_types + )* + } + } else { + // Return a tuple of the outputs + quote! { + (#( + #field_types, + )*) + } + } + } + + /// Generates source code for the initialization method + fn new_method(&self) -> TokenStream { + let initialization_fields = self.initialization_fields(); + + let field_names = self.graph.nodes.iter().filter(|x| x.is_stateful).map(|x| { + let name = Ident::new(&x.name, Span::call_site()); + quote! { + #name + } + }); + + quote! { + pub fn new() -> Self { + #( + #initialization_fields + )* + + Self { + #( + #field_names + ),* + } + } + } + } + + /// Get the fields for the declaration of the model + fn declaration_fields(&self) -> Vec<(Ident, Type)> { + let mut fields = vec![]; + + for node in self.graph.nodes.iter().filter(|x| x.is_stateful) { + let node_type = match node.node_type { + NodeType::Conv1d => syn::parse_str::("nn::conv::Conv1d").unwrap(), + NodeType::Conv2d => syn::parse_str::("nn::conv::Conv2d").unwrap(), + NodeType::Linear => syn::parse_str::("nn::Linear").unwrap(), + _ => { + todo!("Node type not implemented: {:?}", node.node_type) + } + }; + + let node_name = Ident::new(&node.name, Span::call_site()); + + fields.push((node_name, node_type)); + } + + fields + } + + /// Generates source code for the initialization method + fn initialization_fields(&self) -> Vec { + let mut fields = vec![]; + + for node in self.graph.nodes.iter().filter(|x| x.is_stateful) { + let init_code = match node.node_type { + NodeType::Conv2d => conv2d_init(node), + NodeType::Linear => linear_init(node), + _ => { + todo!("Node type not implemented: {:?}", node.node_type) + } + }; + + fields.push(init_code); + } + + fields + } +} + +/// Generates source code for the initialization of a Conv2d node +fn conv2d_init(node: &Node) -> TokenStream { + let node_name = Ident::new(&node.name, Span::call_site()); + + let config = conv2d_config(node); + + let channel_in = config.channels[0]; + let channel_out = config.channels[1]; + let kernel_size_0 = config.kernel_size[0]; + let kernel_size_1 = config.kernel_size[1]; + let bias = config.bias; + + let padding = match config.padding { + Conv2dPaddingConfig::Valid => quote! { nn::conv::Conv2dPaddingConfig::Valid }, + Conv2dPaddingConfig::Same => quote! { nn::conv::Conv2dPaddingConfig::Same }, + _ => todo!("Padding ({:?}) not implemented", config.padding), + }; + + quote! { + let #node_name = nn::conv::Conv2dConfig::new([#channel_in, #channel_out], [#kernel_size_0, #kernel_size_1]) + .with_padding(#padding) + .with_bias(#bias) + .init(); + + } +} + +/// Generates source code for the initialization of a Linear node +fn linear_init(node: &Node) -> TokenStream { + let node_name = Ident::new(&node.name, Span::call_site()); + let config = linear_config(node); + + let bias = config.bias; + let input_size = config.d_input; + let output_size = config.d_output; + + quote! { + let #node_name = nn::LinearConfig::new(#input_size, #output_size) + .with_bias(#bias) + .init(); + + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rstest::*; + use rust_format::{Config, Edition, Formatter, PostProcess, RustFmt}; + + #[fixture] + pub fn model() -> ModelSourceCode { + ModelSourceCode::new("tests/onnx/mnist.onnx") + } + + #[rstest] + fn print(model: ModelSourceCode) { + let config = Config::new_str() + .post_proc(PostProcess::ReplaceMarkersAndDocBlocks) + .edition(Edition::Rust2021); + + let rustfmt = RustFmt::from_config(config); + + let _gen_str = rustfmt.format_tokens(model.body()).unwrap(); + + // TODO compare the result with the expected output + } +} diff --git a/burn-import/src/onnx/convert.rs b/burn-import/src/onnx/convert.rs new file mode 100644 index 000000000..f9e0bdebc --- /dev/null +++ b/burn-import/src/onnx/convert.rs @@ -0,0 +1,523 @@ +use std::{ + collections::{HashMap, HashSet}, + fs::File, + path::Path, + str::{from_utf8, FromStr}, +}; + +use super::coalesce::coalesce; +use super::ir::{ + ArgType, Argument, AttributeValue, Attributes, ElementType, Graph, Node, NodeType, Tensor, + TensorData, +}; +use super::protos::{ + attribute_proto::AttributeType, tensor_proto::DataType, tensor_shape_proto::dimension::Value, + type_proto, AttributeProto, ModelProto, NodeProto, TensorProto, TensorShapeProto, + ValueInfoProto, +}; +use super::shape_inference::shape_inference; + +use protobuf::{Enum, Message}; +use topological_sort::TopologicalSort; + +const STATEFUL_NODE_TYPES: [NodeType; 4] = [ + NodeType::Conv, + NodeType::BatchNormalization, + NodeType::Dropout, + NodeType::Linear, +]; + +/// Error type for parsing ONNX model +#[derive(Debug)] +pub enum ParseError { + VariantNotFound, +} + +/// Open an onnx file and convert it to a Graph (intermediate representation) +pub fn parse_onnx(onnx_path: &Path) -> Graph { + // Open the file + let mut file = File::open(onnx_path).expect("Unable to open file"); + let onnx_model: ModelProto = + Message::parse_from_reader(&mut file).expect("Unable to parse ONNX file"); + + // Convert the nodes + let mut nodes: Vec = vec![]; + for onnx_node in onnx_model.graph.node.iter() { + nodes.push(convert_node_proto(onnx_node)); + } + + // Get the names of the initializers + let check_if_initializer: HashSet = onnx_model + .graph + .initializer + .iter() + .map(|x| x.name.clone()) + .collect(); + + // Move inputs to initializers + move_inputs_to_initializer(&mut nodes, &check_if_initializer); + + // Get the topological sort of the nodes and the top nodes + let (ts, top_nodes) = get_top_nodes(&nodes); + + // Sort the nodes + top_sort_nodes(&mut nodes, ts); + + // Collect inputs, outputs and initializers + let mut inputs = collect_inputs(&onnx_model, &check_if_initializer, top_nodes); + let mut outputs = collect_outputs(&onnx_model, check_if_initializer); + let initializers = collect_initializers(onnx_model); + + // Coalesce and transform nodes + coalesce(&mut nodes); + + // Copy the initializers to the nodes + copy_initializer_info_to_nodes_level(&mut nodes, &initializers); + + // Rename nodes and inputs, save the mapping for later + let old_node_names = rename_nodes(&mut nodes); + let old_input_names = rename_inputs(&mut nodes, &mut inputs, &mut outputs); + + // Infer shapes and update the inputs and outputs + shape_inference(&mut nodes, &inputs, &mut outputs); + + Graph { + nodes, + inputs, + outputs, + initializers, + old_node_names, + old_input_names, + } +} + +/// Collect initializers +fn collect_initializers(onnx_model: ModelProto) -> Vec { + let mut initializers: Vec = vec![]; + for initializer in onnx_model.graph.initializer.iter() { + let tensor_proto = initializer.clone(); + + let name = tensor_proto.name.clone(); + + // FIXME data conversion for the tensor is incorrect + let tensor: Tensor = tensor_proto.try_into().unwrap(); + let arg_type = Some(ArgType::Tensor(tensor)); + let arg = Argument { name, arg_type }; + initializers.push(arg); + } + initializers +} + +/// Collect outputs +fn collect_outputs( + onnx_model: &ModelProto, + check_if_initializer: HashSet, +) -> Vec { + // TODO: filter out the outputs that are not used in the graph + let outputs: Vec = onnx_model + .graph + .output + .iter() + .filter(|x| !check_if_initializer.contains(x.name.as_str())) + .map(|i| Argument::try_from(i.clone()).unwrap()) + .collect(); + outputs +} + +/// Collect inputs +fn collect_inputs( + onnx_model: &ModelProto, + check_if_initializer: &HashSet, + top_nodes: HashSet, +) -> Vec { + let inputs: Vec = onnx_model + .graph + .input + .iter() + .filter(|x| !check_if_initializer.contains(x.name.as_str())) + .filter(|x| top_nodes.contains(&x.name)) + .map(|x| Argument::try_from(x.clone()).unwrap()) + .collect(); + inputs +} + +/// Sort the nodes in topological order +fn top_sort_nodes(nodes: &mut Vec, mut ts: TopologicalSort) { + *nodes = vec![]; + while let Some(node) = ts.pop() { + nodes.push(node); + } +} + +/// Get the top nodes in the graph +fn get_top_nodes(nodes: &Vec) -> (TopologicalSort, HashSet) { + // Get the names of the top nodes (first nodes in the graph to receive the input) + // Sometimes onnx will pass inputs to be used as weights and biases but they are not truly inputs + let ts = topsort(nodes); + let mut top_nodes: HashSet = HashSet::new(); + + for node in ts.peek_all() { + for input in node.inputs.iter() { + top_nodes.insert(input.name.clone()); + } + } + (ts, top_nodes) +} + +/// Move nodes's inputs and outputs to initializers if they are in the initializer list +fn move_inputs_to_initializer(nodes: &mut Vec, check_if_initializer: &HashSet) { + for node in nodes.iter_mut() { + node.initializers = node + .inputs + .iter() + .filter(|x| check_if_initializer.contains(&x.name)) + .cloned() + .collect(); + + // Remove the initializers from the inputs and outputs + node.inputs + .retain(|x| !check_if_initializer.contains(&x.name)); + node.outputs + .retain(|x| !check_if_initializer.contains(&x.name)); + } +} + +fn to_string(bytes: Vec) -> String { + from_utf8(bytes.as_slice()).unwrap().to_string() +} + +fn to_string_vec(bytes: Vec>) -> Vec { + bytes.iter().map(|b| to_string(b.clone())).collect() +} + +fn convert_shape(shape: Vec) -> Vec { + shape.iter().map(|s| *s as usize).collect() +} + +/// Convert a vector of AttributeProto to a HashMap of AttributeValue +impl TryFrom for Tensor { + type Error = ParseError; + fn try_from(tensor: TensorProto) -> Result { + let (elem_type, data) = match DataType::from_i32(tensor.data_type).unwrap() { + // TODO: check if float is empty and use raw_data instead + DataType::FLOAT => ( + ElementType::Float32, + TensorData::Float32s(tensor.float_data), + ), + DataType::INT32 => (ElementType::Int32, TensorData::Int32s(tensor.int32_data)), + DataType::INT64 => (ElementType::Int64, TensorData::Int64s(tensor.int64_data)), + DataType::DOUBLE => ( + ElementType::Float64, + TensorData::Float64s(tensor.double_data), + ), + + // TODO : Add more types + _ => { + return Err(ParseError::VariantNotFound); + } + }; + let shape = convert_shape(tensor.dims); + let name = tensor.name; + + Ok(Tensor { + name: Some(name), + elem_type, + shape, + data: Some(data), + }) + } +} + +impl TryFrom for Vec { + type Error = ParseError; + fn try_from(shape: TensorShapeProto) -> Result, Self::Error> { + let mut result = Vec::new(); + + for dim in shape.dim { + if let Value::DimValue(value) = dim.value.unwrap() { + result.push(value as usize); + } + } + + Ok(result) + } +} + +/// Convert a vector of AttributeProto to a HashMap of AttributeValue +impl TryFrom<&type_proto::Tensor> for Tensor { + type Error = ParseError; + fn try_from(tensor: &type_proto::Tensor) -> Result { + let elem_type = match DataType::from_i32(tensor.elem_type).unwrap() { + DataType::FLOAT => ElementType::Float32, + DataType::INT32 => ElementType::Int32, + DataType::INT64 => ElementType::Int64, + DataType::DOUBLE => ElementType::Float64, + + // TODO : Add more types + _ => { + return Err(ParseError::VariantNotFound); + } + }; + + let shape_proto = tensor.shape.clone().unwrap(); + let shape: Vec = shape_proto.try_into().unwrap(); + + let name = None; + + Ok(Tensor { + name, + elem_type, + shape, + data: None, + }) + } +} + +fn convert_vec_tensor_proto(tensors: Vec) -> Result, ParseError> { + let mut result = Vec::new(); + for tensor in tensors { + result.push(Tensor::try_from(tensor)?); + } + Ok(result) +} + +/// Convert a vector of AttributeProto to a HashMap of AttributeValue +impl TryFrom for AttributeValue { + type Error = ParseError; + + fn try_from(attr: AttributeProto) -> Result { + let value = match attr.type_.unwrap() { + AttributeType::FLOAT => AttributeValue::Float32(attr.f), + AttributeType::INT => AttributeValue::Int64(attr.i), + AttributeType::STRING => AttributeValue::String(to_string(attr.s)), + + // warning: tensor can be empty TODO: check if it is empty + AttributeType::TENSOR => AttributeValue::Tensor(Tensor::try_from(attr.t.unwrap())?), + + // Graph is not supported for now + // AttributeType::GRAPH => AttributeValue::Graph(attr.g), + AttributeType::FLOATS => AttributeValue::Float32s(attr.floats), + AttributeType::INTS => AttributeValue::Int64s(attr.ints), + AttributeType::STRINGS => AttributeValue::Strings(to_string_vec(attr.strings)), + AttributeType::TENSORS => { + AttributeValue::Tensors(convert_vec_tensor_proto(attr.tensors)?) + } + // AttributeType::GRAPHS => AttributeValue::Graphs(attr.graphs), + // AttributeType::SPARSE_TENSORS => AttributeValue::SparseTensors(attr.sparse_tensors), + // AttributeType::SPARSE_TENSOR => AttributeValue::SparseTensor(attr.sparse_tensor), + _ => { + return Err(ParseError::VariantNotFound); + } + }; + + Ok(value) + } +} + +/// Convert a vector of AttributeProto to a HashMap of AttributeValue +pub fn convert_vec_attrs_proto(attrs: Vec) -> Attributes { + let mut result = Attributes::new(); + for attr in attrs { + result.insert(attr.name.clone(), AttributeValue::try_from(attr).unwrap()); + } + result +} + +pub fn convert_node_proto(node: &NodeProto) -> Node { + let name = node.name.clone(); + let inputs = node + .input + .clone() + .into_iter() + .map(|x| Argument { + name: x, + arg_type: None, + }) + .collect(); + let outputs = node + .output + .clone() + .into_iter() + .map(|x| Argument { + name: x, + arg_type: None, + }) + .collect(); + let attrs = convert_vec_attrs_proto(node.attribute.clone()); + + let node_type = NodeType::from_str(node.op_type.as_str()).unwrap(); + + let is_stateful = STATEFUL_NODE_TYPES.contains(&node_type); + + let mut node = Node { + node_type, + name, + inputs, + outputs, + initializers: vec![], + attrs, + is_stateful, + }; + + remap_node_type(&mut node); + + node +} + +/// Remap node type to a more specific one +fn remap_node_type(node: &mut Node) { + match node.node_type { + NodeType::Conv => { + if let AttributeValue::Int64s(ints) = node.attrs.get("kernel_shape").unwrap() { + node.node_type = match ints.len() { + 1 => NodeType::Conv1d, + 2 => NodeType::Conv2d, + _ => todo!(), + }; + } else { + panic!("kernel_shape is not an int64s"); + } + } + _ => (), + } +} + +/// Convert a vector of AttributeProto to a HashMap of AttributeValue +impl TryFrom for Argument { + type Error = ParseError; + + fn try_from(value: ValueInfoProto) -> Result { + let name = value.name.clone(); + let proto_type = value.type_.unwrap(); + + let mut arg_type = None; + + if proto_type.has_tensor_type() { + let tensor_proto = proto_type.tensor_type(); + + let tensor: Tensor = tensor_proto.try_into().unwrap(); + + arg_type = Some(ArgType::Tensor(tensor)); + } + Ok(Argument { name, arg_type }) + } +} + +/// Copy the initializers to the nodes +fn copy_initializer_info_to_nodes_level(nodes: &mut Vec, initializers: &Vec) { + for node in nodes.iter_mut() { + for node_initializer in node.initializers.iter_mut() { + *node_initializer = initializers + .iter() + .find(|x| x.name == node_initializer.name) + .unwrap() + .clone(); + } + } +} + +/// Rename the nodes in the graph to be unique and return a map of the old names to the new names. +fn rename_nodes(nodes: &mut Vec) -> HashMap { + let mut old_names = HashMap::new(); + let mut counter: HashMap = HashMap::new(); + + for node in nodes.iter_mut() { + // keep track of the number of nodes of each type + counter + .entry(node.node_type.clone()) + .and_modify(|e| *e += 1) + .or_insert(1); + + let old_name = node.name.clone(); + let new_name = format!("{}{}", node.node_type, counter[&node.node_type]).to_lowercase(); + + node.name = new_name.clone(); + + old_names.insert(old_name, new_name); + } + + old_names +} + +/// Rename the inputs in the graph and return a map of the old names to the new names. +/// +/// The inputs are renamed to be unique and to be in the format of conv2_in1, conv2_in2, etc. +/// This is done to be consistent with the naming convention of the nodes and allow to be used as rust identifiers. +fn rename_inputs( + nodes: &mut Vec, + inputs: &mut Vec, + outputs: &mut Vec, +) -> HashMap { + let mut old_names = HashMap::new(); + + let mut counter = 1; + for input in inputs.iter_mut() { + let old_name = input.name.clone(); + let new_name = format!("input{}", counter); + + input.name = new_name.clone(); + + old_names.insert(old_name, new_name); + counter += 1; + } + + let mut counter: HashMap = HashMap::new(); + + for node in nodes.iter_mut() { + // keep track of the number of nodes of each type + counter + .entry(node.name.clone()) + .and_modify(|e| *e += 1) + .or_insert(1); + + // loop through node inputs and rename them with previously replaced names + for input in node.inputs.iter_mut() { + if let Some(new_name) = old_names.get(&input.name) { + input.name = new_name.clone(); + } + } + + // loop through node outputs and rename them and store the new name <-> old name mapping + for output in node.outputs.iter_mut() { + let old_name = output.name.clone(); + let new_name = format!("{}_out{}", node.name, counter[&node.name]); + output.name = new_name.clone(); + old_names.insert(old_name, new_name); + } + } + + // Rename the graph outputs + for output in outputs.iter_mut() { + if let Some(new_name) = old_names.get(&output.name) { + output.name = new_name.clone(); + } + } + + old_names +} + +/// Find the node that produces the given output +fn lookup_node_by_output(nodes: &Vec, input: &str) -> Option { + for node in nodes.iter() { + if node.outputs.iter().any(|x| x.name == *input) { + return Some(node.clone()); + } + } + None +} + +/// Sort nodes in topological order +pub fn topsort(nodes: &Vec) -> TopologicalSort { + let mut ts = TopologicalSort::new(); + + for node in nodes.iter() { + for input in node.inputs.iter() { + match lookup_node_by_output(nodes, input.name.as_str()) { + Some(prec) => ts.add_dependency(prec, node.clone()), + None => {} + } + } + } + + ts +} diff --git a/burn-import/src/onnx/ir.rs b/burn-import/src/onnx/ir.rs new file mode 100644 index 000000000..69cd16177 --- /dev/null +++ b/burn-import/src/onnx/ir.rs @@ -0,0 +1,305 @@ +use half::f16; +use std::collections::HashMap; +use strum_macros::{Display, EnumString}; + +pub type Shape = Vec; + +#[derive(Debug, Clone)] +pub struct Argument { + pub name: String, + pub arg_type: Option, +} + +#[derive(Debug, Clone)] +pub enum ArgType { + Tensor(Tensor), +} + +#[derive(Debug, Clone)] +pub struct SparseTensor(Tensor, Tensor, Shape); + +#[derive(Debug, Clone)] +pub enum AttributeValue { + Float32(f32), + Int64(i64), + String(String), + Tensor(Tensor), + SparseTensor(SparseTensor), + Float32s(Vec), + Int64s(Vec), + Strings(Vec), + Tensors(Vec), + SparseTensors(Vec), +} +pub type Attributes = HashMap; + +#[derive(Debug, Clone)] +pub enum ElementType { + Float32, + Float64, + Int32, + Int64, + String, + Float16, +} + +#[derive(Debug, Clone)] +pub struct Tensor { + pub name: Option, + pub elem_type: ElementType, + pub shape: Shape, + pub data: Option, +} + +#[derive(Debug, Clone)] +pub enum TensorData { + Float16s(Vec), + Float32s(Vec), + Float64s(Vec), + Int32s(Vec), + Int64s(Vec), + Strings(Vec), +} + +#[derive(Debug, Clone)] +pub struct Graph { + pub nodes: Vec, + pub inputs: Vec, + pub outputs: Vec, + pub initializers: Vec, + pub old_node_names: HashMap, + pub old_input_names: HashMap, +} + +#[derive(Debug, Clone)] +pub struct Node { + pub node_type: NodeType, + pub name: String, + pub inputs: Vec, + pub outputs: Vec, + pub initializers: Vec, + pub attrs: Attributes, + pub is_stateful: bool, +} + +// Required by topological sort +impl PartialEq for Node { + fn eq(&self, other: &Self) -> bool { + self.name == other.name && self.node_type == other.node_type + } +} + +// Required by topological sort +impl Eq for Node {} + +// Required by topological sort +impl core::hash::Hash for Node { + fn hash(&self, state: &mut H) { + self.name.hash(state); + self.node_type.hash(state); + self.inputs.hash(state); + self.outputs.hash(state); + } +} + +// Required by topological sort +impl core::hash::Hash for Argument { + fn hash(&self, state: &mut H) { + self.name.hash(state); + } +} + +/// The list of supported node types (ONNX operators and some extra ones to map easily to Burn's ops) +#[derive(Debug, Hash, Eq, PartialEq, EnumString, Clone, Display)] +pub enum NodeType { + Abs, + Acos, + Acosh, + Add, + And, + ArgMax, + ArgMin, + Asin, + Asinh, + Atan, + Atanh, + AveragePool, + BatchNormalization, + Bernoulli, + BitShift, + BitwiseAnd, + BitwiseNot, + BitwiseOr, + BitwiseXor, + BlackmanWindow, + Cast, + CastLike, + Ceil, + Celu, + CenterCropPad, + Clip, + Col, + Compress, + Concat, + ConcatFromSequence, + Constant, + ConstantOfShape, + Conv, + Conv1d, + Conv2d, + ConvInteger, + ConvTranspose, + Cos, + Cosh, + CumSum, + DepthToSpace, + DequantizeLinear, + Det, + DFT, + Div, + Dropout, + DynamicQuantizeLinear, + Einsum, + Elu, + Equal, + Erf, + Exp, + Expand, + EyeLike, + Flatten, + Floor, + Gather, + GatherElements, + GatherND, + Gelu, + Gemm, + GlobalAveragePool, + GlobalLpPool, + GlobalMaxPool, + Greater, + GreaterOrEqual, + GridSample, + GroupNormalization, + GRU, + HammingWindow, + HannWindow, + Hardmax, + HardSigmoid, + HardSwish, + Identity, + If, + Im, + InstanceNormalization, + IsInf, + IsNaN, + LayerNormalization, + LeakyRelu, + Less, + LessOrEqual, + Linear, + Log, + LogSoftmax, + Loop, + LpNormalization, + LpPool, + LRN, + LSTM, + MatMul, + MatMulInteger, + Max, + MaxPool, + MaxRoiPool, + MaxUnpool, + Mean, + MeanVarianceNormalization, + MelWeightMatrix, + Min, + Mish, + Mod, + Mul, + Multinomial, + Neg, + NegativeLogLikelihoodLoss, + NonMaxSuppression, + NonZero, + Not, + OneHot, + Optional, + OptionalGetElement, + OptionalHasElement, + Or, + Pad, + Pow, + PRelu, + QLinearConv, + QLinearMatMul, + QuantizeLinear, + RandomNormal, + RandomNormalLike, + RandomUniform, + RandomUniformLike, + Range, + Reciprocal, + ReduceL, + ReduceLogSum, + ReduceLogSumExp, + ReduceMax, + ReduceMean, + ReduceMin, + ReduceProd, + ReduceSum, + ReduceSumSquare, + Relu, + Reshape, + Resize, + ReverseSequence, + RNN, + RoiAlign, + Round, + Scan, + Scatter, + ScatterElements, + ScatterND, + Selu, + SequenceAt, + SequenceConstruct, + SequenceEmpty, + SequenceErase, + SequenceInsert, + SequenceLength, + SequenceMap, + Shape, + Shrink, + Sigmoid, + Sign, + Sin, + Sinh, + Size, + Slice, + Softmax, + SoftmaxCrossEntropyLoss, + Softplus, + Softsign, + SpaceToDepth, + Split, + SplitToSequence, + Sqrt, + Squeeze, + STFT, + StringNormalizer, + Sub, + Sum, + Tan, + Tanh, + TfIdfVectorizer, + ThresholdedRelu, + Tile, + TopK, + Transpose, + Trilu, + Unique, + Unsqueeze, + Upsample, + Where, + Xor, +} diff --git a/burn-import/src/onnx/mod.rs b/burn-import/src/onnx/mod.rs new file mode 100644 index 000000000..a1bf879fd --- /dev/null +++ b/burn-import/src/onnx/mod.rs @@ -0,0 +1,9 @@ +mod coalesce; +mod codegen; +mod convert; +mod ir; +mod op_configuration; +mod protos; +mod shape_inference; + +pub use codegen::*; diff --git a/burn-import/src/onnx/op_configuration.rs b/burn-import/src/onnx/op_configuration.rs new file mode 100644 index 000000000..5fb2aeba3 --- /dev/null +++ b/burn-import/src/onnx/op_configuration.rs @@ -0,0 +1,181 @@ +use burn::nn::{ + conv::{Conv2dConfig, Conv2dPaddingConfig}, + LinearConfig, +}; + +use super::ir::{ArgType, AttributeValue, Node}; + +#[inline(always)] +pub fn attr_value_vec_i64(value: &AttributeValue, target: &mut Vec) { + if let AttributeValue::Int64s(val) = value { + *target = val.clone(); + } +} + +#[inline(always)] +pub fn attr_value_i64(value: &AttributeValue, target: &mut i64) { + if let AttributeValue::Int64(val) = value { + *target = *val; + } +} + +/// Create a Conv2dConfig from the attributes of the node +pub fn conv2d_config(curr: &Node) -> Conv2dConfig { + let mut kernel_shape = Vec::new(); + let mut strides = Vec::new(); + let mut pads = Vec::new(); + let mut dilations = Vec::new(); + let mut group: i64 = 0; + + // extract the channels from the weight tensor's shape [out_channels, in_channels, ...] + let ArgType::Tensor(tensor) = curr.initializers.get(0).unwrap().clone().arg_type.unwrap(); + + // check if the bias is present + let bias = curr.initializers.len() == 2; + + // the channels are inverted in the weight tensor + let channels: [usize; 2] = [tensor.shape[1], tensor.shape[0]]; + + for (key, value) in curr.attrs.iter() { + match key.as_str() { + "kernel_shape" => attr_value_vec_i64(value, &mut kernel_shape), + "strides" => attr_value_vec_i64(value, &mut strides), + "pads" => attr_value_vec_i64(value, &mut pads), + "dilations" => attr_value_vec_i64(value, &mut dilations), + "group" => attr_value_i64(value, &mut group), + _ => {} + } + } + + let padding = if pads.iter().all(|&x| x == 0) { + Conv2dPaddingConfig::Valid + } else { + todo!("Conv2d: padding({pads:?}) is not fully supported"); + }; + + if strides.iter().all(|&x| x != 1) { + todo!("Conv2d: strides({strides:?}) are not fully supported"); + }; + + if dilations.iter().all(|&x| x != 1) { + todo!("Conv2d: dilations({dilations:?}) are not fully supported"); + }; + + if group != 1 { + todo!("Conv2d: group ({group}) is not fully supported"); + }; + + Conv2dConfig::new( + channels, + [kernel_shape[0] as usize, kernel_shape[1] as usize], + ) + .with_bias(bias) + .with_padding(padding) +} + +/// Create a FlattenConfig from the attributes of the node +pub fn flatten_config(curr: &Node) -> (usize, usize) { + // the begin dimension is the first dimension (Default: 1 per ONNX spec) + let mut start_dim: i64 = 1; + + // check if the node has only one input + if curr.inputs.len() != 1 { + panic!( + "Flatten: multiple inputs are not supported (got {:?})", + curr.inputs.len() + ); + } + + // extract the shape of the input tensor + let ArgType::Tensor(tensor) = curr.inputs.get(0).unwrap().clone().arg_type.unwrap(); + + // check if the input tensor has at least 2 dimensions + if tensor.shape.len() < 2 { + panic!( + "Flatten: input tensor must have at least 2 dimensions (got {:?})", + tensor.shape.len() + ); + } + + // the end dimension is the last dimension + let end_dim = tensor.shape.len() - 1; + + // extract the attributes + for (key, value) in curr.attrs.iter() { + match key.as_str() { + "axis" => attr_value_i64(value, &mut start_dim), + _ => {} + } + } + + // if beg_dim is negative, it is counted from the end + if start_dim < 0 { + start_dim += tensor.shape.len() as i64; + } + + (start_dim as usize, end_dim) +} + +/// Create a LinearConfig from the attributes of the node +pub fn linear_config(node: &Node) -> LinearConfig { + // check if the node has only one input + if node.inputs.len() != 1 { + panic!( + "Linear: multiple inputs are not supported (got {:?})", + node.inputs.len() + ); + } + + if node.initializers.is_empty() { + panic!("Linear: no initializers found"); + } + + // extract the shape of the weight tensor + let ArgType::Tensor(tensor) = node.initializers.get(0).unwrap().clone().arg_type.unwrap(); + + // check if the weight tensor has at least 2 dimensions + if tensor.shape.len() < 2 { + panic!( + "Linear: weight tensor must have at least 2 dimensions (got {:?})", + tensor.shape.len() + ); + } + let (out_size, in_size) = (tensor.shape[0], tensor.shape[1]); + + // check if the bias is present + let bias = node.initializers.len() == 2; + + LinearConfig::new(in_size, out_size).with_bias(bias) +} + +/// Create log_softmax config from the attributes of the node +pub fn log_softmax_config(node: &Node) -> usize { + // the axis is the last dimension (Default: 1 per ONNX spec) + let mut axis: i64 = -1; + + // check if the node has only one input + if node.inputs.len() != 1 { + panic!( + "LogSoftmax: multiple inputs are not supported (got {:?})", + node.inputs.len() + ); + } + + // extract the shape of the input tensor + let ArgType::Tensor(tensor) = node.inputs.get(0).unwrap().clone().arg_type.unwrap(); + + // extract the attributes + for (key, value) in node.attrs.iter() { + match key.as_str() { + "axis" => attr_value_i64(value, &mut axis), + _ => {} + } + } + + // if axis is negative, it is counted from the end + if axis < 0 { + axis += tensor.shape.len() as i64; + } + + axis as usize +} diff --git a/burn-import/src/onnx/protos/mod.rs b/burn-import/src/onnx/protos/mod.rs new file mode 100644 index 000000000..328e850e7 --- /dev/null +++ b/burn-import/src/onnx/protos/mod.rs @@ -0,0 +1,5 @@ +mod inner { + include!(concat!(env!("OUT_DIR"), "/onnx-protos/mod.rs")); +} + +pub use inner::onnx::*; diff --git a/burn-import/src/onnx/protos/onnx.proto b/burn-import/src/onnx/protos/onnx.proto new file mode 100644 index 000000000..3e4126ddd --- /dev/null +++ b/burn-import/src/onnx/protos/onnx.proto @@ -0,0 +1,825 @@ +// +// WARNING: This file is automatically generated! Please edit onnx.in.proto. +// + +// Copied from https://github.com/onnx/onnx/blob/main/onnx/onnx.proto3 +// under the following license: + + +// SPDX-License-Identifier: Apache-2.0 + + +syntax = "proto3"; + +package onnx; + +// Overview +// +// ONNX is an open specification that is comprised of the following components: +// +// 1) A definition of an extensible computation graph model. +// 2) Definitions of standard data types. +// 3) Definitions of built-in operators. +// +// This document describes the syntax of models and their computation graphs, +// as well as the standard data types. Together, they are referred to as the ONNX +// Intermediate Representation, or 'IR' for short. +// +// The normative semantic specification of the ONNX IR is found in docs/IR.md. +// Definitions of the built-in neural network operators may be found in docs/Operators.md. + +// Notes +// +// Protobuf compatibility +// +// To simplify framework compatibility, ONNX is defined using the subset of protobuf +// that is compatible with both protobuf v2 and v3. This means that we do not use any +// protobuf features that are only available in one of the two versions. +// +// Here are the most notable contortions we have to carry out to work around +// these limitations: +// +// - No 'map' (added protobuf 3.0). We instead represent mappings as lists +// of key-value pairs, where order does not matter and duplicates +// are not allowed. + + +// Versioning +// +// ONNX versioning is specified in docs/IR.md and elaborated on in docs/Versioning.md +// +// To be compatible with both proto2 and proto3, we will use a version number +// that is not defined by the default value but an explicit enum number. +enum Version { + // proto3 requires the first enum value to be zero. + // We add this just to appease the compiler. + _START_VERSION = 0; + // The version field is always serialized and we will use it to store the + // version that the graph is generated from. This helps us set up version + // control. + // For the IR, we are using simple numbers starting with 0x00000001, + // which was the version we published on Oct 10, 2017. + IR_VERSION_2017_10_10 = 0x0000000000000001; + + // IR_VERSION 2 published on Oct 30, 2017 + // - Added type discriminator to AttributeProto to support proto3 users + IR_VERSION_2017_10_30 = 0x0000000000000002; + + // IR VERSION 3 published on Nov 3, 2017 + // - For operator versioning: + // - Added new message OperatorSetIdProto + // - Added opset_import in ModelProto + // - For vendor extensions, added domain in NodeProto + IR_VERSION_2017_11_3 = 0x0000000000000003; + + // IR VERSION 4 published on Jan 22, 2019 + // - Relax constraint that initializers should be a subset of graph inputs + // - Add type BFLOAT16 + IR_VERSION_2019_1_22 = 0x0000000000000004; + + // IR VERSION 5 published on March 18, 2019 + // - Add message TensorAnnotation. + // - Add quantization annotation in GraphProto to map tensor with its scale and zero point quantization parameters. + IR_VERSION_2019_3_18 = 0x0000000000000005; + + // IR VERSION 6 published on Sep 19, 2019 + // - Add support for sparse tensor constants stored in model. + // - Add message SparseTensorProto + // - Add sparse initializers + IR_VERSION_2019_9_19 = 0x0000000000000006; + + // IR VERSION 7 published on May 8, 2020 + // - Add support to allow function body graph to rely on multiple external opreator sets. + // - Add a list to promote inference graph's initializers to global and + // mutable variables. Global variables are visible in all graphs of the + // stored models. + // - Add message TrainingInfoProto to store initialization + // method and training algorithm. The execution of TrainingInfoProto + // can modify the values of mutable variables. + // - Implicitly add inference graph into each TrainingInfoProto's algorithm. + IR_VERSION_2020_5_8 = 0x0000000000000007; + + // IR VERSION 8 published on July 30, 2021 + // Introduce TypeProto.SparseTensor + // Introduce TypeProto.Optional + // Added a list of FunctionProtos local to the model + // Deprecated since_version and operator status from FunctionProto + IR_VERSION_2021_7_30 = 0x0000000000000008; + + // IR VERSION 9 published on TBD + // Added AttributeProto to FunctionProto so that default attribute values can be set. + IR_VERSION = 0x0000000000000009; +} + +// Attributes +// +// A named attribute containing either singular float, integer, string, graph, +// and tensor values, or repeated float, integer, string, graph, and tensor values. +// An AttributeProto MUST contain the name field, and *only one* of the +// following content fields, effectively enforcing a C/C++ union equivalent. +message AttributeProto { + + // Note: this enum is structurally identical to the OpSchema::AttrType + // enum defined in schema.h. If you rev one, you likely need to rev the other. + enum AttributeType { + UNDEFINED = 0; + FLOAT = 1; + INT = 2; + STRING = 3; + TENSOR = 4; + GRAPH = 5; + SPARSE_TENSOR = 11; + TYPE_PROTO = 13; + + FLOATS = 6; + INTS = 7; + STRINGS = 8; + TENSORS = 9; + GRAPHS = 10; + SPARSE_TENSORS = 12; + TYPE_PROTOS = 14; + } + + // The name field MUST be present for this version of the IR. + string name = 1; // namespace Attribute + + // if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function. + // In this case, this AttributeProto does not contain data, and it's a reference of attribute + // in parent scope. + // NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph. + string ref_attr_name = 21; + + // A human-readable documentation for this attribute. Markdown is allowed. + string doc_string = 13; + + // The type field MUST be present for this version of the IR. + // For 0.0.1 versions of the IR, this field was not defined, and + // implementations needed to use has_field heuristics to determine + // which value field was in use. For IR_VERSION 0.0.2 or later, this + // field MUST be set and match the f|i|s|t|... field in use. This + // change was made to accommodate proto3 implementations. + AttributeType type = 20; // discriminator that indicates which field below is in use + + // Exactly ONE of the following fields must be present for this version of the IR + float f = 2; // float + int64 i = 3; // int + bytes s = 4; // UTF-8 string + TensorProto t = 5; // tensor value + GraphProto g = 6; // graph + SparseTensorProto sparse_tensor = 22; // sparse tensor value + // Do not use field below, it's deprecated. + // optional ValueProto v = 12; // value - subsumes everything but graph + TypeProto tp = 14; // type proto + + repeated float floats = 7; // list of floats + repeated int64 ints = 8; // list of ints + repeated bytes strings = 9; // list of UTF-8 strings + repeated TensorProto tensors = 10; // list of tensors + repeated GraphProto graphs = 11; // list of graph + repeated SparseTensorProto sparse_tensors = 23; // list of sparse tensors + repeated TypeProto type_protos = 15;// list of type protos +} + +// Defines information on value, including the name, the type, and +// the shape of the value. +message ValueInfoProto { + // This field MUST be present in this version of the IR. + string name = 1; // namespace Value + // This field MUST be present in this version of the IR for + // inputs and outputs of the top-level graph. + TypeProto type = 2; + // A human-readable documentation for this value. Markdown is allowed. + string doc_string = 3; +} + +// Nodes +// +// Computation graphs are made up of a DAG of nodes, which represent what is +// commonly called a "layer" or "pipeline stage" in machine learning frameworks. +// +// For example, it can be a node of type "Conv" that takes in an image, a filter +// tensor and a bias tensor, and produces the convolved output. +message NodeProto { + repeated string input = 1; // namespace Value + repeated string output = 2; // namespace Value + + // An optional identifier for this node in a graph. + // This field MAY be absent in ths version of the IR. + string name = 3; // namespace Node + + // The symbolic identifier of the Operator to execute. + string op_type = 4; // namespace Operator + // The domain of the OperatorSet that specifies the operator named by op_type. + string domain = 7; // namespace Domain + + // Additional named attributes. + repeated AttributeProto attribute = 5; + + // A human-readable documentation for this node. Markdown is allowed. + string doc_string = 6; +} + +// Training information +// TrainingInfoProto stores information for training a model. +// In particular, this defines two functionalities: an initialization-step +// and a training-algorithm-step. Initialization resets the model +// back to its original state as if no training has been performed. +// Training algorithm improves the model based on input data. +// +// The semantics of the initialization-step is that the initializers +// in ModelProto.graph and in TrainingInfoProto.algorithm are first +// initialized as specified by the initializers in the graph, and then +// updated by the "initialization_binding" in every instance in +// ModelProto.training_info. +// +// The field "algorithm" defines a computation graph which represents a +// training algorithm's step. After the execution of a +// TrainingInfoProto.algorithm, the initializers specified by "update_binding" +// may be immediately updated. If the targeted training algorithm contains +// consecutive update steps (such as block coordinate descent methods), +// the user needs to create a TrainingInfoProto for each step. +message TrainingInfoProto { + // This field describes a graph to compute the initial tensors + // upon starting the training process. Initialization graph has no input + // and can have multiple outputs. Usually, trainable tensors in neural + // networks are randomly initialized. To achieve that, for each tensor, + // the user can put a random number operator such as RandomNormal or + // RandomUniform in TrainingInfoProto.initialization.node and assign its + // random output to the specific tensor using "initialization_binding". + // This graph can also set the initializers in "algorithm" in the same + // TrainingInfoProto; a use case is resetting the number of training + // iteration to zero. + // + // By default, this field is an empty graph and its evaluation does not + // produce any output. Thus, no initializer would be changed by default. + GraphProto initialization = 1; + + // This field represents a training algorithm step. Given required inputs, + // it computes outputs to update initializers in its own or inference graph's + // initializer lists. In general, this field contains loss node, gradient node, + // optimizer node, increment of iteration count. + // + // An execution of the training algorithm step is performed by executing the + // graph obtained by combining the inference graph (namely "ModelProto.graph") + // and the "algorithm" graph. That is, the actual the actual + // input/initializer/output/node/value_info/sparse_initializer list of + // the training graph is the concatenation of + // "ModelProto.graph.input/initializer/output/node/value_info/sparse_initializer" + // and "algorithm.input/initializer/output/node/value_info/sparse_initializer" + // in that order. This combined graph must satisfy the normal ONNX conditions. + // Now, let's provide a visualization of graph combination for clarity. + // Let the inference graph (i.e., "ModelProto.graph") be + // tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d + // and the "algorithm" graph be + // tensor_d -> Add -> tensor_e + // The combination process results + // tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d -> Add -> tensor_e + // + // Notice that an input of a node in the "algorithm" graph may reference the + // output of a node in the inference graph (but not the other way round). Also, inference + // node cannot reference inputs of "algorithm". With these restrictions, inference graph + // can always be run independently without training information. + // + // By default, this field is an empty graph and its evaluation does not + // produce any output. Evaluating the default training step never + // update any initializers. + GraphProto algorithm = 2; + + // This field specifies the bindings from the outputs of "initialization" to + // some initializers in "ModelProto.graph.initializer" and + // the "algorithm.initializer" in the same TrainingInfoProto. + // See "update_binding" below for details. + // + // By default, this field is empty and no initializer would be changed + // by the execution of "initialization". + repeated StringStringEntryProto initialization_binding = 3; + + // Gradient-based training is usually an iterative procedure. In one gradient + // descent iteration, we apply + // + // x = x - r * g + // + // where "x" is the optimized tensor, "r" stands for learning rate, and "g" is + // gradient of "x" with respect to a chosen loss. To avoid adding assignments + // into the training graph, we split the update equation into + // + // y = x - r * g + // x = y + // + // The user needs to save "y = x - r * g" into TrainingInfoProto.algorithm. To + // tell that "y" should be assigned to "x", the field "update_binding" may + // contain a key-value pair of strings, "x" (key of StringStringEntryProto) + // and "y" (value of StringStringEntryProto). + // For a neural network with multiple trainable (mutable) tensors, there can + // be multiple key-value pairs in "update_binding". + // + // The initializers appears as keys in "update_binding" are considered + // mutable variables. This implies some behaviors + // as described below. + // + // 1. We have only unique keys in all "update_binding"s so that two + // variables may not have the same name. This ensures that one + // variable is assigned up to once. + // 2. The keys must appear in names of "ModelProto.graph.initializer" or + // "TrainingInfoProto.algorithm.initializer". + // 3. The values must be output names of "algorithm" or "ModelProto.graph.output". + // 4. Mutable variables are initialized to the value specified by the + // corresponding initializer, and then potentially updated by + // "initializer_binding"s and "update_binding"s in "TrainingInfoProto"s. + // + // This field usually contains names of trainable tensors + // (in ModelProto.graph), optimizer states such as momentums in advanced + // stochastic gradient methods (in TrainingInfoProto.graph), + // and number of training iterations (in TrainingInfoProto.graph). + // + // By default, this field is empty and no initializer would be changed + // by the execution of "algorithm". + repeated StringStringEntryProto update_binding = 4; +} + +// Models +// +// ModelProto is a top-level file/container format for bundling a ML model and +// associating its computation graph with metadata. +// +// The semantics of the model are described by the associated GraphProto's. +message ModelProto { + // The version of the IR this model targets. See Version enum above. + // This field MUST be present. + int64 ir_version = 1; + + // The OperatorSets this model relies on. + // All ModelProtos MUST have at least one entry that + // specifies which version of the ONNX OperatorSet is + // being imported. + // + // All nodes in the ModelProto's graph will bind against the operator + // with the same-domain/same-op_type operator with the HIGHEST version + // in the referenced operator sets. + repeated OperatorSetIdProto opset_import = 8; + + // The name of the framework or tool used to generate this model. + // This field SHOULD be present to indicate which implementation/tool/framework + // emitted the model. + string producer_name = 2; + + // The version of the framework or tool used to generate this model. + // This field SHOULD be present to indicate which implementation/tool/framework + // emitted the model. + string producer_version = 3; + + // Domain name of the model. + // We use reverse domain names as name space indicators. For example: + // `com.facebook.fair` or `com.microsoft.cognitiveservices` + // + // Together with `model_version` and GraphProto.name, this forms the unique identity of + // the graph. + string domain = 4; + + // The version of the graph encoded. See Version enum below. + int64 model_version = 5; + + // A human-readable documentation for this model. Markdown is allowed. + string doc_string = 6; + + // The parameterized graph that is evaluated to execute the model. + GraphProto graph = 7; + + // Named metadata values; keys should be distinct. + repeated StringStringEntryProto metadata_props = 14; + + // Training-specific information. Sequentially executing all stored + // `TrainingInfoProto.algorithm`s and assigning their outputs following + // the corresponding `TrainingInfoProto.update_binding`s is one training + // iteration. Similarly, to initialize the model + // (as if training hasn't happened), the user should sequentially execute + // all stored `TrainingInfoProto.initialization`s and assigns their outputs + // using `TrainingInfoProto.initialization_binding`s. + // + // If this field is empty, the training behavior of the model is undefined. + repeated TrainingInfoProto training_info = 20; + + // A list of function protos local to the model. + // + // Name of the function "FunctionProto.name" should be unique within the domain "FunctionProto.domain". + // In case of any conflicts the behavior (whether the model local functions are given higher priority, + // or standard opserator sets are given higher priotity or this is treated as error) is defined by + // the runtimes. + // + // The operator sets imported by FunctionProto should be compatible with the ones + // imported by ModelProto and other model local FunctionProtos. + // Example, if same operator set say 'A' is imported by a FunctionProto and ModelProto + // or by 2 FunctionProtos then versions for the operator set may be different but, + // the operator schema returned for op_type, domain, version combination + // for both the versions should be same for every node in the function body. + // + // One FunctionProto can reference other FunctionProto in the model, however, recursive reference + // is not allowed. + repeated FunctionProto functions = 25; +}; + +// StringStringEntryProto follows the pattern for cross-proto-version maps. +// See https://developers.google.com/protocol-buffers/docs/proto3#maps +message StringStringEntryProto { + string key = 1; + string value = 2; +}; + +message TensorAnnotation { + string tensor_name = 1; + // pairs to annotate tensor specified by above. + // The keys used in the mapping below must be pre-defined in ONNX spec. + // For example, for 8-bit linear quantization case, 'SCALE_TENSOR', 'ZERO_POINT_TENSOR' will be pre-defined as + // quantization parameter keys. + repeated StringStringEntryProto quant_parameter_tensor_names = 2; +} + + + +// Graphs +// +// A graph defines the computational logic of a model and is comprised of a parameterized +// list of nodes that form a directed acyclic graph based on their inputs and outputs. +// This is the equivalent of the "network" or "graph" in many deep learning +// frameworks. +message GraphProto { + // The nodes in the graph, sorted topologically. + repeated NodeProto node = 1; + + // The name of the graph. + string name = 2; // namespace Graph + + // A list of named tensor values, used to specify constant inputs of the graph. + // Each initializer (both TensorProto as well SparseTensorProto) MUST have a name. + // The name MUST be unique across both initializer and sparse_initializer, + // but the name MAY also appear in the input list. + repeated TensorProto initializer = 5; + + // Initializers (see above) stored in sparse format. + repeated SparseTensorProto sparse_initializer = 15; + + // A human-readable documentation for this graph. Markdown is allowed. + string doc_string = 10; + + // The inputs and outputs of the graph. + repeated ValueInfoProto input = 11; + repeated ValueInfoProto output = 12; + + // Information for the values in the graph. The ValueInfoProto.name's + // must be distinct. It is optional for a value to appear in value_info list. + repeated ValueInfoProto value_info = 13; + + // This field carries information to indicate the mapping among a tensor and its + // quantization parameter tensors. For example: + // For tensor 'a', it may have {'SCALE_TENSOR', 'a_scale'} and {'ZERO_POINT_TENSOR', 'a_zero_point'} annotated, + // which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model. + repeated TensorAnnotation quantization_annotation = 14; + + reserved 3, 4, 6 to 9; + reserved "ir_version", "producer_version", "producer_tag", "domain"; +} + +// Tensors +// +// A serialized tensor value. +message TensorProto { + enum DataType { + UNDEFINED = 0; + // Basic types. + FLOAT = 1; // float + UINT8 = 2; // uint8_t + INT8 = 3; // int8_t + UINT16 = 4; // uint16_t + INT16 = 5; // int16_t + INT32 = 6; // int32_t + INT64 = 7; // int64_t + STRING = 8; // string + BOOL = 9; // bool + + // IEEE754 half-precision floating-point format (16 bits wide). + // This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits. + FLOAT16 = 10; + + DOUBLE = 11; + UINT32 = 12; + UINT64 = 13; + COMPLEX64 = 14; // complex with float32 real and imaginary components + COMPLEX128 = 15; // complex with float64 real and imaginary components + + // Non-IEEE floating-point format based on IEEE754 single-precision + // floating-point number truncated to 16 bits. + // This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits. + BFLOAT16 = 16; + + // Future extensions go here. + } + + // The shape of the tensor. + repeated int64 dims = 1; + + // The data type of the tensor. + // This field MUST have a valid TensorProto.DataType value + int32 data_type = 2; + + // For very large tensors, we may want to store them in chunks, in which + // case the following fields will specify the segment that is stored in + // the current TensorProto. + message Segment { + int64 begin = 1; + int64 end = 2; + } + Segment segment = 3; + + // Tensor content must be organized in row-major order. + // + // Depending on the data_type field, exactly one of the fields below with + // name ending in _data is used to store the elements of the tensor. + + // For float and complex64 values + // Complex64 tensors are encoded as a single array of floats, + // with the real components appearing in odd numbered positions, + // and the corresponding imaginary component appearing in the + // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] + // is encoded as [1.0, 2.0 ,3.0 ,4.0] + // When this field is present, the data_type field MUST be FLOAT or COMPLEX64. + repeated float float_data = 4 [packed = true]; + + // For int32, uint8, int8, uint16, int16, bool, and float16 values + // float16 values must be bit-wise converted to an uint16_t prior + // to writing to the buffer. + // When this field is present, the data_type field MUST be + // INT32, INT16, INT8, UINT16, UINT8, BOOL, FLOAT16 or BFLOAT16 + repeated int32 int32_data = 5 [packed = true]; + + // For strings. + // Each element of string_data is a UTF-8 encoded Unicode + // string. No trailing null, no leading BOM. The protobuf "string" + // scalar type is not used to match ML community conventions. + // When this field is present, the data_type field MUST be STRING + repeated bytes string_data = 6; + + // For int64. + // When this field is present, the data_type field MUST be INT64 + repeated int64 int64_data = 7 [packed = true]; + + // Optionally, a name for the tensor. + string name = 8; // namespace Value + + // A human-readable documentation for this tensor. Markdown is allowed. + string doc_string = 12; + + // Serializations can either use one of the fields above, or use this + // raw bytes field. The only exception is the string case, where one is + // required to store the content in the repeated bytes string_data field. + // + // When this raw_data field is used to store tensor value, elements MUST + // be stored in as fixed-width, little-endian order. + // Floating-point data types MUST be stored in IEEE 754 format. + // Complex64 elements must be written as two consecutive FLOAT values, real component first. + // Complex128 elements must be written as two consecutive DOUBLE values, real component first. + // Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false). + // + // Note: the advantage of specific field rather than the raw_data field is + // that in some cases (e.g. int data), protobuf does a better packing via + // variable length storage, and may lead to smaller binary footprint. + // When this field is present, the data_type field MUST NOT be STRING or UNDEFINED + bytes raw_data = 9; + + // Data can be stored inside the protobuf file using type-specific fields or raw_data. + // Alternatively, raw bytes data can be stored in an external file, using the external_data field. + // external_data stores key-value pairs describing data location. Recognized keys are: + // - "location" (required) - POSIX filesystem path relative to the directory where the ONNX + // protobuf model was stored + // - "offset" (optional) - position of byte at which stored data begins. Integer stored as string. + // Offset values SHOULD be multiples 4096 (page size) to enable mmap support. + // - "length" (optional) - number of bytes containing data. Integer stored as string. + // - "checksum" (optional) - SHA1 digest of file specified in under 'location' key. + repeated StringStringEntryProto external_data = 13; + + // Location of the data for this tensor. MUST be one of: + // - DEFAULT - data stored inside the protobuf message. Data is stored in raw_data (if set) otherwise in type-specified field. + // - EXTERNAL - data stored in an external location as described by external_data field. + enum DataLocation { + DEFAULT = 0; + EXTERNAL = 1; + } + + // If value not set, data is stored in raw_data (if set) otherwise in type-specified field. + DataLocation data_location = 14; + + // For double + // Complex128 tensors are encoded as a single array of doubles, + // with the real components appearing in odd numbered positions, + // and the corresponding imaginary component appearing in the + // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] + // is encoded as [1.0, 2.0 ,3.0 ,4.0] + // When this field is present, the data_type field MUST be DOUBLE or COMPLEX128 + repeated double double_data = 10 [packed = true]; + + // For uint64 and uint32 values + // When this field is present, the data_type field MUST be + // UINT32 or UINT64 + repeated uint64 uint64_data = 11 [packed = true]; +} + +// A serialized sparse-tensor value +message SparseTensorProto { + // The sequence of non-default values are encoded as a tensor of shape [NNZ]. + // The default-value is zero for numeric tensors, and empty-string for string tensors. + // values must have a non-empty name present which serves as a name for SparseTensorProto + // when used in sparse_initializer list. + TensorProto values = 1; + + // The indices of the non-default values, which may be stored in one of two formats. + // (a) Indices can be a tensor of shape [NNZ, rank] with the [i,j]-th value + // corresponding to the j-th index of the i-th value (in the values tensor). + // (b) Indices can be a tensor of shape [NNZ], in which case the i-th value + // must be the linearized-index of the i-th value (in the values tensor). + // The linearized-index can be converted into an index tuple (k_1,...,k_rank) + // using the shape provided below. + // The indices must appear in ascending order without duplication. + // In the first format, the ordering is lexicographic-ordering: + // e.g., index-value [1,4] must appear before [2,1] + TensorProto indices = 2; + + // The shape of the underlying dense-tensor: [dim_1, dim_2, ... dim_rank] + repeated int64 dims = 3; +} + +// Defines a tensor shape. A dimension can be either an integer value +// or a symbolic variable. A symbolic variable represents an unknown +// dimension. +message TensorShapeProto { + message Dimension { + oneof value { + int64 dim_value = 1; + string dim_param = 2; // namespace Shape + }; + // Standard denotation can optionally be used to denote tensor + // dimensions with standard semantic descriptions to ensure + // that operations are applied to the correct axis of a tensor. + // Refer to https://github.com/onnx/onnx/blob/main/docs/DimensionDenotation.md#denotation-definition + // for pre-defined dimension denotations. + string denotation = 3; + }; + repeated Dimension dim = 1; +} + +// Types +// +// The standard ONNX data types. +message TypeProto { + + message Tensor { + // This field MUST NOT have the value of UNDEFINED + // This field MUST have a valid TensorProto.DataType value + // This field MUST be present for this version of the IR. + int32 elem_type = 1; + TensorShapeProto shape = 2; + } + + // repeated T + message Sequence { + // The type and optional shape of each element of the sequence. + // This field MUST be present for this version of the IR. + TypeProto elem_type = 1; + }; + + // map + message Map { + // This field MUST have a valid TensorProto.DataType value + // This field MUST be present for this version of the IR. + // This field MUST refer to an integral type ([U]INT{8|16|32|64}) or STRING + int32 key_type = 1; + // This field MUST be present for this version of the IR. + TypeProto value_type = 2; + }; + + // wrapper for Tensor, Sequence, or Map + message Optional { + // The type and optional shape of the element wrapped. + // This field MUST be present for this version of the IR. + // Possible values correspond to OptionalProto.DataType enum + TypeProto elem_type = 1; + }; + + + message SparseTensor { + // This field MUST NOT have the value of UNDEFINED + // This field MUST have a valid TensorProto.DataType value + // This field MUST be present for this version of the IR. + int32 elem_type = 1; + TensorShapeProto shape = 2; + } + + + oneof value { + // The type of a tensor. + Tensor tensor_type = 1; + + // NOTE: DNN-only implementations of ONNX MAY elect to not support non-tensor values + // as input and output to graphs and nodes. These types are needed to naturally + // support classical ML operators. DNN operators SHOULD restrict their input + // and output types to tensors. + + // The type of a sequence. + Sequence sequence_type = 4; + + // The type of a map. + Map map_type = 5; + + // The type of an optional. + Optional optional_type = 9; + + + // Type of the sparse tensor + SparseTensor sparse_tensor_type = 8; + + } + + // An optional denotation can be used to denote the whole + // type with a standard semantic description as to what is + // stored inside. Refer to https://github.com/onnx/onnx/blob/main/docs/TypeDenotation.md#type-denotation-definition + // for pre-defined type denotations. + string denotation = 6; +} + +// Operator Sets +// +// OperatorSets are uniquely identified by a (domain, opset_version) pair. +message OperatorSetIdProto { + // The domain of the operator set being identified. + // The empty string ("") or absence of this field implies the operator + // set that is defined as part of the ONNX specification. + // This field MUST be present in this version of the IR when referring to any other operator set. + string domain = 1; + + // The version of the operator set being identified. + // This field MUST be present in this version of the IR. + int64 version = 2; +} + +// Operator/function status. +enum OperatorStatus { + EXPERIMENTAL = 0; + STABLE = 1; +} + +message FunctionProto { + // The name of the function, similar usage of op_type in OperatorProto. + // Combined with FunctionProto.domain, this forms the unique identity of + // the FunctionProto. + string name = 1; + + // Deprecated since IR Version 8 + // optional int64 since_version = 2; + reserved 2; + reserved "since_version"; + + // Deprecated since IR Version 8 + // optional OperatorStatus status = 3; + reserved 3; + reserved "status"; + + // The inputs and outputs of the function. + repeated string input = 4; + repeated string output = 5; + + // The attribute parameters of the function. + // It is for function parameters without default values. + repeated string attribute = 6; + + // The attribute protos of the function. + // It is for function attributes with default values. + // A function attribute shall be represented either as + // a string attribute or an AttributeProto, not both. + repeated AttributeProto attribute_proto = 11; + + // The nodes in the function. + repeated NodeProto node = 7; + // A human-readable documentation for this function. Markdown is allowed. + string doc_string = 8; + + // The OperatorSets this function body (graph) relies on. + // + // All nodes in the function body (graph) will bind against the operator + // with the same-domain/same-op_type operator with the HIGHEST version + // in the referenced operator sets. This means at most one version can be relied + // for one domain. + // + // The operator sets imported by FunctionProto should be compatible with the ones + // imported by ModelProto. Example, if same operator set say 'A' is imported by FunctionProto + // and ModelProto then versions for the operator set may be different but, + // the operator schema returned for op_type, domain, version combination + // for both the versions should be same. + + repeated OperatorSetIdProto opset_import = 9; + + // The domain which this function belongs to. Combined with FunctionProto.name, this forms the unique identity of + // the FunctionProto. + string domain = 10; +} + + +// For using protobuf-lite +option optimize_for = LITE_RUNTIME; diff --git a/burn-import/src/onnx/shape_inference.rs b/burn-import/src/onnx/shape_inference.rs new file mode 100644 index 000000000..745351c53 --- /dev/null +++ b/burn-import/src/onnx/shape_inference.rs @@ -0,0 +1,187 @@ +use std::collections::HashMap; + +use burn::tensor; + +use burn_ndarray::NdArrayBackend; + +use super::{ + ir::{ArgType, Argument, Node, NodeType, Tensor}, + op_configuration::{conv2d_config, flatten_config, linear_config}, +}; + +/// Infer the shape of each node and replace the shape of the output tensor +pub fn shape_inference( + nodes: &mut Vec, + graph_inputs: &Vec, + graph_outputs: &mut Vec, +) { + let mut prev_outputs: HashMap = HashMap::new(); + + for output in graph_inputs.iter() { + prev_outputs.insert(output.name.clone(), output.clone()); + } + + for node in nodes.iter_mut() { + match node.node_type { + NodeType::Conv2d => conv2d(node, &prev_outputs), + NodeType::Linear => linear(node, &prev_outputs), + NodeType::Relu => relu(node, &prev_outputs), + NodeType::Flatten => flatten(node, &prev_outputs), + NodeType::LogSoftmax => log_softmax(node, &prev_outputs), + _ => todo!( + "shape inference for {:?} is not implemented", + node.node_type + ), + } + + for output in node.outputs.iter() { + prev_outputs.insert(output.name.clone(), output.clone()); + } + } + + //update the outputs of the graph from prev_outputs + for output in graph_outputs.iter_mut() { + let arg = prev_outputs.get(output.name.as_str()).unwrap(); + output.arg_type = arg.arg_type.clone(); + } +} + +/// Infer the shape of the output tensor of a Conv2d node +fn linear(curr: &mut Node, prev_outpus: &HashMap) { + if curr.inputs.len() != 1 { + panic!("Linear: multiple inputs are not supported"); + } + + // Fill in the missing information about the input tensor from the previous outputs + let prev_node_output = prev_outpus.get(curr.inputs[0].name.as_str()).unwrap(); + curr.inputs[0].arg_type = prev_node_output.arg_type.clone(); + + // Extract the configuration of the linear layer (inputs are known) + let config = linear_config(curr); + + // Replace the output tensor + let curr_input = &mut curr.inputs[0]; + let ArgType::Tensor(tensor) = curr_input.clone().arg_type.unwrap(); + let mut new_shape = tensor.shape.clone(); + // Update the last dimension of the shape + new_shape[tensor.shape.len() - 1] = config.d_input; + + // Update the output tensor + curr.outputs[0].arg_type = Some(ArgType::Tensor(Tensor { + name: None, + shape: new_shape, + data: None, + elem_type: tensor.elem_type, + })); +} + +/// Infers the shape of a Relu node and replaces the shape of the output tensor. +fn relu(curr: &mut Node, prev_outpus: &HashMap) { + if curr.inputs.len() != 1 { + panic!("Relu: multiple inputs are not supported"); + } + + // Fill in the missing information about the input tensor from the previous outputs + let prev_node_output = prev_outpus.get(curr.inputs[0].name.as_str()).unwrap(); + + curr.inputs[0].arg_type = prev_node_output.arg_type.clone(); + + curr.outputs[0].arg_type = prev_node_output.arg_type.clone(); +} + +/// Infers the shape of a Flatten node and replaces the shape of the output tensor. +fn flatten(curr: &mut Node, prev_outpus: &HashMap) { + if curr.inputs.len() != 1 { + panic!("Flatten: multiple inputs are not supported"); + } + + // Fill in the missing information about the input tensor from the previous outputs + let prev_node_output = prev_outpus.get(curr.inputs[0].name.as_str()).unwrap(); + + curr.inputs[0].arg_type = prev_node_output.arg_type.clone(); + + let curr_input = &mut curr.inputs[0]; + + let ArgType::Tensor(tensor) = curr_input.clone().arg_type.unwrap(); + + let input_shape = tensor.shape; + + let (start_dim, end_dim) = flatten_config(curr); + + // calculate the new shape (code is taken from the flatten op) + // use the same logic as in the flatten op + // unfortunately the output tensor's dimensions (D2) are not known at compile time + // that's why we have to calculate the new shape at runtime + let mut new_dims = vec![0; input_shape.len() - (end_dim - start_dim)]; + let mut flatten_dims = 1; + for i in input_shape[start_dim..=end_dim].iter() { + flatten_dims *= i; + } + new_dims[..start_dim].copy_from_slice(&input_shape[..start_dim]); + new_dims[start_dim] = flatten_dims; + new_dims[start_dim + 1..].copy_from_slice(&input_shape[end_dim + 1..]); + + curr.outputs[0].arg_type = Some(ArgType::Tensor(Tensor { + name: None, + shape: new_dims, + data: None, + elem_type: tensor.elem_type, + })); +} + +/// Infers the shape of a LogSoftmax node and replaces the shape of the output tensor. +fn log_softmax(curr: &mut Node, prev_outpus: &HashMap) { + if curr.inputs.len() != 1 { + panic!("LogSoftmax: multiple inputs are not supported"); + } + + // Fill in the missing information about the input tensor from the previous outputs + let prev_node_output = prev_outpus.get(curr.inputs[0].name.as_str()).unwrap(); + curr.inputs[0].arg_type = prev_node_output.arg_type.clone(); + curr.outputs[0].arg_type = prev_node_output.arg_type.clone(); +} + +/// Infers the shape of a Conv2d node and replaces the shape of the output tensor. +/// +/// The shape of the output tensor is calculated by running the actual convolution operation. +fn conv2d(curr: &mut Node, prev_outpus: &HashMap) { + // copy the type from the previous output to the current input + + if curr.inputs.len() != 1 { + panic!("Conv2d: multiple inputs are not supported"); + } + + // Fill in the missing information about the input tensor from the previous outputs + let curr_input = &mut curr.inputs[0]; + let prev = prev_outpus.get(curr_input.name.as_str()).unwrap(); + curr_input.arg_type = prev.arg_type.clone(); + + // extract the channels from the weight tensor's shape [out_channels, in_channels, ...] + let ArgType::Tensor(tensor) = curr_input.clone().arg_type.unwrap(); + + let elem_type = tensor.elem_type; + + if tensor.shape.len() != 4 { + panic!("Conv2d: input tensor must be 4D"); + } + + let mut input_shape: [usize; 4] = [0; 4]; + input_shape.copy_from_slice(tensor.shape.as_slice()); + + // using the real configuration, run through op and calculate an actual shape of the output tensor + let config = conv2d_config(curr); + + let conv2d = config.init(); + + let input = tensor::Tensor::, 4>::zeros(input_shape); + let output = conv2d.forward(input); + + let output_shape = output.shape().dims.to_vec(); + + curr.outputs[0].arg_type = Some(ArgType::Tensor(Tensor { + name: None, + shape: output_shape, + data: None, + elem_type, + })); +} diff --git a/burn-import/tests/onnx/mnist.onnx b/burn-import/tests/onnx/mnist.onnx new file mode 100644 index 000000000..863c1b39c Binary files /dev/null and b/burn-import/tests/onnx/mnist.onnx differ diff --git a/burn-tensor-testgen/Cargo.toml b/burn-tensor-testgen/Cargo.toml index bfaa9e392..b96a1e4c2 100644 --- a/burn-tensor-testgen/Cargo.toml +++ b/burn-tensor-testgen/Cargo.toml @@ -12,6 +12,6 @@ edition = "2021" proc-macro = true [dependencies] -syn = "1.0.109" -quote = "1.0.26" -proc-macro2 = "1.0.52" +syn = {workspace = true} +quote = {workspace = true} +proc-macro2 = {workspace = true} diff --git a/examples/onnx-inference/Cargo.toml b/examples/onnx-inference/Cargo.toml new file mode 100644 index 000000000..031d89748 --- /dev/null +++ b/examples/onnx-inference/Cargo.toml @@ -0,0 +1,22 @@ +[package] +authors = ["Dilshod Tadjibaev (@antimora)"] +edition = "2021" +license = "MIT/Apache-2.0" +name = "onnx-inference" +publish = false +version = "0.6.0" + +[features] +default = [] + +[dependencies] +burn = {path = "../../burn"} +burn-ndarray = {path = "../../burn-ndarray"} +serde = {workspace = true} + + +[dev-dependencies] +burn-dataset = {path = "../../burn-dataset"} + +[build-dependencies] +burn-import = {path = "../../burn-import"} diff --git a/examples/onnx-inference/README.md b/examples/onnx-inference/README.md new file mode 100644 index 000000000..8667fa232 --- /dev/null +++ b/examples/onnx-inference/README.md @@ -0,0 +1,4 @@ +# ONNX Inference + +This crate provides a simple example for importing ONNX model to Burn. + diff --git a/examples/onnx-inference/build.rs b/examples/onnx-inference/build.rs new file mode 100644 index 000000000..8c3dae406 --- /dev/null +++ b/examples/onnx-inference/build.rs @@ -0,0 +1,9 @@ +use burn_import::onnx::ModelCodeGen; + +fn main() { + // Generate the model code from the ONNX file. + ModelCodeGen::new() + .input("src/model/mnist.onnx") + .out_dir("model/") + .run_from_script(); +} diff --git a/examples/onnx-inference/pytorch/mnist.py b/examples/onnx-inference/pytorch/mnist.py new file mode 100644 index 000000000..a2e52b7e2 --- /dev/null +++ b/examples/onnx-inference/pytorch/mnist.py @@ -0,0 +1,155 @@ +# Originally copied and modified from: https://github.com/pytorch/examples/blob/main/mnist/main.py +# under the following license: BSD-3-Clause license + +from __future__ import print_function +import argparse +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torchvision import datasets, transforms +from torch.optim.lr_scheduler import StepLR + + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 8, 3) + self.conv2 = nn.Conv2d(8, 16, 3) + self.conv3 = nn.Conv2d(16, 24, 3) + self.dropout1 = nn.Dropout(0.3) + self.fc1 = nn.Linear(24 * 22 * 22, 32) + self.fc2 = nn.Linear(32, 10) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + x = F.relu(x) + x = self.conv3(x) + x = F.relu(x) + + x = torch.flatten(x, 1) + x = self.fc1(x) + x = F.relu(x) + x = self.dropout1(x) + + x = self.fc2(x) + output = F.log_softmax(x, dim=1) + return output + + +def train(args, model, device, train_loader, optimizer, epoch): + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + loss = F.nll_loss(output, target) + loss.backward() + optimizer.step() + if batch_idx % args.log_interval == 0: + print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( + epoch, batch_idx * len(data), len(train_loader.dataset), + 100. * batch_idx / len(train_loader), loss.item())) + if args.dry_run: + break + + +def test(model, device, test_loader): + model.eval() + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss + pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability + correct += pred.eq(target.view_as(pred)).sum().item() + + test_loss /= len(test_loader.dataset) + + print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( + test_loss, correct, len(test_loader.dataset), + 100. * correct / len(test_loader.dataset))) + + +def main(): + # Training settings + parser = argparse.ArgumentParser(description='PyTorch MNIST Example') + parser.add_argument('--batch-size', type=int, default=64, metavar='N', + help='input batch size for training (default: 64)') + parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', + help='input batch size for testing (default: 1000)') + parser.add_argument('--epochs', type=int, default=14, metavar='N', + help='number of epochs to train (default: 14)') + parser.add_argument('--lr', type=float, default=1.0, metavar='LR', + help='learning rate (default: 1.0)') + parser.add_argument('--gamma', type=float, default=0.7, metavar='M', + help='Learning rate step gamma (default: 0.7)') + parser.add_argument('--no-cuda', action='store_true', default=False, + help='disables CUDA training') + parser.add_argument('--no-mps', action='store_true', default=False, + help='disables macOS GPU training') + parser.add_argument('--dry-run', action='store_true', default=False, + help='quickly check a single pass') + parser.add_argument('--seed', type=int, default=1, metavar='S', + help='random seed (default: 1)') + parser.add_argument('--log-interval', type=int, default=10, metavar='N', + help='how many batches to wait before logging training status') + parser.add_argument('--save-model', action='store_true', default=False, + help='For Saving the current Model') + parser.add_argument('--export-onnx', action='store_true', default=True, + help='For Saving the current Model in ONNX format') + args = parser.parse_args() + use_cuda = not args.no_cuda and torch.cuda.is_available() + use_mps = not args.no_mps and torch.backends.mps.is_available() + + torch.manual_seed(args.seed) + + if use_cuda: + device = torch.device("cuda") + elif use_mps: + device = torch.device("mps") + else: + device = torch.device("cpu") + + train_kwargs = {'batch_size': args.batch_size} + test_kwargs = {'batch_size': args.test_batch_size} + if use_cuda: + cuda_kwargs = {'num_workers': 1, + 'pin_memory': True, + 'shuffle': True} + train_kwargs.update(cuda_kwargs) + test_kwargs.update(cuda_kwargs) + + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ]) + dataset1 = datasets.MNIST('/tmp/mnist-data', train=True, download=True, + transform=transform) + dataset2 = datasets.MNIST('/tmp/mnist-data', train=False, + transform=transform) + train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs) + test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs) + + model = Net().to(device) + optimizer = optim.Adadelta(model.parameters(), lr=args.lr) + + scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) + for epoch in range(1, args.epochs + 1): + train(args, model, device, train_loader, optimizer, epoch) + test(model, device, test_loader) + scheduler.step() + + if args.save_model: + torch.save(model, "mnist.pt") + + if args.export_onnx: + dummy_input = torch.randn(1, 1, 28, 28, device=device) + torch.onnx.export(model, dummy_input, "mnist.onnx") + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/examples/onnx-inference/src/lib.rs b/examples/onnx-inference/src/lib.rs new file mode 100644 index 000000000..65880be0e --- /dev/null +++ b/examples/onnx-inference/src/lib.rs @@ -0,0 +1 @@ +pub mod model; diff --git a/examples/onnx-inference/src/main.rs b/examples/onnx-inference/src/main.rs new file mode 100644 index 000000000..477c6bbc9 --- /dev/null +++ b/examples/onnx-inference/src/main.rs @@ -0,0 +1,17 @@ +use burn::tensor; +use burn_ndarray::NdArrayBackend; +use onnx_inference::model::mnist::{Model, INPUT1_SHAPE}; + +fn main() { + // Create a new model + let model: Model> = Model::new(); + + // Create a new input tensor (all zeros for demonstration purposes) + let input = tensor::Tensor::, 4>::zeros(INPUT1_SHAPE); + + // Run the model + let output = model.forward(input); + + // Print the output + println!("{:?}", output); +} diff --git a/examples/onnx-inference/src/model/mnist.onnx b/examples/onnx-inference/src/model/mnist.onnx new file mode 100644 index 000000000..863c1b39c Binary files /dev/null and b/examples/onnx-inference/src/model/mnist.onnx differ diff --git a/examples/onnx-inference/src/model/mod.rs b/examples/onnx-inference/src/model/mod.rs new file mode 100644 index 000000000..4c821cafd --- /dev/null +++ b/examples/onnx-inference/src/model/mod.rs @@ -0,0 +1,3 @@ +pub mod mnist { + include!(concat!(env!("OUT_DIR"), "/model/mnist.rs")); +}