Separating ONNX parsing from burn-import (#1921)

* separating onnx parsing from burn-import

* ran clippy and cargo-fmt

* removed unused deps from onnx-ir

* fixed clippy warnings that were causing run-checks to fail

* removed dead code

* removed unused dependencies from burn-import

* updated contributor-book, updated publish.yml, added readme

* update cargo lock

* formatted md document with prettier, rephrased sentence

* missed the errors with reduce_prod_conversion during merge

* formatted onnx-to-burn-conversion-tool.md, forgot to save
This commit is contained in:
Joshua Ferguson 2024-07-02 15:17:44 -05:00 committed by GitHub
parent 755c0708c4
commit 25348cf181
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 405 additions and 279 deletions

View File

@ -161,3 +161,9 @@ jobs:
with:
crate: burn-import
secrets: inherit
publish-onnx-ir:
uses: tracel-ai/burn/.github/workflows/publish-template.yml@main
with:
crate: onnx-ir
secrets: inherit

23
Cargo.lock generated
View File

@ -628,23 +628,19 @@ name = "burn-import"
version = "0.14.0"
dependencies = [
"burn",
"bytemuck",
"candle-core",
"derive-new",
"half",
"log",
"onnx-ir",
"pretty_assertions",
"proc-macro2",
"protobuf",
"protobuf-codegen",
"quote",
"regex",
"rstest",
"rust-format",
"serde",
"serde_json",
"strum",
"strum_macros",
"syn 2.0.68",
"thiserror",
"tracing-core",
@ -3689,6 +3685,23 @@ dependencies = [
"serde",
]
[[package]]
name = "onnx-ir"
version = "0.14.0"
dependencies = [
"bytemuck",
"half",
"log",
"pretty_assertions",
"protobuf",
"protobuf-codegen",
"regex",
"rstest",
"serde",
"strum",
"strum_macros",
]
[[package]]
name = "onnx-tests"
version = "0.14.0"

View File

@ -117,7 +117,8 @@ plug in your operator in terms of \\(x\\) and \\(y\\), and just swap out the var
### Testing autodiff
For testing the `autodiff` operations, please refer to [this section](../getting-started/testing.md).
For testing the `autodiff` operations, please refer to
[this section](../getting-started/testing.md).
## Adding the Op to other backends
@ -199,11 +200,15 @@ Generating the ONNX test files or tests is already covered
about the specific changes you need to make when adding new operators after you have generated the
tests.
The crate is divided into two sections `src/burn` and `src/onnx`. The code under the former
corresponds to the operation you've implemented earlier in this guide, and the latter to the
operations defined in the ONNX specification. So when you are loading a model, the operator is first
parsed to an intermediate representation defined by `src/onnx`, and then mapped to a Burn operation
defined under `src/burn/node`.
Changes will need to be made to both `onnx-ir` and `burn-import`. The code within `onnx-ir` defines
how to parse the nodes in an onnx file and produces the intermediate representation. The code within
`burn-import` is divided into two sections: `src/onnx` and `src/burn`. The code under the former
maps that intermediate representation to one used for code generation and the latter defines how to
generate code for the operator you've implemented earlier in this guide.
So when you are loading a model, the operator is first parsed to an intermediate representation
defined by `burn-import` and then mapped to a Burn operation defined under `src/burn/node`; the
mapping from onnx to burn is aptly defined in `src/onnx/to_burn`
Let's review the changes made for powf starting from `src/burn` and moving to `src/onnx`:
@ -218,17 +223,20 @@ Let's review the changes made for powf starting from `src/burn` and moving to `s
[`{op}_conversion` function](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-import/src/onnx/to_burn.rs#L717)
that maps the ONNX node to the binary type
3. Specify how dimensions for the output should be derived in
[crates/burn-import/src/onnx/dim_inference.rs](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-import/src/onnx/dim_inference.rs#L55)
[crates/onnx-ir/src/dim_inference.rs](https://github.com/tracel-ai/burn/blob/d4ae82b21ac3dd1def01bd380ab7ea4d3293eccb/crates/onnx-ir/src/dim_inference.rs#L17)
And you're done! Congrats, you just fully added a new operation to burn, and we are all one step
closer to the answer to [Are we learning yet?](https://www.arewelearningyet.com/) being "Yes, and
it's freaking fast!". Buy yourself a coffee.
[^supertrait]: for more on supertraits see
[^supertrait]:
for more on supertraits see
[the advanced trait section of the rust book](https://doc.rust-lang.org/book/ch19-03-advanced-traits.html#using-supertraits-to-require-one-traits-functionality-within-another-trait)
[^autodiff]: wiki link for
[^autodiff]:
wiki link for
[automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation)
[^absolute_units]: for more information on unit structs see
[^absolute_units]:
for more information on unit structs see
[the defining and instantiating structs section of the rust book](https://doc.rust-lang.org/book/ch05-01-defining-structs.html#unit-like-structs-without-any-fields)

View File

@ -16,7 +16,17 @@ For an introduction to ONNX import in Burn, see
- [Design Goals](#design-goals)
- [Design Decisions](#design-decisions)
- [Adding New Operators](#adding-new-operators)
- [Implementing a New Operator](#implementing-a-new-operator)
- [Implementing a New Operator](#implementing-a-new-operator)
- [Step 1: Visibility](#step-1-visibility)
- [Step 2: Node Implementation](#step-2-node-implementation)
- [Within Onnx-IR](#within-onnx-ir)
- [Within burn-import](#within-burn-import)
- [Step 3: Registering New Operations](#step-3-registering-new-operations)
- [Step 4: Create a Config Function](#step-4-create-a-config-function)
- [Step 5: Dimension Inference](#step-5-dimension-inference)
- [Step 6: Integrate into the Graph Building Process](#step-6-integrate-into-the-graph-building-process)
- [Step 7: Add Newly Supported Op!](#step-7-add-newly-supported-op)
- [Misc:](#misc)
- [Testing](#testing)
- [Resources](#resources)
@ -91,6 +101,22 @@ located in the `src/burn/node/` directory.
### Step 2: Node Implementation
#### Within Onnx-IR
If the node type does not exist within the
[`NodeType` enum](https://github.com/tracel-ai/burn/blob/d4ae82b21ac3dd1def01bd380ab7ea4d3293eccb/crates/onnx-ir/src/ir.rs#L246),
it will need to be added (support for custom operators is planned). If the node might be provided an
input which is a constant or the output of an identity node, it will need to be added to the list of
nodeTypes
[checked for constants](https://github.com/tracel-ai/burn/blob/d4ae82b21ac3dd1def01bd380ab7ea4d3293eccb/crates/onnx-ir/src/from_onnx.rs#L21).
The node will need to be added to `dim_inference`, and in most cases the work parsing side will be
done. If a node requires extra parsing (such as handling an edge case like potentially remapping an
unsqueeze to a reshape) the best place for that is after check constants and prior to dim_inference
in
[`OnnxGraphBuilder::Build`](https://github.com/tracel-ai/burn/blob/d4ae82b21ac3dd1def01bd380ab7ea4d3293eccb/crates/onnx-ir/src/from_onnx.rs#L221)
#### Within burn-import
Create a new file named `<operation_name>.rs` in the `src/burn/node/` directory.
This file will define the structure and functionality of your new operation. By convention, the
necessary information for carrying out an operation is encapsulated within a struct named

View File

@ -20,30 +20,23 @@ pytorch = ["burn/record-item-custom-serde", "thiserror", "zip"]
[dependencies]
burn = { path = "../burn", version = "0.14.0", features = ["ndarray"] }
bytemuck = { workspace = true }
onnx-ir = { path = "../onnx-ir" }
candle-core = { workspace = true }
derive-new = { workspace = true }
half = { workspace = true }
log = { workspace = true }
proc-macro2 = { workspace = true }
protobuf = { workspace = true, features = ["with-bytes"] }
quote = { workspace = true }
regex = { workspace = true }
rust-format = { workspace = true, features = ["token_stream", "post_process"] }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true, features = ["std"] }
strum = { workspace = true }
strum_macros = { workspace = true }
syn = { workspace = true, features = ["parsing"] }
thiserror = { workspace = true, optional = true }
tracing-core = { workspace = true }
tracing-subscriber = { workspace = true }
zip = { workspace = true, optional = true }
[build-dependencies]
protobuf-codegen = { workspace = true }
[dev-dependencies]
pretty_assertions = { workspace = true }
rstest = { workspace = true }

View File

@ -1,11 +0,0 @@
fn main() {
if cfg!(feature = "onnx") {
// Generate the onnx protobuf files
protobuf_codegen::Codegen::new()
.pure()
.includes(["src"])
.input("src/onnx/protos/onnx.proto")
.cargo_out_dir("onnx-protos")
.run_from_script();
}
}

View File

@ -1,14 +1,3 @@
mod coalesce;
mod dim_inference;
mod from_onnx;
mod ir;
mod node_remap;
mod op_configuration;
mod proto_conversion;
mod protos;
mod to_burn;
pub use to_burn::*;
pub use from_onnx::parse_onnx;
pub use ir::OnnxGraph;

View File

@ -5,8 +5,8 @@ use burn::nn::{
PaddingConfig2d,
};
use super::ir::{ArgType, AttributeValue, Data, Node};
use crate::burn::node::resize::ResizeMode;
use onnx_ir::ir::{ArgType, AttributeValue, Data, Node};
/// Create a Conv1dConfig from the attributes of the node
pub fn conv1d_config(curr: &Node) -> Conv1dConfig {

View File

@ -53,20 +53,24 @@ use crate::{
},
format_tokens,
logger::init_log,
onnx::{
from_onnx::convert_constant_value,
ir::{Node, NodeType},
op_configuration::*,
},
};
use super::{
from_onnx::parse_onnx,
ir::{self, ArgType, Argument, Data, ElementType, OnnxGraph},
op_configuration::{
avg_pool2d_config, clip_config, concat_config, dropout_config, reshape_config,
resize_config, softmax_config,
use super::op_configuration::{
argmax_config, avg_pool1d_config, avg_pool2d_config, batch_norm_config, clip_config,
concat_config, conv1d_config, conv2d_config, conv_transpose2d_config, dropout_config,
expand_config, flatten_config, gather_config, layer_norm_config, leaky_relu_config,
linear_config, log_softmax_config, max_pool1d_config, max_pool2d_config, reduce_max_config,
reduce_mean_config, reduce_min_config, reduce_prod_config, reduce_sum_config, reshape_config,
resize_config, shape_config, slice_config, softmax_config, squeeze_config, transpose_config,
unsqueeze_config,
};
use onnx_ir::{
convert_constant_value,
ir::{
ArgType, Argument as OnnxArgument, Data, ElementType, Node, NodeType, OnnxGraph,
TensorType as OnnxTensorType,
},
parse_onnx,
};
pub use crate::burn::graph::RecordType;
@ -197,6 +201,7 @@ impl ModelGen {
log::debug!("Output file: {:?}", out_file);
let graph = parse_onnx(input.as_ref());
let graph = ParsedOnnxGraph(graph);
if self.development {
// export the graph
@ -231,15 +236,16 @@ impl ModelGen {
log::info!("Model generated");
}
}
impl OnnxGraph {
#[derive(Debug)]
struct ParsedOnnxGraph(OnnxGraph);
impl ParsedOnnxGraph {
/// Converts ONNX graph to Burn graph.
pub fn into_burn<PS: PrecisionSettings + 'static>(self) -> BurnGraph<PS> {
let mut graph = BurnGraph::<PS>::default();
let mut unsupported_ops = vec![];
for node in self.nodes {
for node in self.0.nodes {
match node.node_type {
NodeType::Add => graph.register(Self::add_conversion(node)),
NodeType::ArgMax => graph.register(Self::argmax_conversion(node)),
@ -328,11 +334,13 @@ impl OnnxGraph {
// Get input and output names
let input_names = self
.0
.inputs
.iter()
.map(|input| input.name.clone())
.collect::<Vec<_>>();
let output_names = self
.0
.outputs
.iter()
.map(|output| output.name.clone())
@ -390,13 +398,13 @@ impl OnnxGraph {
ArgType::Shape(_) => panic!("Shape is not supported as constant value."),
};
ConstantNode::new(node.name.clone(), const_value, output.to_type())
ConstantNode::new(node.name.clone(), const_value, Type::from(output))
}
fn random_uniform_conversion(node: Node) -> RandomUniformNode {
let output = node.outputs.first().unwrap();
// cannot use output.to_tensor_type() here, since it drops the shape info...
let output_type = if let Type::Tensor(t) = output.to_type() {
let output_type = if let Type::Tensor(t) = Type::from(output) {
t
} else {
panic!("RandomUniform output type is no Tensor.");
@ -423,7 +431,7 @@ impl OnnxGraph {
fn random_normal_conversion(node: Node) -> RandomNormalNode {
let output = node.outputs.first().unwrap();
// cannot use output.to_tensor_type() here, since it drops the shape info...
let output_type = if let Type::Tensor(t) = output.to_type() {
let output_type = if let Type::Tensor(t) = Type::from(output) {
t
} else {
panic!("RandomNormal output type is no Tensor.");
@ -448,141 +456,141 @@ impl OnnxGraph {
}
fn add_conversion(node: Node) -> BinaryNode {
let lhs = node.inputs.first().unwrap().to_type();
let rhs = node.inputs.get(1).unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
let lhs = Type::from(node.inputs.first().unwrap());
let rhs = Type::from(node.inputs.get(1).unwrap());
let output = Type::from(node.outputs.first().unwrap());
BinaryNode::add(lhs, rhs, output)
}
fn sub_conversion(node: Node) -> BinaryNode {
let lhs = node.inputs.first().unwrap().to_type();
let rhs = node.inputs.get(1).unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
let lhs = Type::from(node.inputs.first().unwrap());
let rhs = Type::from(node.inputs.get(1).unwrap());
let output = Type::from(node.outputs.first().unwrap());
BinaryNode::sub(lhs, rhs, output)
}
fn mul_conversion(node: Node) -> BinaryNode {
let lhs = node.inputs.first().unwrap().to_type();
let rhs = node.inputs.get(1).unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
let lhs = Type::from(node.inputs.first().unwrap());
let rhs = Type::from(node.inputs.get(1).unwrap());
let output = Type::from(node.outputs.first().unwrap());
BinaryNode::mul(lhs, rhs, output)
}
fn div_conversion(node: Node) -> BinaryNode {
let lhs = node.inputs.first().unwrap().to_type();
let rhs = node.inputs.get(1).unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
let lhs = Type::from(node.inputs.first().unwrap());
let rhs = Type::from(node.inputs.get(1).unwrap());
let output = Type::from(node.outputs.first().unwrap());
BinaryNode::div(lhs, rhs, output)
}
fn matmul_conversion(node: Node) -> MatmulNode {
let lhs = node.inputs.first().unwrap().to_tensor_type();
let rhs = node.inputs.get(1).unwrap().to_tensor_type();
let output = node.outputs.first().unwrap().to_tensor_type();
let lhs = TensorType::from(node.inputs.first().unwrap());
let rhs = TensorType::from(node.inputs.get(1).unwrap());
let output = TensorType::from(node.outputs.first().unwrap());
MatmulNode::new(lhs, rhs, output)
}
fn equal_conversion(node: Node) -> BinaryNode {
let lhs = node.inputs.first().unwrap().to_type();
let rhs = node.inputs.get(1).unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
let lhs = Type::from(node.inputs.first().unwrap());
let rhs = Type::from(node.inputs.get(1).unwrap());
let output = Type::from(node.outputs.first().unwrap());
BinaryNode::equal(lhs, rhs, output)
}
fn max_conversion(node: Node) -> BinaryNode {
let lhs = node.inputs.first().unwrap().to_type();
let rhs = node.inputs.get(1).unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
let lhs = Type::from(node.inputs.first().unwrap());
let rhs = Type::from(node.inputs.get(1).unwrap());
let output = Type::from(node.outputs.first().unwrap());
BinaryNode::max_pair(lhs, rhs, output)
}
fn erf_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
let input = Type::from(node.inputs.first().unwrap());
let output = Type::from(node.outputs.first().unwrap());
UnaryNode::erf(input, output)
}
fn leaky_relu_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
let input = Type::from(node.inputs.first().unwrap());
let output = Type::from(node.outputs.first().unwrap());
let alpha = leaky_relu_config(&node);
UnaryNode::leaky_relu(input, output, alpha)
}
fn relu_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
let input = Type::from(node.inputs.first().unwrap());
let output = Type::from(node.outputs.first().unwrap());
UnaryNode::relu(input, output)
}
fn gelu_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
let input = Type::from(node.inputs.first().unwrap());
let output = Type::from(node.outputs.first().unwrap());
UnaryNode::gelu(input, output)
}
fn log_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
let input = Type::from(node.inputs.first().unwrap());
let output = Type::from(node.outputs.first().unwrap());
UnaryNode::log(input, output)
}
fn flatten_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
let input = Type::from(node.inputs.first().unwrap());
let output = Type::from(node.outputs.first().unwrap());
let (start_dim, end_dim) = flatten_config(&node);
UnaryNode::flatten(input, output, start_dim, end_dim)
}
fn gather_conversion(node: Node) -> GatherNode {
let input = node.inputs.first().unwrap().to_tensor_type();
let index = node.inputs.get(1).unwrap().to_tensor_type();
let output = node.outputs.first().unwrap().to_tensor_type();
let input = TensorType::from(node.inputs.first().unwrap());
let index = TensorType::from(node.inputs.get(1).unwrap());
let output = TensorType::from(node.outputs.first().unwrap());
let dim = gather_config(&node);
GatherNode::new(input, index, output, dim)
}
fn gather_elements_conversion(node: Node) -> GatherElementsNode {
let input = node.inputs.first().unwrap().to_tensor_type();
let index = node.inputs.get(1).unwrap().to_tensor_type();
let output = node.outputs.first().unwrap().to_tensor_type();
let input = TensorType::from(node.inputs.first().unwrap());
let index = TensorType::from(node.inputs.get(1).unwrap());
let output = TensorType::from(node.outputs.first().unwrap());
let dim = gather_config(&node);
GatherElementsNode::new(input, index, output, dim)
}
fn transpose_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
let input = Type::from(node.inputs.first().unwrap());
let output = Type::from(node.outputs.first().unwrap());
let perm = transpose_config(&node);
UnaryNode::transpose(input, output, perm)
}
fn cast_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
let input = Type::from(node.inputs.first().unwrap());
let output = Type::from(node.outputs.first().unwrap());
UnaryNode::cast(input, output)
}
fn reshape_conversion(node: Node) -> ReshapeNode {
let input = node.inputs.first().unwrap().to_tensor_type();
let output = node.outputs.first().unwrap().to_tensor_type();
let input = TensorType::from(node.inputs.first().unwrap());
let output = TensorType::from(node.outputs.first().unwrap());
let shape = reshape_config(&node);
ReshapeNode::new(input, output, shape)
@ -591,10 +599,10 @@ impl OnnxGraph {
fn resize_conversion(node: Node) -> ResizeNode {
let name = &node.name;
let input = node.inputs[0].to_tensor_type();
let output_size = node.inputs[3].to_tensor_type();
let input = TensorType::from(&node.inputs[0]);
let output_size = TensorType::from(&node.inputs[3]);
let output = node.outputs.first().unwrap().to_tensor_type();
let output = TensorType::from(node.outputs.first().unwrap());
let mode = resize_config(&node);
@ -602,15 +610,15 @@ impl OnnxGraph {
}
fn min_conversion(node: Node) -> BinaryNode {
let lhs = node.inputs.first().unwrap().to_type();
let rhs = node.inputs.get(1).unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
let lhs = Type::from(node.inputs.first().unwrap());
let rhs = Type::from(node.inputs.get(1).unwrap());
let output = Type::from(node.outputs.first().unwrap());
BinaryNode::min_pair(lhs, rhs, output)
}
fn range_conversion(node: Node) -> RangeNode {
fn convert_arg_to_scalar(arg: &Argument) -> ScalarType {
fn convert_arg_to_scalar(arg: &OnnxArgument) -> ScalarType {
match &arg.ty {
ArgType::Scalar(scalar) => {
ScalarType::new(arg.name.clone(), ScalarKind::from(scalar))
@ -624,7 +632,7 @@ impl OnnxGraph {
_ => panic!("Range node requires scalar inputs"),
}
}
let output = node.outputs.first().unwrap().to_tensor_type();
let output = TensorType::from(node.outputs.first().unwrap());
let start = convert_arg_to_scalar(node.inputs.first().unwrap());
let end = convert_arg_to_scalar(node.inputs.get(1).unwrap());
let step = convert_arg_to_scalar(node.inputs.get(2).unwrap());
@ -633,164 +641,156 @@ impl OnnxGraph {
}
fn reduce_max_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
let input = Type::from(node.inputs.first().unwrap());
let output = Type::from(node.outputs.first().unwrap());
let dim = reduce_max_config(&node);
UnaryNode::reduce_max(input, output, dim)
}
fn reduce_min_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
let input = Type::from(node.inputs.first().unwrap());
let output = Type::from(node.outputs.first().unwrap());
let dim = reduce_min_config(&node);
UnaryNode::reduce_min(input, output, dim)
}
fn reduce_mean_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
let input = Type::from(node.inputs.first().unwrap());
let output = Type::from(node.outputs.first().unwrap());
let dim = reduce_mean_config(&node);
UnaryNode::reduce_mean(input, output, dim)
}
fn reduce_prod_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
let input = Type::from(node.inputs.first().unwrap());
let output = Type::from(node.outputs.first().unwrap());
let dim = reduce_prod_config(&node);
UnaryNode::reduce_prod(input, output, dim)
}
fn reduce_sum_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
let input = Type::from(node.inputs.first().unwrap());
let output = Type::from(node.outputs.first().unwrap());
let dim = reduce_sum_config(&node);
UnaryNode::reduce_sum(input, output, dim)
}
fn shape_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
let input = Type::from(node.inputs.first().unwrap());
let output = Type::from(node.outputs.first().unwrap());
let (start_dim, end_dim) = shape_config(&node);
UnaryNode::shape(input, output, start_dim, end_dim)
}
fn unsqueeze_conversion(node: Node) -> UnsqueezeNode {
let input = node.inputs.first().unwrap().to_type();
let output = node.outputs.first().unwrap().to_tensor_type();
let input = Type::from(node.inputs.first().unwrap());
let output = TensorType::from(node.outputs.first().unwrap());
let dims = unsqueeze_config(&node);
UnsqueezeNode::new(input, output, dims)
}
fn where_conversion(node: Node) -> WhereNode {
let condition = node.inputs.first().unwrap().to_tensor_type();
let x = node.inputs.get(1).unwrap().to_tensor_type();
let y = node.inputs.get(2).unwrap().to_tensor_type();
let output = node.outputs.first().unwrap().to_tensor_type();
let condition = TensorType::from(node.inputs.first().unwrap());
let x = TensorType::from(node.inputs.get(1).unwrap());
let y = TensorType::from(node.inputs.get(2).unwrap());
let output = TensorType::from(node.outputs.first().unwrap());
WhereNode::new(condition, x, y, output)
}
fn clip_conversion(node: Node) -> ClipNode {
let input = node.inputs.first().unwrap().to_tensor_type();
let output = node.outputs.first().unwrap().to_tensor_type();
let input = TensorType::from(node.inputs.first().unwrap());
let output = TensorType::from(node.outputs.first().unwrap());
let (min, max) = clip_config(&node);
ClipNode::new(input, output, min, max)
}
fn sigmoid_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
let input = Type::from(node.inputs.first().unwrap());
let output = Type::from(node.outputs.first().unwrap());
UnaryNode::sigmoid(input, output)
}
fn sin_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
let input = Type::from(node.inputs.first().unwrap());
let output = Type::from(node.outputs.first().unwrap());
UnaryNode::sin(input, output)
}
fn slice_conversion(node: Node) -> SliceNode {
let input = node.inputs.first().unwrap().to_tensor_type();
let output = node.outputs.first().unwrap().to_tensor_type();
let input = TensorType::from(node.inputs.first().unwrap());
let output = TensorType::from(node.outputs.first().unwrap());
let (starts, ends) = slice_config(&node);
SliceNode::new(input, output, starts, ends)
}
fn sum_conversion(node: Node) -> SumNode {
let inputs = node
.inputs
.iter()
.map(|input| input.to_tensor_type())
.collect();
let output = node.outputs.first().unwrap().to_tensor_type();
let inputs = node.inputs.iter().map(TensorType::from).collect();
let output = TensorType::from(node.outputs.first().unwrap());
SumNode::new(inputs, output)
}
fn reciprocal_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
let input = Type::from(node.inputs.first().unwrap());
let output = Type::from(node.outputs.first().unwrap());
UnaryNode::reciprocal(input, output)
}
fn log_softmax_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
let input = Type::from(node.inputs.first().unwrap());
let output = Type::from(node.outputs.first().unwrap());
let dim = log_softmax_config(&node);
UnaryNode::log_softmax(input, output, dim)
}
fn softmax_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
let input = Type::from(node.inputs.first().unwrap());
let output = Type::from(node.outputs.first().unwrap());
let dim = softmax_config(&node);
UnaryNode::softmax(input, output, dim)
}
fn sqrt_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
let input = Type::from(node.inputs.first().unwrap());
let output = Type::from(node.outputs.first().unwrap());
UnaryNode::sqrt(input, output)
}
fn tanh_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
let input = Type::from(node.inputs.first().unwrap());
let output = Type::from(node.outputs.first().unwrap());
UnaryNode::tanh(input, output)
}
fn argmax_conversion(node: Node) -> ArgMaxNode {
let input = node.inputs.first().unwrap().to_tensor_type();
let output = node.outputs.first().unwrap().to_tensor_type();
let input = TensorType::from(node.inputs.first().unwrap());
let output = TensorType::from(node.outputs.first().unwrap());
let axis = argmax_config(&node);
ArgMaxNode::new(input, output, axis)
}
fn concat_conversion(node: Node) -> ConcatNode {
let inputs = node
.inputs
.iter()
.map(|input| input.to_tensor_type())
.collect();
let inputs = node.inputs.iter().map(TensorType::from).collect();
let output = node.outputs.first().unwrap().to_tensor_type();
let output = TensorType::from(node.outputs.first().unwrap());
let dim = concat_config(&node);
ConcatNode::new(inputs, output, dim)
@ -798,8 +798,8 @@ impl OnnxGraph {
fn linear_conversion<PS: PrecisionSettings>(node: Node) -> LinearNode {
let name = &node.name;
let input = node.inputs.first().unwrap().to_tensor_type();
let output = node.outputs.first().unwrap().to_tensor_type();
let input = TensorType::from(node.inputs.first().unwrap());
let output = TensorType::from(node.outputs.first().unwrap());
let config = linear_config(&node);
let weight = extract_data_serialize::<PS::FloatElem>(1, &node).expect("Weight is required");
@ -811,8 +811,8 @@ impl OnnxGraph {
fn dropout_conversion(node: Node) -> DropoutNode {
let name = &node.name;
let input = node.inputs.first().unwrap().to_tensor_type();
let output = node.outputs.first().unwrap().to_tensor_type();
let input = TensorType::from(node.inputs.first().unwrap());
let output = TensorType::from(node.outputs.first().unwrap());
let config = dropout_config(&node);
DropoutNode::new(name, input, output, config)
@ -820,8 +820,8 @@ impl OnnxGraph {
fn batch_norm_conversion<PS: PrecisionSettings>(node: Node) -> BatchNormNode {
let config = batch_norm_config(&node);
let input = node.inputs.first().unwrap().to_tensor_type();
let output = node.outputs.first().unwrap().to_tensor_type();
let input = TensorType::from(node.inputs.first().unwrap());
let output = TensorType::from(node.outputs.first().unwrap());
let dim = input.dim - 2;
let gamma = extract_data_serialize::<PS::FloatElem>(1, &node).expect("Gamma is required");
@ -848,8 +848,8 @@ impl OnnxGraph {
fn layer_norm_conversion<PS: PrecisionSettings>(node: Node) -> LayerNormNode {
let (config, full_precision) = layer_norm_config(&node);
let input = node.inputs.first().unwrap().to_tensor_type();
let output = node.outputs.first().unwrap().to_tensor_type();
let input = TensorType::from(node.inputs.first().unwrap());
let output = TensorType::from(node.outputs.first().unwrap());
// Scale tensor (aka gamma)
let gamma = extract_data_serialize::<PS::FloatElem>(1, &node).expect("Gamma is required");
@ -862,8 +862,8 @@ impl OnnxGraph {
}
fn conv1d_conversion<PS: PrecisionSettings>(node: Node) -> Conv1dNode {
let input = node.inputs.first().unwrap().to_tensor_type();
let output = node.outputs.first().unwrap().to_tensor_type();
let input = TensorType::from(node.inputs.first().unwrap());
let output = TensorType::from(node.outputs.first().unwrap());
let config = conv1d_config(&node);
let bias = node.inputs.len() == 3;
@ -878,8 +878,8 @@ impl OnnxGraph {
}
fn conv2d_conversion<PS: PrecisionSettings>(node: Node) -> Conv2dNode {
let input = node.inputs.first().unwrap().to_tensor_type();
let output = node.outputs.first().unwrap().to_tensor_type();
let input = TensorType::from(node.inputs.first().unwrap());
let output = TensorType::from(node.outputs.first().unwrap());
let config = conv2d_config(&node);
let bias = node.inputs.len() == 3;
@ -894,8 +894,8 @@ impl OnnxGraph {
}
fn max_pool1d_conversion(node: Node) -> MaxPool1dNode {
let input = node.inputs.first().unwrap().to_tensor_type();
let output = node.outputs.first().unwrap().to_tensor_type();
let input = TensorType::from(node.inputs.first().unwrap());
let output = TensorType::from(node.outputs.first().unwrap());
let config = max_pool1d_config(&node);
let name = &node.name;
@ -903,8 +903,8 @@ impl OnnxGraph {
}
fn max_pool2d_conversion(node: Node) -> MaxPool2dNode {
let input = node.inputs.first().unwrap().to_tensor_type();
let output = node.outputs.first().unwrap().to_tensor_type();
let input = TensorType::from(node.inputs.first().unwrap());
let output = TensorType::from(node.outputs.first().unwrap());
let config = max_pool2d_config(&node);
let name = &node.name;
@ -912,16 +912,16 @@ impl OnnxGraph {
}
fn prelu_conversion<PS: PrecisionSettings>(node: Node) -> PReluNode {
let input = node.inputs.first().unwrap().to_tensor_type();
let output = node.outputs.first().unwrap().to_tensor_type();
let input = TensorType::from(node.inputs.first().unwrap());
let output = TensorType::from(node.outputs.first().unwrap());
let weight = extract_data_serialize::<PS::FloatElem>(1, &node).unwrap();
let config = PReluConfig::new();
let name = &node.name;
PReluNode::new(name, input, output, weight, config)
}
fn conv_transpose2d_conversion<PS: PrecisionSettings>(node: Node) -> ConvTranspose2dNode {
let input = node.inputs.first().unwrap().to_tensor_type();
let output = node.outputs.first().unwrap().to_tensor_type();
let input = TensorType::from(node.inputs.first().unwrap());
let output = TensorType::from(node.outputs.first().unwrap());
let config = conv_transpose2d_config(&node);
let bias = node.inputs.len() == 3;
@ -935,8 +935,8 @@ impl OnnxGraph {
ConvTranspose2dNode::new(name, input, output, weight, bias, config)
}
fn avg_pool_1d_conversion(node: Node) -> AvgPool1dNode {
let input = node.inputs.first().unwrap().to_tensor_type();
let output = node.outputs.first().unwrap().to_tensor_type();
let input = TensorType::from(node.inputs.first().unwrap());
let output = TensorType::from(node.outputs.first().unwrap());
let config = avg_pool1d_config(&node);
let name = &node.name;
@ -944,8 +944,8 @@ impl OnnxGraph {
}
fn avg_pool_2d_conversion(node: Node) -> AvgPool2dNode {
let input = node.inputs.first().unwrap().to_tensor_type();
let output = node.outputs.first().unwrap().to_tensor_type();
let input = TensorType::from(node.inputs.first().unwrap());
let output = TensorType::from(node.outputs.first().unwrap());
let config = avg_pool2d_config(&node);
let name = &node.name;
@ -953,8 +953,8 @@ impl OnnxGraph {
}
fn global_avg_pool_conversion(node: Node) -> GlobalAvgPoolNode {
let input = node.inputs.first().unwrap().to_tensor_type();
let output = node.outputs.first().unwrap().to_tensor_type();
let input = TensorType::from(node.inputs.first().unwrap());
let output = TensorType::from(node.outputs.first().unwrap());
let name = &node.name;
@ -962,71 +962,71 @@ impl OnnxGraph {
}
fn cos_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
let input = Type::from(node.inputs.first().unwrap());
let output = Type::from(node.outputs.first().unwrap());
UnaryNode::cos(input, output)
}
fn exp_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
let input = Type::from(node.inputs.first().unwrap());
let output = Type::from(node.outputs.first().unwrap());
UnaryNode::exp(input, output)
}
fn expand_conversion(node: Node) -> ExpandNode {
let input = node.inputs.first().unwrap().to_tensor_type();
let output = node.outputs.first().unwrap().to_tensor_type();
let input = TensorType::from(node.inputs.first().unwrap());
let output = TensorType::from(node.outputs.first().unwrap());
let shape = expand_config(&node);
ExpandNode::new(input, output, shape)
}
fn neg_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
let input = Type::from(node.inputs.first().unwrap());
let output = Type::from(node.outputs.first().unwrap());
UnaryNode::neg(input, output)
}
fn not_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
let input = Type::from(node.inputs.first().unwrap());
let output = Type::from(node.outputs.first().unwrap());
UnaryNode::not(input, output)
}
fn greater_conversion(node: Node) -> BinaryNode {
let lhs = node.inputs.first().unwrap().to_type();
let rhs = node.inputs.get(1).unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
let lhs = Type::from(node.inputs.first().unwrap());
let rhs = Type::from(node.inputs.get(1).unwrap());
let output = Type::from(node.outputs.first().unwrap());
BinaryNode::greater(lhs, rhs, output)
}
fn less_conversion(node: Node) -> BinaryNode {
let lhs = node.inputs.first().unwrap().to_type();
let rhs = node.inputs.get(1).unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
let lhs = Type::from(node.inputs.first().unwrap());
let rhs = Type::from(node.inputs.get(1).unwrap());
let output = Type::from(node.outputs.first().unwrap());
BinaryNode::lower(lhs, rhs, output)
}
fn greater_or_equal_conversion(node: Node) -> BinaryNode {
let lhs = node.inputs.first().unwrap().to_type();
let rhs = node.inputs.get(1).unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
let lhs = Type::from(node.inputs.first().unwrap());
let rhs = Type::from(node.inputs.get(1).unwrap());
let output = Type::from(node.outputs.first().unwrap());
BinaryNode::greater_equal(lhs, rhs, output)
}
fn less_or_equal_conversion(node: Node) -> BinaryNode {
let lhs = node.inputs.first().unwrap().to_type();
let rhs = node.inputs.get(1).unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
let lhs = Type::from(node.inputs.first().unwrap());
let rhs = Type::from(node.inputs.get(1).unwrap());
let output = Type::from(node.outputs.first().unwrap());
BinaryNode::lower_equal(lhs, rhs, output)
}
fn pow_conversion(node: Node) -> BinaryNode {
let lhs = node.inputs.first().unwrap().to_type();
let rhs = node.inputs.get(1).unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
let lhs = Type::from(node.inputs.first().unwrap());
let rhs = Type::from(node.inputs.get(1).unwrap());
let output = Type::from(node.outputs.first().unwrap());
match &rhs {
Type::Tensor(x) => match x.kind {
TensorKind::Int => BinaryNode::powi(lhs, rhs, output),
@ -1043,14 +1043,14 @@ impl OnnxGraph {
}
fn sign_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
let input = Type::from(node.inputs.first().unwrap());
let output = Type::from(node.outputs.first().unwrap());
UnaryNode::sign(input, output)
}
fn squeeze_conversion(node: Node) -> SqueezeNode {
let input = node.inputs.first().unwrap().to_tensor_type();
let output = node.outputs.first().unwrap().to_tensor_type();
let input = TensorType::from(node.inputs.first().unwrap());
let output = TensorType::from(node.outputs.first().unwrap());
let axes = squeeze_config(&node);
SqueezeNode::new(input, output, axes)
@ -1101,48 +1101,49 @@ fn serialize_data<E: Element>(data: Data, shape: Vec<usize>) -> TensorData {
}
}
impl Argument {
pub fn to_tensor_type(&self) -> TensorType {
match &self.ty {
ArgType::Tensor(ir::TensorType {
impl From<&OnnxArgument> for TensorType {
fn from(arg: &OnnxArgument) -> Self {
match &arg.ty {
ArgType::Tensor(OnnxTensorType {
elem_type: ElementType::Float16 | ElementType::Float32 | ElementType::Float64,
dim,
..
}) => TensorType::new_float(self.name.clone(), *dim),
ArgType::Tensor(ir::TensorType {
}) => TensorType::new_float(arg.name.clone(), *dim),
ArgType::Tensor(OnnxTensorType {
elem_type: ElementType::Int32 | ElementType::Int64,
dim,
..
}) => TensorType::new_int(self.name.clone(), *dim),
ArgType::Tensor(ir::TensorType {
}) => TensorType::new_int(arg.name.clone(), *dim),
ArgType::Tensor(OnnxTensorType {
elem_type: ElementType::Bool,
dim,
..
}) => TensorType::new_bool(self.name.clone(), *dim),
_ => panic!("Can't transform to tensor."),
}) => TensorType::new_bool(arg.name.clone(), *dim),
_ => panic!("Can't transform scalar to tensor."),
}
}
pub fn to_type(&self) -> Type {
match &self.ty {
}
impl From<&OnnxArgument> for Type {
fn from(arg: &OnnxArgument) -> Self {
match &arg.ty {
ArgType::Tensor(tensor) => {
// Treat tensor with dim 0 as scalar
if tensor.dim == 0 {
Type::Scalar(ScalarType::new(
self.name.clone(),
arg.name.clone(),
ScalarKind::from(&tensor.elem_type),
))
} else {
let kind: TensorKind = tensor.elem_type.clone().into();
let dim = tensor.dim;
let name = self.name.clone();
let name = arg.name.clone();
let shape = tensor.shape.clone();
Type::Tensor(TensorType::new(name, dim, kind, shape))
}
}
ArgType::Scalar(elem_type) => {
Type::Scalar(ScalarType::new(self.name.clone(), elem_type.into()))
Type::Scalar(ScalarType::new(arg.name.clone(), elem_type.into()))
}
ArgType::Shape(_shape) => panic!("Can't transform shape to tensor."),
}

31
crates/onnx-ir/Cargo.toml Normal file
View File

@ -0,0 +1,31 @@
[package]
authors = [
"Dilshod Tadjibaev (@antimora)",
"Nathaniel Simard (@nathanielsimard)",
]
description = "Library for parsing ONNX models"
edition.workspace = true
license.workspace = true
name = "onnx-ir"
readme.workspace = true
repository = "https://github.com/tracel-ai/burn/tree/main/crates/onnx-ir"
version.workspace = true
[dependencies]
bytemuck = { workspace = true }
half = { workspace = true }
log = { workspace = true }
protobuf = { workspace = true, features = ["with-bytes"] }
regex = { workspace = true }
serde = { workspace = true, features = ["derive"] }
strum = { workspace = true }
strum_macros = { workspace = true }
[build-dependencies]
protobuf-codegen = { workspace = true }
[dev-dependencies]
pretty_assertions = { workspace = true }
rstest = { workspace = true }

7
crates/onnx-ir/README.md Normal file
View File

@ -0,0 +1,7 @@
# ONNX-IR
A pure rust Onnx Parser. Creates an intermediate representation useful for generating code in any ML/DL framework
For a full list of currently supported operators, please check [here](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/SUPPORTED-ONNX-OPS.md)
To see how to use this for generating burn graphs, see [here](crates/burn-import/src/onnx/to_burn.rs).

9
crates/onnx-ir/build.rs Normal file
View File

@ -0,0 +1,9 @@
fn main() {
// Generate the onnx protobuf files
protobuf_codegen::Codegen::new()
.pure()
.includes(["src"])
.input("src/protos/onnx.proto")
.cargo_out_dir("onnx-protos")
.run_from_script();
}

View File

@ -6,7 +6,7 @@ use super::{
proto_conversion::convert_node_proto,
protos::NodeProto,
};
use crate::onnx::ir::{ArgType, Data, TensorType};
use crate::ir::{ArgType, Data, TensorType};
/// The function transforms the graph into a new one where the nodes are coalesced into a single node.
pub fn coalesce(

View File

@ -3,10 +3,10 @@ use core::panic;
use protobuf::Enum;
use super::{
use crate::{
ir::{ArgType, AttributeValue, Data, ElementType, Node, NodeType, TensorType},
op_configuration::flatten_config,
protos::tensor_proto::DataType,
util::flatten_config,
};
/// Infer the dimension of each output tensor and update them.

View File

@ -4,7 +4,7 @@ use std::{
path::Path,
};
use crate::onnx::node_remap::remap_node_type;
use crate::node_remap::remap_node_type;
use super::{
coalesce::coalesce,
@ -56,9 +56,9 @@ pub struct GraphData {
impl GraphData {
pub(crate) fn new(
inputs: &Vec<ValueInfoProto>,
outputs: &Vec<ValueInfoProto>,
initializers: &Vec<TensorProto>,
inputs: &[ValueInfoProto],
outputs: &[ValueInfoProto],
initializers: &[TensorProto],
) -> Self {
let mut input_name_map = HashMap::new();
let mut input_key_map = HashMap::new();
@ -375,35 +375,32 @@ pub fn parse_onnx(onnx_path: &Path) -> OnnxGraph {
/// properly deleted if nothing else uses it
/// Remap the unsqueeze node to a reshape node
pub(crate) fn remap_unsqueeze_to_reshape(node: &mut Node, out_arg: &Argument) {
match &out_arg.ty {
ArgType::Tensor(output_tensor) => {
let inner = output_tensor
.shape
.clone()
.unwrap()
.into_iter()
.map(|x| x as i64)
.collect::<Vec<i64>>();
let shape_len = inner.len();
let new_rhs_value = Some(Data::Int64s(inner));
//moving the remap to here
let rhs_arg = Argument {
name: format!("{}_generated_const", &node.name),
ty: ArgType::Tensor(TensorType {
elem_type: super::ir::ElementType::Int64,
dim: 1,
shape: Some(vec![shape_len]),
}),
value: new_rhs_value,
passed: false,
};
// ? should this replace the old input (reuse the old key) or should it be a new key
// going with new key for now
node.inputs[1] = rhs_arg;
node.outputs[0] = out_arg.clone();
node.node_type = NodeType::Reshape;
}
_ => {}
if let ArgType::Tensor(output_tensor) = &out_arg.ty {
let inner = output_tensor
.shape
.clone()
.unwrap()
.into_iter()
.map(|x| x as i64)
.collect::<Vec<i64>>();
let shape_len = inner.len();
let new_rhs_value = Some(Data::Int64s(inner));
//moving the remap to here
let rhs_arg = Argument {
name: format!("{}_generated_const", &node.name),
ty: ArgType::Tensor(TensorType {
elem_type: super::ir::ElementType::Int64,
dim: 1,
shape: Some(vec![shape_len]),
}),
value: new_rhs_value,
passed: false,
};
// ? should this replace the old input (reuse the old key) or should it be a new key
// going with new key for now
node.inputs[1] = rhs_arg;
node.outputs[0] = out_arg.clone();
node.node_type = NodeType::Reshape;
}
}
// Define a trait for topological sorting
@ -444,7 +441,7 @@ impl TopologicalSortable for Vec<NodeProto> {
}
/// Get the value of a constant node from its attributes
pub(crate) fn convert_constant_value(node: &Node) -> Argument {
pub fn convert_constant_value(node: &Node) -> Argument {
// A value can be stored in any of these attributes
let keys = [
"value",

View File

@ -3,7 +3,7 @@ use half::f16;
use std::{collections::HashMap, fmt::Formatter};
use strum_macros::{Display, EnumString};
use super::protos::TensorProto;
use crate::protos::TensorProto;
pub type Dim = usize;
pub type Shape = Vec<Dim>;
@ -241,7 +241,7 @@ impl PartialEq for Argument {
}
/// The list of supported node types (ONNX operators and some extra ones to map easily to Burn's ops)
/// Refer: https://github.com/onnx/onnx/blob/main/docs/Operators.md
/// Refer: <https://github.com/onnx/onnx/blob/main/docs/Operators.md>
#[derive(Debug, Hash, Eq, PartialEq, EnumString, Clone, Display)]
pub enum NodeType {
Abs,
@ -444,7 +444,7 @@ pub enum NodeType {
}
/// Truncate the vector display for debug display
fn trunc<T: fmt::Display>(v: &Vec<T>) -> String {
fn trunc<T: fmt::Display>(v: &[T]) -> String {
const BEGIN_INDEX: usize = 0;
const MAX_LEN: usize = 5;
let mut s = String::new();

12
crates/onnx-ir/src/lib.rs Normal file
View File

@ -0,0 +1,12 @@
mod coalesce;
mod dim_inference;
mod from_onnx;
pub mod ir;
mod node_remap;
mod proto_conversion;
mod protos;
mod util;
pub use from_onnx::convert_constant_value;
pub use from_onnx::parse_onnx;
pub use ir::OnnxGraph;

View File

@ -1,6 +1,6 @@
use std::str::{from_utf8, FromStr};
use crate::onnx::ir::TensorType;
use crate::ir::TensorType;
use super::from_onnx::GraphData;
use super::ir::Dim;

View File

@ -0,0 +1,45 @@
use crate::ir::{ArgType, Node};
/// Create a FlattenConfig from the attributes of the node
pub fn flatten_config(curr: &Node) -> (usize, usize) {
// the begin dimension is the first dimension (Default: 1 per ONNX spec)
let mut start_dim: i64 = 1;
// check if the node has only one input
if curr.inputs.len() != 1 {
panic!(
"Flatten: multiple inputs are not supported (got {:?})",
curr.inputs.len()
);
}
// extract the shape of the input tensor
let tensor = match curr.inputs.first().unwrap().clone().ty {
ArgType::Tensor(tensor) => tensor,
_ => panic!("Only tensor input is valid"),
};
// check if the input tensor has at least 2 dimensions
if tensor.dim < 2 {
panic!(
"Flatten: input tensor must have at least 2 dimensions (got {:?})",
tensor.dim
);
}
// the end dimension is the last dimension
let end_dim = tensor.dim - 1;
// extract the attributes
for (key, value) in curr.attrs.iter() {
if key.as_str() == "axis" {
start_dim = value.clone().into_i64();
}
}
// if beg_dim is negative, it is counted from the end
if start_dim < 0 {
start_dim += tensor.dim as i64;
}
(start_dim as usize, end_dim)
}