Separating ONNX parsing from burn-import (#1921)

* separating onnx parsing from burn-import

* ran clippy and cargo-fmt

* removed unused deps from onnx-ir

* fixed clippy warnings that were causing run-checks to fail

* removed dead code

* removed unused dependencies from burn-import

* updated contributor-book, updated publish.yml, added readme

* update cargo lock

* formatted md document with prettier, rephrased sentence

* missed the errors with reduce_prod_conversion during merge

* formatted onnx-to-burn-conversion-tool.md, forgot to save
This commit is contained in:
Joshua Ferguson 2024-07-02 15:17:44 -05:00 committed by GitHub
parent 755c0708c4
commit 25348cf181
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 405 additions and 279 deletions

View File

@ -161,3 +161,9 @@ jobs:
with: with:
crate: burn-import crate: burn-import
secrets: inherit secrets: inherit
publish-onnx-ir:
uses: tracel-ai/burn/.github/workflows/publish-template.yml@main
with:
crate: onnx-ir
secrets: inherit

23
Cargo.lock generated
View File

@ -628,23 +628,19 @@ name = "burn-import"
version = "0.14.0" version = "0.14.0"
dependencies = [ dependencies = [
"burn", "burn",
"bytemuck",
"candle-core", "candle-core",
"derive-new", "derive-new",
"half", "half",
"log", "log",
"onnx-ir",
"pretty_assertions", "pretty_assertions",
"proc-macro2", "proc-macro2",
"protobuf",
"protobuf-codegen",
"quote", "quote",
"regex", "regex",
"rstest", "rstest",
"rust-format", "rust-format",
"serde", "serde",
"serde_json", "serde_json",
"strum",
"strum_macros",
"syn 2.0.68", "syn 2.0.68",
"thiserror", "thiserror",
"tracing-core", "tracing-core",
@ -3689,6 +3685,23 @@ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "onnx-ir"
version = "0.14.0"
dependencies = [
"bytemuck",
"half",
"log",
"pretty_assertions",
"protobuf",
"protobuf-codegen",
"regex",
"rstest",
"serde",
"strum",
"strum_macros",
]
[[package]] [[package]]
name = "onnx-tests" name = "onnx-tests"
version = "0.14.0" version = "0.14.0"

View File

@ -117,7 +117,8 @@ plug in your operator in terms of \\(x\\) and \\(y\\), and just swap out the var
### Testing autodiff ### Testing autodiff
For testing the `autodiff` operations, please refer to [this section](../getting-started/testing.md). For testing the `autodiff` operations, please refer to
[this section](../getting-started/testing.md).
## Adding the Op to other backends ## Adding the Op to other backends
@ -199,11 +200,15 @@ Generating the ONNX test files or tests is already covered
about the specific changes you need to make when adding new operators after you have generated the about the specific changes you need to make when adding new operators after you have generated the
tests. tests.
The crate is divided into two sections `src/burn` and `src/onnx`. The code under the former Changes will need to be made to both `onnx-ir` and `burn-import`. The code within `onnx-ir` defines
corresponds to the operation you've implemented earlier in this guide, and the latter to the how to parse the nodes in an onnx file and produces the intermediate representation. The code within
operations defined in the ONNX specification. So when you are loading a model, the operator is first `burn-import` is divided into two sections: `src/onnx` and `src/burn`. The code under the former
parsed to an intermediate representation defined by `src/onnx`, and then mapped to a Burn operation maps that intermediate representation to one used for code generation and the latter defines how to
defined under `src/burn/node`. generate code for the operator you've implemented earlier in this guide.
So when you are loading a model, the operator is first parsed to an intermediate representation
defined by `burn-import` and then mapped to a Burn operation defined under `src/burn/node`; the
mapping from onnx to burn is aptly defined in `src/onnx/to_burn`
Let's review the changes made for powf starting from `src/burn` and moving to `src/onnx`: Let's review the changes made for powf starting from `src/burn` and moving to `src/onnx`:
@ -218,17 +223,20 @@ Let's review the changes made for powf starting from `src/burn` and moving to `s
[`{op}_conversion` function](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-import/src/onnx/to_burn.rs#L717) [`{op}_conversion` function](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-import/src/onnx/to_burn.rs#L717)
that maps the ONNX node to the binary type that maps the ONNX node to the binary type
3. Specify how dimensions for the output should be derived in 3. Specify how dimensions for the output should be derived in
[crates/burn-import/src/onnx/dim_inference.rs](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-import/src/onnx/dim_inference.rs#L55) [crates/onnx-ir/src/dim_inference.rs](https://github.com/tracel-ai/burn/blob/d4ae82b21ac3dd1def01bd380ab7ea4d3293eccb/crates/onnx-ir/src/dim_inference.rs#L17)
And you're done! Congrats, you just fully added a new operation to burn, and we are all one step And you're done! Congrats, you just fully added a new operation to burn, and we are all one step
closer to the answer to [Are we learning yet?](https://www.arewelearningyet.com/) being "Yes, and closer to the answer to [Are we learning yet?](https://www.arewelearningyet.com/) being "Yes, and
it's freaking fast!". Buy yourself a coffee. it's freaking fast!". Buy yourself a coffee.
[^supertrait]: for more on supertraits see [^supertrait]:
for more on supertraits see
[the advanced trait section of the rust book](https://doc.rust-lang.org/book/ch19-03-advanced-traits.html#using-supertraits-to-require-one-traits-functionality-within-another-trait) [the advanced trait section of the rust book](https://doc.rust-lang.org/book/ch19-03-advanced-traits.html#using-supertraits-to-require-one-traits-functionality-within-another-trait)
[^autodiff]: wiki link for [^autodiff]:
wiki link for
[automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation) [automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation)
[^absolute_units]: for more information on unit structs see [^absolute_units]:
for more information on unit structs see
[the defining and instantiating structs section of the rust book](https://doc.rust-lang.org/book/ch05-01-defining-structs.html#unit-like-structs-without-any-fields) [the defining and instantiating structs section of the rust book](https://doc.rust-lang.org/book/ch05-01-defining-structs.html#unit-like-structs-without-any-fields)

View File

@ -16,7 +16,17 @@ For an introduction to ONNX import in Burn, see
- [Design Goals](#design-goals) - [Design Goals](#design-goals)
- [Design Decisions](#design-decisions) - [Design Decisions](#design-decisions)
- [Adding New Operators](#adding-new-operators) - [Adding New Operators](#adding-new-operators)
- [Implementing a New Operator](#implementing-a-new-operator) - [Implementing a New Operator](#implementing-a-new-operator)
- [Step 1: Visibility](#step-1-visibility)
- [Step 2: Node Implementation](#step-2-node-implementation)
- [Within Onnx-IR](#within-onnx-ir)
- [Within burn-import](#within-burn-import)
- [Step 3: Registering New Operations](#step-3-registering-new-operations)
- [Step 4: Create a Config Function](#step-4-create-a-config-function)
- [Step 5: Dimension Inference](#step-5-dimension-inference)
- [Step 6: Integrate into the Graph Building Process](#step-6-integrate-into-the-graph-building-process)
- [Step 7: Add Newly Supported Op!](#step-7-add-newly-supported-op)
- [Misc:](#misc)
- [Testing](#testing) - [Testing](#testing)
- [Resources](#resources) - [Resources](#resources)
@ -91,6 +101,22 @@ located in the `src/burn/node/` directory.
### Step 2: Node Implementation ### Step 2: Node Implementation
#### Within Onnx-IR
If the node type does not exist within the
[`NodeType` enum](https://github.com/tracel-ai/burn/blob/d4ae82b21ac3dd1def01bd380ab7ea4d3293eccb/crates/onnx-ir/src/ir.rs#L246),
it will need to be added (support for custom operators is planned). If the node might be provided an
input which is a constant or the output of an identity node, it will need to be added to the list of
nodeTypes
[checked for constants](https://github.com/tracel-ai/burn/blob/d4ae82b21ac3dd1def01bd380ab7ea4d3293eccb/crates/onnx-ir/src/from_onnx.rs#L21).
The node will need to be added to `dim_inference`, and in most cases the work parsing side will be
done. If a node requires extra parsing (such as handling an edge case like potentially remapping an
unsqueeze to a reshape) the best place for that is after check constants and prior to dim_inference
in
[`OnnxGraphBuilder::Build`](https://github.com/tracel-ai/burn/blob/d4ae82b21ac3dd1def01bd380ab7ea4d3293eccb/crates/onnx-ir/src/from_onnx.rs#L221)
#### Within burn-import
Create a new file named `<operation_name>.rs` in the `src/burn/node/` directory. Create a new file named `<operation_name>.rs` in the `src/burn/node/` directory.
This file will define the structure and functionality of your new operation. By convention, the This file will define the structure and functionality of your new operation. By convention, the
necessary information for carrying out an operation is encapsulated within a struct named necessary information for carrying out an operation is encapsulated within a struct named

View File

@ -20,30 +20,23 @@ pytorch = ["burn/record-item-custom-serde", "thiserror", "zip"]
[dependencies] [dependencies]
burn = { path = "../burn", version = "0.14.0", features = ["ndarray"] } burn = { path = "../burn", version = "0.14.0", features = ["ndarray"] }
onnx-ir = { path = "../onnx-ir" }
bytemuck = { workspace = true }
candle-core = { workspace = true } candle-core = { workspace = true }
derive-new = { workspace = true } derive-new = { workspace = true }
half = { workspace = true } half = { workspace = true }
log = { workspace = true } log = { workspace = true }
proc-macro2 = { workspace = true } proc-macro2 = { workspace = true }
protobuf = { workspace = true, features = ["with-bytes"] }
quote = { workspace = true } quote = { workspace = true }
regex = { workspace = true } regex = { workspace = true }
rust-format = { workspace = true, features = ["token_stream", "post_process"] } rust-format = { workspace = true, features = ["token_stream", "post_process"] }
serde = { workspace = true, features = ["derive"] } serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true, features = ["std"] } serde_json = { workspace = true, features = ["std"] }
strum = { workspace = true }
strum_macros = { workspace = true }
syn = { workspace = true, features = ["parsing"] } syn = { workspace = true, features = ["parsing"] }
thiserror = { workspace = true, optional = true } thiserror = { workspace = true, optional = true }
tracing-core = { workspace = true } tracing-core = { workspace = true }
tracing-subscriber = { workspace = true } tracing-subscriber = { workspace = true }
zip = { workspace = true, optional = true } zip = { workspace = true, optional = true }
[build-dependencies]
protobuf-codegen = { workspace = true }
[dev-dependencies] [dev-dependencies]
pretty_assertions = { workspace = true } pretty_assertions = { workspace = true }
rstest = { workspace = true } rstest = { workspace = true }

View File

@ -1,11 +0,0 @@
fn main() {
if cfg!(feature = "onnx") {
// Generate the onnx protobuf files
protobuf_codegen::Codegen::new()
.pure()
.includes(["src"])
.input("src/onnx/protos/onnx.proto")
.cargo_out_dir("onnx-protos")
.run_from_script();
}
}

View File

@ -1,14 +1,3 @@
mod coalesce;
mod dim_inference;
mod from_onnx;
mod ir;
mod node_remap;
mod op_configuration; mod op_configuration;
mod proto_conversion;
mod protos;
mod to_burn; mod to_burn;
pub use to_burn::*; pub use to_burn::*;
pub use from_onnx::parse_onnx;
pub use ir::OnnxGraph;

View File

@ -5,8 +5,8 @@ use burn::nn::{
PaddingConfig2d, PaddingConfig2d,
}; };
use super::ir::{ArgType, AttributeValue, Data, Node};
use crate::burn::node::resize::ResizeMode; use crate::burn::node::resize::ResizeMode;
use onnx_ir::ir::{ArgType, AttributeValue, Data, Node};
/// Create a Conv1dConfig from the attributes of the node /// Create a Conv1dConfig from the attributes of the node
pub fn conv1d_config(curr: &Node) -> Conv1dConfig { pub fn conv1d_config(curr: &Node) -> Conv1dConfig {

View File

@ -53,20 +53,24 @@ use crate::{
}, },
format_tokens, format_tokens,
logger::init_log, logger::init_log,
onnx::{
from_onnx::convert_constant_value,
ir::{Node, NodeType},
op_configuration::*,
},
}; };
use super::{ use super::op_configuration::{
from_onnx::parse_onnx, argmax_config, avg_pool1d_config, avg_pool2d_config, batch_norm_config, clip_config,
ir::{self, ArgType, Argument, Data, ElementType, OnnxGraph}, concat_config, conv1d_config, conv2d_config, conv_transpose2d_config, dropout_config,
op_configuration::{ expand_config, flatten_config, gather_config, layer_norm_config, leaky_relu_config,
avg_pool2d_config, clip_config, concat_config, dropout_config, reshape_config, linear_config, log_softmax_config, max_pool1d_config, max_pool2d_config, reduce_max_config,
resize_config, softmax_config, reduce_mean_config, reduce_min_config, reduce_prod_config, reduce_sum_config, reshape_config,
resize_config, shape_config, slice_config, softmax_config, squeeze_config, transpose_config,
unsqueeze_config,
};
use onnx_ir::{
convert_constant_value,
ir::{
ArgType, Argument as OnnxArgument, Data, ElementType, Node, NodeType, OnnxGraph,
TensorType as OnnxTensorType,
}, },
parse_onnx,
}; };
pub use crate::burn::graph::RecordType; pub use crate::burn::graph::RecordType;
@ -197,6 +201,7 @@ impl ModelGen {
log::debug!("Output file: {:?}", out_file); log::debug!("Output file: {:?}", out_file);
let graph = parse_onnx(input.as_ref()); let graph = parse_onnx(input.as_ref());
let graph = ParsedOnnxGraph(graph);
if self.development { if self.development {
// export the graph // export the graph
@ -231,15 +236,16 @@ impl ModelGen {
log::info!("Model generated"); log::info!("Model generated");
} }
} }
#[derive(Debug)]
impl OnnxGraph { struct ParsedOnnxGraph(OnnxGraph);
impl ParsedOnnxGraph {
/// Converts ONNX graph to Burn graph. /// Converts ONNX graph to Burn graph.
pub fn into_burn<PS: PrecisionSettings + 'static>(self) -> BurnGraph<PS> { pub fn into_burn<PS: PrecisionSettings + 'static>(self) -> BurnGraph<PS> {
let mut graph = BurnGraph::<PS>::default(); let mut graph = BurnGraph::<PS>::default();
let mut unsupported_ops = vec![]; let mut unsupported_ops = vec![];
for node in self.nodes { for node in self.0.nodes {
match node.node_type { match node.node_type {
NodeType::Add => graph.register(Self::add_conversion(node)), NodeType::Add => graph.register(Self::add_conversion(node)),
NodeType::ArgMax => graph.register(Self::argmax_conversion(node)), NodeType::ArgMax => graph.register(Self::argmax_conversion(node)),
@ -328,11 +334,13 @@ impl OnnxGraph {
// Get input and output names // Get input and output names
let input_names = self let input_names = self
.0
.inputs .inputs
.iter() .iter()
.map(|input| input.name.clone()) .map(|input| input.name.clone())
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let output_names = self let output_names = self
.0
.outputs .outputs
.iter() .iter()
.map(|output| output.name.clone()) .map(|output| output.name.clone())
@ -390,13 +398,13 @@ impl OnnxGraph {
ArgType::Shape(_) => panic!("Shape is not supported as constant value."), ArgType::Shape(_) => panic!("Shape is not supported as constant value."),
}; };
ConstantNode::new(node.name.clone(), const_value, output.to_type()) ConstantNode::new(node.name.clone(), const_value, Type::from(output))
} }
fn random_uniform_conversion(node: Node) -> RandomUniformNode { fn random_uniform_conversion(node: Node) -> RandomUniformNode {
let output = node.outputs.first().unwrap(); let output = node.outputs.first().unwrap();
// cannot use output.to_tensor_type() here, since it drops the shape info... // cannot use output.to_tensor_type() here, since it drops the shape info...
let output_type = if let Type::Tensor(t) = output.to_type() { let output_type = if let Type::Tensor(t) = Type::from(output) {
t t
} else { } else {
panic!("RandomUniform output type is no Tensor."); panic!("RandomUniform output type is no Tensor.");
@ -423,7 +431,7 @@ impl OnnxGraph {
fn random_normal_conversion(node: Node) -> RandomNormalNode { fn random_normal_conversion(node: Node) -> RandomNormalNode {
let output = node.outputs.first().unwrap(); let output = node.outputs.first().unwrap();
// cannot use output.to_tensor_type() here, since it drops the shape info... // cannot use output.to_tensor_type() here, since it drops the shape info...
let output_type = if let Type::Tensor(t) = output.to_type() { let output_type = if let Type::Tensor(t) = Type::from(output) {
t t
} else { } else {
panic!("RandomNormal output type is no Tensor."); panic!("RandomNormal output type is no Tensor.");
@ -448,141 +456,141 @@ impl OnnxGraph {
} }
fn add_conversion(node: Node) -> BinaryNode { fn add_conversion(node: Node) -> BinaryNode {
let lhs = node.inputs.first().unwrap().to_type(); let lhs = Type::from(node.inputs.first().unwrap());
let rhs = node.inputs.get(1).unwrap().to_type(); let rhs = Type::from(node.inputs.get(1).unwrap());
let output = node.outputs.first().unwrap().to_type(); let output = Type::from(node.outputs.first().unwrap());
BinaryNode::add(lhs, rhs, output) BinaryNode::add(lhs, rhs, output)
} }
fn sub_conversion(node: Node) -> BinaryNode { fn sub_conversion(node: Node) -> BinaryNode {
let lhs = node.inputs.first().unwrap().to_type(); let lhs = Type::from(node.inputs.first().unwrap());
let rhs = node.inputs.get(1).unwrap().to_type(); let rhs = Type::from(node.inputs.get(1).unwrap());
let output = node.outputs.first().unwrap().to_type(); let output = Type::from(node.outputs.first().unwrap());
BinaryNode::sub(lhs, rhs, output) BinaryNode::sub(lhs, rhs, output)
} }
fn mul_conversion(node: Node) -> BinaryNode { fn mul_conversion(node: Node) -> BinaryNode {
let lhs = node.inputs.first().unwrap().to_type(); let lhs = Type::from(node.inputs.first().unwrap());
let rhs = node.inputs.get(1).unwrap().to_type(); let rhs = Type::from(node.inputs.get(1).unwrap());
let output = node.outputs.first().unwrap().to_type(); let output = Type::from(node.outputs.first().unwrap());
BinaryNode::mul(lhs, rhs, output) BinaryNode::mul(lhs, rhs, output)
} }
fn div_conversion(node: Node) -> BinaryNode { fn div_conversion(node: Node) -> BinaryNode {
let lhs = node.inputs.first().unwrap().to_type(); let lhs = Type::from(node.inputs.first().unwrap());
let rhs = node.inputs.get(1).unwrap().to_type(); let rhs = Type::from(node.inputs.get(1).unwrap());
let output = node.outputs.first().unwrap().to_type(); let output = Type::from(node.outputs.first().unwrap());
BinaryNode::div(lhs, rhs, output) BinaryNode::div(lhs, rhs, output)
} }
fn matmul_conversion(node: Node) -> MatmulNode { fn matmul_conversion(node: Node) -> MatmulNode {
let lhs = node.inputs.first().unwrap().to_tensor_type(); let lhs = TensorType::from(node.inputs.first().unwrap());
let rhs = node.inputs.get(1).unwrap().to_tensor_type(); let rhs = TensorType::from(node.inputs.get(1).unwrap());
let output = node.outputs.first().unwrap().to_tensor_type(); let output = TensorType::from(node.outputs.first().unwrap());
MatmulNode::new(lhs, rhs, output) MatmulNode::new(lhs, rhs, output)
} }
fn equal_conversion(node: Node) -> BinaryNode { fn equal_conversion(node: Node) -> BinaryNode {
let lhs = node.inputs.first().unwrap().to_type(); let lhs = Type::from(node.inputs.first().unwrap());
let rhs = node.inputs.get(1).unwrap().to_type(); let rhs = Type::from(node.inputs.get(1).unwrap());
let output = node.outputs.first().unwrap().to_type(); let output = Type::from(node.outputs.first().unwrap());
BinaryNode::equal(lhs, rhs, output) BinaryNode::equal(lhs, rhs, output)
} }
fn max_conversion(node: Node) -> BinaryNode { fn max_conversion(node: Node) -> BinaryNode {
let lhs = node.inputs.first().unwrap().to_type(); let lhs = Type::from(node.inputs.first().unwrap());
let rhs = node.inputs.get(1).unwrap().to_type(); let rhs = Type::from(node.inputs.get(1).unwrap());
let output = node.outputs.first().unwrap().to_type(); let output = Type::from(node.outputs.first().unwrap());
BinaryNode::max_pair(lhs, rhs, output) BinaryNode::max_pair(lhs, rhs, output)
} }
fn erf_conversion(node: Node) -> UnaryNode { fn erf_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type(); let input = Type::from(node.inputs.first().unwrap());
let output = node.outputs.first().unwrap().to_type(); let output = Type::from(node.outputs.first().unwrap());
UnaryNode::erf(input, output) UnaryNode::erf(input, output)
} }
fn leaky_relu_conversion(node: Node) -> UnaryNode { fn leaky_relu_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type(); let input = Type::from(node.inputs.first().unwrap());
let output = node.outputs.first().unwrap().to_type(); let output = Type::from(node.outputs.first().unwrap());
let alpha = leaky_relu_config(&node); let alpha = leaky_relu_config(&node);
UnaryNode::leaky_relu(input, output, alpha) UnaryNode::leaky_relu(input, output, alpha)
} }
fn relu_conversion(node: Node) -> UnaryNode { fn relu_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type(); let input = Type::from(node.inputs.first().unwrap());
let output = node.outputs.first().unwrap().to_type(); let output = Type::from(node.outputs.first().unwrap());
UnaryNode::relu(input, output) UnaryNode::relu(input, output)
} }
fn gelu_conversion(node: Node) -> UnaryNode { fn gelu_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type(); let input = Type::from(node.inputs.first().unwrap());
let output = node.outputs.first().unwrap().to_type(); let output = Type::from(node.outputs.first().unwrap());
UnaryNode::gelu(input, output) UnaryNode::gelu(input, output)
} }
fn log_conversion(node: Node) -> UnaryNode { fn log_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type(); let input = Type::from(node.inputs.first().unwrap());
let output = node.outputs.first().unwrap().to_type(); let output = Type::from(node.outputs.first().unwrap());
UnaryNode::log(input, output) UnaryNode::log(input, output)
} }
fn flatten_conversion(node: Node) -> UnaryNode { fn flatten_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type(); let input = Type::from(node.inputs.first().unwrap());
let output = node.outputs.first().unwrap().to_type(); let output = Type::from(node.outputs.first().unwrap());
let (start_dim, end_dim) = flatten_config(&node); let (start_dim, end_dim) = flatten_config(&node);
UnaryNode::flatten(input, output, start_dim, end_dim) UnaryNode::flatten(input, output, start_dim, end_dim)
} }
fn gather_conversion(node: Node) -> GatherNode { fn gather_conversion(node: Node) -> GatherNode {
let input = node.inputs.first().unwrap().to_tensor_type(); let input = TensorType::from(node.inputs.first().unwrap());
let index = node.inputs.get(1).unwrap().to_tensor_type(); let index = TensorType::from(node.inputs.get(1).unwrap());
let output = node.outputs.first().unwrap().to_tensor_type(); let output = TensorType::from(node.outputs.first().unwrap());
let dim = gather_config(&node); let dim = gather_config(&node);
GatherNode::new(input, index, output, dim) GatherNode::new(input, index, output, dim)
} }
fn gather_elements_conversion(node: Node) -> GatherElementsNode { fn gather_elements_conversion(node: Node) -> GatherElementsNode {
let input = node.inputs.first().unwrap().to_tensor_type(); let input = TensorType::from(node.inputs.first().unwrap());
let index = node.inputs.get(1).unwrap().to_tensor_type(); let index = TensorType::from(node.inputs.get(1).unwrap());
let output = node.outputs.first().unwrap().to_tensor_type(); let output = TensorType::from(node.outputs.first().unwrap());
let dim = gather_config(&node); let dim = gather_config(&node);
GatherElementsNode::new(input, index, output, dim) GatherElementsNode::new(input, index, output, dim)
} }
fn transpose_conversion(node: Node) -> UnaryNode { fn transpose_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type(); let input = Type::from(node.inputs.first().unwrap());
let output = node.outputs.first().unwrap().to_type(); let output = Type::from(node.outputs.first().unwrap());
let perm = transpose_config(&node); let perm = transpose_config(&node);
UnaryNode::transpose(input, output, perm) UnaryNode::transpose(input, output, perm)
} }
fn cast_conversion(node: Node) -> UnaryNode { fn cast_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type(); let input = Type::from(node.inputs.first().unwrap());
let output = node.outputs.first().unwrap().to_type(); let output = Type::from(node.outputs.first().unwrap());
UnaryNode::cast(input, output) UnaryNode::cast(input, output)
} }
fn reshape_conversion(node: Node) -> ReshapeNode { fn reshape_conversion(node: Node) -> ReshapeNode {
let input = node.inputs.first().unwrap().to_tensor_type(); let input = TensorType::from(node.inputs.first().unwrap());
let output = node.outputs.first().unwrap().to_tensor_type(); let output = TensorType::from(node.outputs.first().unwrap());
let shape = reshape_config(&node); let shape = reshape_config(&node);
ReshapeNode::new(input, output, shape) ReshapeNode::new(input, output, shape)
@ -591,10 +599,10 @@ impl OnnxGraph {
fn resize_conversion(node: Node) -> ResizeNode { fn resize_conversion(node: Node) -> ResizeNode {
let name = &node.name; let name = &node.name;
let input = node.inputs[0].to_tensor_type(); let input = TensorType::from(&node.inputs[0]);
let output_size = node.inputs[3].to_tensor_type(); let output_size = TensorType::from(&node.inputs[3]);
let output = node.outputs.first().unwrap().to_tensor_type(); let output = TensorType::from(node.outputs.first().unwrap());
let mode = resize_config(&node); let mode = resize_config(&node);
@ -602,15 +610,15 @@ impl OnnxGraph {
} }
fn min_conversion(node: Node) -> BinaryNode { fn min_conversion(node: Node) -> BinaryNode {
let lhs = node.inputs.first().unwrap().to_type(); let lhs = Type::from(node.inputs.first().unwrap());
let rhs = node.inputs.get(1).unwrap().to_type(); let rhs = Type::from(node.inputs.get(1).unwrap());
let output = node.outputs.first().unwrap().to_type(); let output = Type::from(node.outputs.first().unwrap());
BinaryNode::min_pair(lhs, rhs, output) BinaryNode::min_pair(lhs, rhs, output)
} }
fn range_conversion(node: Node) -> RangeNode { fn range_conversion(node: Node) -> RangeNode {
fn convert_arg_to_scalar(arg: &Argument) -> ScalarType { fn convert_arg_to_scalar(arg: &OnnxArgument) -> ScalarType {
match &arg.ty { match &arg.ty {
ArgType::Scalar(scalar) => { ArgType::Scalar(scalar) => {
ScalarType::new(arg.name.clone(), ScalarKind::from(scalar)) ScalarType::new(arg.name.clone(), ScalarKind::from(scalar))
@ -624,7 +632,7 @@ impl OnnxGraph {
_ => panic!("Range node requires scalar inputs"), _ => panic!("Range node requires scalar inputs"),
} }
} }
let output = node.outputs.first().unwrap().to_tensor_type(); let output = TensorType::from(node.outputs.first().unwrap());
let start = convert_arg_to_scalar(node.inputs.first().unwrap()); let start = convert_arg_to_scalar(node.inputs.first().unwrap());
let end = convert_arg_to_scalar(node.inputs.get(1).unwrap()); let end = convert_arg_to_scalar(node.inputs.get(1).unwrap());
let step = convert_arg_to_scalar(node.inputs.get(2).unwrap()); let step = convert_arg_to_scalar(node.inputs.get(2).unwrap());
@ -633,164 +641,156 @@ impl OnnxGraph {
} }
fn reduce_max_conversion(node: Node) -> UnaryNode { fn reduce_max_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type(); let input = Type::from(node.inputs.first().unwrap());
let output = node.outputs.first().unwrap().to_type(); let output = Type::from(node.outputs.first().unwrap());
let dim = reduce_max_config(&node); let dim = reduce_max_config(&node);
UnaryNode::reduce_max(input, output, dim) UnaryNode::reduce_max(input, output, dim)
} }
fn reduce_min_conversion(node: Node) -> UnaryNode { fn reduce_min_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type(); let input = Type::from(node.inputs.first().unwrap());
let output = node.outputs.first().unwrap().to_type(); let output = Type::from(node.outputs.first().unwrap());
let dim = reduce_min_config(&node); let dim = reduce_min_config(&node);
UnaryNode::reduce_min(input, output, dim) UnaryNode::reduce_min(input, output, dim)
} }
fn reduce_mean_conversion(node: Node) -> UnaryNode { fn reduce_mean_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type(); let input = Type::from(node.inputs.first().unwrap());
let output = node.outputs.first().unwrap().to_type(); let output = Type::from(node.outputs.first().unwrap());
let dim = reduce_mean_config(&node); let dim = reduce_mean_config(&node);
UnaryNode::reduce_mean(input, output, dim) UnaryNode::reduce_mean(input, output, dim)
} }
fn reduce_prod_conversion(node: Node) -> UnaryNode { fn reduce_prod_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type(); let input = Type::from(node.inputs.first().unwrap());
let output = node.outputs.first().unwrap().to_type(); let output = Type::from(node.outputs.first().unwrap());
let dim = reduce_prod_config(&node); let dim = reduce_prod_config(&node);
UnaryNode::reduce_prod(input, output, dim) UnaryNode::reduce_prod(input, output, dim)
} }
fn reduce_sum_conversion(node: Node) -> UnaryNode { fn reduce_sum_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type(); let input = Type::from(node.inputs.first().unwrap());
let output = node.outputs.first().unwrap().to_type(); let output = Type::from(node.outputs.first().unwrap());
let dim = reduce_sum_config(&node); let dim = reduce_sum_config(&node);
UnaryNode::reduce_sum(input, output, dim) UnaryNode::reduce_sum(input, output, dim)
} }
fn shape_conversion(node: Node) -> UnaryNode { fn shape_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type(); let input = Type::from(node.inputs.first().unwrap());
let output = node.outputs.first().unwrap().to_type(); let output = Type::from(node.outputs.first().unwrap());
let (start_dim, end_dim) = shape_config(&node); let (start_dim, end_dim) = shape_config(&node);
UnaryNode::shape(input, output, start_dim, end_dim) UnaryNode::shape(input, output, start_dim, end_dim)
} }
fn unsqueeze_conversion(node: Node) -> UnsqueezeNode { fn unsqueeze_conversion(node: Node) -> UnsqueezeNode {
let input = node.inputs.first().unwrap().to_type(); let input = Type::from(node.inputs.first().unwrap());
let output = node.outputs.first().unwrap().to_tensor_type(); let output = TensorType::from(node.outputs.first().unwrap());
let dims = unsqueeze_config(&node); let dims = unsqueeze_config(&node);
UnsqueezeNode::new(input, output, dims) UnsqueezeNode::new(input, output, dims)
} }
fn where_conversion(node: Node) -> WhereNode { fn where_conversion(node: Node) -> WhereNode {
let condition = node.inputs.first().unwrap().to_tensor_type(); let condition = TensorType::from(node.inputs.first().unwrap());
let x = node.inputs.get(1).unwrap().to_tensor_type(); let x = TensorType::from(node.inputs.get(1).unwrap());
let y = node.inputs.get(2).unwrap().to_tensor_type(); let y = TensorType::from(node.inputs.get(2).unwrap());
let output = node.outputs.first().unwrap().to_tensor_type(); let output = TensorType::from(node.outputs.first().unwrap());
WhereNode::new(condition, x, y, output) WhereNode::new(condition, x, y, output)
} }
fn clip_conversion(node: Node) -> ClipNode { fn clip_conversion(node: Node) -> ClipNode {
let input = node.inputs.first().unwrap().to_tensor_type(); let input = TensorType::from(node.inputs.first().unwrap());
let output = node.outputs.first().unwrap().to_tensor_type(); let output = TensorType::from(node.outputs.first().unwrap());
let (min, max) = clip_config(&node); let (min, max) = clip_config(&node);
ClipNode::new(input, output, min, max) ClipNode::new(input, output, min, max)
} }
fn sigmoid_conversion(node: Node) -> UnaryNode { fn sigmoid_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type(); let input = Type::from(node.inputs.first().unwrap());
let output = node.outputs.first().unwrap().to_type(); let output = Type::from(node.outputs.first().unwrap());
UnaryNode::sigmoid(input, output) UnaryNode::sigmoid(input, output)
} }
fn sin_conversion(node: Node) -> UnaryNode { fn sin_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type(); let input = Type::from(node.inputs.first().unwrap());
let output = node.outputs.first().unwrap().to_type(); let output = Type::from(node.outputs.first().unwrap());
UnaryNode::sin(input, output) UnaryNode::sin(input, output)
} }
fn slice_conversion(node: Node) -> SliceNode { fn slice_conversion(node: Node) -> SliceNode {
let input = node.inputs.first().unwrap().to_tensor_type(); let input = TensorType::from(node.inputs.first().unwrap());
let output = node.outputs.first().unwrap().to_tensor_type(); let output = TensorType::from(node.outputs.first().unwrap());
let (starts, ends) = slice_config(&node); let (starts, ends) = slice_config(&node);
SliceNode::new(input, output, starts, ends) SliceNode::new(input, output, starts, ends)
} }
fn sum_conversion(node: Node) -> SumNode { fn sum_conversion(node: Node) -> SumNode {
let inputs = node let inputs = node.inputs.iter().map(TensorType::from).collect();
.inputs let output = TensorType::from(node.outputs.first().unwrap());
.iter()
.map(|input| input.to_tensor_type())
.collect();
let output = node.outputs.first().unwrap().to_tensor_type();
SumNode::new(inputs, output) SumNode::new(inputs, output)
} }
fn reciprocal_conversion(node: Node) -> UnaryNode { fn reciprocal_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type(); let input = Type::from(node.inputs.first().unwrap());
let output = node.outputs.first().unwrap().to_type(); let output = Type::from(node.outputs.first().unwrap());
UnaryNode::reciprocal(input, output) UnaryNode::reciprocal(input, output)
} }
fn log_softmax_conversion(node: Node) -> UnaryNode { fn log_softmax_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type(); let input = Type::from(node.inputs.first().unwrap());
let output = node.outputs.first().unwrap().to_type(); let output = Type::from(node.outputs.first().unwrap());
let dim = log_softmax_config(&node); let dim = log_softmax_config(&node);
UnaryNode::log_softmax(input, output, dim) UnaryNode::log_softmax(input, output, dim)
} }
fn softmax_conversion(node: Node) -> UnaryNode { fn softmax_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type(); let input = Type::from(node.inputs.first().unwrap());
let output = node.outputs.first().unwrap().to_type(); let output = Type::from(node.outputs.first().unwrap());
let dim = softmax_config(&node); let dim = softmax_config(&node);
UnaryNode::softmax(input, output, dim) UnaryNode::softmax(input, output, dim)
} }
fn sqrt_conversion(node: Node) -> UnaryNode { fn sqrt_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type(); let input = Type::from(node.inputs.first().unwrap());
let output = node.outputs.first().unwrap().to_type(); let output = Type::from(node.outputs.first().unwrap());
UnaryNode::sqrt(input, output) UnaryNode::sqrt(input, output)
} }
fn tanh_conversion(node: Node) -> UnaryNode { fn tanh_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type(); let input = Type::from(node.inputs.first().unwrap());
let output = node.outputs.first().unwrap().to_type(); let output = Type::from(node.outputs.first().unwrap());
UnaryNode::tanh(input, output) UnaryNode::tanh(input, output)
} }
fn argmax_conversion(node: Node) -> ArgMaxNode { fn argmax_conversion(node: Node) -> ArgMaxNode {
let input = node.inputs.first().unwrap().to_tensor_type(); let input = TensorType::from(node.inputs.first().unwrap());
let output = node.outputs.first().unwrap().to_tensor_type(); let output = TensorType::from(node.outputs.first().unwrap());
let axis = argmax_config(&node); let axis = argmax_config(&node);
ArgMaxNode::new(input, output, axis) ArgMaxNode::new(input, output, axis)
} }
fn concat_conversion(node: Node) -> ConcatNode { fn concat_conversion(node: Node) -> ConcatNode {
let inputs = node let inputs = node.inputs.iter().map(TensorType::from).collect();
.inputs
.iter()
.map(|input| input.to_tensor_type())
.collect();
let output = node.outputs.first().unwrap().to_tensor_type(); let output = TensorType::from(node.outputs.first().unwrap());
let dim = concat_config(&node); let dim = concat_config(&node);
ConcatNode::new(inputs, output, dim) ConcatNode::new(inputs, output, dim)
@ -798,8 +798,8 @@ impl OnnxGraph {
fn linear_conversion<PS: PrecisionSettings>(node: Node) -> LinearNode { fn linear_conversion<PS: PrecisionSettings>(node: Node) -> LinearNode {
let name = &node.name; let name = &node.name;
let input = node.inputs.first().unwrap().to_tensor_type(); let input = TensorType::from(node.inputs.first().unwrap());
let output = node.outputs.first().unwrap().to_tensor_type(); let output = TensorType::from(node.outputs.first().unwrap());
let config = linear_config(&node); let config = linear_config(&node);
let weight = extract_data_serialize::<PS::FloatElem>(1, &node).expect("Weight is required"); let weight = extract_data_serialize::<PS::FloatElem>(1, &node).expect("Weight is required");
@ -811,8 +811,8 @@ impl OnnxGraph {
fn dropout_conversion(node: Node) -> DropoutNode { fn dropout_conversion(node: Node) -> DropoutNode {
let name = &node.name; let name = &node.name;
let input = node.inputs.first().unwrap().to_tensor_type(); let input = TensorType::from(node.inputs.first().unwrap());
let output = node.outputs.first().unwrap().to_tensor_type(); let output = TensorType::from(node.outputs.first().unwrap());
let config = dropout_config(&node); let config = dropout_config(&node);
DropoutNode::new(name, input, output, config) DropoutNode::new(name, input, output, config)
@ -820,8 +820,8 @@ impl OnnxGraph {
fn batch_norm_conversion<PS: PrecisionSettings>(node: Node) -> BatchNormNode { fn batch_norm_conversion<PS: PrecisionSettings>(node: Node) -> BatchNormNode {
let config = batch_norm_config(&node); let config = batch_norm_config(&node);
let input = node.inputs.first().unwrap().to_tensor_type(); let input = TensorType::from(node.inputs.first().unwrap());
let output = node.outputs.first().unwrap().to_tensor_type(); let output = TensorType::from(node.outputs.first().unwrap());
let dim = input.dim - 2; let dim = input.dim - 2;
let gamma = extract_data_serialize::<PS::FloatElem>(1, &node).expect("Gamma is required"); let gamma = extract_data_serialize::<PS::FloatElem>(1, &node).expect("Gamma is required");
@ -848,8 +848,8 @@ impl OnnxGraph {
fn layer_norm_conversion<PS: PrecisionSettings>(node: Node) -> LayerNormNode { fn layer_norm_conversion<PS: PrecisionSettings>(node: Node) -> LayerNormNode {
let (config, full_precision) = layer_norm_config(&node); let (config, full_precision) = layer_norm_config(&node);
let input = node.inputs.first().unwrap().to_tensor_type(); let input = TensorType::from(node.inputs.first().unwrap());
let output = node.outputs.first().unwrap().to_tensor_type(); let output = TensorType::from(node.outputs.first().unwrap());
// Scale tensor (aka gamma) // Scale tensor (aka gamma)
let gamma = extract_data_serialize::<PS::FloatElem>(1, &node).expect("Gamma is required"); let gamma = extract_data_serialize::<PS::FloatElem>(1, &node).expect("Gamma is required");
@ -862,8 +862,8 @@ impl OnnxGraph {
} }
fn conv1d_conversion<PS: PrecisionSettings>(node: Node) -> Conv1dNode { fn conv1d_conversion<PS: PrecisionSettings>(node: Node) -> Conv1dNode {
let input = node.inputs.first().unwrap().to_tensor_type(); let input = TensorType::from(node.inputs.first().unwrap());
let output = node.outputs.first().unwrap().to_tensor_type(); let output = TensorType::from(node.outputs.first().unwrap());
let config = conv1d_config(&node); let config = conv1d_config(&node);
let bias = node.inputs.len() == 3; let bias = node.inputs.len() == 3;
@ -878,8 +878,8 @@ impl OnnxGraph {
} }
fn conv2d_conversion<PS: PrecisionSettings>(node: Node) -> Conv2dNode { fn conv2d_conversion<PS: PrecisionSettings>(node: Node) -> Conv2dNode {
let input = node.inputs.first().unwrap().to_tensor_type(); let input = TensorType::from(node.inputs.first().unwrap());
let output = node.outputs.first().unwrap().to_tensor_type(); let output = TensorType::from(node.outputs.first().unwrap());
let config = conv2d_config(&node); let config = conv2d_config(&node);
let bias = node.inputs.len() == 3; let bias = node.inputs.len() == 3;
@ -894,8 +894,8 @@ impl OnnxGraph {
} }
fn max_pool1d_conversion(node: Node) -> MaxPool1dNode { fn max_pool1d_conversion(node: Node) -> MaxPool1dNode {
let input = node.inputs.first().unwrap().to_tensor_type(); let input = TensorType::from(node.inputs.first().unwrap());
let output = node.outputs.first().unwrap().to_tensor_type(); let output = TensorType::from(node.outputs.first().unwrap());
let config = max_pool1d_config(&node); let config = max_pool1d_config(&node);
let name = &node.name; let name = &node.name;
@ -903,8 +903,8 @@ impl OnnxGraph {
} }
fn max_pool2d_conversion(node: Node) -> MaxPool2dNode { fn max_pool2d_conversion(node: Node) -> MaxPool2dNode {
let input = node.inputs.first().unwrap().to_tensor_type(); let input = TensorType::from(node.inputs.first().unwrap());
let output = node.outputs.first().unwrap().to_tensor_type(); let output = TensorType::from(node.outputs.first().unwrap());
let config = max_pool2d_config(&node); let config = max_pool2d_config(&node);
let name = &node.name; let name = &node.name;
@ -912,16 +912,16 @@ impl OnnxGraph {
} }
fn prelu_conversion<PS: PrecisionSettings>(node: Node) -> PReluNode { fn prelu_conversion<PS: PrecisionSettings>(node: Node) -> PReluNode {
let input = node.inputs.first().unwrap().to_tensor_type(); let input = TensorType::from(node.inputs.first().unwrap());
let output = node.outputs.first().unwrap().to_tensor_type(); let output = TensorType::from(node.outputs.first().unwrap());
let weight = extract_data_serialize::<PS::FloatElem>(1, &node).unwrap(); let weight = extract_data_serialize::<PS::FloatElem>(1, &node).unwrap();
let config = PReluConfig::new(); let config = PReluConfig::new();
let name = &node.name; let name = &node.name;
PReluNode::new(name, input, output, weight, config) PReluNode::new(name, input, output, weight, config)
} }
fn conv_transpose2d_conversion<PS: PrecisionSettings>(node: Node) -> ConvTranspose2dNode { fn conv_transpose2d_conversion<PS: PrecisionSettings>(node: Node) -> ConvTranspose2dNode {
let input = node.inputs.first().unwrap().to_tensor_type(); let input = TensorType::from(node.inputs.first().unwrap());
let output = node.outputs.first().unwrap().to_tensor_type(); let output = TensorType::from(node.outputs.first().unwrap());
let config = conv_transpose2d_config(&node); let config = conv_transpose2d_config(&node);
let bias = node.inputs.len() == 3; let bias = node.inputs.len() == 3;
@ -935,8 +935,8 @@ impl OnnxGraph {
ConvTranspose2dNode::new(name, input, output, weight, bias, config) ConvTranspose2dNode::new(name, input, output, weight, bias, config)
} }
fn avg_pool_1d_conversion(node: Node) -> AvgPool1dNode { fn avg_pool_1d_conversion(node: Node) -> AvgPool1dNode {
let input = node.inputs.first().unwrap().to_tensor_type(); let input = TensorType::from(node.inputs.first().unwrap());
let output = node.outputs.first().unwrap().to_tensor_type(); let output = TensorType::from(node.outputs.first().unwrap());
let config = avg_pool1d_config(&node); let config = avg_pool1d_config(&node);
let name = &node.name; let name = &node.name;
@ -944,8 +944,8 @@ impl OnnxGraph {
} }
fn avg_pool_2d_conversion(node: Node) -> AvgPool2dNode { fn avg_pool_2d_conversion(node: Node) -> AvgPool2dNode {
let input = node.inputs.first().unwrap().to_tensor_type(); let input = TensorType::from(node.inputs.first().unwrap());
let output = node.outputs.first().unwrap().to_tensor_type(); let output = TensorType::from(node.outputs.first().unwrap());
let config = avg_pool2d_config(&node); let config = avg_pool2d_config(&node);
let name = &node.name; let name = &node.name;
@ -953,8 +953,8 @@ impl OnnxGraph {
} }
fn global_avg_pool_conversion(node: Node) -> GlobalAvgPoolNode { fn global_avg_pool_conversion(node: Node) -> GlobalAvgPoolNode {
let input = node.inputs.first().unwrap().to_tensor_type(); let input = TensorType::from(node.inputs.first().unwrap());
let output = node.outputs.first().unwrap().to_tensor_type(); let output = TensorType::from(node.outputs.first().unwrap());
let name = &node.name; let name = &node.name;
@ -962,71 +962,71 @@ impl OnnxGraph {
} }
fn cos_conversion(node: Node) -> UnaryNode { fn cos_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type(); let input = Type::from(node.inputs.first().unwrap());
let output = node.outputs.first().unwrap().to_type(); let output = Type::from(node.outputs.first().unwrap());
UnaryNode::cos(input, output) UnaryNode::cos(input, output)
} }
fn exp_conversion(node: Node) -> UnaryNode { fn exp_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type(); let input = Type::from(node.inputs.first().unwrap());
let output = node.outputs.first().unwrap().to_type(); let output = Type::from(node.outputs.first().unwrap());
UnaryNode::exp(input, output) UnaryNode::exp(input, output)
} }
fn expand_conversion(node: Node) -> ExpandNode { fn expand_conversion(node: Node) -> ExpandNode {
let input = node.inputs.first().unwrap().to_tensor_type(); let input = TensorType::from(node.inputs.first().unwrap());
let output = node.outputs.first().unwrap().to_tensor_type(); let output = TensorType::from(node.outputs.first().unwrap());
let shape = expand_config(&node); let shape = expand_config(&node);
ExpandNode::new(input, output, shape) ExpandNode::new(input, output, shape)
} }
fn neg_conversion(node: Node) -> UnaryNode { fn neg_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type(); let input = Type::from(node.inputs.first().unwrap());
let output = node.outputs.first().unwrap().to_type(); let output = Type::from(node.outputs.first().unwrap());
UnaryNode::neg(input, output) UnaryNode::neg(input, output)
} }
fn not_conversion(node: Node) -> UnaryNode { fn not_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type(); let input = Type::from(node.inputs.first().unwrap());
let output = node.outputs.first().unwrap().to_type(); let output = Type::from(node.outputs.first().unwrap());
UnaryNode::not(input, output) UnaryNode::not(input, output)
} }
fn greater_conversion(node: Node) -> BinaryNode { fn greater_conversion(node: Node) -> BinaryNode {
let lhs = node.inputs.first().unwrap().to_type(); let lhs = Type::from(node.inputs.first().unwrap());
let rhs = node.inputs.get(1).unwrap().to_type(); let rhs = Type::from(node.inputs.get(1).unwrap());
let output = node.outputs.first().unwrap().to_type(); let output = Type::from(node.outputs.first().unwrap());
BinaryNode::greater(lhs, rhs, output) BinaryNode::greater(lhs, rhs, output)
} }
fn less_conversion(node: Node) -> BinaryNode { fn less_conversion(node: Node) -> BinaryNode {
let lhs = node.inputs.first().unwrap().to_type(); let lhs = Type::from(node.inputs.first().unwrap());
let rhs = node.inputs.get(1).unwrap().to_type(); let rhs = Type::from(node.inputs.get(1).unwrap());
let output = node.outputs.first().unwrap().to_type(); let output = Type::from(node.outputs.first().unwrap());
BinaryNode::lower(lhs, rhs, output) BinaryNode::lower(lhs, rhs, output)
} }
fn greater_or_equal_conversion(node: Node) -> BinaryNode { fn greater_or_equal_conversion(node: Node) -> BinaryNode {
let lhs = node.inputs.first().unwrap().to_type(); let lhs = Type::from(node.inputs.first().unwrap());
let rhs = node.inputs.get(1).unwrap().to_type(); let rhs = Type::from(node.inputs.get(1).unwrap());
let output = node.outputs.first().unwrap().to_type(); let output = Type::from(node.outputs.first().unwrap());
BinaryNode::greater_equal(lhs, rhs, output) BinaryNode::greater_equal(lhs, rhs, output)
} }
fn less_or_equal_conversion(node: Node) -> BinaryNode { fn less_or_equal_conversion(node: Node) -> BinaryNode {
let lhs = node.inputs.first().unwrap().to_type(); let lhs = Type::from(node.inputs.first().unwrap());
let rhs = node.inputs.get(1).unwrap().to_type(); let rhs = Type::from(node.inputs.get(1).unwrap());
let output = node.outputs.first().unwrap().to_type(); let output = Type::from(node.outputs.first().unwrap());
BinaryNode::lower_equal(lhs, rhs, output) BinaryNode::lower_equal(lhs, rhs, output)
} }
fn pow_conversion(node: Node) -> BinaryNode { fn pow_conversion(node: Node) -> BinaryNode {
let lhs = node.inputs.first().unwrap().to_type(); let lhs = Type::from(node.inputs.first().unwrap());
let rhs = node.inputs.get(1).unwrap().to_type(); let rhs = Type::from(node.inputs.get(1).unwrap());
let output = node.outputs.first().unwrap().to_type(); let output = Type::from(node.outputs.first().unwrap());
match &rhs { match &rhs {
Type::Tensor(x) => match x.kind { Type::Tensor(x) => match x.kind {
TensorKind::Int => BinaryNode::powi(lhs, rhs, output), TensorKind::Int => BinaryNode::powi(lhs, rhs, output),
@ -1043,14 +1043,14 @@ impl OnnxGraph {
} }
fn sign_conversion(node: Node) -> UnaryNode { fn sign_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type(); let input = Type::from(node.inputs.first().unwrap());
let output = node.outputs.first().unwrap().to_type(); let output = Type::from(node.outputs.first().unwrap());
UnaryNode::sign(input, output) UnaryNode::sign(input, output)
} }
fn squeeze_conversion(node: Node) -> SqueezeNode { fn squeeze_conversion(node: Node) -> SqueezeNode {
let input = node.inputs.first().unwrap().to_tensor_type(); let input = TensorType::from(node.inputs.first().unwrap());
let output = node.outputs.first().unwrap().to_tensor_type(); let output = TensorType::from(node.outputs.first().unwrap());
let axes = squeeze_config(&node); let axes = squeeze_config(&node);
SqueezeNode::new(input, output, axes) SqueezeNode::new(input, output, axes)
@ -1101,48 +1101,49 @@ fn serialize_data<E: Element>(data: Data, shape: Vec<usize>) -> TensorData {
} }
} }
impl Argument { impl From<&OnnxArgument> for TensorType {
pub fn to_tensor_type(&self) -> TensorType { fn from(arg: &OnnxArgument) -> Self {
match &self.ty { match &arg.ty {
ArgType::Tensor(ir::TensorType { ArgType::Tensor(OnnxTensorType {
elem_type: ElementType::Float16 | ElementType::Float32 | ElementType::Float64, elem_type: ElementType::Float16 | ElementType::Float32 | ElementType::Float64,
dim, dim,
.. ..
}) => TensorType::new_float(self.name.clone(), *dim), }) => TensorType::new_float(arg.name.clone(), *dim),
ArgType::Tensor(ir::TensorType { ArgType::Tensor(OnnxTensorType {
elem_type: ElementType::Int32 | ElementType::Int64, elem_type: ElementType::Int32 | ElementType::Int64,
dim, dim,
.. ..
}) => TensorType::new_int(self.name.clone(), *dim), }) => TensorType::new_int(arg.name.clone(), *dim),
ArgType::Tensor(ir::TensorType { ArgType::Tensor(OnnxTensorType {
elem_type: ElementType::Bool, elem_type: ElementType::Bool,
dim, dim,
.. ..
}) => TensorType::new_bool(self.name.clone(), *dim), }) => TensorType::new_bool(arg.name.clone(), *dim),
_ => panic!("Can't transform to tensor."), _ => panic!("Can't transform scalar to tensor."),
} }
} }
}
pub fn to_type(&self) -> Type { impl From<&OnnxArgument> for Type {
match &self.ty { fn from(arg: &OnnxArgument) -> Self {
match &arg.ty {
ArgType::Tensor(tensor) => { ArgType::Tensor(tensor) => {
// Treat tensor with dim 0 as scalar // Treat tensor with dim 0 as scalar
if tensor.dim == 0 { if tensor.dim == 0 {
Type::Scalar(ScalarType::new( Type::Scalar(ScalarType::new(
self.name.clone(), arg.name.clone(),
ScalarKind::from(&tensor.elem_type), ScalarKind::from(&tensor.elem_type),
)) ))
} else { } else {
let kind: TensorKind = tensor.elem_type.clone().into(); let kind: TensorKind = tensor.elem_type.clone().into();
let dim = tensor.dim; let dim = tensor.dim;
let name = self.name.clone(); let name = arg.name.clone();
let shape = tensor.shape.clone(); let shape = tensor.shape.clone();
Type::Tensor(TensorType::new(name, dim, kind, shape)) Type::Tensor(TensorType::new(name, dim, kind, shape))
} }
} }
ArgType::Scalar(elem_type) => { ArgType::Scalar(elem_type) => {
Type::Scalar(ScalarType::new(self.name.clone(), elem_type.into())) Type::Scalar(ScalarType::new(arg.name.clone(), elem_type.into()))
} }
ArgType::Shape(_shape) => panic!("Can't transform shape to tensor."), ArgType::Shape(_shape) => panic!("Can't transform shape to tensor."),
} }

31
crates/onnx-ir/Cargo.toml Normal file
View File

@ -0,0 +1,31 @@
[package]
authors = [
"Dilshod Tadjibaev (@antimora)",
"Nathaniel Simard (@nathanielsimard)",
]
description = "Library for parsing ONNX models"
edition.workspace = true
license.workspace = true
name = "onnx-ir"
readme.workspace = true
repository = "https://github.com/tracel-ai/burn/tree/main/crates/onnx-ir"
version.workspace = true
[dependencies]
bytemuck = { workspace = true }
half = { workspace = true }
log = { workspace = true }
protobuf = { workspace = true, features = ["with-bytes"] }
regex = { workspace = true }
serde = { workspace = true, features = ["derive"] }
strum = { workspace = true }
strum_macros = { workspace = true }
[build-dependencies]
protobuf-codegen = { workspace = true }
[dev-dependencies]
pretty_assertions = { workspace = true }
rstest = { workspace = true }

7
crates/onnx-ir/README.md Normal file
View File

@ -0,0 +1,7 @@
# ONNX-IR
A pure rust Onnx Parser. Creates an intermediate representation useful for generating code in any ML/DL framework
For a full list of currently supported operators, please check [here](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/SUPPORTED-ONNX-OPS.md)
To see how to use this for generating burn graphs, see [here](crates/burn-import/src/onnx/to_burn.rs).

9
crates/onnx-ir/build.rs Normal file
View File

@ -0,0 +1,9 @@
fn main() {
// Generate the onnx protobuf files
protobuf_codegen::Codegen::new()
.pure()
.includes(["src"])
.input("src/protos/onnx.proto")
.cargo_out_dir("onnx-protos")
.run_from_script();
}

View File

@ -6,7 +6,7 @@ use super::{
proto_conversion::convert_node_proto, proto_conversion::convert_node_proto,
protos::NodeProto, protos::NodeProto,
}; };
use crate::onnx::ir::{ArgType, Data, TensorType}; use crate::ir::{ArgType, Data, TensorType};
/// The function transforms the graph into a new one where the nodes are coalesced into a single node. /// The function transforms the graph into a new one where the nodes are coalesced into a single node.
pub fn coalesce( pub fn coalesce(

View File

@ -3,10 +3,10 @@ use core::panic;
use protobuf::Enum; use protobuf::Enum;
use super::{ use crate::{
ir::{ArgType, AttributeValue, Data, ElementType, Node, NodeType, TensorType}, ir::{ArgType, AttributeValue, Data, ElementType, Node, NodeType, TensorType},
op_configuration::flatten_config,
protos::tensor_proto::DataType, protos::tensor_proto::DataType,
util::flatten_config,
}; };
/// Infer the dimension of each output tensor and update them. /// Infer the dimension of each output tensor and update them.

View File

@ -4,7 +4,7 @@ use std::{
path::Path, path::Path,
}; };
use crate::onnx::node_remap::remap_node_type; use crate::node_remap::remap_node_type;
use super::{ use super::{
coalesce::coalesce, coalesce::coalesce,
@ -56,9 +56,9 @@ pub struct GraphData {
impl GraphData { impl GraphData {
pub(crate) fn new( pub(crate) fn new(
inputs: &Vec<ValueInfoProto>, inputs: &[ValueInfoProto],
outputs: &Vec<ValueInfoProto>, outputs: &[ValueInfoProto],
initializers: &Vec<TensorProto>, initializers: &[TensorProto],
) -> Self { ) -> Self {
let mut input_name_map = HashMap::new(); let mut input_name_map = HashMap::new();
let mut input_key_map = HashMap::new(); let mut input_key_map = HashMap::new();
@ -375,35 +375,32 @@ pub fn parse_onnx(onnx_path: &Path) -> OnnxGraph {
/// properly deleted if nothing else uses it /// properly deleted if nothing else uses it
/// Remap the unsqueeze node to a reshape node /// Remap the unsqueeze node to a reshape node
pub(crate) fn remap_unsqueeze_to_reshape(node: &mut Node, out_arg: &Argument) { pub(crate) fn remap_unsqueeze_to_reshape(node: &mut Node, out_arg: &Argument) {
match &out_arg.ty { if let ArgType::Tensor(output_tensor) = &out_arg.ty {
ArgType::Tensor(output_tensor) => { let inner = output_tensor
let inner = output_tensor .shape
.shape .clone()
.clone() .unwrap()
.unwrap() .into_iter()
.into_iter() .map(|x| x as i64)
.map(|x| x as i64) .collect::<Vec<i64>>();
.collect::<Vec<i64>>(); let shape_len = inner.len();
let shape_len = inner.len(); let new_rhs_value = Some(Data::Int64s(inner));
let new_rhs_value = Some(Data::Int64s(inner)); //moving the remap to here
//moving the remap to here let rhs_arg = Argument {
let rhs_arg = Argument { name: format!("{}_generated_const", &node.name),
name: format!("{}_generated_const", &node.name), ty: ArgType::Tensor(TensorType {
ty: ArgType::Tensor(TensorType { elem_type: super::ir::ElementType::Int64,
elem_type: super::ir::ElementType::Int64, dim: 1,
dim: 1, shape: Some(vec![shape_len]),
shape: Some(vec![shape_len]), }),
}), value: new_rhs_value,
value: new_rhs_value, passed: false,
passed: false, };
}; // ? should this replace the old input (reuse the old key) or should it be a new key
// ? should this replace the old input (reuse the old key) or should it be a new key // going with new key for now
// going with new key for now node.inputs[1] = rhs_arg;
node.inputs[1] = rhs_arg; node.outputs[0] = out_arg.clone();
node.outputs[0] = out_arg.clone(); node.node_type = NodeType::Reshape;
node.node_type = NodeType::Reshape;
}
_ => {}
} }
} }
// Define a trait for topological sorting // Define a trait for topological sorting
@ -444,7 +441,7 @@ impl TopologicalSortable for Vec<NodeProto> {
} }
/// Get the value of a constant node from its attributes /// Get the value of a constant node from its attributes
pub(crate) fn convert_constant_value(node: &Node) -> Argument { pub fn convert_constant_value(node: &Node) -> Argument {
// A value can be stored in any of these attributes // A value can be stored in any of these attributes
let keys = [ let keys = [
"value", "value",

View File

@ -3,7 +3,7 @@ use half::f16;
use std::{collections::HashMap, fmt::Formatter}; use std::{collections::HashMap, fmt::Formatter};
use strum_macros::{Display, EnumString}; use strum_macros::{Display, EnumString};
use super::protos::TensorProto; use crate::protos::TensorProto;
pub type Dim = usize; pub type Dim = usize;
pub type Shape = Vec<Dim>; pub type Shape = Vec<Dim>;
@ -241,7 +241,7 @@ impl PartialEq for Argument {
} }
/// The list of supported node types (ONNX operators and some extra ones to map easily to Burn's ops) /// The list of supported node types (ONNX operators and some extra ones to map easily to Burn's ops)
/// Refer: https://github.com/onnx/onnx/blob/main/docs/Operators.md /// Refer: <https://github.com/onnx/onnx/blob/main/docs/Operators.md>
#[derive(Debug, Hash, Eq, PartialEq, EnumString, Clone, Display)] #[derive(Debug, Hash, Eq, PartialEq, EnumString, Clone, Display)]
pub enum NodeType { pub enum NodeType {
Abs, Abs,
@ -444,7 +444,7 @@ pub enum NodeType {
} }
/// Truncate the vector display for debug display /// Truncate the vector display for debug display
fn trunc<T: fmt::Display>(v: &Vec<T>) -> String { fn trunc<T: fmt::Display>(v: &[T]) -> String {
const BEGIN_INDEX: usize = 0; const BEGIN_INDEX: usize = 0;
const MAX_LEN: usize = 5; const MAX_LEN: usize = 5;
let mut s = String::new(); let mut s = String::new();

12
crates/onnx-ir/src/lib.rs Normal file
View File

@ -0,0 +1,12 @@
mod coalesce;
mod dim_inference;
mod from_onnx;
pub mod ir;
mod node_remap;
mod proto_conversion;
mod protos;
mod util;
pub use from_onnx::convert_constant_value;
pub use from_onnx::parse_onnx;
pub use ir::OnnxGraph;

View File

@ -1,6 +1,6 @@
use std::str::{from_utf8, FromStr}; use std::str::{from_utf8, FromStr};
use crate::onnx::ir::TensorType; use crate::ir::TensorType;
use super::from_onnx::GraphData; use super::from_onnx::GraphData;
use super::ir::Dim; use super::ir::Dim;

View File

@ -0,0 +1,45 @@
use crate::ir::{ArgType, Node};
/// Create a FlattenConfig from the attributes of the node
pub fn flatten_config(curr: &Node) -> (usize, usize) {
// the begin dimension is the first dimension (Default: 1 per ONNX spec)
let mut start_dim: i64 = 1;
// check if the node has only one input
if curr.inputs.len() != 1 {
panic!(
"Flatten: multiple inputs are not supported (got {:?})",
curr.inputs.len()
);
}
// extract the shape of the input tensor
let tensor = match curr.inputs.first().unwrap().clone().ty {
ArgType::Tensor(tensor) => tensor,
_ => panic!("Only tensor input is valid"),
};
// check if the input tensor has at least 2 dimensions
if tensor.dim < 2 {
panic!(
"Flatten: input tensor must have at least 2 dimensions (got {:?})",
tensor.dim
);
}
// the end dimension is the last dimension
let end_dim = tensor.dim - 1;
// extract the attributes
for (key, value) in curr.attrs.iter() {
if key.as_str() == "axis" {
start_dim = value.clone().into_i64();
}
}
// if beg_dim is negative, it is counted from the end
if start_dim < 0 {
start_dim += tensor.dim as i64;
}
(start_dim as usize, end_dim)
}