mirror of https://github.com/tracel-ai/burn.git
Remove topological sort of nodes (#589)
ONNX nodes should come as topologically sorted, so we are removing it instead, we are making sure the nodes are topologically sorted.
This commit is contained in:
parent
ce8a175aa4
commit
5cc32cc8cb
|
@ -55,7 +55,6 @@ strum_macros = "0.24"
|
|||
syn = {version = "2.0", features = ["full", "extra-traits"]}
|
||||
tempfile = "3.6.0"
|
||||
thiserror = "1.0.40"
|
||||
topological-sort = "0.2.2"
|
||||
|
||||
# WGPU stuff
|
||||
futures-intrusive = "0.5"
|
||||
|
|
|
@ -35,7 +35,6 @@ serde_json = {workspace = true, features = ["std"]}
|
|||
strum = {workspace = true}
|
||||
strum_macros = {workspace = true}
|
||||
syn = {workspace = true, features = ["parsing"]}
|
||||
topological-sort = {workspace = true}
|
||||
|
||||
[build-dependencies]
|
||||
protobuf-codegen = {workspace = true}
|
||||
|
|
|
@ -455,8 +455,12 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
|
|||
});
|
||||
|
||||
output_names.iter().for_each(|output| {
|
||||
self.graph_output_types
|
||||
.push(outputs.get(output).unwrap().clone());
|
||||
self.graph_output_types.push(
|
||||
outputs
|
||||
.get(output)
|
||||
.unwrap_or_else(|| panic!("Output type is not found for {output}"))
|
||||
.clone(),
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
collections::HashMap,
|
||||
fs::File,
|
||||
path::Path,
|
||||
str::{from_utf8, FromStr},
|
||||
|
@ -19,7 +19,6 @@ use super::{coalesce::coalesce, ir::StateType};
|
|||
|
||||
use bytemuck::cast_slice;
|
||||
use protobuf::{Enum, Message};
|
||||
use topological_sort::TopologicalSort;
|
||||
|
||||
/// Error type for parsing ONNX model
|
||||
#[derive(Debug)]
|
||||
|
@ -28,6 +27,20 @@ pub enum ParseError {
|
|||
}
|
||||
|
||||
/// Open an onnx file and convert it to a Graph (intermediate representation)
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `onnx_path` - Path to the onnx file
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `ONNXGraph` - The graph representation of the onnx file
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// * If the file cannot be opened
|
||||
/// * If the file cannot be parsed
|
||||
/// * If the nodes are not topologically sorted
|
||||
pub fn parse_onnx(onnx_path: &Path) -> ONNXGraph {
|
||||
log::info!("Parsing ONNX file: {}", onnx_path.display());
|
||||
|
||||
|
@ -52,29 +65,35 @@ pub fn parse_onnx(onnx_path: &Path) -> ONNXGraph {
|
|||
nodes.push(convert_node_proto(onnx_node));
|
||||
}
|
||||
|
||||
// ONNX nodes must be topologically sorted per spec:
|
||||
// https://github.com/onnx/onnx/blob/main/docs/IR.md#graphs
|
||||
assert!(nodes.is_top_sorted(), "Nodes are not topologically sorted");
|
||||
|
||||
// Move inputs to initializers
|
||||
move_inputs_to_state(&mut nodes, &onnx_model.graph.initializer);
|
||||
|
||||
// Get the topological sort of the nodes and the top nodes
|
||||
top_sort_nodes(&mut nodes);
|
||||
|
||||
// Collect inputs, outputs and initializers
|
||||
let check_if_initializer: HashSet<String> = onnx_model
|
||||
.graph
|
||||
.initializer
|
||||
.iter()
|
||||
.map(|x| x.name.clone())
|
||||
.collect();
|
||||
let mut inputs = collect_inputs(&onnx_model, &check_if_initializer);
|
||||
|
||||
let mut outputs = collect_outputs(&onnx_model, check_if_initializer);
|
||||
let states = collect_states(onnx_model);
|
||||
|
||||
// Coalesce and transform nodes
|
||||
coalesce(&mut nodes);
|
||||
|
||||
// Rename nodes and inputs, save the mapping for later
|
||||
let old_node_names = rename_nodes(&mut nodes);
|
||||
|
||||
// This function collects the inputs of an ONNX model and returns them as a vector of Arguments.
|
||||
let mut inputs = onnx_model
|
||||
.graph
|
||||
.input
|
||||
.iter()
|
||||
.map(|x| Argument::try_from(x.clone()).unwrap())
|
||||
.collect();
|
||||
|
||||
// Map each output in the model's graph to an Argument and collect them into a vector.
|
||||
let mut outputs = onnx_model
|
||||
.graph
|
||||
.output
|
||||
.iter()
|
||||
.map(|x| Argument::try_from(x.clone()).unwrap())
|
||||
.collect();
|
||||
|
||||
let old_input_names = rename_inputs(&mut nodes, &mut inputs, &mut outputs);
|
||||
|
||||
// Infer shapes and update the inputs and outputs
|
||||
|
@ -86,71 +105,11 @@ pub fn parse_onnx(onnx_path: &Path) -> ONNXGraph {
|
|||
nodes,
|
||||
inputs,
|
||||
outputs,
|
||||
states,
|
||||
old_node_names,
|
||||
old_input_names,
|
||||
}
|
||||
}
|
||||
|
||||
/// Collect initializers
|
||||
fn collect_states(onnx_model: ModelProto) -> Vec<State> {
|
||||
let mut initializers = Vec::new();
|
||||
|
||||
for initializer in onnx_model.graph.initializer.iter() {
|
||||
let tensor_proto = initializer.clone();
|
||||
let name = tensor_proto.name.clone();
|
||||
let tensor: Tensor = tensor_proto.try_into().unwrap();
|
||||
let ty = StateType::Tensor(tensor);
|
||||
let arg = State { name, ty };
|
||||
|
||||
initializers.push(arg);
|
||||
}
|
||||
initializers
|
||||
}
|
||||
|
||||
/// Collect outputs
|
||||
fn collect_outputs(
|
||||
onnx_model: &ModelProto,
|
||||
check_if_initializer: HashSet<String>,
|
||||
) -> Vec<Argument> {
|
||||
let outputs: Vec<Argument> = onnx_model
|
||||
.graph
|
||||
.output
|
||||
.iter()
|
||||
.filter(|x| !check_if_initializer.contains(x.name.as_str()))
|
||||
.map(|i| Argument::try_from(i.clone()).unwrap())
|
||||
.collect();
|
||||
outputs
|
||||
}
|
||||
|
||||
/// Collect inputs
|
||||
fn collect_inputs(
|
||||
onnx_model: &ModelProto,
|
||||
check_if_initializer: &HashSet<String>,
|
||||
) -> Vec<Argument> {
|
||||
// Get the unique inputs
|
||||
let inputs: Vec<Argument> = onnx_model
|
||||
.graph
|
||||
.input
|
||||
.iter()
|
||||
.filter(|x| !check_if_initializer.contains(x.name.as_str()))
|
||||
// .filter(|x| top_nodes.contains(&x.name))
|
||||
.map(|x| Argument::try_from(x.clone()).unwrap())
|
||||
.collect();
|
||||
|
||||
// Convert to a vector and return
|
||||
inputs.into_iter().collect()
|
||||
}
|
||||
|
||||
/// Sort the nodes in topological order
|
||||
fn top_sort_nodes(nodes: &mut Vec<Node>) {
|
||||
let mut ts = topsort(nodes);
|
||||
*nodes = vec![];
|
||||
while let Some(node) = ts.pop() {
|
||||
nodes.push(node);
|
||||
}
|
||||
}
|
||||
|
||||
fn to_string(bytes: Vec<u8>) -> String {
|
||||
from_utf8(bytes.as_slice()).unwrap().to_string()
|
||||
}
|
||||
|
@ -426,14 +385,23 @@ impl TryFrom<ValueInfoProto> for State {
|
|||
}
|
||||
}
|
||||
|
||||
// This function moves inputs that are also present in the initializer to the node's states vector.
|
||||
// It also removes inputs that are already present in the states vector.
|
||||
fn move_inputs_to_state(nodes: &mut Vec<Node>, initializer: &[TensorProto]) {
|
||||
// Iterate over each node in the graph
|
||||
nodes.iter_mut().for_each(|node| {
|
||||
// Create a new vector to hold the node's states
|
||||
let mut node_states = Vec::new();
|
||||
// Create a new vector to hold the node's inputs
|
||||
let mut inputs = Vec::new();
|
||||
|
||||
// Iterate over each input in the node's inputs vector
|
||||
for input in node.inputs.iter() {
|
||||
// Iterate over each tensor in the initializer
|
||||
for init in initializer.iter() {
|
||||
// If the input name matches the tensor name in the initializer
|
||||
if init.name == input.name {
|
||||
// Add the tensor to the node's states vector
|
||||
node_states.push(State {
|
||||
name: init.name.clone(),
|
||||
ty: StateType::Tensor(init.clone().try_into().unwrap()),
|
||||
|
@ -442,7 +410,10 @@ fn move_inputs_to_state(nodes: &mut Vec<Node>, initializer: &[TensorProto]) {
|
|||
}
|
||||
}
|
||||
|
||||
// Swap the node's inputs vector with the temporary inputs vector
|
||||
core::mem::swap(&mut inputs, &mut node.inputs);
|
||||
|
||||
// Filter out inputs that are already present in the node's states vector
|
||||
node.inputs = inputs
|
||||
.into_iter()
|
||||
.filter(|input| {
|
||||
|
@ -455,6 +426,8 @@ fn move_inputs_to_state(nodes: &mut Vec<Node>, initializer: &[TensorProto]) {
|
|||
true
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Set the node's states vector to the temporary node_states vector
|
||||
node.states = node_states;
|
||||
});
|
||||
}
|
||||
|
@ -482,7 +455,7 @@ fn rename_nodes(nodes: &mut Vec<Node>) -> HashMap<String, String> {
|
|||
old_names
|
||||
}
|
||||
|
||||
/// Rename the inputs in the graph and return a map of the old names to the new names.
|
||||
/// Rename the inputs and output in the graph and return a map of the old names to the new names.
|
||||
///
|
||||
/// The inputs are renamed to be unique and to be in the format of conv2_in1, conv2_in2, etc.
|
||||
/// This is done to be consistent with the naming convention of the nodes and allow to be used as rust identifiers.
|
||||
|
@ -493,13 +466,13 @@ fn rename_inputs(
|
|||
) -> HashMap<String, String> {
|
||||
let mut old_names = HashMap::new();
|
||||
|
||||
// rename all graph input names to follow input1, input2, input3, etc.
|
||||
// (assumes the input names are already unique)
|
||||
let mut counter = 1;
|
||||
for input in inputs.iter_mut() {
|
||||
let old_name = input.name.clone();
|
||||
let new_name = format!("input{}", counter);
|
||||
|
||||
input.name = new_name.clone();
|
||||
|
||||
old_names.insert(old_name, new_name);
|
||||
counter += 1;
|
||||
}
|
||||
|
@ -513,13 +486,6 @@ fn rename_inputs(
|
|||
.and_modify(|e| *e += 1)
|
||||
.or_insert(1);
|
||||
|
||||
// loop through node inputs and rename them with previously replaced names
|
||||
for input in node.inputs.iter_mut() {
|
||||
if let Some(new_name) = old_names.get(&input.name) {
|
||||
input.name = new_name.clone();
|
||||
}
|
||||
}
|
||||
|
||||
// loop through node outputs and rename them and store the new name <-> old name mapping
|
||||
for output in node.outputs.iter_mut() {
|
||||
let old_name = output.name.clone();
|
||||
|
@ -529,48 +495,63 @@ fn rename_inputs(
|
|||
}
|
||||
}
|
||||
|
||||
for node in nodes.iter_mut() {
|
||||
// loop through node inputs and rename them with previously replaced names
|
||||
for input in node.inputs.iter_mut() {
|
||||
if let Some(new_name) = old_names.get(&input.name) {
|
||||
input.name = new_name.clone();
|
||||
} else {
|
||||
panic!("Input {} not found in old_names", input.name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Rename the graph outputs
|
||||
for output in outputs.iter_mut() {
|
||||
if let Some(new_name) = old_names.get(&output.name) {
|
||||
output.name = new_name.clone();
|
||||
} else {
|
||||
println!("{:#?}", old_names);
|
||||
panic!("Output {} not found in old_names", output.name);
|
||||
}
|
||||
}
|
||||
|
||||
old_names
|
||||
}
|
||||
|
||||
/// Find the node that produces the given output
|
||||
fn lookup_node_by_output(nodes: &Vec<Node>, input: &str) -> Option<Node> {
|
||||
for node in nodes.iter() {
|
||||
if node.outputs.iter().any(|x| x.name == *input) {
|
||||
return Some(node.clone());
|
||||
}
|
||||
}
|
||||
None
|
||||
// Define a trait for topological sorting
|
||||
trait TopologicalSortable {
|
||||
fn is_top_sorted(&self) -> bool;
|
||||
}
|
||||
|
||||
/// Sort nodes in topological order
|
||||
pub fn topsort(nodes: &Vec<Node>) -> TopologicalSort<Node> {
|
||||
if nodes.is_empty() {
|
||||
panic!("No nodes to sort");
|
||||
}
|
||||
impl TopologicalSortable for Vec<Node> {
|
||||
fn is_top_sorted(&self) -> bool {
|
||||
// Create a hashmap to store the position of each node in the vector
|
||||
let position: HashMap<String, usize> = self
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, node)| (node.name.clone(), idx))
|
||||
.collect();
|
||||
|
||||
let mut ts = TopologicalSort::new();
|
||||
|
||||
// If there is only one node, then it is the only dependency
|
||||
if nodes.len() == 1 {
|
||||
ts.insert(nodes[0].clone());
|
||||
return ts;
|
||||
}
|
||||
|
||||
for node in nodes.iter() {
|
||||
for input in node.inputs.iter() {
|
||||
match lookup_node_by_output(nodes, input.name.as_str()) {
|
||||
Some(prec) => ts.add_dependency(prec, node.clone()),
|
||||
None => {}
|
||||
// Iterate over each node in the vector
|
||||
for node in self {
|
||||
// Iterate over each output of the node
|
||||
for output in &node.outputs {
|
||||
// Iterate over each other node in the vector
|
||||
for other_node in self {
|
||||
// If the other node has an input that matches the current output
|
||||
if other_node.inputs.contains(output) {
|
||||
// If the position of the current node is greater than the position of the other node
|
||||
if position[&node.name] > position[&other_node.name] {
|
||||
// The vector is not topologically sorted
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ts
|
||||
// The vector is topologically sorted
|
||||
true
|
||||
}
|
||||
}
|
||||
|
|
|
@ -84,9 +84,6 @@ pub struct ONNXGraph {
|
|||
/// The outputs of the graph.
|
||||
pub outputs: Vec<Argument>,
|
||||
|
||||
/// The states of the graph.
|
||||
pub states: Vec<State>,
|
||||
|
||||
/// The original node names.
|
||||
pub old_node_names: HashMap<String, String>,
|
||||
|
||||
|
|
Loading…
Reference in New Issue