diff --git a/Cargo.toml b/Cargo.toml index f2337a4fa..b80df0513 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -55,7 +55,6 @@ strum_macros = "0.24" syn = {version = "2.0", features = ["full", "extra-traits"]} tempfile = "3.6.0" thiserror = "1.0.40" -topological-sort = "0.2.2" # WGPU stuff futures-intrusive = "0.5" diff --git a/burn-import/Cargo.toml b/burn-import/Cargo.toml index 1c25df5d8..2696322ca 100644 --- a/burn-import/Cargo.toml +++ b/burn-import/Cargo.toml @@ -35,7 +35,6 @@ serde_json = {workspace = true, features = ["std"]} strum = {workspace = true} strum_macros = {workspace = true} syn = {workspace = true, features = ["parsing"]} -topological-sort = {workspace = true} [build-dependencies] protobuf-codegen = {workspace = true} diff --git a/burn-import/src/burn/graph.rs b/burn-import/src/burn/graph.rs index 28fe9f834..3f1f34396 100644 --- a/burn-import/src/burn/graph.rs +++ b/burn-import/src/burn/graph.rs @@ -455,8 +455,12 @@ impl BurnGraph { }); output_names.iter().for_each(|output| { - self.graph_output_types - .push(outputs.get(output).unwrap().clone()); + self.graph_output_types.push( + outputs + .get(output) + .unwrap_or_else(|| panic!("Output type is not found for {output}")) + .clone(), + ); }); } } diff --git a/burn-import/src/onnx/from_onnx.rs b/burn-import/src/onnx/from_onnx.rs index 811a00de7..70b5821ee 100644 --- a/burn-import/src/onnx/from_onnx.rs +++ b/burn-import/src/onnx/from_onnx.rs @@ -1,5 +1,5 @@ use std::{ - collections::{HashMap, HashSet}, + collections::HashMap, fs::File, path::Path, str::{from_utf8, FromStr}, @@ -19,7 +19,6 @@ use super::{coalesce::coalesce, ir::StateType}; use bytemuck::cast_slice; use protobuf::{Enum, Message}; -use topological_sort::TopologicalSort; /// Error type for parsing ONNX model #[derive(Debug)] @@ -28,6 +27,20 @@ pub enum ParseError { } /// Open an onnx file and convert it to a Graph (intermediate representation) +/// +/// # Arguments +/// +/// * `onnx_path` - Path to the onnx file +/// +/// # Returns +/// +/// * `ONNXGraph` - The graph representation of the onnx file +/// +/// # Panics +/// +/// * If the file cannot be opened +/// * If the file cannot be parsed +/// * If the nodes are not topologically sorted pub fn parse_onnx(onnx_path: &Path) -> ONNXGraph { log::info!("Parsing ONNX file: {}", onnx_path.display()); @@ -52,29 +65,35 @@ pub fn parse_onnx(onnx_path: &Path) -> ONNXGraph { nodes.push(convert_node_proto(onnx_node)); } + // ONNX nodes must be topologically sorted per spec: + // https://github.com/onnx/onnx/blob/main/docs/IR.md#graphs + assert!(nodes.is_top_sorted(), "Nodes are not topologically sorted"); + // Move inputs to initializers move_inputs_to_state(&mut nodes, &onnx_model.graph.initializer); - // Get the topological sort of the nodes and the top nodes - top_sort_nodes(&mut nodes); - - // Collect inputs, outputs and initializers - let check_if_initializer: HashSet = onnx_model - .graph - .initializer - .iter() - .map(|x| x.name.clone()) - .collect(); - let mut inputs = collect_inputs(&onnx_model, &check_if_initializer); - - let mut outputs = collect_outputs(&onnx_model, check_if_initializer); - let states = collect_states(onnx_model); - // Coalesce and transform nodes coalesce(&mut nodes); // Rename nodes and inputs, save the mapping for later let old_node_names = rename_nodes(&mut nodes); + + // This function collects the inputs of an ONNX model and returns them as a vector of Arguments. + let mut inputs = onnx_model + .graph + .input + .iter() + .map(|x| Argument::try_from(x.clone()).unwrap()) + .collect(); + + // Map each output in the model's graph to an Argument and collect them into a vector. + let mut outputs = onnx_model + .graph + .output + .iter() + .map(|x| Argument::try_from(x.clone()).unwrap()) + .collect(); + let old_input_names = rename_inputs(&mut nodes, &mut inputs, &mut outputs); // Infer shapes and update the inputs and outputs @@ -86,71 +105,11 @@ pub fn parse_onnx(onnx_path: &Path) -> ONNXGraph { nodes, inputs, outputs, - states, old_node_names, old_input_names, } } -/// Collect initializers -fn collect_states(onnx_model: ModelProto) -> Vec { - let mut initializers = Vec::new(); - - for initializer in onnx_model.graph.initializer.iter() { - let tensor_proto = initializer.clone(); - let name = tensor_proto.name.clone(); - let tensor: Tensor = tensor_proto.try_into().unwrap(); - let ty = StateType::Tensor(tensor); - let arg = State { name, ty }; - - initializers.push(arg); - } - initializers -} - -/// Collect outputs -fn collect_outputs( - onnx_model: &ModelProto, - check_if_initializer: HashSet, -) -> Vec { - let outputs: Vec = onnx_model - .graph - .output - .iter() - .filter(|x| !check_if_initializer.contains(x.name.as_str())) - .map(|i| Argument::try_from(i.clone()).unwrap()) - .collect(); - outputs -} - -/// Collect inputs -fn collect_inputs( - onnx_model: &ModelProto, - check_if_initializer: &HashSet, -) -> Vec { - // Get the unique inputs - let inputs: Vec = onnx_model - .graph - .input - .iter() - .filter(|x| !check_if_initializer.contains(x.name.as_str())) - // .filter(|x| top_nodes.contains(&x.name)) - .map(|x| Argument::try_from(x.clone()).unwrap()) - .collect(); - - // Convert to a vector and return - inputs.into_iter().collect() -} - -/// Sort the nodes in topological order -fn top_sort_nodes(nodes: &mut Vec) { - let mut ts = topsort(nodes); - *nodes = vec![]; - while let Some(node) = ts.pop() { - nodes.push(node); - } -} - fn to_string(bytes: Vec) -> String { from_utf8(bytes.as_slice()).unwrap().to_string() } @@ -426,14 +385,23 @@ impl TryFrom for State { } } +// This function moves inputs that are also present in the initializer to the node's states vector. +// It also removes inputs that are already present in the states vector. fn move_inputs_to_state(nodes: &mut Vec, initializer: &[TensorProto]) { + // Iterate over each node in the graph nodes.iter_mut().for_each(|node| { + // Create a new vector to hold the node's states let mut node_states = Vec::new(); + // Create a new vector to hold the node's inputs let mut inputs = Vec::new(); + // Iterate over each input in the node's inputs vector for input in node.inputs.iter() { + // Iterate over each tensor in the initializer for init in initializer.iter() { + // If the input name matches the tensor name in the initializer if init.name == input.name { + // Add the tensor to the node's states vector node_states.push(State { name: init.name.clone(), ty: StateType::Tensor(init.clone().try_into().unwrap()), @@ -442,7 +410,10 @@ fn move_inputs_to_state(nodes: &mut Vec, initializer: &[TensorProto]) { } } + // Swap the node's inputs vector with the temporary inputs vector core::mem::swap(&mut inputs, &mut node.inputs); + + // Filter out inputs that are already present in the node's states vector node.inputs = inputs .into_iter() .filter(|input| { @@ -455,6 +426,8 @@ fn move_inputs_to_state(nodes: &mut Vec, initializer: &[TensorProto]) { true }) .collect(); + + // Set the node's states vector to the temporary node_states vector node.states = node_states; }); } @@ -482,7 +455,7 @@ fn rename_nodes(nodes: &mut Vec) -> HashMap { old_names } -/// Rename the inputs in the graph and return a map of the old names to the new names. +/// Rename the inputs and output in the graph and return a map of the old names to the new names. /// /// The inputs are renamed to be unique and to be in the format of conv2_in1, conv2_in2, etc. /// This is done to be consistent with the naming convention of the nodes and allow to be used as rust identifiers. @@ -493,13 +466,13 @@ fn rename_inputs( ) -> HashMap { let mut old_names = HashMap::new(); + // rename all graph input names to follow input1, input2, input3, etc. + // (assumes the input names are already unique) let mut counter = 1; for input in inputs.iter_mut() { let old_name = input.name.clone(); let new_name = format!("input{}", counter); - input.name = new_name.clone(); - old_names.insert(old_name, new_name); counter += 1; } @@ -513,13 +486,6 @@ fn rename_inputs( .and_modify(|e| *e += 1) .or_insert(1); - // loop through node inputs and rename them with previously replaced names - for input in node.inputs.iter_mut() { - if let Some(new_name) = old_names.get(&input.name) { - input.name = new_name.clone(); - } - } - // loop through node outputs and rename them and store the new name <-> old name mapping for output in node.outputs.iter_mut() { let old_name = output.name.clone(); @@ -529,48 +495,63 @@ fn rename_inputs( } } + for node in nodes.iter_mut() { + // loop through node inputs and rename them with previously replaced names + for input in node.inputs.iter_mut() { + if let Some(new_name) = old_names.get(&input.name) { + input.name = new_name.clone(); + } else { + panic!("Input {} not found in old_names", input.name); + } + } + } + // Rename the graph outputs for output in outputs.iter_mut() { if let Some(new_name) = old_names.get(&output.name) { output.name = new_name.clone(); + } else { + println!("{:#?}", old_names); + panic!("Output {} not found in old_names", output.name); } } old_names } -/// Find the node that produces the given output -fn lookup_node_by_output(nodes: &Vec, input: &str) -> Option { - for node in nodes.iter() { - if node.outputs.iter().any(|x| x.name == *input) { - return Some(node.clone()); - } - } - None +// Define a trait for topological sorting +trait TopologicalSortable { + fn is_top_sorted(&self) -> bool; } -/// Sort nodes in topological order -pub fn topsort(nodes: &Vec) -> TopologicalSort { - if nodes.is_empty() { - panic!("No nodes to sort"); - } +impl TopologicalSortable for Vec { + fn is_top_sorted(&self) -> bool { + // Create a hashmap to store the position of each node in the vector + let position: HashMap = self + .iter() + .enumerate() + .map(|(idx, node)| (node.name.clone(), idx)) + .collect(); - let mut ts = TopologicalSort::new(); - - // If there is only one node, then it is the only dependency - if nodes.len() == 1 { - ts.insert(nodes[0].clone()); - return ts; - } - - for node in nodes.iter() { - for input in node.inputs.iter() { - match lookup_node_by_output(nodes, input.name.as_str()) { - Some(prec) => ts.add_dependency(prec, node.clone()), - None => {} + // Iterate over each node in the vector + for node in self { + // Iterate over each output of the node + for output in &node.outputs { + // Iterate over each other node in the vector + for other_node in self { + // If the other node has an input that matches the current output + if other_node.inputs.contains(output) { + // If the position of the current node is greater than the position of the other node + if position[&node.name] > position[&other_node.name] { + // The vector is not topologically sorted + return false; + } + } + } } } - } - ts + // The vector is topologically sorted + true + } } diff --git a/burn-import/src/onnx/ir.rs b/burn-import/src/onnx/ir.rs index 6902a5c55..8d074192d 100644 --- a/burn-import/src/onnx/ir.rs +++ b/burn-import/src/onnx/ir.rs @@ -84,9 +84,6 @@ pub struct ONNXGraph { /// The outputs of the graph. pub outputs: Vec, - /// The states of the graph. - pub states: Vec, - /// The original node names. pub old_node_names: HashMap,