mirror of https://github.com/tracel-ai/burn.git
Separating ONNX parsing from burn-import (#1921)
* separating onnx parsing from burn-import * ran clippy and cargo-fmt * removed unused deps from onnx-ir * fixed clippy warnings that were causing run-checks to fail * removed dead code * removed unused dependencies from burn-import * updated contributor-book, updated publish.yml, added readme * update cargo lock * formatted md document with prettier, rephrased sentence * missed the errors with reduce_prod_conversion during merge * formatted onnx-to-burn-conversion-tool.md, forgot to save
This commit is contained in:
parent
755c0708c4
commit
25348cf181
|
@ -161,3 +161,9 @@ jobs:
|
|||
with:
|
||||
crate: burn-import
|
||||
secrets: inherit
|
||||
|
||||
publish-onnx-ir:
|
||||
uses: tracel-ai/burn/.github/workflows/publish-template.yml@main
|
||||
with:
|
||||
crate: onnx-ir
|
||||
secrets: inherit
|
||||
|
|
|
@ -628,23 +628,19 @@ name = "burn-import"
|
|||
version = "0.14.0"
|
||||
dependencies = [
|
||||
"burn",
|
||||
"bytemuck",
|
||||
"candle-core",
|
||||
"derive-new",
|
||||
"half",
|
||||
"log",
|
||||
"onnx-ir",
|
||||
"pretty_assertions",
|
||||
"proc-macro2",
|
||||
"protobuf",
|
||||
"protobuf-codegen",
|
||||
"quote",
|
||||
"regex",
|
||||
"rstest",
|
||||
"rust-format",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"strum",
|
||||
"strum_macros",
|
||||
"syn 2.0.68",
|
||||
"thiserror",
|
||||
"tracing-core",
|
||||
|
@ -3689,6 +3685,23 @@ dependencies = [
|
|||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "onnx-ir"
|
||||
version = "0.14.0"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"half",
|
||||
"log",
|
||||
"pretty_assertions",
|
||||
"protobuf",
|
||||
"protobuf-codegen",
|
||||
"regex",
|
||||
"rstest",
|
||||
"serde",
|
||||
"strum",
|
||||
"strum_macros",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "onnx-tests"
|
||||
version = "0.14.0"
|
||||
|
|
|
@ -117,7 +117,8 @@ plug in your operator in terms of \\(x\\) and \\(y\\), and just swap out the var
|
|||
|
||||
### Testing autodiff
|
||||
|
||||
For testing the `autodiff` operations, please refer to [this section](../getting-started/testing.md).
|
||||
For testing the `autodiff` operations, please refer to
|
||||
[this section](../getting-started/testing.md).
|
||||
|
||||
## Adding the Op to other backends
|
||||
|
||||
|
@ -199,11 +200,15 @@ Generating the ONNX test files or tests is already covered
|
|||
about the specific changes you need to make when adding new operators after you have generated the
|
||||
tests.
|
||||
|
||||
The crate is divided into two sections `src/burn` and `src/onnx`. The code under the former
|
||||
corresponds to the operation you've implemented earlier in this guide, and the latter to the
|
||||
operations defined in the ONNX specification. So when you are loading a model, the operator is first
|
||||
parsed to an intermediate representation defined by `src/onnx`, and then mapped to a Burn operation
|
||||
defined under `src/burn/node`.
|
||||
Changes will need to be made to both `onnx-ir` and `burn-import`. The code within `onnx-ir` defines
|
||||
how to parse the nodes in an onnx file and produces the intermediate representation. The code within
|
||||
`burn-import` is divided into two sections: `src/onnx` and `src/burn`. The code under the former
|
||||
maps that intermediate representation to one used for code generation and the latter defines how to
|
||||
generate code for the operator you've implemented earlier in this guide.
|
||||
|
||||
So when you are loading a model, the operator is first parsed to an intermediate representation
|
||||
defined by `burn-import` and then mapped to a Burn operation defined under `src/burn/node`; the
|
||||
mapping from onnx to burn is aptly defined in `src/onnx/to_burn`
|
||||
|
||||
Let's review the changes made for powf starting from `src/burn` and moving to `src/onnx`:
|
||||
|
||||
|
@ -218,17 +223,20 @@ Let's review the changes made for powf starting from `src/burn` and moving to `s
|
|||
[`{op}_conversion` function](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-import/src/onnx/to_burn.rs#L717)
|
||||
that maps the ONNX node to the binary type
|
||||
3. Specify how dimensions for the output should be derived in
|
||||
[crates/burn-import/src/onnx/dim_inference.rs](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-import/src/onnx/dim_inference.rs#L55)
|
||||
[crates/onnx-ir/src/dim_inference.rs](https://github.com/tracel-ai/burn/blob/d4ae82b21ac3dd1def01bd380ab7ea4d3293eccb/crates/onnx-ir/src/dim_inference.rs#L17)
|
||||
|
||||
And you're done! Congrats, you just fully added a new operation to burn, and we are all one step
|
||||
closer to the answer to [Are we learning yet?](https://www.arewelearningyet.com/) being "Yes, and
|
||||
it's freaking fast!". Buy yourself a coffee.
|
||||
|
||||
[^supertrait]: for more on supertraits see
|
||||
[^supertrait]:
|
||||
for more on supertraits see
|
||||
[the advanced trait section of the rust book](https://doc.rust-lang.org/book/ch19-03-advanced-traits.html#using-supertraits-to-require-one-traits-functionality-within-another-trait)
|
||||
|
||||
[^autodiff]: wiki link for
|
||||
[^autodiff]:
|
||||
wiki link for
|
||||
[automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation)
|
||||
|
||||
[^absolute_units]: for more information on unit structs see
|
||||
[^absolute_units]:
|
||||
for more information on unit structs see
|
||||
[the defining and instantiating structs section of the rust book](https://doc.rust-lang.org/book/ch05-01-defining-structs.html#unit-like-structs-without-any-fields)
|
||||
|
|
|
@ -16,7 +16,17 @@ For an introduction to ONNX import in Burn, see
|
|||
- [Design Goals](#design-goals)
|
||||
- [Design Decisions](#design-decisions)
|
||||
- [Adding New Operators](#adding-new-operators)
|
||||
- [Implementing a New Operator](#implementing-a-new-operator)
|
||||
- [Implementing a New Operator](#implementing-a-new-operator)
|
||||
- [Step 1: Visibility](#step-1-visibility)
|
||||
- [Step 2: Node Implementation](#step-2-node-implementation)
|
||||
- [Within Onnx-IR](#within-onnx-ir)
|
||||
- [Within burn-import](#within-burn-import)
|
||||
- [Step 3: Registering New Operations](#step-3-registering-new-operations)
|
||||
- [Step 4: Create a Config Function](#step-4-create-a-config-function)
|
||||
- [Step 5: Dimension Inference](#step-5-dimension-inference)
|
||||
- [Step 6: Integrate into the Graph Building Process](#step-6-integrate-into-the-graph-building-process)
|
||||
- [Step 7: Add Newly Supported Op!](#step-7-add-newly-supported-op)
|
||||
- [Misc:](#misc)
|
||||
- [Testing](#testing)
|
||||
- [Resources](#resources)
|
||||
|
||||
|
@ -91,6 +101,22 @@ located in the `src/burn/node/` directory.
|
|||
|
||||
### Step 2: Node Implementation
|
||||
|
||||
#### Within Onnx-IR
|
||||
|
||||
If the node type does not exist within the
|
||||
[`NodeType` enum](https://github.com/tracel-ai/burn/blob/d4ae82b21ac3dd1def01bd380ab7ea4d3293eccb/crates/onnx-ir/src/ir.rs#L246),
|
||||
it will need to be added (support for custom operators is planned). If the node might be provided an
|
||||
input which is a constant or the output of an identity node, it will need to be added to the list of
|
||||
nodeTypes
|
||||
[checked for constants](https://github.com/tracel-ai/burn/blob/d4ae82b21ac3dd1def01bd380ab7ea4d3293eccb/crates/onnx-ir/src/from_onnx.rs#L21).
|
||||
The node will need to be added to `dim_inference`, and in most cases the work parsing side will be
|
||||
done. If a node requires extra parsing (such as handling an edge case like potentially remapping an
|
||||
unsqueeze to a reshape) the best place for that is after check constants and prior to dim_inference
|
||||
in
|
||||
[`OnnxGraphBuilder::Build`](https://github.com/tracel-ai/burn/blob/d4ae82b21ac3dd1def01bd380ab7ea4d3293eccb/crates/onnx-ir/src/from_onnx.rs#L221)
|
||||
|
||||
#### Within burn-import
|
||||
|
||||
Create a new file named `<operation_name>.rs` in the `src/burn/node/` directory.
|
||||
This file will define the structure and functionality of your new operation. By convention, the
|
||||
necessary information for carrying out an operation is encapsulated within a struct named
|
||||
|
|
|
@ -20,30 +20,23 @@ pytorch = ["burn/record-item-custom-serde", "thiserror", "zip"]
|
|||
|
||||
[dependencies]
|
||||
burn = { path = "../burn", version = "0.14.0", features = ["ndarray"] }
|
||||
|
||||
bytemuck = { workspace = true }
|
||||
onnx-ir = { path = "../onnx-ir" }
|
||||
candle-core = { workspace = true }
|
||||
derive-new = { workspace = true }
|
||||
half = { workspace = true }
|
||||
log = { workspace = true }
|
||||
proc-macro2 = { workspace = true }
|
||||
protobuf = { workspace = true, features = ["with-bytes"] }
|
||||
quote = { workspace = true }
|
||||
regex = { workspace = true }
|
||||
rust-format = { workspace = true, features = ["token_stream", "post_process"] }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true, features = ["std"] }
|
||||
strum = { workspace = true }
|
||||
strum_macros = { workspace = true }
|
||||
syn = { workspace = true, features = ["parsing"] }
|
||||
thiserror = { workspace = true, optional = true }
|
||||
tracing-core = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
zip = { workspace = true, optional = true }
|
||||
|
||||
[build-dependencies]
|
||||
protobuf-codegen = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
pretty_assertions = { workspace = true }
|
||||
rstest = { workspace = true }
|
||||
|
|
|
@ -1,11 +0,0 @@
|
|||
fn main() {
|
||||
if cfg!(feature = "onnx") {
|
||||
// Generate the onnx protobuf files
|
||||
protobuf_codegen::Codegen::new()
|
||||
.pure()
|
||||
.includes(["src"])
|
||||
.input("src/onnx/protos/onnx.proto")
|
||||
.cargo_out_dir("onnx-protos")
|
||||
.run_from_script();
|
||||
}
|
||||
}
|
|
@ -1,14 +1,3 @@
|
|||
mod coalesce;
|
||||
mod dim_inference;
|
||||
mod from_onnx;
|
||||
mod ir;
|
||||
mod node_remap;
|
||||
mod op_configuration;
|
||||
mod proto_conversion;
|
||||
mod protos;
|
||||
mod to_burn;
|
||||
|
||||
pub use to_burn::*;
|
||||
|
||||
pub use from_onnx::parse_onnx;
|
||||
pub use ir::OnnxGraph;
|
||||
|
|
|
@ -5,8 +5,8 @@ use burn::nn::{
|
|||
PaddingConfig2d,
|
||||
};
|
||||
|
||||
use super::ir::{ArgType, AttributeValue, Data, Node};
|
||||
use crate::burn::node::resize::ResizeMode;
|
||||
use onnx_ir::ir::{ArgType, AttributeValue, Data, Node};
|
||||
|
||||
/// Create a Conv1dConfig from the attributes of the node
|
||||
pub fn conv1d_config(curr: &Node) -> Conv1dConfig {
|
||||
|
|
|
@ -53,20 +53,24 @@ use crate::{
|
|||
},
|
||||
format_tokens,
|
||||
logger::init_log,
|
||||
onnx::{
|
||||
from_onnx::convert_constant_value,
|
||||
ir::{Node, NodeType},
|
||||
op_configuration::*,
|
||||
},
|
||||
};
|
||||
|
||||
use super::{
|
||||
from_onnx::parse_onnx,
|
||||
ir::{self, ArgType, Argument, Data, ElementType, OnnxGraph},
|
||||
op_configuration::{
|
||||
avg_pool2d_config, clip_config, concat_config, dropout_config, reshape_config,
|
||||
resize_config, softmax_config,
|
||||
use super::op_configuration::{
|
||||
argmax_config, avg_pool1d_config, avg_pool2d_config, batch_norm_config, clip_config,
|
||||
concat_config, conv1d_config, conv2d_config, conv_transpose2d_config, dropout_config,
|
||||
expand_config, flatten_config, gather_config, layer_norm_config, leaky_relu_config,
|
||||
linear_config, log_softmax_config, max_pool1d_config, max_pool2d_config, reduce_max_config,
|
||||
reduce_mean_config, reduce_min_config, reduce_prod_config, reduce_sum_config, reshape_config,
|
||||
resize_config, shape_config, slice_config, softmax_config, squeeze_config, transpose_config,
|
||||
unsqueeze_config,
|
||||
};
|
||||
use onnx_ir::{
|
||||
convert_constant_value,
|
||||
ir::{
|
||||
ArgType, Argument as OnnxArgument, Data, ElementType, Node, NodeType, OnnxGraph,
|
||||
TensorType as OnnxTensorType,
|
||||
},
|
||||
parse_onnx,
|
||||
};
|
||||
|
||||
pub use crate::burn::graph::RecordType;
|
||||
|
@ -197,6 +201,7 @@ impl ModelGen {
|
|||
log::debug!("Output file: {:?}", out_file);
|
||||
|
||||
let graph = parse_onnx(input.as_ref());
|
||||
let graph = ParsedOnnxGraph(graph);
|
||||
|
||||
if self.development {
|
||||
// export the graph
|
||||
|
@ -231,15 +236,16 @@ impl ModelGen {
|
|||
log::info!("Model generated");
|
||||
}
|
||||
}
|
||||
|
||||
impl OnnxGraph {
|
||||
#[derive(Debug)]
|
||||
struct ParsedOnnxGraph(OnnxGraph);
|
||||
impl ParsedOnnxGraph {
|
||||
/// Converts ONNX graph to Burn graph.
|
||||
pub fn into_burn<PS: PrecisionSettings + 'static>(self) -> BurnGraph<PS> {
|
||||
let mut graph = BurnGraph::<PS>::default();
|
||||
|
||||
let mut unsupported_ops = vec![];
|
||||
|
||||
for node in self.nodes {
|
||||
for node in self.0.nodes {
|
||||
match node.node_type {
|
||||
NodeType::Add => graph.register(Self::add_conversion(node)),
|
||||
NodeType::ArgMax => graph.register(Self::argmax_conversion(node)),
|
||||
|
@ -328,11 +334,13 @@ impl OnnxGraph {
|
|||
|
||||
// Get input and output names
|
||||
let input_names = self
|
||||
.0
|
||||
.inputs
|
||||
.iter()
|
||||
.map(|input| input.name.clone())
|
||||
.collect::<Vec<_>>();
|
||||
let output_names = self
|
||||
.0
|
||||
.outputs
|
||||
.iter()
|
||||
.map(|output| output.name.clone())
|
||||
|
@ -390,13 +398,13 @@ impl OnnxGraph {
|
|||
ArgType::Shape(_) => panic!("Shape is not supported as constant value."),
|
||||
};
|
||||
|
||||
ConstantNode::new(node.name.clone(), const_value, output.to_type())
|
||||
ConstantNode::new(node.name.clone(), const_value, Type::from(output))
|
||||
}
|
||||
|
||||
fn random_uniform_conversion(node: Node) -> RandomUniformNode {
|
||||
let output = node.outputs.first().unwrap();
|
||||
// cannot use output.to_tensor_type() here, since it drops the shape info...
|
||||
let output_type = if let Type::Tensor(t) = output.to_type() {
|
||||
let output_type = if let Type::Tensor(t) = Type::from(output) {
|
||||
t
|
||||
} else {
|
||||
panic!("RandomUniform output type is no Tensor.");
|
||||
|
@ -423,7 +431,7 @@ impl OnnxGraph {
|
|||
fn random_normal_conversion(node: Node) -> RandomNormalNode {
|
||||
let output = node.outputs.first().unwrap();
|
||||
// cannot use output.to_tensor_type() here, since it drops the shape info...
|
||||
let output_type = if let Type::Tensor(t) = output.to_type() {
|
||||
let output_type = if let Type::Tensor(t) = Type::from(output) {
|
||||
t
|
||||
} else {
|
||||
panic!("RandomNormal output type is no Tensor.");
|
||||
|
@ -448,141 +456,141 @@ impl OnnxGraph {
|
|||
}
|
||||
|
||||
fn add_conversion(node: Node) -> BinaryNode {
|
||||
let lhs = node.inputs.first().unwrap().to_type();
|
||||
let rhs = node.inputs.get(1).unwrap().to_type();
|
||||
let output = node.outputs.first().unwrap().to_type();
|
||||
let lhs = Type::from(node.inputs.first().unwrap());
|
||||
let rhs = Type::from(node.inputs.get(1).unwrap());
|
||||
let output = Type::from(node.outputs.first().unwrap());
|
||||
|
||||
BinaryNode::add(lhs, rhs, output)
|
||||
}
|
||||
|
||||
fn sub_conversion(node: Node) -> BinaryNode {
|
||||
let lhs = node.inputs.first().unwrap().to_type();
|
||||
let rhs = node.inputs.get(1).unwrap().to_type();
|
||||
let output = node.outputs.first().unwrap().to_type();
|
||||
let lhs = Type::from(node.inputs.first().unwrap());
|
||||
let rhs = Type::from(node.inputs.get(1).unwrap());
|
||||
let output = Type::from(node.outputs.first().unwrap());
|
||||
|
||||
BinaryNode::sub(lhs, rhs, output)
|
||||
}
|
||||
|
||||
fn mul_conversion(node: Node) -> BinaryNode {
|
||||
let lhs = node.inputs.first().unwrap().to_type();
|
||||
let rhs = node.inputs.get(1).unwrap().to_type();
|
||||
let output = node.outputs.first().unwrap().to_type();
|
||||
let lhs = Type::from(node.inputs.first().unwrap());
|
||||
let rhs = Type::from(node.inputs.get(1).unwrap());
|
||||
let output = Type::from(node.outputs.first().unwrap());
|
||||
|
||||
BinaryNode::mul(lhs, rhs, output)
|
||||
}
|
||||
|
||||
fn div_conversion(node: Node) -> BinaryNode {
|
||||
let lhs = node.inputs.first().unwrap().to_type();
|
||||
let rhs = node.inputs.get(1).unwrap().to_type();
|
||||
let output = node.outputs.first().unwrap().to_type();
|
||||
let lhs = Type::from(node.inputs.first().unwrap());
|
||||
let rhs = Type::from(node.inputs.get(1).unwrap());
|
||||
let output = Type::from(node.outputs.first().unwrap());
|
||||
|
||||
BinaryNode::div(lhs, rhs, output)
|
||||
}
|
||||
|
||||
fn matmul_conversion(node: Node) -> MatmulNode {
|
||||
let lhs = node.inputs.first().unwrap().to_tensor_type();
|
||||
let rhs = node.inputs.get(1).unwrap().to_tensor_type();
|
||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
||||
let lhs = TensorType::from(node.inputs.first().unwrap());
|
||||
let rhs = TensorType::from(node.inputs.get(1).unwrap());
|
||||
let output = TensorType::from(node.outputs.first().unwrap());
|
||||
|
||||
MatmulNode::new(lhs, rhs, output)
|
||||
}
|
||||
|
||||
fn equal_conversion(node: Node) -> BinaryNode {
|
||||
let lhs = node.inputs.first().unwrap().to_type();
|
||||
let rhs = node.inputs.get(1).unwrap().to_type();
|
||||
let output = node.outputs.first().unwrap().to_type();
|
||||
let lhs = Type::from(node.inputs.first().unwrap());
|
||||
let rhs = Type::from(node.inputs.get(1).unwrap());
|
||||
let output = Type::from(node.outputs.first().unwrap());
|
||||
|
||||
BinaryNode::equal(lhs, rhs, output)
|
||||
}
|
||||
|
||||
fn max_conversion(node: Node) -> BinaryNode {
|
||||
let lhs = node.inputs.first().unwrap().to_type();
|
||||
let rhs = node.inputs.get(1).unwrap().to_type();
|
||||
let output = node.outputs.first().unwrap().to_type();
|
||||
let lhs = Type::from(node.inputs.first().unwrap());
|
||||
let rhs = Type::from(node.inputs.get(1).unwrap());
|
||||
let output = Type::from(node.outputs.first().unwrap());
|
||||
|
||||
BinaryNode::max_pair(lhs, rhs, output)
|
||||
}
|
||||
|
||||
fn erf_conversion(node: Node) -> UnaryNode {
|
||||
let input = node.inputs.first().unwrap().to_type();
|
||||
let output = node.outputs.first().unwrap().to_type();
|
||||
let input = Type::from(node.inputs.first().unwrap());
|
||||
let output = Type::from(node.outputs.first().unwrap());
|
||||
|
||||
UnaryNode::erf(input, output)
|
||||
}
|
||||
|
||||
fn leaky_relu_conversion(node: Node) -> UnaryNode {
|
||||
let input = node.inputs.first().unwrap().to_type();
|
||||
let output = node.outputs.first().unwrap().to_type();
|
||||
let input = Type::from(node.inputs.first().unwrap());
|
||||
let output = Type::from(node.outputs.first().unwrap());
|
||||
let alpha = leaky_relu_config(&node);
|
||||
|
||||
UnaryNode::leaky_relu(input, output, alpha)
|
||||
}
|
||||
|
||||
fn relu_conversion(node: Node) -> UnaryNode {
|
||||
let input = node.inputs.first().unwrap().to_type();
|
||||
let output = node.outputs.first().unwrap().to_type();
|
||||
let input = Type::from(node.inputs.first().unwrap());
|
||||
let output = Type::from(node.outputs.first().unwrap());
|
||||
|
||||
UnaryNode::relu(input, output)
|
||||
}
|
||||
|
||||
fn gelu_conversion(node: Node) -> UnaryNode {
|
||||
let input = node.inputs.first().unwrap().to_type();
|
||||
let output = node.outputs.first().unwrap().to_type();
|
||||
let input = Type::from(node.inputs.first().unwrap());
|
||||
let output = Type::from(node.outputs.first().unwrap());
|
||||
|
||||
UnaryNode::gelu(input, output)
|
||||
}
|
||||
|
||||
fn log_conversion(node: Node) -> UnaryNode {
|
||||
let input = node.inputs.first().unwrap().to_type();
|
||||
let output = node.outputs.first().unwrap().to_type();
|
||||
let input = Type::from(node.inputs.first().unwrap());
|
||||
let output = Type::from(node.outputs.first().unwrap());
|
||||
|
||||
UnaryNode::log(input, output)
|
||||
}
|
||||
|
||||
fn flatten_conversion(node: Node) -> UnaryNode {
|
||||
let input = node.inputs.first().unwrap().to_type();
|
||||
let output = node.outputs.first().unwrap().to_type();
|
||||
let input = Type::from(node.inputs.first().unwrap());
|
||||
let output = Type::from(node.outputs.first().unwrap());
|
||||
let (start_dim, end_dim) = flatten_config(&node);
|
||||
|
||||
UnaryNode::flatten(input, output, start_dim, end_dim)
|
||||
}
|
||||
|
||||
fn gather_conversion(node: Node) -> GatherNode {
|
||||
let input = node.inputs.first().unwrap().to_tensor_type();
|
||||
let index = node.inputs.get(1).unwrap().to_tensor_type();
|
||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
||||
let input = TensorType::from(node.inputs.first().unwrap());
|
||||
let index = TensorType::from(node.inputs.get(1).unwrap());
|
||||
let output = TensorType::from(node.outputs.first().unwrap());
|
||||
let dim = gather_config(&node);
|
||||
|
||||
GatherNode::new(input, index, output, dim)
|
||||
}
|
||||
|
||||
fn gather_elements_conversion(node: Node) -> GatherElementsNode {
|
||||
let input = node.inputs.first().unwrap().to_tensor_type();
|
||||
let index = node.inputs.get(1).unwrap().to_tensor_type();
|
||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
||||
let input = TensorType::from(node.inputs.first().unwrap());
|
||||
let index = TensorType::from(node.inputs.get(1).unwrap());
|
||||
let output = TensorType::from(node.outputs.first().unwrap());
|
||||
let dim = gather_config(&node);
|
||||
|
||||
GatherElementsNode::new(input, index, output, dim)
|
||||
}
|
||||
|
||||
fn transpose_conversion(node: Node) -> UnaryNode {
|
||||
let input = node.inputs.first().unwrap().to_type();
|
||||
let output = node.outputs.first().unwrap().to_type();
|
||||
let input = Type::from(node.inputs.first().unwrap());
|
||||
let output = Type::from(node.outputs.first().unwrap());
|
||||
let perm = transpose_config(&node);
|
||||
|
||||
UnaryNode::transpose(input, output, perm)
|
||||
}
|
||||
|
||||
fn cast_conversion(node: Node) -> UnaryNode {
|
||||
let input = node.inputs.first().unwrap().to_type();
|
||||
let output = node.outputs.first().unwrap().to_type();
|
||||
let input = Type::from(node.inputs.first().unwrap());
|
||||
let output = Type::from(node.outputs.first().unwrap());
|
||||
|
||||
UnaryNode::cast(input, output)
|
||||
}
|
||||
|
||||
fn reshape_conversion(node: Node) -> ReshapeNode {
|
||||
let input = node.inputs.first().unwrap().to_tensor_type();
|
||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
||||
let input = TensorType::from(node.inputs.first().unwrap());
|
||||
let output = TensorType::from(node.outputs.first().unwrap());
|
||||
let shape = reshape_config(&node);
|
||||
|
||||
ReshapeNode::new(input, output, shape)
|
||||
|
@ -591,10 +599,10 @@ impl OnnxGraph {
|
|||
fn resize_conversion(node: Node) -> ResizeNode {
|
||||
let name = &node.name;
|
||||
|
||||
let input = node.inputs[0].to_tensor_type();
|
||||
let output_size = node.inputs[3].to_tensor_type();
|
||||
let input = TensorType::from(&node.inputs[0]);
|
||||
let output_size = TensorType::from(&node.inputs[3]);
|
||||
|
||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
||||
let output = TensorType::from(node.outputs.first().unwrap());
|
||||
|
||||
let mode = resize_config(&node);
|
||||
|
||||
|
@ -602,15 +610,15 @@ impl OnnxGraph {
|
|||
}
|
||||
|
||||
fn min_conversion(node: Node) -> BinaryNode {
|
||||
let lhs = node.inputs.first().unwrap().to_type();
|
||||
let rhs = node.inputs.get(1).unwrap().to_type();
|
||||
let output = node.outputs.first().unwrap().to_type();
|
||||
let lhs = Type::from(node.inputs.first().unwrap());
|
||||
let rhs = Type::from(node.inputs.get(1).unwrap());
|
||||
let output = Type::from(node.outputs.first().unwrap());
|
||||
|
||||
BinaryNode::min_pair(lhs, rhs, output)
|
||||
}
|
||||
|
||||
fn range_conversion(node: Node) -> RangeNode {
|
||||
fn convert_arg_to_scalar(arg: &Argument) -> ScalarType {
|
||||
fn convert_arg_to_scalar(arg: &OnnxArgument) -> ScalarType {
|
||||
match &arg.ty {
|
||||
ArgType::Scalar(scalar) => {
|
||||
ScalarType::new(arg.name.clone(), ScalarKind::from(scalar))
|
||||
|
@ -624,7 +632,7 @@ impl OnnxGraph {
|
|||
_ => panic!("Range node requires scalar inputs"),
|
||||
}
|
||||
}
|
||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
||||
let output = TensorType::from(node.outputs.first().unwrap());
|
||||
let start = convert_arg_to_scalar(node.inputs.first().unwrap());
|
||||
let end = convert_arg_to_scalar(node.inputs.get(1).unwrap());
|
||||
let step = convert_arg_to_scalar(node.inputs.get(2).unwrap());
|
||||
|
@ -633,164 +641,156 @@ impl OnnxGraph {
|
|||
}
|
||||
|
||||
fn reduce_max_conversion(node: Node) -> UnaryNode {
|
||||
let input = node.inputs.first().unwrap().to_type();
|
||||
let output = node.outputs.first().unwrap().to_type();
|
||||
let input = Type::from(node.inputs.first().unwrap());
|
||||
let output = Type::from(node.outputs.first().unwrap());
|
||||
let dim = reduce_max_config(&node);
|
||||
|
||||
UnaryNode::reduce_max(input, output, dim)
|
||||
}
|
||||
|
||||
fn reduce_min_conversion(node: Node) -> UnaryNode {
|
||||
let input = node.inputs.first().unwrap().to_type();
|
||||
let output = node.outputs.first().unwrap().to_type();
|
||||
let input = Type::from(node.inputs.first().unwrap());
|
||||
let output = Type::from(node.outputs.first().unwrap());
|
||||
let dim = reduce_min_config(&node);
|
||||
|
||||
UnaryNode::reduce_min(input, output, dim)
|
||||
}
|
||||
|
||||
fn reduce_mean_conversion(node: Node) -> UnaryNode {
|
||||
let input = node.inputs.first().unwrap().to_type();
|
||||
let output = node.outputs.first().unwrap().to_type();
|
||||
let input = Type::from(node.inputs.first().unwrap());
|
||||
let output = Type::from(node.outputs.first().unwrap());
|
||||
let dim = reduce_mean_config(&node);
|
||||
|
||||
UnaryNode::reduce_mean(input, output, dim)
|
||||
}
|
||||
|
||||
fn reduce_prod_conversion(node: Node) -> UnaryNode {
|
||||
let input = node.inputs.first().unwrap().to_type();
|
||||
let output = node.outputs.first().unwrap().to_type();
|
||||
let input = Type::from(node.inputs.first().unwrap());
|
||||
let output = Type::from(node.outputs.first().unwrap());
|
||||
let dim = reduce_prod_config(&node);
|
||||
|
||||
UnaryNode::reduce_prod(input, output, dim)
|
||||
}
|
||||
|
||||
fn reduce_sum_conversion(node: Node) -> UnaryNode {
|
||||
let input = node.inputs.first().unwrap().to_type();
|
||||
let output = node.outputs.first().unwrap().to_type();
|
||||
let input = Type::from(node.inputs.first().unwrap());
|
||||
let output = Type::from(node.outputs.first().unwrap());
|
||||
let dim = reduce_sum_config(&node);
|
||||
|
||||
UnaryNode::reduce_sum(input, output, dim)
|
||||
}
|
||||
|
||||
fn shape_conversion(node: Node) -> UnaryNode {
|
||||
let input = node.inputs.first().unwrap().to_type();
|
||||
let output = node.outputs.first().unwrap().to_type();
|
||||
let input = Type::from(node.inputs.first().unwrap());
|
||||
let output = Type::from(node.outputs.first().unwrap());
|
||||
let (start_dim, end_dim) = shape_config(&node);
|
||||
|
||||
UnaryNode::shape(input, output, start_dim, end_dim)
|
||||
}
|
||||
|
||||
fn unsqueeze_conversion(node: Node) -> UnsqueezeNode {
|
||||
let input = node.inputs.first().unwrap().to_type();
|
||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
||||
let input = Type::from(node.inputs.first().unwrap());
|
||||
let output = TensorType::from(node.outputs.first().unwrap());
|
||||
let dims = unsqueeze_config(&node);
|
||||
|
||||
UnsqueezeNode::new(input, output, dims)
|
||||
}
|
||||
|
||||
fn where_conversion(node: Node) -> WhereNode {
|
||||
let condition = node.inputs.first().unwrap().to_tensor_type();
|
||||
let x = node.inputs.get(1).unwrap().to_tensor_type();
|
||||
let y = node.inputs.get(2).unwrap().to_tensor_type();
|
||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
||||
let condition = TensorType::from(node.inputs.first().unwrap());
|
||||
let x = TensorType::from(node.inputs.get(1).unwrap());
|
||||
let y = TensorType::from(node.inputs.get(2).unwrap());
|
||||
let output = TensorType::from(node.outputs.first().unwrap());
|
||||
|
||||
WhereNode::new(condition, x, y, output)
|
||||
}
|
||||
|
||||
fn clip_conversion(node: Node) -> ClipNode {
|
||||
let input = node.inputs.first().unwrap().to_tensor_type();
|
||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
||||
let input = TensorType::from(node.inputs.first().unwrap());
|
||||
let output = TensorType::from(node.outputs.first().unwrap());
|
||||
let (min, max) = clip_config(&node);
|
||||
|
||||
ClipNode::new(input, output, min, max)
|
||||
}
|
||||
|
||||
fn sigmoid_conversion(node: Node) -> UnaryNode {
|
||||
let input = node.inputs.first().unwrap().to_type();
|
||||
let output = node.outputs.first().unwrap().to_type();
|
||||
let input = Type::from(node.inputs.first().unwrap());
|
||||
let output = Type::from(node.outputs.first().unwrap());
|
||||
|
||||
UnaryNode::sigmoid(input, output)
|
||||
}
|
||||
|
||||
fn sin_conversion(node: Node) -> UnaryNode {
|
||||
let input = node.inputs.first().unwrap().to_type();
|
||||
let output = node.outputs.first().unwrap().to_type();
|
||||
let input = Type::from(node.inputs.first().unwrap());
|
||||
let output = Type::from(node.outputs.first().unwrap());
|
||||
|
||||
UnaryNode::sin(input, output)
|
||||
}
|
||||
|
||||
fn slice_conversion(node: Node) -> SliceNode {
|
||||
let input = node.inputs.first().unwrap().to_tensor_type();
|
||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
||||
let input = TensorType::from(node.inputs.first().unwrap());
|
||||
let output = TensorType::from(node.outputs.first().unwrap());
|
||||
let (starts, ends) = slice_config(&node);
|
||||
|
||||
SliceNode::new(input, output, starts, ends)
|
||||
}
|
||||
|
||||
fn sum_conversion(node: Node) -> SumNode {
|
||||
let inputs = node
|
||||
.inputs
|
||||
.iter()
|
||||
.map(|input| input.to_tensor_type())
|
||||
.collect();
|
||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
||||
let inputs = node.inputs.iter().map(TensorType::from).collect();
|
||||
let output = TensorType::from(node.outputs.first().unwrap());
|
||||
|
||||
SumNode::new(inputs, output)
|
||||
}
|
||||
|
||||
fn reciprocal_conversion(node: Node) -> UnaryNode {
|
||||
let input = node.inputs.first().unwrap().to_type();
|
||||
let output = node.outputs.first().unwrap().to_type();
|
||||
let input = Type::from(node.inputs.first().unwrap());
|
||||
let output = Type::from(node.outputs.first().unwrap());
|
||||
|
||||
UnaryNode::reciprocal(input, output)
|
||||
}
|
||||
|
||||
fn log_softmax_conversion(node: Node) -> UnaryNode {
|
||||
let input = node.inputs.first().unwrap().to_type();
|
||||
let output = node.outputs.first().unwrap().to_type();
|
||||
let input = Type::from(node.inputs.first().unwrap());
|
||||
let output = Type::from(node.outputs.first().unwrap());
|
||||
let dim = log_softmax_config(&node);
|
||||
|
||||
UnaryNode::log_softmax(input, output, dim)
|
||||
}
|
||||
|
||||
fn softmax_conversion(node: Node) -> UnaryNode {
|
||||
let input = node.inputs.first().unwrap().to_type();
|
||||
let output = node.outputs.first().unwrap().to_type();
|
||||
let input = Type::from(node.inputs.first().unwrap());
|
||||
let output = Type::from(node.outputs.first().unwrap());
|
||||
let dim = softmax_config(&node);
|
||||
|
||||
UnaryNode::softmax(input, output, dim)
|
||||
}
|
||||
|
||||
fn sqrt_conversion(node: Node) -> UnaryNode {
|
||||
let input = node.inputs.first().unwrap().to_type();
|
||||
let output = node.outputs.first().unwrap().to_type();
|
||||
let input = Type::from(node.inputs.first().unwrap());
|
||||
let output = Type::from(node.outputs.first().unwrap());
|
||||
|
||||
UnaryNode::sqrt(input, output)
|
||||
}
|
||||
|
||||
fn tanh_conversion(node: Node) -> UnaryNode {
|
||||
let input = node.inputs.first().unwrap().to_type();
|
||||
let output = node.outputs.first().unwrap().to_type();
|
||||
let input = Type::from(node.inputs.first().unwrap());
|
||||
let output = Type::from(node.outputs.first().unwrap());
|
||||
|
||||
UnaryNode::tanh(input, output)
|
||||
}
|
||||
|
||||
fn argmax_conversion(node: Node) -> ArgMaxNode {
|
||||
let input = node.inputs.first().unwrap().to_tensor_type();
|
||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
||||
let input = TensorType::from(node.inputs.first().unwrap());
|
||||
let output = TensorType::from(node.outputs.first().unwrap());
|
||||
let axis = argmax_config(&node);
|
||||
|
||||
ArgMaxNode::new(input, output, axis)
|
||||
}
|
||||
|
||||
fn concat_conversion(node: Node) -> ConcatNode {
|
||||
let inputs = node
|
||||
.inputs
|
||||
.iter()
|
||||
.map(|input| input.to_tensor_type())
|
||||
.collect();
|
||||
let inputs = node.inputs.iter().map(TensorType::from).collect();
|
||||
|
||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
||||
let output = TensorType::from(node.outputs.first().unwrap());
|
||||
let dim = concat_config(&node);
|
||||
|
||||
ConcatNode::new(inputs, output, dim)
|
||||
|
@ -798,8 +798,8 @@ impl OnnxGraph {
|
|||
|
||||
fn linear_conversion<PS: PrecisionSettings>(node: Node) -> LinearNode {
|
||||
let name = &node.name;
|
||||
let input = node.inputs.first().unwrap().to_tensor_type();
|
||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
||||
let input = TensorType::from(node.inputs.first().unwrap());
|
||||
let output = TensorType::from(node.outputs.first().unwrap());
|
||||
let config = linear_config(&node);
|
||||
|
||||
let weight = extract_data_serialize::<PS::FloatElem>(1, &node).expect("Weight is required");
|
||||
|
@ -811,8 +811,8 @@ impl OnnxGraph {
|
|||
|
||||
fn dropout_conversion(node: Node) -> DropoutNode {
|
||||
let name = &node.name;
|
||||
let input = node.inputs.first().unwrap().to_tensor_type();
|
||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
||||
let input = TensorType::from(node.inputs.first().unwrap());
|
||||
let output = TensorType::from(node.outputs.first().unwrap());
|
||||
let config = dropout_config(&node);
|
||||
|
||||
DropoutNode::new(name, input, output, config)
|
||||
|
@ -820,8 +820,8 @@ impl OnnxGraph {
|
|||
|
||||
fn batch_norm_conversion<PS: PrecisionSettings>(node: Node) -> BatchNormNode {
|
||||
let config = batch_norm_config(&node);
|
||||
let input = node.inputs.first().unwrap().to_tensor_type();
|
||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
||||
let input = TensorType::from(node.inputs.first().unwrap());
|
||||
let output = TensorType::from(node.outputs.first().unwrap());
|
||||
let dim = input.dim - 2;
|
||||
|
||||
let gamma = extract_data_serialize::<PS::FloatElem>(1, &node).expect("Gamma is required");
|
||||
|
@ -848,8 +848,8 @@ impl OnnxGraph {
|
|||
|
||||
fn layer_norm_conversion<PS: PrecisionSettings>(node: Node) -> LayerNormNode {
|
||||
let (config, full_precision) = layer_norm_config(&node);
|
||||
let input = node.inputs.first().unwrap().to_tensor_type();
|
||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
||||
let input = TensorType::from(node.inputs.first().unwrap());
|
||||
let output = TensorType::from(node.outputs.first().unwrap());
|
||||
|
||||
// Scale tensor (aka gamma)
|
||||
let gamma = extract_data_serialize::<PS::FloatElem>(1, &node).expect("Gamma is required");
|
||||
|
@ -862,8 +862,8 @@ impl OnnxGraph {
|
|||
}
|
||||
|
||||
fn conv1d_conversion<PS: PrecisionSettings>(node: Node) -> Conv1dNode {
|
||||
let input = node.inputs.first().unwrap().to_tensor_type();
|
||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
||||
let input = TensorType::from(node.inputs.first().unwrap());
|
||||
let output = TensorType::from(node.outputs.first().unwrap());
|
||||
let config = conv1d_config(&node);
|
||||
|
||||
let bias = node.inputs.len() == 3;
|
||||
|
@ -878,8 +878,8 @@ impl OnnxGraph {
|
|||
}
|
||||
|
||||
fn conv2d_conversion<PS: PrecisionSettings>(node: Node) -> Conv2dNode {
|
||||
let input = node.inputs.first().unwrap().to_tensor_type();
|
||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
||||
let input = TensorType::from(node.inputs.first().unwrap());
|
||||
let output = TensorType::from(node.outputs.first().unwrap());
|
||||
let config = conv2d_config(&node);
|
||||
|
||||
let bias = node.inputs.len() == 3;
|
||||
|
@ -894,8 +894,8 @@ impl OnnxGraph {
|
|||
}
|
||||
|
||||
fn max_pool1d_conversion(node: Node) -> MaxPool1dNode {
|
||||
let input = node.inputs.first().unwrap().to_tensor_type();
|
||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
||||
let input = TensorType::from(node.inputs.first().unwrap());
|
||||
let output = TensorType::from(node.outputs.first().unwrap());
|
||||
let config = max_pool1d_config(&node);
|
||||
|
||||
let name = &node.name;
|
||||
|
@ -903,8 +903,8 @@ impl OnnxGraph {
|
|||
}
|
||||
|
||||
fn max_pool2d_conversion(node: Node) -> MaxPool2dNode {
|
||||
let input = node.inputs.first().unwrap().to_tensor_type();
|
||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
||||
let input = TensorType::from(node.inputs.first().unwrap());
|
||||
let output = TensorType::from(node.outputs.first().unwrap());
|
||||
let config = max_pool2d_config(&node);
|
||||
|
||||
let name = &node.name;
|
||||
|
@ -912,16 +912,16 @@ impl OnnxGraph {
|
|||
}
|
||||
|
||||
fn prelu_conversion<PS: PrecisionSettings>(node: Node) -> PReluNode {
|
||||
let input = node.inputs.first().unwrap().to_tensor_type();
|
||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
||||
let input = TensorType::from(node.inputs.first().unwrap());
|
||||
let output = TensorType::from(node.outputs.first().unwrap());
|
||||
let weight = extract_data_serialize::<PS::FloatElem>(1, &node).unwrap();
|
||||
let config = PReluConfig::new();
|
||||
let name = &node.name;
|
||||
PReluNode::new(name, input, output, weight, config)
|
||||
}
|
||||
fn conv_transpose2d_conversion<PS: PrecisionSettings>(node: Node) -> ConvTranspose2dNode {
|
||||
let input = node.inputs.first().unwrap().to_tensor_type();
|
||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
||||
let input = TensorType::from(node.inputs.first().unwrap());
|
||||
let output = TensorType::from(node.outputs.first().unwrap());
|
||||
let config = conv_transpose2d_config(&node);
|
||||
|
||||
let bias = node.inputs.len() == 3;
|
||||
|
@ -935,8 +935,8 @@ impl OnnxGraph {
|
|||
ConvTranspose2dNode::new(name, input, output, weight, bias, config)
|
||||
}
|
||||
fn avg_pool_1d_conversion(node: Node) -> AvgPool1dNode {
|
||||
let input = node.inputs.first().unwrap().to_tensor_type();
|
||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
||||
let input = TensorType::from(node.inputs.first().unwrap());
|
||||
let output = TensorType::from(node.outputs.first().unwrap());
|
||||
let config = avg_pool1d_config(&node);
|
||||
|
||||
let name = &node.name;
|
||||
|
@ -944,8 +944,8 @@ impl OnnxGraph {
|
|||
}
|
||||
|
||||
fn avg_pool_2d_conversion(node: Node) -> AvgPool2dNode {
|
||||
let input = node.inputs.first().unwrap().to_tensor_type();
|
||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
||||
let input = TensorType::from(node.inputs.first().unwrap());
|
||||
let output = TensorType::from(node.outputs.first().unwrap());
|
||||
let config = avg_pool2d_config(&node);
|
||||
|
||||
let name = &node.name;
|
||||
|
@ -953,8 +953,8 @@ impl OnnxGraph {
|
|||
}
|
||||
|
||||
fn global_avg_pool_conversion(node: Node) -> GlobalAvgPoolNode {
|
||||
let input = node.inputs.first().unwrap().to_tensor_type();
|
||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
||||
let input = TensorType::from(node.inputs.first().unwrap());
|
||||
let output = TensorType::from(node.outputs.first().unwrap());
|
||||
|
||||
let name = &node.name;
|
||||
|
||||
|
@ -962,71 +962,71 @@ impl OnnxGraph {
|
|||
}
|
||||
|
||||
fn cos_conversion(node: Node) -> UnaryNode {
|
||||
let input = node.inputs.first().unwrap().to_type();
|
||||
let output = node.outputs.first().unwrap().to_type();
|
||||
let input = Type::from(node.inputs.first().unwrap());
|
||||
let output = Type::from(node.outputs.first().unwrap());
|
||||
|
||||
UnaryNode::cos(input, output)
|
||||
}
|
||||
|
||||
fn exp_conversion(node: Node) -> UnaryNode {
|
||||
let input = node.inputs.first().unwrap().to_type();
|
||||
let output = node.outputs.first().unwrap().to_type();
|
||||
let input = Type::from(node.inputs.first().unwrap());
|
||||
let output = Type::from(node.outputs.first().unwrap());
|
||||
|
||||
UnaryNode::exp(input, output)
|
||||
}
|
||||
|
||||
fn expand_conversion(node: Node) -> ExpandNode {
|
||||
let input = node.inputs.first().unwrap().to_tensor_type();
|
||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
||||
let input = TensorType::from(node.inputs.first().unwrap());
|
||||
let output = TensorType::from(node.outputs.first().unwrap());
|
||||
let shape = expand_config(&node);
|
||||
|
||||
ExpandNode::new(input, output, shape)
|
||||
}
|
||||
|
||||
fn neg_conversion(node: Node) -> UnaryNode {
|
||||
let input = node.inputs.first().unwrap().to_type();
|
||||
let output = node.outputs.first().unwrap().to_type();
|
||||
let input = Type::from(node.inputs.first().unwrap());
|
||||
let output = Type::from(node.outputs.first().unwrap());
|
||||
UnaryNode::neg(input, output)
|
||||
}
|
||||
|
||||
fn not_conversion(node: Node) -> UnaryNode {
|
||||
let input = node.inputs.first().unwrap().to_type();
|
||||
let output = node.outputs.first().unwrap().to_type();
|
||||
let input = Type::from(node.inputs.first().unwrap());
|
||||
let output = Type::from(node.outputs.first().unwrap());
|
||||
UnaryNode::not(input, output)
|
||||
}
|
||||
|
||||
fn greater_conversion(node: Node) -> BinaryNode {
|
||||
let lhs = node.inputs.first().unwrap().to_type();
|
||||
let rhs = node.inputs.get(1).unwrap().to_type();
|
||||
let output = node.outputs.first().unwrap().to_type();
|
||||
let lhs = Type::from(node.inputs.first().unwrap());
|
||||
let rhs = Type::from(node.inputs.get(1).unwrap());
|
||||
let output = Type::from(node.outputs.first().unwrap());
|
||||
BinaryNode::greater(lhs, rhs, output)
|
||||
}
|
||||
|
||||
fn less_conversion(node: Node) -> BinaryNode {
|
||||
let lhs = node.inputs.first().unwrap().to_type();
|
||||
let rhs = node.inputs.get(1).unwrap().to_type();
|
||||
let output = node.outputs.first().unwrap().to_type();
|
||||
let lhs = Type::from(node.inputs.first().unwrap());
|
||||
let rhs = Type::from(node.inputs.get(1).unwrap());
|
||||
let output = Type::from(node.outputs.first().unwrap());
|
||||
BinaryNode::lower(lhs, rhs, output)
|
||||
}
|
||||
|
||||
fn greater_or_equal_conversion(node: Node) -> BinaryNode {
|
||||
let lhs = node.inputs.first().unwrap().to_type();
|
||||
let rhs = node.inputs.get(1).unwrap().to_type();
|
||||
let output = node.outputs.first().unwrap().to_type();
|
||||
let lhs = Type::from(node.inputs.first().unwrap());
|
||||
let rhs = Type::from(node.inputs.get(1).unwrap());
|
||||
let output = Type::from(node.outputs.first().unwrap());
|
||||
BinaryNode::greater_equal(lhs, rhs, output)
|
||||
}
|
||||
|
||||
fn less_or_equal_conversion(node: Node) -> BinaryNode {
|
||||
let lhs = node.inputs.first().unwrap().to_type();
|
||||
let rhs = node.inputs.get(1).unwrap().to_type();
|
||||
let output = node.outputs.first().unwrap().to_type();
|
||||
let lhs = Type::from(node.inputs.first().unwrap());
|
||||
let rhs = Type::from(node.inputs.get(1).unwrap());
|
||||
let output = Type::from(node.outputs.first().unwrap());
|
||||
BinaryNode::lower_equal(lhs, rhs, output)
|
||||
}
|
||||
|
||||
fn pow_conversion(node: Node) -> BinaryNode {
|
||||
let lhs = node.inputs.first().unwrap().to_type();
|
||||
let rhs = node.inputs.get(1).unwrap().to_type();
|
||||
let output = node.outputs.first().unwrap().to_type();
|
||||
let lhs = Type::from(node.inputs.first().unwrap());
|
||||
let rhs = Type::from(node.inputs.get(1).unwrap());
|
||||
let output = Type::from(node.outputs.first().unwrap());
|
||||
match &rhs {
|
||||
Type::Tensor(x) => match x.kind {
|
||||
TensorKind::Int => BinaryNode::powi(lhs, rhs, output),
|
||||
|
@ -1043,14 +1043,14 @@ impl OnnxGraph {
|
|||
}
|
||||
|
||||
fn sign_conversion(node: Node) -> UnaryNode {
|
||||
let input = node.inputs.first().unwrap().to_type();
|
||||
let output = node.outputs.first().unwrap().to_type();
|
||||
let input = Type::from(node.inputs.first().unwrap());
|
||||
let output = Type::from(node.outputs.first().unwrap());
|
||||
UnaryNode::sign(input, output)
|
||||
}
|
||||
|
||||
fn squeeze_conversion(node: Node) -> SqueezeNode {
|
||||
let input = node.inputs.first().unwrap().to_tensor_type();
|
||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
||||
let input = TensorType::from(node.inputs.first().unwrap());
|
||||
let output = TensorType::from(node.outputs.first().unwrap());
|
||||
let axes = squeeze_config(&node);
|
||||
|
||||
SqueezeNode::new(input, output, axes)
|
||||
|
@ -1101,48 +1101,49 @@ fn serialize_data<E: Element>(data: Data, shape: Vec<usize>) -> TensorData {
|
|||
}
|
||||
}
|
||||
|
||||
impl Argument {
|
||||
pub fn to_tensor_type(&self) -> TensorType {
|
||||
match &self.ty {
|
||||
ArgType::Tensor(ir::TensorType {
|
||||
impl From<&OnnxArgument> for TensorType {
|
||||
fn from(arg: &OnnxArgument) -> Self {
|
||||
match &arg.ty {
|
||||
ArgType::Tensor(OnnxTensorType {
|
||||
elem_type: ElementType::Float16 | ElementType::Float32 | ElementType::Float64,
|
||||
dim,
|
||||
..
|
||||
}) => TensorType::new_float(self.name.clone(), *dim),
|
||||
ArgType::Tensor(ir::TensorType {
|
||||
}) => TensorType::new_float(arg.name.clone(), *dim),
|
||||
ArgType::Tensor(OnnxTensorType {
|
||||
elem_type: ElementType::Int32 | ElementType::Int64,
|
||||
dim,
|
||||
..
|
||||
}) => TensorType::new_int(self.name.clone(), *dim),
|
||||
ArgType::Tensor(ir::TensorType {
|
||||
}) => TensorType::new_int(arg.name.clone(), *dim),
|
||||
ArgType::Tensor(OnnxTensorType {
|
||||
elem_type: ElementType::Bool,
|
||||
dim,
|
||||
..
|
||||
}) => TensorType::new_bool(self.name.clone(), *dim),
|
||||
_ => panic!("Can't transform to tensor."),
|
||||
}) => TensorType::new_bool(arg.name.clone(), *dim),
|
||||
_ => panic!("Can't transform scalar to tensor."),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_type(&self) -> Type {
|
||||
match &self.ty {
|
||||
}
|
||||
impl From<&OnnxArgument> for Type {
|
||||
fn from(arg: &OnnxArgument) -> Self {
|
||||
match &arg.ty {
|
||||
ArgType::Tensor(tensor) => {
|
||||
// Treat tensor with dim 0 as scalar
|
||||
if tensor.dim == 0 {
|
||||
Type::Scalar(ScalarType::new(
|
||||
self.name.clone(),
|
||||
arg.name.clone(),
|
||||
ScalarKind::from(&tensor.elem_type),
|
||||
))
|
||||
} else {
|
||||
let kind: TensorKind = tensor.elem_type.clone().into();
|
||||
let dim = tensor.dim;
|
||||
let name = self.name.clone();
|
||||
let name = arg.name.clone();
|
||||
let shape = tensor.shape.clone();
|
||||
Type::Tensor(TensorType::new(name, dim, kind, shape))
|
||||
}
|
||||
}
|
||||
|
||||
ArgType::Scalar(elem_type) => {
|
||||
Type::Scalar(ScalarType::new(self.name.clone(), elem_type.into()))
|
||||
Type::Scalar(ScalarType::new(arg.name.clone(), elem_type.into()))
|
||||
}
|
||||
ArgType::Shape(_shape) => panic!("Can't transform shape to tensor."),
|
||||
}
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
[package]
|
||||
authors = [
|
||||
"Dilshod Tadjibaev (@antimora)",
|
||||
"Nathaniel Simard (@nathanielsimard)",
|
||||
]
|
||||
description = "Library for parsing ONNX models"
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
name = "onnx-ir"
|
||||
readme.workspace = true
|
||||
repository = "https://github.com/tracel-ai/burn/tree/main/crates/onnx-ir"
|
||||
version.workspace = true
|
||||
|
||||
|
||||
[dependencies]
|
||||
bytemuck = { workspace = true }
|
||||
half = { workspace = true }
|
||||
log = { workspace = true }
|
||||
protobuf = { workspace = true, features = ["with-bytes"] }
|
||||
regex = { workspace = true }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
strum = { workspace = true }
|
||||
strum_macros = { workspace = true }
|
||||
|
||||
|
||||
[build-dependencies]
|
||||
protobuf-codegen = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
pretty_assertions = { workspace = true }
|
||||
rstest = { workspace = true }
|
|
@ -0,0 +1,7 @@
|
|||
# ONNX-IR
|
||||
|
||||
A pure rust Onnx Parser. Creates an intermediate representation useful for generating code in any ML/DL framework
|
||||
|
||||
For a full list of currently supported operators, please check [here](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/SUPPORTED-ONNX-OPS.md)
|
||||
|
||||
To see how to use this for generating burn graphs, see [here](crates/burn-import/src/onnx/to_burn.rs).
|
|
@ -0,0 +1,9 @@
|
|||
fn main() {
|
||||
// Generate the onnx protobuf files
|
||||
protobuf_codegen::Codegen::new()
|
||||
.pure()
|
||||
.includes(["src"])
|
||||
.input("src/protos/onnx.proto")
|
||||
.cargo_out_dir("onnx-protos")
|
||||
.run_from_script();
|
||||
}
|
|
@ -6,7 +6,7 @@ use super::{
|
|||
proto_conversion::convert_node_proto,
|
||||
protos::NodeProto,
|
||||
};
|
||||
use crate::onnx::ir::{ArgType, Data, TensorType};
|
||||
use crate::ir::{ArgType, Data, TensorType};
|
||||
|
||||
/// The function transforms the graph into a new one where the nodes are coalesced into a single node.
|
||||
pub fn coalesce(
|
|
@ -3,10 +3,10 @@ use core::panic;
|
|||
|
||||
use protobuf::Enum;
|
||||
|
||||
use super::{
|
||||
use crate::{
|
||||
ir::{ArgType, AttributeValue, Data, ElementType, Node, NodeType, TensorType},
|
||||
op_configuration::flatten_config,
|
||||
protos::tensor_proto::DataType,
|
||||
util::flatten_config,
|
||||
};
|
||||
|
||||
/// Infer the dimension of each output tensor and update them.
|
|
@ -4,7 +4,7 @@ use std::{
|
|||
path::Path,
|
||||
};
|
||||
|
||||
use crate::onnx::node_remap::remap_node_type;
|
||||
use crate::node_remap::remap_node_type;
|
||||
|
||||
use super::{
|
||||
coalesce::coalesce,
|
||||
|
@ -56,9 +56,9 @@ pub struct GraphData {
|
|||
|
||||
impl GraphData {
|
||||
pub(crate) fn new(
|
||||
inputs: &Vec<ValueInfoProto>,
|
||||
outputs: &Vec<ValueInfoProto>,
|
||||
initializers: &Vec<TensorProto>,
|
||||
inputs: &[ValueInfoProto],
|
||||
outputs: &[ValueInfoProto],
|
||||
initializers: &[TensorProto],
|
||||
) -> Self {
|
||||
let mut input_name_map = HashMap::new();
|
||||
let mut input_key_map = HashMap::new();
|
||||
|
@ -375,35 +375,32 @@ pub fn parse_onnx(onnx_path: &Path) -> OnnxGraph {
|
|||
/// properly deleted if nothing else uses it
|
||||
/// Remap the unsqueeze node to a reshape node
|
||||
pub(crate) fn remap_unsqueeze_to_reshape(node: &mut Node, out_arg: &Argument) {
|
||||
match &out_arg.ty {
|
||||
ArgType::Tensor(output_tensor) => {
|
||||
let inner = output_tensor
|
||||
.shape
|
||||
.clone()
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(|x| x as i64)
|
||||
.collect::<Vec<i64>>();
|
||||
let shape_len = inner.len();
|
||||
let new_rhs_value = Some(Data::Int64s(inner));
|
||||
//moving the remap to here
|
||||
let rhs_arg = Argument {
|
||||
name: format!("{}_generated_const", &node.name),
|
||||
ty: ArgType::Tensor(TensorType {
|
||||
elem_type: super::ir::ElementType::Int64,
|
||||
dim: 1,
|
||||
shape: Some(vec![shape_len]),
|
||||
}),
|
||||
value: new_rhs_value,
|
||||
passed: false,
|
||||
};
|
||||
// ? should this replace the old input (reuse the old key) or should it be a new key
|
||||
// going with new key for now
|
||||
node.inputs[1] = rhs_arg;
|
||||
node.outputs[0] = out_arg.clone();
|
||||
node.node_type = NodeType::Reshape;
|
||||
}
|
||||
_ => {}
|
||||
if let ArgType::Tensor(output_tensor) = &out_arg.ty {
|
||||
let inner = output_tensor
|
||||
.shape
|
||||
.clone()
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(|x| x as i64)
|
||||
.collect::<Vec<i64>>();
|
||||
let shape_len = inner.len();
|
||||
let new_rhs_value = Some(Data::Int64s(inner));
|
||||
//moving the remap to here
|
||||
let rhs_arg = Argument {
|
||||
name: format!("{}_generated_const", &node.name),
|
||||
ty: ArgType::Tensor(TensorType {
|
||||
elem_type: super::ir::ElementType::Int64,
|
||||
dim: 1,
|
||||
shape: Some(vec![shape_len]),
|
||||
}),
|
||||
value: new_rhs_value,
|
||||
passed: false,
|
||||
};
|
||||
// ? should this replace the old input (reuse the old key) or should it be a new key
|
||||
// going with new key for now
|
||||
node.inputs[1] = rhs_arg;
|
||||
node.outputs[0] = out_arg.clone();
|
||||
node.node_type = NodeType::Reshape;
|
||||
}
|
||||
}
|
||||
// Define a trait for topological sorting
|
||||
|
@ -444,7 +441,7 @@ impl TopologicalSortable for Vec<NodeProto> {
|
|||
}
|
||||
|
||||
/// Get the value of a constant node from its attributes
|
||||
pub(crate) fn convert_constant_value(node: &Node) -> Argument {
|
||||
pub fn convert_constant_value(node: &Node) -> Argument {
|
||||
// A value can be stored in any of these attributes
|
||||
let keys = [
|
||||
"value",
|
|
@ -3,7 +3,7 @@ use half::f16;
|
|||
use std::{collections::HashMap, fmt::Formatter};
|
||||
use strum_macros::{Display, EnumString};
|
||||
|
||||
use super::protos::TensorProto;
|
||||
use crate::protos::TensorProto;
|
||||
|
||||
pub type Dim = usize;
|
||||
pub type Shape = Vec<Dim>;
|
||||
|
@ -241,7 +241,7 @@ impl PartialEq for Argument {
|
|||
}
|
||||
|
||||
/// The list of supported node types (ONNX operators and some extra ones to map easily to Burn's ops)
|
||||
/// Refer: https://github.com/onnx/onnx/blob/main/docs/Operators.md
|
||||
/// Refer: <https://github.com/onnx/onnx/blob/main/docs/Operators.md>
|
||||
#[derive(Debug, Hash, Eq, PartialEq, EnumString, Clone, Display)]
|
||||
pub enum NodeType {
|
||||
Abs,
|
||||
|
@ -444,7 +444,7 @@ pub enum NodeType {
|
|||
}
|
||||
|
||||
/// Truncate the vector display for debug display
|
||||
fn trunc<T: fmt::Display>(v: &Vec<T>) -> String {
|
||||
fn trunc<T: fmt::Display>(v: &[T]) -> String {
|
||||
const BEGIN_INDEX: usize = 0;
|
||||
const MAX_LEN: usize = 5;
|
||||
let mut s = String::new();
|
|
@ -0,0 +1,12 @@
|
|||
mod coalesce;
|
||||
mod dim_inference;
|
||||
mod from_onnx;
|
||||
pub mod ir;
|
||||
mod node_remap;
|
||||
mod proto_conversion;
|
||||
mod protos;
|
||||
mod util;
|
||||
|
||||
pub use from_onnx::convert_constant_value;
|
||||
pub use from_onnx::parse_onnx;
|
||||
pub use ir::OnnxGraph;
|
|
@ -1,6 +1,6 @@
|
|||
use std::str::{from_utf8, FromStr};
|
||||
|
||||
use crate::onnx::ir::TensorType;
|
||||
use crate::ir::TensorType;
|
||||
|
||||
use super::from_onnx::GraphData;
|
||||
use super::ir::Dim;
|
|
@ -0,0 +1,45 @@
|
|||
use crate::ir::{ArgType, Node};
|
||||
/// Create a FlattenConfig from the attributes of the node
|
||||
pub fn flatten_config(curr: &Node) -> (usize, usize) {
|
||||
// the begin dimension is the first dimension (Default: 1 per ONNX spec)
|
||||
let mut start_dim: i64 = 1;
|
||||
|
||||
// check if the node has only one input
|
||||
if curr.inputs.len() != 1 {
|
||||
panic!(
|
||||
"Flatten: multiple inputs are not supported (got {:?})",
|
||||
curr.inputs.len()
|
||||
);
|
||||
}
|
||||
|
||||
// extract the shape of the input tensor
|
||||
let tensor = match curr.inputs.first().unwrap().clone().ty {
|
||||
ArgType::Tensor(tensor) => tensor,
|
||||
_ => panic!("Only tensor input is valid"),
|
||||
};
|
||||
|
||||
// check if the input tensor has at least 2 dimensions
|
||||
if tensor.dim < 2 {
|
||||
panic!(
|
||||
"Flatten: input tensor must have at least 2 dimensions (got {:?})",
|
||||
tensor.dim
|
||||
);
|
||||
}
|
||||
|
||||
// the end dimension is the last dimension
|
||||
let end_dim = tensor.dim - 1;
|
||||
|
||||
// extract the attributes
|
||||
for (key, value) in curr.attrs.iter() {
|
||||
if key.as_str() == "axis" {
|
||||
start_dim = value.clone().into_i64();
|
||||
}
|
||||
}
|
||||
|
||||
// if beg_dim is negative, it is counted from the end
|
||||
if start_dim < 0 {
|
||||
start_dim += tensor.dim as i64;
|
||||
}
|
||||
|
||||
(start_dim as usize, end_dim)
|
||||
}
|
Loading…
Reference in New Issue