Add foundation for importing ONNX files (#297)

This commit is contained in:
Dilshod Tadjibaev 2023-04-15 09:44:50 -05:00 committed by GitHub
parent a74f26620a
commit df980d534e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 3262 additions and 17 deletions

View File

@ -79,3 +79,8 @@ jobs:
with:
crate: burn-train
test-burn-import:
uses: burn-rs/burn/.github/workflows/test-template.yml@main
with:
crate: burn-import

View File

@ -5,24 +5,25 @@ resolver = "2"
members = [
"burn",
"burn-core",
"burn-train",
"burn-derive",
"burn-tensor",
"burn-tensor-testgen",
"burn-dataset",
"burn-tch",
"burn-ndarray",
"burn-autodiff",
"burn-common",
"burn-core",
"burn-dataset",
"burn-derive",
"burn-import",
"burn-ndarray",
"burn-no-std-tests",
"burn-tch",
"burn-tensor-testgen",
"burn-tensor",
"burn-train",
"examples/*",
]
[workspace.dependencies]
const-random = "0.1.15"
dashmap = "5.4.0"
dirs = "4.0.0"
dirs = "5.0.0"
fake = "2.5.0"
flate2 = "1.0.25"
half = {version = "2", features = ["alloc", "num-traits"], default-features = false}
@ -33,6 +34,11 @@ log = "0.4.17"
log4rs = "1.2.0"
spin = {version = "0.9.5", features = ["mutex", "spin_mutex"]}
thiserror = "1.0.39"
proc-macro2 = "1.0.54"
quote = "1.0.26"
syn = "2.0"
strum = "0.24"
strum_macros = "0.24"
#
# The following packages disable the "std" feature for no_std compatibility
@ -44,7 +50,7 @@ rand = {version = "0.8.5", default-features = false, features = ["std_rng"]}# st
rand_distr = {version = "0.4.3", default-features = false}
uuid = {version = "1.3.0", default-features = false}
bincode = {version = "2.0.0-rc", features = ["alloc", "serde"], default-features = false}
rmp-serde = {version = "1.1.1"}
serde = {version = "1.0.155", default-features = false, features = ["derive", "alloc"]}# alloc is for no_std, derive is needed
serde_json = {version = "1.0.94", default-features = false}
rmp-serde = {version = "1.1.1"}
bincode = {version = "2.0.0-rc", features=["alloc", "serde"], default-features = false}

248
NOTICES.md Normal file
View File

@ -0,0 +1,248 @@
# NOTICES AND INFORMATION
This file contains notices and information required by libraries that this
repository copied or derived from.
## PyTorch MNIST Example
**Source**: https://github.com/pytorch/examples/blob/main/mnist/main.py
License: BSD 3-Clause License
Copyright (c) 2017,
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
## ONNX
**Source**: https://github.com/onnx/onnx/blob/main/onnx/onnx.proto3
License: Apache License 2.0
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@ -16,6 +16,7 @@ version = "0.6.0"
proc-macro = true
[dependencies]
proc-macro2 = "1.0.52"
quote = "1.0.26"
syn = "1.0.109"
proc-macro2 = {workspace = true}
quote = {workspace = true}
# syn = {workspace = true}
syn = "1.0.76" # TODO upgrade to 2.0

36
burn-import/Cargo.toml Normal file
View File

@ -0,0 +1,36 @@
[package]
authors = [
"Dilshod Tadjibaev (@antimora)",
]
edition = "2021"
license = "MIT/Apache-2.0"
name = "burn-import"
readme = "README.md"
repository = "https://github.com/burn-rs/burn/tree/main/burn-import"
version = "0.6.0"
[features]
default = ["onnx"]
onnx = []
[dependencies]
burn = {path = "../burn", version = "0.6.0"}
burn-ndarray = {path = "../burn-ndarray", version = "0.6.0"}
half = {workspace = true}
proc-macro2 = {workspace = true}
protobuf = {version = "3.2", features = ["with-bytes"]}
quote = {workspace = true}
rust-format = {version = "0.3", features = ["token_stream", "post_process"]}
serde = {workspace = true}
strum = {workspace = true}
strum_macros = {workspace = true}
syn = {workspace = true, features = ["parsing"]}
topological-sort = {version = "0.2.2"}
[build-dependencies]
protobuf-codegen = {version = "3.2"}
[dev-dependencies]
rstest = "0.17.0"

65
burn-import/README.md Normal file
View File

@ -0,0 +1,65 @@
# Burn Import
`burn-import` is a crate designed to facilitate importing models trained in other machine learning
frameworks into the Burn framework. This tool generates a Rust source file that aligns the source
model with Burn's model and converts tensor data into a format compatible with Burn.
Currently under development, `burn-import` supports importing ONNX models with a limited set of
operators.
## Supported ONNX Operators
- Conv2d
- Gemm (Linear layer)
- Flatten
- LogSoftmax
## Usage
### Importing ONNX models
In `build.rs`, add the following:
```rust
use burn_import::onnx::ModelCodeGen;
fn main() {
ModelCodeGen::new()
.input("src/model/mnist.onnx") // Path to the ONNX model
.out_dir("model/") // Directory to output the generated Rust source file (under target/)
.run_from_script();
}
```
Then, add the following to mod.rs under `src/model`:
```rust
pub mod mnist {
include!(concat!(env!("OUT_DIR"), "/model/mnist.rs"));
}
```
Finally, in your code, you can use the imported model as follows:
```rust
use burn::tensor;
use burn_ndarray::NdArrayBackend;
use onnx_inference::model::mnist::{Model, INPUT1_SHAPE};
fn main() {
// Create a new model
let model: Model<NdArrayBackend<f32>> = Model::new();
// Create a new input tensor (all zeros for demonstration purposes)
let input = tensor::Tensor::<NdArrayBackend<f32>, 4>::zeros(INPUT1_SHAPE);
// Run the model
let output = model.forward(input);
// Print the output
println!("{:?}", output);
}
```
You can view the working example in the `examples/onnx-inference` directory.

11
burn-import/build.rs Normal file
View File

@ -0,0 +1,11 @@
fn main() {
if cfg!(feature = "onnx") {
// Generate the onnx protobuf files
protobuf_codegen::Codegen::new()
.pure()
.includes(["src"])
.input("src/onnx/protos/onnx.proto")
.cargo_out_dir("onnx-protos")
.run_from_script();
}
}

6
burn-import/src/lib.rs Normal file
View File

@ -0,0 +1,6 @@
#![allow(clippy::ptr_arg)]
#![allow(clippy::single_match)]
#![allow(clippy::upper_case_acronyms)]
#[cfg(feature = "onnx")]
pub mod onnx;

View File

@ -0,0 +1,47 @@
use super::ir::{AttributeValue, Node, NodeType};
/// The function transforms the graph into a new one where the nodes are coalesced into a single node.
pub fn coalesce(nodes: &mut Vec<Node>) {
for node in nodes.iter_mut() {
match node.node_type {
NodeType::Gemm => convert_gemm(node),
_ => {}
}
}
}
/// This function converts a Gemm node into a Linear node
///
/// Warning: This function is not complete yet.
/// It only supports the case where the Gemm node is a straight linear transformation.
fn convert_gemm(node: &mut Node) {
if node.inputs.len() != 1 {
panic!("Gemm node must have 3 inputs");
}
if node.outputs.len() != 1 {
panic!("Gemm node must have 1 output");
}
let straight_linear = match (
node.attrs.get("alpha"),
node.attrs.get("beta"),
node.attrs.get("transB"),
) {
(
Some(AttributeValue::Float32(alpha)),
Some(AttributeValue::Float32(beta)),
Some(AttributeValue::Int64(trans_b)),
) => *alpha == 1.0 && *beta == 1.0 && *trans_b == 1,
_ => false,
};
if straight_linear {
node.node_type = NodeType::Linear;
node.is_stateful = true;
node.attrs.remove("alpha");
node.attrs.remove("beta");
node.attrs.remove("transB");
} else {
panic!("Full Gemm node not supported yet.");
}
}

View File

@ -0,0 +1,574 @@
use std::{
collections::HashSet,
env,
fs::{self, create_dir_all},
path::{Path, PathBuf},
};
use burn::nn::conv::Conv2dPaddingConfig;
use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::{Ident, Type};
use crate::onnx::{
ir::{ArgType, Node, NodeType},
op_configuration::{conv2d_config, flatten_config, linear_config, log_softmax_config},
};
use super::{convert::parse_onnx, ir::Graph};
use rust_format::{Config, Edition, Formatter, PostProcess, RustFmt};
/// Code generation for onnx files.
#[derive(Debug, Default)]
pub struct ModelCodeGen {
out_dir: Option<PathBuf>,
/// List of onnx files to generate source code from.
inputs: Vec<PathBuf>,
}
/// Generate code from `.onnx` files and save it to the `out_dir`.
impl ModelCodeGen {
pub fn new() -> Self {
Self::default()
}
/// Set output directory.
pub fn out_dir(&mut self, out_dir: &str) -> &mut Self {
let cargo_out_dir = env::var("OUT_DIR").expect("OUT_DIR env is not set");
let mut path = PathBuf::from(cargo_out_dir);
// Append the out_dir to the cargo_out_dir
path.push(Path::new(out_dir));
self.out_dir = Some(path);
self
}
/// Add input file.
pub fn input(&mut self, input: &str) -> &mut Self {
self.inputs.push(input.into());
self
}
/// Run code generation.
///
/// This function is intended to be called from `build.rs` script.
pub fn run_from_script(&self) {
self.run();
}
/// Run code generation.
pub fn run(&self) {
let config = Config::new_str()
.post_proc(PostProcess::ReplaceMarkersAndDocBlocks)
.edition(Edition::Rust2021);
let rust_formatter = RustFmt::from_config(config);
let out_dir = self.out_dir.as_ref().expect("out_dir is not set");
create_dir_all(out_dir).unwrap();
for input in self.inputs.iter() {
let file_name = input.file_stem().unwrap();
let out_file = out_dir.join(file_name);
let out_file = out_file.with_extension("rs");
let model = ModelSourceCode::new(input);
let code_str = rust_formatter.format_tokens(model.body()).unwrap();
fs::write(out_file, code_str).unwrap();
}
}
}
/// A model that can be used to generate code
#[derive(Debug, Clone)]
pub struct ModelSourceCode {
onnx_path: PathBuf,
pub graph: Graph,
}
impl ModelSourceCode {
/// Create a new model from the onnx file
pub fn new<P: AsRef<Path>>(onnx_path: P) -> Self {
let graph = parse_onnx(onnx_path.as_ref());
Self {
onnx_path: onnx_path.as_ref().to_path_buf(),
graph,
}
}
/// Generates source code for the model
pub fn body(&self) -> TokenStream {
let input = "Model"; // TODO make this a parameter
let input = Ident::new(input, Span::call_site());
let declaration = self.declaration(&input);
let file_path = self.onnx_path.to_str().unwrap();
let top_file_comment = format!("Generated from {file_path} by burn-import");
let mut imports: HashSet<String> = HashSet::new();
let implementation = self.implementation(&mut imports);
let import_statements = self.import_statements(&imports);
let shape_constants = self.shape_constants();
//TODO print out the old -> new name mapping
quote! {
_comment_!(#top_file_comment);
_blank_!();
_blank_!();
#import_statements
_blank_!();
#shape_constants
_blank_!();
#declaration
_blank_!();
#[allow(dead_code)]
#[allow(clippy::new_without_default)]
#[allow(clippy::let_and_return)]
#implementation
}
}
fn shape_constants(&self) -> TokenStream {
let input_constants = self.graph.inputs.iter().enumerate().map(|(i, input)| {
let name = format!("INPUT{}_SHAPE", i + 1);
let name = Ident::new(&name, Span::call_site());
let ArgType::Tensor(tensor) = input.clone().arg_type.unwrap();
let dims = tensor.shape;
let dims_count = dims.len();
quote! {
pub const #name: [usize; #dims_count] = [#(#dims),*];
}
});
let output_constants = self.graph.outputs.iter().enumerate().map(|(i, input)| {
let name = format!("OUTPUT{}_SHAPE", i + 1);
let name = Ident::new(&name, Span::call_site());
let ArgType::Tensor(tensor) = input.clone().arg_type.unwrap();
let dims = tensor.shape;
let dims_count = dims.len();
quote! {
pub const #name: [usize; #dims_count] = [#(#dims),*];
}
});
quote! {
#(#input_constants)*
#(#output_constants)*
}
}
/// Generates import statements for the model
fn import_statements(&self, imports: &HashSet<String>) -> TokenStream {
let mut import_tokens = vec![];
for import in imports.iter() {
let path: syn::Path =
syn::parse_str(import).expect("Unable to parse input string as a path");
import_tokens.push(quote! { #path });
}
quote! {
use burn::{
module::Module,
nn,
tensor::{backend::Backend, Tensor},
};
#(use #import_tokens;)*
}
}
/// Generates the declaration portion of the source code for the model
fn declaration(&self, name: &Ident) -> TokenStream {
let fields = self.declaration_fields();
let mut field_names = vec![];
let mut field_types = vec![];
for (field_name, field_type) in fields.iter() {
field_names.push(field_name);
field_types.push(field_type);
}
quote! {
// TODO add documentation
#[doc = "This is a generated model from an ONNX file"]
#[derive(Module, Debug)]
pub struct #name<B: Backend> {
#(
#field_names: #field_types,
)*
}
}
}
/// Model implementation code
fn implementation(&self, imports: &mut HashSet<String>) -> TokenStream {
let forward_method = self.forward_method(imports);
let new_method = self.new_method();
quote! {
impl<B: Backend> Model<B> {
#new_method
#forward_method
}
}
}
/// Generates the new method for the model
fn forward_method(&self, imports: &mut HashSet<String>) -> TokenStream {
let inputs = self.forward_signature_input();
let return_type = self.forward_signature_return();
let results = self.forward_method_results();
let mut call_nodes: Vec<TokenStream> = vec![];
for node in self.graph.nodes.iter() {
if node.is_stateful {
call_nodes.push(Self::node_call_stateful(node));
} else {
call_nodes.push(Self::node_call_stateless(node, imports));
}
}
quote! {
pub fn forward(&self, #(#inputs,)*) -> #return_type {
#(#call_nodes)*
#results
}
}
}
/// Generates source code for the stateful node calls, i.e. conv, dropout, etc.
fn node_call_stateful(node: &Node) -> TokenStream {
if !node.is_stateful {
panic!("Node must be stateful");
}
let name = Ident::new(&node.name, Span::call_site());
let mut inputs = vec![];
for input in node.inputs.iter() {
let name = Ident::new(&input.name, Span::call_site());
inputs.push(quote! {
#name
});
}
let mut outputs = vec![];
for output in node.outputs.iter() {
let name = Ident::new(&output.name, Span::call_site());
outputs.push(quote! {
#name
});
}
if outputs.len() == 1 {
let output = outputs.pop().unwrap();
quote! {
let #output = self.#name.forward(#(#inputs,)*);
}
} else {
quote! {
let (#(#outputs,)*) = self.#name.forward(#(#inputs,)*);
}
}
}
/// Generates source code for the forward method results
fn forward_method_results(&self) -> TokenStream {
let mut outputs = vec![];
for output in self.graph.outputs.iter() {
let name = Ident::new(&output.name, Span::call_site());
outputs.push(quote! {
#name
});
}
if outputs.len() == 1 {
let output = outputs.pop().unwrap();
quote! {
#output
}
} else {
quote! {
(#(#outputs,)*)
}
}
}
/// Generates source code for the stateless node calls, i.e. add, mul, etc.
fn node_call_stateless(node: &Node, imports: &mut HashSet<String>) -> TokenStream {
if node.is_stateful {
panic!("Node must be stateless");
}
let mut inputs = vec![];
for input in node.inputs.iter() {
let name = Ident::new(&input.name, Span::call_site());
inputs.push(quote! {
#name
});
}
let mut outputs = vec![];
for output in node.outputs.iter() {
let name = Ident::new(&output.name, Span::call_site());
outputs.push(quote! {
#name
});
}
let rhs = Self::node_call_stateless_rhs(node, imports);
if outputs.len() == 1 {
let output = outputs.pop().unwrap();
quote! {
let #output = #rhs;
}
} else {
quote! {
let (#(#outputs,)*) = #rhs;
}
}
}
/// Generates source code for the right hand side stateless node calls, i.e. add, relu, etc.
fn node_call_stateless_rhs(node: &Node, imports: &mut HashSet<String>) -> TokenStream {
let mut inputs = vec![];
for input in node.inputs.iter() {
let name = Ident::new(&input.name, Span::call_site());
inputs.push(quote! {
#name
});
}
let input1 = inputs.pop().unwrap();
match node.node_type {
NodeType::Relu => {
imports.insert("burn::tensor::activation::relu".to_string());
quote! { relu(#input1) }
}
NodeType::LogSoftmax => {
imports.insert("burn::tensor::activation::log_softmax".to_string());
let dim = log_softmax_config(node);
quote! { log_softmax(#input1, #dim) }
}
NodeType::Flatten => {
let (start_dim, end_dim) = flatten_config(node);
quote! { #input1.flatten(#start_dim, #end_dim) }
}
_ => quote! {},
}
}
/// Generates the forward method signature
fn forward_signature_input(&self) -> Vec<TokenStream> {
let mut fields = vec![];
for input in self.graph.inputs.iter() {
let name = Ident::new(&input.name, Span::call_site());
let ty = match input.arg_type.as_ref().unwrap() {
ArgType::Tensor(tensor) => {
let d = &tensor.shape.len();
syn::parse_str::<Type>(format!("Tensor<B, {d}>").as_str()).unwrap()
}
};
fields.push(quote! {
#name: #ty
});
}
fields
}
/// Generates the forward method return signature
fn forward_signature_return(&self) -> TokenStream {
let mut field_types = vec![];
for output in self.graph.outputs.iter() {
let ty = match output.arg_type.as_ref().unwrap() {
ArgType::Tensor(tensor) => {
let d = &tensor.shape.len();
syn::parse_str::<Type>(format!("Tensor<B, {d}>").as_str()).unwrap()
}
};
field_types.push(ty);
}
if field_types.len() == 1 {
// Return one output
quote! {
#(
#field_types
)*
}
} else {
// Return a tuple of the outputs
quote! {
(#(
#field_types,
)*)
}
}
}
/// Generates source code for the initialization method
fn new_method(&self) -> TokenStream {
let initialization_fields = self.initialization_fields();
let field_names = self.graph.nodes.iter().filter(|x| x.is_stateful).map(|x| {
let name = Ident::new(&x.name, Span::call_site());
quote! {
#name
}
});
quote! {
pub fn new() -> Self {
#(
#initialization_fields
)*
Self {
#(
#field_names
),*
}
}
}
}
/// Get the fields for the declaration of the model
fn declaration_fields(&self) -> Vec<(Ident, Type)> {
let mut fields = vec![];
for node in self.graph.nodes.iter().filter(|x| x.is_stateful) {
let node_type = match node.node_type {
NodeType::Conv1d => syn::parse_str::<Type>("nn::conv::Conv1d<B>").unwrap(),
NodeType::Conv2d => syn::parse_str::<Type>("nn::conv::Conv2d<B>").unwrap(),
NodeType::Linear => syn::parse_str::<Type>("nn::Linear<B>").unwrap(),
_ => {
todo!("Node type not implemented: {:?}", node.node_type)
}
};
let node_name = Ident::new(&node.name, Span::call_site());
fields.push((node_name, node_type));
}
fields
}
/// Generates source code for the initialization method
fn initialization_fields(&self) -> Vec<TokenStream> {
let mut fields = vec![];
for node in self.graph.nodes.iter().filter(|x| x.is_stateful) {
let init_code = match node.node_type {
NodeType::Conv2d => conv2d_init(node),
NodeType::Linear => linear_init(node),
_ => {
todo!("Node type not implemented: {:?}", node.node_type)
}
};
fields.push(init_code);
}
fields
}
}
/// Generates source code for the initialization of a Conv2d node
fn conv2d_init(node: &Node) -> TokenStream {
let node_name = Ident::new(&node.name, Span::call_site());
let config = conv2d_config(node);
let channel_in = config.channels[0];
let channel_out = config.channels[1];
let kernel_size_0 = config.kernel_size[0];
let kernel_size_1 = config.kernel_size[1];
let bias = config.bias;
let padding = match config.padding {
Conv2dPaddingConfig::Valid => quote! { nn::conv::Conv2dPaddingConfig::Valid },
Conv2dPaddingConfig::Same => quote! { nn::conv::Conv2dPaddingConfig::Same },
_ => todo!("Padding ({:?}) not implemented", config.padding),
};
quote! {
let #node_name = nn::conv::Conv2dConfig::new([#channel_in, #channel_out], [#kernel_size_0, #kernel_size_1])
.with_padding(#padding)
.with_bias(#bias)
.init();
}
}
/// Generates source code for the initialization of a Linear node
fn linear_init(node: &Node) -> TokenStream {
let node_name = Ident::new(&node.name, Span::call_site());
let config = linear_config(node);
let bias = config.bias;
let input_size = config.d_input;
let output_size = config.d_output;
quote! {
let #node_name = nn::LinearConfig::new(#input_size, #output_size)
.with_bias(#bias)
.init();
}
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::*;
use rust_format::{Config, Edition, Formatter, PostProcess, RustFmt};
#[fixture]
pub fn model() -> ModelSourceCode {
ModelSourceCode::new("tests/onnx/mnist.onnx")
}
#[rstest]
fn print(model: ModelSourceCode) {
let config = Config::new_str()
.post_proc(PostProcess::ReplaceMarkersAndDocBlocks)
.edition(Edition::Rust2021);
let rustfmt = RustFmt::from_config(config);
let _gen_str = rustfmt.format_tokens(model.body()).unwrap();
// TODO compare the result with the expected output
}
}

View File

@ -0,0 +1,523 @@
use std::{
collections::{HashMap, HashSet},
fs::File,
path::Path,
str::{from_utf8, FromStr},
};
use super::coalesce::coalesce;
use super::ir::{
ArgType, Argument, AttributeValue, Attributes, ElementType, Graph, Node, NodeType, Tensor,
TensorData,
};
use super::protos::{
attribute_proto::AttributeType, tensor_proto::DataType, tensor_shape_proto::dimension::Value,
type_proto, AttributeProto, ModelProto, NodeProto, TensorProto, TensorShapeProto,
ValueInfoProto,
};
use super::shape_inference::shape_inference;
use protobuf::{Enum, Message};
use topological_sort::TopologicalSort;
const STATEFUL_NODE_TYPES: [NodeType; 4] = [
NodeType::Conv,
NodeType::BatchNormalization,
NodeType::Dropout,
NodeType::Linear,
];
/// Error type for parsing ONNX model
#[derive(Debug)]
pub enum ParseError {
VariantNotFound,
}
/// Open an onnx file and convert it to a Graph (intermediate representation)
pub fn parse_onnx(onnx_path: &Path) -> Graph {
// Open the file
let mut file = File::open(onnx_path).expect("Unable to open file");
let onnx_model: ModelProto =
Message::parse_from_reader(&mut file).expect("Unable to parse ONNX file");
// Convert the nodes
let mut nodes: Vec<Node> = vec![];
for onnx_node in onnx_model.graph.node.iter() {
nodes.push(convert_node_proto(onnx_node));
}
// Get the names of the initializers
let check_if_initializer: HashSet<String> = onnx_model
.graph
.initializer
.iter()
.map(|x| x.name.clone())
.collect();
// Move inputs to initializers
move_inputs_to_initializer(&mut nodes, &check_if_initializer);
// Get the topological sort of the nodes and the top nodes
let (ts, top_nodes) = get_top_nodes(&nodes);
// Sort the nodes
top_sort_nodes(&mut nodes, ts);
// Collect inputs, outputs and initializers
let mut inputs = collect_inputs(&onnx_model, &check_if_initializer, top_nodes);
let mut outputs = collect_outputs(&onnx_model, check_if_initializer);
let initializers = collect_initializers(onnx_model);
// Coalesce and transform nodes
coalesce(&mut nodes);
// Copy the initializers to the nodes
copy_initializer_info_to_nodes_level(&mut nodes, &initializers);
// Rename nodes and inputs, save the mapping for later
let old_node_names = rename_nodes(&mut nodes);
let old_input_names = rename_inputs(&mut nodes, &mut inputs, &mut outputs);
// Infer shapes and update the inputs and outputs
shape_inference(&mut nodes, &inputs, &mut outputs);
Graph {
nodes,
inputs,
outputs,
initializers,
old_node_names,
old_input_names,
}
}
/// Collect initializers
fn collect_initializers(onnx_model: ModelProto) -> Vec<Argument> {
let mut initializers: Vec<Argument> = vec![];
for initializer in onnx_model.graph.initializer.iter() {
let tensor_proto = initializer.clone();
let name = tensor_proto.name.clone();
// FIXME data conversion for the tensor is incorrect
let tensor: Tensor = tensor_proto.try_into().unwrap();
let arg_type = Some(ArgType::Tensor(tensor));
let arg = Argument { name, arg_type };
initializers.push(arg);
}
initializers
}
/// Collect outputs
fn collect_outputs(
onnx_model: &ModelProto,
check_if_initializer: HashSet<String>,
) -> Vec<Argument> {
// TODO: filter out the outputs that are not used in the graph
let outputs: Vec<Argument> = onnx_model
.graph
.output
.iter()
.filter(|x| !check_if_initializer.contains(x.name.as_str()))
.map(|i| Argument::try_from(i.clone()).unwrap())
.collect();
outputs
}
/// Collect inputs
fn collect_inputs(
onnx_model: &ModelProto,
check_if_initializer: &HashSet<String>,
top_nodes: HashSet<String>,
) -> Vec<Argument> {
let inputs: Vec<Argument> = onnx_model
.graph
.input
.iter()
.filter(|x| !check_if_initializer.contains(x.name.as_str()))
.filter(|x| top_nodes.contains(&x.name))
.map(|x| Argument::try_from(x.clone()).unwrap())
.collect();
inputs
}
/// Sort the nodes in topological order
fn top_sort_nodes(nodes: &mut Vec<Node>, mut ts: TopologicalSort<Node>) {
*nodes = vec![];
while let Some(node) = ts.pop() {
nodes.push(node);
}
}
/// Get the top nodes in the graph
fn get_top_nodes(nodes: &Vec<Node>) -> (TopologicalSort<Node>, HashSet<String>) {
// Get the names of the top nodes (first nodes in the graph to receive the input)
// Sometimes onnx will pass inputs to be used as weights and biases but they are not truly inputs
let ts = topsort(nodes);
let mut top_nodes: HashSet<String> = HashSet::new();
for node in ts.peek_all() {
for input in node.inputs.iter() {
top_nodes.insert(input.name.clone());
}
}
(ts, top_nodes)
}
/// Move nodes's inputs and outputs to initializers if they are in the initializer list
fn move_inputs_to_initializer(nodes: &mut Vec<Node>, check_if_initializer: &HashSet<String>) {
for node in nodes.iter_mut() {
node.initializers = node
.inputs
.iter()
.filter(|x| check_if_initializer.contains(&x.name))
.cloned()
.collect();
// Remove the initializers from the inputs and outputs
node.inputs
.retain(|x| !check_if_initializer.contains(&x.name));
node.outputs
.retain(|x| !check_if_initializer.contains(&x.name));
}
}
fn to_string(bytes: Vec<u8>) -> String {
from_utf8(bytes.as_slice()).unwrap().to_string()
}
fn to_string_vec(bytes: Vec<Vec<u8>>) -> Vec<String> {
bytes.iter().map(|b| to_string(b.clone())).collect()
}
fn convert_shape(shape: Vec<i64>) -> Vec<usize> {
shape.iter().map(|s| *s as usize).collect()
}
/// Convert a vector of AttributeProto to a HashMap of AttributeValue
impl TryFrom<TensorProto> for Tensor {
type Error = ParseError;
fn try_from(tensor: TensorProto) -> Result<Tensor, Self::Error> {
let (elem_type, data) = match DataType::from_i32(tensor.data_type).unwrap() {
// TODO: check if float is empty and use raw_data instead
DataType::FLOAT => (
ElementType::Float32,
TensorData::Float32s(tensor.float_data),
),
DataType::INT32 => (ElementType::Int32, TensorData::Int32s(tensor.int32_data)),
DataType::INT64 => (ElementType::Int64, TensorData::Int64s(tensor.int64_data)),
DataType::DOUBLE => (
ElementType::Float64,
TensorData::Float64s(tensor.double_data),
),
// TODO : Add more types
_ => {
return Err(ParseError::VariantNotFound);
}
};
let shape = convert_shape(tensor.dims);
let name = tensor.name;
Ok(Tensor {
name: Some(name),
elem_type,
shape,
data: Some(data),
})
}
}
impl TryFrom<TensorShapeProto> for Vec<usize> {
type Error = ParseError;
fn try_from(shape: TensorShapeProto) -> Result<Vec<usize>, Self::Error> {
let mut result = Vec::new();
for dim in shape.dim {
if let Value::DimValue(value) = dim.value.unwrap() {
result.push(value as usize);
}
}
Ok(result)
}
}
/// Convert a vector of AttributeProto to a HashMap of AttributeValue
impl TryFrom<&type_proto::Tensor> for Tensor {
type Error = ParseError;
fn try_from(tensor: &type_proto::Tensor) -> Result<Tensor, Self::Error> {
let elem_type = match DataType::from_i32(tensor.elem_type).unwrap() {
DataType::FLOAT => ElementType::Float32,
DataType::INT32 => ElementType::Int32,
DataType::INT64 => ElementType::Int64,
DataType::DOUBLE => ElementType::Float64,
// TODO : Add more types
_ => {
return Err(ParseError::VariantNotFound);
}
};
let shape_proto = tensor.shape.clone().unwrap();
let shape: Vec<usize> = shape_proto.try_into().unwrap();
let name = None;
Ok(Tensor {
name,
elem_type,
shape,
data: None,
})
}
}
fn convert_vec_tensor_proto(tensors: Vec<TensorProto>) -> Result<Vec<Tensor>, ParseError> {
let mut result = Vec::new();
for tensor in tensors {
result.push(Tensor::try_from(tensor)?);
}
Ok(result)
}
/// Convert a vector of AttributeProto to a HashMap of AttributeValue
impl TryFrom<AttributeProto> for AttributeValue {
type Error = ParseError;
fn try_from(attr: AttributeProto) -> Result<AttributeValue, Self::Error> {
let value = match attr.type_.unwrap() {
AttributeType::FLOAT => AttributeValue::Float32(attr.f),
AttributeType::INT => AttributeValue::Int64(attr.i),
AttributeType::STRING => AttributeValue::String(to_string(attr.s)),
// warning: tensor can be empty TODO: check if it is empty
AttributeType::TENSOR => AttributeValue::Tensor(Tensor::try_from(attr.t.unwrap())?),
// Graph is not supported for now
// AttributeType::GRAPH => AttributeValue::Graph(attr.g),
AttributeType::FLOATS => AttributeValue::Float32s(attr.floats),
AttributeType::INTS => AttributeValue::Int64s(attr.ints),
AttributeType::STRINGS => AttributeValue::Strings(to_string_vec(attr.strings)),
AttributeType::TENSORS => {
AttributeValue::Tensors(convert_vec_tensor_proto(attr.tensors)?)
}
// AttributeType::GRAPHS => AttributeValue::Graphs(attr.graphs),
// AttributeType::SPARSE_TENSORS => AttributeValue::SparseTensors(attr.sparse_tensors),
// AttributeType::SPARSE_TENSOR => AttributeValue::SparseTensor(attr.sparse_tensor),
_ => {
return Err(ParseError::VariantNotFound);
}
};
Ok(value)
}
}
/// Convert a vector of AttributeProto to a HashMap of AttributeValue
pub fn convert_vec_attrs_proto(attrs: Vec<AttributeProto>) -> Attributes {
let mut result = Attributes::new();
for attr in attrs {
result.insert(attr.name.clone(), AttributeValue::try_from(attr).unwrap());
}
result
}
pub fn convert_node_proto(node: &NodeProto) -> Node {
let name = node.name.clone();
let inputs = node
.input
.clone()
.into_iter()
.map(|x| Argument {
name: x,
arg_type: None,
})
.collect();
let outputs = node
.output
.clone()
.into_iter()
.map(|x| Argument {
name: x,
arg_type: None,
})
.collect();
let attrs = convert_vec_attrs_proto(node.attribute.clone());
let node_type = NodeType::from_str(node.op_type.as_str()).unwrap();
let is_stateful = STATEFUL_NODE_TYPES.contains(&node_type);
let mut node = Node {
node_type,
name,
inputs,
outputs,
initializers: vec![],
attrs,
is_stateful,
};
remap_node_type(&mut node);
node
}
/// Remap node type to a more specific one
fn remap_node_type(node: &mut Node) {
match node.node_type {
NodeType::Conv => {
if let AttributeValue::Int64s(ints) = node.attrs.get("kernel_shape").unwrap() {
node.node_type = match ints.len() {
1 => NodeType::Conv1d,
2 => NodeType::Conv2d,
_ => todo!(),
};
} else {
panic!("kernel_shape is not an int64s");
}
}
_ => (),
}
}
/// Convert a vector of AttributeProto to a HashMap of AttributeValue
impl TryFrom<ValueInfoProto> for Argument {
type Error = ParseError;
fn try_from(value: ValueInfoProto) -> Result<Argument, Self::Error> {
let name = value.name.clone();
let proto_type = value.type_.unwrap();
let mut arg_type = None;
if proto_type.has_tensor_type() {
let tensor_proto = proto_type.tensor_type();
let tensor: Tensor = tensor_proto.try_into().unwrap();
arg_type = Some(ArgType::Tensor(tensor));
}
Ok(Argument { name, arg_type })
}
}
/// Copy the initializers to the nodes
fn copy_initializer_info_to_nodes_level(nodes: &mut Vec<Node>, initializers: &Vec<Argument>) {
for node in nodes.iter_mut() {
for node_initializer in node.initializers.iter_mut() {
*node_initializer = initializers
.iter()
.find(|x| x.name == node_initializer.name)
.unwrap()
.clone();
}
}
}
/// Rename the nodes in the graph to be unique and return a map of the old names to the new names.
fn rename_nodes(nodes: &mut Vec<Node>) -> HashMap<String, String> {
let mut old_names = HashMap::new();
let mut counter: HashMap<NodeType, usize> = HashMap::new();
for node in nodes.iter_mut() {
// keep track of the number of nodes of each type
counter
.entry(node.node_type.clone())
.and_modify(|e| *e += 1)
.or_insert(1);
let old_name = node.name.clone();
let new_name = format!("{}{}", node.node_type, counter[&node.node_type]).to_lowercase();
node.name = new_name.clone();
old_names.insert(old_name, new_name);
}
old_names
}
/// Rename the inputs in the graph and return a map of the old names to the new names.
///
/// The inputs are renamed to be unique and to be in the format of conv2_in1, conv2_in2, etc.
/// This is done to be consistent with the naming convention of the nodes and allow to be used as rust identifiers.
fn rename_inputs(
nodes: &mut Vec<Node>,
inputs: &mut Vec<Argument>,
outputs: &mut Vec<Argument>,
) -> HashMap<String, String> {
let mut old_names = HashMap::new();
let mut counter = 1;
for input in inputs.iter_mut() {
let old_name = input.name.clone();
let new_name = format!("input{}", counter);
input.name = new_name.clone();
old_names.insert(old_name, new_name);
counter += 1;
}
let mut counter: HashMap<String, usize> = HashMap::new();
for node in nodes.iter_mut() {
// keep track of the number of nodes of each type
counter
.entry(node.name.clone())
.and_modify(|e| *e += 1)
.or_insert(1);
// loop through node inputs and rename them with previously replaced names
for input in node.inputs.iter_mut() {
if let Some(new_name) = old_names.get(&input.name) {
input.name = new_name.clone();
}
}
// loop through node outputs and rename them and store the new name <-> old name mapping
for output in node.outputs.iter_mut() {
let old_name = output.name.clone();
let new_name = format!("{}_out{}", node.name, counter[&node.name]);
output.name = new_name.clone();
old_names.insert(old_name, new_name);
}
}
// Rename the graph outputs
for output in outputs.iter_mut() {
if let Some(new_name) = old_names.get(&output.name) {
output.name = new_name.clone();
}
}
old_names
}
/// Find the node that produces the given output
fn lookup_node_by_output(nodes: &Vec<Node>, input: &str) -> Option<Node> {
for node in nodes.iter() {
if node.outputs.iter().any(|x| x.name == *input) {
return Some(node.clone());
}
}
None
}
/// Sort nodes in topological order
pub fn topsort(nodes: &Vec<Node>) -> TopologicalSort<Node> {
let mut ts = TopologicalSort::new();
for node in nodes.iter() {
for input in node.inputs.iter() {
match lookup_node_by_output(nodes, input.name.as_str()) {
Some(prec) => ts.add_dependency(prec, node.clone()),
None => {}
}
}
}
ts
}

305
burn-import/src/onnx/ir.rs Normal file
View File

@ -0,0 +1,305 @@
use half::f16;
use std::collections::HashMap;
use strum_macros::{Display, EnumString};
pub type Shape = Vec<usize>;
#[derive(Debug, Clone)]
pub struct Argument {
pub name: String,
pub arg_type: Option<ArgType>,
}
#[derive(Debug, Clone)]
pub enum ArgType {
Tensor(Tensor),
}
#[derive(Debug, Clone)]
pub struct SparseTensor(Tensor, Tensor, Shape);
#[derive(Debug, Clone)]
pub enum AttributeValue {
Float32(f32),
Int64(i64),
String(String),
Tensor(Tensor),
SparseTensor(SparseTensor),
Float32s(Vec<f32>),
Int64s(Vec<i64>),
Strings(Vec<String>),
Tensors(Vec<Tensor>),
SparseTensors(Vec<SparseTensor>),
}
pub type Attributes = HashMap<String, AttributeValue>;
#[derive(Debug, Clone)]
pub enum ElementType {
Float32,
Float64,
Int32,
Int64,
String,
Float16,
}
#[derive(Debug, Clone)]
pub struct Tensor {
pub name: Option<String>,
pub elem_type: ElementType,
pub shape: Shape,
pub data: Option<TensorData>,
}
#[derive(Debug, Clone)]
pub enum TensorData {
Float16s(Vec<f16>),
Float32s(Vec<f32>),
Float64s(Vec<f64>),
Int32s(Vec<i32>),
Int64s(Vec<i64>),
Strings(Vec<String>),
}
#[derive(Debug, Clone)]
pub struct Graph {
pub nodes: Vec<Node>,
pub inputs: Vec<Argument>,
pub outputs: Vec<Argument>,
pub initializers: Vec<Argument>,
pub old_node_names: HashMap<String, String>,
pub old_input_names: HashMap<String, String>,
}
#[derive(Debug, Clone)]
pub struct Node {
pub node_type: NodeType,
pub name: String,
pub inputs: Vec<Argument>,
pub outputs: Vec<Argument>,
pub initializers: Vec<Argument>,
pub attrs: Attributes,
pub is_stateful: bool,
}
// Required by topological sort
impl PartialEq for Node {
fn eq(&self, other: &Self) -> bool {
self.name == other.name && self.node_type == other.node_type
}
}
// Required by topological sort
impl Eq for Node {}
// Required by topological sort
impl core::hash::Hash for Node {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.name.hash(state);
self.node_type.hash(state);
self.inputs.hash(state);
self.outputs.hash(state);
}
}
// Required by topological sort
impl core::hash::Hash for Argument {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.name.hash(state);
}
}
/// The list of supported node types (ONNX operators and some extra ones to map easily to Burn's ops)
#[derive(Debug, Hash, Eq, PartialEq, EnumString, Clone, Display)]
pub enum NodeType {
Abs,
Acos,
Acosh,
Add,
And,
ArgMax,
ArgMin,
Asin,
Asinh,
Atan,
Atanh,
AveragePool,
BatchNormalization,
Bernoulli,
BitShift,
BitwiseAnd,
BitwiseNot,
BitwiseOr,
BitwiseXor,
BlackmanWindow,
Cast,
CastLike,
Ceil,
Celu,
CenterCropPad,
Clip,
Col,
Compress,
Concat,
ConcatFromSequence,
Constant,
ConstantOfShape,
Conv,
Conv1d,
Conv2d,
ConvInteger,
ConvTranspose,
Cos,
Cosh,
CumSum,
DepthToSpace,
DequantizeLinear,
Det,
DFT,
Div,
Dropout,
DynamicQuantizeLinear,
Einsum,
Elu,
Equal,
Erf,
Exp,
Expand,
EyeLike,
Flatten,
Floor,
Gather,
GatherElements,
GatherND,
Gelu,
Gemm,
GlobalAveragePool,
GlobalLpPool,
GlobalMaxPool,
Greater,
GreaterOrEqual,
GridSample,
GroupNormalization,
GRU,
HammingWindow,
HannWindow,
Hardmax,
HardSigmoid,
HardSwish,
Identity,
If,
Im,
InstanceNormalization,
IsInf,
IsNaN,
LayerNormalization,
LeakyRelu,
Less,
LessOrEqual,
Linear,
Log,
LogSoftmax,
Loop,
LpNormalization,
LpPool,
LRN,
LSTM,
MatMul,
MatMulInteger,
Max,
MaxPool,
MaxRoiPool,
MaxUnpool,
Mean,
MeanVarianceNormalization,
MelWeightMatrix,
Min,
Mish,
Mod,
Mul,
Multinomial,
Neg,
NegativeLogLikelihoodLoss,
NonMaxSuppression,
NonZero,
Not,
OneHot,
Optional,
OptionalGetElement,
OptionalHasElement,
Or,
Pad,
Pow,
PRelu,
QLinearConv,
QLinearMatMul,
QuantizeLinear,
RandomNormal,
RandomNormalLike,
RandomUniform,
RandomUniformLike,
Range,
Reciprocal,
ReduceL,
ReduceLogSum,
ReduceLogSumExp,
ReduceMax,
ReduceMean,
ReduceMin,
ReduceProd,
ReduceSum,
ReduceSumSquare,
Relu,
Reshape,
Resize,
ReverseSequence,
RNN,
RoiAlign,
Round,
Scan,
Scatter,
ScatterElements,
ScatterND,
Selu,
SequenceAt,
SequenceConstruct,
SequenceEmpty,
SequenceErase,
SequenceInsert,
SequenceLength,
SequenceMap,
Shape,
Shrink,
Sigmoid,
Sign,
Sin,
Sinh,
Size,
Slice,
Softmax,
SoftmaxCrossEntropyLoss,
Softplus,
Softsign,
SpaceToDepth,
Split,
SplitToSequence,
Sqrt,
Squeeze,
STFT,
StringNormalizer,
Sub,
Sum,
Tan,
Tanh,
TfIdfVectorizer,
ThresholdedRelu,
Tile,
TopK,
Transpose,
Trilu,
Unique,
Unsqueeze,
Upsample,
Where,
Xor,
}

View File

@ -0,0 +1,9 @@
mod coalesce;
mod codegen;
mod convert;
mod ir;
mod op_configuration;
mod protos;
mod shape_inference;
pub use codegen::*;

View File

@ -0,0 +1,181 @@
use burn::nn::{
conv::{Conv2dConfig, Conv2dPaddingConfig},
LinearConfig,
};
use super::ir::{ArgType, AttributeValue, Node};
#[inline(always)]
pub fn attr_value_vec_i64(value: &AttributeValue, target: &mut Vec<i64>) {
if let AttributeValue::Int64s(val) = value {
*target = val.clone();
}
}
#[inline(always)]
pub fn attr_value_i64(value: &AttributeValue, target: &mut i64) {
if let AttributeValue::Int64(val) = value {
*target = *val;
}
}
/// Create a Conv2dConfig from the attributes of the node
pub fn conv2d_config(curr: &Node) -> Conv2dConfig {
let mut kernel_shape = Vec::new();
let mut strides = Vec::new();
let mut pads = Vec::new();
let mut dilations = Vec::new();
let mut group: i64 = 0;
// extract the channels from the weight tensor's shape [out_channels, in_channels, ...]
let ArgType::Tensor(tensor) = curr.initializers.get(0).unwrap().clone().arg_type.unwrap();
// check if the bias is present
let bias = curr.initializers.len() == 2;
// the channels are inverted in the weight tensor
let channels: [usize; 2] = [tensor.shape[1], tensor.shape[0]];
for (key, value) in curr.attrs.iter() {
match key.as_str() {
"kernel_shape" => attr_value_vec_i64(value, &mut kernel_shape),
"strides" => attr_value_vec_i64(value, &mut strides),
"pads" => attr_value_vec_i64(value, &mut pads),
"dilations" => attr_value_vec_i64(value, &mut dilations),
"group" => attr_value_i64(value, &mut group),
_ => {}
}
}
let padding = if pads.iter().all(|&x| x == 0) {
Conv2dPaddingConfig::Valid
} else {
todo!("Conv2d: padding({pads:?}) is not fully supported");
};
if strides.iter().all(|&x| x != 1) {
todo!("Conv2d: strides({strides:?}) are not fully supported");
};
if dilations.iter().all(|&x| x != 1) {
todo!("Conv2d: dilations({dilations:?}) are not fully supported");
};
if group != 1 {
todo!("Conv2d: group ({group}) is not fully supported");
};
Conv2dConfig::new(
channels,
[kernel_shape[0] as usize, kernel_shape[1] as usize],
)
.with_bias(bias)
.with_padding(padding)
}
/// Create a FlattenConfig from the attributes of the node
pub fn flatten_config(curr: &Node) -> (usize, usize) {
// the begin dimension is the first dimension (Default: 1 per ONNX spec)
let mut start_dim: i64 = 1;
// check if the node has only one input
if curr.inputs.len() != 1 {
panic!(
"Flatten: multiple inputs are not supported (got {:?})",
curr.inputs.len()
);
}
// extract the shape of the input tensor
let ArgType::Tensor(tensor) = curr.inputs.get(0).unwrap().clone().arg_type.unwrap();
// check if the input tensor has at least 2 dimensions
if tensor.shape.len() < 2 {
panic!(
"Flatten: input tensor must have at least 2 dimensions (got {:?})",
tensor.shape.len()
);
}
// the end dimension is the last dimension
let end_dim = tensor.shape.len() - 1;
// extract the attributes
for (key, value) in curr.attrs.iter() {
match key.as_str() {
"axis" => attr_value_i64(value, &mut start_dim),
_ => {}
}
}
// if beg_dim is negative, it is counted from the end
if start_dim < 0 {
start_dim += tensor.shape.len() as i64;
}
(start_dim as usize, end_dim)
}
/// Create a LinearConfig from the attributes of the node
pub fn linear_config(node: &Node) -> LinearConfig {
// check if the node has only one input
if node.inputs.len() != 1 {
panic!(
"Linear: multiple inputs are not supported (got {:?})",
node.inputs.len()
);
}
if node.initializers.is_empty() {
panic!("Linear: no initializers found");
}
// extract the shape of the weight tensor
let ArgType::Tensor(tensor) = node.initializers.get(0).unwrap().clone().arg_type.unwrap();
// check if the weight tensor has at least 2 dimensions
if tensor.shape.len() < 2 {
panic!(
"Linear: weight tensor must have at least 2 dimensions (got {:?})",
tensor.shape.len()
);
}
let (out_size, in_size) = (tensor.shape[0], tensor.shape[1]);
// check if the bias is present
let bias = node.initializers.len() == 2;
LinearConfig::new(in_size, out_size).with_bias(bias)
}
/// Create log_softmax config from the attributes of the node
pub fn log_softmax_config(node: &Node) -> usize {
// the axis is the last dimension (Default: 1 per ONNX spec)
let mut axis: i64 = -1;
// check if the node has only one input
if node.inputs.len() != 1 {
panic!(
"LogSoftmax: multiple inputs are not supported (got {:?})",
node.inputs.len()
);
}
// extract the shape of the input tensor
let ArgType::Tensor(tensor) = node.inputs.get(0).unwrap().clone().arg_type.unwrap();
// extract the attributes
for (key, value) in node.attrs.iter() {
match key.as_str() {
"axis" => attr_value_i64(value, &mut axis),
_ => {}
}
}
// if axis is negative, it is counted from the end
if axis < 0 {
axis += tensor.shape.len() as i64;
}
axis as usize
}

View File

@ -0,0 +1,5 @@
mod inner {
include!(concat!(env!("OUT_DIR"), "/onnx-protos/mod.rs"));
}
pub use inner::onnx::*;

View File

@ -0,0 +1,825 @@
//
// WARNING: This file is automatically generated! Please edit onnx.in.proto.
//
// Copied from https://github.com/onnx/onnx/blob/main/onnx/onnx.proto3
// under the following license:
// SPDX-License-Identifier: Apache-2.0
syntax = "proto3";
package onnx;
// Overview
//
// ONNX is an open specification that is comprised of the following components:
//
// 1) A definition of an extensible computation graph model.
// 2) Definitions of standard data types.
// 3) Definitions of built-in operators.
//
// This document describes the syntax of models and their computation graphs,
// as well as the standard data types. Together, they are referred to as the ONNX
// Intermediate Representation, or 'IR' for short.
//
// The normative semantic specification of the ONNX IR is found in docs/IR.md.
// Definitions of the built-in neural network operators may be found in docs/Operators.md.
// Notes
//
// Protobuf compatibility
//
// To simplify framework compatibility, ONNX is defined using the subset of protobuf
// that is compatible with both protobuf v2 and v3. This means that we do not use any
// protobuf features that are only available in one of the two versions.
//
// Here are the most notable contortions we have to carry out to work around
// these limitations:
//
// - No 'map' (added protobuf 3.0). We instead represent mappings as lists
// of key-value pairs, where order does not matter and duplicates
// are not allowed.
// Versioning
//
// ONNX versioning is specified in docs/IR.md and elaborated on in docs/Versioning.md
//
// To be compatible with both proto2 and proto3, we will use a version number
// that is not defined by the default value but an explicit enum number.
enum Version {
// proto3 requires the first enum value to be zero.
// We add this just to appease the compiler.
_START_VERSION = 0;
// The version field is always serialized and we will use it to store the
// version that the graph is generated from. This helps us set up version
// control.
// For the IR, we are using simple numbers starting with 0x00000001,
// which was the version we published on Oct 10, 2017.
IR_VERSION_2017_10_10 = 0x0000000000000001;
// IR_VERSION 2 published on Oct 30, 2017
// - Added type discriminator to AttributeProto to support proto3 users
IR_VERSION_2017_10_30 = 0x0000000000000002;
// IR VERSION 3 published on Nov 3, 2017
// - For operator versioning:
// - Added new message OperatorSetIdProto
// - Added opset_import in ModelProto
// - For vendor extensions, added domain in NodeProto
IR_VERSION_2017_11_3 = 0x0000000000000003;
// IR VERSION 4 published on Jan 22, 2019
// - Relax constraint that initializers should be a subset of graph inputs
// - Add type BFLOAT16
IR_VERSION_2019_1_22 = 0x0000000000000004;
// IR VERSION 5 published on March 18, 2019
// - Add message TensorAnnotation.
// - Add quantization annotation in GraphProto to map tensor with its scale and zero point quantization parameters.
IR_VERSION_2019_3_18 = 0x0000000000000005;
// IR VERSION 6 published on Sep 19, 2019
// - Add support for sparse tensor constants stored in model.
// - Add message SparseTensorProto
// - Add sparse initializers
IR_VERSION_2019_9_19 = 0x0000000000000006;
// IR VERSION 7 published on May 8, 2020
// - Add support to allow function body graph to rely on multiple external opreator sets.
// - Add a list to promote inference graph's initializers to global and
// mutable variables. Global variables are visible in all graphs of the
// stored models.
// - Add message TrainingInfoProto to store initialization
// method and training algorithm. The execution of TrainingInfoProto
// can modify the values of mutable variables.
// - Implicitly add inference graph into each TrainingInfoProto's algorithm.
IR_VERSION_2020_5_8 = 0x0000000000000007;
// IR VERSION 8 published on July 30, 2021
// Introduce TypeProto.SparseTensor
// Introduce TypeProto.Optional
// Added a list of FunctionProtos local to the model
// Deprecated since_version and operator status from FunctionProto
IR_VERSION_2021_7_30 = 0x0000000000000008;
// IR VERSION 9 published on TBD
// Added AttributeProto to FunctionProto so that default attribute values can be set.
IR_VERSION = 0x0000000000000009;
}
// Attributes
//
// A named attribute containing either singular float, integer, string, graph,
// and tensor values, or repeated float, integer, string, graph, and tensor values.
// An AttributeProto MUST contain the name field, and *only one* of the
// following content fields, effectively enforcing a C/C++ union equivalent.
message AttributeProto {
// Note: this enum is structurally identical to the OpSchema::AttrType
// enum defined in schema.h. If you rev one, you likely need to rev the other.
enum AttributeType {
UNDEFINED = 0;
FLOAT = 1;
INT = 2;
STRING = 3;
TENSOR = 4;
GRAPH = 5;
SPARSE_TENSOR = 11;
TYPE_PROTO = 13;
FLOATS = 6;
INTS = 7;
STRINGS = 8;
TENSORS = 9;
GRAPHS = 10;
SPARSE_TENSORS = 12;
TYPE_PROTOS = 14;
}
// The name field MUST be present for this version of the IR.
string name = 1; // namespace Attribute
// if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function.
// In this case, this AttributeProto does not contain data, and it's a reference of attribute
// in parent scope.
// NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph.
string ref_attr_name = 21;
// A human-readable documentation for this attribute. Markdown is allowed.
string doc_string = 13;
// The type field MUST be present for this version of the IR.
// For 0.0.1 versions of the IR, this field was not defined, and
// implementations needed to use has_field heuristics to determine
// which value field was in use. For IR_VERSION 0.0.2 or later, this
// field MUST be set and match the f|i|s|t|... field in use. This
// change was made to accommodate proto3 implementations.
AttributeType type = 20; // discriminator that indicates which field below is in use
// Exactly ONE of the following fields must be present for this version of the IR
float f = 2; // float
int64 i = 3; // int
bytes s = 4; // UTF-8 string
TensorProto t = 5; // tensor value
GraphProto g = 6; // graph
SparseTensorProto sparse_tensor = 22; // sparse tensor value
// Do not use field below, it's deprecated.
// optional ValueProto v = 12; // value - subsumes everything but graph
TypeProto tp = 14; // type proto
repeated float floats = 7; // list of floats
repeated int64 ints = 8; // list of ints
repeated bytes strings = 9; // list of UTF-8 strings
repeated TensorProto tensors = 10; // list of tensors
repeated GraphProto graphs = 11; // list of graph
repeated SparseTensorProto sparse_tensors = 23; // list of sparse tensors
repeated TypeProto type_protos = 15;// list of type protos
}
// Defines information on value, including the name, the type, and
// the shape of the value.
message ValueInfoProto {
// This field MUST be present in this version of the IR.
string name = 1; // namespace Value
// This field MUST be present in this version of the IR for
// inputs and outputs of the top-level graph.
TypeProto type = 2;
// A human-readable documentation for this value. Markdown is allowed.
string doc_string = 3;
}
// Nodes
//
// Computation graphs are made up of a DAG of nodes, which represent what is
// commonly called a "layer" or "pipeline stage" in machine learning frameworks.
//
// For example, it can be a node of type "Conv" that takes in an image, a filter
// tensor and a bias tensor, and produces the convolved output.
message NodeProto {
repeated string input = 1; // namespace Value
repeated string output = 2; // namespace Value
// An optional identifier for this node in a graph.
// This field MAY be absent in ths version of the IR.
string name = 3; // namespace Node
// The symbolic identifier of the Operator to execute.
string op_type = 4; // namespace Operator
// The domain of the OperatorSet that specifies the operator named by op_type.
string domain = 7; // namespace Domain
// Additional named attributes.
repeated AttributeProto attribute = 5;
// A human-readable documentation for this node. Markdown is allowed.
string doc_string = 6;
}
// Training information
// TrainingInfoProto stores information for training a model.
// In particular, this defines two functionalities: an initialization-step
// and a training-algorithm-step. Initialization resets the model
// back to its original state as if no training has been performed.
// Training algorithm improves the model based on input data.
//
// The semantics of the initialization-step is that the initializers
// in ModelProto.graph and in TrainingInfoProto.algorithm are first
// initialized as specified by the initializers in the graph, and then
// updated by the "initialization_binding" in every instance in
// ModelProto.training_info.
//
// The field "algorithm" defines a computation graph which represents a
// training algorithm's step. After the execution of a
// TrainingInfoProto.algorithm, the initializers specified by "update_binding"
// may be immediately updated. If the targeted training algorithm contains
// consecutive update steps (such as block coordinate descent methods),
// the user needs to create a TrainingInfoProto for each step.
message TrainingInfoProto {
// This field describes a graph to compute the initial tensors
// upon starting the training process. Initialization graph has no input
// and can have multiple outputs. Usually, trainable tensors in neural
// networks are randomly initialized. To achieve that, for each tensor,
// the user can put a random number operator such as RandomNormal or
// RandomUniform in TrainingInfoProto.initialization.node and assign its
// random output to the specific tensor using "initialization_binding".
// This graph can also set the initializers in "algorithm" in the same
// TrainingInfoProto; a use case is resetting the number of training
// iteration to zero.
//
// By default, this field is an empty graph and its evaluation does not
// produce any output. Thus, no initializer would be changed by default.
GraphProto initialization = 1;
// This field represents a training algorithm step. Given required inputs,
// it computes outputs to update initializers in its own or inference graph's
// initializer lists. In general, this field contains loss node, gradient node,
// optimizer node, increment of iteration count.
//
// An execution of the training algorithm step is performed by executing the
// graph obtained by combining the inference graph (namely "ModelProto.graph")
// and the "algorithm" graph. That is, the actual the actual
// input/initializer/output/node/value_info/sparse_initializer list of
// the training graph is the concatenation of
// "ModelProto.graph.input/initializer/output/node/value_info/sparse_initializer"
// and "algorithm.input/initializer/output/node/value_info/sparse_initializer"
// in that order. This combined graph must satisfy the normal ONNX conditions.
// Now, let's provide a visualization of graph combination for clarity.
// Let the inference graph (i.e., "ModelProto.graph") be
// tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d
// and the "algorithm" graph be
// tensor_d -> Add -> tensor_e
// The combination process results
// tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d -> Add -> tensor_e
//
// Notice that an input of a node in the "algorithm" graph may reference the
// output of a node in the inference graph (but not the other way round). Also, inference
// node cannot reference inputs of "algorithm". With these restrictions, inference graph
// can always be run independently without training information.
//
// By default, this field is an empty graph and its evaluation does not
// produce any output. Evaluating the default training step never
// update any initializers.
GraphProto algorithm = 2;
// This field specifies the bindings from the outputs of "initialization" to
// some initializers in "ModelProto.graph.initializer" and
// the "algorithm.initializer" in the same TrainingInfoProto.
// See "update_binding" below for details.
//
// By default, this field is empty and no initializer would be changed
// by the execution of "initialization".
repeated StringStringEntryProto initialization_binding = 3;
// Gradient-based training is usually an iterative procedure. In one gradient
// descent iteration, we apply
//
// x = x - r * g
//
// where "x" is the optimized tensor, "r" stands for learning rate, and "g" is
// gradient of "x" with respect to a chosen loss. To avoid adding assignments
// into the training graph, we split the update equation into
//
// y = x - r * g
// x = y
//
// The user needs to save "y = x - r * g" into TrainingInfoProto.algorithm. To
// tell that "y" should be assigned to "x", the field "update_binding" may
// contain a key-value pair of strings, "x" (key of StringStringEntryProto)
// and "y" (value of StringStringEntryProto).
// For a neural network with multiple trainable (mutable) tensors, there can
// be multiple key-value pairs in "update_binding".
//
// The initializers appears as keys in "update_binding" are considered
// mutable variables. This implies some behaviors
// as described below.
//
// 1. We have only unique keys in all "update_binding"s so that two
// variables may not have the same name. This ensures that one
// variable is assigned up to once.
// 2. The keys must appear in names of "ModelProto.graph.initializer" or
// "TrainingInfoProto.algorithm.initializer".
// 3. The values must be output names of "algorithm" or "ModelProto.graph.output".
// 4. Mutable variables are initialized to the value specified by the
// corresponding initializer, and then potentially updated by
// "initializer_binding"s and "update_binding"s in "TrainingInfoProto"s.
//
// This field usually contains names of trainable tensors
// (in ModelProto.graph), optimizer states such as momentums in advanced
// stochastic gradient methods (in TrainingInfoProto.graph),
// and number of training iterations (in TrainingInfoProto.graph).
//
// By default, this field is empty and no initializer would be changed
// by the execution of "algorithm".
repeated StringStringEntryProto update_binding = 4;
}
// Models
//
// ModelProto is a top-level file/container format for bundling a ML model and
// associating its computation graph with metadata.
//
// The semantics of the model are described by the associated GraphProto's.
message ModelProto {
// The version of the IR this model targets. See Version enum above.
// This field MUST be present.
int64 ir_version = 1;
// The OperatorSets this model relies on.
// All ModelProtos MUST have at least one entry that
// specifies which version of the ONNX OperatorSet is
// being imported.
//
// All nodes in the ModelProto's graph will bind against the operator
// with the same-domain/same-op_type operator with the HIGHEST version
// in the referenced operator sets.
repeated OperatorSetIdProto opset_import = 8;
// The name of the framework or tool used to generate this model.
// This field SHOULD be present to indicate which implementation/tool/framework
// emitted the model.
string producer_name = 2;
// The version of the framework or tool used to generate this model.
// This field SHOULD be present to indicate which implementation/tool/framework
// emitted the model.
string producer_version = 3;
// Domain name of the model.
// We use reverse domain names as name space indicators. For example:
// `com.facebook.fair` or `com.microsoft.cognitiveservices`
//
// Together with `model_version` and GraphProto.name, this forms the unique identity of
// the graph.
string domain = 4;
// The version of the graph encoded. See Version enum below.
int64 model_version = 5;
// A human-readable documentation for this model. Markdown is allowed.
string doc_string = 6;
// The parameterized graph that is evaluated to execute the model.
GraphProto graph = 7;
// Named metadata values; keys should be distinct.
repeated StringStringEntryProto metadata_props = 14;
// Training-specific information. Sequentially executing all stored
// `TrainingInfoProto.algorithm`s and assigning their outputs following
// the corresponding `TrainingInfoProto.update_binding`s is one training
// iteration. Similarly, to initialize the model
// (as if training hasn't happened), the user should sequentially execute
// all stored `TrainingInfoProto.initialization`s and assigns their outputs
// using `TrainingInfoProto.initialization_binding`s.
//
// If this field is empty, the training behavior of the model is undefined.
repeated TrainingInfoProto training_info = 20;
// A list of function protos local to the model.
//
// Name of the function "FunctionProto.name" should be unique within the domain "FunctionProto.domain".
// In case of any conflicts the behavior (whether the model local functions are given higher priority,
// or standard opserator sets are given higher priotity or this is treated as error) is defined by
// the runtimes.
//
// The operator sets imported by FunctionProto should be compatible with the ones
// imported by ModelProto and other model local FunctionProtos.
// Example, if same operator set say 'A' is imported by a FunctionProto and ModelProto
// or by 2 FunctionProtos then versions for the operator set may be different but,
// the operator schema returned for op_type, domain, version combination
// for both the versions should be same for every node in the function body.
//
// One FunctionProto can reference other FunctionProto in the model, however, recursive reference
// is not allowed.
repeated FunctionProto functions = 25;
};
// StringStringEntryProto follows the pattern for cross-proto-version maps.
// See https://developers.google.com/protocol-buffers/docs/proto3#maps
message StringStringEntryProto {
string key = 1;
string value = 2;
};
message TensorAnnotation {
string tensor_name = 1;
// <key, value> pairs to annotate tensor specified by <tensor_name> above.
// The keys used in the mapping below must be pre-defined in ONNX spec.
// For example, for 8-bit linear quantization case, 'SCALE_TENSOR', 'ZERO_POINT_TENSOR' will be pre-defined as
// quantization parameter keys.
repeated StringStringEntryProto quant_parameter_tensor_names = 2;
}
// Graphs
//
// A graph defines the computational logic of a model and is comprised of a parameterized
// list of nodes that form a directed acyclic graph based on their inputs and outputs.
// This is the equivalent of the "network" or "graph" in many deep learning
// frameworks.
message GraphProto {
// The nodes in the graph, sorted topologically.
repeated NodeProto node = 1;
// The name of the graph.
string name = 2; // namespace Graph
// A list of named tensor values, used to specify constant inputs of the graph.
// Each initializer (both TensorProto as well SparseTensorProto) MUST have a name.
// The name MUST be unique across both initializer and sparse_initializer,
// but the name MAY also appear in the input list.
repeated TensorProto initializer = 5;
// Initializers (see above) stored in sparse format.
repeated SparseTensorProto sparse_initializer = 15;
// A human-readable documentation for this graph. Markdown is allowed.
string doc_string = 10;
// The inputs and outputs of the graph.
repeated ValueInfoProto input = 11;
repeated ValueInfoProto output = 12;
// Information for the values in the graph. The ValueInfoProto.name's
// must be distinct. It is optional for a value to appear in value_info list.
repeated ValueInfoProto value_info = 13;
// This field carries information to indicate the mapping among a tensor and its
// quantization parameter tensors. For example:
// For tensor 'a', it may have {'SCALE_TENSOR', 'a_scale'} and {'ZERO_POINT_TENSOR', 'a_zero_point'} annotated,
// which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model.
repeated TensorAnnotation quantization_annotation = 14;
reserved 3, 4, 6 to 9;
reserved "ir_version", "producer_version", "producer_tag", "domain";
}
// Tensors
//
// A serialized tensor value.
message TensorProto {
enum DataType {
UNDEFINED = 0;
// Basic types.
FLOAT = 1; // float
UINT8 = 2; // uint8_t
INT8 = 3; // int8_t
UINT16 = 4; // uint16_t
INT16 = 5; // int16_t
INT32 = 6; // int32_t
INT64 = 7; // int64_t
STRING = 8; // string
BOOL = 9; // bool
// IEEE754 half-precision floating-point format (16 bits wide).
// This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits.
FLOAT16 = 10;
DOUBLE = 11;
UINT32 = 12;
UINT64 = 13;
COMPLEX64 = 14; // complex with float32 real and imaginary components
COMPLEX128 = 15; // complex with float64 real and imaginary components
// Non-IEEE floating-point format based on IEEE754 single-precision
// floating-point number truncated to 16 bits.
// This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits.
BFLOAT16 = 16;
// Future extensions go here.
}
// The shape of the tensor.
repeated int64 dims = 1;
// The data type of the tensor.
// This field MUST have a valid TensorProto.DataType value
int32 data_type = 2;
// For very large tensors, we may want to store them in chunks, in which
// case the following fields will specify the segment that is stored in
// the current TensorProto.
message Segment {
int64 begin = 1;
int64 end = 2;
}
Segment segment = 3;
// Tensor content must be organized in row-major order.
//
// Depending on the data_type field, exactly one of the fields below with
// name ending in _data is used to store the elements of the tensor.
// For float and complex64 values
// Complex64 tensors are encoded as a single array of floats,
// with the real components appearing in odd numbered positions,
// and the corresponding imaginary component appearing in the
// subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]
// is encoded as [1.0, 2.0 ,3.0 ,4.0]
// When this field is present, the data_type field MUST be FLOAT or COMPLEX64.
repeated float float_data = 4 [packed = true];
// For int32, uint8, int8, uint16, int16, bool, and float16 values
// float16 values must be bit-wise converted to an uint16_t prior
// to writing to the buffer.
// When this field is present, the data_type field MUST be
// INT32, INT16, INT8, UINT16, UINT8, BOOL, FLOAT16 or BFLOAT16
repeated int32 int32_data = 5 [packed = true];
// For strings.
// Each element of string_data is a UTF-8 encoded Unicode
// string. No trailing null, no leading BOM. The protobuf "string"
// scalar type is not used to match ML community conventions.
// When this field is present, the data_type field MUST be STRING
repeated bytes string_data = 6;
// For int64.
// When this field is present, the data_type field MUST be INT64
repeated int64 int64_data = 7 [packed = true];
// Optionally, a name for the tensor.
string name = 8; // namespace Value
// A human-readable documentation for this tensor. Markdown is allowed.
string doc_string = 12;
// Serializations can either use one of the fields above, or use this
// raw bytes field. The only exception is the string case, where one is
// required to store the content in the repeated bytes string_data field.
//
// When this raw_data field is used to store tensor value, elements MUST
// be stored in as fixed-width, little-endian order.
// Floating-point data types MUST be stored in IEEE 754 format.
// Complex64 elements must be written as two consecutive FLOAT values, real component first.
// Complex128 elements must be written as two consecutive DOUBLE values, real component first.
// Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false).
//
// Note: the advantage of specific field rather than the raw_data field is
// that in some cases (e.g. int data), protobuf does a better packing via
// variable length storage, and may lead to smaller binary footprint.
// When this field is present, the data_type field MUST NOT be STRING or UNDEFINED
bytes raw_data = 9;
// Data can be stored inside the protobuf file using type-specific fields or raw_data.
// Alternatively, raw bytes data can be stored in an external file, using the external_data field.
// external_data stores key-value pairs describing data location. Recognized keys are:
// - "location" (required) - POSIX filesystem path relative to the directory where the ONNX
// protobuf model was stored
// - "offset" (optional) - position of byte at which stored data begins. Integer stored as string.
// Offset values SHOULD be multiples 4096 (page size) to enable mmap support.
// - "length" (optional) - number of bytes containing data. Integer stored as string.
// - "checksum" (optional) - SHA1 digest of file specified in under 'location' key.
repeated StringStringEntryProto external_data = 13;
// Location of the data for this tensor. MUST be one of:
// - DEFAULT - data stored inside the protobuf message. Data is stored in raw_data (if set) otherwise in type-specified field.
// - EXTERNAL - data stored in an external location as described by external_data field.
enum DataLocation {
DEFAULT = 0;
EXTERNAL = 1;
}
// If value not set, data is stored in raw_data (if set) otherwise in type-specified field.
DataLocation data_location = 14;
// For double
// Complex128 tensors are encoded as a single array of doubles,
// with the real components appearing in odd numbered positions,
// and the corresponding imaginary component appearing in the
// subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]
// is encoded as [1.0, 2.0 ,3.0 ,4.0]
// When this field is present, the data_type field MUST be DOUBLE or COMPLEX128
repeated double double_data = 10 [packed = true];
// For uint64 and uint32 values
// When this field is present, the data_type field MUST be
// UINT32 or UINT64
repeated uint64 uint64_data = 11 [packed = true];
}
// A serialized sparse-tensor value
message SparseTensorProto {
// The sequence of non-default values are encoded as a tensor of shape [NNZ].
// The default-value is zero for numeric tensors, and empty-string for string tensors.
// values must have a non-empty name present which serves as a name for SparseTensorProto
// when used in sparse_initializer list.
TensorProto values = 1;
// The indices of the non-default values, which may be stored in one of two formats.
// (a) Indices can be a tensor of shape [NNZ, rank] with the [i,j]-th value
// corresponding to the j-th index of the i-th value (in the values tensor).
// (b) Indices can be a tensor of shape [NNZ], in which case the i-th value
// must be the linearized-index of the i-th value (in the values tensor).
// The linearized-index can be converted into an index tuple (k_1,...,k_rank)
// using the shape provided below.
// The indices must appear in ascending order without duplication.
// In the first format, the ordering is lexicographic-ordering:
// e.g., index-value [1,4] must appear before [2,1]
TensorProto indices = 2;
// The shape of the underlying dense-tensor: [dim_1, dim_2, ... dim_rank]
repeated int64 dims = 3;
}
// Defines a tensor shape. A dimension can be either an integer value
// or a symbolic variable. A symbolic variable represents an unknown
// dimension.
message TensorShapeProto {
message Dimension {
oneof value {
int64 dim_value = 1;
string dim_param = 2; // namespace Shape
};
// Standard denotation can optionally be used to denote tensor
// dimensions with standard semantic descriptions to ensure
// that operations are applied to the correct axis of a tensor.
// Refer to https://github.com/onnx/onnx/blob/main/docs/DimensionDenotation.md#denotation-definition
// for pre-defined dimension denotations.
string denotation = 3;
};
repeated Dimension dim = 1;
}
// Types
//
// The standard ONNX data types.
message TypeProto {
message Tensor {
// This field MUST NOT have the value of UNDEFINED
// This field MUST have a valid TensorProto.DataType value
// This field MUST be present for this version of the IR.
int32 elem_type = 1;
TensorShapeProto shape = 2;
}
// repeated T
message Sequence {
// The type and optional shape of each element of the sequence.
// This field MUST be present for this version of the IR.
TypeProto elem_type = 1;
};
// map<K,V>
message Map {
// This field MUST have a valid TensorProto.DataType value
// This field MUST be present for this version of the IR.
// This field MUST refer to an integral type ([U]INT{8|16|32|64}) or STRING
int32 key_type = 1;
// This field MUST be present for this version of the IR.
TypeProto value_type = 2;
};
// wrapper for Tensor, Sequence, or Map
message Optional {
// The type and optional shape of the element wrapped.
// This field MUST be present for this version of the IR.
// Possible values correspond to OptionalProto.DataType enum
TypeProto elem_type = 1;
};
message SparseTensor {
// This field MUST NOT have the value of UNDEFINED
// This field MUST have a valid TensorProto.DataType value
// This field MUST be present for this version of the IR.
int32 elem_type = 1;
TensorShapeProto shape = 2;
}
oneof value {
// The type of a tensor.
Tensor tensor_type = 1;
// NOTE: DNN-only implementations of ONNX MAY elect to not support non-tensor values
// as input and output to graphs and nodes. These types are needed to naturally
// support classical ML operators. DNN operators SHOULD restrict their input
// and output types to tensors.
// The type of a sequence.
Sequence sequence_type = 4;
// The type of a map.
Map map_type = 5;
// The type of an optional.
Optional optional_type = 9;
// Type of the sparse tensor
SparseTensor sparse_tensor_type = 8;
}
// An optional denotation can be used to denote the whole
// type with a standard semantic description as to what is
// stored inside. Refer to https://github.com/onnx/onnx/blob/main/docs/TypeDenotation.md#type-denotation-definition
// for pre-defined type denotations.
string denotation = 6;
}
// Operator Sets
//
// OperatorSets are uniquely identified by a (domain, opset_version) pair.
message OperatorSetIdProto {
// The domain of the operator set being identified.
// The empty string ("") or absence of this field implies the operator
// set that is defined as part of the ONNX specification.
// This field MUST be present in this version of the IR when referring to any other operator set.
string domain = 1;
// The version of the operator set being identified.
// This field MUST be present in this version of the IR.
int64 version = 2;
}
// Operator/function status.
enum OperatorStatus {
EXPERIMENTAL = 0;
STABLE = 1;
}
message FunctionProto {
// The name of the function, similar usage of op_type in OperatorProto.
// Combined with FunctionProto.domain, this forms the unique identity of
// the FunctionProto.
string name = 1;
// Deprecated since IR Version 8
// optional int64 since_version = 2;
reserved 2;
reserved "since_version";
// Deprecated since IR Version 8
// optional OperatorStatus status = 3;
reserved 3;
reserved "status";
// The inputs and outputs of the function.
repeated string input = 4;
repeated string output = 5;
// The attribute parameters of the function.
// It is for function parameters without default values.
repeated string attribute = 6;
// The attribute protos of the function.
// It is for function attributes with default values.
// A function attribute shall be represented either as
// a string attribute or an AttributeProto, not both.
repeated AttributeProto attribute_proto = 11;
// The nodes in the function.
repeated NodeProto node = 7;
// A human-readable documentation for this function. Markdown is allowed.
string doc_string = 8;
// The OperatorSets this function body (graph) relies on.
//
// All nodes in the function body (graph) will bind against the operator
// with the same-domain/same-op_type operator with the HIGHEST version
// in the referenced operator sets. This means at most one version can be relied
// for one domain.
//
// The operator sets imported by FunctionProto should be compatible with the ones
// imported by ModelProto. Example, if same operator set say 'A' is imported by FunctionProto
// and ModelProto then versions for the operator set may be different but,
// the operator schema returned for op_type, domain, version combination
// for both the versions should be same.
repeated OperatorSetIdProto opset_import = 9;
// The domain which this function belongs to. Combined with FunctionProto.name, this forms the unique identity of
// the FunctionProto.
string domain = 10;
}
// For using protobuf-lite
option optimize_for = LITE_RUNTIME;

View File

@ -0,0 +1,187 @@
use std::collections::HashMap;
use burn::tensor;
use burn_ndarray::NdArrayBackend;
use super::{
ir::{ArgType, Argument, Node, NodeType, Tensor},
op_configuration::{conv2d_config, flatten_config, linear_config},
};
/// Infer the shape of each node and replace the shape of the output tensor
pub fn shape_inference(
nodes: &mut Vec<Node>,
graph_inputs: &Vec<Argument>,
graph_outputs: &mut Vec<Argument>,
) {
let mut prev_outputs: HashMap<String, Argument> = HashMap::new();
for output in graph_inputs.iter() {
prev_outputs.insert(output.name.clone(), output.clone());
}
for node in nodes.iter_mut() {
match node.node_type {
NodeType::Conv2d => conv2d(node, &prev_outputs),
NodeType::Linear => linear(node, &prev_outputs),
NodeType::Relu => relu(node, &prev_outputs),
NodeType::Flatten => flatten(node, &prev_outputs),
NodeType::LogSoftmax => log_softmax(node, &prev_outputs),
_ => todo!(
"shape inference for {:?} is not implemented",
node.node_type
),
}
for output in node.outputs.iter() {
prev_outputs.insert(output.name.clone(), output.clone());
}
}
//update the outputs of the graph from prev_outputs
for output in graph_outputs.iter_mut() {
let arg = prev_outputs.get(output.name.as_str()).unwrap();
output.arg_type = arg.arg_type.clone();
}
}
/// Infer the shape of the output tensor of a Conv2d node
fn linear(curr: &mut Node, prev_outpus: &HashMap<String, Argument>) {
if curr.inputs.len() != 1 {
panic!("Linear: multiple inputs are not supported");
}
// Fill in the missing information about the input tensor from the previous outputs
let prev_node_output = prev_outpus.get(curr.inputs[0].name.as_str()).unwrap();
curr.inputs[0].arg_type = prev_node_output.arg_type.clone();
// Extract the configuration of the linear layer (inputs are known)
let config = linear_config(curr);
// Replace the output tensor
let curr_input = &mut curr.inputs[0];
let ArgType::Tensor(tensor) = curr_input.clone().arg_type.unwrap();
let mut new_shape = tensor.shape.clone();
// Update the last dimension of the shape
new_shape[tensor.shape.len() - 1] = config.d_input;
// Update the output tensor
curr.outputs[0].arg_type = Some(ArgType::Tensor(Tensor {
name: None,
shape: new_shape,
data: None,
elem_type: tensor.elem_type,
}));
}
/// Infers the shape of a Relu node and replaces the shape of the output tensor.
fn relu(curr: &mut Node, prev_outpus: &HashMap<String, Argument>) {
if curr.inputs.len() != 1 {
panic!("Relu: multiple inputs are not supported");
}
// Fill in the missing information about the input tensor from the previous outputs
let prev_node_output = prev_outpus.get(curr.inputs[0].name.as_str()).unwrap();
curr.inputs[0].arg_type = prev_node_output.arg_type.clone();
curr.outputs[0].arg_type = prev_node_output.arg_type.clone();
}
/// Infers the shape of a Flatten node and replaces the shape of the output tensor.
fn flatten(curr: &mut Node, prev_outpus: &HashMap<String, Argument>) {
if curr.inputs.len() != 1 {
panic!("Flatten: multiple inputs are not supported");
}
// Fill in the missing information about the input tensor from the previous outputs
let prev_node_output = prev_outpus.get(curr.inputs[0].name.as_str()).unwrap();
curr.inputs[0].arg_type = prev_node_output.arg_type.clone();
let curr_input = &mut curr.inputs[0];
let ArgType::Tensor(tensor) = curr_input.clone().arg_type.unwrap();
let input_shape = tensor.shape;
let (start_dim, end_dim) = flatten_config(curr);
// calculate the new shape (code is taken from the flatten op)
// use the same logic as in the flatten op
// unfortunately the output tensor's dimensions (D2) are not known at compile time
// that's why we have to calculate the new shape at runtime
let mut new_dims = vec![0; input_shape.len() - (end_dim - start_dim)];
let mut flatten_dims = 1;
for i in input_shape[start_dim..=end_dim].iter() {
flatten_dims *= i;
}
new_dims[..start_dim].copy_from_slice(&input_shape[..start_dim]);
new_dims[start_dim] = flatten_dims;
new_dims[start_dim + 1..].copy_from_slice(&input_shape[end_dim + 1..]);
curr.outputs[0].arg_type = Some(ArgType::Tensor(Tensor {
name: None,
shape: new_dims,
data: None,
elem_type: tensor.elem_type,
}));
}
/// Infers the shape of a LogSoftmax node and replaces the shape of the output tensor.
fn log_softmax(curr: &mut Node, prev_outpus: &HashMap<String, Argument>) {
if curr.inputs.len() != 1 {
panic!("LogSoftmax: multiple inputs are not supported");
}
// Fill in the missing information about the input tensor from the previous outputs
let prev_node_output = prev_outpus.get(curr.inputs[0].name.as_str()).unwrap();
curr.inputs[0].arg_type = prev_node_output.arg_type.clone();
curr.outputs[0].arg_type = prev_node_output.arg_type.clone();
}
/// Infers the shape of a Conv2d node and replaces the shape of the output tensor.
///
/// The shape of the output tensor is calculated by running the actual convolution operation.
fn conv2d(curr: &mut Node, prev_outpus: &HashMap<String, Argument>) {
// copy the type from the previous output to the current input
if curr.inputs.len() != 1 {
panic!("Conv2d: multiple inputs are not supported");
}
// Fill in the missing information about the input tensor from the previous outputs
let curr_input = &mut curr.inputs[0];
let prev = prev_outpus.get(curr_input.name.as_str()).unwrap();
curr_input.arg_type = prev.arg_type.clone();
// extract the channels from the weight tensor's shape [out_channels, in_channels, ...]
let ArgType::Tensor(tensor) = curr_input.clone().arg_type.unwrap();
let elem_type = tensor.elem_type;
if tensor.shape.len() != 4 {
panic!("Conv2d: input tensor must be 4D");
}
let mut input_shape: [usize; 4] = [0; 4];
input_shape.copy_from_slice(tensor.shape.as_slice());
// using the real configuration, run through op and calculate an actual shape of the output tensor
let config = conv2d_config(curr);
let conv2d = config.init();
let input = tensor::Tensor::<NdArrayBackend<f32>, 4>::zeros(input_shape);
let output = conv2d.forward(input);
let output_shape = output.shape().dims.to_vec();
curr.outputs[0].arg_type = Some(ArgType::Tensor(Tensor {
name: None,
shape: output_shape,
data: None,
elem_type,
}));
}

Binary file not shown.

View File

@ -12,6 +12,6 @@ edition = "2021"
proc-macro = true
[dependencies]
syn = "1.0.109"
quote = "1.0.26"
proc-macro2 = "1.0.52"
syn = {workspace = true}
quote = {workspace = true}
proc-macro2 = {workspace = true}

View File

@ -0,0 +1,22 @@
[package]
authors = ["Dilshod Tadjibaev (@antimora)"]
edition = "2021"
license = "MIT/Apache-2.0"
name = "onnx-inference"
publish = false
version = "0.6.0"
[features]
default = []
[dependencies]
burn = {path = "../../burn"}
burn-ndarray = {path = "../../burn-ndarray"}
serde = {workspace = true}
[dev-dependencies]
burn-dataset = {path = "../../burn-dataset"}
[build-dependencies]
burn-import = {path = "../../burn-import"}

View File

@ -0,0 +1,4 @@
# ONNX Inference
This crate provides a simple example for importing ONNX model to Burn.

View File

@ -0,0 +1,9 @@
use burn_import::onnx::ModelCodeGen;
fn main() {
// Generate the model code from the ONNX file.
ModelCodeGen::new()
.input("src/model/mnist.onnx")
.out_dir("model/")
.run_from_script();
}

View File

@ -0,0 +1,155 @@
# Originally copied and modified from: https://github.com/pytorch/examples/blob/main/mnist/main.py
# under the following license: BSD-3-Clause license
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 8, 3)
self.conv2 = nn.Conv2d(8, 16, 3)
self.conv3 = nn.Conv2d(16, 24, 3)
self.dropout1 = nn.Dropout(0.3)
self.fc1 = nn.Linear(24 * 22 * 22, 32)
self.fc2 = nn.Linear(32, 10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = self.conv3(x)
x = F.relu(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout1(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
def train(args, model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
if args.dry_run:
break
def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
def main():
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=14, metavar='N',
help='number of epochs to train (default: 14)')
parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
help='learning rate (default: 1.0)')
parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
help='Learning rate step gamma (default: 0.7)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--no-mps', action='store_true', default=False,
help='disables macOS GPU training')
parser.add_argument('--dry-run', action='store_true', default=False,
help='quickly check a single pass')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--save-model', action='store_true', default=False,
help='For Saving the current Model')
parser.add_argument('--export-onnx', action='store_true', default=True,
help='For Saving the current Model in ONNX format')
args = parser.parse_args()
use_cuda = not args.no_cuda and torch.cuda.is_available()
use_mps = not args.no_mps and torch.backends.mps.is_available()
torch.manual_seed(args.seed)
if use_cuda:
device = torch.device("cuda")
elif use_mps:
device = torch.device("mps")
else:
device = torch.device("cpu")
train_kwargs = {'batch_size': args.batch_size}
test_kwargs = {'batch_size': args.test_batch_size}
if use_cuda:
cuda_kwargs = {'num_workers': 1,
'pin_memory': True,
'shuffle': True}
train_kwargs.update(cuda_kwargs)
test_kwargs.update(cuda_kwargs)
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
dataset1 = datasets.MNIST('/tmp/mnist-data', train=True, download=True,
transform=transform)
dataset2 = datasets.MNIST('/tmp/mnist-data', train=False,
transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
model = Net().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
for epoch in range(1, args.epochs + 1):
train(args, model, device, train_loader, optimizer, epoch)
test(model, device, test_loader)
scheduler.step()
if args.save_model:
torch.save(model, "mnist.pt")
if args.export_onnx:
dummy_input = torch.randn(1, 1, 28, 28, device=device)
torch.onnx.export(model, dummy_input, "mnist.onnx")
if __name__ == '__main__':
main()

View File

@ -0,0 +1 @@
pub mod model;

View File

@ -0,0 +1,17 @@
use burn::tensor;
use burn_ndarray::NdArrayBackend;
use onnx_inference::model::mnist::{Model, INPUT1_SHAPE};
fn main() {
// Create a new model
let model: Model<NdArrayBackend<f32>> = Model::new();
// Create a new input tensor (all zeros for demonstration purposes)
let input = tensor::Tensor::<NdArrayBackend<f32>, 4>::zeros(INPUT1_SHAPE);
// Run the model
let output = model.forward(input);
// Print the output
println!("{:?}", output);
}

Binary file not shown.

View File

@ -0,0 +1,3 @@
pub mod mnist {
include!(concat!(env!("OUT_DIR"), "/model/mnist.rs"));
}