mirror of https://github.com/tracel-ai/burn.git
Separating ONNX parsing from burn-import (#1921)
* separating onnx parsing from burn-import * ran clippy and cargo-fmt * removed unused deps from onnx-ir * fixed clippy warnings that were causing run-checks to fail * removed dead code * removed unused dependencies from burn-import * updated contributor-book, updated publish.yml, added readme * update cargo lock * formatted md document with prettier, rephrased sentence * missed the errors with reduce_prod_conversion during merge * formatted onnx-to-burn-conversion-tool.md, forgot to save
This commit is contained in:
parent
755c0708c4
commit
25348cf181
|
@ -161,3 +161,9 @@ jobs:
|
||||||
with:
|
with:
|
||||||
crate: burn-import
|
crate: burn-import
|
||||||
secrets: inherit
|
secrets: inherit
|
||||||
|
|
||||||
|
publish-onnx-ir:
|
||||||
|
uses: tracel-ai/burn/.github/workflows/publish-template.yml@main
|
||||||
|
with:
|
||||||
|
crate: onnx-ir
|
||||||
|
secrets: inherit
|
||||||
|
|
|
@ -628,23 +628,19 @@ name = "burn-import"
|
||||||
version = "0.14.0"
|
version = "0.14.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"burn",
|
"burn",
|
||||||
"bytemuck",
|
|
||||||
"candle-core",
|
"candle-core",
|
||||||
"derive-new",
|
"derive-new",
|
||||||
"half",
|
"half",
|
||||||
"log",
|
"log",
|
||||||
|
"onnx-ir",
|
||||||
"pretty_assertions",
|
"pretty_assertions",
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"protobuf",
|
|
||||||
"protobuf-codegen",
|
|
||||||
"quote",
|
"quote",
|
||||||
"regex",
|
"regex",
|
||||||
"rstest",
|
"rstest",
|
||||||
"rust-format",
|
"rust-format",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"strum",
|
|
||||||
"strum_macros",
|
|
||||||
"syn 2.0.68",
|
"syn 2.0.68",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
"tracing-core",
|
"tracing-core",
|
||||||
|
@ -3689,6 +3685,23 @@ dependencies = [
|
||||||
"serde",
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "onnx-ir"
|
||||||
|
version = "0.14.0"
|
||||||
|
dependencies = [
|
||||||
|
"bytemuck",
|
||||||
|
"half",
|
||||||
|
"log",
|
||||||
|
"pretty_assertions",
|
||||||
|
"protobuf",
|
||||||
|
"protobuf-codegen",
|
||||||
|
"regex",
|
||||||
|
"rstest",
|
||||||
|
"serde",
|
||||||
|
"strum",
|
||||||
|
"strum_macros",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "onnx-tests"
|
name = "onnx-tests"
|
||||||
version = "0.14.0"
|
version = "0.14.0"
|
||||||
|
|
|
@ -117,7 +117,8 @@ plug in your operator in terms of \\(x\\) and \\(y\\), and just swap out the var
|
||||||
|
|
||||||
### Testing autodiff
|
### Testing autodiff
|
||||||
|
|
||||||
For testing the `autodiff` operations, please refer to [this section](../getting-started/testing.md).
|
For testing the `autodiff` operations, please refer to
|
||||||
|
[this section](../getting-started/testing.md).
|
||||||
|
|
||||||
## Adding the Op to other backends
|
## Adding the Op to other backends
|
||||||
|
|
||||||
|
@ -199,11 +200,15 @@ Generating the ONNX test files or tests is already covered
|
||||||
about the specific changes you need to make when adding new operators after you have generated the
|
about the specific changes you need to make when adding new operators after you have generated the
|
||||||
tests.
|
tests.
|
||||||
|
|
||||||
The crate is divided into two sections `src/burn` and `src/onnx`. The code under the former
|
Changes will need to be made to both `onnx-ir` and `burn-import`. The code within `onnx-ir` defines
|
||||||
corresponds to the operation you've implemented earlier in this guide, and the latter to the
|
how to parse the nodes in an onnx file and produces the intermediate representation. The code within
|
||||||
operations defined in the ONNX specification. So when you are loading a model, the operator is first
|
`burn-import` is divided into two sections: `src/onnx` and `src/burn`. The code under the former
|
||||||
parsed to an intermediate representation defined by `src/onnx`, and then mapped to a Burn operation
|
maps that intermediate representation to one used for code generation and the latter defines how to
|
||||||
defined under `src/burn/node`.
|
generate code for the operator you've implemented earlier in this guide.
|
||||||
|
|
||||||
|
So when you are loading a model, the operator is first parsed to an intermediate representation
|
||||||
|
defined by `burn-import` and then mapped to a Burn operation defined under `src/burn/node`; the
|
||||||
|
mapping from onnx to burn is aptly defined in `src/onnx/to_burn`
|
||||||
|
|
||||||
Let's review the changes made for powf starting from `src/burn` and moving to `src/onnx`:
|
Let's review the changes made for powf starting from `src/burn` and moving to `src/onnx`:
|
||||||
|
|
||||||
|
@ -218,17 +223,20 @@ Let's review the changes made for powf starting from `src/burn` and moving to `s
|
||||||
[`{op}_conversion` function](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-import/src/onnx/to_burn.rs#L717)
|
[`{op}_conversion` function](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-import/src/onnx/to_burn.rs#L717)
|
||||||
that maps the ONNX node to the binary type
|
that maps the ONNX node to the binary type
|
||||||
3. Specify how dimensions for the output should be derived in
|
3. Specify how dimensions for the output should be derived in
|
||||||
[crates/burn-import/src/onnx/dim_inference.rs](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-import/src/onnx/dim_inference.rs#L55)
|
[crates/onnx-ir/src/dim_inference.rs](https://github.com/tracel-ai/burn/blob/d4ae82b21ac3dd1def01bd380ab7ea4d3293eccb/crates/onnx-ir/src/dim_inference.rs#L17)
|
||||||
|
|
||||||
And you're done! Congrats, you just fully added a new operation to burn, and we are all one step
|
And you're done! Congrats, you just fully added a new operation to burn, and we are all one step
|
||||||
closer to the answer to [Are we learning yet?](https://www.arewelearningyet.com/) being "Yes, and
|
closer to the answer to [Are we learning yet?](https://www.arewelearningyet.com/) being "Yes, and
|
||||||
it's freaking fast!". Buy yourself a coffee.
|
it's freaking fast!". Buy yourself a coffee.
|
||||||
|
|
||||||
[^supertrait]: for more on supertraits see
|
[^supertrait]:
|
||||||
|
for more on supertraits see
|
||||||
[the advanced trait section of the rust book](https://doc.rust-lang.org/book/ch19-03-advanced-traits.html#using-supertraits-to-require-one-traits-functionality-within-another-trait)
|
[the advanced trait section of the rust book](https://doc.rust-lang.org/book/ch19-03-advanced-traits.html#using-supertraits-to-require-one-traits-functionality-within-another-trait)
|
||||||
|
|
||||||
[^autodiff]: wiki link for
|
[^autodiff]:
|
||||||
|
wiki link for
|
||||||
[automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation)
|
[automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation)
|
||||||
|
|
||||||
[^absolute_units]: for more information on unit structs see
|
[^absolute_units]:
|
||||||
|
for more information on unit structs see
|
||||||
[the defining and instantiating structs section of the rust book](https://doc.rust-lang.org/book/ch05-01-defining-structs.html#unit-like-structs-without-any-fields)
|
[the defining and instantiating structs section of the rust book](https://doc.rust-lang.org/book/ch05-01-defining-structs.html#unit-like-structs-without-any-fields)
|
||||||
|
|
|
@ -16,7 +16,17 @@ For an introduction to ONNX import in Burn, see
|
||||||
- [Design Goals](#design-goals)
|
- [Design Goals](#design-goals)
|
||||||
- [Design Decisions](#design-decisions)
|
- [Design Decisions](#design-decisions)
|
||||||
- [Adding New Operators](#adding-new-operators)
|
- [Adding New Operators](#adding-new-operators)
|
||||||
- [Implementing a New Operator](#implementing-a-new-operator)
|
- [Implementing a New Operator](#implementing-a-new-operator)
|
||||||
|
- [Step 1: Visibility](#step-1-visibility)
|
||||||
|
- [Step 2: Node Implementation](#step-2-node-implementation)
|
||||||
|
- [Within Onnx-IR](#within-onnx-ir)
|
||||||
|
- [Within burn-import](#within-burn-import)
|
||||||
|
- [Step 3: Registering New Operations](#step-3-registering-new-operations)
|
||||||
|
- [Step 4: Create a Config Function](#step-4-create-a-config-function)
|
||||||
|
- [Step 5: Dimension Inference](#step-5-dimension-inference)
|
||||||
|
- [Step 6: Integrate into the Graph Building Process](#step-6-integrate-into-the-graph-building-process)
|
||||||
|
- [Step 7: Add Newly Supported Op!](#step-7-add-newly-supported-op)
|
||||||
|
- [Misc:](#misc)
|
||||||
- [Testing](#testing)
|
- [Testing](#testing)
|
||||||
- [Resources](#resources)
|
- [Resources](#resources)
|
||||||
|
|
||||||
|
@ -91,6 +101,22 @@ located in the `src/burn/node/` directory.
|
||||||
|
|
||||||
### Step 2: Node Implementation
|
### Step 2: Node Implementation
|
||||||
|
|
||||||
|
#### Within Onnx-IR
|
||||||
|
|
||||||
|
If the node type does not exist within the
|
||||||
|
[`NodeType` enum](https://github.com/tracel-ai/burn/blob/d4ae82b21ac3dd1def01bd380ab7ea4d3293eccb/crates/onnx-ir/src/ir.rs#L246),
|
||||||
|
it will need to be added (support for custom operators is planned). If the node might be provided an
|
||||||
|
input which is a constant or the output of an identity node, it will need to be added to the list of
|
||||||
|
nodeTypes
|
||||||
|
[checked for constants](https://github.com/tracel-ai/burn/blob/d4ae82b21ac3dd1def01bd380ab7ea4d3293eccb/crates/onnx-ir/src/from_onnx.rs#L21).
|
||||||
|
The node will need to be added to `dim_inference`, and in most cases the work parsing side will be
|
||||||
|
done. If a node requires extra parsing (such as handling an edge case like potentially remapping an
|
||||||
|
unsqueeze to a reshape) the best place for that is after check constants and prior to dim_inference
|
||||||
|
in
|
||||||
|
[`OnnxGraphBuilder::Build`](https://github.com/tracel-ai/burn/blob/d4ae82b21ac3dd1def01bd380ab7ea4d3293eccb/crates/onnx-ir/src/from_onnx.rs#L221)
|
||||||
|
|
||||||
|
#### Within burn-import
|
||||||
|
|
||||||
Create a new file named `<operation_name>.rs` in the `src/burn/node/` directory.
|
Create a new file named `<operation_name>.rs` in the `src/burn/node/` directory.
|
||||||
This file will define the structure and functionality of your new operation. By convention, the
|
This file will define the structure and functionality of your new operation. By convention, the
|
||||||
necessary information for carrying out an operation is encapsulated within a struct named
|
necessary information for carrying out an operation is encapsulated within a struct named
|
||||||
|
|
|
@ -20,30 +20,23 @@ pytorch = ["burn/record-item-custom-serde", "thiserror", "zip"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
burn = { path = "../burn", version = "0.14.0", features = ["ndarray"] }
|
burn = { path = "../burn", version = "0.14.0", features = ["ndarray"] }
|
||||||
|
onnx-ir = { path = "../onnx-ir" }
|
||||||
bytemuck = { workspace = true }
|
|
||||||
candle-core = { workspace = true }
|
candle-core = { workspace = true }
|
||||||
derive-new = { workspace = true }
|
derive-new = { workspace = true }
|
||||||
half = { workspace = true }
|
half = { workspace = true }
|
||||||
log = { workspace = true }
|
log = { workspace = true }
|
||||||
proc-macro2 = { workspace = true }
|
proc-macro2 = { workspace = true }
|
||||||
protobuf = { workspace = true, features = ["with-bytes"] }
|
|
||||||
quote = { workspace = true }
|
quote = { workspace = true }
|
||||||
regex = { workspace = true }
|
regex = { workspace = true }
|
||||||
rust-format = { workspace = true, features = ["token_stream", "post_process"] }
|
rust-format = { workspace = true, features = ["token_stream", "post_process"] }
|
||||||
serde = { workspace = true, features = ["derive"] }
|
serde = { workspace = true, features = ["derive"] }
|
||||||
serde_json = { workspace = true, features = ["std"] }
|
serde_json = { workspace = true, features = ["std"] }
|
||||||
strum = { workspace = true }
|
|
||||||
strum_macros = { workspace = true }
|
|
||||||
syn = { workspace = true, features = ["parsing"] }
|
syn = { workspace = true, features = ["parsing"] }
|
||||||
thiserror = { workspace = true, optional = true }
|
thiserror = { workspace = true, optional = true }
|
||||||
tracing-core = { workspace = true }
|
tracing-core = { workspace = true }
|
||||||
tracing-subscriber = { workspace = true }
|
tracing-subscriber = { workspace = true }
|
||||||
zip = { workspace = true, optional = true }
|
zip = { workspace = true, optional = true }
|
||||||
|
|
||||||
[build-dependencies]
|
|
||||||
protobuf-codegen = { workspace = true }
|
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
pretty_assertions = { workspace = true }
|
pretty_assertions = { workspace = true }
|
||||||
rstest = { workspace = true }
|
rstest = { workspace = true }
|
||||||
|
|
|
@ -1,11 +0,0 @@
|
||||||
fn main() {
|
|
||||||
if cfg!(feature = "onnx") {
|
|
||||||
// Generate the onnx protobuf files
|
|
||||||
protobuf_codegen::Codegen::new()
|
|
||||||
.pure()
|
|
||||||
.includes(["src"])
|
|
||||||
.input("src/onnx/protos/onnx.proto")
|
|
||||||
.cargo_out_dir("onnx-protos")
|
|
||||||
.run_from_script();
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,14 +1,3 @@
|
||||||
mod coalesce;
|
|
||||||
mod dim_inference;
|
|
||||||
mod from_onnx;
|
|
||||||
mod ir;
|
|
||||||
mod node_remap;
|
|
||||||
mod op_configuration;
|
mod op_configuration;
|
||||||
mod proto_conversion;
|
|
||||||
mod protos;
|
|
||||||
mod to_burn;
|
mod to_burn;
|
||||||
|
|
||||||
pub use to_burn::*;
|
pub use to_burn::*;
|
||||||
|
|
||||||
pub use from_onnx::parse_onnx;
|
|
||||||
pub use ir::OnnxGraph;
|
|
||||||
|
|
|
@ -5,8 +5,8 @@ use burn::nn::{
|
||||||
PaddingConfig2d,
|
PaddingConfig2d,
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::ir::{ArgType, AttributeValue, Data, Node};
|
|
||||||
use crate::burn::node::resize::ResizeMode;
|
use crate::burn::node::resize::ResizeMode;
|
||||||
|
use onnx_ir::ir::{ArgType, AttributeValue, Data, Node};
|
||||||
|
|
||||||
/// Create a Conv1dConfig from the attributes of the node
|
/// Create a Conv1dConfig from the attributes of the node
|
||||||
pub fn conv1d_config(curr: &Node) -> Conv1dConfig {
|
pub fn conv1d_config(curr: &Node) -> Conv1dConfig {
|
||||||
|
|
|
@ -53,20 +53,24 @@ use crate::{
|
||||||
},
|
},
|
||||||
format_tokens,
|
format_tokens,
|
||||||
logger::init_log,
|
logger::init_log,
|
||||||
onnx::{
|
|
||||||
from_onnx::convert_constant_value,
|
|
||||||
ir::{Node, NodeType},
|
|
||||||
op_configuration::*,
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::{
|
use super::op_configuration::{
|
||||||
from_onnx::parse_onnx,
|
argmax_config, avg_pool1d_config, avg_pool2d_config, batch_norm_config, clip_config,
|
||||||
ir::{self, ArgType, Argument, Data, ElementType, OnnxGraph},
|
concat_config, conv1d_config, conv2d_config, conv_transpose2d_config, dropout_config,
|
||||||
op_configuration::{
|
expand_config, flatten_config, gather_config, layer_norm_config, leaky_relu_config,
|
||||||
avg_pool2d_config, clip_config, concat_config, dropout_config, reshape_config,
|
linear_config, log_softmax_config, max_pool1d_config, max_pool2d_config, reduce_max_config,
|
||||||
resize_config, softmax_config,
|
reduce_mean_config, reduce_min_config, reduce_prod_config, reduce_sum_config, reshape_config,
|
||||||
|
resize_config, shape_config, slice_config, softmax_config, squeeze_config, transpose_config,
|
||||||
|
unsqueeze_config,
|
||||||
|
};
|
||||||
|
use onnx_ir::{
|
||||||
|
convert_constant_value,
|
||||||
|
ir::{
|
||||||
|
ArgType, Argument as OnnxArgument, Data, ElementType, Node, NodeType, OnnxGraph,
|
||||||
|
TensorType as OnnxTensorType,
|
||||||
},
|
},
|
||||||
|
parse_onnx,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub use crate::burn::graph::RecordType;
|
pub use crate::burn::graph::RecordType;
|
||||||
|
@ -197,6 +201,7 @@ impl ModelGen {
|
||||||
log::debug!("Output file: {:?}", out_file);
|
log::debug!("Output file: {:?}", out_file);
|
||||||
|
|
||||||
let graph = parse_onnx(input.as_ref());
|
let graph = parse_onnx(input.as_ref());
|
||||||
|
let graph = ParsedOnnxGraph(graph);
|
||||||
|
|
||||||
if self.development {
|
if self.development {
|
||||||
// export the graph
|
// export the graph
|
||||||
|
@ -231,15 +236,16 @@ impl ModelGen {
|
||||||
log::info!("Model generated");
|
log::info!("Model generated");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#[derive(Debug)]
|
||||||
impl OnnxGraph {
|
struct ParsedOnnxGraph(OnnxGraph);
|
||||||
|
impl ParsedOnnxGraph {
|
||||||
/// Converts ONNX graph to Burn graph.
|
/// Converts ONNX graph to Burn graph.
|
||||||
pub fn into_burn<PS: PrecisionSettings + 'static>(self) -> BurnGraph<PS> {
|
pub fn into_burn<PS: PrecisionSettings + 'static>(self) -> BurnGraph<PS> {
|
||||||
let mut graph = BurnGraph::<PS>::default();
|
let mut graph = BurnGraph::<PS>::default();
|
||||||
|
|
||||||
let mut unsupported_ops = vec![];
|
let mut unsupported_ops = vec![];
|
||||||
|
|
||||||
for node in self.nodes {
|
for node in self.0.nodes {
|
||||||
match node.node_type {
|
match node.node_type {
|
||||||
NodeType::Add => graph.register(Self::add_conversion(node)),
|
NodeType::Add => graph.register(Self::add_conversion(node)),
|
||||||
NodeType::ArgMax => graph.register(Self::argmax_conversion(node)),
|
NodeType::ArgMax => graph.register(Self::argmax_conversion(node)),
|
||||||
|
@ -328,11 +334,13 @@ impl OnnxGraph {
|
||||||
|
|
||||||
// Get input and output names
|
// Get input and output names
|
||||||
let input_names = self
|
let input_names = self
|
||||||
|
.0
|
||||||
.inputs
|
.inputs
|
||||||
.iter()
|
.iter()
|
||||||
.map(|input| input.name.clone())
|
.map(|input| input.name.clone())
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
let output_names = self
|
let output_names = self
|
||||||
|
.0
|
||||||
.outputs
|
.outputs
|
||||||
.iter()
|
.iter()
|
||||||
.map(|output| output.name.clone())
|
.map(|output| output.name.clone())
|
||||||
|
@ -390,13 +398,13 @@ impl OnnxGraph {
|
||||||
ArgType::Shape(_) => panic!("Shape is not supported as constant value."),
|
ArgType::Shape(_) => panic!("Shape is not supported as constant value."),
|
||||||
};
|
};
|
||||||
|
|
||||||
ConstantNode::new(node.name.clone(), const_value, output.to_type())
|
ConstantNode::new(node.name.clone(), const_value, Type::from(output))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn random_uniform_conversion(node: Node) -> RandomUniformNode {
|
fn random_uniform_conversion(node: Node) -> RandomUniformNode {
|
||||||
let output = node.outputs.first().unwrap();
|
let output = node.outputs.first().unwrap();
|
||||||
// cannot use output.to_tensor_type() here, since it drops the shape info...
|
// cannot use output.to_tensor_type() here, since it drops the shape info...
|
||||||
let output_type = if let Type::Tensor(t) = output.to_type() {
|
let output_type = if let Type::Tensor(t) = Type::from(output) {
|
||||||
t
|
t
|
||||||
} else {
|
} else {
|
||||||
panic!("RandomUniform output type is no Tensor.");
|
panic!("RandomUniform output type is no Tensor.");
|
||||||
|
@ -423,7 +431,7 @@ impl OnnxGraph {
|
||||||
fn random_normal_conversion(node: Node) -> RandomNormalNode {
|
fn random_normal_conversion(node: Node) -> RandomNormalNode {
|
||||||
let output = node.outputs.first().unwrap();
|
let output = node.outputs.first().unwrap();
|
||||||
// cannot use output.to_tensor_type() here, since it drops the shape info...
|
// cannot use output.to_tensor_type() here, since it drops the shape info...
|
||||||
let output_type = if let Type::Tensor(t) = output.to_type() {
|
let output_type = if let Type::Tensor(t) = Type::from(output) {
|
||||||
t
|
t
|
||||||
} else {
|
} else {
|
||||||
panic!("RandomNormal output type is no Tensor.");
|
panic!("RandomNormal output type is no Tensor.");
|
||||||
|
@ -448,141 +456,141 @@ impl OnnxGraph {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn add_conversion(node: Node) -> BinaryNode {
|
fn add_conversion(node: Node) -> BinaryNode {
|
||||||
let lhs = node.inputs.first().unwrap().to_type();
|
let lhs = Type::from(node.inputs.first().unwrap());
|
||||||
let rhs = node.inputs.get(1).unwrap().to_type();
|
let rhs = Type::from(node.inputs.get(1).unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_type();
|
let output = Type::from(node.outputs.first().unwrap());
|
||||||
|
|
||||||
BinaryNode::add(lhs, rhs, output)
|
BinaryNode::add(lhs, rhs, output)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn sub_conversion(node: Node) -> BinaryNode {
|
fn sub_conversion(node: Node) -> BinaryNode {
|
||||||
let lhs = node.inputs.first().unwrap().to_type();
|
let lhs = Type::from(node.inputs.first().unwrap());
|
||||||
let rhs = node.inputs.get(1).unwrap().to_type();
|
let rhs = Type::from(node.inputs.get(1).unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_type();
|
let output = Type::from(node.outputs.first().unwrap());
|
||||||
|
|
||||||
BinaryNode::sub(lhs, rhs, output)
|
BinaryNode::sub(lhs, rhs, output)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn mul_conversion(node: Node) -> BinaryNode {
|
fn mul_conversion(node: Node) -> BinaryNode {
|
||||||
let lhs = node.inputs.first().unwrap().to_type();
|
let lhs = Type::from(node.inputs.first().unwrap());
|
||||||
let rhs = node.inputs.get(1).unwrap().to_type();
|
let rhs = Type::from(node.inputs.get(1).unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_type();
|
let output = Type::from(node.outputs.first().unwrap());
|
||||||
|
|
||||||
BinaryNode::mul(lhs, rhs, output)
|
BinaryNode::mul(lhs, rhs, output)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn div_conversion(node: Node) -> BinaryNode {
|
fn div_conversion(node: Node) -> BinaryNode {
|
||||||
let lhs = node.inputs.first().unwrap().to_type();
|
let lhs = Type::from(node.inputs.first().unwrap());
|
||||||
let rhs = node.inputs.get(1).unwrap().to_type();
|
let rhs = Type::from(node.inputs.get(1).unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_type();
|
let output = Type::from(node.outputs.first().unwrap());
|
||||||
|
|
||||||
BinaryNode::div(lhs, rhs, output)
|
BinaryNode::div(lhs, rhs, output)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn matmul_conversion(node: Node) -> MatmulNode {
|
fn matmul_conversion(node: Node) -> MatmulNode {
|
||||||
let lhs = node.inputs.first().unwrap().to_tensor_type();
|
let lhs = TensorType::from(node.inputs.first().unwrap());
|
||||||
let rhs = node.inputs.get(1).unwrap().to_tensor_type();
|
let rhs = TensorType::from(node.inputs.get(1).unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
let output = TensorType::from(node.outputs.first().unwrap());
|
||||||
|
|
||||||
MatmulNode::new(lhs, rhs, output)
|
MatmulNode::new(lhs, rhs, output)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn equal_conversion(node: Node) -> BinaryNode {
|
fn equal_conversion(node: Node) -> BinaryNode {
|
||||||
let lhs = node.inputs.first().unwrap().to_type();
|
let lhs = Type::from(node.inputs.first().unwrap());
|
||||||
let rhs = node.inputs.get(1).unwrap().to_type();
|
let rhs = Type::from(node.inputs.get(1).unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_type();
|
let output = Type::from(node.outputs.first().unwrap());
|
||||||
|
|
||||||
BinaryNode::equal(lhs, rhs, output)
|
BinaryNode::equal(lhs, rhs, output)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn max_conversion(node: Node) -> BinaryNode {
|
fn max_conversion(node: Node) -> BinaryNode {
|
||||||
let lhs = node.inputs.first().unwrap().to_type();
|
let lhs = Type::from(node.inputs.first().unwrap());
|
||||||
let rhs = node.inputs.get(1).unwrap().to_type();
|
let rhs = Type::from(node.inputs.get(1).unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_type();
|
let output = Type::from(node.outputs.first().unwrap());
|
||||||
|
|
||||||
BinaryNode::max_pair(lhs, rhs, output)
|
BinaryNode::max_pair(lhs, rhs, output)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn erf_conversion(node: Node) -> UnaryNode {
|
fn erf_conversion(node: Node) -> UnaryNode {
|
||||||
let input = node.inputs.first().unwrap().to_type();
|
let input = Type::from(node.inputs.first().unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_type();
|
let output = Type::from(node.outputs.first().unwrap());
|
||||||
|
|
||||||
UnaryNode::erf(input, output)
|
UnaryNode::erf(input, output)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn leaky_relu_conversion(node: Node) -> UnaryNode {
|
fn leaky_relu_conversion(node: Node) -> UnaryNode {
|
||||||
let input = node.inputs.first().unwrap().to_type();
|
let input = Type::from(node.inputs.first().unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_type();
|
let output = Type::from(node.outputs.first().unwrap());
|
||||||
let alpha = leaky_relu_config(&node);
|
let alpha = leaky_relu_config(&node);
|
||||||
|
|
||||||
UnaryNode::leaky_relu(input, output, alpha)
|
UnaryNode::leaky_relu(input, output, alpha)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn relu_conversion(node: Node) -> UnaryNode {
|
fn relu_conversion(node: Node) -> UnaryNode {
|
||||||
let input = node.inputs.first().unwrap().to_type();
|
let input = Type::from(node.inputs.first().unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_type();
|
let output = Type::from(node.outputs.first().unwrap());
|
||||||
|
|
||||||
UnaryNode::relu(input, output)
|
UnaryNode::relu(input, output)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn gelu_conversion(node: Node) -> UnaryNode {
|
fn gelu_conversion(node: Node) -> UnaryNode {
|
||||||
let input = node.inputs.first().unwrap().to_type();
|
let input = Type::from(node.inputs.first().unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_type();
|
let output = Type::from(node.outputs.first().unwrap());
|
||||||
|
|
||||||
UnaryNode::gelu(input, output)
|
UnaryNode::gelu(input, output)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn log_conversion(node: Node) -> UnaryNode {
|
fn log_conversion(node: Node) -> UnaryNode {
|
||||||
let input = node.inputs.first().unwrap().to_type();
|
let input = Type::from(node.inputs.first().unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_type();
|
let output = Type::from(node.outputs.first().unwrap());
|
||||||
|
|
||||||
UnaryNode::log(input, output)
|
UnaryNode::log(input, output)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn flatten_conversion(node: Node) -> UnaryNode {
|
fn flatten_conversion(node: Node) -> UnaryNode {
|
||||||
let input = node.inputs.first().unwrap().to_type();
|
let input = Type::from(node.inputs.first().unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_type();
|
let output = Type::from(node.outputs.first().unwrap());
|
||||||
let (start_dim, end_dim) = flatten_config(&node);
|
let (start_dim, end_dim) = flatten_config(&node);
|
||||||
|
|
||||||
UnaryNode::flatten(input, output, start_dim, end_dim)
|
UnaryNode::flatten(input, output, start_dim, end_dim)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn gather_conversion(node: Node) -> GatherNode {
|
fn gather_conversion(node: Node) -> GatherNode {
|
||||||
let input = node.inputs.first().unwrap().to_tensor_type();
|
let input = TensorType::from(node.inputs.first().unwrap());
|
||||||
let index = node.inputs.get(1).unwrap().to_tensor_type();
|
let index = TensorType::from(node.inputs.get(1).unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
let output = TensorType::from(node.outputs.first().unwrap());
|
||||||
let dim = gather_config(&node);
|
let dim = gather_config(&node);
|
||||||
|
|
||||||
GatherNode::new(input, index, output, dim)
|
GatherNode::new(input, index, output, dim)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn gather_elements_conversion(node: Node) -> GatherElementsNode {
|
fn gather_elements_conversion(node: Node) -> GatherElementsNode {
|
||||||
let input = node.inputs.first().unwrap().to_tensor_type();
|
let input = TensorType::from(node.inputs.first().unwrap());
|
||||||
let index = node.inputs.get(1).unwrap().to_tensor_type();
|
let index = TensorType::from(node.inputs.get(1).unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
let output = TensorType::from(node.outputs.first().unwrap());
|
||||||
let dim = gather_config(&node);
|
let dim = gather_config(&node);
|
||||||
|
|
||||||
GatherElementsNode::new(input, index, output, dim)
|
GatherElementsNode::new(input, index, output, dim)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn transpose_conversion(node: Node) -> UnaryNode {
|
fn transpose_conversion(node: Node) -> UnaryNode {
|
||||||
let input = node.inputs.first().unwrap().to_type();
|
let input = Type::from(node.inputs.first().unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_type();
|
let output = Type::from(node.outputs.first().unwrap());
|
||||||
let perm = transpose_config(&node);
|
let perm = transpose_config(&node);
|
||||||
|
|
||||||
UnaryNode::transpose(input, output, perm)
|
UnaryNode::transpose(input, output, perm)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn cast_conversion(node: Node) -> UnaryNode {
|
fn cast_conversion(node: Node) -> UnaryNode {
|
||||||
let input = node.inputs.first().unwrap().to_type();
|
let input = Type::from(node.inputs.first().unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_type();
|
let output = Type::from(node.outputs.first().unwrap());
|
||||||
|
|
||||||
UnaryNode::cast(input, output)
|
UnaryNode::cast(input, output)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn reshape_conversion(node: Node) -> ReshapeNode {
|
fn reshape_conversion(node: Node) -> ReshapeNode {
|
||||||
let input = node.inputs.first().unwrap().to_tensor_type();
|
let input = TensorType::from(node.inputs.first().unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
let output = TensorType::from(node.outputs.first().unwrap());
|
||||||
let shape = reshape_config(&node);
|
let shape = reshape_config(&node);
|
||||||
|
|
||||||
ReshapeNode::new(input, output, shape)
|
ReshapeNode::new(input, output, shape)
|
||||||
|
@ -591,10 +599,10 @@ impl OnnxGraph {
|
||||||
fn resize_conversion(node: Node) -> ResizeNode {
|
fn resize_conversion(node: Node) -> ResizeNode {
|
||||||
let name = &node.name;
|
let name = &node.name;
|
||||||
|
|
||||||
let input = node.inputs[0].to_tensor_type();
|
let input = TensorType::from(&node.inputs[0]);
|
||||||
let output_size = node.inputs[3].to_tensor_type();
|
let output_size = TensorType::from(&node.inputs[3]);
|
||||||
|
|
||||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
let output = TensorType::from(node.outputs.first().unwrap());
|
||||||
|
|
||||||
let mode = resize_config(&node);
|
let mode = resize_config(&node);
|
||||||
|
|
||||||
|
@ -602,15 +610,15 @@ impl OnnxGraph {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn min_conversion(node: Node) -> BinaryNode {
|
fn min_conversion(node: Node) -> BinaryNode {
|
||||||
let lhs = node.inputs.first().unwrap().to_type();
|
let lhs = Type::from(node.inputs.first().unwrap());
|
||||||
let rhs = node.inputs.get(1).unwrap().to_type();
|
let rhs = Type::from(node.inputs.get(1).unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_type();
|
let output = Type::from(node.outputs.first().unwrap());
|
||||||
|
|
||||||
BinaryNode::min_pair(lhs, rhs, output)
|
BinaryNode::min_pair(lhs, rhs, output)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn range_conversion(node: Node) -> RangeNode {
|
fn range_conversion(node: Node) -> RangeNode {
|
||||||
fn convert_arg_to_scalar(arg: &Argument) -> ScalarType {
|
fn convert_arg_to_scalar(arg: &OnnxArgument) -> ScalarType {
|
||||||
match &arg.ty {
|
match &arg.ty {
|
||||||
ArgType::Scalar(scalar) => {
|
ArgType::Scalar(scalar) => {
|
||||||
ScalarType::new(arg.name.clone(), ScalarKind::from(scalar))
|
ScalarType::new(arg.name.clone(), ScalarKind::from(scalar))
|
||||||
|
@ -624,7 +632,7 @@ impl OnnxGraph {
|
||||||
_ => panic!("Range node requires scalar inputs"),
|
_ => panic!("Range node requires scalar inputs"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
let output = TensorType::from(node.outputs.first().unwrap());
|
||||||
let start = convert_arg_to_scalar(node.inputs.first().unwrap());
|
let start = convert_arg_to_scalar(node.inputs.first().unwrap());
|
||||||
let end = convert_arg_to_scalar(node.inputs.get(1).unwrap());
|
let end = convert_arg_to_scalar(node.inputs.get(1).unwrap());
|
||||||
let step = convert_arg_to_scalar(node.inputs.get(2).unwrap());
|
let step = convert_arg_to_scalar(node.inputs.get(2).unwrap());
|
||||||
|
@ -633,164 +641,156 @@ impl OnnxGraph {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn reduce_max_conversion(node: Node) -> UnaryNode {
|
fn reduce_max_conversion(node: Node) -> UnaryNode {
|
||||||
let input = node.inputs.first().unwrap().to_type();
|
let input = Type::from(node.inputs.first().unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_type();
|
let output = Type::from(node.outputs.first().unwrap());
|
||||||
let dim = reduce_max_config(&node);
|
let dim = reduce_max_config(&node);
|
||||||
|
|
||||||
UnaryNode::reduce_max(input, output, dim)
|
UnaryNode::reduce_max(input, output, dim)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn reduce_min_conversion(node: Node) -> UnaryNode {
|
fn reduce_min_conversion(node: Node) -> UnaryNode {
|
||||||
let input = node.inputs.first().unwrap().to_type();
|
let input = Type::from(node.inputs.first().unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_type();
|
let output = Type::from(node.outputs.first().unwrap());
|
||||||
let dim = reduce_min_config(&node);
|
let dim = reduce_min_config(&node);
|
||||||
|
|
||||||
UnaryNode::reduce_min(input, output, dim)
|
UnaryNode::reduce_min(input, output, dim)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn reduce_mean_conversion(node: Node) -> UnaryNode {
|
fn reduce_mean_conversion(node: Node) -> UnaryNode {
|
||||||
let input = node.inputs.first().unwrap().to_type();
|
let input = Type::from(node.inputs.first().unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_type();
|
let output = Type::from(node.outputs.first().unwrap());
|
||||||
let dim = reduce_mean_config(&node);
|
let dim = reduce_mean_config(&node);
|
||||||
|
|
||||||
UnaryNode::reduce_mean(input, output, dim)
|
UnaryNode::reduce_mean(input, output, dim)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn reduce_prod_conversion(node: Node) -> UnaryNode {
|
fn reduce_prod_conversion(node: Node) -> UnaryNode {
|
||||||
let input = node.inputs.first().unwrap().to_type();
|
let input = Type::from(node.inputs.first().unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_type();
|
let output = Type::from(node.outputs.first().unwrap());
|
||||||
let dim = reduce_prod_config(&node);
|
let dim = reduce_prod_config(&node);
|
||||||
|
|
||||||
UnaryNode::reduce_prod(input, output, dim)
|
UnaryNode::reduce_prod(input, output, dim)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn reduce_sum_conversion(node: Node) -> UnaryNode {
|
fn reduce_sum_conversion(node: Node) -> UnaryNode {
|
||||||
let input = node.inputs.first().unwrap().to_type();
|
let input = Type::from(node.inputs.first().unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_type();
|
let output = Type::from(node.outputs.first().unwrap());
|
||||||
let dim = reduce_sum_config(&node);
|
let dim = reduce_sum_config(&node);
|
||||||
|
|
||||||
UnaryNode::reduce_sum(input, output, dim)
|
UnaryNode::reduce_sum(input, output, dim)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn shape_conversion(node: Node) -> UnaryNode {
|
fn shape_conversion(node: Node) -> UnaryNode {
|
||||||
let input = node.inputs.first().unwrap().to_type();
|
let input = Type::from(node.inputs.first().unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_type();
|
let output = Type::from(node.outputs.first().unwrap());
|
||||||
let (start_dim, end_dim) = shape_config(&node);
|
let (start_dim, end_dim) = shape_config(&node);
|
||||||
|
|
||||||
UnaryNode::shape(input, output, start_dim, end_dim)
|
UnaryNode::shape(input, output, start_dim, end_dim)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn unsqueeze_conversion(node: Node) -> UnsqueezeNode {
|
fn unsqueeze_conversion(node: Node) -> UnsqueezeNode {
|
||||||
let input = node.inputs.first().unwrap().to_type();
|
let input = Type::from(node.inputs.first().unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
let output = TensorType::from(node.outputs.first().unwrap());
|
||||||
let dims = unsqueeze_config(&node);
|
let dims = unsqueeze_config(&node);
|
||||||
|
|
||||||
UnsqueezeNode::new(input, output, dims)
|
UnsqueezeNode::new(input, output, dims)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn where_conversion(node: Node) -> WhereNode {
|
fn where_conversion(node: Node) -> WhereNode {
|
||||||
let condition = node.inputs.first().unwrap().to_tensor_type();
|
let condition = TensorType::from(node.inputs.first().unwrap());
|
||||||
let x = node.inputs.get(1).unwrap().to_tensor_type();
|
let x = TensorType::from(node.inputs.get(1).unwrap());
|
||||||
let y = node.inputs.get(2).unwrap().to_tensor_type();
|
let y = TensorType::from(node.inputs.get(2).unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
let output = TensorType::from(node.outputs.first().unwrap());
|
||||||
|
|
||||||
WhereNode::new(condition, x, y, output)
|
WhereNode::new(condition, x, y, output)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn clip_conversion(node: Node) -> ClipNode {
|
fn clip_conversion(node: Node) -> ClipNode {
|
||||||
let input = node.inputs.first().unwrap().to_tensor_type();
|
let input = TensorType::from(node.inputs.first().unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
let output = TensorType::from(node.outputs.first().unwrap());
|
||||||
let (min, max) = clip_config(&node);
|
let (min, max) = clip_config(&node);
|
||||||
|
|
||||||
ClipNode::new(input, output, min, max)
|
ClipNode::new(input, output, min, max)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn sigmoid_conversion(node: Node) -> UnaryNode {
|
fn sigmoid_conversion(node: Node) -> UnaryNode {
|
||||||
let input = node.inputs.first().unwrap().to_type();
|
let input = Type::from(node.inputs.first().unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_type();
|
let output = Type::from(node.outputs.first().unwrap());
|
||||||
|
|
||||||
UnaryNode::sigmoid(input, output)
|
UnaryNode::sigmoid(input, output)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn sin_conversion(node: Node) -> UnaryNode {
|
fn sin_conversion(node: Node) -> UnaryNode {
|
||||||
let input = node.inputs.first().unwrap().to_type();
|
let input = Type::from(node.inputs.first().unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_type();
|
let output = Type::from(node.outputs.first().unwrap());
|
||||||
|
|
||||||
UnaryNode::sin(input, output)
|
UnaryNode::sin(input, output)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn slice_conversion(node: Node) -> SliceNode {
|
fn slice_conversion(node: Node) -> SliceNode {
|
||||||
let input = node.inputs.first().unwrap().to_tensor_type();
|
let input = TensorType::from(node.inputs.first().unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
let output = TensorType::from(node.outputs.first().unwrap());
|
||||||
let (starts, ends) = slice_config(&node);
|
let (starts, ends) = slice_config(&node);
|
||||||
|
|
||||||
SliceNode::new(input, output, starts, ends)
|
SliceNode::new(input, output, starts, ends)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn sum_conversion(node: Node) -> SumNode {
|
fn sum_conversion(node: Node) -> SumNode {
|
||||||
let inputs = node
|
let inputs = node.inputs.iter().map(TensorType::from).collect();
|
||||||
.inputs
|
let output = TensorType::from(node.outputs.first().unwrap());
|
||||||
.iter()
|
|
||||||
.map(|input| input.to_tensor_type())
|
|
||||||
.collect();
|
|
||||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
|
||||||
|
|
||||||
SumNode::new(inputs, output)
|
SumNode::new(inputs, output)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn reciprocal_conversion(node: Node) -> UnaryNode {
|
fn reciprocal_conversion(node: Node) -> UnaryNode {
|
||||||
let input = node.inputs.first().unwrap().to_type();
|
let input = Type::from(node.inputs.first().unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_type();
|
let output = Type::from(node.outputs.first().unwrap());
|
||||||
|
|
||||||
UnaryNode::reciprocal(input, output)
|
UnaryNode::reciprocal(input, output)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn log_softmax_conversion(node: Node) -> UnaryNode {
|
fn log_softmax_conversion(node: Node) -> UnaryNode {
|
||||||
let input = node.inputs.first().unwrap().to_type();
|
let input = Type::from(node.inputs.first().unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_type();
|
let output = Type::from(node.outputs.first().unwrap());
|
||||||
let dim = log_softmax_config(&node);
|
let dim = log_softmax_config(&node);
|
||||||
|
|
||||||
UnaryNode::log_softmax(input, output, dim)
|
UnaryNode::log_softmax(input, output, dim)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn softmax_conversion(node: Node) -> UnaryNode {
|
fn softmax_conversion(node: Node) -> UnaryNode {
|
||||||
let input = node.inputs.first().unwrap().to_type();
|
let input = Type::from(node.inputs.first().unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_type();
|
let output = Type::from(node.outputs.first().unwrap());
|
||||||
let dim = softmax_config(&node);
|
let dim = softmax_config(&node);
|
||||||
|
|
||||||
UnaryNode::softmax(input, output, dim)
|
UnaryNode::softmax(input, output, dim)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn sqrt_conversion(node: Node) -> UnaryNode {
|
fn sqrt_conversion(node: Node) -> UnaryNode {
|
||||||
let input = node.inputs.first().unwrap().to_type();
|
let input = Type::from(node.inputs.first().unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_type();
|
let output = Type::from(node.outputs.first().unwrap());
|
||||||
|
|
||||||
UnaryNode::sqrt(input, output)
|
UnaryNode::sqrt(input, output)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn tanh_conversion(node: Node) -> UnaryNode {
|
fn tanh_conversion(node: Node) -> UnaryNode {
|
||||||
let input = node.inputs.first().unwrap().to_type();
|
let input = Type::from(node.inputs.first().unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_type();
|
let output = Type::from(node.outputs.first().unwrap());
|
||||||
|
|
||||||
UnaryNode::tanh(input, output)
|
UnaryNode::tanh(input, output)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn argmax_conversion(node: Node) -> ArgMaxNode {
|
fn argmax_conversion(node: Node) -> ArgMaxNode {
|
||||||
let input = node.inputs.first().unwrap().to_tensor_type();
|
let input = TensorType::from(node.inputs.first().unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
let output = TensorType::from(node.outputs.first().unwrap());
|
||||||
let axis = argmax_config(&node);
|
let axis = argmax_config(&node);
|
||||||
|
|
||||||
ArgMaxNode::new(input, output, axis)
|
ArgMaxNode::new(input, output, axis)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn concat_conversion(node: Node) -> ConcatNode {
|
fn concat_conversion(node: Node) -> ConcatNode {
|
||||||
let inputs = node
|
let inputs = node.inputs.iter().map(TensorType::from).collect();
|
||||||
.inputs
|
|
||||||
.iter()
|
|
||||||
.map(|input| input.to_tensor_type())
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
let output = TensorType::from(node.outputs.first().unwrap());
|
||||||
let dim = concat_config(&node);
|
let dim = concat_config(&node);
|
||||||
|
|
||||||
ConcatNode::new(inputs, output, dim)
|
ConcatNode::new(inputs, output, dim)
|
||||||
|
@ -798,8 +798,8 @@ impl OnnxGraph {
|
||||||
|
|
||||||
fn linear_conversion<PS: PrecisionSettings>(node: Node) -> LinearNode {
|
fn linear_conversion<PS: PrecisionSettings>(node: Node) -> LinearNode {
|
||||||
let name = &node.name;
|
let name = &node.name;
|
||||||
let input = node.inputs.first().unwrap().to_tensor_type();
|
let input = TensorType::from(node.inputs.first().unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
let output = TensorType::from(node.outputs.first().unwrap());
|
||||||
let config = linear_config(&node);
|
let config = linear_config(&node);
|
||||||
|
|
||||||
let weight = extract_data_serialize::<PS::FloatElem>(1, &node).expect("Weight is required");
|
let weight = extract_data_serialize::<PS::FloatElem>(1, &node).expect("Weight is required");
|
||||||
|
@ -811,8 +811,8 @@ impl OnnxGraph {
|
||||||
|
|
||||||
fn dropout_conversion(node: Node) -> DropoutNode {
|
fn dropout_conversion(node: Node) -> DropoutNode {
|
||||||
let name = &node.name;
|
let name = &node.name;
|
||||||
let input = node.inputs.first().unwrap().to_tensor_type();
|
let input = TensorType::from(node.inputs.first().unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
let output = TensorType::from(node.outputs.first().unwrap());
|
||||||
let config = dropout_config(&node);
|
let config = dropout_config(&node);
|
||||||
|
|
||||||
DropoutNode::new(name, input, output, config)
|
DropoutNode::new(name, input, output, config)
|
||||||
|
@ -820,8 +820,8 @@ impl OnnxGraph {
|
||||||
|
|
||||||
fn batch_norm_conversion<PS: PrecisionSettings>(node: Node) -> BatchNormNode {
|
fn batch_norm_conversion<PS: PrecisionSettings>(node: Node) -> BatchNormNode {
|
||||||
let config = batch_norm_config(&node);
|
let config = batch_norm_config(&node);
|
||||||
let input = node.inputs.first().unwrap().to_tensor_type();
|
let input = TensorType::from(node.inputs.first().unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
let output = TensorType::from(node.outputs.first().unwrap());
|
||||||
let dim = input.dim - 2;
|
let dim = input.dim - 2;
|
||||||
|
|
||||||
let gamma = extract_data_serialize::<PS::FloatElem>(1, &node).expect("Gamma is required");
|
let gamma = extract_data_serialize::<PS::FloatElem>(1, &node).expect("Gamma is required");
|
||||||
|
@ -848,8 +848,8 @@ impl OnnxGraph {
|
||||||
|
|
||||||
fn layer_norm_conversion<PS: PrecisionSettings>(node: Node) -> LayerNormNode {
|
fn layer_norm_conversion<PS: PrecisionSettings>(node: Node) -> LayerNormNode {
|
||||||
let (config, full_precision) = layer_norm_config(&node);
|
let (config, full_precision) = layer_norm_config(&node);
|
||||||
let input = node.inputs.first().unwrap().to_tensor_type();
|
let input = TensorType::from(node.inputs.first().unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
let output = TensorType::from(node.outputs.first().unwrap());
|
||||||
|
|
||||||
// Scale tensor (aka gamma)
|
// Scale tensor (aka gamma)
|
||||||
let gamma = extract_data_serialize::<PS::FloatElem>(1, &node).expect("Gamma is required");
|
let gamma = extract_data_serialize::<PS::FloatElem>(1, &node).expect("Gamma is required");
|
||||||
|
@ -862,8 +862,8 @@ impl OnnxGraph {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn conv1d_conversion<PS: PrecisionSettings>(node: Node) -> Conv1dNode {
|
fn conv1d_conversion<PS: PrecisionSettings>(node: Node) -> Conv1dNode {
|
||||||
let input = node.inputs.first().unwrap().to_tensor_type();
|
let input = TensorType::from(node.inputs.first().unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
let output = TensorType::from(node.outputs.first().unwrap());
|
||||||
let config = conv1d_config(&node);
|
let config = conv1d_config(&node);
|
||||||
|
|
||||||
let bias = node.inputs.len() == 3;
|
let bias = node.inputs.len() == 3;
|
||||||
|
@ -878,8 +878,8 @@ impl OnnxGraph {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn conv2d_conversion<PS: PrecisionSettings>(node: Node) -> Conv2dNode {
|
fn conv2d_conversion<PS: PrecisionSettings>(node: Node) -> Conv2dNode {
|
||||||
let input = node.inputs.first().unwrap().to_tensor_type();
|
let input = TensorType::from(node.inputs.first().unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
let output = TensorType::from(node.outputs.first().unwrap());
|
||||||
let config = conv2d_config(&node);
|
let config = conv2d_config(&node);
|
||||||
|
|
||||||
let bias = node.inputs.len() == 3;
|
let bias = node.inputs.len() == 3;
|
||||||
|
@ -894,8 +894,8 @@ impl OnnxGraph {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn max_pool1d_conversion(node: Node) -> MaxPool1dNode {
|
fn max_pool1d_conversion(node: Node) -> MaxPool1dNode {
|
||||||
let input = node.inputs.first().unwrap().to_tensor_type();
|
let input = TensorType::from(node.inputs.first().unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
let output = TensorType::from(node.outputs.first().unwrap());
|
||||||
let config = max_pool1d_config(&node);
|
let config = max_pool1d_config(&node);
|
||||||
|
|
||||||
let name = &node.name;
|
let name = &node.name;
|
||||||
|
@ -903,8 +903,8 @@ impl OnnxGraph {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn max_pool2d_conversion(node: Node) -> MaxPool2dNode {
|
fn max_pool2d_conversion(node: Node) -> MaxPool2dNode {
|
||||||
let input = node.inputs.first().unwrap().to_tensor_type();
|
let input = TensorType::from(node.inputs.first().unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
let output = TensorType::from(node.outputs.first().unwrap());
|
||||||
let config = max_pool2d_config(&node);
|
let config = max_pool2d_config(&node);
|
||||||
|
|
||||||
let name = &node.name;
|
let name = &node.name;
|
||||||
|
@ -912,16 +912,16 @@ impl OnnxGraph {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn prelu_conversion<PS: PrecisionSettings>(node: Node) -> PReluNode {
|
fn prelu_conversion<PS: PrecisionSettings>(node: Node) -> PReluNode {
|
||||||
let input = node.inputs.first().unwrap().to_tensor_type();
|
let input = TensorType::from(node.inputs.first().unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
let output = TensorType::from(node.outputs.first().unwrap());
|
||||||
let weight = extract_data_serialize::<PS::FloatElem>(1, &node).unwrap();
|
let weight = extract_data_serialize::<PS::FloatElem>(1, &node).unwrap();
|
||||||
let config = PReluConfig::new();
|
let config = PReluConfig::new();
|
||||||
let name = &node.name;
|
let name = &node.name;
|
||||||
PReluNode::new(name, input, output, weight, config)
|
PReluNode::new(name, input, output, weight, config)
|
||||||
}
|
}
|
||||||
fn conv_transpose2d_conversion<PS: PrecisionSettings>(node: Node) -> ConvTranspose2dNode {
|
fn conv_transpose2d_conversion<PS: PrecisionSettings>(node: Node) -> ConvTranspose2dNode {
|
||||||
let input = node.inputs.first().unwrap().to_tensor_type();
|
let input = TensorType::from(node.inputs.first().unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
let output = TensorType::from(node.outputs.first().unwrap());
|
||||||
let config = conv_transpose2d_config(&node);
|
let config = conv_transpose2d_config(&node);
|
||||||
|
|
||||||
let bias = node.inputs.len() == 3;
|
let bias = node.inputs.len() == 3;
|
||||||
|
@ -935,8 +935,8 @@ impl OnnxGraph {
|
||||||
ConvTranspose2dNode::new(name, input, output, weight, bias, config)
|
ConvTranspose2dNode::new(name, input, output, weight, bias, config)
|
||||||
}
|
}
|
||||||
fn avg_pool_1d_conversion(node: Node) -> AvgPool1dNode {
|
fn avg_pool_1d_conversion(node: Node) -> AvgPool1dNode {
|
||||||
let input = node.inputs.first().unwrap().to_tensor_type();
|
let input = TensorType::from(node.inputs.first().unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
let output = TensorType::from(node.outputs.first().unwrap());
|
||||||
let config = avg_pool1d_config(&node);
|
let config = avg_pool1d_config(&node);
|
||||||
|
|
||||||
let name = &node.name;
|
let name = &node.name;
|
||||||
|
@ -944,8 +944,8 @@ impl OnnxGraph {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn avg_pool_2d_conversion(node: Node) -> AvgPool2dNode {
|
fn avg_pool_2d_conversion(node: Node) -> AvgPool2dNode {
|
||||||
let input = node.inputs.first().unwrap().to_tensor_type();
|
let input = TensorType::from(node.inputs.first().unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
let output = TensorType::from(node.outputs.first().unwrap());
|
||||||
let config = avg_pool2d_config(&node);
|
let config = avg_pool2d_config(&node);
|
||||||
|
|
||||||
let name = &node.name;
|
let name = &node.name;
|
||||||
|
@ -953,8 +953,8 @@ impl OnnxGraph {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn global_avg_pool_conversion(node: Node) -> GlobalAvgPoolNode {
|
fn global_avg_pool_conversion(node: Node) -> GlobalAvgPoolNode {
|
||||||
let input = node.inputs.first().unwrap().to_tensor_type();
|
let input = TensorType::from(node.inputs.first().unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
let output = TensorType::from(node.outputs.first().unwrap());
|
||||||
|
|
||||||
let name = &node.name;
|
let name = &node.name;
|
||||||
|
|
||||||
|
@ -962,71 +962,71 @@ impl OnnxGraph {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn cos_conversion(node: Node) -> UnaryNode {
|
fn cos_conversion(node: Node) -> UnaryNode {
|
||||||
let input = node.inputs.first().unwrap().to_type();
|
let input = Type::from(node.inputs.first().unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_type();
|
let output = Type::from(node.outputs.first().unwrap());
|
||||||
|
|
||||||
UnaryNode::cos(input, output)
|
UnaryNode::cos(input, output)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn exp_conversion(node: Node) -> UnaryNode {
|
fn exp_conversion(node: Node) -> UnaryNode {
|
||||||
let input = node.inputs.first().unwrap().to_type();
|
let input = Type::from(node.inputs.first().unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_type();
|
let output = Type::from(node.outputs.first().unwrap());
|
||||||
|
|
||||||
UnaryNode::exp(input, output)
|
UnaryNode::exp(input, output)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn expand_conversion(node: Node) -> ExpandNode {
|
fn expand_conversion(node: Node) -> ExpandNode {
|
||||||
let input = node.inputs.first().unwrap().to_tensor_type();
|
let input = TensorType::from(node.inputs.first().unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
let output = TensorType::from(node.outputs.first().unwrap());
|
||||||
let shape = expand_config(&node);
|
let shape = expand_config(&node);
|
||||||
|
|
||||||
ExpandNode::new(input, output, shape)
|
ExpandNode::new(input, output, shape)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn neg_conversion(node: Node) -> UnaryNode {
|
fn neg_conversion(node: Node) -> UnaryNode {
|
||||||
let input = node.inputs.first().unwrap().to_type();
|
let input = Type::from(node.inputs.first().unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_type();
|
let output = Type::from(node.outputs.first().unwrap());
|
||||||
UnaryNode::neg(input, output)
|
UnaryNode::neg(input, output)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn not_conversion(node: Node) -> UnaryNode {
|
fn not_conversion(node: Node) -> UnaryNode {
|
||||||
let input = node.inputs.first().unwrap().to_type();
|
let input = Type::from(node.inputs.first().unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_type();
|
let output = Type::from(node.outputs.first().unwrap());
|
||||||
UnaryNode::not(input, output)
|
UnaryNode::not(input, output)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn greater_conversion(node: Node) -> BinaryNode {
|
fn greater_conversion(node: Node) -> BinaryNode {
|
||||||
let lhs = node.inputs.first().unwrap().to_type();
|
let lhs = Type::from(node.inputs.first().unwrap());
|
||||||
let rhs = node.inputs.get(1).unwrap().to_type();
|
let rhs = Type::from(node.inputs.get(1).unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_type();
|
let output = Type::from(node.outputs.first().unwrap());
|
||||||
BinaryNode::greater(lhs, rhs, output)
|
BinaryNode::greater(lhs, rhs, output)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn less_conversion(node: Node) -> BinaryNode {
|
fn less_conversion(node: Node) -> BinaryNode {
|
||||||
let lhs = node.inputs.first().unwrap().to_type();
|
let lhs = Type::from(node.inputs.first().unwrap());
|
||||||
let rhs = node.inputs.get(1).unwrap().to_type();
|
let rhs = Type::from(node.inputs.get(1).unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_type();
|
let output = Type::from(node.outputs.first().unwrap());
|
||||||
BinaryNode::lower(lhs, rhs, output)
|
BinaryNode::lower(lhs, rhs, output)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn greater_or_equal_conversion(node: Node) -> BinaryNode {
|
fn greater_or_equal_conversion(node: Node) -> BinaryNode {
|
||||||
let lhs = node.inputs.first().unwrap().to_type();
|
let lhs = Type::from(node.inputs.first().unwrap());
|
||||||
let rhs = node.inputs.get(1).unwrap().to_type();
|
let rhs = Type::from(node.inputs.get(1).unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_type();
|
let output = Type::from(node.outputs.first().unwrap());
|
||||||
BinaryNode::greater_equal(lhs, rhs, output)
|
BinaryNode::greater_equal(lhs, rhs, output)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn less_or_equal_conversion(node: Node) -> BinaryNode {
|
fn less_or_equal_conversion(node: Node) -> BinaryNode {
|
||||||
let lhs = node.inputs.first().unwrap().to_type();
|
let lhs = Type::from(node.inputs.first().unwrap());
|
||||||
let rhs = node.inputs.get(1).unwrap().to_type();
|
let rhs = Type::from(node.inputs.get(1).unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_type();
|
let output = Type::from(node.outputs.first().unwrap());
|
||||||
BinaryNode::lower_equal(lhs, rhs, output)
|
BinaryNode::lower_equal(lhs, rhs, output)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn pow_conversion(node: Node) -> BinaryNode {
|
fn pow_conversion(node: Node) -> BinaryNode {
|
||||||
let lhs = node.inputs.first().unwrap().to_type();
|
let lhs = Type::from(node.inputs.first().unwrap());
|
||||||
let rhs = node.inputs.get(1).unwrap().to_type();
|
let rhs = Type::from(node.inputs.get(1).unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_type();
|
let output = Type::from(node.outputs.first().unwrap());
|
||||||
match &rhs {
|
match &rhs {
|
||||||
Type::Tensor(x) => match x.kind {
|
Type::Tensor(x) => match x.kind {
|
||||||
TensorKind::Int => BinaryNode::powi(lhs, rhs, output),
|
TensorKind::Int => BinaryNode::powi(lhs, rhs, output),
|
||||||
|
@ -1043,14 +1043,14 @@ impl OnnxGraph {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn sign_conversion(node: Node) -> UnaryNode {
|
fn sign_conversion(node: Node) -> UnaryNode {
|
||||||
let input = node.inputs.first().unwrap().to_type();
|
let input = Type::from(node.inputs.first().unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_type();
|
let output = Type::from(node.outputs.first().unwrap());
|
||||||
UnaryNode::sign(input, output)
|
UnaryNode::sign(input, output)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn squeeze_conversion(node: Node) -> SqueezeNode {
|
fn squeeze_conversion(node: Node) -> SqueezeNode {
|
||||||
let input = node.inputs.first().unwrap().to_tensor_type();
|
let input = TensorType::from(node.inputs.first().unwrap());
|
||||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
let output = TensorType::from(node.outputs.first().unwrap());
|
||||||
let axes = squeeze_config(&node);
|
let axes = squeeze_config(&node);
|
||||||
|
|
||||||
SqueezeNode::new(input, output, axes)
|
SqueezeNode::new(input, output, axes)
|
||||||
|
@ -1101,48 +1101,49 @@ fn serialize_data<E: Element>(data: Data, shape: Vec<usize>) -> TensorData {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Argument {
|
impl From<&OnnxArgument> for TensorType {
|
||||||
pub fn to_tensor_type(&self) -> TensorType {
|
fn from(arg: &OnnxArgument) -> Self {
|
||||||
match &self.ty {
|
match &arg.ty {
|
||||||
ArgType::Tensor(ir::TensorType {
|
ArgType::Tensor(OnnxTensorType {
|
||||||
elem_type: ElementType::Float16 | ElementType::Float32 | ElementType::Float64,
|
elem_type: ElementType::Float16 | ElementType::Float32 | ElementType::Float64,
|
||||||
dim,
|
dim,
|
||||||
..
|
..
|
||||||
}) => TensorType::new_float(self.name.clone(), *dim),
|
}) => TensorType::new_float(arg.name.clone(), *dim),
|
||||||
ArgType::Tensor(ir::TensorType {
|
ArgType::Tensor(OnnxTensorType {
|
||||||
elem_type: ElementType::Int32 | ElementType::Int64,
|
elem_type: ElementType::Int32 | ElementType::Int64,
|
||||||
dim,
|
dim,
|
||||||
..
|
..
|
||||||
}) => TensorType::new_int(self.name.clone(), *dim),
|
}) => TensorType::new_int(arg.name.clone(), *dim),
|
||||||
ArgType::Tensor(ir::TensorType {
|
ArgType::Tensor(OnnxTensorType {
|
||||||
elem_type: ElementType::Bool,
|
elem_type: ElementType::Bool,
|
||||||
dim,
|
dim,
|
||||||
..
|
..
|
||||||
}) => TensorType::new_bool(self.name.clone(), *dim),
|
}) => TensorType::new_bool(arg.name.clone(), *dim),
|
||||||
_ => panic!("Can't transform to tensor."),
|
_ => panic!("Can't transform scalar to tensor."),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
pub fn to_type(&self) -> Type {
|
impl From<&OnnxArgument> for Type {
|
||||||
match &self.ty {
|
fn from(arg: &OnnxArgument) -> Self {
|
||||||
|
match &arg.ty {
|
||||||
ArgType::Tensor(tensor) => {
|
ArgType::Tensor(tensor) => {
|
||||||
// Treat tensor with dim 0 as scalar
|
// Treat tensor with dim 0 as scalar
|
||||||
if tensor.dim == 0 {
|
if tensor.dim == 0 {
|
||||||
Type::Scalar(ScalarType::new(
|
Type::Scalar(ScalarType::new(
|
||||||
self.name.clone(),
|
arg.name.clone(),
|
||||||
ScalarKind::from(&tensor.elem_type),
|
ScalarKind::from(&tensor.elem_type),
|
||||||
))
|
))
|
||||||
} else {
|
} else {
|
||||||
let kind: TensorKind = tensor.elem_type.clone().into();
|
let kind: TensorKind = tensor.elem_type.clone().into();
|
||||||
let dim = tensor.dim;
|
let dim = tensor.dim;
|
||||||
let name = self.name.clone();
|
let name = arg.name.clone();
|
||||||
let shape = tensor.shape.clone();
|
let shape = tensor.shape.clone();
|
||||||
Type::Tensor(TensorType::new(name, dim, kind, shape))
|
Type::Tensor(TensorType::new(name, dim, kind, shape))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ArgType::Scalar(elem_type) => {
|
ArgType::Scalar(elem_type) => {
|
||||||
Type::Scalar(ScalarType::new(self.name.clone(), elem_type.into()))
|
Type::Scalar(ScalarType::new(arg.name.clone(), elem_type.into()))
|
||||||
}
|
}
|
||||||
ArgType::Shape(_shape) => panic!("Can't transform shape to tensor."),
|
ArgType::Shape(_shape) => panic!("Can't transform shape to tensor."),
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,31 @@
|
||||||
|
[package]
|
||||||
|
authors = [
|
||||||
|
"Dilshod Tadjibaev (@antimora)",
|
||||||
|
"Nathaniel Simard (@nathanielsimard)",
|
||||||
|
]
|
||||||
|
description = "Library for parsing ONNX models"
|
||||||
|
edition.workspace = true
|
||||||
|
license.workspace = true
|
||||||
|
name = "onnx-ir"
|
||||||
|
readme.workspace = true
|
||||||
|
repository = "https://github.com/tracel-ai/burn/tree/main/crates/onnx-ir"
|
||||||
|
version.workspace = true
|
||||||
|
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
bytemuck = { workspace = true }
|
||||||
|
half = { workspace = true }
|
||||||
|
log = { workspace = true }
|
||||||
|
protobuf = { workspace = true, features = ["with-bytes"] }
|
||||||
|
regex = { workspace = true }
|
||||||
|
serde = { workspace = true, features = ["derive"] }
|
||||||
|
strum = { workspace = true }
|
||||||
|
strum_macros = { workspace = true }
|
||||||
|
|
||||||
|
|
||||||
|
[build-dependencies]
|
||||||
|
protobuf-codegen = { workspace = true }
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
pretty_assertions = { workspace = true }
|
||||||
|
rstest = { workspace = true }
|
|
@ -0,0 +1,7 @@
|
||||||
|
# ONNX-IR
|
||||||
|
|
||||||
|
A pure rust Onnx Parser. Creates an intermediate representation useful for generating code in any ML/DL framework
|
||||||
|
|
||||||
|
For a full list of currently supported operators, please check [here](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/SUPPORTED-ONNX-OPS.md)
|
||||||
|
|
||||||
|
To see how to use this for generating burn graphs, see [here](crates/burn-import/src/onnx/to_burn.rs).
|
|
@ -0,0 +1,9 @@
|
||||||
|
fn main() {
|
||||||
|
// Generate the onnx protobuf files
|
||||||
|
protobuf_codegen::Codegen::new()
|
||||||
|
.pure()
|
||||||
|
.includes(["src"])
|
||||||
|
.input("src/protos/onnx.proto")
|
||||||
|
.cargo_out_dir("onnx-protos")
|
||||||
|
.run_from_script();
|
||||||
|
}
|
|
@ -6,7 +6,7 @@ use super::{
|
||||||
proto_conversion::convert_node_proto,
|
proto_conversion::convert_node_proto,
|
||||||
protos::NodeProto,
|
protos::NodeProto,
|
||||||
};
|
};
|
||||||
use crate::onnx::ir::{ArgType, Data, TensorType};
|
use crate::ir::{ArgType, Data, TensorType};
|
||||||
|
|
||||||
/// The function transforms the graph into a new one where the nodes are coalesced into a single node.
|
/// The function transforms the graph into a new one where the nodes are coalesced into a single node.
|
||||||
pub fn coalesce(
|
pub fn coalesce(
|
|
@ -3,10 +3,10 @@ use core::panic;
|
||||||
|
|
||||||
use protobuf::Enum;
|
use protobuf::Enum;
|
||||||
|
|
||||||
use super::{
|
use crate::{
|
||||||
ir::{ArgType, AttributeValue, Data, ElementType, Node, NodeType, TensorType},
|
ir::{ArgType, AttributeValue, Data, ElementType, Node, NodeType, TensorType},
|
||||||
op_configuration::flatten_config,
|
|
||||||
protos::tensor_proto::DataType,
|
protos::tensor_proto::DataType,
|
||||||
|
util::flatten_config,
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Infer the dimension of each output tensor and update them.
|
/// Infer the dimension of each output tensor and update them.
|
|
@ -4,7 +4,7 @@ use std::{
|
||||||
path::Path,
|
path::Path,
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::onnx::node_remap::remap_node_type;
|
use crate::node_remap::remap_node_type;
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
coalesce::coalesce,
|
coalesce::coalesce,
|
||||||
|
@ -56,9 +56,9 @@ pub struct GraphData {
|
||||||
|
|
||||||
impl GraphData {
|
impl GraphData {
|
||||||
pub(crate) fn new(
|
pub(crate) fn new(
|
||||||
inputs: &Vec<ValueInfoProto>,
|
inputs: &[ValueInfoProto],
|
||||||
outputs: &Vec<ValueInfoProto>,
|
outputs: &[ValueInfoProto],
|
||||||
initializers: &Vec<TensorProto>,
|
initializers: &[TensorProto],
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let mut input_name_map = HashMap::new();
|
let mut input_name_map = HashMap::new();
|
||||||
let mut input_key_map = HashMap::new();
|
let mut input_key_map = HashMap::new();
|
||||||
|
@ -375,35 +375,32 @@ pub fn parse_onnx(onnx_path: &Path) -> OnnxGraph {
|
||||||
/// properly deleted if nothing else uses it
|
/// properly deleted if nothing else uses it
|
||||||
/// Remap the unsqueeze node to a reshape node
|
/// Remap the unsqueeze node to a reshape node
|
||||||
pub(crate) fn remap_unsqueeze_to_reshape(node: &mut Node, out_arg: &Argument) {
|
pub(crate) fn remap_unsqueeze_to_reshape(node: &mut Node, out_arg: &Argument) {
|
||||||
match &out_arg.ty {
|
if let ArgType::Tensor(output_tensor) = &out_arg.ty {
|
||||||
ArgType::Tensor(output_tensor) => {
|
let inner = output_tensor
|
||||||
let inner = output_tensor
|
.shape
|
||||||
.shape
|
.clone()
|
||||||
.clone()
|
.unwrap()
|
||||||
.unwrap()
|
.into_iter()
|
||||||
.into_iter()
|
.map(|x| x as i64)
|
||||||
.map(|x| x as i64)
|
.collect::<Vec<i64>>();
|
||||||
.collect::<Vec<i64>>();
|
let shape_len = inner.len();
|
||||||
let shape_len = inner.len();
|
let new_rhs_value = Some(Data::Int64s(inner));
|
||||||
let new_rhs_value = Some(Data::Int64s(inner));
|
//moving the remap to here
|
||||||
//moving the remap to here
|
let rhs_arg = Argument {
|
||||||
let rhs_arg = Argument {
|
name: format!("{}_generated_const", &node.name),
|
||||||
name: format!("{}_generated_const", &node.name),
|
ty: ArgType::Tensor(TensorType {
|
||||||
ty: ArgType::Tensor(TensorType {
|
elem_type: super::ir::ElementType::Int64,
|
||||||
elem_type: super::ir::ElementType::Int64,
|
dim: 1,
|
||||||
dim: 1,
|
shape: Some(vec![shape_len]),
|
||||||
shape: Some(vec![shape_len]),
|
}),
|
||||||
}),
|
value: new_rhs_value,
|
||||||
value: new_rhs_value,
|
passed: false,
|
||||||
passed: false,
|
};
|
||||||
};
|
// ? should this replace the old input (reuse the old key) or should it be a new key
|
||||||
// ? should this replace the old input (reuse the old key) or should it be a new key
|
// going with new key for now
|
||||||
// going with new key for now
|
node.inputs[1] = rhs_arg;
|
||||||
node.inputs[1] = rhs_arg;
|
node.outputs[0] = out_arg.clone();
|
||||||
node.outputs[0] = out_arg.clone();
|
node.node_type = NodeType::Reshape;
|
||||||
node.node_type = NodeType::Reshape;
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Define a trait for topological sorting
|
// Define a trait for topological sorting
|
||||||
|
@ -444,7 +441,7 @@ impl TopologicalSortable for Vec<NodeProto> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the value of a constant node from its attributes
|
/// Get the value of a constant node from its attributes
|
||||||
pub(crate) fn convert_constant_value(node: &Node) -> Argument {
|
pub fn convert_constant_value(node: &Node) -> Argument {
|
||||||
// A value can be stored in any of these attributes
|
// A value can be stored in any of these attributes
|
||||||
let keys = [
|
let keys = [
|
||||||
"value",
|
"value",
|
|
@ -3,7 +3,7 @@ use half::f16;
|
||||||
use std::{collections::HashMap, fmt::Formatter};
|
use std::{collections::HashMap, fmt::Formatter};
|
||||||
use strum_macros::{Display, EnumString};
|
use strum_macros::{Display, EnumString};
|
||||||
|
|
||||||
use super::protos::TensorProto;
|
use crate::protos::TensorProto;
|
||||||
|
|
||||||
pub type Dim = usize;
|
pub type Dim = usize;
|
||||||
pub type Shape = Vec<Dim>;
|
pub type Shape = Vec<Dim>;
|
||||||
|
@ -241,7 +241,7 @@ impl PartialEq for Argument {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The list of supported node types (ONNX operators and some extra ones to map easily to Burn's ops)
|
/// The list of supported node types (ONNX operators and some extra ones to map easily to Burn's ops)
|
||||||
/// Refer: https://github.com/onnx/onnx/blob/main/docs/Operators.md
|
/// Refer: <https://github.com/onnx/onnx/blob/main/docs/Operators.md>
|
||||||
#[derive(Debug, Hash, Eq, PartialEq, EnumString, Clone, Display)]
|
#[derive(Debug, Hash, Eq, PartialEq, EnumString, Clone, Display)]
|
||||||
pub enum NodeType {
|
pub enum NodeType {
|
||||||
Abs,
|
Abs,
|
||||||
|
@ -444,7 +444,7 @@ pub enum NodeType {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Truncate the vector display for debug display
|
/// Truncate the vector display for debug display
|
||||||
fn trunc<T: fmt::Display>(v: &Vec<T>) -> String {
|
fn trunc<T: fmt::Display>(v: &[T]) -> String {
|
||||||
const BEGIN_INDEX: usize = 0;
|
const BEGIN_INDEX: usize = 0;
|
||||||
const MAX_LEN: usize = 5;
|
const MAX_LEN: usize = 5;
|
||||||
let mut s = String::new();
|
let mut s = String::new();
|
|
@ -0,0 +1,12 @@
|
||||||
|
mod coalesce;
|
||||||
|
mod dim_inference;
|
||||||
|
mod from_onnx;
|
||||||
|
pub mod ir;
|
||||||
|
mod node_remap;
|
||||||
|
mod proto_conversion;
|
||||||
|
mod protos;
|
||||||
|
mod util;
|
||||||
|
|
||||||
|
pub use from_onnx::convert_constant_value;
|
||||||
|
pub use from_onnx::parse_onnx;
|
||||||
|
pub use ir::OnnxGraph;
|
|
@ -1,6 +1,6 @@
|
||||||
use std::str::{from_utf8, FromStr};
|
use std::str::{from_utf8, FromStr};
|
||||||
|
|
||||||
use crate::onnx::ir::TensorType;
|
use crate::ir::TensorType;
|
||||||
|
|
||||||
use super::from_onnx::GraphData;
|
use super::from_onnx::GraphData;
|
||||||
use super::ir::Dim;
|
use super::ir::Dim;
|
|
@ -0,0 +1,45 @@
|
||||||
|
use crate::ir::{ArgType, Node};
|
||||||
|
/// Create a FlattenConfig from the attributes of the node
|
||||||
|
pub fn flatten_config(curr: &Node) -> (usize, usize) {
|
||||||
|
// the begin dimension is the first dimension (Default: 1 per ONNX spec)
|
||||||
|
let mut start_dim: i64 = 1;
|
||||||
|
|
||||||
|
// check if the node has only one input
|
||||||
|
if curr.inputs.len() != 1 {
|
||||||
|
panic!(
|
||||||
|
"Flatten: multiple inputs are not supported (got {:?})",
|
||||||
|
curr.inputs.len()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// extract the shape of the input tensor
|
||||||
|
let tensor = match curr.inputs.first().unwrap().clone().ty {
|
||||||
|
ArgType::Tensor(tensor) => tensor,
|
||||||
|
_ => panic!("Only tensor input is valid"),
|
||||||
|
};
|
||||||
|
|
||||||
|
// check if the input tensor has at least 2 dimensions
|
||||||
|
if tensor.dim < 2 {
|
||||||
|
panic!(
|
||||||
|
"Flatten: input tensor must have at least 2 dimensions (got {:?})",
|
||||||
|
tensor.dim
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// the end dimension is the last dimension
|
||||||
|
let end_dim = tensor.dim - 1;
|
||||||
|
|
||||||
|
// extract the attributes
|
||||||
|
for (key, value) in curr.attrs.iter() {
|
||||||
|
if key.as_str() == "axis" {
|
||||||
|
start_dim = value.clone().into_i64();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// if beg_dim is negative, it is counted from the end
|
||||||
|
if start_dim < 0 {
|
||||||
|
start_dim += tensor.dim as i64;
|
||||||
|
}
|
||||||
|
|
||||||
|
(start_dim as usize, end_dim)
|
||||||
|
}
|
Loading…
Reference in New Issue