diff --git a/crates/burn-import/src/onnx/coalesce.rs b/crates/burn-import/src/onnx/coalesce.rs index c3d5d93d2..efccadeb2 100644 --- a/crates/burn-import/src/onnx/coalesce.rs +++ b/crates/burn-import/src/onnx/coalesce.rs @@ -1,7 +1,7 @@ use std::{iter::Peekable, slice::Iter}; use super::{ - from_onnx::OnnxGraphIO, + from_onnx::GraphData, ir::{AttributeValue, Node, NodeType}, proto_conversion::convert_node_proto, protos::NodeProto, @@ -12,12 +12,12 @@ use crate::onnx::ir::{ArgType, Data, TensorType}; pub fn coalesce( node: &mut Node, nodes_iter: &mut Peekable>, - graph_io: &OnnxGraphIO, + graph_data: &GraphData, ) { match node.node_type { NodeType::Gemm => convert_gemm_to_linear(node), NodeType::MatMul => { - convert_matmul_to_linear(node, nodes_iter, graph_io); + convert_matmul_to_linear(node, nodes_iter, graph_data); } _ => {} } @@ -120,7 +120,7 @@ fn transpose_flattened(matrix: Vec, rows: usize, cols: usize) -> Vec pub(crate) fn convert_matmul_to_linear( node: &mut Node, iter_mut: &mut Peekable>, - graph_io: &OnnxGraphIO, + graph_data: &GraphData, ) { if node.inputs.len() != 2 { panic!("MatMul node must have 2 inputs"); @@ -141,9 +141,10 @@ pub(crate) fn convert_matmul_to_linear( // Convert the node to Linear node.node_type = NodeType::Linear; + log::debug!("peeking next node for bias conversion"); // Check the next node for potential conversion if let Some(peek_node) = iter_mut.peek() { - let peek_node = convert_node_proto(peek_node, graph_io).clone(); + let peek_node = convert_node_proto(peek_node, graph_data); if is_add_node_with_bias(&peek_node, node) { convert_and_remove_add_node(&peek_node, node); diff --git a/crates/burn-import/src/onnx/dim_inference.rs b/crates/burn-import/src/onnx/dim_inference.rs index a70502f9c..35029be06 100644 --- a/crates/burn-import/src/onnx/dim_inference.rs +++ b/crates/burn-import/src/onnx/dim_inference.rs @@ -4,14 +4,13 @@ use core::panic; use protobuf::Enum; use super::{ - from_onnx::OnnxGraphIO, ir::{ArgType, AttributeValue, Data, ElementType, Node, NodeType, TensorType}, op_configuration::flatten_config, protos::tensor_proto::DataType, }; /// Infer the dimension of each output tensor and update them. -pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) { +pub fn dim_inference(node: &mut Node) { match node.node_type { NodeType::Add => same_as_input(node), NodeType::ArgMax => argmax_update_outputs(node), @@ -81,8 +80,6 @@ pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) { // Intentionally letting outputs leave unchanged but issue a warning so IR file can be generated. _ => temporary_pass_through_stub(node), } - - graph_io.update_tensor_output(node); } fn constant_update_outputs(node: &mut Node) { diff --git a/crates/burn-import/src/onnx/from_onnx.rs b/crates/burn-import/src/onnx/from_onnx.rs index fbcba0dbc..6a1d61514 100644 --- a/crates/burn-import/src/onnx/from_onnx.rs +++ b/crates/burn-import/src/onnx/from_onnx.rs @@ -4,11 +4,12 @@ use std::{ path::Path, }; -use crate::onnx::{node_remap::remap_node_type, proto_conversion::convert_node_proto}; +use crate::onnx::node_remap::remap_node_type; use super::{ coalesce::coalesce, ir::{Data, OnnxGraph, TensorType}, + proto_conversion::convert_node_proto, protos::{ModelProto, NodeProto, TensorProto, ValueInfoProto}, }; @@ -30,288 +31,213 @@ const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 10] = [ NodeType::Squeeze, ]; -#[derive(Debug)] +#[derive(Debug, Clone)] pub(crate) enum IOEntry { In(usize), - Out(usize), - Node(usize), + Node(usize, usize), } -pub(crate) struct OnnxGraphIO { - /// The inputs for the Graph - pub(crate) inputs: Vec, - /// The outputs for the Graph - pub(crate) outputs: Vec, - /// Initializers +pub struct GraphData { + /// The nodes that have been processed, used to copy the outputs to a child node + processed_nodes: Vec, + /// The inputs of the graph + inputs: Vec, + /// The outputs of the graph + outputs: Vec, + /// The initializers of the graph pub(crate) initializers: HashMap, - ///updated names of outputs of node not stored in the graph - node_out: Vec, - pub(crate) old_io_names: HashMap, + /// Maps the original input name to a graph input + input_name_map: HashMap, + /// Maps the updated input name to the original input name. Required to check if the input is an initializer + input_key_map: HashMap, } -impl OnnxGraphIO { +impl GraphData { pub(crate) fn new( inputs: &Vec, outputs: &Vec, initializers: &Vec, ) -> Self { - let mut old_io_names = HashMap::new(); - let mut in_count = 1; + let mut input_name_map = HashMap::new(); + let mut input_key_map = HashMap::new(); + let constants = initializers .iter() .map(|x| (x.name.clone(), Argument::from_initializer(x))) .collect::>(); - + let outputs = outputs + .iter() + .map(|x| Argument::try_from(x.clone()).unwrap()) + .collect::>(); let inputs = inputs .iter() .enumerate() .map(|(i, x)| { - let in_name = format!("input{}", in_count); - old_io_names.insert(x.name.clone(), IOEntry::In(i)); + let in_name = format!("input{}", i + 1); + + input_name_map.insert(x.name.clone(), IOEntry::In(i)); + input_key_map.insert(in_name.clone(), x.name.clone()); + let mut arg = Argument::try_from(x.clone()).unwrap(); if let Some(initial_arg) = constants.get(&x.name) { if arg.value.is_none() { + log::warn!("Input {} is also an initializer. Initializer as default values are currently not supported", x.name); arg.copy_value(initial_arg); } } - in_count += 1; arg.name = in_name; arg }) .collect::>(); - - let outputs = outputs - .iter() - .enumerate() - .map(|(i, x)| { - old_io_names.insert(x.name.clone(), IOEntry::Out(i)); - Argument::try_from(x.clone()).unwrap() - }) - .collect::>(); - - let constants = initializers - .iter() - .map(|x| (x.name.clone(), Argument::from_initializer(x))) - .collect::>(); - Self { inputs, outputs, initializers: constants, - node_out: Vec::new(), - old_io_names, + processed_nodes: Vec::new(), + input_name_map, + input_key_map, } } - fn update_name(&mut self, arg: &Argument, new_name: &str) { - match self.old_io_names.get(&arg.name) { - Some(IOEntry::In(_)) => { - panic!("input names are set from the beginning"); - } - Some(IOEntry::Out(i)) => { - let arg = self.outputs.get_mut(*i).unwrap(); - arg.name = new_name.to_string(); - } - Some(IOEntry::Node(i)) => { - let arg = self.node_out.get_mut(*i).unwrap(); - arg.name = new_name.to_string(); - } + /// Get the value of an input from the original input name. Used during proto conversion + pub(crate) fn init_in(&self, proto_str: &str) -> Argument { + match self.input_name_map.get(proto_str) { None => { - //Constants, Casts wound up here before API changes - panic!( - "Tried to update the name of {} to {} but entry doesn't exist in the map", - arg.name, new_name - ) - } - } - } - - /// Used to initialize the input arguments for nodes. Names need to remain the same because - /// currently the old names are the key for accessing the Argument - pub fn init_in(&self, proto_str: String) -> Argument { - match self.old_io_names.get(&proto_str) { - None => { - if let Some(init_arg) = self.initializers.get(&proto_str) { + //NOTE: if initializers are guaranteed to be unique, (I think they are + //need to confirm) then we could pop the initializer from the map + if let Some(init_arg) = self.initializers.get(proto_str) { init_arg.clone() } else { - Argument::new(proto_str) + log::warn!( + "Input {} not found, should only happen when peeking", + proto_str + ); + Argument::new(proto_str.to_string()) } } - - Some(IOEntry::In(i)) => { - let mut arg = self.inputs[*i].clone(); - - arg.name = proto_str; - arg.passed = true; - arg - } - Some(IOEntry::Node(i)) => { - let mut arg = self.node_out[*i].clone(); - arg.name = proto_str; - arg - } - Some(IOEntry::Out(_)) => { - panic!("graph output {} can't be a Node input", &proto_str) - } + Some(IOEntry::In(i)) => self.inputs[*i].clone(), + Some(IOEntry::Node(i, j)) => self.processed_nodes[*i].outputs[*j].clone(), } } - fn insert(&mut self, arg: &Argument, new_name: &str) { - if let Some(idx) = self.old_io_names.get(&arg.name) { - if let IOEntry::Node(idx) = idx { - if self.node_out[*idx].name == arg.name { - self.node_out[*idx].name = new_name.to_string(); - return; + /// Mark the graph_inputs to a node as passed, unless they are also initializers + fn mark_input_passed(&mut self, node: &Node) { + // we have to double map the inputs because the input might be replaced by an initializer + node.inputs.iter().for_each(|node_input| { + if let Some(old_input_name) = self.input_key_map.get(&node_input.name) { + if !self.initializers.contains_key(old_input_name) { + match self.input_name_map.get(old_input_name) { + Some(IOEntry::In(i)) => self.inputs[*i].passed = true, + _ => { + panic!("Should not happen, please report this error"); + } + } } - } else { - panic!("arg entry with old name {} is a graph IO", &arg.name); } - } - - let idx = self.node_out.len(); - self.old_io_names - .insert(arg.name.clone(), IOEntry::Node(idx)); - self.node_out.push(arg.clone()); - self.node_out[idx].name = new_name.to_string(); + }); } - /// Copies node outputs to graph IO. Used at the end of dim inference. - pub(crate) fn update_tensor_output(&mut self, node: &Node) { - for node_output in node.outputs.iter() { - match self.old_io_names.get(&node_output.name) { - Some(IOEntry::In(i)) => { - let arg = self.inputs.get_mut(*i).unwrap(); - arg.copy_value(node_output); - } - Some(IOEntry::Out(i)) => { - let arg = self.outputs.get_mut(*i).unwrap(); - arg.copy_value(node_output); - //Set the output to passed since it's been altered by a Node - arg.passed = true; - } - Some(IOEntry::Node(_)) => { - panic!("This output is from another node"); - } - None => { - log::debug!("inserting with name {:?}", &node_output.name); - let idx = self.node_out.len(); - self.old_io_names - .insert(node_output.name.clone(), IOEntry::Node(idx)); - self.node_out.push(node_output.clone()); - } - } + /// This function does three things: + /// 1. marks the inputs as passed + /// 2. maps the old output names to the node output + /// 3. renames the node output + fn add_node(&mut self, mut node: Node) { + log::debug!("adding node {:?}", &node.name); + self.mark_input_passed(&node); + let mut out_count = 1; + for output in node.outputs.iter_mut() { + self.input_name_map.insert( + output.name.clone(), + IOEntry::Node(self.processed_nodes.len(), 0), + ); + output.name = format!("{}_out{}", node.name, out_count); + out_count += 1; } + self.processed_nodes.push(node); } - /// Used by handle unsqeeze to remap the output of a node to a new name - /// expected match if it exists is either a graph input or graph output - pub(crate) fn get_node_output(&self, old_name: &str) -> Option<&Argument> { - match self.old_io_names.get(old_name) { - Some(IOEntry::In(i)) => self.inputs.get(*i), - Some(IOEntry::Out(i)) => self.outputs.get(*i), - Some(IOEntry::Node(_)) => panic!("This is a node output"), - None => None, - } + /// Consumes the graph data and returns the processed nodes, filtered inputs and outputs + fn consume(mut self) -> (Vec, Vec, Vec) { + self.inputs.retain(|x| x.passed); + let outputs = self + .outputs + .into_iter() + .filter_map(|x| match self.input_name_map.get(&x.name) { + Some(IOEntry::Node(i, j)) => Some(self.processed_nodes[*i].outputs[*j].clone()), + _ => None, + }) + .collect(); + (self.processed_nodes, self.inputs, outputs) } - /// Get the updated name of a Node Input, which should be - /// either a graph input or a node output. - /// Will return None if the it isn't a graph input or node output(like an initializer) - /// Will panic if it's a graph output - fn get_new_name(&mut self, old_name: &str) -> Option { - match self.old_io_names.get(old_name) { - Some(IOEntry::In(i)) => { - //FIXME: technically in the spec, initializers are default values - //for optional inputs, but implementing that would require reworking - //the way the graph is built, and it's not clear burn users are using initializers - //in that way - // see https://github.com/onnx/onnx/issues/2660 - if self.initializers.contains_key(old_name) { - None - } else { - //set the input as passed since a node is referencing it - self.inputs[*i].passed = true; - Some(self.inputs[*i].name.clone()) - } - } - Some(IOEntry::Out(_)) => { - panic!( - "you just tried to get an updated name on a graph output: {}", - old_name - ) - } - Some(IOEntry::Node(i)) => Some(self.node_out[*i].name.clone()), - None => None, - } + /// Used to get the output of the graph by name. Only used to remap unsqueeze nodes + pub fn get_graph_output(&self, name: &str) -> Option<&Argument> { + self.outputs.iter().find(|x| x.name == name) + } + + // Since Nodes are added at the end of conversion, the current index is the length of the processed nodes + /// Get the current index of the processed nodes. Useful when lifting values or marking nodes for removal + pub fn get_current_index(&self) -> usize { + self.processed_nodes.len() } } #[derive(Default)] pub(crate) struct OnnxGraphBuilder { - nodes: Vec, - inputs: Vec, - outputs: Vec, - /// Counter for node names, used for renaming nodes - node_name_counter: HashMap, - /// Nodes to remove + /// Nodes to remove. Note may be moved to graph data if we implement support for custom ops nodes_to_remove: HashSet, /// Map from constant node output names to indices of constant nodes constants_map: HashMap, + /// Node types that should be lifted to constants constants_types: HashSet, /// Map from identity node output names to indices of identity nodes identity_idx: HashMap, + node_name_counter: HashMap, } impl OnnxGraphBuilder { - pub(crate) fn node_gen(&mut self, model_proto: &ModelProto) { + pub(crate) fn build(mut self, model_proto: &ModelProto) -> OnnxGraph { self.constants_types = LIFT_CONSTANTS_FOR_NODE_TYPES.into_iter().collect(); - let mut graph_io = OnnxGraphIO::new( + let mut graph_data = GraphData::new( &model_proto.graph.input, &model_proto.graph.output, &model_proto.graph.initializer, ); - self.nodes = Vec::with_capacity(model_proto.graph.node.len()); - let mut and_idx = 0; let mut node_iter = model_proto.graph.node.iter().peekable(); while let Some(node_proto) = node_iter.next() { - let mut node = convert_node_proto(node_proto, &graph_io); + let mut node = convert_node_proto(node_proto, &graph_data); remap_node_type(&mut node); - - coalesce(&mut node, &mut node_iter, &graph_io); self.handle_node_renaming(&mut node); - self.handle_identity(&mut node, and_idx); - self.check_constants(&mut node, and_idx, &mut graph_io); - self.handle_unsqueeze(&mut node, &graph_io); + coalesce(&mut node, &mut node_iter, &graph_data); + self.handle_identity(&mut node, &graph_data); + self.check_constants(&mut node, &graph_data); + // NOTE: potential start of custom functions + // can filter, coalesce, or modify the nodes here + // args : node, peek_iter, graph_data + self.handle_unsqueeze(&mut node, &graph_data); - dim_inference(&mut node, &mut graph_io); - - rename_io(&mut node, &mut graph_io); - - self.nodes.push(node); - and_idx += 1; + dim_inference(&mut node); + graph_data.add_node(node); } - let mut i = 0; - self.nodes.retain(|_x| { - let res = !self.nodes_to_remove.contains(&i); - i += 1; - res - }); - let OnnxGraphIO { - mut inputs, - mut outputs, - .. - } = graph_io; - + let (mut processed_nodes, inputs, outputs) = graph_data.consume(); // Remove the graph inputs/output that are not used by any node - remove_unused_graph_inputs(&mut inputs, &mut outputs); - self.inputs = inputs; - self.outputs = outputs; + let mut i = 0; + processed_nodes.retain(|_| { + let keep = !self.nodes_to_remove.contains(&i); + i += 1; + keep + }); + OnnxGraph { + nodes: processed_nodes, + inputs, + outputs, + } } fn handle_node_renaming(&mut self, node: &mut Node) { @@ -328,17 +254,20 @@ impl OnnxGraphBuilder { node.name.clone_from(&new_name); } - fn check_constants(&mut self, node: &mut Node, i: usize, _graph_io: &mut OnnxGraphIO) { + fn check_constants(&mut self, node: &mut Node, graph_data: &GraphData) { if node.node_type == NodeType::Constant || (node.node_type == NodeType::Identity && node.inputs[0].value.is_some()) { - self.constants_map.insert(node.outputs[0].name.clone(), i); + self.constants_map.insert( + format!("{}_out{}", &node.name, 1), + graph_data.get_current_index(), + ); } else if self.constants_types.contains(&node.node_type) { log::debug!("checking node {} for constants", &node.name); for input in node.inputs.iter_mut().skip(1) { log::debug!("checking input {:?} for const", input); if let Some(const_idx) = self.constants_map.get(&input.name) { - let constant = &self.nodes[*const_idx]; + let constant = &graph_data.processed_nodes[*const_idx]; log::debug!( "input {} matched constant node {}", &input.name, @@ -362,29 +291,29 @@ impl OnnxGraphBuilder { /// Check if the unsqueeze node has a rhs value (rhs is constant) and if not remap it to a reshape /// Needs to be called after node renaming to ensure that the rhs name is correct /// Needs to be called after constant lifting to ensure that the rhs value exists - fn handle_unsqueeze(&mut self, node: &mut Node, graph_io: &OnnxGraphIO) { + fn handle_unsqueeze(&mut self, node: &mut Node, graph_data: &GraphData) { if node.node_type == NodeType::Unsqueeze && node.inputs.len() > 1 && node.inputs[1].value.is_none() { - if let Some(in_arg) = graph_io.get_node_output(&node.outputs[0].name) { - remap_unsqueeze_to_reshape(node, in_arg); + //if the output has a shape, it's only because it's a graph output + if let Some(out_arg) = graph_data.get_graph_output(&node.outputs[0].name) { + remap_unsqueeze_to_reshape(node, out_arg); } } } - fn handle_identity(&mut self, node: &mut Node, i: usize) { + fn handle_identity(&mut self, node: &mut Node, graph_data: &GraphData) { if node.node_type == NodeType::Identity && node.inputs[0].value.is_none() { log::debug!("\nfound identity node:\n{:?}\n", &node); + let i = graph_data.get_current_index(); //map the output name to check for pass through values - self.identity_idx.insert(node.outputs[0].name.clone(), i); + self.identity_idx.insert(format!("{}_out1", &node.name), i); self.nodes_to_remove.insert(i); } else { - //NOTE: it might be possible to rework the API to handle all "per input" operations - //in a new function that operates on each input. node.inputs.iter_mut().for_each(|x| { if let Some(identity_idx) = self.identity_idx.get(&x.name) { - let input_name = &self.nodes[*identity_idx].inputs[0].name; + let input_name = &graph_data.processed_nodes[*identity_idx].inputs[0].name; x.name.clone_from(input_name); } @@ -431,118 +360,50 @@ pub fn parse_onnx(onnx_path: &Path) -> OnnxGraph { ); log::debug!("Number of outputs: {:?}", onnx_model.graph.output.len()); - let mut builder = OnnxGraphBuilder::default(); - builder.node_gen(&onnx_model); - - let OnnxGraphBuilder { - nodes, - inputs: inner_inputs, - outputs: inner_outputs, - .. - } = builder; + let builder = OnnxGraphBuilder::default(); + let graph = builder.build(&onnx_model); log::info!("Finished parsing ONNX file: {}", onnx_path.display()); - OnnxGraph { - nodes, - inputs: inner_inputs, - outputs: inner_outputs, - } + graph } /// Remap the unsqueeze node to a reshape node, Should only be called after /// node renaming has been done. avoids marking rhs as passed so that it can be /// properly deleted if nothing else uses it -fn remap_unsqueeze_to_reshape(node: &mut Node, out_arg: &Argument) { - match node.outputs[0].ty { - ArgType::Tensor(ref mut tensor_type) => { - if let ArgType::Tensor(arg_tensor) = &out_arg.ty { - tensor_type.shape.clone_from(&arg_tensor.shape); - let inner = arg_tensor - .shape - .clone() - .unwrap() - .into_iter() - .map(|x| x as i64) - .collect::>(); - let shape_len = inner.len(); - let new_rhs_value = Some(Data::Int64s(inner)); - //moving the remap to here - let rhs_arg = Argument { - name: format!("{}_generated_const", node.name), - ty: ArgType::Tensor(TensorType { - elem_type: super::ir::ElementType::Int64, - dim: 1, - shape: Some(vec![shape_len]), - }), - value: new_rhs_value, - passed: false, - }; - node.inputs[1] = rhs_arg; - node.outputs[0] = out_arg.clone(); - node.node_type = NodeType::Reshape; - } +/// Remap the unsqueeze node to a reshape node +pub(crate) fn remap_unsqueeze_to_reshape(node: &mut Node, out_arg: &Argument) { + match &out_arg.ty { + ArgType::Tensor(output_tensor) => { + let inner = output_tensor + .shape + .clone() + .unwrap() + .into_iter() + .map(|x| x as i64) + .collect::>(); + let shape_len = inner.len(); + let new_rhs_value = Some(Data::Int64s(inner)); + //moving the remap to here + let rhs_arg = Argument { + name: format!("{}_generated_const", &node.name), + ty: ArgType::Tensor(TensorType { + elem_type: super::ir::ElementType::Int64, + dim: 1, + shape: Some(vec![shape_len]), + }), + value: new_rhs_value, + passed: false, + }; + // ? should this replace the old input (reuse the old key) or should it be a new key + // going with new key for now + node.inputs[1] = rhs_arg; + node.outputs[0] = out_arg.clone(); + node.node_type = NodeType::Reshape; } _ => {} } } - -/// Rename the inputs and output in the graph and return a map of -/// the old names to the new names. -/// -/// The inputs are renamed to be unique and to be in the format of -/// conv2_in1, conv2_in2, etc. This is done to be consistent with -/// the naming convention of the nodes and allow to be used as rust identifiers. -/// Rename the inputs and output in the graph and return a map of -/// the old names to the new names. -fn rename_io(node: &mut Node, graph_io: &mut OnnxGraphIO) { - log::debug!("checking inputs for node {:?}", &node.name); - for node_input in node.inputs.iter_mut() { - if let Some(input_name) = graph_io.get_new_name(&node_input.name) { - node_input.passed = true; - node_input.name.clone_from(&input_name); - } else { - node_input.name = "".to_string(); - node_input.passed = false; - } - } - let mut out_count = 1; - if node.node_type == NodeType::Constant || node.node_type == NodeType::Identity { - let new_name = format!("{}_out{}", node.name, out_count); - graph_io.insert(&node.outputs[0], &new_name); - node.outputs[0].name.clone_from(&new_name); - log::debug!("Found {} constant", new_name); - } else { - for output in node.outputs.iter_mut() { - log::debug!("output name: {}", &output.name); - - let new_name = format!("{}_out{}", node.name, out_count); - - graph_io.update_name(output, &new_name); - - output.name.clone_from(&new_name); - out_count += 1; - } - } -} - -/// Removes the graph inputs/output that are not used by any node. -/// -/// In older ONNX models, the inputs and outputs are not always used by the nodes. -/// For example, the input could be used as a state instead of an input. Since the -/// inputs with initializers are moved to the states vector, the inputs vector could -/// contain unused inputs. The same is true for the outputs. -/// -/// Generally, it's a good idea to remove unused inputs/outputs because it makes the -/// generated code cleaner and easier to read. -fn remove_unused_graph_inputs(inputs: &mut Vec, outputs: &mut Vec) { - // Remove inputs that are not used by any node - inputs.retain(|input| input.passed); - - // Remove outputs that are not used by any node - outputs.retain(|output| output.passed); -} - // Define a trait for topological sorting trait TopologicalSortable { fn is_top_sorted(&self) -> bool; diff --git a/crates/burn-import/src/onnx/proto_conversion.rs b/crates/burn-import/src/onnx/proto_conversion.rs index 2f1481281..740db218e 100644 --- a/crates/burn-import/src/onnx/proto_conversion.rs +++ b/crates/burn-import/src/onnx/proto_conversion.rs @@ -2,7 +2,7 @@ use std::str::{from_utf8, FromStr}; use crate::onnx::ir::TensorType; -use super::from_onnx::OnnxGraphIO; +use super::from_onnx::GraphData; use super::ir::Dim; use super::ir::{ ArgType, Argument, AttributeValue, Attributes, Data, ElementType, Node, NodeType, Tensor, @@ -180,19 +180,18 @@ pub fn convert_vec_attrs_proto(attrs: Vec) -> Attributes { result } -pub fn convert_node_proto(node: &NodeProto, graph_io: &OnnxGraphIO) -> Node { +pub fn convert_node_proto(node: &NodeProto, graph_data: &GraphData) -> Node { let name = node.name.clone(); log::debug!("Converting ONNX node with type {:?}", node.op_type.as_str()); - let inputs = node - .input - .clone() - .into_iter() - .map(|x| graph_io.init_in(x)) - .collect(); + let inputs = node.input.iter().map(|x| graph_data.init_in(x)).collect(); - let outputs = node.output.clone().into_iter().map(Argument::new).collect(); + let outputs = node + .output + .iter() + .map(|x| Argument::new(x.to_string())) + .collect(); let attrs = convert_vec_attrs_proto(node.attribute.clone());