Remove topological sort of nodes (#589)

ONNX nodes should come as topologically sorted, so we are removing it instead, we are making sure the nodes are topologically sorted.
This commit is contained in:
Dilshod Tadjibaev 2023-08-06 09:51:28 -05:00 committed by GitHub
parent ce8a175aa4
commit 5cc32cc8cb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 102 additions and 122 deletions

View File

@ -55,7 +55,6 @@ strum_macros = "0.24"
syn = {version = "2.0", features = ["full", "extra-traits"]} syn = {version = "2.0", features = ["full", "extra-traits"]}
tempfile = "3.6.0" tempfile = "3.6.0"
thiserror = "1.0.40" thiserror = "1.0.40"
topological-sort = "0.2.2"
# WGPU stuff # WGPU stuff
futures-intrusive = "0.5" futures-intrusive = "0.5"

View File

@ -35,7 +35,6 @@ serde_json = {workspace = true, features = ["std"]}
strum = {workspace = true} strum = {workspace = true}
strum_macros = {workspace = true} strum_macros = {workspace = true}
syn = {workspace = true, features = ["parsing"]} syn = {workspace = true, features = ["parsing"]}
topological-sort = {workspace = true}
[build-dependencies] [build-dependencies]
protobuf-codegen = {workspace = true} protobuf-codegen = {workspace = true}

View File

@ -455,8 +455,12 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
}); });
output_names.iter().for_each(|output| { output_names.iter().for_each(|output| {
self.graph_output_types self.graph_output_types.push(
.push(outputs.get(output).unwrap().clone()); outputs
.get(output)
.unwrap_or_else(|| panic!("Output type is not found for {output}"))
.clone(),
);
}); });
} }
} }

View File

@ -1,5 +1,5 @@
use std::{ use std::{
collections::{HashMap, HashSet}, collections::HashMap,
fs::File, fs::File,
path::Path, path::Path,
str::{from_utf8, FromStr}, str::{from_utf8, FromStr},
@ -19,7 +19,6 @@ use super::{coalesce::coalesce, ir::StateType};
use bytemuck::cast_slice; use bytemuck::cast_slice;
use protobuf::{Enum, Message}; use protobuf::{Enum, Message};
use topological_sort::TopologicalSort;
/// Error type for parsing ONNX model /// Error type for parsing ONNX model
#[derive(Debug)] #[derive(Debug)]
@ -28,6 +27,20 @@ pub enum ParseError {
} }
/// Open an onnx file and convert it to a Graph (intermediate representation) /// Open an onnx file and convert it to a Graph (intermediate representation)
///
/// # Arguments
///
/// * `onnx_path` - Path to the onnx file
///
/// # Returns
///
/// * `ONNXGraph` - The graph representation of the onnx file
///
/// # Panics
///
/// * If the file cannot be opened
/// * If the file cannot be parsed
/// * If the nodes are not topologically sorted
pub fn parse_onnx(onnx_path: &Path) -> ONNXGraph { pub fn parse_onnx(onnx_path: &Path) -> ONNXGraph {
log::info!("Parsing ONNX file: {}", onnx_path.display()); log::info!("Parsing ONNX file: {}", onnx_path.display());
@ -52,29 +65,35 @@ pub fn parse_onnx(onnx_path: &Path) -> ONNXGraph {
nodes.push(convert_node_proto(onnx_node)); nodes.push(convert_node_proto(onnx_node));
} }
// ONNX nodes must be topologically sorted per spec:
// https://github.com/onnx/onnx/blob/main/docs/IR.md#graphs
assert!(nodes.is_top_sorted(), "Nodes are not topologically sorted");
// Move inputs to initializers // Move inputs to initializers
move_inputs_to_state(&mut nodes, &onnx_model.graph.initializer); move_inputs_to_state(&mut nodes, &onnx_model.graph.initializer);
// Get the topological sort of the nodes and the top nodes
top_sort_nodes(&mut nodes);
// Collect inputs, outputs and initializers
let check_if_initializer: HashSet<String> = onnx_model
.graph
.initializer
.iter()
.map(|x| x.name.clone())
.collect();
let mut inputs = collect_inputs(&onnx_model, &check_if_initializer);
let mut outputs = collect_outputs(&onnx_model, check_if_initializer);
let states = collect_states(onnx_model);
// Coalesce and transform nodes // Coalesce and transform nodes
coalesce(&mut nodes); coalesce(&mut nodes);
// Rename nodes and inputs, save the mapping for later // Rename nodes and inputs, save the mapping for later
let old_node_names = rename_nodes(&mut nodes); let old_node_names = rename_nodes(&mut nodes);
// This function collects the inputs of an ONNX model and returns them as a vector of Arguments.
let mut inputs = onnx_model
.graph
.input
.iter()
.map(|x| Argument::try_from(x.clone()).unwrap())
.collect();
// Map each output in the model's graph to an Argument and collect them into a vector.
let mut outputs = onnx_model
.graph
.output
.iter()
.map(|x| Argument::try_from(x.clone()).unwrap())
.collect();
let old_input_names = rename_inputs(&mut nodes, &mut inputs, &mut outputs); let old_input_names = rename_inputs(&mut nodes, &mut inputs, &mut outputs);
// Infer shapes and update the inputs and outputs // Infer shapes and update the inputs and outputs
@ -86,71 +105,11 @@ pub fn parse_onnx(onnx_path: &Path) -> ONNXGraph {
nodes, nodes,
inputs, inputs,
outputs, outputs,
states,
old_node_names, old_node_names,
old_input_names, old_input_names,
} }
} }
/// Collect initializers
fn collect_states(onnx_model: ModelProto) -> Vec<State> {
let mut initializers = Vec::new();
for initializer in onnx_model.graph.initializer.iter() {
let tensor_proto = initializer.clone();
let name = tensor_proto.name.clone();
let tensor: Tensor = tensor_proto.try_into().unwrap();
let ty = StateType::Tensor(tensor);
let arg = State { name, ty };
initializers.push(arg);
}
initializers
}
/// Collect outputs
fn collect_outputs(
onnx_model: &ModelProto,
check_if_initializer: HashSet<String>,
) -> Vec<Argument> {
let outputs: Vec<Argument> = onnx_model
.graph
.output
.iter()
.filter(|x| !check_if_initializer.contains(x.name.as_str()))
.map(|i| Argument::try_from(i.clone()).unwrap())
.collect();
outputs
}
/// Collect inputs
fn collect_inputs(
onnx_model: &ModelProto,
check_if_initializer: &HashSet<String>,
) -> Vec<Argument> {
// Get the unique inputs
let inputs: Vec<Argument> = onnx_model
.graph
.input
.iter()
.filter(|x| !check_if_initializer.contains(x.name.as_str()))
// .filter(|x| top_nodes.contains(&x.name))
.map(|x| Argument::try_from(x.clone()).unwrap())
.collect();
// Convert to a vector and return
inputs.into_iter().collect()
}
/// Sort the nodes in topological order
fn top_sort_nodes(nodes: &mut Vec<Node>) {
let mut ts = topsort(nodes);
*nodes = vec![];
while let Some(node) = ts.pop() {
nodes.push(node);
}
}
fn to_string(bytes: Vec<u8>) -> String { fn to_string(bytes: Vec<u8>) -> String {
from_utf8(bytes.as_slice()).unwrap().to_string() from_utf8(bytes.as_slice()).unwrap().to_string()
} }
@ -426,14 +385,23 @@ impl TryFrom<ValueInfoProto> for State {
} }
} }
// This function moves inputs that are also present in the initializer to the node's states vector.
// It also removes inputs that are already present in the states vector.
fn move_inputs_to_state(nodes: &mut Vec<Node>, initializer: &[TensorProto]) { fn move_inputs_to_state(nodes: &mut Vec<Node>, initializer: &[TensorProto]) {
// Iterate over each node in the graph
nodes.iter_mut().for_each(|node| { nodes.iter_mut().for_each(|node| {
// Create a new vector to hold the node's states
let mut node_states = Vec::new(); let mut node_states = Vec::new();
// Create a new vector to hold the node's inputs
let mut inputs = Vec::new(); let mut inputs = Vec::new();
// Iterate over each input in the node's inputs vector
for input in node.inputs.iter() { for input in node.inputs.iter() {
// Iterate over each tensor in the initializer
for init in initializer.iter() { for init in initializer.iter() {
// If the input name matches the tensor name in the initializer
if init.name == input.name { if init.name == input.name {
// Add the tensor to the node's states vector
node_states.push(State { node_states.push(State {
name: init.name.clone(), name: init.name.clone(),
ty: StateType::Tensor(init.clone().try_into().unwrap()), ty: StateType::Tensor(init.clone().try_into().unwrap()),
@ -442,7 +410,10 @@ fn move_inputs_to_state(nodes: &mut Vec<Node>, initializer: &[TensorProto]) {
} }
} }
// Swap the node's inputs vector with the temporary inputs vector
core::mem::swap(&mut inputs, &mut node.inputs); core::mem::swap(&mut inputs, &mut node.inputs);
// Filter out inputs that are already present in the node's states vector
node.inputs = inputs node.inputs = inputs
.into_iter() .into_iter()
.filter(|input| { .filter(|input| {
@ -455,6 +426,8 @@ fn move_inputs_to_state(nodes: &mut Vec<Node>, initializer: &[TensorProto]) {
true true
}) })
.collect(); .collect();
// Set the node's states vector to the temporary node_states vector
node.states = node_states; node.states = node_states;
}); });
} }
@ -482,7 +455,7 @@ fn rename_nodes(nodes: &mut Vec<Node>) -> HashMap<String, String> {
old_names old_names
} }
/// Rename the inputs in the graph and return a map of the old names to the new names. /// Rename the inputs and output in the graph and return a map of the old names to the new names.
/// ///
/// The inputs are renamed to be unique and to be in the format of conv2_in1, conv2_in2, etc. /// The inputs are renamed to be unique and to be in the format of conv2_in1, conv2_in2, etc.
/// This is done to be consistent with the naming convention of the nodes and allow to be used as rust identifiers. /// This is done to be consistent with the naming convention of the nodes and allow to be used as rust identifiers.
@ -493,13 +466,13 @@ fn rename_inputs(
) -> HashMap<String, String> { ) -> HashMap<String, String> {
let mut old_names = HashMap::new(); let mut old_names = HashMap::new();
// rename all graph input names to follow input1, input2, input3, etc.
// (assumes the input names are already unique)
let mut counter = 1; let mut counter = 1;
for input in inputs.iter_mut() { for input in inputs.iter_mut() {
let old_name = input.name.clone(); let old_name = input.name.clone();
let new_name = format!("input{}", counter); let new_name = format!("input{}", counter);
input.name = new_name.clone(); input.name = new_name.clone();
old_names.insert(old_name, new_name); old_names.insert(old_name, new_name);
counter += 1; counter += 1;
} }
@ -513,13 +486,6 @@ fn rename_inputs(
.and_modify(|e| *e += 1) .and_modify(|e| *e += 1)
.or_insert(1); .or_insert(1);
// loop through node inputs and rename them with previously replaced names
for input in node.inputs.iter_mut() {
if let Some(new_name) = old_names.get(&input.name) {
input.name = new_name.clone();
}
}
// loop through node outputs and rename them and store the new name <-> old name mapping // loop through node outputs and rename them and store the new name <-> old name mapping
for output in node.outputs.iter_mut() { for output in node.outputs.iter_mut() {
let old_name = output.name.clone(); let old_name = output.name.clone();
@ -529,48 +495,63 @@ fn rename_inputs(
} }
} }
for node in nodes.iter_mut() {
// loop through node inputs and rename them with previously replaced names
for input in node.inputs.iter_mut() {
if let Some(new_name) = old_names.get(&input.name) {
input.name = new_name.clone();
} else {
panic!("Input {} not found in old_names", input.name);
}
}
}
// Rename the graph outputs // Rename the graph outputs
for output in outputs.iter_mut() { for output in outputs.iter_mut() {
if let Some(new_name) = old_names.get(&output.name) { if let Some(new_name) = old_names.get(&output.name) {
output.name = new_name.clone(); output.name = new_name.clone();
} else {
println!("{:#?}", old_names);
panic!("Output {} not found in old_names", output.name);
} }
} }
old_names old_names
} }
/// Find the node that produces the given output // Define a trait for topological sorting
fn lookup_node_by_output(nodes: &Vec<Node>, input: &str) -> Option<Node> { trait TopologicalSortable {
for node in nodes.iter() { fn is_top_sorted(&self) -> bool;
if node.outputs.iter().any(|x| x.name == *input) {
return Some(node.clone());
}
}
None
} }
/// Sort nodes in topological order impl TopologicalSortable for Vec<Node> {
pub fn topsort(nodes: &Vec<Node>) -> TopologicalSort<Node> { fn is_top_sorted(&self) -> bool {
if nodes.is_empty() { // Create a hashmap to store the position of each node in the vector
panic!("No nodes to sort"); let position: HashMap<String, usize> = self
} .iter()
.enumerate()
.map(|(idx, node)| (node.name.clone(), idx))
.collect();
let mut ts = TopologicalSort::new(); // Iterate over each node in the vector
for node in self {
// If there is only one node, then it is the only dependency // Iterate over each output of the node
if nodes.len() == 1 { for output in &node.outputs {
ts.insert(nodes[0].clone()); // Iterate over each other node in the vector
return ts; for other_node in self {
} // If the other node has an input that matches the current output
if other_node.inputs.contains(output) {
for node in nodes.iter() { // If the position of the current node is greater than the position of the other node
for input in node.inputs.iter() { if position[&node.name] > position[&other_node.name] {
match lookup_node_by_output(nodes, input.name.as_str()) { // The vector is not topologically sorted
Some(prec) => ts.add_dependency(prec, node.clone()), return false;
None => {} }
}
}
} }
} }
}
ts // The vector is topologically sorted
true
}
} }

View File

@ -84,9 +84,6 @@ pub struct ONNXGraph {
/// The outputs of the graph. /// The outputs of the graph.
pub outputs: Vec<Argument>, pub outputs: Vec<Argument>,
/// The states of the graph.
pub states: Vec<State>,
/// The original node names. /// The original node names.
pub old_node_names: HashMap<String, String>, pub old_node_names: HashMap<String, String>,