mirror of https://github.com/tracel-ai/burn.git
Add foundation for importing ONNX files (#297)
This commit is contained in:
parent
a74f26620a
commit
df980d534e
|
@ -79,3 +79,8 @@ jobs:
|
|||
with:
|
||||
crate: burn-train
|
||||
|
||||
test-burn-import:
|
||||
uses: burn-rs/burn/.github/workflows/test-template.yml@main
|
||||
with:
|
||||
crate: burn-import
|
||||
|
||||
|
|
28
Cargo.toml
28
Cargo.toml
|
@ -5,24 +5,25 @@ resolver = "2"
|
|||
|
||||
members = [
|
||||
"burn",
|
||||
"burn-core",
|
||||
"burn-train",
|
||||
"burn-derive",
|
||||
"burn-tensor",
|
||||
"burn-tensor-testgen",
|
||||
"burn-dataset",
|
||||
"burn-tch",
|
||||
"burn-ndarray",
|
||||
"burn-autodiff",
|
||||
"burn-common",
|
||||
"burn-core",
|
||||
"burn-dataset",
|
||||
"burn-derive",
|
||||
"burn-import",
|
||||
"burn-ndarray",
|
||||
"burn-no-std-tests",
|
||||
"burn-tch",
|
||||
"burn-tensor-testgen",
|
||||
"burn-tensor",
|
||||
"burn-train",
|
||||
"examples/*",
|
||||
]
|
||||
|
||||
[workspace.dependencies]
|
||||
const-random = "0.1.15"
|
||||
dashmap = "5.4.0"
|
||||
dirs = "4.0.0"
|
||||
dirs = "5.0.0"
|
||||
fake = "2.5.0"
|
||||
flate2 = "1.0.25"
|
||||
half = {version = "2", features = ["alloc", "num-traits"], default-features = false}
|
||||
|
@ -33,6 +34,11 @@ log = "0.4.17"
|
|||
log4rs = "1.2.0"
|
||||
spin = {version = "0.9.5", features = ["mutex", "spin_mutex"]}
|
||||
thiserror = "1.0.39"
|
||||
proc-macro2 = "1.0.54"
|
||||
quote = "1.0.26"
|
||||
syn = "2.0"
|
||||
strum = "0.24"
|
||||
strum_macros = "0.24"
|
||||
|
||||
#
|
||||
# The following packages disable the "std" feature for no_std compatibility
|
||||
|
@ -44,7 +50,7 @@ rand = {version = "0.8.5", default-features = false, features = ["std_rng"]}# st
|
|||
rand_distr = {version = "0.4.3", default-features = false}
|
||||
uuid = {version = "1.3.0", default-features = false}
|
||||
|
||||
bincode = {version = "2.0.0-rc", features = ["alloc", "serde"], default-features = false}
|
||||
rmp-serde = {version = "1.1.1"}
|
||||
serde = {version = "1.0.155", default-features = false, features = ["derive", "alloc"]}# alloc is for no_std, derive is needed
|
||||
serde_json = {version = "1.0.94", default-features = false}
|
||||
rmp-serde = {version = "1.1.1"}
|
||||
bincode = {version = "2.0.0-rc", features=["alloc", "serde"], default-features = false}
|
||||
|
|
|
@ -0,0 +1,248 @@
|
|||
# NOTICES AND INFORMATION
|
||||
|
||||
This file contains notices and information required by libraries that this
|
||||
repository copied or derived from.
|
||||
|
||||
## PyTorch MNIST Example
|
||||
|
||||
**Source**: https://github.com/pytorch/examples/blob/main/mnist/main.py
|
||||
|
||||
License: BSD 3-Clause License
|
||||
|
||||
Copyright (c) 2017,
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
* Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
* Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
## ONNX
|
||||
|
||||
**Source**: https://github.com/onnx/onnx/blob/main/onnx/onnx.proto3
|
||||
|
||||
License: Apache License 2.0
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
|
|
@ -16,6 +16,7 @@ version = "0.6.0"
|
|||
proc-macro = true
|
||||
|
||||
[dependencies]
|
||||
proc-macro2 = "1.0.52"
|
||||
quote = "1.0.26"
|
||||
syn = "1.0.109"
|
||||
proc-macro2 = {workspace = true}
|
||||
quote = {workspace = true}
|
||||
# syn = {workspace = true}
|
||||
syn = "1.0.76" # TODO upgrade to 2.0
|
|
@ -0,0 +1,36 @@
|
|||
[package]
|
||||
authors = [
|
||||
"Dilshod Tadjibaev (@antimora)",
|
||||
]
|
||||
edition = "2021"
|
||||
license = "MIT/Apache-2.0"
|
||||
name = "burn-import"
|
||||
readme = "README.md"
|
||||
repository = "https://github.com/burn-rs/burn/tree/main/burn-import"
|
||||
|
||||
version = "0.6.0"
|
||||
|
||||
[features]
|
||||
default = ["onnx"]
|
||||
onnx = []
|
||||
|
||||
[dependencies]
|
||||
burn = {path = "../burn", version = "0.6.0"}
|
||||
burn-ndarray = {path = "../burn-ndarray", version = "0.6.0"}
|
||||
|
||||
half = {workspace = true}
|
||||
proc-macro2 = {workspace = true}
|
||||
protobuf = {version = "3.2", features = ["with-bytes"]}
|
||||
quote = {workspace = true}
|
||||
rust-format = {version = "0.3", features = ["token_stream", "post_process"]}
|
||||
serde = {workspace = true}
|
||||
strum = {workspace = true}
|
||||
strum_macros = {workspace = true}
|
||||
syn = {workspace = true, features = ["parsing"]}
|
||||
topological-sort = {version = "0.2.2"}
|
||||
|
||||
[build-dependencies]
|
||||
protobuf-codegen = {version = "3.2"}
|
||||
|
||||
[dev-dependencies]
|
||||
rstest = "0.17.0"
|
|
@ -0,0 +1,65 @@
|
|||
# Burn Import
|
||||
|
||||
`burn-import` is a crate designed to facilitate importing models trained in other machine learning
|
||||
frameworks into the Burn framework. This tool generates a Rust source file that aligns the source
|
||||
model with Burn's model and converts tensor data into a format compatible with Burn.
|
||||
|
||||
Currently under development, `burn-import` supports importing ONNX models with a limited set of
|
||||
operators.
|
||||
|
||||
## Supported ONNX Operators
|
||||
|
||||
- Conv2d
|
||||
- Gemm (Linear layer)
|
||||
- Flatten
|
||||
- LogSoftmax
|
||||
|
||||
## Usage
|
||||
|
||||
### Importing ONNX models
|
||||
|
||||
In `build.rs`, add the following:
|
||||
|
||||
```rust
|
||||
use burn_import::onnx::ModelCodeGen;
|
||||
|
||||
fn main() {
|
||||
ModelCodeGen::new()
|
||||
.input("src/model/mnist.onnx") // Path to the ONNX model
|
||||
.out_dir("model/") // Directory to output the generated Rust source file (under target/)
|
||||
.run_from_script();
|
||||
}
|
||||
```
|
||||
|
||||
Then, add the following to mod.rs under `src/model`:
|
||||
|
||||
```rust
|
||||
pub mod mnist {
|
||||
include!(concat!(env!("OUT_DIR"), "/model/mnist.rs"));
|
||||
}
|
||||
```
|
||||
|
||||
Finally, in your code, you can use the imported model as follows:
|
||||
|
||||
```rust
|
||||
use burn::tensor;
|
||||
use burn_ndarray::NdArrayBackend;
|
||||
use onnx_inference::model::mnist::{Model, INPUT1_SHAPE};
|
||||
|
||||
fn main() {
|
||||
|
||||
// Create a new model
|
||||
let model: Model<NdArrayBackend<f32>> = Model::new();
|
||||
|
||||
// Create a new input tensor (all zeros for demonstration purposes)
|
||||
let input = tensor::Tensor::<NdArrayBackend<f32>, 4>::zeros(INPUT1_SHAPE);
|
||||
|
||||
// Run the model
|
||||
let output = model.forward(input);
|
||||
|
||||
// Print the output
|
||||
println!("{:?}", output);
|
||||
}
|
||||
```
|
||||
|
||||
You can view the working example in the `examples/onnx-inference` directory.
|
|
@ -0,0 +1,11 @@
|
|||
fn main() {
|
||||
if cfg!(feature = "onnx") {
|
||||
// Generate the onnx protobuf files
|
||||
protobuf_codegen::Codegen::new()
|
||||
.pure()
|
||||
.includes(["src"])
|
||||
.input("src/onnx/protos/onnx.proto")
|
||||
.cargo_out_dir("onnx-protos")
|
||||
.run_from_script();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,6 @@
|
|||
#![allow(clippy::ptr_arg)]
|
||||
#![allow(clippy::single_match)]
|
||||
#![allow(clippy::upper_case_acronyms)]
|
||||
|
||||
#[cfg(feature = "onnx")]
|
||||
pub mod onnx;
|
|
@ -0,0 +1,47 @@
|
|||
use super::ir::{AttributeValue, Node, NodeType};
|
||||
|
||||
/// The function transforms the graph into a new one where the nodes are coalesced into a single node.
|
||||
pub fn coalesce(nodes: &mut Vec<Node>) {
|
||||
for node in nodes.iter_mut() {
|
||||
match node.node_type {
|
||||
NodeType::Gemm => convert_gemm(node),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// This function converts a Gemm node into a Linear node
|
||||
///
|
||||
/// Warning: This function is not complete yet.
|
||||
/// It only supports the case where the Gemm node is a straight linear transformation.
|
||||
fn convert_gemm(node: &mut Node) {
|
||||
if node.inputs.len() != 1 {
|
||||
panic!("Gemm node must have 3 inputs");
|
||||
}
|
||||
|
||||
if node.outputs.len() != 1 {
|
||||
panic!("Gemm node must have 1 output");
|
||||
}
|
||||
let straight_linear = match (
|
||||
node.attrs.get("alpha"),
|
||||
node.attrs.get("beta"),
|
||||
node.attrs.get("transB"),
|
||||
) {
|
||||
(
|
||||
Some(AttributeValue::Float32(alpha)),
|
||||
Some(AttributeValue::Float32(beta)),
|
||||
Some(AttributeValue::Int64(trans_b)),
|
||||
) => *alpha == 1.0 && *beta == 1.0 && *trans_b == 1,
|
||||
_ => false,
|
||||
};
|
||||
|
||||
if straight_linear {
|
||||
node.node_type = NodeType::Linear;
|
||||
node.is_stateful = true;
|
||||
node.attrs.remove("alpha");
|
||||
node.attrs.remove("beta");
|
||||
node.attrs.remove("transB");
|
||||
} else {
|
||||
panic!("Full Gemm node not supported yet.");
|
||||
}
|
||||
}
|
|
@ -0,0 +1,574 @@
|
|||
use std::{
|
||||
collections::HashSet,
|
||||
env,
|
||||
fs::{self, create_dir_all},
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
|
||||
use burn::nn::conv::Conv2dPaddingConfig;
|
||||
use proc_macro2::{Span, TokenStream};
|
||||
use quote::quote;
|
||||
use syn::{Ident, Type};
|
||||
|
||||
use crate::onnx::{
|
||||
ir::{ArgType, Node, NodeType},
|
||||
op_configuration::{conv2d_config, flatten_config, linear_config, log_softmax_config},
|
||||
};
|
||||
|
||||
use super::{convert::parse_onnx, ir::Graph};
|
||||
|
||||
use rust_format::{Config, Edition, Formatter, PostProcess, RustFmt};
|
||||
|
||||
/// Code generation for onnx files.
|
||||
#[derive(Debug, Default)]
|
||||
pub struct ModelCodeGen {
|
||||
out_dir: Option<PathBuf>,
|
||||
|
||||
/// List of onnx files to generate source code from.
|
||||
inputs: Vec<PathBuf>,
|
||||
}
|
||||
|
||||
/// Generate code from `.onnx` files and save it to the `out_dir`.
|
||||
impl ModelCodeGen {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Set output directory.
|
||||
pub fn out_dir(&mut self, out_dir: &str) -> &mut Self {
|
||||
let cargo_out_dir = env::var("OUT_DIR").expect("OUT_DIR env is not set");
|
||||
let mut path = PathBuf::from(cargo_out_dir);
|
||||
|
||||
// Append the out_dir to the cargo_out_dir
|
||||
path.push(Path::new(out_dir));
|
||||
|
||||
self.out_dir = Some(path);
|
||||
self
|
||||
}
|
||||
|
||||
/// Add input file.
|
||||
pub fn input(&mut self, input: &str) -> &mut Self {
|
||||
self.inputs.push(input.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Run code generation.
|
||||
///
|
||||
/// This function is intended to be called from `build.rs` script.
|
||||
pub fn run_from_script(&self) {
|
||||
self.run();
|
||||
}
|
||||
|
||||
/// Run code generation.
|
||||
pub fn run(&self) {
|
||||
let config = Config::new_str()
|
||||
.post_proc(PostProcess::ReplaceMarkersAndDocBlocks)
|
||||
.edition(Edition::Rust2021);
|
||||
|
||||
let rust_formatter = RustFmt::from_config(config);
|
||||
|
||||
let out_dir = self.out_dir.as_ref().expect("out_dir is not set");
|
||||
create_dir_all(out_dir).unwrap();
|
||||
|
||||
for input in self.inputs.iter() {
|
||||
let file_name = input.file_stem().unwrap();
|
||||
let out_file = out_dir.join(file_name);
|
||||
let out_file = out_file.with_extension("rs");
|
||||
|
||||
let model = ModelSourceCode::new(input);
|
||||
let code_str = rust_formatter.format_tokens(model.body()).unwrap();
|
||||
|
||||
fs::write(out_file, code_str).unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A model that can be used to generate code
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ModelSourceCode {
|
||||
onnx_path: PathBuf,
|
||||
pub graph: Graph,
|
||||
}
|
||||
|
||||
impl ModelSourceCode {
|
||||
/// Create a new model from the onnx file
|
||||
pub fn new<P: AsRef<Path>>(onnx_path: P) -> Self {
|
||||
let graph = parse_onnx(onnx_path.as_ref());
|
||||
Self {
|
||||
onnx_path: onnx_path.as_ref().to_path_buf(),
|
||||
graph,
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates source code for the model
|
||||
pub fn body(&self) -> TokenStream {
|
||||
let input = "Model"; // TODO make this a parameter
|
||||
let input = Ident::new(input, Span::call_site());
|
||||
|
||||
let declaration = self.declaration(&input);
|
||||
|
||||
let file_path = self.onnx_path.to_str().unwrap();
|
||||
|
||||
let top_file_comment = format!("Generated from {file_path} by burn-import");
|
||||
|
||||
let mut imports: HashSet<String> = HashSet::new();
|
||||
|
||||
let implementation = self.implementation(&mut imports);
|
||||
|
||||
let import_statements = self.import_statements(&imports);
|
||||
|
||||
let shape_constants = self.shape_constants();
|
||||
|
||||
//TODO print out the old -> new name mapping
|
||||
quote! {
|
||||
_comment_!(#top_file_comment);
|
||||
_blank_!();
|
||||
_blank_!();
|
||||
#import_statements
|
||||
_blank_!();
|
||||
#shape_constants
|
||||
_blank_!();
|
||||
#declaration
|
||||
_blank_!();
|
||||
#[allow(dead_code)]
|
||||
#[allow(clippy::new_without_default)]
|
||||
#[allow(clippy::let_and_return)]
|
||||
#implementation
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
fn shape_constants(&self) -> TokenStream {
|
||||
let input_constants = self.graph.inputs.iter().enumerate().map(|(i, input)| {
|
||||
let name = format!("INPUT{}_SHAPE", i + 1);
|
||||
let name = Ident::new(&name, Span::call_site());
|
||||
let ArgType::Tensor(tensor) = input.clone().arg_type.unwrap();
|
||||
let dims = tensor.shape;
|
||||
let dims_count = dims.len();
|
||||
quote! {
|
||||
pub const #name: [usize; #dims_count] = [#(#dims),*];
|
||||
}
|
||||
});
|
||||
|
||||
let output_constants = self.graph.outputs.iter().enumerate().map(|(i, input)| {
|
||||
let name = format!("OUTPUT{}_SHAPE", i + 1);
|
||||
let name = Ident::new(&name, Span::call_site());
|
||||
let ArgType::Tensor(tensor) = input.clone().arg_type.unwrap();
|
||||
let dims = tensor.shape;
|
||||
let dims_count = dims.len();
|
||||
quote! {
|
||||
pub const #name: [usize; #dims_count] = [#(#dims),*];
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
#(#input_constants)*
|
||||
#(#output_constants)*
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates import statements for the model
|
||||
fn import_statements(&self, imports: &HashSet<String>) -> TokenStream {
|
||||
let mut import_tokens = vec![];
|
||||
|
||||
for import in imports.iter() {
|
||||
let path: syn::Path =
|
||||
syn::parse_str(import).expect("Unable to parse input string as a path");
|
||||
|
||||
import_tokens.push(quote! { #path });
|
||||
}
|
||||
|
||||
quote! {
|
||||
use burn::{
|
||||
module::Module,
|
||||
nn,
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
#(use #import_tokens;)*
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates the declaration portion of the source code for the model
|
||||
fn declaration(&self, name: &Ident) -> TokenStream {
|
||||
let fields = self.declaration_fields();
|
||||
|
||||
let mut field_names = vec![];
|
||||
let mut field_types = vec![];
|
||||
|
||||
for (field_name, field_type) in fields.iter() {
|
||||
field_names.push(field_name);
|
||||
field_types.push(field_type);
|
||||
}
|
||||
|
||||
quote! {
|
||||
// TODO add documentation
|
||||
#[doc = "This is a generated model from an ONNX file"]
|
||||
#[derive(Module, Debug)]
|
||||
pub struct #name<B: Backend> {
|
||||
#(
|
||||
#field_names: #field_types,
|
||||
)*
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
/// Model implementation code
|
||||
fn implementation(&self, imports: &mut HashSet<String>) -> TokenStream {
|
||||
let forward_method = self.forward_method(imports);
|
||||
|
||||
let new_method = self.new_method();
|
||||
|
||||
quote! {
|
||||
impl<B: Backend> Model<B> {
|
||||
#new_method
|
||||
#forward_method
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates the new method for the model
|
||||
fn forward_method(&self, imports: &mut HashSet<String>) -> TokenStream {
|
||||
let inputs = self.forward_signature_input();
|
||||
let return_type = self.forward_signature_return();
|
||||
let results = self.forward_method_results();
|
||||
|
||||
let mut call_nodes: Vec<TokenStream> = vec![];
|
||||
|
||||
for node in self.graph.nodes.iter() {
|
||||
if node.is_stateful {
|
||||
call_nodes.push(Self::node_call_stateful(node));
|
||||
} else {
|
||||
call_nodes.push(Self::node_call_stateless(node, imports));
|
||||
}
|
||||
}
|
||||
|
||||
quote! {
|
||||
pub fn forward(&self, #(#inputs,)*) -> #return_type {
|
||||
#(#call_nodes)*
|
||||
#results
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates source code for the stateful node calls, i.e. conv, dropout, etc.
|
||||
fn node_call_stateful(node: &Node) -> TokenStream {
|
||||
if !node.is_stateful {
|
||||
panic!("Node must be stateful");
|
||||
}
|
||||
|
||||
let name = Ident::new(&node.name, Span::call_site());
|
||||
|
||||
let mut inputs = vec![];
|
||||
|
||||
for input in node.inputs.iter() {
|
||||
let name = Ident::new(&input.name, Span::call_site());
|
||||
inputs.push(quote! {
|
||||
#name
|
||||
});
|
||||
}
|
||||
|
||||
let mut outputs = vec![];
|
||||
|
||||
for output in node.outputs.iter() {
|
||||
let name = Ident::new(&output.name, Span::call_site());
|
||||
outputs.push(quote! {
|
||||
#name
|
||||
});
|
||||
}
|
||||
|
||||
if outputs.len() == 1 {
|
||||
let output = outputs.pop().unwrap();
|
||||
quote! {
|
||||
let #output = self.#name.forward(#(#inputs,)*);
|
||||
}
|
||||
} else {
|
||||
quote! {
|
||||
let (#(#outputs,)*) = self.#name.forward(#(#inputs,)*);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates source code for the forward method results
|
||||
fn forward_method_results(&self) -> TokenStream {
|
||||
let mut outputs = vec![];
|
||||
for output in self.graph.outputs.iter() {
|
||||
let name = Ident::new(&output.name, Span::call_site());
|
||||
outputs.push(quote! {
|
||||
#name
|
||||
});
|
||||
}
|
||||
if outputs.len() == 1 {
|
||||
let output = outputs.pop().unwrap();
|
||||
quote! {
|
||||
#output
|
||||
}
|
||||
} else {
|
||||
quote! {
|
||||
(#(#outputs,)*)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates source code for the stateless node calls, i.e. add, mul, etc.
|
||||
fn node_call_stateless(node: &Node, imports: &mut HashSet<String>) -> TokenStream {
|
||||
if node.is_stateful {
|
||||
panic!("Node must be stateless");
|
||||
}
|
||||
|
||||
let mut inputs = vec![];
|
||||
|
||||
for input in node.inputs.iter() {
|
||||
let name = Ident::new(&input.name, Span::call_site());
|
||||
inputs.push(quote! {
|
||||
#name
|
||||
});
|
||||
}
|
||||
|
||||
let mut outputs = vec![];
|
||||
|
||||
for output in node.outputs.iter() {
|
||||
let name = Ident::new(&output.name, Span::call_site());
|
||||
outputs.push(quote! {
|
||||
#name
|
||||
});
|
||||
}
|
||||
|
||||
let rhs = Self::node_call_stateless_rhs(node, imports);
|
||||
|
||||
if outputs.len() == 1 {
|
||||
let output = outputs.pop().unwrap();
|
||||
quote! {
|
||||
let #output = #rhs;
|
||||
}
|
||||
} else {
|
||||
quote! {
|
||||
let (#(#outputs,)*) = #rhs;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates source code for the right hand side stateless node calls, i.e. add, relu, etc.
|
||||
fn node_call_stateless_rhs(node: &Node, imports: &mut HashSet<String>) -> TokenStream {
|
||||
let mut inputs = vec![];
|
||||
|
||||
for input in node.inputs.iter() {
|
||||
let name = Ident::new(&input.name, Span::call_site());
|
||||
inputs.push(quote! {
|
||||
#name
|
||||
});
|
||||
}
|
||||
|
||||
let input1 = inputs.pop().unwrap();
|
||||
|
||||
match node.node_type {
|
||||
NodeType::Relu => {
|
||||
imports.insert("burn::tensor::activation::relu".to_string());
|
||||
|
||||
quote! { relu(#input1) }
|
||||
}
|
||||
NodeType::LogSoftmax => {
|
||||
imports.insert("burn::tensor::activation::log_softmax".to_string());
|
||||
let dim = log_softmax_config(node);
|
||||
|
||||
quote! { log_softmax(#input1, #dim) }
|
||||
}
|
||||
NodeType::Flatten => {
|
||||
let (start_dim, end_dim) = flatten_config(node);
|
||||
|
||||
quote! { #input1.flatten(#start_dim, #end_dim) }
|
||||
}
|
||||
_ => quote! {},
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates the forward method signature
|
||||
fn forward_signature_input(&self) -> Vec<TokenStream> {
|
||||
let mut fields = vec![];
|
||||
|
||||
for input in self.graph.inputs.iter() {
|
||||
let name = Ident::new(&input.name, Span::call_site());
|
||||
|
||||
let ty = match input.arg_type.as_ref().unwrap() {
|
||||
ArgType::Tensor(tensor) => {
|
||||
let d = &tensor.shape.len();
|
||||
syn::parse_str::<Type>(format!("Tensor<B, {d}>").as_str()).unwrap()
|
||||
}
|
||||
};
|
||||
|
||||
fields.push(quote! {
|
||||
#name: #ty
|
||||
});
|
||||
}
|
||||
fields
|
||||
}
|
||||
|
||||
/// Generates the forward method return signature
|
||||
fn forward_signature_return(&self) -> TokenStream {
|
||||
let mut field_types = vec![];
|
||||
|
||||
for output in self.graph.outputs.iter() {
|
||||
let ty = match output.arg_type.as_ref().unwrap() {
|
||||
ArgType::Tensor(tensor) => {
|
||||
let d = &tensor.shape.len();
|
||||
syn::parse_str::<Type>(format!("Tensor<B, {d}>").as_str()).unwrap()
|
||||
}
|
||||
};
|
||||
|
||||
field_types.push(ty);
|
||||
}
|
||||
|
||||
if field_types.len() == 1 {
|
||||
// Return one output
|
||||
quote! {
|
||||
#(
|
||||
#field_types
|
||||
)*
|
||||
}
|
||||
} else {
|
||||
// Return a tuple of the outputs
|
||||
quote! {
|
||||
(#(
|
||||
#field_types,
|
||||
)*)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates source code for the initialization method
|
||||
fn new_method(&self) -> TokenStream {
|
||||
let initialization_fields = self.initialization_fields();
|
||||
|
||||
let field_names = self.graph.nodes.iter().filter(|x| x.is_stateful).map(|x| {
|
||||
let name = Ident::new(&x.name, Span::call_site());
|
||||
quote! {
|
||||
#name
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
pub fn new() -> Self {
|
||||
#(
|
||||
#initialization_fields
|
||||
)*
|
||||
|
||||
Self {
|
||||
#(
|
||||
#field_names
|
||||
),*
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the fields for the declaration of the model
|
||||
fn declaration_fields(&self) -> Vec<(Ident, Type)> {
|
||||
let mut fields = vec![];
|
||||
|
||||
for node in self.graph.nodes.iter().filter(|x| x.is_stateful) {
|
||||
let node_type = match node.node_type {
|
||||
NodeType::Conv1d => syn::parse_str::<Type>("nn::conv::Conv1d<B>").unwrap(),
|
||||
NodeType::Conv2d => syn::parse_str::<Type>("nn::conv::Conv2d<B>").unwrap(),
|
||||
NodeType::Linear => syn::parse_str::<Type>("nn::Linear<B>").unwrap(),
|
||||
_ => {
|
||||
todo!("Node type not implemented: {:?}", node.node_type)
|
||||
}
|
||||
};
|
||||
|
||||
let node_name = Ident::new(&node.name, Span::call_site());
|
||||
|
||||
fields.push((node_name, node_type));
|
||||
}
|
||||
|
||||
fields
|
||||
}
|
||||
|
||||
/// Generates source code for the initialization method
|
||||
fn initialization_fields(&self) -> Vec<TokenStream> {
|
||||
let mut fields = vec![];
|
||||
|
||||
for node in self.graph.nodes.iter().filter(|x| x.is_stateful) {
|
||||
let init_code = match node.node_type {
|
||||
NodeType::Conv2d => conv2d_init(node),
|
||||
NodeType::Linear => linear_init(node),
|
||||
_ => {
|
||||
todo!("Node type not implemented: {:?}", node.node_type)
|
||||
}
|
||||
};
|
||||
|
||||
fields.push(init_code);
|
||||
}
|
||||
|
||||
fields
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates source code for the initialization of a Conv2d node
|
||||
fn conv2d_init(node: &Node) -> TokenStream {
|
||||
let node_name = Ident::new(&node.name, Span::call_site());
|
||||
|
||||
let config = conv2d_config(node);
|
||||
|
||||
let channel_in = config.channels[0];
|
||||
let channel_out = config.channels[1];
|
||||
let kernel_size_0 = config.kernel_size[0];
|
||||
let kernel_size_1 = config.kernel_size[1];
|
||||
let bias = config.bias;
|
||||
|
||||
let padding = match config.padding {
|
||||
Conv2dPaddingConfig::Valid => quote! { nn::conv::Conv2dPaddingConfig::Valid },
|
||||
Conv2dPaddingConfig::Same => quote! { nn::conv::Conv2dPaddingConfig::Same },
|
||||
_ => todo!("Padding ({:?}) not implemented", config.padding),
|
||||
};
|
||||
|
||||
quote! {
|
||||
let #node_name = nn::conv::Conv2dConfig::new([#channel_in, #channel_out], [#kernel_size_0, #kernel_size_1])
|
||||
.with_padding(#padding)
|
||||
.with_bias(#bias)
|
||||
.init();
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates source code for the initialization of a Linear node
|
||||
fn linear_init(node: &Node) -> TokenStream {
|
||||
let node_name = Ident::new(&node.name, Span::call_site());
|
||||
let config = linear_config(node);
|
||||
|
||||
let bias = config.bias;
|
||||
let input_size = config.d_input;
|
||||
let output_size = config.d_output;
|
||||
|
||||
quote! {
|
||||
let #node_name = nn::LinearConfig::new(#input_size, #output_size)
|
||||
.with_bias(#bias)
|
||||
.init();
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use rstest::*;
|
||||
use rust_format::{Config, Edition, Formatter, PostProcess, RustFmt};
|
||||
|
||||
#[fixture]
|
||||
pub fn model() -> ModelSourceCode {
|
||||
ModelSourceCode::new("tests/onnx/mnist.onnx")
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
fn print(model: ModelSourceCode) {
|
||||
let config = Config::new_str()
|
||||
.post_proc(PostProcess::ReplaceMarkersAndDocBlocks)
|
||||
.edition(Edition::Rust2021);
|
||||
|
||||
let rustfmt = RustFmt::from_config(config);
|
||||
|
||||
let _gen_str = rustfmt.format_tokens(model.body()).unwrap();
|
||||
|
||||
// TODO compare the result with the expected output
|
||||
}
|
||||
}
|
|
@ -0,0 +1,523 @@
|
|||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
fs::File,
|
||||
path::Path,
|
||||
str::{from_utf8, FromStr},
|
||||
};
|
||||
|
||||
use super::coalesce::coalesce;
|
||||
use super::ir::{
|
||||
ArgType, Argument, AttributeValue, Attributes, ElementType, Graph, Node, NodeType, Tensor,
|
||||
TensorData,
|
||||
};
|
||||
use super::protos::{
|
||||
attribute_proto::AttributeType, tensor_proto::DataType, tensor_shape_proto::dimension::Value,
|
||||
type_proto, AttributeProto, ModelProto, NodeProto, TensorProto, TensorShapeProto,
|
||||
ValueInfoProto,
|
||||
};
|
||||
use super::shape_inference::shape_inference;
|
||||
|
||||
use protobuf::{Enum, Message};
|
||||
use topological_sort::TopologicalSort;
|
||||
|
||||
const STATEFUL_NODE_TYPES: [NodeType; 4] = [
|
||||
NodeType::Conv,
|
||||
NodeType::BatchNormalization,
|
||||
NodeType::Dropout,
|
||||
NodeType::Linear,
|
||||
];
|
||||
|
||||
/// Error type for parsing ONNX model
|
||||
#[derive(Debug)]
|
||||
pub enum ParseError {
|
||||
VariantNotFound,
|
||||
}
|
||||
|
||||
/// Open an onnx file and convert it to a Graph (intermediate representation)
|
||||
pub fn parse_onnx(onnx_path: &Path) -> Graph {
|
||||
// Open the file
|
||||
let mut file = File::open(onnx_path).expect("Unable to open file");
|
||||
let onnx_model: ModelProto =
|
||||
Message::parse_from_reader(&mut file).expect("Unable to parse ONNX file");
|
||||
|
||||
// Convert the nodes
|
||||
let mut nodes: Vec<Node> = vec![];
|
||||
for onnx_node in onnx_model.graph.node.iter() {
|
||||
nodes.push(convert_node_proto(onnx_node));
|
||||
}
|
||||
|
||||
// Get the names of the initializers
|
||||
let check_if_initializer: HashSet<String> = onnx_model
|
||||
.graph
|
||||
.initializer
|
||||
.iter()
|
||||
.map(|x| x.name.clone())
|
||||
.collect();
|
||||
|
||||
// Move inputs to initializers
|
||||
move_inputs_to_initializer(&mut nodes, &check_if_initializer);
|
||||
|
||||
// Get the topological sort of the nodes and the top nodes
|
||||
let (ts, top_nodes) = get_top_nodes(&nodes);
|
||||
|
||||
// Sort the nodes
|
||||
top_sort_nodes(&mut nodes, ts);
|
||||
|
||||
// Collect inputs, outputs and initializers
|
||||
let mut inputs = collect_inputs(&onnx_model, &check_if_initializer, top_nodes);
|
||||
let mut outputs = collect_outputs(&onnx_model, check_if_initializer);
|
||||
let initializers = collect_initializers(onnx_model);
|
||||
|
||||
// Coalesce and transform nodes
|
||||
coalesce(&mut nodes);
|
||||
|
||||
// Copy the initializers to the nodes
|
||||
copy_initializer_info_to_nodes_level(&mut nodes, &initializers);
|
||||
|
||||
// Rename nodes and inputs, save the mapping for later
|
||||
let old_node_names = rename_nodes(&mut nodes);
|
||||
let old_input_names = rename_inputs(&mut nodes, &mut inputs, &mut outputs);
|
||||
|
||||
// Infer shapes and update the inputs and outputs
|
||||
shape_inference(&mut nodes, &inputs, &mut outputs);
|
||||
|
||||
Graph {
|
||||
nodes,
|
||||
inputs,
|
||||
outputs,
|
||||
initializers,
|
||||
old_node_names,
|
||||
old_input_names,
|
||||
}
|
||||
}
|
||||
|
||||
/// Collect initializers
|
||||
fn collect_initializers(onnx_model: ModelProto) -> Vec<Argument> {
|
||||
let mut initializers: Vec<Argument> = vec![];
|
||||
for initializer in onnx_model.graph.initializer.iter() {
|
||||
let tensor_proto = initializer.clone();
|
||||
|
||||
let name = tensor_proto.name.clone();
|
||||
|
||||
// FIXME data conversion for the tensor is incorrect
|
||||
let tensor: Tensor = tensor_proto.try_into().unwrap();
|
||||
let arg_type = Some(ArgType::Tensor(tensor));
|
||||
let arg = Argument { name, arg_type };
|
||||
initializers.push(arg);
|
||||
}
|
||||
initializers
|
||||
}
|
||||
|
||||
/// Collect outputs
|
||||
fn collect_outputs(
|
||||
onnx_model: &ModelProto,
|
||||
check_if_initializer: HashSet<String>,
|
||||
) -> Vec<Argument> {
|
||||
// TODO: filter out the outputs that are not used in the graph
|
||||
let outputs: Vec<Argument> = onnx_model
|
||||
.graph
|
||||
.output
|
||||
.iter()
|
||||
.filter(|x| !check_if_initializer.contains(x.name.as_str()))
|
||||
.map(|i| Argument::try_from(i.clone()).unwrap())
|
||||
.collect();
|
||||
outputs
|
||||
}
|
||||
|
||||
/// Collect inputs
|
||||
fn collect_inputs(
|
||||
onnx_model: &ModelProto,
|
||||
check_if_initializer: &HashSet<String>,
|
||||
top_nodes: HashSet<String>,
|
||||
) -> Vec<Argument> {
|
||||
let inputs: Vec<Argument> = onnx_model
|
||||
.graph
|
||||
.input
|
||||
.iter()
|
||||
.filter(|x| !check_if_initializer.contains(x.name.as_str()))
|
||||
.filter(|x| top_nodes.contains(&x.name))
|
||||
.map(|x| Argument::try_from(x.clone()).unwrap())
|
||||
.collect();
|
||||
inputs
|
||||
}
|
||||
|
||||
/// Sort the nodes in topological order
|
||||
fn top_sort_nodes(nodes: &mut Vec<Node>, mut ts: TopologicalSort<Node>) {
|
||||
*nodes = vec![];
|
||||
while let Some(node) = ts.pop() {
|
||||
nodes.push(node);
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the top nodes in the graph
|
||||
fn get_top_nodes(nodes: &Vec<Node>) -> (TopologicalSort<Node>, HashSet<String>) {
|
||||
// Get the names of the top nodes (first nodes in the graph to receive the input)
|
||||
// Sometimes onnx will pass inputs to be used as weights and biases but they are not truly inputs
|
||||
let ts = topsort(nodes);
|
||||
let mut top_nodes: HashSet<String> = HashSet::new();
|
||||
|
||||
for node in ts.peek_all() {
|
||||
for input in node.inputs.iter() {
|
||||
top_nodes.insert(input.name.clone());
|
||||
}
|
||||
}
|
||||
(ts, top_nodes)
|
||||
}
|
||||
|
||||
/// Move nodes's inputs and outputs to initializers if they are in the initializer list
|
||||
fn move_inputs_to_initializer(nodes: &mut Vec<Node>, check_if_initializer: &HashSet<String>) {
|
||||
for node in nodes.iter_mut() {
|
||||
node.initializers = node
|
||||
.inputs
|
||||
.iter()
|
||||
.filter(|x| check_if_initializer.contains(&x.name))
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
// Remove the initializers from the inputs and outputs
|
||||
node.inputs
|
||||
.retain(|x| !check_if_initializer.contains(&x.name));
|
||||
node.outputs
|
||||
.retain(|x| !check_if_initializer.contains(&x.name));
|
||||
}
|
||||
}
|
||||
|
||||
fn to_string(bytes: Vec<u8>) -> String {
|
||||
from_utf8(bytes.as_slice()).unwrap().to_string()
|
||||
}
|
||||
|
||||
fn to_string_vec(bytes: Vec<Vec<u8>>) -> Vec<String> {
|
||||
bytes.iter().map(|b| to_string(b.clone())).collect()
|
||||
}
|
||||
|
||||
fn convert_shape(shape: Vec<i64>) -> Vec<usize> {
|
||||
shape.iter().map(|s| *s as usize).collect()
|
||||
}
|
||||
|
||||
/// Convert a vector of AttributeProto to a HashMap of AttributeValue
|
||||
impl TryFrom<TensorProto> for Tensor {
|
||||
type Error = ParseError;
|
||||
fn try_from(tensor: TensorProto) -> Result<Tensor, Self::Error> {
|
||||
let (elem_type, data) = match DataType::from_i32(tensor.data_type).unwrap() {
|
||||
// TODO: check if float is empty and use raw_data instead
|
||||
DataType::FLOAT => (
|
||||
ElementType::Float32,
|
||||
TensorData::Float32s(tensor.float_data),
|
||||
),
|
||||
DataType::INT32 => (ElementType::Int32, TensorData::Int32s(tensor.int32_data)),
|
||||
DataType::INT64 => (ElementType::Int64, TensorData::Int64s(tensor.int64_data)),
|
||||
DataType::DOUBLE => (
|
||||
ElementType::Float64,
|
||||
TensorData::Float64s(tensor.double_data),
|
||||
),
|
||||
|
||||
// TODO : Add more types
|
||||
_ => {
|
||||
return Err(ParseError::VariantNotFound);
|
||||
}
|
||||
};
|
||||
let shape = convert_shape(tensor.dims);
|
||||
let name = tensor.name;
|
||||
|
||||
Ok(Tensor {
|
||||
name: Some(name),
|
||||
elem_type,
|
||||
shape,
|
||||
data: Some(data),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<TensorShapeProto> for Vec<usize> {
|
||||
type Error = ParseError;
|
||||
fn try_from(shape: TensorShapeProto) -> Result<Vec<usize>, Self::Error> {
|
||||
let mut result = Vec::new();
|
||||
|
||||
for dim in shape.dim {
|
||||
if let Value::DimValue(value) = dim.value.unwrap() {
|
||||
result.push(value as usize);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert a vector of AttributeProto to a HashMap of AttributeValue
|
||||
impl TryFrom<&type_proto::Tensor> for Tensor {
|
||||
type Error = ParseError;
|
||||
fn try_from(tensor: &type_proto::Tensor) -> Result<Tensor, Self::Error> {
|
||||
let elem_type = match DataType::from_i32(tensor.elem_type).unwrap() {
|
||||
DataType::FLOAT => ElementType::Float32,
|
||||
DataType::INT32 => ElementType::Int32,
|
||||
DataType::INT64 => ElementType::Int64,
|
||||
DataType::DOUBLE => ElementType::Float64,
|
||||
|
||||
// TODO : Add more types
|
||||
_ => {
|
||||
return Err(ParseError::VariantNotFound);
|
||||
}
|
||||
};
|
||||
|
||||
let shape_proto = tensor.shape.clone().unwrap();
|
||||
let shape: Vec<usize> = shape_proto.try_into().unwrap();
|
||||
|
||||
let name = None;
|
||||
|
||||
Ok(Tensor {
|
||||
name,
|
||||
elem_type,
|
||||
shape,
|
||||
data: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_vec_tensor_proto(tensors: Vec<TensorProto>) -> Result<Vec<Tensor>, ParseError> {
|
||||
let mut result = Vec::new();
|
||||
for tensor in tensors {
|
||||
result.push(Tensor::try_from(tensor)?);
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Convert a vector of AttributeProto to a HashMap of AttributeValue
|
||||
impl TryFrom<AttributeProto> for AttributeValue {
|
||||
type Error = ParseError;
|
||||
|
||||
fn try_from(attr: AttributeProto) -> Result<AttributeValue, Self::Error> {
|
||||
let value = match attr.type_.unwrap() {
|
||||
AttributeType::FLOAT => AttributeValue::Float32(attr.f),
|
||||
AttributeType::INT => AttributeValue::Int64(attr.i),
|
||||
AttributeType::STRING => AttributeValue::String(to_string(attr.s)),
|
||||
|
||||
// warning: tensor can be empty TODO: check if it is empty
|
||||
AttributeType::TENSOR => AttributeValue::Tensor(Tensor::try_from(attr.t.unwrap())?),
|
||||
|
||||
// Graph is not supported for now
|
||||
// AttributeType::GRAPH => AttributeValue::Graph(attr.g),
|
||||
AttributeType::FLOATS => AttributeValue::Float32s(attr.floats),
|
||||
AttributeType::INTS => AttributeValue::Int64s(attr.ints),
|
||||
AttributeType::STRINGS => AttributeValue::Strings(to_string_vec(attr.strings)),
|
||||
AttributeType::TENSORS => {
|
||||
AttributeValue::Tensors(convert_vec_tensor_proto(attr.tensors)?)
|
||||
}
|
||||
// AttributeType::GRAPHS => AttributeValue::Graphs(attr.graphs),
|
||||
// AttributeType::SPARSE_TENSORS => AttributeValue::SparseTensors(attr.sparse_tensors),
|
||||
// AttributeType::SPARSE_TENSOR => AttributeValue::SparseTensor(attr.sparse_tensor),
|
||||
_ => {
|
||||
return Err(ParseError::VariantNotFound);
|
||||
}
|
||||
};
|
||||
|
||||
Ok(value)
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert a vector of AttributeProto to a HashMap of AttributeValue
|
||||
pub fn convert_vec_attrs_proto(attrs: Vec<AttributeProto>) -> Attributes {
|
||||
let mut result = Attributes::new();
|
||||
for attr in attrs {
|
||||
result.insert(attr.name.clone(), AttributeValue::try_from(attr).unwrap());
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
pub fn convert_node_proto(node: &NodeProto) -> Node {
|
||||
let name = node.name.clone();
|
||||
let inputs = node
|
||||
.input
|
||||
.clone()
|
||||
.into_iter()
|
||||
.map(|x| Argument {
|
||||
name: x,
|
||||
arg_type: None,
|
||||
})
|
||||
.collect();
|
||||
let outputs = node
|
||||
.output
|
||||
.clone()
|
||||
.into_iter()
|
||||
.map(|x| Argument {
|
||||
name: x,
|
||||
arg_type: None,
|
||||
})
|
||||
.collect();
|
||||
let attrs = convert_vec_attrs_proto(node.attribute.clone());
|
||||
|
||||
let node_type = NodeType::from_str(node.op_type.as_str()).unwrap();
|
||||
|
||||
let is_stateful = STATEFUL_NODE_TYPES.contains(&node_type);
|
||||
|
||||
let mut node = Node {
|
||||
node_type,
|
||||
name,
|
||||
inputs,
|
||||
outputs,
|
||||
initializers: vec![],
|
||||
attrs,
|
||||
is_stateful,
|
||||
};
|
||||
|
||||
remap_node_type(&mut node);
|
||||
|
||||
node
|
||||
}
|
||||
|
||||
/// Remap node type to a more specific one
|
||||
fn remap_node_type(node: &mut Node) {
|
||||
match node.node_type {
|
||||
NodeType::Conv => {
|
||||
if let AttributeValue::Int64s(ints) = node.attrs.get("kernel_shape").unwrap() {
|
||||
node.node_type = match ints.len() {
|
||||
1 => NodeType::Conv1d,
|
||||
2 => NodeType::Conv2d,
|
||||
_ => todo!(),
|
||||
};
|
||||
} else {
|
||||
panic!("kernel_shape is not an int64s");
|
||||
}
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert a vector of AttributeProto to a HashMap of AttributeValue
|
||||
impl TryFrom<ValueInfoProto> for Argument {
|
||||
type Error = ParseError;
|
||||
|
||||
fn try_from(value: ValueInfoProto) -> Result<Argument, Self::Error> {
|
||||
let name = value.name.clone();
|
||||
let proto_type = value.type_.unwrap();
|
||||
|
||||
let mut arg_type = None;
|
||||
|
||||
if proto_type.has_tensor_type() {
|
||||
let tensor_proto = proto_type.tensor_type();
|
||||
|
||||
let tensor: Tensor = tensor_proto.try_into().unwrap();
|
||||
|
||||
arg_type = Some(ArgType::Tensor(tensor));
|
||||
}
|
||||
Ok(Argument { name, arg_type })
|
||||
}
|
||||
}
|
||||
|
||||
/// Copy the initializers to the nodes
|
||||
fn copy_initializer_info_to_nodes_level(nodes: &mut Vec<Node>, initializers: &Vec<Argument>) {
|
||||
for node in nodes.iter_mut() {
|
||||
for node_initializer in node.initializers.iter_mut() {
|
||||
*node_initializer = initializers
|
||||
.iter()
|
||||
.find(|x| x.name == node_initializer.name)
|
||||
.unwrap()
|
||||
.clone();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Rename the nodes in the graph to be unique and return a map of the old names to the new names.
|
||||
fn rename_nodes(nodes: &mut Vec<Node>) -> HashMap<String, String> {
|
||||
let mut old_names = HashMap::new();
|
||||
let mut counter: HashMap<NodeType, usize> = HashMap::new();
|
||||
|
||||
for node in nodes.iter_mut() {
|
||||
// keep track of the number of nodes of each type
|
||||
counter
|
||||
.entry(node.node_type.clone())
|
||||
.and_modify(|e| *e += 1)
|
||||
.or_insert(1);
|
||||
|
||||
let old_name = node.name.clone();
|
||||
let new_name = format!("{}{}", node.node_type, counter[&node.node_type]).to_lowercase();
|
||||
|
||||
node.name = new_name.clone();
|
||||
|
||||
old_names.insert(old_name, new_name);
|
||||
}
|
||||
|
||||
old_names
|
||||
}
|
||||
|
||||
/// Rename the inputs in the graph and return a map of the old names to the new names.
|
||||
///
|
||||
/// The inputs are renamed to be unique and to be in the format of conv2_in1, conv2_in2, etc.
|
||||
/// This is done to be consistent with the naming convention of the nodes and allow to be used as rust identifiers.
|
||||
fn rename_inputs(
|
||||
nodes: &mut Vec<Node>,
|
||||
inputs: &mut Vec<Argument>,
|
||||
outputs: &mut Vec<Argument>,
|
||||
) -> HashMap<String, String> {
|
||||
let mut old_names = HashMap::new();
|
||||
|
||||
let mut counter = 1;
|
||||
for input in inputs.iter_mut() {
|
||||
let old_name = input.name.clone();
|
||||
let new_name = format!("input{}", counter);
|
||||
|
||||
input.name = new_name.clone();
|
||||
|
||||
old_names.insert(old_name, new_name);
|
||||
counter += 1;
|
||||
}
|
||||
|
||||
let mut counter: HashMap<String, usize> = HashMap::new();
|
||||
|
||||
for node in nodes.iter_mut() {
|
||||
// keep track of the number of nodes of each type
|
||||
counter
|
||||
.entry(node.name.clone())
|
||||
.and_modify(|e| *e += 1)
|
||||
.or_insert(1);
|
||||
|
||||
// loop through node inputs and rename them with previously replaced names
|
||||
for input in node.inputs.iter_mut() {
|
||||
if let Some(new_name) = old_names.get(&input.name) {
|
||||
input.name = new_name.clone();
|
||||
}
|
||||
}
|
||||
|
||||
// loop through node outputs and rename them and store the new name <-> old name mapping
|
||||
for output in node.outputs.iter_mut() {
|
||||
let old_name = output.name.clone();
|
||||
let new_name = format!("{}_out{}", node.name, counter[&node.name]);
|
||||
output.name = new_name.clone();
|
||||
old_names.insert(old_name, new_name);
|
||||
}
|
||||
}
|
||||
|
||||
// Rename the graph outputs
|
||||
for output in outputs.iter_mut() {
|
||||
if let Some(new_name) = old_names.get(&output.name) {
|
||||
output.name = new_name.clone();
|
||||
}
|
||||
}
|
||||
|
||||
old_names
|
||||
}
|
||||
|
||||
/// Find the node that produces the given output
|
||||
fn lookup_node_by_output(nodes: &Vec<Node>, input: &str) -> Option<Node> {
|
||||
for node in nodes.iter() {
|
||||
if node.outputs.iter().any(|x| x.name == *input) {
|
||||
return Some(node.clone());
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Sort nodes in topological order
|
||||
pub fn topsort(nodes: &Vec<Node>) -> TopologicalSort<Node> {
|
||||
let mut ts = TopologicalSort::new();
|
||||
|
||||
for node in nodes.iter() {
|
||||
for input in node.inputs.iter() {
|
||||
match lookup_node_by_output(nodes, input.name.as_str()) {
|
||||
Some(prec) => ts.add_dependency(prec, node.clone()),
|
||||
None => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ts
|
||||
}
|
|
@ -0,0 +1,305 @@
|
|||
use half::f16;
|
||||
use std::collections::HashMap;
|
||||
use strum_macros::{Display, EnumString};
|
||||
|
||||
pub type Shape = Vec<usize>;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Argument {
|
||||
pub name: String,
|
||||
pub arg_type: Option<ArgType>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ArgType {
|
||||
Tensor(Tensor),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SparseTensor(Tensor, Tensor, Shape);
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum AttributeValue {
|
||||
Float32(f32),
|
||||
Int64(i64),
|
||||
String(String),
|
||||
Tensor(Tensor),
|
||||
SparseTensor(SparseTensor),
|
||||
Float32s(Vec<f32>),
|
||||
Int64s(Vec<i64>),
|
||||
Strings(Vec<String>),
|
||||
Tensors(Vec<Tensor>),
|
||||
SparseTensors(Vec<SparseTensor>),
|
||||
}
|
||||
pub type Attributes = HashMap<String, AttributeValue>;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ElementType {
|
||||
Float32,
|
||||
Float64,
|
||||
Int32,
|
||||
Int64,
|
||||
String,
|
||||
Float16,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Tensor {
|
||||
pub name: Option<String>,
|
||||
pub elem_type: ElementType,
|
||||
pub shape: Shape,
|
||||
pub data: Option<TensorData>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum TensorData {
|
||||
Float16s(Vec<f16>),
|
||||
Float32s(Vec<f32>),
|
||||
Float64s(Vec<f64>),
|
||||
Int32s(Vec<i32>),
|
||||
Int64s(Vec<i64>),
|
||||
Strings(Vec<String>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Graph {
|
||||
pub nodes: Vec<Node>,
|
||||
pub inputs: Vec<Argument>,
|
||||
pub outputs: Vec<Argument>,
|
||||
pub initializers: Vec<Argument>,
|
||||
pub old_node_names: HashMap<String, String>,
|
||||
pub old_input_names: HashMap<String, String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Node {
|
||||
pub node_type: NodeType,
|
||||
pub name: String,
|
||||
pub inputs: Vec<Argument>,
|
||||
pub outputs: Vec<Argument>,
|
||||
pub initializers: Vec<Argument>,
|
||||
pub attrs: Attributes,
|
||||
pub is_stateful: bool,
|
||||
}
|
||||
|
||||
// Required by topological sort
|
||||
impl PartialEq for Node {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.name == other.name && self.node_type == other.node_type
|
||||
}
|
||||
}
|
||||
|
||||
// Required by topological sort
|
||||
impl Eq for Node {}
|
||||
|
||||
// Required by topological sort
|
||||
impl core::hash::Hash for Node {
|
||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||
self.name.hash(state);
|
||||
self.node_type.hash(state);
|
||||
self.inputs.hash(state);
|
||||
self.outputs.hash(state);
|
||||
}
|
||||
}
|
||||
|
||||
// Required by topological sort
|
||||
impl core::hash::Hash for Argument {
|
||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||
self.name.hash(state);
|
||||
}
|
||||
}
|
||||
|
||||
/// The list of supported node types (ONNX operators and some extra ones to map easily to Burn's ops)
|
||||
#[derive(Debug, Hash, Eq, PartialEq, EnumString, Clone, Display)]
|
||||
pub enum NodeType {
|
||||
Abs,
|
||||
Acos,
|
||||
Acosh,
|
||||
Add,
|
||||
And,
|
||||
ArgMax,
|
||||
ArgMin,
|
||||
Asin,
|
||||
Asinh,
|
||||
Atan,
|
||||
Atanh,
|
||||
AveragePool,
|
||||
BatchNormalization,
|
||||
Bernoulli,
|
||||
BitShift,
|
||||
BitwiseAnd,
|
||||
BitwiseNot,
|
||||
BitwiseOr,
|
||||
BitwiseXor,
|
||||
BlackmanWindow,
|
||||
Cast,
|
||||
CastLike,
|
||||
Ceil,
|
||||
Celu,
|
||||
CenterCropPad,
|
||||
Clip,
|
||||
Col,
|
||||
Compress,
|
||||
Concat,
|
||||
ConcatFromSequence,
|
||||
Constant,
|
||||
ConstantOfShape,
|
||||
Conv,
|
||||
Conv1d,
|
||||
Conv2d,
|
||||
ConvInteger,
|
||||
ConvTranspose,
|
||||
Cos,
|
||||
Cosh,
|
||||
CumSum,
|
||||
DepthToSpace,
|
||||
DequantizeLinear,
|
||||
Det,
|
||||
DFT,
|
||||
Div,
|
||||
Dropout,
|
||||
DynamicQuantizeLinear,
|
||||
Einsum,
|
||||
Elu,
|
||||
Equal,
|
||||
Erf,
|
||||
Exp,
|
||||
Expand,
|
||||
EyeLike,
|
||||
Flatten,
|
||||
Floor,
|
||||
Gather,
|
||||
GatherElements,
|
||||
GatherND,
|
||||
Gelu,
|
||||
Gemm,
|
||||
GlobalAveragePool,
|
||||
GlobalLpPool,
|
||||
GlobalMaxPool,
|
||||
Greater,
|
||||
GreaterOrEqual,
|
||||
GridSample,
|
||||
GroupNormalization,
|
||||
GRU,
|
||||
HammingWindow,
|
||||
HannWindow,
|
||||
Hardmax,
|
||||
HardSigmoid,
|
||||
HardSwish,
|
||||
Identity,
|
||||
If,
|
||||
Im,
|
||||
InstanceNormalization,
|
||||
IsInf,
|
||||
IsNaN,
|
||||
LayerNormalization,
|
||||
LeakyRelu,
|
||||
Less,
|
||||
LessOrEqual,
|
||||
Linear,
|
||||
Log,
|
||||
LogSoftmax,
|
||||
Loop,
|
||||
LpNormalization,
|
||||
LpPool,
|
||||
LRN,
|
||||
LSTM,
|
||||
MatMul,
|
||||
MatMulInteger,
|
||||
Max,
|
||||
MaxPool,
|
||||
MaxRoiPool,
|
||||
MaxUnpool,
|
||||
Mean,
|
||||
MeanVarianceNormalization,
|
||||
MelWeightMatrix,
|
||||
Min,
|
||||
Mish,
|
||||
Mod,
|
||||
Mul,
|
||||
Multinomial,
|
||||
Neg,
|
||||
NegativeLogLikelihoodLoss,
|
||||
NonMaxSuppression,
|
||||
NonZero,
|
||||
Not,
|
||||
OneHot,
|
||||
Optional,
|
||||
OptionalGetElement,
|
||||
OptionalHasElement,
|
||||
Or,
|
||||
Pad,
|
||||
Pow,
|
||||
PRelu,
|
||||
QLinearConv,
|
||||
QLinearMatMul,
|
||||
QuantizeLinear,
|
||||
RandomNormal,
|
||||
RandomNormalLike,
|
||||
RandomUniform,
|
||||
RandomUniformLike,
|
||||
Range,
|
||||
Reciprocal,
|
||||
ReduceL,
|
||||
ReduceLogSum,
|
||||
ReduceLogSumExp,
|
||||
ReduceMax,
|
||||
ReduceMean,
|
||||
ReduceMin,
|
||||
ReduceProd,
|
||||
ReduceSum,
|
||||
ReduceSumSquare,
|
||||
Relu,
|
||||
Reshape,
|
||||
Resize,
|
||||
ReverseSequence,
|
||||
RNN,
|
||||
RoiAlign,
|
||||
Round,
|
||||
Scan,
|
||||
Scatter,
|
||||
ScatterElements,
|
||||
ScatterND,
|
||||
Selu,
|
||||
SequenceAt,
|
||||
SequenceConstruct,
|
||||
SequenceEmpty,
|
||||
SequenceErase,
|
||||
SequenceInsert,
|
||||
SequenceLength,
|
||||
SequenceMap,
|
||||
Shape,
|
||||
Shrink,
|
||||
Sigmoid,
|
||||
Sign,
|
||||
Sin,
|
||||
Sinh,
|
||||
Size,
|
||||
Slice,
|
||||
Softmax,
|
||||
SoftmaxCrossEntropyLoss,
|
||||
Softplus,
|
||||
Softsign,
|
||||
SpaceToDepth,
|
||||
Split,
|
||||
SplitToSequence,
|
||||
Sqrt,
|
||||
Squeeze,
|
||||
STFT,
|
||||
StringNormalizer,
|
||||
Sub,
|
||||
Sum,
|
||||
Tan,
|
||||
Tanh,
|
||||
TfIdfVectorizer,
|
||||
ThresholdedRelu,
|
||||
Tile,
|
||||
TopK,
|
||||
Transpose,
|
||||
Trilu,
|
||||
Unique,
|
||||
Unsqueeze,
|
||||
Upsample,
|
||||
Where,
|
||||
Xor,
|
||||
}
|
|
@ -0,0 +1,9 @@
|
|||
mod coalesce;
|
||||
mod codegen;
|
||||
mod convert;
|
||||
mod ir;
|
||||
mod op_configuration;
|
||||
mod protos;
|
||||
mod shape_inference;
|
||||
|
||||
pub use codegen::*;
|
|
@ -0,0 +1,181 @@
|
|||
use burn::nn::{
|
||||
conv::{Conv2dConfig, Conv2dPaddingConfig},
|
||||
LinearConfig,
|
||||
};
|
||||
|
||||
use super::ir::{ArgType, AttributeValue, Node};
|
||||
|
||||
#[inline(always)]
|
||||
pub fn attr_value_vec_i64(value: &AttributeValue, target: &mut Vec<i64>) {
|
||||
if let AttributeValue::Int64s(val) = value {
|
||||
*target = val.clone();
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn attr_value_i64(value: &AttributeValue, target: &mut i64) {
|
||||
if let AttributeValue::Int64(val) = value {
|
||||
*target = *val;
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a Conv2dConfig from the attributes of the node
|
||||
pub fn conv2d_config(curr: &Node) -> Conv2dConfig {
|
||||
let mut kernel_shape = Vec::new();
|
||||
let mut strides = Vec::new();
|
||||
let mut pads = Vec::new();
|
||||
let mut dilations = Vec::new();
|
||||
let mut group: i64 = 0;
|
||||
|
||||
// extract the channels from the weight tensor's shape [out_channels, in_channels, ...]
|
||||
let ArgType::Tensor(tensor) = curr.initializers.get(0).unwrap().clone().arg_type.unwrap();
|
||||
|
||||
// check if the bias is present
|
||||
let bias = curr.initializers.len() == 2;
|
||||
|
||||
// the channels are inverted in the weight tensor
|
||||
let channels: [usize; 2] = [tensor.shape[1], tensor.shape[0]];
|
||||
|
||||
for (key, value) in curr.attrs.iter() {
|
||||
match key.as_str() {
|
||||
"kernel_shape" => attr_value_vec_i64(value, &mut kernel_shape),
|
||||
"strides" => attr_value_vec_i64(value, &mut strides),
|
||||
"pads" => attr_value_vec_i64(value, &mut pads),
|
||||
"dilations" => attr_value_vec_i64(value, &mut dilations),
|
||||
"group" => attr_value_i64(value, &mut group),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
let padding = if pads.iter().all(|&x| x == 0) {
|
||||
Conv2dPaddingConfig::Valid
|
||||
} else {
|
||||
todo!("Conv2d: padding({pads:?}) is not fully supported");
|
||||
};
|
||||
|
||||
if strides.iter().all(|&x| x != 1) {
|
||||
todo!("Conv2d: strides({strides:?}) are not fully supported");
|
||||
};
|
||||
|
||||
if dilations.iter().all(|&x| x != 1) {
|
||||
todo!("Conv2d: dilations({dilations:?}) are not fully supported");
|
||||
};
|
||||
|
||||
if group != 1 {
|
||||
todo!("Conv2d: group ({group}) is not fully supported");
|
||||
};
|
||||
|
||||
Conv2dConfig::new(
|
||||
channels,
|
||||
[kernel_shape[0] as usize, kernel_shape[1] as usize],
|
||||
)
|
||||
.with_bias(bias)
|
||||
.with_padding(padding)
|
||||
}
|
||||
|
||||
/// Create a FlattenConfig from the attributes of the node
|
||||
pub fn flatten_config(curr: &Node) -> (usize, usize) {
|
||||
// the begin dimension is the first dimension (Default: 1 per ONNX spec)
|
||||
let mut start_dim: i64 = 1;
|
||||
|
||||
// check if the node has only one input
|
||||
if curr.inputs.len() != 1 {
|
||||
panic!(
|
||||
"Flatten: multiple inputs are not supported (got {:?})",
|
||||
curr.inputs.len()
|
||||
);
|
||||
}
|
||||
|
||||
// extract the shape of the input tensor
|
||||
let ArgType::Tensor(tensor) = curr.inputs.get(0).unwrap().clone().arg_type.unwrap();
|
||||
|
||||
// check if the input tensor has at least 2 dimensions
|
||||
if tensor.shape.len() < 2 {
|
||||
panic!(
|
||||
"Flatten: input tensor must have at least 2 dimensions (got {:?})",
|
||||
tensor.shape.len()
|
||||
);
|
||||
}
|
||||
|
||||
// the end dimension is the last dimension
|
||||
let end_dim = tensor.shape.len() - 1;
|
||||
|
||||
// extract the attributes
|
||||
for (key, value) in curr.attrs.iter() {
|
||||
match key.as_str() {
|
||||
"axis" => attr_value_i64(value, &mut start_dim),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
// if beg_dim is negative, it is counted from the end
|
||||
if start_dim < 0 {
|
||||
start_dim += tensor.shape.len() as i64;
|
||||
}
|
||||
|
||||
(start_dim as usize, end_dim)
|
||||
}
|
||||
|
||||
/// Create a LinearConfig from the attributes of the node
|
||||
pub fn linear_config(node: &Node) -> LinearConfig {
|
||||
// check if the node has only one input
|
||||
if node.inputs.len() != 1 {
|
||||
panic!(
|
||||
"Linear: multiple inputs are not supported (got {:?})",
|
||||
node.inputs.len()
|
||||
);
|
||||
}
|
||||
|
||||
if node.initializers.is_empty() {
|
||||
panic!("Linear: no initializers found");
|
||||
}
|
||||
|
||||
// extract the shape of the weight tensor
|
||||
let ArgType::Tensor(tensor) = node.initializers.get(0).unwrap().clone().arg_type.unwrap();
|
||||
|
||||
// check if the weight tensor has at least 2 dimensions
|
||||
if tensor.shape.len() < 2 {
|
||||
panic!(
|
||||
"Linear: weight tensor must have at least 2 dimensions (got {:?})",
|
||||
tensor.shape.len()
|
||||
);
|
||||
}
|
||||
let (out_size, in_size) = (tensor.shape[0], tensor.shape[1]);
|
||||
|
||||
// check if the bias is present
|
||||
let bias = node.initializers.len() == 2;
|
||||
|
||||
LinearConfig::new(in_size, out_size).with_bias(bias)
|
||||
}
|
||||
|
||||
/// Create log_softmax config from the attributes of the node
|
||||
pub fn log_softmax_config(node: &Node) -> usize {
|
||||
// the axis is the last dimension (Default: 1 per ONNX spec)
|
||||
let mut axis: i64 = -1;
|
||||
|
||||
// check if the node has only one input
|
||||
if node.inputs.len() != 1 {
|
||||
panic!(
|
||||
"LogSoftmax: multiple inputs are not supported (got {:?})",
|
||||
node.inputs.len()
|
||||
);
|
||||
}
|
||||
|
||||
// extract the shape of the input tensor
|
||||
let ArgType::Tensor(tensor) = node.inputs.get(0).unwrap().clone().arg_type.unwrap();
|
||||
|
||||
// extract the attributes
|
||||
for (key, value) in node.attrs.iter() {
|
||||
match key.as_str() {
|
||||
"axis" => attr_value_i64(value, &mut axis),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
// if axis is negative, it is counted from the end
|
||||
if axis < 0 {
|
||||
axis += tensor.shape.len() as i64;
|
||||
}
|
||||
|
||||
axis as usize
|
||||
}
|
|
@ -0,0 +1,5 @@
|
|||
mod inner {
|
||||
include!(concat!(env!("OUT_DIR"), "/onnx-protos/mod.rs"));
|
||||
}
|
||||
|
||||
pub use inner::onnx::*;
|
|
@ -0,0 +1,825 @@
|
|||
//
|
||||
// WARNING: This file is automatically generated! Please edit onnx.in.proto.
|
||||
//
|
||||
|
||||
// Copied from https://github.com/onnx/onnx/blob/main/onnx/onnx.proto3
|
||||
// under the following license:
|
||||
|
||||
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
package onnx;
|
||||
|
||||
// Overview
|
||||
//
|
||||
// ONNX is an open specification that is comprised of the following components:
|
||||
//
|
||||
// 1) A definition of an extensible computation graph model.
|
||||
// 2) Definitions of standard data types.
|
||||
// 3) Definitions of built-in operators.
|
||||
//
|
||||
// This document describes the syntax of models and their computation graphs,
|
||||
// as well as the standard data types. Together, they are referred to as the ONNX
|
||||
// Intermediate Representation, or 'IR' for short.
|
||||
//
|
||||
// The normative semantic specification of the ONNX IR is found in docs/IR.md.
|
||||
// Definitions of the built-in neural network operators may be found in docs/Operators.md.
|
||||
|
||||
// Notes
|
||||
//
|
||||
// Protobuf compatibility
|
||||
//
|
||||
// To simplify framework compatibility, ONNX is defined using the subset of protobuf
|
||||
// that is compatible with both protobuf v2 and v3. This means that we do not use any
|
||||
// protobuf features that are only available in one of the two versions.
|
||||
//
|
||||
// Here are the most notable contortions we have to carry out to work around
|
||||
// these limitations:
|
||||
//
|
||||
// - No 'map' (added protobuf 3.0). We instead represent mappings as lists
|
||||
// of key-value pairs, where order does not matter and duplicates
|
||||
// are not allowed.
|
||||
|
||||
|
||||
// Versioning
|
||||
//
|
||||
// ONNX versioning is specified in docs/IR.md and elaborated on in docs/Versioning.md
|
||||
//
|
||||
// To be compatible with both proto2 and proto3, we will use a version number
|
||||
// that is not defined by the default value but an explicit enum number.
|
||||
enum Version {
|
||||
// proto3 requires the first enum value to be zero.
|
||||
// We add this just to appease the compiler.
|
||||
_START_VERSION = 0;
|
||||
// The version field is always serialized and we will use it to store the
|
||||
// version that the graph is generated from. This helps us set up version
|
||||
// control.
|
||||
// For the IR, we are using simple numbers starting with 0x00000001,
|
||||
// which was the version we published on Oct 10, 2017.
|
||||
IR_VERSION_2017_10_10 = 0x0000000000000001;
|
||||
|
||||
// IR_VERSION 2 published on Oct 30, 2017
|
||||
// - Added type discriminator to AttributeProto to support proto3 users
|
||||
IR_VERSION_2017_10_30 = 0x0000000000000002;
|
||||
|
||||
// IR VERSION 3 published on Nov 3, 2017
|
||||
// - For operator versioning:
|
||||
// - Added new message OperatorSetIdProto
|
||||
// - Added opset_import in ModelProto
|
||||
// - For vendor extensions, added domain in NodeProto
|
||||
IR_VERSION_2017_11_3 = 0x0000000000000003;
|
||||
|
||||
// IR VERSION 4 published on Jan 22, 2019
|
||||
// - Relax constraint that initializers should be a subset of graph inputs
|
||||
// - Add type BFLOAT16
|
||||
IR_VERSION_2019_1_22 = 0x0000000000000004;
|
||||
|
||||
// IR VERSION 5 published on March 18, 2019
|
||||
// - Add message TensorAnnotation.
|
||||
// - Add quantization annotation in GraphProto to map tensor with its scale and zero point quantization parameters.
|
||||
IR_VERSION_2019_3_18 = 0x0000000000000005;
|
||||
|
||||
// IR VERSION 6 published on Sep 19, 2019
|
||||
// - Add support for sparse tensor constants stored in model.
|
||||
// - Add message SparseTensorProto
|
||||
// - Add sparse initializers
|
||||
IR_VERSION_2019_9_19 = 0x0000000000000006;
|
||||
|
||||
// IR VERSION 7 published on May 8, 2020
|
||||
// - Add support to allow function body graph to rely on multiple external opreator sets.
|
||||
// - Add a list to promote inference graph's initializers to global and
|
||||
// mutable variables. Global variables are visible in all graphs of the
|
||||
// stored models.
|
||||
// - Add message TrainingInfoProto to store initialization
|
||||
// method and training algorithm. The execution of TrainingInfoProto
|
||||
// can modify the values of mutable variables.
|
||||
// - Implicitly add inference graph into each TrainingInfoProto's algorithm.
|
||||
IR_VERSION_2020_5_8 = 0x0000000000000007;
|
||||
|
||||
// IR VERSION 8 published on July 30, 2021
|
||||
// Introduce TypeProto.SparseTensor
|
||||
// Introduce TypeProto.Optional
|
||||
// Added a list of FunctionProtos local to the model
|
||||
// Deprecated since_version and operator status from FunctionProto
|
||||
IR_VERSION_2021_7_30 = 0x0000000000000008;
|
||||
|
||||
// IR VERSION 9 published on TBD
|
||||
// Added AttributeProto to FunctionProto so that default attribute values can be set.
|
||||
IR_VERSION = 0x0000000000000009;
|
||||
}
|
||||
|
||||
// Attributes
|
||||
//
|
||||
// A named attribute containing either singular float, integer, string, graph,
|
||||
// and tensor values, or repeated float, integer, string, graph, and tensor values.
|
||||
// An AttributeProto MUST contain the name field, and *only one* of the
|
||||
// following content fields, effectively enforcing a C/C++ union equivalent.
|
||||
message AttributeProto {
|
||||
|
||||
// Note: this enum is structurally identical to the OpSchema::AttrType
|
||||
// enum defined in schema.h. If you rev one, you likely need to rev the other.
|
||||
enum AttributeType {
|
||||
UNDEFINED = 0;
|
||||
FLOAT = 1;
|
||||
INT = 2;
|
||||
STRING = 3;
|
||||
TENSOR = 4;
|
||||
GRAPH = 5;
|
||||
SPARSE_TENSOR = 11;
|
||||
TYPE_PROTO = 13;
|
||||
|
||||
FLOATS = 6;
|
||||
INTS = 7;
|
||||
STRINGS = 8;
|
||||
TENSORS = 9;
|
||||
GRAPHS = 10;
|
||||
SPARSE_TENSORS = 12;
|
||||
TYPE_PROTOS = 14;
|
||||
}
|
||||
|
||||
// The name field MUST be present for this version of the IR.
|
||||
string name = 1; // namespace Attribute
|
||||
|
||||
// if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function.
|
||||
// In this case, this AttributeProto does not contain data, and it's a reference of attribute
|
||||
// in parent scope.
|
||||
// NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph.
|
||||
string ref_attr_name = 21;
|
||||
|
||||
// A human-readable documentation for this attribute. Markdown is allowed.
|
||||
string doc_string = 13;
|
||||
|
||||
// The type field MUST be present for this version of the IR.
|
||||
// For 0.0.1 versions of the IR, this field was not defined, and
|
||||
// implementations needed to use has_field heuristics to determine
|
||||
// which value field was in use. For IR_VERSION 0.0.2 or later, this
|
||||
// field MUST be set and match the f|i|s|t|... field in use. This
|
||||
// change was made to accommodate proto3 implementations.
|
||||
AttributeType type = 20; // discriminator that indicates which field below is in use
|
||||
|
||||
// Exactly ONE of the following fields must be present for this version of the IR
|
||||
float f = 2; // float
|
||||
int64 i = 3; // int
|
||||
bytes s = 4; // UTF-8 string
|
||||
TensorProto t = 5; // tensor value
|
||||
GraphProto g = 6; // graph
|
||||
SparseTensorProto sparse_tensor = 22; // sparse tensor value
|
||||
// Do not use field below, it's deprecated.
|
||||
// optional ValueProto v = 12; // value - subsumes everything but graph
|
||||
TypeProto tp = 14; // type proto
|
||||
|
||||
repeated float floats = 7; // list of floats
|
||||
repeated int64 ints = 8; // list of ints
|
||||
repeated bytes strings = 9; // list of UTF-8 strings
|
||||
repeated TensorProto tensors = 10; // list of tensors
|
||||
repeated GraphProto graphs = 11; // list of graph
|
||||
repeated SparseTensorProto sparse_tensors = 23; // list of sparse tensors
|
||||
repeated TypeProto type_protos = 15;// list of type protos
|
||||
}
|
||||
|
||||
// Defines information on value, including the name, the type, and
|
||||
// the shape of the value.
|
||||
message ValueInfoProto {
|
||||
// This field MUST be present in this version of the IR.
|
||||
string name = 1; // namespace Value
|
||||
// This field MUST be present in this version of the IR for
|
||||
// inputs and outputs of the top-level graph.
|
||||
TypeProto type = 2;
|
||||
// A human-readable documentation for this value. Markdown is allowed.
|
||||
string doc_string = 3;
|
||||
}
|
||||
|
||||
// Nodes
|
||||
//
|
||||
// Computation graphs are made up of a DAG of nodes, which represent what is
|
||||
// commonly called a "layer" or "pipeline stage" in machine learning frameworks.
|
||||
//
|
||||
// For example, it can be a node of type "Conv" that takes in an image, a filter
|
||||
// tensor and a bias tensor, and produces the convolved output.
|
||||
message NodeProto {
|
||||
repeated string input = 1; // namespace Value
|
||||
repeated string output = 2; // namespace Value
|
||||
|
||||
// An optional identifier for this node in a graph.
|
||||
// This field MAY be absent in ths version of the IR.
|
||||
string name = 3; // namespace Node
|
||||
|
||||
// The symbolic identifier of the Operator to execute.
|
||||
string op_type = 4; // namespace Operator
|
||||
// The domain of the OperatorSet that specifies the operator named by op_type.
|
||||
string domain = 7; // namespace Domain
|
||||
|
||||
// Additional named attributes.
|
||||
repeated AttributeProto attribute = 5;
|
||||
|
||||
// A human-readable documentation for this node. Markdown is allowed.
|
||||
string doc_string = 6;
|
||||
}
|
||||
|
||||
// Training information
|
||||
// TrainingInfoProto stores information for training a model.
|
||||
// In particular, this defines two functionalities: an initialization-step
|
||||
// and a training-algorithm-step. Initialization resets the model
|
||||
// back to its original state as if no training has been performed.
|
||||
// Training algorithm improves the model based on input data.
|
||||
//
|
||||
// The semantics of the initialization-step is that the initializers
|
||||
// in ModelProto.graph and in TrainingInfoProto.algorithm are first
|
||||
// initialized as specified by the initializers in the graph, and then
|
||||
// updated by the "initialization_binding" in every instance in
|
||||
// ModelProto.training_info.
|
||||
//
|
||||
// The field "algorithm" defines a computation graph which represents a
|
||||
// training algorithm's step. After the execution of a
|
||||
// TrainingInfoProto.algorithm, the initializers specified by "update_binding"
|
||||
// may be immediately updated. If the targeted training algorithm contains
|
||||
// consecutive update steps (such as block coordinate descent methods),
|
||||
// the user needs to create a TrainingInfoProto for each step.
|
||||
message TrainingInfoProto {
|
||||
// This field describes a graph to compute the initial tensors
|
||||
// upon starting the training process. Initialization graph has no input
|
||||
// and can have multiple outputs. Usually, trainable tensors in neural
|
||||
// networks are randomly initialized. To achieve that, for each tensor,
|
||||
// the user can put a random number operator such as RandomNormal or
|
||||
// RandomUniform in TrainingInfoProto.initialization.node and assign its
|
||||
// random output to the specific tensor using "initialization_binding".
|
||||
// This graph can also set the initializers in "algorithm" in the same
|
||||
// TrainingInfoProto; a use case is resetting the number of training
|
||||
// iteration to zero.
|
||||
//
|
||||
// By default, this field is an empty graph and its evaluation does not
|
||||
// produce any output. Thus, no initializer would be changed by default.
|
||||
GraphProto initialization = 1;
|
||||
|
||||
// This field represents a training algorithm step. Given required inputs,
|
||||
// it computes outputs to update initializers in its own or inference graph's
|
||||
// initializer lists. In general, this field contains loss node, gradient node,
|
||||
// optimizer node, increment of iteration count.
|
||||
//
|
||||
// An execution of the training algorithm step is performed by executing the
|
||||
// graph obtained by combining the inference graph (namely "ModelProto.graph")
|
||||
// and the "algorithm" graph. That is, the actual the actual
|
||||
// input/initializer/output/node/value_info/sparse_initializer list of
|
||||
// the training graph is the concatenation of
|
||||
// "ModelProto.graph.input/initializer/output/node/value_info/sparse_initializer"
|
||||
// and "algorithm.input/initializer/output/node/value_info/sparse_initializer"
|
||||
// in that order. This combined graph must satisfy the normal ONNX conditions.
|
||||
// Now, let's provide a visualization of graph combination for clarity.
|
||||
// Let the inference graph (i.e., "ModelProto.graph") be
|
||||
// tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d
|
||||
// and the "algorithm" graph be
|
||||
// tensor_d -> Add -> tensor_e
|
||||
// The combination process results
|
||||
// tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d -> Add -> tensor_e
|
||||
//
|
||||
// Notice that an input of a node in the "algorithm" graph may reference the
|
||||
// output of a node in the inference graph (but not the other way round). Also, inference
|
||||
// node cannot reference inputs of "algorithm". With these restrictions, inference graph
|
||||
// can always be run independently without training information.
|
||||
//
|
||||
// By default, this field is an empty graph and its evaluation does not
|
||||
// produce any output. Evaluating the default training step never
|
||||
// update any initializers.
|
||||
GraphProto algorithm = 2;
|
||||
|
||||
// This field specifies the bindings from the outputs of "initialization" to
|
||||
// some initializers in "ModelProto.graph.initializer" and
|
||||
// the "algorithm.initializer" in the same TrainingInfoProto.
|
||||
// See "update_binding" below for details.
|
||||
//
|
||||
// By default, this field is empty and no initializer would be changed
|
||||
// by the execution of "initialization".
|
||||
repeated StringStringEntryProto initialization_binding = 3;
|
||||
|
||||
// Gradient-based training is usually an iterative procedure. In one gradient
|
||||
// descent iteration, we apply
|
||||
//
|
||||
// x = x - r * g
|
||||
//
|
||||
// where "x" is the optimized tensor, "r" stands for learning rate, and "g" is
|
||||
// gradient of "x" with respect to a chosen loss. To avoid adding assignments
|
||||
// into the training graph, we split the update equation into
|
||||
//
|
||||
// y = x - r * g
|
||||
// x = y
|
||||
//
|
||||
// The user needs to save "y = x - r * g" into TrainingInfoProto.algorithm. To
|
||||
// tell that "y" should be assigned to "x", the field "update_binding" may
|
||||
// contain a key-value pair of strings, "x" (key of StringStringEntryProto)
|
||||
// and "y" (value of StringStringEntryProto).
|
||||
// For a neural network with multiple trainable (mutable) tensors, there can
|
||||
// be multiple key-value pairs in "update_binding".
|
||||
//
|
||||
// The initializers appears as keys in "update_binding" are considered
|
||||
// mutable variables. This implies some behaviors
|
||||
// as described below.
|
||||
//
|
||||
// 1. We have only unique keys in all "update_binding"s so that two
|
||||
// variables may not have the same name. This ensures that one
|
||||
// variable is assigned up to once.
|
||||
// 2. The keys must appear in names of "ModelProto.graph.initializer" or
|
||||
// "TrainingInfoProto.algorithm.initializer".
|
||||
// 3. The values must be output names of "algorithm" or "ModelProto.graph.output".
|
||||
// 4. Mutable variables are initialized to the value specified by the
|
||||
// corresponding initializer, and then potentially updated by
|
||||
// "initializer_binding"s and "update_binding"s in "TrainingInfoProto"s.
|
||||
//
|
||||
// This field usually contains names of trainable tensors
|
||||
// (in ModelProto.graph), optimizer states such as momentums in advanced
|
||||
// stochastic gradient methods (in TrainingInfoProto.graph),
|
||||
// and number of training iterations (in TrainingInfoProto.graph).
|
||||
//
|
||||
// By default, this field is empty and no initializer would be changed
|
||||
// by the execution of "algorithm".
|
||||
repeated StringStringEntryProto update_binding = 4;
|
||||
}
|
||||
|
||||
// Models
|
||||
//
|
||||
// ModelProto is a top-level file/container format for bundling a ML model and
|
||||
// associating its computation graph with metadata.
|
||||
//
|
||||
// The semantics of the model are described by the associated GraphProto's.
|
||||
message ModelProto {
|
||||
// The version of the IR this model targets. See Version enum above.
|
||||
// This field MUST be present.
|
||||
int64 ir_version = 1;
|
||||
|
||||
// The OperatorSets this model relies on.
|
||||
// All ModelProtos MUST have at least one entry that
|
||||
// specifies which version of the ONNX OperatorSet is
|
||||
// being imported.
|
||||
//
|
||||
// All nodes in the ModelProto's graph will bind against the operator
|
||||
// with the same-domain/same-op_type operator with the HIGHEST version
|
||||
// in the referenced operator sets.
|
||||
repeated OperatorSetIdProto opset_import = 8;
|
||||
|
||||
// The name of the framework or tool used to generate this model.
|
||||
// This field SHOULD be present to indicate which implementation/tool/framework
|
||||
// emitted the model.
|
||||
string producer_name = 2;
|
||||
|
||||
// The version of the framework or tool used to generate this model.
|
||||
// This field SHOULD be present to indicate which implementation/tool/framework
|
||||
// emitted the model.
|
||||
string producer_version = 3;
|
||||
|
||||
// Domain name of the model.
|
||||
// We use reverse domain names as name space indicators. For example:
|
||||
// `com.facebook.fair` or `com.microsoft.cognitiveservices`
|
||||
//
|
||||
// Together with `model_version` and GraphProto.name, this forms the unique identity of
|
||||
// the graph.
|
||||
string domain = 4;
|
||||
|
||||
// The version of the graph encoded. See Version enum below.
|
||||
int64 model_version = 5;
|
||||
|
||||
// A human-readable documentation for this model. Markdown is allowed.
|
||||
string doc_string = 6;
|
||||
|
||||
// The parameterized graph that is evaluated to execute the model.
|
||||
GraphProto graph = 7;
|
||||
|
||||
// Named metadata values; keys should be distinct.
|
||||
repeated StringStringEntryProto metadata_props = 14;
|
||||
|
||||
// Training-specific information. Sequentially executing all stored
|
||||
// `TrainingInfoProto.algorithm`s and assigning their outputs following
|
||||
// the corresponding `TrainingInfoProto.update_binding`s is one training
|
||||
// iteration. Similarly, to initialize the model
|
||||
// (as if training hasn't happened), the user should sequentially execute
|
||||
// all stored `TrainingInfoProto.initialization`s and assigns their outputs
|
||||
// using `TrainingInfoProto.initialization_binding`s.
|
||||
//
|
||||
// If this field is empty, the training behavior of the model is undefined.
|
||||
repeated TrainingInfoProto training_info = 20;
|
||||
|
||||
// A list of function protos local to the model.
|
||||
//
|
||||
// Name of the function "FunctionProto.name" should be unique within the domain "FunctionProto.domain".
|
||||
// In case of any conflicts the behavior (whether the model local functions are given higher priority,
|
||||
// or standard opserator sets are given higher priotity or this is treated as error) is defined by
|
||||
// the runtimes.
|
||||
//
|
||||
// The operator sets imported by FunctionProto should be compatible with the ones
|
||||
// imported by ModelProto and other model local FunctionProtos.
|
||||
// Example, if same operator set say 'A' is imported by a FunctionProto and ModelProto
|
||||
// or by 2 FunctionProtos then versions for the operator set may be different but,
|
||||
// the operator schema returned for op_type, domain, version combination
|
||||
// for both the versions should be same for every node in the function body.
|
||||
//
|
||||
// One FunctionProto can reference other FunctionProto in the model, however, recursive reference
|
||||
// is not allowed.
|
||||
repeated FunctionProto functions = 25;
|
||||
};
|
||||
|
||||
// StringStringEntryProto follows the pattern for cross-proto-version maps.
|
||||
// See https://developers.google.com/protocol-buffers/docs/proto3#maps
|
||||
message StringStringEntryProto {
|
||||
string key = 1;
|
||||
string value = 2;
|
||||
};
|
||||
|
||||
message TensorAnnotation {
|
||||
string tensor_name = 1;
|
||||
// <key, value> pairs to annotate tensor specified by <tensor_name> above.
|
||||
// The keys used in the mapping below must be pre-defined in ONNX spec.
|
||||
// For example, for 8-bit linear quantization case, 'SCALE_TENSOR', 'ZERO_POINT_TENSOR' will be pre-defined as
|
||||
// quantization parameter keys.
|
||||
repeated StringStringEntryProto quant_parameter_tensor_names = 2;
|
||||
}
|
||||
|
||||
|
||||
|
||||
// Graphs
|
||||
//
|
||||
// A graph defines the computational logic of a model and is comprised of a parameterized
|
||||
// list of nodes that form a directed acyclic graph based on their inputs and outputs.
|
||||
// This is the equivalent of the "network" or "graph" in many deep learning
|
||||
// frameworks.
|
||||
message GraphProto {
|
||||
// The nodes in the graph, sorted topologically.
|
||||
repeated NodeProto node = 1;
|
||||
|
||||
// The name of the graph.
|
||||
string name = 2; // namespace Graph
|
||||
|
||||
// A list of named tensor values, used to specify constant inputs of the graph.
|
||||
// Each initializer (both TensorProto as well SparseTensorProto) MUST have a name.
|
||||
// The name MUST be unique across both initializer and sparse_initializer,
|
||||
// but the name MAY also appear in the input list.
|
||||
repeated TensorProto initializer = 5;
|
||||
|
||||
// Initializers (see above) stored in sparse format.
|
||||
repeated SparseTensorProto sparse_initializer = 15;
|
||||
|
||||
// A human-readable documentation for this graph. Markdown is allowed.
|
||||
string doc_string = 10;
|
||||
|
||||
// The inputs and outputs of the graph.
|
||||
repeated ValueInfoProto input = 11;
|
||||
repeated ValueInfoProto output = 12;
|
||||
|
||||
// Information for the values in the graph. The ValueInfoProto.name's
|
||||
// must be distinct. It is optional for a value to appear in value_info list.
|
||||
repeated ValueInfoProto value_info = 13;
|
||||
|
||||
// This field carries information to indicate the mapping among a tensor and its
|
||||
// quantization parameter tensors. For example:
|
||||
// For tensor 'a', it may have {'SCALE_TENSOR', 'a_scale'} and {'ZERO_POINT_TENSOR', 'a_zero_point'} annotated,
|
||||
// which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model.
|
||||
repeated TensorAnnotation quantization_annotation = 14;
|
||||
|
||||
reserved 3, 4, 6 to 9;
|
||||
reserved "ir_version", "producer_version", "producer_tag", "domain";
|
||||
}
|
||||
|
||||
// Tensors
|
||||
//
|
||||
// A serialized tensor value.
|
||||
message TensorProto {
|
||||
enum DataType {
|
||||
UNDEFINED = 0;
|
||||
// Basic types.
|
||||
FLOAT = 1; // float
|
||||
UINT8 = 2; // uint8_t
|
||||
INT8 = 3; // int8_t
|
||||
UINT16 = 4; // uint16_t
|
||||
INT16 = 5; // int16_t
|
||||
INT32 = 6; // int32_t
|
||||
INT64 = 7; // int64_t
|
||||
STRING = 8; // string
|
||||
BOOL = 9; // bool
|
||||
|
||||
// IEEE754 half-precision floating-point format (16 bits wide).
|
||||
// This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits.
|
||||
FLOAT16 = 10;
|
||||
|
||||
DOUBLE = 11;
|
||||
UINT32 = 12;
|
||||
UINT64 = 13;
|
||||
COMPLEX64 = 14; // complex with float32 real and imaginary components
|
||||
COMPLEX128 = 15; // complex with float64 real and imaginary components
|
||||
|
||||
// Non-IEEE floating-point format based on IEEE754 single-precision
|
||||
// floating-point number truncated to 16 bits.
|
||||
// This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits.
|
||||
BFLOAT16 = 16;
|
||||
|
||||
// Future extensions go here.
|
||||
}
|
||||
|
||||
// The shape of the tensor.
|
||||
repeated int64 dims = 1;
|
||||
|
||||
// The data type of the tensor.
|
||||
// This field MUST have a valid TensorProto.DataType value
|
||||
int32 data_type = 2;
|
||||
|
||||
// For very large tensors, we may want to store them in chunks, in which
|
||||
// case the following fields will specify the segment that is stored in
|
||||
// the current TensorProto.
|
||||
message Segment {
|
||||
int64 begin = 1;
|
||||
int64 end = 2;
|
||||
}
|
||||
Segment segment = 3;
|
||||
|
||||
// Tensor content must be organized in row-major order.
|
||||
//
|
||||
// Depending on the data_type field, exactly one of the fields below with
|
||||
// name ending in _data is used to store the elements of the tensor.
|
||||
|
||||
// For float and complex64 values
|
||||
// Complex64 tensors are encoded as a single array of floats,
|
||||
// with the real components appearing in odd numbered positions,
|
||||
// and the corresponding imaginary component appearing in the
|
||||
// subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]
|
||||
// is encoded as [1.0, 2.0 ,3.0 ,4.0]
|
||||
// When this field is present, the data_type field MUST be FLOAT or COMPLEX64.
|
||||
repeated float float_data = 4 [packed = true];
|
||||
|
||||
// For int32, uint8, int8, uint16, int16, bool, and float16 values
|
||||
// float16 values must be bit-wise converted to an uint16_t prior
|
||||
// to writing to the buffer.
|
||||
// When this field is present, the data_type field MUST be
|
||||
// INT32, INT16, INT8, UINT16, UINT8, BOOL, FLOAT16 or BFLOAT16
|
||||
repeated int32 int32_data = 5 [packed = true];
|
||||
|
||||
// For strings.
|
||||
// Each element of string_data is a UTF-8 encoded Unicode
|
||||
// string. No trailing null, no leading BOM. The protobuf "string"
|
||||
// scalar type is not used to match ML community conventions.
|
||||
// When this field is present, the data_type field MUST be STRING
|
||||
repeated bytes string_data = 6;
|
||||
|
||||
// For int64.
|
||||
// When this field is present, the data_type field MUST be INT64
|
||||
repeated int64 int64_data = 7 [packed = true];
|
||||
|
||||
// Optionally, a name for the tensor.
|
||||
string name = 8; // namespace Value
|
||||
|
||||
// A human-readable documentation for this tensor. Markdown is allowed.
|
||||
string doc_string = 12;
|
||||
|
||||
// Serializations can either use one of the fields above, or use this
|
||||
// raw bytes field. The only exception is the string case, where one is
|
||||
// required to store the content in the repeated bytes string_data field.
|
||||
//
|
||||
// When this raw_data field is used to store tensor value, elements MUST
|
||||
// be stored in as fixed-width, little-endian order.
|
||||
// Floating-point data types MUST be stored in IEEE 754 format.
|
||||
// Complex64 elements must be written as two consecutive FLOAT values, real component first.
|
||||
// Complex128 elements must be written as two consecutive DOUBLE values, real component first.
|
||||
// Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false).
|
||||
//
|
||||
// Note: the advantage of specific field rather than the raw_data field is
|
||||
// that in some cases (e.g. int data), protobuf does a better packing via
|
||||
// variable length storage, and may lead to smaller binary footprint.
|
||||
// When this field is present, the data_type field MUST NOT be STRING or UNDEFINED
|
||||
bytes raw_data = 9;
|
||||
|
||||
// Data can be stored inside the protobuf file using type-specific fields or raw_data.
|
||||
// Alternatively, raw bytes data can be stored in an external file, using the external_data field.
|
||||
// external_data stores key-value pairs describing data location. Recognized keys are:
|
||||
// - "location" (required) - POSIX filesystem path relative to the directory where the ONNX
|
||||
// protobuf model was stored
|
||||
// - "offset" (optional) - position of byte at which stored data begins. Integer stored as string.
|
||||
// Offset values SHOULD be multiples 4096 (page size) to enable mmap support.
|
||||
// - "length" (optional) - number of bytes containing data. Integer stored as string.
|
||||
// - "checksum" (optional) - SHA1 digest of file specified in under 'location' key.
|
||||
repeated StringStringEntryProto external_data = 13;
|
||||
|
||||
// Location of the data for this tensor. MUST be one of:
|
||||
// - DEFAULT - data stored inside the protobuf message. Data is stored in raw_data (if set) otherwise in type-specified field.
|
||||
// - EXTERNAL - data stored in an external location as described by external_data field.
|
||||
enum DataLocation {
|
||||
DEFAULT = 0;
|
||||
EXTERNAL = 1;
|
||||
}
|
||||
|
||||
// If value not set, data is stored in raw_data (if set) otherwise in type-specified field.
|
||||
DataLocation data_location = 14;
|
||||
|
||||
// For double
|
||||
// Complex128 tensors are encoded as a single array of doubles,
|
||||
// with the real components appearing in odd numbered positions,
|
||||
// and the corresponding imaginary component appearing in the
|
||||
// subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]
|
||||
// is encoded as [1.0, 2.0 ,3.0 ,4.0]
|
||||
// When this field is present, the data_type field MUST be DOUBLE or COMPLEX128
|
||||
repeated double double_data = 10 [packed = true];
|
||||
|
||||
// For uint64 and uint32 values
|
||||
// When this field is present, the data_type field MUST be
|
||||
// UINT32 or UINT64
|
||||
repeated uint64 uint64_data = 11 [packed = true];
|
||||
}
|
||||
|
||||
// A serialized sparse-tensor value
|
||||
message SparseTensorProto {
|
||||
// The sequence of non-default values are encoded as a tensor of shape [NNZ].
|
||||
// The default-value is zero for numeric tensors, and empty-string for string tensors.
|
||||
// values must have a non-empty name present which serves as a name for SparseTensorProto
|
||||
// when used in sparse_initializer list.
|
||||
TensorProto values = 1;
|
||||
|
||||
// The indices of the non-default values, which may be stored in one of two formats.
|
||||
// (a) Indices can be a tensor of shape [NNZ, rank] with the [i,j]-th value
|
||||
// corresponding to the j-th index of the i-th value (in the values tensor).
|
||||
// (b) Indices can be a tensor of shape [NNZ], in which case the i-th value
|
||||
// must be the linearized-index of the i-th value (in the values tensor).
|
||||
// The linearized-index can be converted into an index tuple (k_1,...,k_rank)
|
||||
// using the shape provided below.
|
||||
// The indices must appear in ascending order without duplication.
|
||||
// In the first format, the ordering is lexicographic-ordering:
|
||||
// e.g., index-value [1,4] must appear before [2,1]
|
||||
TensorProto indices = 2;
|
||||
|
||||
// The shape of the underlying dense-tensor: [dim_1, dim_2, ... dim_rank]
|
||||
repeated int64 dims = 3;
|
||||
}
|
||||
|
||||
// Defines a tensor shape. A dimension can be either an integer value
|
||||
// or a symbolic variable. A symbolic variable represents an unknown
|
||||
// dimension.
|
||||
message TensorShapeProto {
|
||||
message Dimension {
|
||||
oneof value {
|
||||
int64 dim_value = 1;
|
||||
string dim_param = 2; // namespace Shape
|
||||
};
|
||||
// Standard denotation can optionally be used to denote tensor
|
||||
// dimensions with standard semantic descriptions to ensure
|
||||
// that operations are applied to the correct axis of a tensor.
|
||||
// Refer to https://github.com/onnx/onnx/blob/main/docs/DimensionDenotation.md#denotation-definition
|
||||
// for pre-defined dimension denotations.
|
||||
string denotation = 3;
|
||||
};
|
||||
repeated Dimension dim = 1;
|
||||
}
|
||||
|
||||
// Types
|
||||
//
|
||||
// The standard ONNX data types.
|
||||
message TypeProto {
|
||||
|
||||
message Tensor {
|
||||
// This field MUST NOT have the value of UNDEFINED
|
||||
// This field MUST have a valid TensorProto.DataType value
|
||||
// This field MUST be present for this version of the IR.
|
||||
int32 elem_type = 1;
|
||||
TensorShapeProto shape = 2;
|
||||
}
|
||||
|
||||
// repeated T
|
||||
message Sequence {
|
||||
// The type and optional shape of each element of the sequence.
|
||||
// This field MUST be present for this version of the IR.
|
||||
TypeProto elem_type = 1;
|
||||
};
|
||||
|
||||
// map<K,V>
|
||||
message Map {
|
||||
// This field MUST have a valid TensorProto.DataType value
|
||||
// This field MUST be present for this version of the IR.
|
||||
// This field MUST refer to an integral type ([U]INT{8|16|32|64}) or STRING
|
||||
int32 key_type = 1;
|
||||
// This field MUST be present for this version of the IR.
|
||||
TypeProto value_type = 2;
|
||||
};
|
||||
|
||||
// wrapper for Tensor, Sequence, or Map
|
||||
message Optional {
|
||||
// The type and optional shape of the element wrapped.
|
||||
// This field MUST be present for this version of the IR.
|
||||
// Possible values correspond to OptionalProto.DataType enum
|
||||
TypeProto elem_type = 1;
|
||||
};
|
||||
|
||||
|
||||
message SparseTensor {
|
||||
// This field MUST NOT have the value of UNDEFINED
|
||||
// This field MUST have a valid TensorProto.DataType value
|
||||
// This field MUST be present for this version of the IR.
|
||||
int32 elem_type = 1;
|
||||
TensorShapeProto shape = 2;
|
||||
}
|
||||
|
||||
|
||||
oneof value {
|
||||
// The type of a tensor.
|
||||
Tensor tensor_type = 1;
|
||||
|
||||
// NOTE: DNN-only implementations of ONNX MAY elect to not support non-tensor values
|
||||
// as input and output to graphs and nodes. These types are needed to naturally
|
||||
// support classical ML operators. DNN operators SHOULD restrict their input
|
||||
// and output types to tensors.
|
||||
|
||||
// The type of a sequence.
|
||||
Sequence sequence_type = 4;
|
||||
|
||||
// The type of a map.
|
||||
Map map_type = 5;
|
||||
|
||||
// The type of an optional.
|
||||
Optional optional_type = 9;
|
||||
|
||||
|
||||
// Type of the sparse tensor
|
||||
SparseTensor sparse_tensor_type = 8;
|
||||
|
||||
}
|
||||
|
||||
// An optional denotation can be used to denote the whole
|
||||
// type with a standard semantic description as to what is
|
||||
// stored inside. Refer to https://github.com/onnx/onnx/blob/main/docs/TypeDenotation.md#type-denotation-definition
|
||||
// for pre-defined type denotations.
|
||||
string denotation = 6;
|
||||
}
|
||||
|
||||
// Operator Sets
|
||||
//
|
||||
// OperatorSets are uniquely identified by a (domain, opset_version) pair.
|
||||
message OperatorSetIdProto {
|
||||
// The domain of the operator set being identified.
|
||||
// The empty string ("") or absence of this field implies the operator
|
||||
// set that is defined as part of the ONNX specification.
|
||||
// This field MUST be present in this version of the IR when referring to any other operator set.
|
||||
string domain = 1;
|
||||
|
||||
// The version of the operator set being identified.
|
||||
// This field MUST be present in this version of the IR.
|
||||
int64 version = 2;
|
||||
}
|
||||
|
||||
// Operator/function status.
|
||||
enum OperatorStatus {
|
||||
EXPERIMENTAL = 0;
|
||||
STABLE = 1;
|
||||
}
|
||||
|
||||
message FunctionProto {
|
||||
// The name of the function, similar usage of op_type in OperatorProto.
|
||||
// Combined with FunctionProto.domain, this forms the unique identity of
|
||||
// the FunctionProto.
|
||||
string name = 1;
|
||||
|
||||
// Deprecated since IR Version 8
|
||||
// optional int64 since_version = 2;
|
||||
reserved 2;
|
||||
reserved "since_version";
|
||||
|
||||
// Deprecated since IR Version 8
|
||||
// optional OperatorStatus status = 3;
|
||||
reserved 3;
|
||||
reserved "status";
|
||||
|
||||
// The inputs and outputs of the function.
|
||||
repeated string input = 4;
|
||||
repeated string output = 5;
|
||||
|
||||
// The attribute parameters of the function.
|
||||
// It is for function parameters without default values.
|
||||
repeated string attribute = 6;
|
||||
|
||||
// The attribute protos of the function.
|
||||
// It is for function attributes with default values.
|
||||
// A function attribute shall be represented either as
|
||||
// a string attribute or an AttributeProto, not both.
|
||||
repeated AttributeProto attribute_proto = 11;
|
||||
|
||||
// The nodes in the function.
|
||||
repeated NodeProto node = 7;
|
||||
// A human-readable documentation for this function. Markdown is allowed.
|
||||
string doc_string = 8;
|
||||
|
||||
// The OperatorSets this function body (graph) relies on.
|
||||
//
|
||||
// All nodes in the function body (graph) will bind against the operator
|
||||
// with the same-domain/same-op_type operator with the HIGHEST version
|
||||
// in the referenced operator sets. This means at most one version can be relied
|
||||
// for one domain.
|
||||
//
|
||||
// The operator sets imported by FunctionProto should be compatible with the ones
|
||||
// imported by ModelProto. Example, if same operator set say 'A' is imported by FunctionProto
|
||||
// and ModelProto then versions for the operator set may be different but,
|
||||
// the operator schema returned for op_type, domain, version combination
|
||||
// for both the versions should be same.
|
||||
|
||||
repeated OperatorSetIdProto opset_import = 9;
|
||||
|
||||
// The domain which this function belongs to. Combined with FunctionProto.name, this forms the unique identity of
|
||||
// the FunctionProto.
|
||||
string domain = 10;
|
||||
}
|
||||
|
||||
|
||||
// For using protobuf-lite
|
||||
option optimize_for = LITE_RUNTIME;
|
|
@ -0,0 +1,187 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use burn::tensor;
|
||||
|
||||
use burn_ndarray::NdArrayBackend;
|
||||
|
||||
use super::{
|
||||
ir::{ArgType, Argument, Node, NodeType, Tensor},
|
||||
op_configuration::{conv2d_config, flatten_config, linear_config},
|
||||
};
|
||||
|
||||
/// Infer the shape of each node and replace the shape of the output tensor
|
||||
pub fn shape_inference(
|
||||
nodes: &mut Vec<Node>,
|
||||
graph_inputs: &Vec<Argument>,
|
||||
graph_outputs: &mut Vec<Argument>,
|
||||
) {
|
||||
let mut prev_outputs: HashMap<String, Argument> = HashMap::new();
|
||||
|
||||
for output in graph_inputs.iter() {
|
||||
prev_outputs.insert(output.name.clone(), output.clone());
|
||||
}
|
||||
|
||||
for node in nodes.iter_mut() {
|
||||
match node.node_type {
|
||||
NodeType::Conv2d => conv2d(node, &prev_outputs),
|
||||
NodeType::Linear => linear(node, &prev_outputs),
|
||||
NodeType::Relu => relu(node, &prev_outputs),
|
||||
NodeType::Flatten => flatten(node, &prev_outputs),
|
||||
NodeType::LogSoftmax => log_softmax(node, &prev_outputs),
|
||||
_ => todo!(
|
||||
"shape inference for {:?} is not implemented",
|
||||
node.node_type
|
||||
),
|
||||
}
|
||||
|
||||
for output in node.outputs.iter() {
|
||||
prev_outputs.insert(output.name.clone(), output.clone());
|
||||
}
|
||||
}
|
||||
|
||||
//update the outputs of the graph from prev_outputs
|
||||
for output in graph_outputs.iter_mut() {
|
||||
let arg = prev_outputs.get(output.name.as_str()).unwrap();
|
||||
output.arg_type = arg.arg_type.clone();
|
||||
}
|
||||
}
|
||||
|
||||
/// Infer the shape of the output tensor of a Conv2d node
|
||||
fn linear(curr: &mut Node, prev_outpus: &HashMap<String, Argument>) {
|
||||
if curr.inputs.len() != 1 {
|
||||
panic!("Linear: multiple inputs are not supported");
|
||||
}
|
||||
|
||||
// Fill in the missing information about the input tensor from the previous outputs
|
||||
let prev_node_output = prev_outpus.get(curr.inputs[0].name.as_str()).unwrap();
|
||||
curr.inputs[0].arg_type = prev_node_output.arg_type.clone();
|
||||
|
||||
// Extract the configuration of the linear layer (inputs are known)
|
||||
let config = linear_config(curr);
|
||||
|
||||
// Replace the output tensor
|
||||
let curr_input = &mut curr.inputs[0];
|
||||
let ArgType::Tensor(tensor) = curr_input.clone().arg_type.unwrap();
|
||||
let mut new_shape = tensor.shape.clone();
|
||||
// Update the last dimension of the shape
|
||||
new_shape[tensor.shape.len() - 1] = config.d_input;
|
||||
|
||||
// Update the output tensor
|
||||
curr.outputs[0].arg_type = Some(ArgType::Tensor(Tensor {
|
||||
name: None,
|
||||
shape: new_shape,
|
||||
data: None,
|
||||
elem_type: tensor.elem_type,
|
||||
}));
|
||||
}
|
||||
|
||||
/// Infers the shape of a Relu node and replaces the shape of the output tensor.
|
||||
fn relu(curr: &mut Node, prev_outpus: &HashMap<String, Argument>) {
|
||||
if curr.inputs.len() != 1 {
|
||||
panic!("Relu: multiple inputs are not supported");
|
||||
}
|
||||
|
||||
// Fill in the missing information about the input tensor from the previous outputs
|
||||
let prev_node_output = prev_outpus.get(curr.inputs[0].name.as_str()).unwrap();
|
||||
|
||||
curr.inputs[0].arg_type = prev_node_output.arg_type.clone();
|
||||
|
||||
curr.outputs[0].arg_type = prev_node_output.arg_type.clone();
|
||||
}
|
||||
|
||||
/// Infers the shape of a Flatten node and replaces the shape of the output tensor.
|
||||
fn flatten(curr: &mut Node, prev_outpus: &HashMap<String, Argument>) {
|
||||
if curr.inputs.len() != 1 {
|
||||
panic!("Flatten: multiple inputs are not supported");
|
||||
}
|
||||
|
||||
// Fill in the missing information about the input tensor from the previous outputs
|
||||
let prev_node_output = prev_outpus.get(curr.inputs[0].name.as_str()).unwrap();
|
||||
|
||||
curr.inputs[0].arg_type = prev_node_output.arg_type.clone();
|
||||
|
||||
let curr_input = &mut curr.inputs[0];
|
||||
|
||||
let ArgType::Tensor(tensor) = curr_input.clone().arg_type.unwrap();
|
||||
|
||||
let input_shape = tensor.shape;
|
||||
|
||||
let (start_dim, end_dim) = flatten_config(curr);
|
||||
|
||||
// calculate the new shape (code is taken from the flatten op)
|
||||
// use the same logic as in the flatten op
|
||||
// unfortunately the output tensor's dimensions (D2) are not known at compile time
|
||||
// that's why we have to calculate the new shape at runtime
|
||||
let mut new_dims = vec![0; input_shape.len() - (end_dim - start_dim)];
|
||||
let mut flatten_dims = 1;
|
||||
for i in input_shape[start_dim..=end_dim].iter() {
|
||||
flatten_dims *= i;
|
||||
}
|
||||
new_dims[..start_dim].copy_from_slice(&input_shape[..start_dim]);
|
||||
new_dims[start_dim] = flatten_dims;
|
||||
new_dims[start_dim + 1..].copy_from_slice(&input_shape[end_dim + 1..]);
|
||||
|
||||
curr.outputs[0].arg_type = Some(ArgType::Tensor(Tensor {
|
||||
name: None,
|
||||
shape: new_dims,
|
||||
data: None,
|
||||
elem_type: tensor.elem_type,
|
||||
}));
|
||||
}
|
||||
|
||||
/// Infers the shape of a LogSoftmax node and replaces the shape of the output tensor.
|
||||
fn log_softmax(curr: &mut Node, prev_outpus: &HashMap<String, Argument>) {
|
||||
if curr.inputs.len() != 1 {
|
||||
panic!("LogSoftmax: multiple inputs are not supported");
|
||||
}
|
||||
|
||||
// Fill in the missing information about the input tensor from the previous outputs
|
||||
let prev_node_output = prev_outpus.get(curr.inputs[0].name.as_str()).unwrap();
|
||||
curr.inputs[0].arg_type = prev_node_output.arg_type.clone();
|
||||
curr.outputs[0].arg_type = prev_node_output.arg_type.clone();
|
||||
}
|
||||
|
||||
/// Infers the shape of a Conv2d node and replaces the shape of the output tensor.
|
||||
///
|
||||
/// The shape of the output tensor is calculated by running the actual convolution operation.
|
||||
fn conv2d(curr: &mut Node, prev_outpus: &HashMap<String, Argument>) {
|
||||
// copy the type from the previous output to the current input
|
||||
|
||||
if curr.inputs.len() != 1 {
|
||||
panic!("Conv2d: multiple inputs are not supported");
|
||||
}
|
||||
|
||||
// Fill in the missing information about the input tensor from the previous outputs
|
||||
let curr_input = &mut curr.inputs[0];
|
||||
let prev = prev_outpus.get(curr_input.name.as_str()).unwrap();
|
||||
curr_input.arg_type = prev.arg_type.clone();
|
||||
|
||||
// extract the channels from the weight tensor's shape [out_channels, in_channels, ...]
|
||||
let ArgType::Tensor(tensor) = curr_input.clone().arg_type.unwrap();
|
||||
|
||||
let elem_type = tensor.elem_type;
|
||||
|
||||
if tensor.shape.len() != 4 {
|
||||
panic!("Conv2d: input tensor must be 4D");
|
||||
}
|
||||
|
||||
let mut input_shape: [usize; 4] = [0; 4];
|
||||
input_shape.copy_from_slice(tensor.shape.as_slice());
|
||||
|
||||
// using the real configuration, run through op and calculate an actual shape of the output tensor
|
||||
let config = conv2d_config(curr);
|
||||
|
||||
let conv2d = config.init();
|
||||
|
||||
let input = tensor::Tensor::<NdArrayBackend<f32>, 4>::zeros(input_shape);
|
||||
let output = conv2d.forward(input);
|
||||
|
||||
let output_shape = output.shape().dims.to_vec();
|
||||
|
||||
curr.outputs[0].arg_type = Some(ArgType::Tensor(Tensor {
|
||||
name: None,
|
||||
shape: output_shape,
|
||||
data: None,
|
||||
elem_type,
|
||||
}));
|
||||
}
|
Binary file not shown.
|
@ -12,6 +12,6 @@ edition = "2021"
|
|||
proc-macro = true
|
||||
|
||||
[dependencies]
|
||||
syn = "1.0.109"
|
||||
quote = "1.0.26"
|
||||
proc-macro2 = "1.0.52"
|
||||
syn = {workspace = true}
|
||||
quote = {workspace = true}
|
||||
proc-macro2 = {workspace = true}
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
[package]
|
||||
authors = ["Dilshod Tadjibaev (@antimora)"]
|
||||
edition = "2021"
|
||||
license = "MIT/Apache-2.0"
|
||||
name = "onnx-inference"
|
||||
publish = false
|
||||
version = "0.6.0"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
|
||||
[dependencies]
|
||||
burn = {path = "../../burn"}
|
||||
burn-ndarray = {path = "../../burn-ndarray"}
|
||||
serde = {workspace = true}
|
||||
|
||||
|
||||
[dev-dependencies]
|
||||
burn-dataset = {path = "../../burn-dataset"}
|
||||
|
||||
[build-dependencies]
|
||||
burn-import = {path = "../../burn-import"}
|
|
@ -0,0 +1,4 @@
|
|||
# ONNX Inference
|
||||
|
||||
This crate provides a simple example for importing ONNX model to Burn.
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
use burn_import::onnx::ModelCodeGen;
|
||||
|
||||
fn main() {
|
||||
// Generate the model code from the ONNX file.
|
||||
ModelCodeGen::new()
|
||||
.input("src/model/mnist.onnx")
|
||||
.out_dir("model/")
|
||||
.run_from_script();
|
||||
}
|
|
@ -0,0 +1,155 @@
|
|||
# Originally copied and modified from: https://github.com/pytorch/examples/blob/main/mnist/main.py
|
||||
# under the following license: BSD-3-Clause license
|
||||
|
||||
from __future__ import print_function
|
||||
import argparse
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from torchvision import datasets, transforms
|
||||
from torch.optim.lr_scheduler import StepLR
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.conv1 = nn.Conv2d(1, 8, 3)
|
||||
self.conv2 = nn.Conv2d(8, 16, 3)
|
||||
self.conv3 = nn.Conv2d(16, 24, 3)
|
||||
self.dropout1 = nn.Dropout(0.3)
|
||||
self.fc1 = nn.Linear(24 * 22 * 22, 32)
|
||||
self.fc2 = nn.Linear(32, 10)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = F.relu(x)
|
||||
x = self.conv2(x)
|
||||
x = F.relu(x)
|
||||
x = self.conv3(x)
|
||||
x = F.relu(x)
|
||||
|
||||
x = torch.flatten(x, 1)
|
||||
x = self.fc1(x)
|
||||
x = F.relu(x)
|
||||
x = self.dropout1(x)
|
||||
|
||||
x = self.fc2(x)
|
||||
output = F.log_softmax(x, dim=1)
|
||||
return output
|
||||
|
||||
|
||||
def train(args, model, device, train_loader, optimizer, epoch):
|
||||
model.train()
|
||||
for batch_idx, (data, target) in enumerate(train_loader):
|
||||
data, target = data.to(device), target.to(device)
|
||||
optimizer.zero_grad()
|
||||
output = model(data)
|
||||
loss = F.nll_loss(output, target)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
if batch_idx % args.log_interval == 0:
|
||||
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
|
||||
epoch, batch_idx * len(data), len(train_loader.dataset),
|
||||
100. * batch_idx / len(train_loader), loss.item()))
|
||||
if args.dry_run:
|
||||
break
|
||||
|
||||
|
||||
def test(model, device, test_loader):
|
||||
model.eval()
|
||||
test_loss = 0
|
||||
correct = 0
|
||||
with torch.no_grad():
|
||||
for data, target in test_loader:
|
||||
data, target = data.to(device), target.to(device)
|
||||
output = model(data)
|
||||
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
|
||||
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
|
||||
correct += pred.eq(target.view_as(pred)).sum().item()
|
||||
|
||||
test_loss /= len(test_loader.dataset)
|
||||
|
||||
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
|
||||
test_loss, correct, len(test_loader.dataset),
|
||||
100. * correct / len(test_loader.dataset)))
|
||||
|
||||
|
||||
def main():
|
||||
# Training settings
|
||||
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
|
||||
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
|
||||
help='input batch size for training (default: 64)')
|
||||
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
|
||||
help='input batch size for testing (default: 1000)')
|
||||
parser.add_argument('--epochs', type=int, default=14, metavar='N',
|
||||
help='number of epochs to train (default: 14)')
|
||||
parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
|
||||
help='learning rate (default: 1.0)')
|
||||
parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
|
||||
help='Learning rate step gamma (default: 0.7)')
|
||||
parser.add_argument('--no-cuda', action='store_true', default=False,
|
||||
help='disables CUDA training')
|
||||
parser.add_argument('--no-mps', action='store_true', default=False,
|
||||
help='disables macOS GPU training')
|
||||
parser.add_argument('--dry-run', action='store_true', default=False,
|
||||
help='quickly check a single pass')
|
||||
parser.add_argument('--seed', type=int, default=1, metavar='S',
|
||||
help='random seed (default: 1)')
|
||||
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
|
||||
help='how many batches to wait before logging training status')
|
||||
parser.add_argument('--save-model', action='store_true', default=False,
|
||||
help='For Saving the current Model')
|
||||
parser.add_argument('--export-onnx', action='store_true', default=True,
|
||||
help='For Saving the current Model in ONNX format')
|
||||
args = parser.parse_args()
|
||||
use_cuda = not args.no_cuda and torch.cuda.is_available()
|
||||
use_mps = not args.no_mps and torch.backends.mps.is_available()
|
||||
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
if use_cuda:
|
||||
device = torch.device("cuda")
|
||||
elif use_mps:
|
||||
device = torch.device("mps")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
|
||||
train_kwargs = {'batch_size': args.batch_size}
|
||||
test_kwargs = {'batch_size': args.test_batch_size}
|
||||
if use_cuda:
|
||||
cuda_kwargs = {'num_workers': 1,
|
||||
'pin_memory': True,
|
||||
'shuffle': True}
|
||||
train_kwargs.update(cuda_kwargs)
|
||||
test_kwargs.update(cuda_kwargs)
|
||||
|
||||
transform=transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307,), (0.3081,))
|
||||
])
|
||||
dataset1 = datasets.MNIST('/tmp/mnist-data', train=True, download=True,
|
||||
transform=transform)
|
||||
dataset2 = datasets.MNIST('/tmp/mnist-data', train=False,
|
||||
transform=transform)
|
||||
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
|
||||
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
|
||||
|
||||
model = Net().to(device)
|
||||
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
|
||||
|
||||
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
|
||||
for epoch in range(1, args.epochs + 1):
|
||||
train(args, model, device, train_loader, optimizer, epoch)
|
||||
test(model, device, test_loader)
|
||||
scheduler.step()
|
||||
|
||||
if args.save_model:
|
||||
torch.save(model, "mnist.pt")
|
||||
|
||||
if args.export_onnx:
|
||||
dummy_input = torch.randn(1, 1, 28, 28, device=device)
|
||||
torch.onnx.export(model, dummy_input, "mnist.onnx")
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1 @@
|
|||
pub mod model;
|
|
@ -0,0 +1,17 @@
|
|||
use burn::tensor;
|
||||
use burn_ndarray::NdArrayBackend;
|
||||
use onnx_inference::model::mnist::{Model, INPUT1_SHAPE};
|
||||
|
||||
fn main() {
|
||||
// Create a new model
|
||||
let model: Model<NdArrayBackend<f32>> = Model::new();
|
||||
|
||||
// Create a new input tensor (all zeros for demonstration purposes)
|
||||
let input = tensor::Tensor::<NdArrayBackend<f32>, 4>::zeros(INPUT1_SHAPE);
|
||||
|
||||
// Run the model
|
||||
let output = model.forward(input);
|
||||
|
||||
// Print the output
|
||||
println!("{:?}", output);
|
||||
}
|
Binary file not shown.
|
@ -0,0 +1,3 @@
|
|||
pub mod mnist {
|
||||
include!(concat!(env!("OUT_DIR"), "/model/mnist.rs"));
|
||||
}
|
Loading…
Reference in New Issue