Optimize argument handling and improve ONNX graph building (#1857)

* draft for alternative burn import design

* passes onnx test, fails to build example

* pushing to test example on main

* fixed the issue with the example

* passes the test now

* spring cleaning and minor code changes

* removed pub visibility from most graph_data fields and functions

* comment fixes

* went ahead and removed the constant check for now

* removed unused function arg
This commit is contained in:
Joshua Ferguson 2024-06-10 14:06:54 -05:00 committed by GitHub
parent 9a32e53e65
commit effce28b72
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 184 additions and 326 deletions

View File

@ -1,7 +1,7 @@
use std::{iter::Peekable, slice::Iter};
use super::{
from_onnx::OnnxGraphIO,
from_onnx::GraphData,
ir::{AttributeValue, Node, NodeType},
proto_conversion::convert_node_proto,
protos::NodeProto,
@ -12,12 +12,12 @@ use crate::onnx::ir::{ArgType, Data, TensorType};
pub fn coalesce(
node: &mut Node,
nodes_iter: &mut Peekable<Iter<NodeProto>>,
graph_io: &OnnxGraphIO,
graph_data: &GraphData,
) {
match node.node_type {
NodeType::Gemm => convert_gemm_to_linear(node),
NodeType::MatMul => {
convert_matmul_to_linear(node, nodes_iter, graph_io);
convert_matmul_to_linear(node, nodes_iter, graph_data);
}
_ => {}
}
@ -120,7 +120,7 @@ fn transpose_flattened<T: Copy>(matrix: Vec<T>, rows: usize, cols: usize) -> Vec
pub(crate) fn convert_matmul_to_linear(
node: &mut Node,
iter_mut: &mut Peekable<Iter<NodeProto>>,
graph_io: &OnnxGraphIO,
graph_data: &GraphData,
) {
if node.inputs.len() != 2 {
panic!("MatMul node must have 2 inputs");
@ -141,9 +141,10 @@ pub(crate) fn convert_matmul_to_linear(
// Convert the node to Linear
node.node_type = NodeType::Linear;
log::debug!("peeking next node for bias conversion");
// Check the next node for potential conversion
if let Some(peek_node) = iter_mut.peek() {
let peek_node = convert_node_proto(peek_node, graph_io).clone();
let peek_node = convert_node_proto(peek_node, graph_data);
if is_add_node_with_bias(&peek_node, node) {
convert_and_remove_add_node(&peek_node, node);

View File

@ -4,14 +4,13 @@ use core::panic;
use protobuf::Enum;
use super::{
from_onnx::OnnxGraphIO,
ir::{ArgType, AttributeValue, Data, ElementType, Node, NodeType, TensorType},
op_configuration::flatten_config,
protos::tensor_proto::DataType,
};
/// Infer the dimension of each output tensor and update them.
pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) {
pub fn dim_inference(node: &mut Node) {
match node.node_type {
NodeType::Add => same_as_input(node),
NodeType::ArgMax => argmax_update_outputs(node),
@ -81,8 +80,6 @@ pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) {
// Intentionally letting outputs leave unchanged but issue a warning so IR file can be generated.
_ => temporary_pass_through_stub(node),
}
graph_io.update_tensor_output(node);
}
fn constant_update_outputs(node: &mut Node) {

View File

@ -4,11 +4,12 @@ use std::{
path::Path,
};
use crate::onnx::{node_remap::remap_node_type, proto_conversion::convert_node_proto};
use crate::onnx::node_remap::remap_node_type;
use super::{
coalesce::coalesce,
ir::{Data, OnnxGraph, TensorType},
proto_conversion::convert_node_proto,
protos::{ModelProto, NodeProto, TensorProto, ValueInfoProto},
};
@ -30,288 +31,213 @@ const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 10] = [
NodeType::Squeeze,
];
#[derive(Debug)]
#[derive(Debug, Clone)]
pub(crate) enum IOEntry {
In(usize),
Out(usize),
Node(usize),
Node(usize, usize),
}
pub(crate) struct OnnxGraphIO {
/// The inputs for the Graph
pub(crate) inputs: Vec<Argument>,
/// The outputs for the Graph
pub(crate) outputs: Vec<Argument>,
/// Initializers
pub struct GraphData {
/// The nodes that have been processed, used to copy the outputs to a child node
processed_nodes: Vec<Node>,
/// The inputs of the graph
inputs: Vec<Argument>,
/// The outputs of the graph
outputs: Vec<Argument>,
/// The initializers of the graph
pub(crate) initializers: HashMap<String, Argument>,
///updated names of outputs of node not stored in the graph
node_out: Vec<Argument>,
pub(crate) old_io_names: HashMap<String, IOEntry>,
/// Maps the original input name to a graph input
input_name_map: HashMap<String, IOEntry>,
/// Maps the updated input name to the original input name. Required to check if the input is an initializer
input_key_map: HashMap<String, String>,
}
impl OnnxGraphIO {
impl GraphData {
pub(crate) fn new(
inputs: &Vec<ValueInfoProto>,
outputs: &Vec<ValueInfoProto>,
initializers: &Vec<TensorProto>,
) -> Self {
let mut old_io_names = HashMap::new();
let mut in_count = 1;
let mut input_name_map = HashMap::new();
let mut input_key_map = HashMap::new();
let constants = initializers
.iter()
.map(|x| (x.name.clone(), Argument::from_initializer(x)))
.collect::<HashMap<String, Argument>>();
let outputs = outputs
.iter()
.map(|x| Argument::try_from(x.clone()).unwrap())
.collect::<Vec<Argument>>();
let inputs = inputs
.iter()
.enumerate()
.map(|(i, x)| {
let in_name = format!("input{}", in_count);
old_io_names.insert(x.name.clone(), IOEntry::In(i));
let in_name = format!("input{}", i + 1);
input_name_map.insert(x.name.clone(), IOEntry::In(i));
input_key_map.insert(in_name.clone(), x.name.clone());
let mut arg = Argument::try_from(x.clone()).unwrap();
if let Some(initial_arg) = constants.get(&x.name) {
if arg.value.is_none() {
log::warn!("Input {} is also an initializer. Initializer as default values are currently not supported", x.name);
arg.copy_value(initial_arg);
}
}
in_count += 1;
arg.name = in_name;
arg
})
.collect::<Vec<Argument>>();
let outputs = outputs
.iter()
.enumerate()
.map(|(i, x)| {
old_io_names.insert(x.name.clone(), IOEntry::Out(i));
Argument::try_from(x.clone()).unwrap()
})
.collect::<Vec<Argument>>();
let constants = initializers
.iter()
.map(|x| (x.name.clone(), Argument::from_initializer(x)))
.collect::<HashMap<String, Argument>>();
Self {
inputs,
outputs,
initializers: constants,
node_out: Vec::new(),
old_io_names,
processed_nodes: Vec::new(),
input_name_map,
input_key_map,
}
}
fn update_name(&mut self, arg: &Argument, new_name: &str) {
match self.old_io_names.get(&arg.name) {
Some(IOEntry::In(_)) => {
panic!("input names are set from the beginning");
}
Some(IOEntry::Out(i)) => {
let arg = self.outputs.get_mut(*i).unwrap();
arg.name = new_name.to_string();
}
Some(IOEntry::Node(i)) => {
let arg = self.node_out.get_mut(*i).unwrap();
arg.name = new_name.to_string();
}
/// Get the value of an input from the original input name. Used during proto conversion
pub(crate) fn init_in(&self, proto_str: &str) -> Argument {
match self.input_name_map.get(proto_str) {
None => {
//Constants, Casts wound up here before API changes
panic!(
"Tried to update the name of {} to {} but entry doesn't exist in the map",
arg.name, new_name
)
}
}
}
/// Used to initialize the input arguments for nodes. Names need to remain the same because
/// currently the old names are the key for accessing the Argument
pub fn init_in(&self, proto_str: String) -> Argument {
match self.old_io_names.get(&proto_str) {
None => {
if let Some(init_arg) = self.initializers.get(&proto_str) {
//NOTE: if initializers are guaranteed to be unique, (I think they are
//need to confirm) then we could pop the initializer from the map
if let Some(init_arg) = self.initializers.get(proto_str) {
init_arg.clone()
} else {
Argument::new(proto_str)
log::warn!(
"Input {} not found, should only happen when peeking",
proto_str
);
Argument::new(proto_str.to_string())
}
}
Some(IOEntry::In(i)) => {
let mut arg = self.inputs[*i].clone();
arg.name = proto_str;
arg.passed = true;
arg
}
Some(IOEntry::Node(i)) => {
let mut arg = self.node_out[*i].clone();
arg.name = proto_str;
arg
}
Some(IOEntry::Out(_)) => {
panic!("graph output {} can't be a Node input", &proto_str)
}
Some(IOEntry::In(i)) => self.inputs[*i].clone(),
Some(IOEntry::Node(i, j)) => self.processed_nodes[*i].outputs[*j].clone(),
}
}
fn insert(&mut self, arg: &Argument, new_name: &str) {
if let Some(idx) = self.old_io_names.get(&arg.name) {
if let IOEntry::Node(idx) = idx {
if self.node_out[*idx].name == arg.name {
self.node_out[*idx].name = new_name.to_string();
return;
/// Mark the graph_inputs to a node as passed, unless they are also initializers
fn mark_input_passed(&mut self, node: &Node) {
// we have to double map the inputs because the input might be replaced by an initializer
node.inputs.iter().for_each(|node_input| {
if let Some(old_input_name) = self.input_key_map.get(&node_input.name) {
if !self.initializers.contains_key(old_input_name) {
match self.input_name_map.get(old_input_name) {
Some(IOEntry::In(i)) => self.inputs[*i].passed = true,
_ => {
panic!("Should not happen, please report this error");
}
}
}
} else {
panic!("arg entry with old name {} is a graph IO", &arg.name);
}
}
let idx = self.node_out.len();
self.old_io_names
.insert(arg.name.clone(), IOEntry::Node(idx));
self.node_out.push(arg.clone());
self.node_out[idx].name = new_name.to_string();
});
}
/// Copies node outputs to graph IO. Used at the end of dim inference.
pub(crate) fn update_tensor_output(&mut self, node: &Node) {
for node_output in node.outputs.iter() {
match self.old_io_names.get(&node_output.name) {
Some(IOEntry::In(i)) => {
let arg = self.inputs.get_mut(*i).unwrap();
arg.copy_value(node_output);
}
Some(IOEntry::Out(i)) => {
let arg = self.outputs.get_mut(*i).unwrap();
arg.copy_value(node_output);
//Set the output to passed since it's been altered by a Node
arg.passed = true;
}
Some(IOEntry::Node(_)) => {
panic!("This output is from another node");
}
None => {
log::debug!("inserting with name {:?}", &node_output.name);
let idx = self.node_out.len();
self.old_io_names
.insert(node_output.name.clone(), IOEntry::Node(idx));
self.node_out.push(node_output.clone());
}
}
/// This function does three things:
/// 1. marks the inputs as passed
/// 2. maps the old output names to the node output
/// 3. renames the node output
fn add_node(&mut self, mut node: Node) {
log::debug!("adding node {:?}", &node.name);
self.mark_input_passed(&node);
let mut out_count = 1;
for output in node.outputs.iter_mut() {
self.input_name_map.insert(
output.name.clone(),
IOEntry::Node(self.processed_nodes.len(), 0),
);
output.name = format!("{}_out{}", node.name, out_count);
out_count += 1;
}
self.processed_nodes.push(node);
}
/// Used by handle unsqeeze to remap the output of a node to a new name
/// expected match if it exists is either a graph input or graph output
pub(crate) fn get_node_output(&self, old_name: &str) -> Option<&Argument> {
match self.old_io_names.get(old_name) {
Some(IOEntry::In(i)) => self.inputs.get(*i),
Some(IOEntry::Out(i)) => self.outputs.get(*i),
Some(IOEntry::Node(_)) => panic!("This is a node output"),
None => None,
}
/// Consumes the graph data and returns the processed nodes, filtered inputs and outputs
fn consume(mut self) -> (Vec<Node>, Vec<Argument>, Vec<Argument>) {
self.inputs.retain(|x| x.passed);
let outputs = self
.outputs
.into_iter()
.filter_map(|x| match self.input_name_map.get(&x.name) {
Some(IOEntry::Node(i, j)) => Some(self.processed_nodes[*i].outputs[*j].clone()),
_ => None,
})
.collect();
(self.processed_nodes, self.inputs, outputs)
}
/// Get the updated name of a Node Input, which should be
/// either a graph input or a node output.
/// Will return None if the it isn't a graph input or node output(like an initializer)
/// Will panic if it's a graph output
fn get_new_name(&mut self, old_name: &str) -> Option<String> {
match self.old_io_names.get(old_name) {
Some(IOEntry::In(i)) => {
//FIXME: technically in the spec, initializers are default values
//for optional inputs, but implementing that would require reworking
//the way the graph is built, and it's not clear burn users are using initializers
//in that way
// see https://github.com/onnx/onnx/issues/2660
if self.initializers.contains_key(old_name) {
None
} else {
//set the input as passed since a node is referencing it
self.inputs[*i].passed = true;
Some(self.inputs[*i].name.clone())
}
}
Some(IOEntry::Out(_)) => {
panic!(
"you just tried to get an updated name on a graph output: {}",
old_name
)
}
Some(IOEntry::Node(i)) => Some(self.node_out[*i].name.clone()),
None => None,
}
/// Used to get the output of the graph by name. Only used to remap unsqueeze nodes
pub fn get_graph_output(&self, name: &str) -> Option<&Argument> {
self.outputs.iter().find(|x| x.name == name)
}
// Since Nodes are added at the end of conversion, the current index is the length of the processed nodes
/// Get the current index of the processed nodes. Useful when lifting values or marking nodes for removal
pub fn get_current_index(&self) -> usize {
self.processed_nodes.len()
}
}
#[derive(Default)]
pub(crate) struct OnnxGraphBuilder {
nodes: Vec<Node>,
inputs: Vec<Argument>,
outputs: Vec<Argument>,
/// Counter for node names, used for renaming nodes
node_name_counter: HashMap<NodeType, usize>,
/// Nodes to remove
/// Nodes to remove. Note may be moved to graph data if we implement support for custom ops
nodes_to_remove: HashSet<usize>,
/// Map from constant node output names to indices of constant nodes
constants_map: HashMap<String, usize>,
/// Node types that should be lifted to constants
constants_types: HashSet<NodeType>,
/// Map from identity node output names to indices of identity nodes
identity_idx: HashMap<String, usize>,
node_name_counter: HashMap<NodeType, usize>,
}
impl OnnxGraphBuilder {
pub(crate) fn node_gen(&mut self, model_proto: &ModelProto) {
pub(crate) fn build(mut self, model_proto: &ModelProto) -> OnnxGraph {
self.constants_types = LIFT_CONSTANTS_FOR_NODE_TYPES.into_iter().collect();
let mut graph_io = OnnxGraphIO::new(
let mut graph_data = GraphData::new(
&model_proto.graph.input,
&model_proto.graph.output,
&model_proto.graph.initializer,
);
self.nodes = Vec::with_capacity(model_proto.graph.node.len());
let mut and_idx = 0;
let mut node_iter = model_proto.graph.node.iter().peekable();
while let Some(node_proto) = node_iter.next() {
let mut node = convert_node_proto(node_proto, &graph_io);
let mut node = convert_node_proto(node_proto, &graph_data);
remap_node_type(&mut node);
coalesce(&mut node, &mut node_iter, &graph_io);
self.handle_node_renaming(&mut node);
self.handle_identity(&mut node, and_idx);
self.check_constants(&mut node, and_idx, &mut graph_io);
self.handle_unsqueeze(&mut node, &graph_io);
coalesce(&mut node, &mut node_iter, &graph_data);
self.handle_identity(&mut node, &graph_data);
self.check_constants(&mut node, &graph_data);
// NOTE: potential start of custom functions
// can filter, coalesce, or modify the nodes here
// args : node, peek_iter, graph_data
self.handle_unsqueeze(&mut node, &graph_data);
dim_inference(&mut node, &mut graph_io);
rename_io(&mut node, &mut graph_io);
self.nodes.push(node);
and_idx += 1;
dim_inference(&mut node);
graph_data.add_node(node);
}
let mut i = 0;
self.nodes.retain(|_x| {
let res = !self.nodes_to_remove.contains(&i);
i += 1;
res
});
let OnnxGraphIO {
mut inputs,
mut outputs,
..
} = graph_io;
let (mut processed_nodes, inputs, outputs) = graph_data.consume();
// Remove the graph inputs/output that are not used by any node
remove_unused_graph_inputs(&mut inputs, &mut outputs);
self.inputs = inputs;
self.outputs = outputs;
let mut i = 0;
processed_nodes.retain(|_| {
let keep = !self.nodes_to_remove.contains(&i);
i += 1;
keep
});
OnnxGraph {
nodes: processed_nodes,
inputs,
outputs,
}
}
fn handle_node_renaming(&mut self, node: &mut Node) {
@ -328,17 +254,20 @@ impl OnnxGraphBuilder {
node.name.clone_from(&new_name);
}
fn check_constants(&mut self, node: &mut Node, i: usize, _graph_io: &mut OnnxGraphIO) {
fn check_constants(&mut self, node: &mut Node, graph_data: &GraphData) {
if node.node_type == NodeType::Constant
|| (node.node_type == NodeType::Identity && node.inputs[0].value.is_some())
{
self.constants_map.insert(node.outputs[0].name.clone(), i);
self.constants_map.insert(
format!("{}_out{}", &node.name, 1),
graph_data.get_current_index(),
);
} else if self.constants_types.contains(&node.node_type) {
log::debug!("checking node {} for constants", &node.name);
for input in node.inputs.iter_mut().skip(1) {
log::debug!("checking input {:?} for const", input);
if let Some(const_idx) = self.constants_map.get(&input.name) {
let constant = &self.nodes[*const_idx];
let constant = &graph_data.processed_nodes[*const_idx];
log::debug!(
"input {} matched constant node {}",
&input.name,
@ -362,29 +291,29 @@ impl OnnxGraphBuilder {
/// Check if the unsqueeze node has a rhs value (rhs is constant) and if not remap it to a reshape
/// Needs to be called after node renaming to ensure that the rhs name is correct
/// Needs to be called after constant lifting to ensure that the rhs value exists
fn handle_unsqueeze(&mut self, node: &mut Node, graph_io: &OnnxGraphIO) {
fn handle_unsqueeze(&mut self, node: &mut Node, graph_data: &GraphData) {
if node.node_type == NodeType::Unsqueeze
&& node.inputs.len() > 1
&& node.inputs[1].value.is_none()
{
if let Some(in_arg) = graph_io.get_node_output(&node.outputs[0].name) {
remap_unsqueeze_to_reshape(node, in_arg);
//if the output has a shape, it's only because it's a graph output
if let Some(out_arg) = graph_data.get_graph_output(&node.outputs[0].name) {
remap_unsqueeze_to_reshape(node, out_arg);
}
}
}
fn handle_identity(&mut self, node: &mut Node, i: usize) {
fn handle_identity(&mut self, node: &mut Node, graph_data: &GraphData) {
if node.node_type == NodeType::Identity && node.inputs[0].value.is_none() {
log::debug!("\nfound identity node:\n{:?}\n", &node);
let i = graph_data.get_current_index();
//map the output name to check for pass through values
self.identity_idx.insert(node.outputs[0].name.clone(), i);
self.identity_idx.insert(format!("{}_out1", &node.name), i);
self.nodes_to_remove.insert(i);
} else {
//NOTE: it might be possible to rework the API to handle all "per input" operations
//in a new function that operates on each input.
node.inputs.iter_mut().for_each(|x| {
if let Some(identity_idx) = self.identity_idx.get(&x.name) {
let input_name = &self.nodes[*identity_idx].inputs[0].name;
let input_name = &graph_data.processed_nodes[*identity_idx].inputs[0].name;
x.name.clone_from(input_name);
}
@ -431,118 +360,50 @@ pub fn parse_onnx(onnx_path: &Path) -> OnnxGraph {
);
log::debug!("Number of outputs: {:?}", onnx_model.graph.output.len());
let mut builder = OnnxGraphBuilder::default();
builder.node_gen(&onnx_model);
let OnnxGraphBuilder {
nodes,
inputs: inner_inputs,
outputs: inner_outputs,
..
} = builder;
let builder = OnnxGraphBuilder::default();
let graph = builder.build(&onnx_model);
log::info!("Finished parsing ONNX file: {}", onnx_path.display());
OnnxGraph {
nodes,
inputs: inner_inputs,
outputs: inner_outputs,
}
graph
}
/// Remap the unsqueeze node to a reshape node, Should only be called after
/// node renaming has been done. avoids marking rhs as passed so that it can be
/// properly deleted if nothing else uses it
fn remap_unsqueeze_to_reshape(node: &mut Node, out_arg: &Argument) {
match node.outputs[0].ty {
ArgType::Tensor(ref mut tensor_type) => {
if let ArgType::Tensor(arg_tensor) = &out_arg.ty {
tensor_type.shape.clone_from(&arg_tensor.shape);
let inner = arg_tensor
.shape
.clone()
.unwrap()
.into_iter()
.map(|x| x as i64)
.collect::<Vec<i64>>();
let shape_len = inner.len();
let new_rhs_value = Some(Data::Int64s(inner));
//moving the remap to here
let rhs_arg = Argument {
name: format!("{}_generated_const", node.name),
ty: ArgType::Tensor(TensorType {
elem_type: super::ir::ElementType::Int64,
dim: 1,
shape: Some(vec![shape_len]),
}),
value: new_rhs_value,
passed: false,
};
node.inputs[1] = rhs_arg;
node.outputs[0] = out_arg.clone();
node.node_type = NodeType::Reshape;
}
/// Remap the unsqueeze node to a reshape node
pub(crate) fn remap_unsqueeze_to_reshape(node: &mut Node, out_arg: &Argument) {
match &out_arg.ty {
ArgType::Tensor(output_tensor) => {
let inner = output_tensor
.shape
.clone()
.unwrap()
.into_iter()
.map(|x| x as i64)
.collect::<Vec<i64>>();
let shape_len = inner.len();
let new_rhs_value = Some(Data::Int64s(inner));
//moving the remap to here
let rhs_arg = Argument {
name: format!("{}_generated_const", &node.name),
ty: ArgType::Tensor(TensorType {
elem_type: super::ir::ElementType::Int64,
dim: 1,
shape: Some(vec![shape_len]),
}),
value: new_rhs_value,
passed: false,
};
// ? should this replace the old input (reuse the old key) or should it be a new key
// going with new key for now
node.inputs[1] = rhs_arg;
node.outputs[0] = out_arg.clone();
node.node_type = NodeType::Reshape;
}
_ => {}
}
}
/// Rename the inputs and output in the graph and return a map of
/// the old names to the new names.
///
/// The inputs are renamed to be unique and to be in the format of
/// conv2_in1, conv2_in2, etc. This is done to be consistent with
/// the naming convention of the nodes and allow to be used as rust identifiers.
/// Rename the inputs and output in the graph and return a map of
/// the old names to the new names.
fn rename_io(node: &mut Node, graph_io: &mut OnnxGraphIO) {
log::debug!("checking inputs for node {:?}", &node.name);
for node_input in node.inputs.iter_mut() {
if let Some(input_name) = graph_io.get_new_name(&node_input.name) {
node_input.passed = true;
node_input.name.clone_from(&input_name);
} else {
node_input.name = "".to_string();
node_input.passed = false;
}
}
let mut out_count = 1;
if node.node_type == NodeType::Constant || node.node_type == NodeType::Identity {
let new_name = format!("{}_out{}", node.name, out_count);
graph_io.insert(&node.outputs[0], &new_name);
node.outputs[0].name.clone_from(&new_name);
log::debug!("Found {} constant", new_name);
} else {
for output in node.outputs.iter_mut() {
log::debug!("output name: {}", &output.name);
let new_name = format!("{}_out{}", node.name, out_count);
graph_io.update_name(output, &new_name);
output.name.clone_from(&new_name);
out_count += 1;
}
}
}
/// Removes the graph inputs/output that are not used by any node.
///
/// In older ONNX models, the inputs and outputs are not always used by the nodes.
/// For example, the input could be used as a state instead of an input. Since the
/// inputs with initializers are moved to the states vector, the inputs vector could
/// contain unused inputs. The same is true for the outputs.
///
/// Generally, it's a good idea to remove unused inputs/outputs because it makes the
/// generated code cleaner and easier to read.
fn remove_unused_graph_inputs(inputs: &mut Vec<Argument>, outputs: &mut Vec<Argument>) {
// Remove inputs that are not used by any node
inputs.retain(|input| input.passed);
// Remove outputs that are not used by any node
outputs.retain(|output| output.passed);
}
// Define a trait for topological sorting
trait TopologicalSortable {
fn is_top_sorted(&self) -> bool;

View File

@ -2,7 +2,7 @@ use std::str::{from_utf8, FromStr};
use crate::onnx::ir::TensorType;
use super::from_onnx::OnnxGraphIO;
use super::from_onnx::GraphData;
use super::ir::Dim;
use super::ir::{
ArgType, Argument, AttributeValue, Attributes, Data, ElementType, Node, NodeType, Tensor,
@ -180,19 +180,18 @@ pub fn convert_vec_attrs_proto(attrs: Vec<AttributeProto>) -> Attributes {
result
}
pub fn convert_node_proto(node: &NodeProto, graph_io: &OnnxGraphIO) -> Node {
pub fn convert_node_proto(node: &NodeProto, graph_data: &GraphData) -> Node {
let name = node.name.clone();
log::debug!("Converting ONNX node with type {:?}", node.op_type.as_str());
let inputs = node
.input
.clone()
.into_iter()
.map(|x| graph_io.init_in(x))
.collect();
let inputs = node.input.iter().map(|x| graph_data.init_in(x)).collect();
let outputs = node.output.clone().into_iter().map(Argument::new).collect();
let outputs = node
.output
.iter()
.map(|x| Argument::new(x.to_string()))
.collect();
let attrs = convert_vec_attrs_proto(node.attribute.clone());